├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── pyproject.toml ├── src └── jax_sourceror │ ├── __init__.py │ ├── interpreter.py │ └── utils.py └── tests └── test_jaxpr_to_source.py /.gitignore: -------------------------------------------------------------------------------- 1 | /scratch 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | # JetBrains 134 | .idea/ 135 | 136 | 137 | # Wandb stuff 138 | /wandb 139 | 140 | # dataset cache files 141 | *.parquet 142 | ledger.json 143 | 144 | /checkpoints 145 | *.jaxpr 146 | 147 | # local execution commands 148 | local_*.sh 149 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | exclude: ".git" 4 | default_stages: 5 | - commit 6 | fail_fast: true 7 | 8 | repos: 9 | - repo: https://github.com/pre-commit/pre-commit-hooks 10 | rev: v4.0.1 11 | hooks: 12 | - id: trailing-whitespace 13 | - id: end-of-file-fixer 14 | - id: check-yaml 15 | - id: check-toml 16 | - id: check-merge-conflict 17 | - id: check-added-large-files 18 | 19 | - repo: https://github.com/psf/black 20 | rev: 22.3.0 21 | hooks: 22 | - id: black 23 | 24 | - repo: https://github.com/timothycrosley/isort 25 | rev: 5.11.5 26 | hooks: 27 | - id: isort 28 | 29 | - repo: https://gitlab.com/pycqa/flake8 30 | rev: 3.9.2 31 | hooks: 32 | - id: flake8 33 | additional_dependencies: [flake8-isort] 34 | 35 | - repo: https://github.com/pre-commit/mirrors-mypy 36 | rev: 'v0.960' 37 | hooks: 38 | - id: mypy 39 | args: [--ignore-missing-imports] 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 Stanford University 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Jax Sourceror 2 | 3 | Jax Sourceror is a Python library that allows you to recreate JAX source code from a jitted jax function (specifically its `jaxpr`) 4 | and a set of inputs. This is useful for minimizing bugs, debugging, teaching, and understanding how JAX works under the hood. 5 | 6 | The code this generates is definitely not going to be clean, idiomatic, or sometimes even correct, but it should be a good starting point for understanding what's going on. 7 | 8 | I created it mostly as a learning exercise and to minimize bugs in framework-heavy code (i.e. removing layers of equinox or flax abstraction to get to the JAX code). 9 | 10 | This is more of a "submit a PR" or "fork it" repo than a "this doesn't work for me" repo, but I'm happy to help out if you're stuck. 11 | 12 | ## Example 13 | 14 | Jax Sourceror can turn this: 15 | 16 | ```python 17 | import jax 18 | import jax.numpy as jnp 19 | 20 | block_len = 64 21 | seq_len = 128 22 | batch = 4 23 | num_heads = 2 24 | embed_size = 32 25 | num_layers = 2 26 | head_size = 16 27 | def pseudo_sliding_window_attention(x): 28 | # (this is not attention, but is minimized from attn) 29 | # dims are [batch, len, num_heads, head_dim] 30 | # having num_heads is important. num_heads = 1, no boom 31 | def block(block_idx): 32 | query_block = jax.lax.dynamic_slice_in_dim(x, block_idx, block_len, axis=1) 33 | weights = jnp.sum(query_block, axis=3) # [batch, len, num_heads] 34 | weights = jax.lax.broadcast_in_dim(weights, (batch, block_len, num_heads, block_len), 35 | (0, 1, 2)) # [batch, len, num_heads, len] 36 | # weights = with_sharding_constraint(weights, P('data', None, None, None)) 37 | # without "bias", no boom 38 | bias = jnp.ones(block_len) 39 | bias = jnp.broadcast_to(bias, (batch, block_len, num_heads, block_len)) 40 | weights = weights + bias 41 | return jnp.einsum('bqhk,bkhd->bqhd', weights, query_block).astype(query_block.dtype) 42 | 43 | num_blocks = seq_len // block_len 44 | blocked_attn = jax.lax.map(block, jnp.arange(0, num_blocks)) # [num_blocks, batch, len, num_heads, head_dim] 45 | blocked_attn = jnp.concatenate(blocked_attn, axis=1) 46 | 47 | return blocked_attn 48 | 49 | def fwd(params, x): 50 | @partial(jax.checkpoint, prevent_cse=False) 51 | def layer(x, params): 52 | qkv, o = params 53 | y = jnp.einsum('bte,hde->bthd', x, qkv) 54 | y = pseudo_sliding_window_attention(y) 55 | z = jnp.einsum('bthd,hde->bte', y, o) 56 | return z, None 57 | 58 | x, _ = jax.lax.scan(layer, x, params) 59 | 60 | return x 61 | 62 | def loss_fn(params, x): 63 | x = fwd(params, x) 64 | l = jnp.mean(x) 65 | return l 66 | 67 | def grad_fn(params, x): 68 | loss, grad = jax.value_and_grad(loss_fn)(params, x) 69 | # we can't reasonably sourcerize pytrees so just get the leaves 70 | return loss, *jax.tree_util.tree_leaves(grad) 71 | 72 | 73 | 74 | qkv = jnp.ones((num_layers, num_heads, head_size, embed_size), dtype=jnp.bfloat16) 75 | o = jnp.ones((num_layers, num_heads, head_size, embed_size), dtype=jnp.bfloat16) 76 | x = jnp.ones((batch, seq_len, embed_size), dtype=jnp.bfloat16) 77 | 78 | params = (qkv, o) 79 | 80 | grad_fn(params, x) 81 | ``` 82 | 83 | into this: 84 | 85 | ```python 86 | def grad_fn(*args, **kwargs): 87 | 88 | def grad_fn(a, b, c): 89 | d = jax.numpy.zeros((4, 128, 32), jax.numpy.bfloat16) 90 | e = jax.numpy.ones((64,), jax.numpy.float32) 91 | f = jax.lax.broadcast_in_dim(e, shape=(4, 64, 2, 64), broadcast_dimensions=(3,)) 92 | 93 | def fn_1(carry, x): 94 | # (I would like to make this part nicer) 95 | (g, h, i) = (carry, *x) 96 | j = jax.lax.dot_general(g, h, (((2,), (2,)), ((), ())), None, jax.numpy.bfloat16) 97 | 98 | def fn_2(k, l): 99 | 100 | def fn_3(carry, x): 101 | (m,) = (*carry, x) 102 | n = jax.lax.dynamic_slice(l, (0, m, 0, 0), slice_sizes=(4, 64, 2, 16)) 103 | o = n.astype(jax.numpy.float32) 104 | p = jax.numpy.sum(o, axis=(3,)) 105 | q = p.astype(jax.numpy.bfloat16) 106 | r = jax.lax.broadcast_in_dim(q, shape=(4, 64, 2, 64), broadcast_dimensions=(0, 1, 2)) 107 | s = r.astype(jax.numpy.float32) 108 | t = s + k 109 | u = jax.lax.dot_general(n, t, (((1,), (3,)), ((0, 2), (0, 2))), None, jax.numpy.float32) 110 | v = jax.lax.transpose(u, permutation=(0, 3, 1, 2)) 111 | w = v.astype(jax.numpy.bfloat16) 112 | return ((), w) 113 | (final_carry, ys) = jax.lax.scan(fn_3, (), jax.numpy.array([0, 1], dtype=jax.numpy.int32), length=2, unroll=1, reverse=False) 114 | x = ys 115 | return x 116 | y = fn_2(f, j) 117 | z = jax.numpy.reshape(jax.numpy.transpose(y, (1, 0, 2, 3, 4)), (4, 128, 2, 16)) 118 | ba = jax.lax.dot_general(z, i, (((3, 2), (1, 0)), ((), ())), None, jax.numpy.bfloat16) 119 | return (ba, g) 120 | (final_carry, ys) = jax.lax.scan(fn_1, c, (a, b), length=2, unroll=1, reverse=False) 121 | bb = final_carry 122 | bc = ys 123 | bd = bb.astype(jax.numpy.float32) 124 | be = jax.numpy.sum(bd, axis=(0, 1, 2)) 125 | bf = be / 16384.0 126 | bg = bf.astype(jax.numpy.bfloat16) 127 | bh = jax.lax.broadcast_in_dim(6.103515625e-05, shape=(4, 128, 32), broadcast_dimensions=()) 128 | bi = bh.astype(jax.numpy.bfloat16) 129 | 130 | def fn_4(carry, x): 131 | (bj, bk, bl, bm) = (carry, *x) 132 | 133 | def fn_5(bn, bo, bp, bq): 134 | br = jax.lax.dot_general(bn, bo, (((2,), (2,)), ((), ())), None, jax.numpy.bfloat16) 135 | bs = jax.numpy.ones((64,), jax.numpy.float32) 136 | bt = jax.lax.broadcast_in_dim(bs, shape=(4, 64, 2, 64), broadcast_dimensions=(3,)) 137 | 138 | def fn_6(carry, x): 139 | (bu,) = (*carry, x) 140 | bv = bu < 0 141 | bw = bu + 128 142 | bx = jax.lax.select_n(bv, bu, bw) 143 | by = jax.lax.dynamic_slice(br, (0, bx, 0, 0), slice_sizes=(4, 64, 2, 16)) 144 | bz = by.astype(jax.numpy.float32) 145 | ca = jax.numpy.sum(bz, axis=(3,)) 146 | cb = ca.astype(jax.numpy.bfloat16) 147 | cc = jax.lax.broadcast_in_dim(cb, shape=(4, 64, 2, 64), broadcast_dimensions=(0, 1, 2)) 148 | cd = cc.astype(jax.numpy.float32) 149 | ce = cd + bt 150 | cf = jax.lax.dot_general(by, ce, (((1,), (3,)), ((0, 2), (0, 2))), None, jax.numpy.float32) 151 | cg = jax.lax.transpose(cf, permutation=(0, 3, 1, 2)) 152 | ch = cg.astype(jax.numpy.bfloat16) 153 | return ((), (ch, bx, ce, by)) 154 | (final_carry, ys) = jax.lax.scan(fn_6, (), jax.numpy.array([0, 1], dtype=jax.numpy.int32), length=2, unroll=1, reverse=False) 155 | (ci, cj, ck, cl) = ys 156 | cm = jax.numpy.reshape(jax.numpy.transpose(ci, (1, 0, 2, 3, 4)), (4, 128, 2, 16)) 157 | cn = jax.lax.dot_general(bq, cm, (((0, 1), (0, 1)), ((), ())), None, jax.numpy.bfloat16) 158 | co = jax.lax.transpose(cn, permutation=(1, 2, 0)) 159 | cp = jax.lax.dot_general(bq, bp, (((2,), (2,)), ((), ())), None, jax.numpy.bfloat16) 160 | cq = jax.numpy.reshape(cp, (4, 2, 64, 2, 16)) 161 | cr = jax.lax.transpose(cq, permutation=(1, 0, 2, 3, 4)) 162 | cs = jax.numpy.zeros((4, 128, 2, 16), jax.numpy.bfloat16) 163 | 164 | def fn_7(carry, x): 165 | (ct, cu, cv, cw, cx) = (carry, *x) 166 | cy = cu.astype(jax.numpy.float32) 167 | cz = jax.lax.transpose(cy, permutation=(0, 2, 3, 1)) 168 | da = jax.lax.dot_general(cz, cx, (((2,), (3,)), ((0, 1), (0, 2))), None, jax.numpy.float32) 169 | db = jax.lax.transpose(da, permutation=(0, 2, 1, 3)) 170 | dc = db.astype(jax.numpy.bfloat16) 171 | dd = jax.numpy.sum(dc, axis=(3,)) 172 | de = dd.astype(jax.numpy.float32) 173 | df = jax.lax.broadcast_in_dim(de, shape=(4, 64, 2, 16), broadcast_dimensions=(0, 1, 2)) 174 | dg = df.astype(jax.numpy.bfloat16) 175 | dh = jax.lax.dot_general(cz, cw, (((3,), (1,)), ((0, 1), (0, 2))), None, jax.numpy.float32) 176 | di = jax.lax.transpose(dh, permutation=(0, 3, 1, 2)) 177 | dj = di.astype(jax.numpy.bfloat16) 178 | dk = dg + dj 179 | dl = jax.numpy.zeros((4, 128, 2, 16), jax.numpy.bfloat16) 180 | dm = jax.lax.dynamic_update_slice(dl, dk, (0, cv, 0, 0)) 181 | dn = ct + dm 182 | return (dn, ()) 183 | (final_carry, ys) = jax.lax.scan(fn_7, cs, (cr, cj, ck, cl), length=2, unroll=1, reverse=True) 184 | do = final_carry 185 | dp = jax.lax.dot_general(do, bn, (((0, 1), (0, 1)), ((), ())), None, jax.numpy.bfloat16) 186 | dq = jax.lax.dot_general(do, bo, (((2, 3), (0, 1)), ((), ())), None, jax.numpy.bfloat16) 187 | return (dq, dp, co) 188 | ckpt_fn_5 = jax.checkpoint(fn_5) 189 | (dr, ds, dt) = ckpt_fn_5(bk, bl, bm, bj) 190 | return (dr, (ds, dt)) 191 | (final_carry, ys) = jax.lax.scan(fn_4, bi, (bc, a, b), length=2, unroll=1, reverse=True) 192 | du = final_carry 193 | (dv, dw) = ys 194 | return (bg, dv, dw) 195 | return grad_fn(*jax.tree_leaves((args, kwargs))) 196 | ``` 197 | 198 | Is this pretty code? No. Is it even readable? If you try hard enough. 199 | Is it correct? I think so. (It definitely passes my unit test!) 200 | 201 | 202 | ## Usage 203 | 204 | ```python 205 | from jax_sourceror import sourcerize 206 | 207 | source_code = sourcerize(grad_fn)(*args, **kwargs) 208 | 209 | print(source_code) 210 | ``` -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=58.0.4", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "jax-sourceror" 7 | version = "0.0.1" 8 | authors = [ 9 | { name="David Hall", email="dlwh@cs.stanford.edu" }, 10 | ] 11 | description = "Named Tensors for Legible Deep Learning in JAX" 12 | readme = "README.md" 13 | requires-python = ">=3.10" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: Apache Software License", 17 | "Operating System :: POSIX :: Linux", 18 | "Operating System :: MacOS :: MacOS X", 19 | "Development Status :: 4 - Beta", 20 | "Intended Audience :: Science/Research", 21 | ] 22 | dependencies = [ 23 | # we'll require that you install jax yourself, since the extras vary by system. 24 | # jax = {version = ">=0.4.10,<0.5.0"} 25 | "jax", 26 | "numpy", 27 | "ast_comments", 28 | "equinox" 29 | ] 30 | 31 | 32 | [project.urls] 33 | "Homepage" = "https://github.com/dlwh/jax-sourceror" 34 | "Bug Tracker" = "https://github.com/dlwh/jax-sourceror/issues" 35 | -------------------------------------------------------------------------------- /src/jax_sourceror/__init__.py: -------------------------------------------------------------------------------- 1 | from .interpreter import sourcerize, register_prim_handler, primitive_handler -------------------------------------------------------------------------------- /src/jax_sourceror/interpreter.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import enum 3 | import warnings 4 | from dataclasses import dataclass 5 | from typing import Callable, Optional, Union 6 | 7 | import ast_comments as ast 8 | import equinox as eqx 9 | import jax 10 | import jax.numpy as jnp 11 | import numpy as np 12 | from jax._src.core import ClosedJaxpr 13 | from jax._src.source_info_util import user_frame, user_frames 14 | from jax.sharding import NamedSharding 15 | from jax._src.custom_derivatives import CustomJVPCallPrimitive 16 | from jax.experimental.pjit import _UNSPECIFIED 17 | from jax.core import Literal, Var, Jaxpr 18 | 19 | from jax_sourceror.utils import IdentityMap, IdentitySet 20 | 21 | @dataclass 22 | class SourcerorState(): 23 | """State for the auto-minimizer. Basically just in charge of naming variables.""" 24 | _var_names: IdentityMap[Var, str] = dataclasses.field(default_factory=IdentityMap) 25 | _used_fn_names: set[str] = dataclasses.field(default_factory=set) 26 | _skolem_count: int = 0 27 | use_jax_typing: bool = True 28 | 29 | def name(self, var, ctx=ast.Load()) -> ast.Name: 30 | return ast.Name(id=self.str_name(var), ctx=ctx) 31 | 32 | def str_name(self, var: Var): 33 | # Names things in a way vaguely compatible with JAX's naming scheme, which is 'a'-'z' followed by 'aa'-'az' etc. 34 | if var in self._var_names: 35 | return self._var_names[var] 36 | else: 37 | cur_count = len(self._var_names) 38 | name = "" 39 | while cur_count >= 26: 40 | name += chr(ord('a') + cur_count % 26) 41 | cur_count //= 26 42 | 43 | name += chr(ord('a') + cur_count) 44 | 45 | name = name[::-1] 46 | 47 | if name in ['def', 'if', 'or', 'and', 'not', 'for', 'as', 'in', 'is']: 48 | name = f"{name}_" 49 | 50 | 51 | self._var_names[var] = name 52 | 53 | return name 54 | 55 | def skolem(self, prefix: str): 56 | self._skolem_count += 1 57 | return f"{prefix}_{self._skolem_count}" 58 | 59 | def heuristic_fn_skolem(self, jaxpr: Jaxpr, default: Optional[str] = None): 60 | if default is None: 61 | default = "fn" 62 | 63 | name = _attempt_to_sniff_fn_name_for_jaxpr(jaxpr) or default 64 | if name in self._used_fn_names: 65 | return self.skolem(name) 66 | else: 67 | self._used_fn_names.add(name) 68 | return name 69 | 70 | 71 | 72 | def sourcerize(f, *, use_jax_typing: bool = False): 73 | def return_fn(*args, **kwargs): 74 | closed_jaxpr = eqx.filter_make_jaxpr(f)(*args, **kwargs)[0] 75 | jaxpr = closed_jaxpr.jaxpr 76 | state = SourcerorState(use_jax_typing=use_jax_typing) 77 | try: 78 | name = f.__name__ 79 | except AttributeError: 80 | name = "unknown" 81 | node = jaxpr_to_py_ast(state, jaxpr, fn_name=name, unique_fn_name=False) 82 | node = _maybe_wrap_fn_for_leaves(node, f, len(args) + len(kwargs)) 83 | return _render_ast(node) 84 | 85 | 86 | return return_fn 87 | 88 | 89 | def _render_ast(node): 90 | ast.fix_missing_locations(node) 91 | source = ast.unparse(node) 92 | return source 93 | 94 | 95 | def register_prim_handler(prim_name, handler): 96 | """ 97 | Register a handler for a primitive for automin 98 | :param prim_name: 99 | :param handler: 100 | :return: 101 | """ 102 | if prim_name in prim_to_python: 103 | warnings.warn(f"Overwriting handler for primitive {prim_name}") 104 | prim_to_python[prim_name] = handler 105 | 106 | 107 | def primitive_handler(prim_name): 108 | """ 109 | Decorator to register a handler for a primitive. 110 | :param prim_name: 111 | :return: 112 | """ 113 | def decorator(fn): 114 | register_prim_handler(prim_name, fn) 115 | return fn 116 | return decorator 117 | 118 | 119 | def _assign_stmt(call_expr: Callable): 120 | """ 121 | Create a handler for a primitive that is a simple assignment. 122 | :param call_expr: 123 | :return: 124 | """ 125 | def binop_fn(state, eqn): 126 | invars = [_astify_atom(state, v) for v in eqn.invars] 127 | outvars = _astify_outvars(state, eqn.outvars) 128 | return ast.Assign(outvars, call_expr(*invars, 129 | **{k: _astify_value(v) for k, v in eqn.params.items()} 130 | )) 131 | return binop_fn 132 | 133 | def _binop_fn(op: ast.operator): 134 | return _assign_stmt(lambda x, y: ast.BinOp(left=x, op=op, right=y)) 135 | 136 | def _cmpop_fn(op: ast.cmpop): 137 | return _assign_stmt(lambda x, y: ast.Compare(left=x, ops=[op], comparators=[y])) 138 | 139 | 140 | def normal_fn(fn_name): 141 | """ 142 | Create a handler for a normal function call. 143 | :param fn_name: 144 | :return: 145 | """ 146 | return _assign_stmt(lambda *args, **kwargs: ast.Call( 147 | func=ast.Name(id=fn_name, ctx=ast.Load()), 148 | args=list(args), 149 | keywords=[ast.keyword(arg=k, value=v) for k, v in kwargs.items()] 150 | )) 151 | 152 | 153 | 154 | def _reduce_fn(fn_name: str): 155 | def reduce_fn_inner(state: SourcerorState, eqn): 156 | invars = [_astify_atom(state, v) for v in eqn.invars] 157 | outvars = _astify_outvars(state, eqn.outvars) 158 | if eqn.params: 159 | params = eqn.params.copy() 160 | params['axis'] = tuple(params['axes']) 161 | del params['axes'] 162 | call_op = ast.Call( 163 | func=ast.Name(id=fn_name, ctx=ast.Load()), 164 | args=invars, 165 | keywords=[ast.keyword(arg=k, value=_astify_value(v)) for k, v in params.items()] 166 | ) 167 | else: 168 | call_op = ast.Call( 169 | func=ast.Name(id=fn_name, ctx=ast.Load()), 170 | args=invars, 171 | keywords=[] 172 | ) 173 | 174 | return ast.Assign(outvars, call_op) 175 | 176 | return reduce_fn_inner 177 | 178 | 179 | prim_to_python = { 180 | } 181 | 182 | register_prim_handler('add', _binop_fn(ast.Add())) 183 | register_prim_handler('sub', _binop_fn(ast.Sub())) 184 | register_prim_handler('mul', _binop_fn(ast.Mult())) 185 | register_prim_handler('div', _binop_fn(ast.Div())) 186 | register_prim_handler('lt', _cmpop_fn(ast.Lt())) 187 | register_prim_handler('gt', _cmpop_fn(ast.Gt())) 188 | register_prim_handler('le', _cmpop_fn(ast.LtE())) 189 | register_prim_handler('ge', _cmpop_fn(ast.GtE())) 190 | register_prim_handler('eq', _cmpop_fn(ast.Eq())) 191 | register_prim_handler('ne', _cmpop_fn(ast.NotEq())) 192 | # register_prim_handler('min', normal_fn('jax.lax.min')) 193 | # register_prim_handler('max', normal_fn('jax.lax.max')) 194 | # register_prim_handler('select_n', normal_fn('jax.lax.select_n')) 195 | # register_prim_handler('squeeze', normal_fn('jax.lax.squeeze')) 196 | # register_prim_handler('broadcast', normal_fn('jax.lax.broadcast')) 197 | register_prim_handler('reduce_sum', _reduce_fn('jnp.sum')) 198 | # register_prim_handler('transpose', normal_fn('jax.lax.transpose')) 199 | # register_prim_handler('clamp', normal_fn('jax.lax.clamp')) 200 | 201 | normal_fns = { 202 | 'min': 'jax.lax.min', 203 | 'max': 'jax.lax.max', 204 | 'select_n': 'jax.lax.select_n', 205 | 'squeeze': 'jax.lax.squeeze', 206 | 'broadcast': 'jax.lax.broadcast', 207 | 'transpose': 'jax.lax.transpose', 208 | 'clamp': 'jax.lax.clamp', 209 | # 'reduce_sum': 'jnp.sum', 210 | 'reduce_max': 'jnp.max', 211 | 'reduce_min': 'jnp.min', 212 | 'is_finite': 'jnp.isfinite', 213 | # misc jax.lax functions 214 | 'integer_pow': 'jax.lax.integer_pow', 215 | 'stop_gradient': 'jax.lax.stop_gradient', 216 | 'neg': 'jnp.negative', 217 | 'abs': 'jnp.abs', 218 | 'sin': 'jnp.sin', 219 | 'cos': 'jnp.cos', 220 | 'tan': 'jnp.tan', 221 | 'asin': 'jnp.arcsin', 222 | 'acos': 'jnp.arccos', 223 | 'atan': 'jnp.arctan', 224 | 'sinh': 'jnp.sinh', 225 | 'cosh': 'jnp.cosh', 226 | 'tanh': 'jnp.tanh', 227 | 'asinh': 'jnp.arcsinh', 228 | 'acosh': 'jnp.arccosh', 229 | 'atanh': 'jnp.arctanh', 230 | 'exp': 'jnp.exp', 231 | 'log': 'jnp.log', 232 | 'log1p': 'jnp.log1p', 233 | 'expm1': 'jnp.expm1', 234 | 'sqrt': 'jnp.sqrt', 235 | 'square': 'jnp.square', 236 | 'reciprocal': 'jnp.reciprocal', 237 | 'sign': 'jnp.sign', 238 | 'rsqrt': 'jax.lax.rsqrt', 239 | # 'concatenate': 'jnp.concatenate', 240 | } 241 | 242 | 243 | 244 | for k, v in normal_fns.items(): 245 | register_prim_handler(k, normal_fn(v)) 246 | 247 | 248 | @primitive_handler('cumsum') 249 | def _astify_cumsum(state, eqn): 250 | invars = [_astify_atom(state, v) for v in eqn.invars] 251 | outvars = _astify_outvars(state, eqn.outvars) 252 | axis = eqn.params['axis'] 253 | reverse = eqn.params['reverse'] 254 | 255 | if reverse: 256 | return ast.Assign(outvars, ast.Call( 257 | func=ast.Name(id='jax.lax.cumsum', ctx=ast.Load()), 258 | args=[invars[0]], 259 | keywords=[ast.keyword(arg='axis', value=_astify_value(axis)), ast.keyword(arg='reverse', value=ast.NameConstant(value=True))] 260 | )) 261 | else: 262 | return ast.Assign(outvars, ast.Call( 263 | func=ast.Name(id='jnp.cumsum', ctx=ast.Load()), 264 | args=[invars[0]], 265 | keywords=[ast.keyword(arg='axis', value=_astify_value(axis))] 266 | )) 267 | 268 | 269 | @primitive_handler('cumprod') 270 | def _astify_cumprod(state, eqn): 271 | invars = [_astify_atom(state, v) for v in eqn.invars] 272 | outvars = _astify_outvars(state, eqn.outvars) 273 | axis = eqn.params['axis'] 274 | reverse = eqn.params['reverse'] 275 | 276 | if reverse: 277 | return ast.Assign(outvars, ast.Call( 278 | func=ast.Name(id='jax.lax.cumprod', ctx=ast.Load()), 279 | args=[invars[0]], 280 | keywords=[ast.keyword(arg='axis', value=_astify_value(axis)), ast.keyword(arg='reverse', value=ast.NameConstant(value=True))] 281 | )) 282 | else: 283 | return ast.Assign(outvars, ast.Call( 284 | func=ast.Name(id='jnp.cumprod', ctx=ast.Load()), 285 | args=[invars[0]], 286 | keywords=[ast.keyword(arg='axis', value=_astify_value(axis))] 287 | )) 288 | 289 | 290 | @primitive_handler('concatenate') 291 | def _astify_concatenate(state, eqn): 292 | invars = [_astify_atom(state, v) for v in eqn.invars] 293 | outvars = _astify_outvars(state, eqn.outvars) 294 | axis = eqn.params['dimension'] 295 | return ast.Assign(outvars, ast.Call( 296 | func=ast.Attribute(value=ast.Name(id='jnp', ctx=ast.Load()), attr='concatenate', ctx=ast.Load()), 297 | args=[ast.Tuple(elts=invars, ctx=ast.Load())], 298 | keywords=[ast.keyword(arg='axis', value=_astify_value(axis))] 299 | )) 300 | 301 | 302 | 303 | def _maybe_wrap_fn_for_leaves(node, f, num_args): 304 | if len(node.args.args) == num_args: 305 | return node 306 | 307 | wrapped_node = ast.FunctionDef(name=f.__name__, 308 | args=ast.arguments( 309 | args=[], 310 | vararg=ast.arg(arg="args", annotation=None), 311 | kwarg=ast.arg(arg="kwargs", annotation=None), 312 | kwonlyargs=[], kw_defaults=[], defaults=[], 313 | posonlyargs=[]), 314 | body=[ 315 | node, 316 | ast.Return(ast.Call(func=ast.Name(id=node.name, ctx=ast.Load()), 317 | args=[ast.Starred(ast.Call(func=ast.Attribute(value=ast.Name(id="jax", ctx=ast.Load()), 318 | attr="tree_leaves", 319 | ctx=ast.Load()), 320 | args=[ast.Tuple(elts=[ast.Name(id="args", ctx=ast.Load()), 321 | ast.Name(id="kwargs", ctx=ast.Load())], 322 | ctx=ast.Load())], 323 | keywords=[]))], 324 | keywords=[])), 325 | ], 326 | decorator_list=[]) 327 | 328 | return wrapped_node 329 | 330 | 331 | def _astify_jax_typing_annotation(state, aval): 332 | # jaxtyping annotations are like Float32[Array, "128 32"] 333 | if not state.use_jax_typing: 334 | return None 335 | 336 | dtype = aval.dtype 337 | shape = aval.shape 338 | 339 | if dtype == jnp.float32: 340 | dtype_str = "Float32" 341 | elif dtype == jnp.float64: 342 | dtype_str = "Float64" 343 | elif dtype == jnp.int32: 344 | dtype_str = "Int32" 345 | elif dtype == jnp.int64: 346 | dtype_str = "Int64" 347 | elif dtype == jnp.bool_: 348 | dtype_str = "Bool" 349 | elif dtype == jnp.bfloat16: 350 | dtype_str = "BFloat16" 351 | elif dtype == jnp.float16: 352 | dtype_str = "Float16" 353 | else: 354 | warnings.warn(f"Unknown dtype for jaxtyping {dtype}") 355 | dtype_str = "Shaped" 356 | 357 | if len(shape) == 0: 358 | return ast.Subscript( 359 | value=ast.Name(id="Scalar", ctx=ast.Load()), 360 | slice=ast.Name(id=dtype_str, ctx=ast.Load()), 361 | ) 362 | 363 | shape_str = " ".join(str(s) for s in shape) 364 | 365 | return ast.Subscript( 366 | value=ast.Name(id=dtype_str, ctx=ast.Load()), 367 | slice=ast.Tuple(elts=[ast.Name(id="Array", ctx=ast.Load()), ast.Str(shape_str)], ctx=ast.Load()), 368 | ) 369 | 370 | 371 | 372 | def jaxpr_to_py_ast(state: SourcerorState, jaxpr, fn_name: Optional[str] = None, *, unique_fn_name: bool = True): 373 | if isinstance(jaxpr, ClosedJaxpr): 374 | jaxpr = jaxpr.jaxpr 375 | if fn_name is None or unique_fn_name: 376 | fn_name = state.heuristic_fn_skolem(jaxpr, default=fn_name) 377 | 378 | # Generate argument declarations 379 | jaxpr = constant_fold_jaxpr(jaxpr) 380 | annotations = [_astify_jax_typing_annotation(state, v.aval) for v in jaxpr.invars] 381 | ast_args = [ast.arg(arg=state.str_name(var), annotation=ann) for var, ann in zip(jaxpr.invars, annotations)] 382 | ast_args = ast.arguments(args=ast_args, vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[], posonlyargs=[]) 383 | 384 | stmts = [] 385 | 386 | # Generate body of the function 387 | for eqn in jaxpr.eqns: 388 | prim = str(eqn.primitive) 389 | if prim in prim_to_python: 390 | eqn_stmts = prim_to_python[prim](state, eqn) 391 | else: 392 | try: 393 | eqn_stmts = normal_fn(prim)(state, eqn) 394 | except Exception: 395 | raise ValueError(f"Could not handle primitive {prim}") 396 | 397 | if isinstance(eqn_stmts, list): 398 | stmts.extend(eqn_stmts) 399 | else: 400 | stmts.append(eqn_stmts) 401 | 402 | # Generate return statement 403 | if len(jaxpr.outvars) == 1: 404 | returns = state.name(jaxpr.outvars[0]) 405 | else: 406 | returns = ast.Tuple(elts=[_name_or_literal(state, var) for var in jaxpr.outvars], ctx=ast.Load()) 407 | stmts.append(ast.Return(value=returns)) 408 | 409 | return ast.FunctionDef(name=fn_name, args=ast_args, body=stmts, decorator_list=[]) 410 | 411 | 412 | def _name_or_literal(state, var): 413 | if isinstance(var, Literal): 414 | return _astify_value(var.val) 415 | else: 416 | return state.name(var) 417 | 418 | 419 | def constant_fold_jaxpr(jaxpr: jax.core.Jaxpr): 420 | """ 421 | Given a jaxpr, return a new jaxpr with all constant folding done. 422 | """ 423 | return partial_eval_jaxpr(jaxpr, {}, elide_unused_invars=False) 424 | 425 | def partial_eval_jaxpr(jaxpr, env, elide_unused_invars): 426 | env = env.copy() 427 | new_eqns = [] 428 | 429 | def read(var): 430 | if isinstance(var, Literal): 431 | return var.val 432 | else: 433 | return env.get(var, None) 434 | 435 | def read_or_self(var): 436 | out = read(var) 437 | if out is None: 438 | return var 439 | elif isinstance(out, Var): 440 | return out 441 | elif isinstance(out, Literal): 442 | return Literal(out.val, var.aval) 443 | else: 444 | assert not isinstance(out, Jaxpr) 445 | return Literal(out, var.aval) 446 | 447 | for eqn in jaxpr.eqns: 448 | vals = [read(var) for var in eqn.invars] 449 | if eqn.primitive.name in constant_fold_blacklist: 450 | new_eqns.append(eqn) 451 | elif all(val is not None for val in vals): 452 | # go ahead and eval it 453 | out = _eval_eqn(eqn, vals) 454 | 455 | # two options: either it's a jaxpr result (partial eval) or it's a value or a list of values 456 | if isinstance(out, Jaxpr): 457 | # we need to inline this 458 | new_eqns.extend(out.eqns) 459 | out = out.outvars 460 | elif not isinstance(out, tuple) and not isinstance(out, list): 461 | out = (out,) 462 | 463 | for var, val in zip(eqn.outvars, out): 464 | assert not isinstance(val, Jaxpr) 465 | if isinstance(val, Literal): 466 | env[var] = val.val 467 | else: 468 | env[var] = val 469 | else: 470 | new_eqns.append(eqn) 471 | 472 | # now that we've evaled everything, inline all the constants 473 | out_eqns = [] 474 | for eqn in new_eqns: 475 | eqn = eqn.replace(invars=tuple(read_or_self(var) for var in eqn.invars)) 476 | out_eqns.append(eqn) 477 | 478 | invars_still_used = IdentitySet() 479 | for eqn in out_eqns: 480 | for var in eqn.invars: 481 | invars_still_used.add(var) 482 | 483 | if elide_unused_invars: 484 | invars = tuple(var for var in jaxpr.invars if var in invars_still_used) 485 | else: 486 | invars = jaxpr.invars 487 | 488 | # sub in any constants for outvars 489 | outvars = tuple(read_or_self(var) for var in jaxpr.outvars) 490 | 491 | return jaxpr.replace(eqns=out_eqns, outvars=outvars, invars=invars) 492 | 493 | 494 | def _eval_eqn(eqn, vals) -> Union[Jaxpr, tuple, list, jnp.ndarray]: 495 | if eqn.primitive.name == "closed_call": 496 | assert eqn.primitive.call_primitive == True 497 | assert eqn.primitive.map_primitive == False 498 | 499 | out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr, {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)}, elide_unused_invars=True) 500 | elif eqn.primitive.name == "scan": 501 | out = eqn.primitive.bind(*vals, **eqn.params) 502 | elif isinstance(eqn.primitive, CustomJVPCallPrimitive): 503 | # out = eqn.primitive.bind(*vals, **eqn.params) 504 | closed_jaxpr = eqn.params['call_jaxpr'] 505 | out = partial_eval_jaxpr(closed_jaxpr.jaxpr, {var: val for var, val in zip(closed_jaxpr.jaxpr.invars, vals)}, elide_unused_invars=True) 506 | else: 507 | out = eqn.primitive.bind(*vals, **eqn.params) 508 | return out 509 | 510 | 511 | @primitive_handler('dot_general') 512 | def _astify_dot_general(state, eqn): 513 | x, y = eqn.invars 514 | d = eqn.params['dimension_numbers'] 515 | precision = eqn.params['precision'] 516 | preferred_element_type = eqn.params['preferred_element_type'] 517 | 518 | has_dtype = preferred_element_type is None or x.aval.dtype == y.aval.dtype == preferred_element_type 519 | 520 | # recognize simple matmul case 521 | if d == (((1,), (0,)), ((), ())) and precision == None: 522 | invars = [_astify_atom(state, x), _astify_atom(state, y)] 523 | outvars = _astify_outvars(state, eqn.outvars) 524 | out = ast.Assign(targets=outvars, value=ast.Call(func=ast.Attribute(value=ast.Name(id='jnp', ctx=ast.Load()), attr='matmul', ctx=ast.Load()), args=invars, keywords=[])) 525 | if not has_dtype: 526 | out = ast.Assign(targets=outvars, value=ast.Call(func=ast.Attribute(value=out.value, attr='astype', ctx=ast.Load()), args=[_astify_value(preferred_element_type)], keywords=[])) 527 | 528 | return out 529 | 530 | # handle einsum case 531 | contract_dims, batch_dims = d 532 | in_specs = [['0']*x.aval.ndim, ['0']*y.aval.ndim] # the 0's will be replaced with letters 533 | out_spec = '' 534 | letter = ord('a') 535 | 536 | # output ordering is batch dims in order, then remaining lhs, then remaining rhs 537 | for i in range(len(batch_dims[0])): 538 | in_specs[0][batch_dims[0][i]] = chr(letter) 539 | in_specs[1][batch_dims[1][i]] = chr(letter) 540 | out_spec += chr(letter) 541 | letter += 1 542 | 543 | for i in range(len(contract_dims[0])): 544 | in_specs[0][contract_dims[0][i]] = chr(letter) 545 | in_specs[1][contract_dims[1][i]] = chr(letter) 546 | letter += 1 547 | 548 | # remaining dims are just the rest of the dims 549 | for i in range(x.aval.ndim): 550 | if in_specs[0][i] == '0': 551 | in_specs[0][i] = chr(letter) 552 | out_spec += chr(letter) 553 | letter += 1 554 | 555 | for i in range(y.aval.ndim): 556 | if in_specs[1][i] == '0': 557 | in_specs[1][i] = chr(letter) 558 | out_spec += chr(letter) 559 | letter += 1 560 | 561 | 562 | 563 | final_spec = f"{''.join(in_specs[0])},{''.join(in_specs[1])}->{out_spec}" 564 | invars = [_astify_value(final_spec), _astify_atom(state, x), _astify_atom(state, y)] 565 | outvars = _astify_outvars(state, eqn.outvars) 566 | keywords = [] 567 | if precision is not None: 568 | keywords.append(ast.keyword(arg='precision', value=_astify_value(precision))) 569 | if preferred_element_type is not None: 570 | keywords.append(ast.keyword(arg='preferred_element_type', value=_astify_value(preferred_element_type))) 571 | 572 | return ast.Assign(targets=outvars, value=ast.Call(func=ast.Attribute(value=ast.Name(id='jnp', ctx=ast.Load()), attr='einsum', ctx=ast.Load()), args=invars, 573 | keywords=keywords)) 574 | 575 | 576 | # invars = [_astify_atom(state, x), _astify_atom(state, y), _astify_value(d), _astify_value(precision), 577 | # _astify_value(preferred_element_type)] 578 | # outvars = _astify_outvars(state, eqn.outvars) 579 | # return ast.Assign(targets=outvars, value=ast.Call(func=ast.Attribute(value=ast.Name(id='jax.lax', ctx=ast.Load()), attr='dot_general', ctx=ast.Load()), args=invars, keywords=[])) 580 | 581 | @primitive_handler('dynamic_slice') 582 | def _sourcify_dynamic_slice(state, eqn): 583 | sliced = eqn.invars[0] 584 | invars = ast.Tuple(elts=[_astify_atom(state, var) for var in eqn.invars[1:]], ctx=ast.Load()) 585 | outvars = _astify_outvars(state, eqn.outvars) 586 | params = [ast.keyword(arg=k, value=_astify_value(v)) for k, v in eqn.params.items()] 587 | return ast.Assign(targets=outvars, value=ast.Call( 588 | func=ast.Attribute( 589 | value=ast.Name(id='jax.lax', ctx=ast.Load()), 590 | attr='dynamic_slice', 591 | ctx=ast.Load() 592 | ), 593 | args=[_astify_atom(state, sliced), invars], 594 | keywords=params 595 | )) 596 | 597 | 598 | @primitive_handler('slice') 599 | def _sourcify_slice(state, eqn): 600 | sliced = eqn.invars[0] 601 | # invars = ast.Tuple(elts=[_astify_atom(state, var) for var in eqn.invars[1:]], ctx=ast.Load()) 602 | outvars = _astify_outvars(state, eqn.outvars) 603 | start_indices = eqn.params['start_indices'] 604 | limit_indices = eqn.params['limit_indices'] 605 | strides = eqn.params['strides'] 606 | if strides is None: 607 | strides = (None,) * len(start_indices) 608 | indices = [_astify_value(slice(s, e, stride)) for s, e, stride in zip(start_indices, limit_indices, strides)] 609 | # params = [ast.keyword(arg=k, value=_astify_value(v)) for k, v in eqn.params.items()] 610 | return ast.Assign(targets=outvars, value=ast.Subscript( 611 | value=_astify_atom(state, sliced), 612 | slice=ast.Tuple(elts=indices, ctx=ast.Load()), 613 | ctx=ast.Load() 614 | )) 615 | 616 | 617 | @primitive_handler('dynamic_update_slice') 618 | def _sourcify_dynamic_update_slice(state, eqn): 619 | sliced = eqn.invars[0] 620 | # the first two arguments are the sliced array and the update array 621 | # the remaining are start indices and should be packaged into a tuple 622 | target = _astify_atom(state, eqn.invars[0]) 623 | update = _astify_atom(state, eqn.invars[1]) 624 | start_indices = maybe_tuple_vars([_astify_atom(state, var) for var in eqn.invars[2:]]) 625 | outvars = _astify_outvars(state, eqn.outvars) 626 | 627 | return ast.Assign(targets=outvars, value=ast.Call( 628 | func=ast.Attribute( 629 | value=ast.Name(id='jax.lax', ctx=ast.Load()), 630 | attr='dynamic_update_slice', 631 | ctx=ast.Load() 632 | ), 633 | args=[target, update, start_indices], 634 | keywords=[] 635 | )) 636 | 637 | 638 | @primitive_handler('convert_element_type') 639 | def _astify_convert_element_type(state, eqn): 640 | # now we use ast 641 | outvars = _astify_outvars(state, eqn.outvars) 642 | assert len(eqn.invars) == 1 643 | invar = _astify_atom(state, eqn.invars[0]) 644 | dtype = _astify_value(eqn.params['new_dtype']) 645 | # return ast.Assign(targets=outvars, value=ast.Call( 646 | # func=ast.Attribute( 647 | # value=ast.Name(id='jax.lax', ctx=ast.Load()), 648 | # attr='convert_element_type', 649 | # ctx=ast.Load() 650 | # ), 651 | # args=[invars], 652 | # keywords=params 653 | # )) 654 | return ast.Assign(targets=outvars, value=ast.Call( 655 | func=ast.Attribute( 656 | value=invar, 657 | attr='astype', 658 | ctx=ast.Load() 659 | ), 660 | args=[dtype], 661 | keywords=[] 662 | )) 663 | 664 | def is_array(arr): 665 | return isinstance(arr, (np.ndarray, np.generic, jnp.ndarray)) 666 | 667 | 668 | def _astify_array(value): 669 | assert is_array(value) 670 | if isinstance(value, np.int64): 671 | return ast.Constant(value=int(value)) 672 | 673 | if value.ndim == 0 and value.dtype in (jnp.float32, jnp.int32, jnp.bool_, jnp.int64): 674 | return ast.Constant(value=value.item()) 675 | 676 | if value.ndim == 0: 677 | dtype_value = _astify_value(value.dtype) 678 | return ast.Call( 679 | dtype_value, 680 | args=[ast.Constant(value=value.item())], 681 | keywords=[], 682 | ) 683 | 684 | values = value.tolist() 685 | 686 | def rec_astify_list(values): 687 | if isinstance(values, list): 688 | return ast.List(elts=[rec_astify_list(val) for val in values], ctx=ast.Load()) 689 | else: 690 | return ast.Constant(value=values) 691 | 692 | return ast.Call( 693 | func=ast.Attribute( 694 | value=ast.Name(id='jnp', ctx=ast.Load()), 695 | attr='array', 696 | ctx=ast.Load() 697 | ), 698 | args=[rec_astify_list(values)], 699 | keywords=[ast.keyword(arg='dtype', value=_astify_value(value.dtype))] 700 | ) 701 | 702 | def _astify_atom(state: SourcerorState, var: Union[Literal, Var]): 703 | if isinstance(var, Literal): 704 | return _astify_value(var.val) 705 | elif isinstance(var, Var): 706 | return state.name(var) 707 | else: 708 | raise NotImplementedError() 709 | 710 | def _astify_value(value): 711 | assert not isinstance(value, (Literal, Var)) 712 | 713 | if is_array(value): 714 | return _astify_array(value) 715 | elif isinstance(value, (int, bool, float, str, type(None))): 716 | return ast.Constant(value=value) 717 | elif isinstance(value, (tuple, list)): 718 | return ast.Tuple(elts=[_astify_value(v) for v in value], ctx=ast.Load()) 719 | elif isinstance(value, jnp.dtype): 720 | # return ast.Call(func=ast.Attribute(value=ast.Name(id="jnp", ctx=ast.Load()), attr='dtype', ctx=ast.Load()), args=[ast.Constant(value=str(value))], keywords=[]) 721 | if value.name in ('float32', 'float64', 'int32', 'int64', 'bfloat16', 'float16'): 722 | # return ast.Constant(value=getattr(jnp, value.name)) 723 | return ast.Attribute(value=ast.Name(id="jnp", ctx=ast.Load()), attr=value.name, ctx=ast.Load()) 724 | elif value.name == 'bool': 725 | return ast.Attribute(value=ast.Name(id="jnp", ctx=ast.Load()), attr='bool_', ctx=ast.Load()) 726 | else: 727 | return ast.Call(func=ast.Attribute(value=ast.Name(id="jnp", ctx=ast.Load()), attr='dtype', ctx=ast.Load()), args=[ast.Constant(value=str(value))], keywords=[]) 728 | elif value is _UNSPECIFIED: 729 | return ast.Attribute(value=ast.Name(id='jax.experimental.pjit', ctx=ast.Load()), attr='_UNSPECIFIED', ctx=ast.Load()) 730 | elif isinstance(value, jax.lax.GatherScatterMode): 731 | return ast.Attribute(value=ast.Name("jax.lax.GatherScatterMode", ctx=ast.Load()), attr=value.name, ctx=ast.Load()) 732 | elif isinstance(value, enum.Enum): 733 | return ast.Attribute(value=ast.Name(id=value.__class__.__qualname__, ctx=ast.Load()), attr=value.name, ctx=ast.Load()) 734 | elif isinstance(value, slice): 735 | return ast.Call( 736 | func=ast.Name(id='slice', ctx=ast.Load()), 737 | args=[_astify_value(value.start), _astify_value(value.stop), _astify_value(value.step)], 738 | keywords=[] 739 | ) 740 | elif isinstance(value, NamedSharding): 741 | # jax.sharding.NamedSharding(mesh=, spec=PartitionSpec(*) 742 | return ast.Call( 743 | func=ast.Attribute(value=ast.Name(id='jax.sharding', ctx=ast.Load()), attr='NamedSharding', ctx=ast.Load()), 744 | args=[_astify_value(value.mesh), _astify_value(value.spec)], 745 | keywords=[] 746 | ) 747 | elif isinstance(value, jax.sharding.Mesh): 748 | return ast.Load(name="TODO_mesh") 749 | elif isinstance(value, jax.sharding.PartitionSpec): 750 | return ast.Load(name="TODO_partition_spec") 751 | elif isinstance(value, bytes): 752 | return ast.Constant(value=value.decode('utf-8')) 753 | else: 754 | warnings.warn(f"Unknown value type {type(value)}") 755 | raise NotImplementedError(f"Unknown value type {type(value)}") 756 | return ast.parse(repr(value)).body[0] 757 | 758 | 759 | def _astify_outvars(state, outvars): 760 | out = [state.name(v, ctx=ast.Store()) for v in outvars] 761 | if len(out) == 1: 762 | return out 763 | else: 764 | return [ast.Tuple(elts=out, ctx=ast.Store())] 765 | 766 | def maybe_tuple_vars(vars): 767 | if len(vars) == 1: 768 | return vars[0] 769 | else: 770 | return ast.Tuple(elts=vars, ctx=ast.Load()) 771 | 772 | 773 | def maybe_untuple_vars(var, is_tuple): 774 | if is_tuple: 775 | return ast.Starred(value=var, ctx=ast.Load()) 776 | else: 777 | return var 778 | 779 | 780 | @primitive_handler('scan') 781 | def _astify_scan(state, eqn): 782 | assert eqn.primitive.name == 'scan' 783 | 784 | # the args to scan are [constants, carry, xs] 785 | # constants aren't exposed in the Python API, so we need to handle them specially (we use a lambda) 786 | num_consts = eqn.params['num_consts'] 787 | num_carry = eqn.params['num_carry'] 788 | 789 | # TODO: bring back map 790 | # if num_carry == 0: 791 | # this is a map 792 | # return _astify_map(eqn) 793 | 794 | constant_args = eqn.invars[:num_consts] 795 | carries = eqn.invars[num_consts:num_consts + num_carry] 796 | xs = eqn.invars[num_consts + num_carry:] 797 | 798 | jaxpr = eqn.params['jaxpr'].jaxpr 799 | 800 | if num_consts != 0: 801 | # we want to construct an environment where we partial eval the function using the constants as the env 802 | env = dict(zip(jaxpr.invars, constant_args)) 803 | jaxpr = partial_eval_jaxpr(jaxpr, env, elide_unused_invars=True) 804 | 805 | fn_ast = jaxpr_to_py_ast(state, jaxpr) 806 | fn_name = fn_ast.name 807 | 808 | length = _astify_value(eqn.params['length']) 809 | unroll = _astify_value(eqn.params['unroll']) 810 | reverse = _astify_value(eqn.params['reverse']) 811 | 812 | stmts = [] 813 | 814 | if num_carry != 1 or len(jaxpr.invars) != 2: 815 | # what we want is something like: 816 | # fn_name = lambda carry, xs: fn_name(*carry, *xs) 817 | # jax.lax.scan(fn_name, (carries...), (xs...)) 818 | 819 | modified_signature = ast.arguments( 820 | args=[ast.arg(arg='carry'), ast.arg(arg='x')], 821 | vararg=None, 822 | kwonlyargs=[], 823 | kw_defaults=[], 824 | kwarg=None, 825 | defaults=[], 826 | posonlyargs=[] 827 | ) 828 | 829 | initial_assign = ast.Assign( 830 | targets=[ast.Tuple(elts=[ast.Name(a.arg) for a in fn_ast.args.args], ctx=ast.Store())], 831 | value=ast.Tuple(elts=[maybe_untuple_vars(ast.Name(id='carry', ctx=ast.Load()), num_carry != 1), 832 | maybe_untuple_vars(ast.Name(id='x', ctx=ast.Load()), len(xs) != 1)]) 833 | ) 834 | 835 | fn_return = fn_ast.body[-1] 836 | assert isinstance(fn_return, ast.Return) 837 | 838 | fn_return_value = fn_return.value 839 | 840 | if isinstance(fn_return_value, ast.Tuple): 841 | fn_return_value = fn_return_value.elts 842 | ret_carries = maybe_tuple_vars(fn_return_value[:num_carry]) 843 | ret_ys = maybe_tuple_vars(fn_return_value[num_carry:]) 844 | elif num_carry == 0: 845 | ret_carries = _astify_value(()) 846 | ret_ys = fn_return_value 847 | else: 848 | ret_carries = fn_return_value 849 | ret_ys = _astify_value(()) 850 | 851 | scan_return = ast.Return( 852 | value = ast.Tuple(elts=[ret_carries, ret_ys], ctx=ast.Load()) 853 | ) 854 | 855 | new_body = [initial_assign] + list(fn_ast.body[:-1]) + [scan_return] 856 | 857 | fn_ast = ast.FunctionDef( 858 | name=fn_name, 859 | args=modified_signature, 860 | body=new_body, 861 | decorator_list=[] 862 | ) 863 | 864 | stmts.append(fn_ast) 865 | 866 | scan_call = ast.Assign( 867 | # targets=_astify_outvars(eqn.outvars), 868 | targets=[ast.Tuple(elts=[ast.Name(id='final_carry', ctx=ast.Store()), ast.Name(id='ys', ctx=ast.Store())], ctx=ast.Store())], 869 | value=ast.Call( 870 | func=ast.Name(id='jax.lax.scan', ctx=ast.Load()), 871 | args=[ast.Name(id=fn_name, ctx=ast.Load()), 872 | maybe_tuple_vars([_astify_atom(state, v) for v in carries]), 873 | maybe_tuple_vars([_astify_atom(state, v) for v in xs])], 874 | keywords=[ast.keyword(arg='length', value=length), ast.keyword(arg='unroll', value=unroll), ast.keyword(arg='reverse', value=reverse)] 875 | )) 876 | stmts.append(scan_call) 877 | 878 | if num_carry > 0: 879 | assign_carry = ast.Assign( 880 | targets=_astify_outvars(state, eqn.outvars[:num_carry]), 881 | value=ast.Name(id='final_carry', ctx=ast.Load()) 882 | ) 883 | 884 | stmts.append(assign_carry) 885 | 886 | if num_carry < len(eqn.outvars): 887 | assign_ys = ast.Assign( 888 | targets=_astify_outvars(state, eqn.outvars[num_carry:]), 889 | value=ast.Name(id='ys', ctx=ast.Load()) 890 | ) 891 | 892 | stmts.append(assign_ys) 893 | else: 894 | stmts.append(fn_ast) 895 | 896 | scan_call = ast.Assign( 897 | targets=_astify_outvars(state, eqn.outvars), 898 | value=ast.Call( 899 | func=ast.Name(id='jax.lax.scan', ctx=ast.Load()), 900 | args=[ast.Name(id=fn_name, ctx=ast.Load())] + [_astify_atom(state, v) for v in eqn.invars], 901 | keywords=[ast.keyword(arg='length', value=length), ast.keyword(arg='unroll', value=unroll), ast.keyword(arg='reverse', value=reverse)] 902 | )) 903 | 904 | stmts.append(scan_call) 905 | 906 | return stmts 907 | 908 | def _astify_map(state, eqn): 909 | assert eqn.primitive.name == 'scan' 910 | assert eqn.params['num_carry'] == 0 911 | 912 | jaxpr = eqn.params['jaxpr'] 913 | 914 | fn_ast = jaxpr_to_py_ast(state, jaxpr) 915 | fn_name = fn_ast.name 916 | 917 | # map is a bit funny, because the jaxpr takes K args, but the jax.lax.map function takes a single tuple arg 918 | # so we need to use a lambda to redirect the call 919 | lam = ast.parse(f"lambda args: {fn_name}(*args)").body[0] 920 | 921 | assign = ast.Assign( 922 | targets=_astify_outvars(state, eqn.outvars), 923 | value=ast.Call( 924 | func=ast.Name(id='jax.lax.map', ctx=ast.Load()), 925 | args=[lam, ast.Tuple(elts=[_astify_atom(state, v) for v in eqn.invars], ctx=ast.Load())], 926 | keywords=[] 927 | )) 928 | 929 | return [fn_ast, assign] 930 | 931 | 932 | def _attempt_to_sniff_fn_name_for_jaxpr(jaxpr): 933 | # this is necessarily very hacky. 934 | eqns = jaxpr.eqns 935 | if len(eqns) == 0: 936 | return None 937 | source_info = eqns[0].source_info 938 | try: 939 | name = None 940 | for frame in user_frames(source_info): 941 | name = frame.function_name 942 | 943 | if name and "<" not in name: 944 | return name 945 | 946 | if not name: 947 | name = frame.file_name 948 | return name 949 | except: 950 | return None 951 | 952 | 953 | 954 | 955 | 956 | @primitive_handler('closed_call') 957 | def _astify_closed_call(state, eqn): 958 | # out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr, 959 | # {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)}) 960 | raw_jaxpr = eqn.params['call_jaxpr'].jaxpr 961 | literal_args = {k: v.val for k, v in zip(raw_jaxpr.invars, eqn.invars) if isinstance(v, Literal)} 962 | call_jaxpr = partial_eval_jaxpr(raw_jaxpr, literal_args, elide_unused_invars=False) 963 | fn_ast = jaxpr_to_py_ast(state, call_jaxpr) 964 | fn_name = fn_ast.name 965 | 966 | invars = [_astify_atom(state, v) for v in eqn.invars if not isinstance(v, Literal)] 967 | outvars = _astify_outvars(state, eqn.outvars) 968 | 969 | assign = ast.Assign( 970 | targets=outvars, 971 | value=ast.Call( 972 | func=ast.Name(id=fn_name, ctx=ast.Load()), 973 | args=invars, 974 | keywords=[] 975 | )) 976 | 977 | return [fn_ast, assign] 978 | 979 | @primitive_handler('pjit') 980 | def _astify_pjit(state, eqn): 981 | # this one's a real pain. 982 | # pjit's params are : 983 | # jaxpr 984 | # donated_invars: 985 | # in_shardings, out_shardings 986 | # resource env 987 | # name (yay) 988 | # keep_unused, inline (which we won't use) 989 | 990 | jaxpr = eqn.params['jaxpr'] 991 | donated_invars = eqn.params['donated_invars'] 992 | in_shardings = eqn.params['in_shardings'] 993 | out_shardings = eqn.params['out_shardings'] 994 | resource_env = eqn.params['resource_env'] 995 | name = eqn.params['name'] 996 | 997 | can_ignore_donated = not any(donated_invars) 998 | 999 | keywords = [] 1000 | 1001 | if in_shardings and any(s != jax.experimental.pjit._UNSPECIFIED for s in in_shardings): 1002 | in_shardings = _astify_value(in_shardings) 1003 | keywords.append(ast.keyword(arg='in_shardings', value=in_shardings)) 1004 | 1005 | if out_shardings and any(s != jax.experimental.pjit._UNSPECIFIED for s in out_shardings): 1006 | out_shardings = _astify_value(out_shardings) 1007 | keywords.append(ast.keyword(arg='out_shardings', value=out_shardings)) 1008 | 1009 | if not can_ignore_donated: 1010 | donated_invars = _astify_value(donated_invars) 1011 | keywords.append(ast.keyword(arg='donated_invars', value=donated_invars)) 1012 | 1013 | 1014 | # preprocess the function 1015 | fn_ast = jaxpr_to_py_ast(state, jaxpr) 1016 | fn_name = fn_ast.name 1017 | 1018 | jitted_fn = ast.Call( 1019 | func= 1020 | ast.Attribute( 1021 | ast.Name(id='jax', ctx=ast.Load()), 1022 | attr='jit'), 1023 | args=[ast.Name(id=fn_name, ctx=ast.Load())], 1024 | keywords=keywords 1025 | ) 1026 | 1027 | assign = ast.Assign( 1028 | targets=_astify_outvars(state, eqn.outvars), 1029 | value=ast.Call( 1030 | func=jitted_fn, 1031 | args=[_astify_atom(state, v) for v in eqn.invars], 1032 | keywords=[] 1033 | )) 1034 | 1035 | return [fn_ast, assign] 1036 | 1037 | 1038 | @primitive_handler('remat2') 1039 | def _astify_remat(state: SourcerorState, eqn): 1040 | # out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr, 1041 | # {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)}) 1042 | fn_ast = jaxpr_to_py_ast(state, constant_fold_jaxpr(eqn.params['jaxpr'])) 1043 | fn_name = fn_ast.name 1044 | 1045 | invars = [_astify_atom(state, v) for v in eqn.invars] 1046 | outvars = _astify_outvars(state, eqn.outvars) 1047 | 1048 | prevent_cse = _astify_value(eqn.params.get('prevent_cse', False)) 1049 | 1050 | policy = eqn.params.get('policy') 1051 | 1052 | if policy is not None: 1053 | warnings.warn(f"Remat2 policy {policy} is not supported.") 1054 | 1055 | has_args = prevent_cse 1056 | 1057 | # if we have args, we wrap checkpoint in a partial 1058 | if has_args: 1059 | checkpoint = ast.Call(ast.Name('partial'), [ast.Name(id='jax.checkpoint', ctx=ast.Load())], [ast.keyword(arg='prevent_cse', value=prevent_cse)]) 1060 | else: 1061 | checkpoint = ast.Name(id='jax.checkpoint', ctx=ast.Load()) 1062 | 1063 | # apply as a decorator 1064 | fn_ast.decorator_list.append(checkpoint) 1065 | 1066 | assign = ast.Assign( 1067 | targets=outvars, 1068 | value=ast.Call( 1069 | func=ast.Name(id=fn_name, ctx=ast.Load()), 1070 | args=invars, 1071 | keywords=[] 1072 | )) 1073 | 1074 | return [fn_ast, assign] 1075 | 1076 | 1077 | 1078 | 1079 | @primitive_handler('custom_vjp_call_jaxpr') 1080 | def _astify_custom_vjp_call_jaxpr(state, eqn): 1081 | # out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr, 1082 | # {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)}) 1083 | closed_jaxpr = eqn.params['fun_jaxpr'] 1084 | fn_ast = jaxpr_to_py_ast(state, closed_jaxpr) 1085 | fn_name = fn_ast.name 1086 | 1087 | invars = [_astify_atom(state, v) for v in eqn.invars] 1088 | outvars = _astify_outvars(state, eqn.outvars) 1089 | 1090 | lam = ast.Assign( 1091 | targets=[ast.Name(id=f"vjp_{fn_name}", ctx=ast.Store())], 1092 | value=ast.Lambda( 1093 | args=ast.arguments( 1094 | args=[ast.arg(arg='primals')], 1095 | vararg=None, 1096 | kwonlyargs=[], 1097 | kw_defaults=[], 1098 | kwarg=None, 1099 | defaults=[], 1100 | posonlyargs=[] 1101 | ), 1102 | body=ast.Call( 1103 | func=ast.Name(id=fn_name, ctx=ast.Load()), 1104 | args=[ast.Name(id='primals', ctx=ast.Load())], 1105 | keywords=[] 1106 | ) 1107 | ) 1108 | ) 1109 | 1110 | assign = ast.Assign( 1111 | targets=outvars, 1112 | value=ast.Call( 1113 | func=ast.Name(id=f"vjp_{fn_name}", ctx=ast.Load()), 1114 | args=invars, 1115 | keywords=[] 1116 | )) 1117 | 1118 | return [fn_ast, lam, assign] 1119 | 1120 | 1121 | @primitive_handler('custom_jvp_call') 1122 | def _astify_custom_jvp_call(state, eqn): 1123 | closed_jaxpr = eqn.params['call_jaxpr'] 1124 | fn_ast = jaxpr_to_py_ast(state, closed_jaxpr) 1125 | fn_name = fn_ast.name 1126 | 1127 | invars = [_astify_atom(state, v) for v in eqn.invars] 1128 | outvars = _astify_outvars(state, eqn.outvars) 1129 | 1130 | # lam = ast.Assign( 1131 | # targets=[ast.Name(id=f"jvp_{fn_name}", ctx=ast.Store())], 1132 | # value=ast.Lambda( 1133 | # args=ast.arguments( 1134 | # args=[ast.arg(arg='primals', annotation=None), ast.arg(arg='tangents', annotation=None)], 1135 | # vararg=None, 1136 | # kwonlyargs=[], 1137 | # kw_defaults=[], 1138 | # kwarg=None, 1139 | # defaults=[], 1140 | # posonlyargs=[] 1141 | # ), 1142 | # body=ast.Call( 1143 | # func=ast.Name(id=fn_name, ctx=ast.Load()), 1144 | # args=[ast.Name(id='primals', ctx=ast.Load()), ast.Name(id='tangents', ctx=ast.Load())], 1145 | # keywords=[] 1146 | # ) 1147 | # ) 1148 | # ) 1149 | # 1150 | # assign = ast.Assign( 1151 | # targets=outvars, 1152 | # value=ast.Call( 1153 | # func=ast.Name(id=f"jvp_{fn_name}", ctx=ast.Load()), 1154 | # args=invars, 1155 | # keywords=[] 1156 | # )) 1157 | 1158 | # return [fn_ast, lam, assign] 1159 | 1160 | # just call the fn 1161 | 1162 | assign = ast.Assign( 1163 | targets=outvars, 1164 | value=ast.Call( 1165 | func=ast.Name(id=fn_name, ctx=ast.Load()), 1166 | args=invars, 1167 | keywords=[] 1168 | )) 1169 | 1170 | return [fn_ast, assign] 1171 | 1172 | 1173 | 1174 | 1175 | 1176 | @primitive_handler('while') 1177 | def _astify_while(state, eqn): 1178 | # out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr, 1179 | # {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)}) 1180 | body_jaxpr = eqn.params['body_jaxpr'] 1181 | cond_jaxpr = eqn.params['cond_jaxpr'] 1182 | 1183 | body_nconsts = eqn.params['body_nconsts'] 1184 | cond_nconsts = eqn.params['cond_nconsts'] 1185 | 1186 | if cond_nconsts != 0: 1187 | env = dict(zip(cond_jaxpr.in_avals, eqn.invars[:cond_nconsts])) 1188 | cond_jaxpr = partial_eval_jaxpr(cond_jaxpr.jaxpr, env, elide_unused_invars=False) 1189 | 1190 | cond_fn_ast = jaxpr_to_py_ast(state, cond_jaxpr) 1191 | cond_fn_name = cond_fn_ast.name 1192 | 1193 | if body_nconsts != 0: 1194 | env = dict(zip(body_jaxpr.in_avals, eqn.invars[cond_nconsts:cond_nconsts+body_nconsts])) 1195 | body_jaxpr = partial_eval_jaxpr(body_jaxpr.jaxpr, env, elide_unused_invars=False) 1196 | 1197 | body_fn_ast = jaxpr_to_py_ast(state, body_jaxpr) 1198 | body_fn_name = body_fn_ast.name 1199 | 1200 | true_args = eqn.invars[cond_nconsts+body_nconsts:] 1201 | 1202 | invars = [_astify_atom(state, v) for v in true_args] 1203 | outvars = _astify_outvars(state, eqn.outvars) 1204 | 1205 | body_lambda = ast.Lambda( 1206 | args=ast.arguments( 1207 | args=[ast.arg(arg='state')], 1208 | vararg=None, 1209 | kwonlyargs=[], 1210 | kw_defaults=[], 1211 | kwarg=None, 1212 | defaults=[], 1213 | posonlyargs=[] 1214 | ), 1215 | body=ast.Call( 1216 | func=ast.Name(id=body_fn_name, ctx=ast.Load()), 1217 | args=[ast.Starred(value=ast.Name(id='state', ctx=ast.Load()), ctx=ast.Load())], 1218 | keywords=[] 1219 | ) 1220 | ) 1221 | 1222 | cond_lam = ast.Lambda( 1223 | args=ast.arguments( 1224 | args=[ast.arg(arg='state')], 1225 | vararg=None, 1226 | kwonlyargs=[], 1227 | kw_defaults=[], 1228 | kwarg=None, 1229 | defaults=[], 1230 | posonlyargs=[] 1231 | ), 1232 | body=ast.Call( 1233 | func=ast.Name(id=cond_fn_name, ctx=ast.Load()), 1234 | args=[ast.Starred(value=ast.Name(id='state', ctx=ast.Load()), ctx=ast.Load())], 1235 | keywords=[] 1236 | ) 1237 | ) 1238 | 1239 | args = ast.Tuple(elts=invars, ctx=ast.Load()) 1240 | 1241 | assign = ast.Assign( 1242 | targets=outvars, 1243 | value=ast.Call( 1244 | func=ast.Name(id='jax.lax.while_loop', ctx=ast.Load()), 1245 | args=[cond_lam, body_lambda, args], 1246 | keywords=[] 1247 | )) 1248 | 1249 | return [body_fn_ast, cond_fn_ast, assign] 1250 | 1251 | 1252 | def _astize_fn(state, jaxpr, name): 1253 | return jaxpr_to_py_ast(state, jaxpr, name) 1254 | 1255 | 1256 | @primitive_handler('cond') 1257 | def _astify_cond(state, eqn): 1258 | # out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr, 1259 | # {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)}) 1260 | branches = eqn.params['branches'] 1261 | 1262 | ast_branches = [jaxpr_to_py_ast(state, jaxpr) for jaxpr in branches] 1263 | 1264 | pred_var, *rest_args = eqn.invars 1265 | pred_ast = _astify_atom(state, pred_var) 1266 | invars = [_astify_atom(state, v) for v in rest_args] 1267 | outvars = _astify_outvars(state, eqn.outvars) 1268 | 1269 | branch_names = [ast.Name(id=ast_branch.name, ctx=ast.Load()) for ast_branch in ast_branches] 1270 | 1271 | if len(branches) == 2: 1272 | false_fn_ast, true_fn_ast = ast_branches 1273 | false_name, true_name = branch_names 1274 | 1275 | assign = ast.Assign( 1276 | targets=outvars, 1277 | value=ast.Call( 1278 | func=ast.Name(id='jax.lax.cond', ctx=ast.Load()), 1279 | args=[pred_ast, true_name, false_name, *invars], 1280 | keywords=[] 1281 | )) 1282 | 1283 | return [true_fn_ast, false_fn_ast, assign] 1284 | 1285 | else: 1286 | # jax.lax.switch 1287 | assign = ast.Assign( 1288 | targets=outvars, 1289 | value=ast.Call( 1290 | func=ast.Name(id='jax.lax.switch', ctx=ast.Load()), 1291 | args=[pred_ast, ast.List(elts=branch_names), *invars], 1292 | keywords=[] 1293 | )) 1294 | 1295 | return ast_branches + [assign] 1296 | 1297 | 1298 | 1299 | 1300 | @primitive_handler('iota') 1301 | def _astify_iota(state, eqn): 1302 | # iota is a sort of broadcasted arange 1303 | # we can use np.broadcast_to(np.arange(size), shape) 1304 | dimension = eqn.params['dimension'] # axis along which to increment. 1305 | shape = eqn.params['shape'] 1306 | dtype = eqn.params['dtype'] 1307 | 1308 | arange = ast.Call( 1309 | func=ast.Attribute( 1310 | value=ast.Name(id="jnp", ctx=ast.Load()), 1311 | attr='arange', 1312 | ctx=ast.Load() 1313 | ), 1314 | args=[_astify_value(shape[0])], 1315 | keywords=[ast.keyword(arg='dtype', value=_astify_value(dtype))] 1316 | ) 1317 | 1318 | if len(shape) == 1: 1319 | # this is a simple arange 1320 | return ast.Assign( 1321 | targets=_astify_outvars(state, eqn.outvars), 1322 | value=arange 1323 | ) 1324 | 1325 | broadcast = ast.Call( 1326 | func=ast.Attribute( 1327 | value=ast.Name(id="jnp", ctx=ast.Load()), 1328 | attr='broadcast_to', 1329 | ctx=ast.Load() 1330 | ), 1331 | args=[arange, _astify_value(shape)], 1332 | keywords=[] 1333 | ) 1334 | 1335 | return ast.Assign( 1336 | targets=_astify_outvars(state, eqn.outvars), 1337 | value=broadcast 1338 | ) 1339 | 1340 | 1341 | 1342 | @primitive_handler('reshape') 1343 | def _astify_reshape(state, eqn): 1344 | # the lax reshape is a bit different, because it can combine a transpose and reshape into one. 1345 | # np.reshape(np.transpose(operand, dimensions), new_sizes) 1346 | dimensions = eqn.params['dimensions'] 1347 | new_sizes = eqn.params['new_sizes'] 1348 | 1349 | source = _astify_atom(state, eqn.invars[0]) 1350 | 1351 | if dimensions is not None: 1352 | source = ast.Call( 1353 | func=ast.Name(id='jnp.transpose', ctx=ast.Load()), 1354 | args=[source, _astify_value(dimensions)], 1355 | keywords=[] 1356 | ) 1357 | 1358 | assign = ast.Assign( 1359 | targets=_astify_outvars(state, eqn.outvars), 1360 | value=ast.Call( 1361 | func=ast.Name(id='jnp.reshape', ctx=ast.Load()), 1362 | args=[source, _astify_value(new_sizes)], 1363 | keywords=[] 1364 | )) 1365 | 1366 | return [assign] 1367 | 1368 | 1369 | @primitive_handler('add_any') 1370 | def _astify_add_any(state, eqn): 1371 | # add_any is a weird undocumented jax primitive. best guess is it adds? 1372 | return _binop_fn(ast.Add())(state, eqn) 1373 | 1374 | 1375 | @primitive_handler('broadcast_in_dim') 1376 | def _astify_broadcast_in_dim(state, eqn): 1377 | # broadcast_in_dim is how zeros, ones, full, etc are implemented, 1378 | # so we prefer to use those where possible 1379 | assert len(eqn.invars) == 1 1380 | value = eqn.invars[0] 1381 | shape = eqn.params['shape'] 1382 | broadcast_dimensions = eqn.params['broadcast_dimensions'] 1383 | 1384 | if not isinstance(value, Literal) or broadcast_dimensions != (): 1385 | return normal_fn('jax.lax.broadcast_in_dim')(state, eqn) 1386 | 1387 | if not isinstance(value.val, np.ndarray) or value.val.ndim != 0: 1388 | return normal_fn('jax.lax.broadcast_in_dim')(state, eqn) 1389 | else: 1390 | constant_value = value.val.item() 1391 | if constant_value == 0: 1392 | call = ast.Call( 1393 | ast.Attribute( 1394 | value=ast.Name(id="jnp", ctx=ast.Load()), 1395 | attr='zeros', 1396 | ctx=ast.Load() 1397 | ), 1398 | args=[_astify_value(shape), _astify_value(value.val.dtype)], 1399 | keywords=[] 1400 | ) 1401 | elif constant_value == 1: 1402 | call = ast.Call( 1403 | ast.Attribute( 1404 | value=ast.Name(id="jnp", ctx=ast.Load()), 1405 | attr='ones', 1406 | ctx=ast.Load() 1407 | ), 1408 | args=[_astify_value(shape), _astify_value(value.val.dtype)], 1409 | keywords=[] 1410 | ) 1411 | else: 1412 | call = ast.Call( 1413 | ast.Attribute( 1414 | value=ast.Name(id="jnp", ctx=ast.Load()), 1415 | attr='full', 1416 | ctx=ast.Load() 1417 | ), 1418 | args=[_astify_value(shape), _astify_value(constant_value), _astify_value(value.val.dtype)], 1419 | keywords=[] 1420 | ) 1421 | 1422 | return [ast.Assign( 1423 | targets=_astify_outvars(state, eqn.outvars), 1424 | value=call 1425 | )] 1426 | 1427 | @primitive_handler('random_wrap') 1428 | def _astify_random_wrap(state, eqn): 1429 | # we treat this as a noop 1430 | return ast.Assign( 1431 | targets=_astify_outvars(state, eqn.outvars), 1432 | value=_astify_atom(state, eqn.invars[0]) 1433 | ) 1434 | 1435 | 1436 | 1437 | 1438 | constant_fold_blacklist = { 1439 | 'broadcast_in_dim', 1440 | 'broadcast', 1441 | 'iota', 1442 | } -------------------------------------------------------------------------------- /src/jax_sourceror/utils.py: -------------------------------------------------------------------------------- 1 | from collections.abc import MutableMapping, MutableSet 2 | 3 | 4 | class IdentitySet(MutableSet): 5 | """Set that compares objects by identity. 6 | 7 | This is a set that compares objects by identity instead of equality. It is 8 | useful for storing objects that are not hashable or that should be compared 9 | by identity. 10 | 11 | This is a mutable set, but it does not support the ``__hash__`` method and 12 | therefore cannot be used as a dictionary key or as an element of another 13 | set. 14 | """ 15 | 16 | def __init__(self, iterable=None): 17 | self._data = {} 18 | if iterable is not None: 19 | self.update(iterable) 20 | 21 | def __contains__(self, value): 22 | return id(value) in self._data 23 | 24 | def __iter__(self): 25 | return iter(self._data.values()) 26 | 27 | def __len__(self): 28 | return len(self._data) 29 | 30 | def add(self, value): 31 | self._data[id(value)] = value 32 | 33 | def discard(self, value): 34 | self._data.pop(id(value), None) 35 | 36 | def __repr__(self): 37 | return f"IdentitySet({list(repr(x) for x in self._data.values())})" 38 | 39 | def __str__(self): 40 | return f"IdentitySet({list(str(x) for x in self._data.values())})" 41 | 42 | 43 | class IdentityMap(MutableMapping): 44 | """Map that compares keys by identity. 45 | 46 | This is a map that compares keys by identity instead of equality. It is 47 | useful for storing objects that are not hashable or that should be compared 48 | by identity. 49 | 50 | This is a mutable mapping, but it does not support the ``__hash__`` method 51 | and therefore cannot be used as a dictionary key or as an element of another 52 | set. 53 | """ 54 | 55 | def __init__(self, iterable=None): 56 | self._data = {} 57 | if iterable is not None: 58 | self.update(iterable) 59 | 60 | def __contains__(self, key): 61 | return id(key) in self._data 62 | 63 | def __getitem__(self, key): 64 | return self._data[id(key)] 65 | 66 | def __setitem__(self, key, value): 67 | self._data[id(key)] = value 68 | 69 | def __delitem__(self, key): 70 | del self._data[id(key)] 71 | 72 | def __iter__(self): 73 | return iter(self._data.values()) 74 | 75 | def __len__(self): 76 | return len(self._data) 77 | 78 | def __repr__(self): 79 | return f"IdentityMap({list(repr(x) for x in self._data.values())})" 80 | 81 | def __str__(self): 82 | return f"IdentityMap({list(str(x) for x in self._data.values())})" -------------------------------------------------------------------------------- /tests/test_jaxpr_to_source.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import textwrap 3 | from functools import partial 4 | 5 | import pytest 6 | 7 | from jax_sourceror.interpreter import sourcerize 8 | import jax 9 | import jaxtyping 10 | import jax.numpy as jnp 11 | 12 | 13 | def test_jaxpr_to_source_simple(): 14 | import jax.numpy as jnp 15 | 16 | def f(x): 17 | return x + 1 18 | 19 | source = sourcerize(f, use_jax_typing=False)(jnp.array([1, 2, 3])) 20 | 21 | assert source == """def f(a): 22 | b = a + 1 23 | return b""" 24 | 25 | 26 | def test_jaxpr_to_source_matmul(): 27 | import jax.numpy as jnp 28 | 29 | def f(x, y): 30 | return jnp.matmul(x, y) 31 | 32 | source = sourcerize(f, use_jax_typing=False)(jnp.array([[1, 2], [3, 4]]), jnp.array([[1, 2], [3, 4]])) 33 | 34 | assert source == """def f(a, b): 35 | c = jnp.matmul(a, b) 36 | return c""" 37 | 38 | check_roundtrip(f)(jnp.array([[1, 2], [3, 4]]), jnp.array([[1, 2], [3, 4]])) 39 | 40 | 41 | def check_roundtrip(f, **config_kwargs): 42 | def return_function(*args, **kwargs): 43 | source = sourcerize(f, **config_kwargs)(*args, **kwargs) 44 | f2 = _parse_sandboxed(source, f.__name__) 45 | 46 | f_results = f(*args, **kwargs) 47 | f2_results = f2(*args, **kwargs) 48 | 49 | if isinstance(f_results, tuple): 50 | assert isinstance(f2_results, tuple) 51 | assert len(f_results) == len(f2_results) 52 | for a, b in zip(f_results, f2_results): 53 | assert jnp.all(a == b) 54 | else: 55 | assert jnp.all(f_results == f2_results) 56 | return f2 57 | 58 | return return_function 59 | 60 | 61 | def _parse_sandboxed(source, fn_name): 62 | g = {'jax': jax, 'jaxtyping': jaxtyping, 'jnp': jnp, 'functools': functools, 'partial': partial} 63 | l = {} 64 | source = f""" 65 | from jaxtyping import * 66 | 67 | {source}""" 68 | exec(source, g, l) 69 | return l[fn_name] 70 | 71 | 72 | def test_slice_squeeze(): 73 | def f(x): 74 | return x[0:2, 0:1, 3] 75 | 76 | f2 = check_roundtrip(f)(jnp.arange(4 * 5 * 6).reshape(4, 5, 6)) 77 | check_roundtrip(f2)(jnp.arange(4 * 5 * 6).reshape(4, 5, 6)) 78 | 79 | 80 | def test_pseudo_sliding_window_attn_block(): 81 | block_len = 64 82 | seq_len = 128 83 | batch = 4 84 | num_heads = 2 85 | embed_size = 32 86 | num_layers = 2 87 | head_size = 16 88 | 89 | def block(x): 90 | query_block = x 91 | weights = jnp.sum(query_block, axis=3) # [batch, len, num_heads] 92 | weights = jax.lax.broadcast_in_dim(weights, (batch, block_len, num_heads, block_len), 93 | (0, 1, 2)) # [batch, len, num_heads, len] 94 | # weights = jax.lax.with_sharding_constraint(weights, PartitionSpec('data', None, None, None)) 95 | # without "bias", no boom 96 | bias = jnp.ones(block_len) 97 | bias = jnp.broadcast_to(bias, (batch, block_len, num_heads, block_len)) 98 | weights = weights + bias 99 | return jnp.einsum('bqhk,bkhd->bqhd', weights, query_block).astype(query_block.dtype) 100 | 101 | x = jnp.arange(batch * block_len * num_heads * head_size).reshape(batch, block_len, num_heads, head_size).astype(jnp.float32) 102 | 103 | mesh = jax.sharding.Mesh(jax.devices('cpu'), ('data',)) 104 | with mesh: 105 | f2 = check_roundtrip(block)(x) 106 | 107 | def test_scan(): 108 | def scanfn(x, y): 109 | return x + y, x * y 110 | 111 | x = jnp.arange(10) 112 | y = jnp.arange(10) 113 | 114 | def f(x, y): 115 | return jax.lax.scan(scanfn, x, y) 116 | 117 | f2 = check_roundtrip(f)(x, y) 118 | 119 | assert jnp.all(f(x, y)[0] == f2(x, y)[0]) 120 | assert jnp.all(f(x, y)[1] == f2(x, y)[1]) 121 | 122 | 123 | def test_map(): 124 | def f(x): 125 | return x + 1 126 | 127 | x = jnp.arange(10) 128 | 129 | def g(x): 130 | return jax.lax.map(f, x) 131 | 132 | g2 = check_roundtrip(g)(x) 133 | 134 | assert jnp.all(g(x) == g2(x)) 135 | 136 | 137 | def test_map_pytree(): 138 | def f(x): 139 | return x[0] + 1, x[1] + 1 140 | 141 | x = jnp.arange(10) 142 | 143 | def g(x, y): 144 | return jax.lax.map(f, (x, y)) 145 | 146 | g2 = check_roundtrip(g)(x, x) 147 | 148 | assert jnp.all(g(x, x)[0] == g2(x, x)[0]) 149 | assert jnp.all(g(x, x)[1] == g2(x, x)[1]) 150 | 151 | 152 | def test_pseudo_sliding_window_attention(): 153 | block_len = 64 154 | seq_len = 128 155 | batch = 4 156 | num_heads = 2 157 | embed_size = 32 158 | num_layers = 2 159 | head_size = 16 160 | def pseudo_sliding_window_attention(x): 161 | # (this is not attention, but is minimized from attn) 162 | # dims are [batch, len, num_heads, head_dim] 163 | # having num_heads is important. num_heads = 1, no boom 164 | def block(block_idx): 165 | query_block = jax.lax.dynamic_slice_in_dim(x, block_idx, block_len, axis=1) 166 | weights = jnp.sum(query_block, axis=3) # [batch, len, num_heads] 167 | weights = jax.lax.broadcast_in_dim(weights, (batch, block_len, num_heads, block_len), 168 | (0, 1, 2)) # [batch, len, num_heads, len] 169 | # weights = with_sharding_constraint(weights, P('data', None, None, None)) 170 | # without "bias", no boom 171 | bias = jnp.ones(block_len) 172 | bias = jnp.broadcast_to(bias, (batch, block_len, num_heads, block_len)) 173 | weights = weights + bias 174 | return jnp.einsum('bqhk,bkhd->bqhd', weights, query_block).astype(query_block.dtype) 175 | 176 | num_blocks = seq_len // block_len 177 | blocked_attn = jax.lax.map(block, jnp.arange(0, num_blocks)) # [num_blocks, batch, len, num_heads, head_dim] 178 | blocked_attn = jnp.concatenate(blocked_attn, axis=1) 179 | 180 | return blocked_attn 181 | 182 | def fwd(params, x): 183 | @partial(jax.checkpoint, prevent_cse=False) 184 | def layer(x, params): 185 | qkv, o = params 186 | y = jnp.einsum('bte,hde->bthd', x, qkv) 187 | y = pseudo_sliding_window_attention(y) 188 | z = jnp.einsum('bthd,hde->bte', y, o) 189 | return z, None 190 | 191 | x, _ = jax.lax.scan(layer, x, params) 192 | 193 | return x 194 | 195 | def loss_fn(params, x): 196 | x = fwd(params, x) 197 | l = jnp.mean(x) 198 | return l 199 | 200 | def grad_fn(params, x): 201 | loss, grad = jax.value_and_grad(loss_fn)(params, x) 202 | # we can't reasonably sourcerize pytrees so just get the leaves 203 | return loss, *jax.tree_util.tree_leaves(grad) 204 | 205 | qkv = jnp.ones((num_layers, num_heads, head_size, embed_size), dtype=jnp.bfloat16) 206 | o = jnp.ones((num_layers, num_heads, head_size, embed_size), dtype=jnp.bfloat16) 207 | x = jnp.ones((batch, seq_len, embed_size), dtype=jnp.bfloat16) 208 | 209 | params = (qkv, o) 210 | 211 | f2 = check_roundtrip(grad_fn)(params, x) 212 | 213 | def test_einsum(): 214 | def f(x, y): 215 | return jnp.einsum('cij,cjk->ik', x, y) 216 | 217 | x = jnp.arange(8).reshape(2, 2, 2) 218 | y = jnp.arange(8).reshape(2, 2, 2) 219 | 220 | check_roundtrip(f)(x, y) 221 | 222 | source = sourcerize(f, use_jax_typing=True)(x, y) 223 | 224 | assert source.strip() == \ 225 | textwrap.dedent(""" 226 | def f(a: Int32[Array, '2 2 2'], b: Int32[Array, '2 2 2']): 227 | c = jnp.einsum('acb,abd->cd', a, b, preferred_element_type=jnp.int32) 228 | return c""".strip()) 229 | 230 | 231 | def test_while_loop(): 232 | def f(y): 233 | def loop(args): 234 | x, y = args 235 | return x + y, y 236 | 237 | def cond(args): 238 | return args[0] < 10 239 | 240 | return jax.lax.while_loop(cond, loop, (0, y)) 241 | 242 | y = jnp.array(1) 243 | check_roundtrip(f)(y) 244 | 245 | 246 | def test_cond(): 247 | def f(x): 248 | def true_fn(x): 249 | return x + 1 250 | 251 | def false_fn(x): 252 | return x + 2 253 | 254 | return jax.lax.cond(x > 0, true_fn, false_fn, x) 255 | 256 | x = jnp.array(1) 257 | check_roundtrip(f)(x) 258 | 259 | 260 | def test_switch(): 261 | def f(x): 262 | def fn_a(x): 263 | return x + 1 264 | 265 | def fn_b(x): 266 | return x + 2 267 | 268 | def fn_c(x): 269 | return x + 3 270 | 271 | return jax.lax.switch(x, [fn_a, fn_b, fn_c], x) 272 | 273 | 274 | x = jnp.array(1) 275 | 276 | check_roundtrip(f)(x) 277 | 278 | 279 | def test_gather(): 280 | def f(x, y): 281 | return x[y] 282 | 283 | x = jnp.arange(8).reshape(2, 2, 2) 284 | y = jnp.array([0, 1]) 285 | 286 | check_roundtrip(f)(x, y) 287 | 288 | @pytest.mark.parametrize('fn', [jnp.cumsum, jnp.cumprod]) 289 | def test_cumulative(fn): 290 | def f(x): 291 | return fn(x, axis=0) 292 | 293 | x = jnp.arange(8).reshape(2, 2, 2) + 3 294 | 295 | check_roundtrip(f)(x) 296 | 297 | def test_concatenate(): 298 | def f(x, y): 299 | return jnp.concatenate([x, y], axis=0) 300 | 301 | x = jnp.arange(8).reshape(2, 2, 2) 302 | y = jnp.arange(8).reshape(2, 2, 2) 303 | 304 | check_roundtrip(f)(x, y) 305 | 306 | 307 | 308 | # jax.lax.fori_loop 309 | # jax.lax.dynamic_update_slice 310 | # jax.lax.dynamic_update_index_in_dim 311 | # jax.lax.dynamic_slice 312 | # jax.lax.dynamic_index_in_dim 313 | # jax.lax.dynamic_update_index_in_dim 314 | # jax.lax.dynamic_slice_in_dim 315 | # jax.lax.dynamic_update_slice_in_dim 316 | # jax.lax.gather 317 | # jax.lax.scatter 318 | # jax.lax.scatter_add 319 | # jax.lax.scatter_mul 320 | 321 | 322 | 323 | 324 | 325 | 326 | # want to handle this (complex) case: 327 | # { lambda a:u32[2] b:f32[128,72] c:f32[16,72] d:i32[4] e:f32[4,72] f:f32[4,72] g:f32[4,72,3,8,9] 328 | # h:f32[4,3,8,9] i:f32[4,8,9,72] j:f32[4,72] k:f32[4,72] l:f32[4,72] m:f32[4,72,288] 329 | # n:f32[4,288] o:f32[4,288,72] p:f32[4,72] q:bool[16,16] r:f32[72] s:f32[72]; t:i32[16]. let 330 | # u:key[] = random_wrap[impl=fry] a 331 | # v:key[2] = random_split[count=2] u 332 | # w:u32[2,2] = random_unwrap v 333 | # x:u32[1,2] = slice[limit_indices=(1, 2) start_indices=(0, 0) strides=(1, 1)] w 334 | # _:u32[2] = squeeze[dimensions=(0,)] x 335 | # y:u32[1,2] = slice[limit_indices=(2, 2) start_indices=(1, 0) strides=(1, 1)] w 336 | # z:u32[2] = squeeze[dimensions=(0,)] y 337 | # ba:f32[16,72] = pjit[ 338 | # jaxpr={ lambda ; bb:f32[128,72] bc:i32[16]. let 339 | # bd:bool[16] = lt bc 0 340 | # be:i32[16] = add bc 128 341 | # bf:i32[16] = pjit[ 342 | # jaxpr={ lambda ; bg:bool[16] bh:i32[16] bi:i32[16]. let 343 | # bj:i32[16] = select_n bg bi bh 344 | # in (bj,) } 345 | # name=_where 346 | # ] bd be bc 347 | # bk:i32[16,1] = broadcast_in_dim[ 348 | # broadcast_dimensions=(0,) 349 | # shape=(16, 1) 350 | # ] bf 351 | # bl:f32[16,72] = gather[ 352 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)) 353 | # fill_value=nan 354 | # indices_are_sorted=False 355 | # mode=GatherScatterMode.FILL_OR_DROP 356 | # slice_sizes=(1, 72) 357 | # unique_indices=False 358 | # ] bb bk 359 | # in (bl,) } 360 | # name=_take 361 | # ] b t 362 | # bm:f32[16,72] = add ba c 363 | # bn:key[] = random_wrap[impl=fry] z 364 | # bo:key[4] = random_split[count=4] bn 365 | # bp:u32[4,2] = random_unwrap bo 366 | # _:f32[72] = pjit[ 367 | # jaxpr={ lambda ; bq:f32[4,72] br:i32[]. let 368 | # bs:bool[] = lt br 0 369 | # bt:i32[] = add br 4 370 | # bu:i32[] = pjit[ 371 | # jaxpr={ lambda ; bv:bool[] bw:i32[] bx:i32[]. let 372 | # by:i32[] = select_n bv bx bw 373 | # in (by,) } 374 | # name=_where 375 | # ] bs bt br 376 | # bz:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] bu 377 | # ca:f32[72] = gather[ 378 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(0,), start_index_map=(0,)) 379 | # fill_value=nan 380 | # indices_are_sorted=False 381 | # mode=GatherScatterMode.FILL_OR_DROP 382 | # slice_sizes=(1, 72) 383 | # unique_indices=False 384 | # ] bq bz 385 | # in (ca,) } 386 | # name=_take 387 | # ] e 0 388 | # _:f32[72] = pjit[ 389 | # jaxpr={ lambda ; bq:f32[4,72] br:i32[]. let 390 | # bs:bool[] = lt br 0 391 | # bt:i32[] = add br 4 392 | # bu:i32[] = pjit[ 393 | # jaxpr={ lambda ; bv:bool[] bw:i32[] bx:i32[]. let 394 | # by:i32[] = select_n bv bx bw 395 | # in (by,) } 396 | # name=_where 397 | # ] bs bt br 398 | # bz:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] bu 399 | # ca:f32[72] = gather[ 400 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(0,), start_index_map=(0,)) 401 | # fill_value=nan 402 | # indices_are_sorted=False 403 | # mode=GatherScatterMode.FILL_OR_DROP 404 | # slice_sizes=(1, 72) 405 | # unique_indices=False 406 | # ] bq bz 407 | # in (ca,) } 408 | # name=_take 409 | # ] f 0 410 | # _:f32[72,3,8,9] = pjit[ 411 | # jaxpr={ lambda ; cb:f32[4,72,3,8,9] cc:i32[]. let 412 | # cd:bool[] = lt cc 0 413 | # ce:i32[] = add cc 4 414 | # cf:i32[] = pjit[ 415 | # jaxpr={ lambda ; bv:bool[] bw:i32[] bx:i32[]. let 416 | # by:i32[] = select_n bv bx bw 417 | # in (by,) } 418 | # name=_where 419 | # ] cd ce cc 420 | # cg:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] cf 421 | # ch:f32[72,3,8,9] = gather[ 422 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 2, 3), collapsed_slice_dims=(0,), start_index_map=(0,)) 423 | # fill_value=nan 424 | # indices_are_sorted=False 425 | # mode=GatherScatterMode.FILL_OR_DROP 426 | # slice_sizes=(1, 72, 3, 8, 9) 427 | # unique_indices=False 428 | # ] cb cg 429 | # in (ch,) } 430 | # name=_take 431 | # ] g 0 432 | # _:f32[3,8,9] = pjit[ 433 | # jaxpr={ lambda ; ci:f32[4,3,8,9] cj:i32[]. let 434 | # ck:bool[] = lt cj 0 435 | # cl:i32[] = add cj 4 436 | # cm:i32[] = pjit[ 437 | # jaxpr={ lambda ; bv:bool[] bw:i32[] bx:i32[]. let 438 | # by:i32[] = select_n bv bx bw 439 | # in (by,) } 440 | # name=_where 441 | # ] ck cl cj 442 | # cn:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] cm 443 | # co:f32[3,8,9] = gather[ 444 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 2), collapsed_slice_dims=(0,), start_index_map=(0,)) 445 | # fill_value=nan 446 | # indices_are_sorted=False 447 | # mode=GatherScatterMode.FILL_OR_DROP 448 | # slice_sizes=(1, 3, 8, 9) 449 | # unique_indices=False 450 | # ] ci cn 451 | # in (co,) } 452 | # name=_take 453 | # ] h 0 454 | # _:f32[8,9,72] = pjit[ 455 | # jaxpr={ lambda ; cp:f32[4,8,9,72] cq:i32[]. let 456 | # cr:bool[] = lt cq 0 457 | # cs:i32[] = add cq 4 458 | # ct:i32[] = pjit[ 459 | # jaxpr={ lambda ; bv:bool[] bw:i32[] bx:i32[]. let 460 | # by:i32[] = select_n bv bx bw 461 | # in (by,) } 462 | # name=_where 463 | # ] cr cs cq 464 | # cu:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] ct 465 | # cv:f32[8,9,72] = gather[ 466 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 2), collapsed_slice_dims=(0,), start_index_map=(0,)) 467 | # fill_value=nan 468 | # indices_are_sorted=False 469 | # mode=GatherScatterMode.FILL_OR_DROP 470 | # slice_sizes=(1, 8, 9, 72) 471 | # unique_indices=False 472 | # ] cp cu 473 | # in (cv,) } 474 | # name=_take 475 | # ] i 0 476 | # _:f32[72] = pjit[ 477 | # jaxpr={ lambda ; bq:f32[4,72] br:i32[]. let 478 | # bs:bool[] = lt br 0 479 | # bt:i32[] = add br 4 480 | # bu:i32[] = pjit[ 481 | # jaxpr={ lambda ; bv:bool[] bw:i32[] bx:i32[]. let 482 | # by:i32[] = select_n bv bx bw 483 | # in (by,) } 484 | # name=_where 485 | # ] bs bt br 486 | # bz:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] bu 487 | # ca:f32[72] = gather[ 488 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(0,), start_index_map=(0,)) 489 | # fill_value=nan 490 | # indices_are_sorted=False 491 | # mode=GatherScatterMode.FILL_OR_DROP 492 | # slice_sizes=(1, 72) 493 | # unique_indices=False 494 | # ] bq bz 495 | # in (ca,) } 496 | # name=_take 497 | # ] j 0 498 | # _:f32[72] = pjit[ 499 | # jaxpr={ lambda ; bq:f32[4,72] br:i32[]. let 500 | # bs:bool[] = lt br 0 501 | # bt:i32[] = add br 4 502 | # bu:i32[] = pjit[ 503 | # jaxpr={ lambda ; bv:bool[] bw:i32[] bx:i32[]. let 504 | # by:i32[] = select_n bv bx bw 505 | # in (by,) } 506 | # name=_where 507 | # ] bs bt br 508 | # bz:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] bu 509 | # ca:f32[72] = gather[ 510 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(0,), start_index_map=(0,)) 511 | # fill_value=nan 512 | # indices_are_sorted=False 513 | # mode=GatherScatterMode.FILL_OR_DROP 514 | # slice_sizes=(1, 72) 515 | # unique_indices=False 516 | # ] bq bz 517 | # in (ca,) } 518 | # name=_take 519 | # ] k 0 520 | # _:f32[72] = pjit[ 521 | # jaxpr={ lambda ; bq:f32[4,72] br:i32[]. let 522 | # bs:bool[] = lt br 0 523 | # bt:i32[] = add br 4 524 | # bu:i32[] = pjit[ 525 | # jaxpr={ lambda ; bv:bool[] bw:i32[] bx:i32[]. let 526 | # by:i32[] = select_n bv bx bw 527 | # in (by,) } 528 | # name=_where 529 | # ] bs bt br 530 | # bz:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] bu 531 | # ca:f32[72] = gather[ 532 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(0,), start_index_map=(0,)) 533 | # fill_value=nan 534 | # indices_are_sorted=False 535 | # mode=GatherScatterMode.FILL_OR_DROP 536 | # slice_sizes=(1, 72) 537 | # unique_indices=False 538 | # ] bq bz 539 | # in (ca,) } 540 | # name=_take 541 | # ] l 0 542 | # _:f32[72,288] = pjit[ 543 | # jaxpr={ lambda ; cw:f32[4,72,288] cx:i32[]. let 544 | # cy:bool[] = lt cx 0 545 | # cz:i32[] = add cx 4 546 | # da:i32[] = pjit[ 547 | # jaxpr={ lambda ; bv:bool[] bw:i32[] bx:i32[]. let 548 | # by:i32[] = select_n bv bx bw 549 | # in (by,) } 550 | # name=_where 551 | # ] cy cz cx 552 | # db:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] da 553 | # dc:f32[72,288] = gather[ 554 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(0,), start_index_map=(0,)) 555 | # fill_value=nan 556 | # indices_are_sorted=False 557 | # mode=GatherScatterMode.FILL_OR_DROP 558 | # slice_sizes=(1, 72, 288) 559 | # unique_indices=False 560 | # ] cw db 561 | # in (dc,) } 562 | # name=_take 563 | # ] m 0 564 | # _:f32[288] = pjit[ 565 | # jaxpr={ lambda ; dd:f32[4,288] de:i32[]. let 566 | # df:bool[] = lt de 0 567 | # dg:i32[] = add de 4 568 | # dh:i32[] = pjit[ 569 | # jaxpr={ lambda ; bv:bool[] bw:i32[] bx:i32[]. let 570 | # by:i32[] = select_n bv bx bw 571 | # in (by,) } 572 | # name=_where 573 | # ] df dg de 574 | # di:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] dh 575 | # dj:f32[288] = gather[ 576 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(0,), start_index_map=(0,)) 577 | # fill_value=nan 578 | # indices_are_sorted=False 579 | # mode=GatherScatterMode.FILL_OR_DROP 580 | # slice_sizes=(1, 288) 581 | # unique_indices=False 582 | # ] dd di 583 | # in (dj,) } 584 | # name=_take 585 | # ] n 0 586 | # _:f32[288,72] = pjit[ 587 | # jaxpr={ lambda ; dk:f32[4,288,72] dl:i32[]. let 588 | # dm:bool[] = lt dl 0 589 | # dn:i32[] = add dl 4 590 | # do:i32[] = pjit[ 591 | # jaxpr={ lambda ; bv:bool[] bw:i32[] bx:i32[]. let 592 | # by:i32[] = select_n bv bx bw 593 | # in (by,) } 594 | # name=_where 595 | # ] dm dn dl 596 | # dp:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] do 597 | # dq:f32[288,72] = gather[ 598 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(0,), start_index_map=(0,)) 599 | # fill_value=nan 600 | # indices_are_sorted=False 601 | # mode=GatherScatterMode.FILL_OR_DROP 602 | # slice_sizes=(1, 288, 72) 603 | # unique_indices=False 604 | # ] dk dp 605 | # in (dq,) } 606 | # name=_take 607 | # ] o 0 608 | # _:f32[72] = pjit[ 609 | # jaxpr={ lambda ; bq:f32[4,72] br:i32[]. let 610 | # bs:bool[] = lt br 0 611 | # bt:i32[] = add br 4 612 | # bu:i32[] = pjit[ 613 | # jaxpr={ lambda ; bv:bool[] bw:i32[] bx:i32[]. let 614 | # by:i32[] = select_n bv bx bw 615 | # in (by,) } 616 | # name=_where 617 | # ] bs bt br 618 | # bz:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] bu 619 | # ca:f32[72] = gather[ 620 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(0,), start_index_map=(0,)) 621 | # fill_value=nan 622 | # indices_are_sorted=False 623 | # mode=GatherScatterMode.FILL_OR_DROP 624 | # slice_sizes=(1, 72) 625 | # unique_indices=False 626 | # ] bq bz 627 | # in (ca,) } 628 | # name=_take 629 | # ] p 0 630 | # _:i32[] = pjit[ 631 | # jaxpr={ lambda ; dr:i32[4] ds:i32[]. let 632 | # dt:bool[] = lt ds 0 633 | # du:i32[] = add ds 4 634 | # dv:i32[] = pjit[ 635 | # jaxpr={ lambda ; bv:bool[] bw:i32[] bx:i32[]. let 636 | # by:i32[] = select_n bv bx bw 637 | # in (by,) } 638 | # name=_where 639 | # ] dt du ds 640 | # dw:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] dv 641 | # dx:i32[] = gather[ 642 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) 643 | # fill_value=-2147483648 644 | # indices_are_sorted=False 645 | # mode=GatherScatterMode.FILL_OR_DROP 646 | # slice_sizes=(1,) 647 | # unique_indices=False 648 | # ] dr dw 649 | # in (dx,) } 650 | # name=_take 651 | # ] d 0 652 | # dy:f32[16,72] = scan[ 653 | # jaxpr={ lambda ; dz:bool[16,16] ea:f32[16,72] eb:f32[72] ec:f32[72] ed:f32[72,3,8,9] 654 | # ee:f32[3,8,9] ef:f32[8,9,72] eg:f32[72] eh:f32[72] ei:f32[72] ej:f32[72,288] 655 | # ek:f32[288] el:f32[288,72] em:f32[72] en:i32[] eo:u32[2]. let 656 | # ep:key[] = random_wrap[impl=fry] eo 657 | # eq:key[3] = random_split[count=3] ep 658 | # er:u32[3,2] = random_unwrap eq 659 | # es:u32[1,2] = slice[ 660 | # limit_indices=(1, 2) 661 | # start_indices=(0, 0) 662 | # strides=(1, 1) 663 | # ] er 664 | # _:u32[2] = squeeze[dimensions=(0,)] es 665 | # et:u32[1,2] = slice[ 666 | # limit_indices=(2, 2) 667 | # start_indices=(1, 0) 668 | # strides=(1, 1) 669 | # ] er 670 | # _:u32[2] = squeeze[dimensions=(0,)] et 671 | # eu:u32[1,2] = slice[ 672 | # limit_indices=(3, 2) 673 | # start_indices=(2, 0) 674 | # strides=(1, 1) 675 | # ] er 676 | # _:u32[2] = squeeze[dimensions=(0,)] eu 677 | # ev:f32[16] = reduce_sum[axes=(1,)] ea 678 | # ew:f32[16] = div ev 72.0 679 | # ex:f32[16] = pjit[ 680 | # jaxpr={ lambda ; ey:f32[16,72] ez:i32[]. let 681 | # fa:f32[16] = reduce_sum[axes=(1,)] ey 682 | # fb:f32[16,1] = broadcast_in_dim[ 683 | # broadcast_dimensions=(0,) 684 | # shape=(16, 1) 685 | # ] fa 686 | # fc:f32[16,1] = div fb 72.0 687 | # fd:f32[16,72] = sub ey fc 688 | # fe:f32[16,72] = integer_pow[y=2] fd 689 | # ff:f32[] = convert_element_type[ 690 | # new_dtype=float32 691 | # weak_type=False 692 | # ] ez 693 | # fg:f32[] = sub 72.0 ff 694 | # fh:f32[16] = reduce_sum[axes=(1,)] fe 695 | # fi:f32[16] = div fh fg 696 | # in (fi,) } 697 | # name=_var 698 | # ] ea 0 699 | # fj:f32[16] = add ex 9.999999747378752e-06 700 | # fk:f32[16] = rsqrt fj 701 | # fl:f32[72,16] = broadcast_in_dim[ 702 | # broadcast_dimensions=(1,) 703 | # shape=(72, 16) 704 | # ] ew 705 | # fm:f32[16,72] = transpose[permutation=(1, 0)] fl 706 | # fn:f32[16,72] = sub ea fm 707 | # fo:f32[72,16] = broadcast_in_dim[ 708 | # broadcast_dimensions=(1,) 709 | # shape=(72, 16) 710 | # ] fk 711 | # fp:f32[16,72] = transpose[permutation=(1, 0)] fo 712 | # fq:f32[16,72] = mul fn fp 713 | # fr:f32[16,72] = broadcast_in_dim[ 714 | # broadcast_dimensions=(1,) 715 | # shape=(16, 72) 716 | # ] eb 717 | # fs:f32[16,72] = mul fr fq 718 | # ft:f32[16,72] = broadcast_in_dim[ 719 | # broadcast_dimensions=(1,) 720 | # shape=(16, 72) 721 | # ] ec 722 | # fu:f32[16,72] = add fs ft 723 | # fv:f32[16,3,8,9] = dot_general[ 724 | # dimension_numbers=(([1], [0]), ([], [])) 725 | # ] fu ed 726 | # fw:f32[16,3,8,9] = broadcast_in_dim[ 727 | # broadcast_dimensions=(1, 2, 3) 728 | # shape=(16, 3, 8, 9) 729 | # ] ee 730 | # fx:f32[16,3,8,9] = add fv fw 731 | # fy:f32[3,8,16,9] = transpose[permutation=(1, 2, 0, 3)] fx 732 | # fz:f32[8,16,9] = pjit[ 733 | # jaxpr={ lambda ; ga:f32[3,8,16,9] gb:i32[]. let 734 | # gc:bool[] = lt gb 0 735 | # gd:i32[] = add gb 3 736 | # ge:i32[] = pjit[ 737 | # jaxpr={ lambda ; bv:bool[] bw:i32[] bx:i32[]. let 738 | # by:i32[] = select_n bv bx bw 739 | # in (by,) } 740 | # name=_where 741 | # ] gc gd gb 742 | # gf:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] ge 743 | # gg:f32[8,16,9] = gather[ 744 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 2), collapsed_slice_dims=(0,), start_index_map=(0,)) 745 | # fill_value=nan 746 | # indices_are_sorted=False 747 | # mode=GatherScatterMode.FILL_OR_DROP 748 | # slice_sizes=(1, 8, 16, 9) 749 | # unique_indices=False 750 | # ] ga gf 751 | # in (gg,) } 752 | # name=_take 753 | # ] fy 0 754 | # gh:f32[8,16,9] = pjit[ 755 | # jaxpr={ lambda ; ga:f32[3,8,16,9] gb:i32[]. let 756 | # gc:bool[] = lt gb 0 757 | # gd:i32[] = add gb 3 758 | # ge:i32[] = pjit[ 759 | # jaxpr={ lambda ; bv:bool[] bw:i32[] bx:i32[]. let 760 | # by:i32[] = select_n bv bx bw 761 | # in (by,) } 762 | # name=_where 763 | # ] gc gd gb 764 | # gf:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] ge 765 | # gg:f32[8,16,9] = gather[ 766 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 2), collapsed_slice_dims=(0,), start_index_map=(0,)) 767 | # fill_value=nan 768 | # indices_are_sorted=False 769 | # mode=GatherScatterMode.FILL_OR_DROP 770 | # slice_sizes=(1, 8, 16, 9) 771 | # unique_indices=False 772 | # ] ga gf 773 | # in (gg,) } 774 | # name=_take 775 | # ] fy 1 776 | # gi:f32[8,16,9] = pjit[ 777 | # jaxpr={ lambda ; ga:f32[3,8,16,9] gb:i32[]. let 778 | # gc:bool[] = lt gb 0 779 | # gd:i32[] = add gb 3 780 | # ge:i32[] = pjit[ 781 | # jaxpr={ lambda ; bv:bool[] bw:i32[] bx:i32[]. let 782 | # by:i32[] = select_n bv bx bw 783 | # in (by,) } 784 | # name=_where 785 | # ] gc gd gb 786 | # gf:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] ge 787 | # gg:f32[8,16,9] = gather[ 788 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 2), collapsed_slice_dims=(0,), start_index_map=(0,)) 789 | # fill_value=nan 790 | # indices_are_sorted=False 791 | # mode=GatherScatterMode.FILL_OR_DROP 792 | # slice_sizes=(1, 8, 16, 9) 793 | # unique_indices=False 794 | # ] ga gf 795 | # in (gg,) } 796 | # name=_take 797 | # ] fy 2 798 | # gj:f32[] = rsqrt 9.0 799 | # gk:f32[] = convert_element_type[new_dtype=float32 weak_type=False] gj 800 | # gl:f32[8,16,9] = mul fz gk 801 | # gm:f32[8,16,16] = dot_general[ 802 | # dimension_numbers=(([2], [2]), ([0], [0])) 803 | # ] gl gh 804 | # gn:f32[16,16] = convert_element_type[new_dtype=float32 weak_type=True] dz 805 | # go:f32[16,16] = sub 1.0 gn 806 | # gp:f32[16,16] = mul go -1000000000.0 807 | # gq:f32[8,16,16] = broadcast_in_dim[ 808 | # broadcast_dimensions=(1, 2) 809 | # shape=(8, 16, 16) 810 | # ] gp 811 | # gr:f32[8,16,16] = convert_element_type[ 812 | # new_dtype=float32 813 | # weak_type=False 814 | # ] gq 815 | # gs:f32[8,16,16] = add gm gr 816 | # gt:f32[8,16] = reduce_max[axes=(2,)] gs 817 | # gu:f32[8,16,1] = broadcast_in_dim[ 818 | # broadcast_dimensions=(0, 1) 819 | # shape=(8, 16, 1) 820 | # ] gt 821 | # gv:f32[8,16,1] = stop_gradient gu 822 | # gw:f32[8,16,16] = sub gs gv 823 | # gx:f32[8,16,16] = exp gw 824 | # gy:f32[8,16] = reduce_sum[axes=(2,)] gx 825 | # gz:f32[8,16,1] = broadcast_in_dim[ 826 | # broadcast_dimensions=(0, 1) 827 | # shape=(8, 16, 1) 828 | # ] gy 829 | # ha:f32[8,16,16] = div gx gz 830 | # hb:f32[8,16,9] = dot_general[ 831 | # dimension_numbers=(([2], [1]), ([0], [0])) 832 | # ] ha gi 833 | # hc:f32[16,72] = dot_general[ 834 | # dimension_numbers=(([0, 2], [0, 1]), ([], [])) 835 | # ] hb ef 836 | # hd:f32[16,72] = broadcast_in_dim[ 837 | # broadcast_dimensions=(1,) 838 | # shape=(16, 72) 839 | # ] eg 840 | # he:f32[16,72] = add hc hd 841 | # hf:f32[16,72] = add ea he 842 | # hg:f32[16] = reduce_sum[axes=(1,)] hf 843 | # hh:f32[16] = div hg 72.0 844 | # hi:f32[16] = pjit[ 845 | # jaxpr={ lambda ; ey:f32[16,72] ez:i32[]. let 846 | # fa:f32[16] = reduce_sum[axes=(1,)] ey 847 | # fb:f32[16,1] = broadcast_in_dim[ 848 | # broadcast_dimensions=(0,) 849 | # shape=(16, 1) 850 | # ] fa 851 | # fc:f32[16,1] = div fb 72.0 852 | # fd:f32[16,72] = sub ey fc 853 | # fe:f32[16,72] = integer_pow[y=2] fd 854 | # ff:f32[] = convert_element_type[ 855 | # new_dtype=float32 856 | # weak_type=False 857 | # ] ez 858 | # fg:f32[] = sub 72.0 ff 859 | # fh:f32[16] = reduce_sum[axes=(1,)] fe 860 | # fi:f32[16] = div fh fg 861 | # in (fi,) } 862 | # name=_var 863 | # ] hf 0 864 | # hj:f32[16] = add hi 9.999999747378752e-06 865 | # hk:f32[16] = rsqrt hj 866 | # hl:f32[72,16] = broadcast_in_dim[ 867 | # broadcast_dimensions=(1,) 868 | # shape=(72, 16) 869 | # ] hh 870 | # hm:f32[16,72] = transpose[permutation=(1, 0)] hl 871 | # hn:f32[16,72] = sub hf hm 872 | # ho:f32[72,16] = broadcast_in_dim[ 873 | # broadcast_dimensions=(1,) 874 | # shape=(72, 16) 875 | # ] hk 876 | # hp:f32[16,72] = transpose[permutation=(1, 0)] ho 877 | # hq:f32[16,72] = mul hn hp 878 | # hr:f32[16,72] = broadcast_in_dim[ 879 | # broadcast_dimensions=(1,) 880 | # shape=(16, 72) 881 | # ] eh 882 | # hs:f32[16,72] = mul hr hq 883 | # ht:f32[16,72] = broadcast_in_dim[ 884 | # broadcast_dimensions=(1,) 885 | # shape=(16, 72) 886 | # ] ei 887 | # hu:f32[16,72] = add hs ht 888 | # hv:f32[16,288] = dot_general[dimension_numbers=(([1], [0]), ([], []))] hu 889 | # ej 890 | # hw:f32[16,288] = broadcast_in_dim[ 891 | # broadcast_dimensions=(1,) 892 | # shape=(16, 288) 893 | # ] ek 894 | # hx:f32[16,288] = add hv hw 895 | # hy:f32[16,288] = integer_pow[y=3] hx 896 | # hz:f32[16,288] = mul 0.044714998453855515 hy 897 | # ia:f32[16,288] = add hx hz 898 | # ib:f32[16,288] = mul 0.7978845834732056 ia 899 | # ic:f32[16,288] = tanh ib 900 | # id:f32[16,288] = add 1.0 ic 901 | # ie:f32[16,288] = mul 0.5 id 902 | # if:f32[16,288] = mul hx ie 903 | # ig:f32[16,72] = dot_general[dimension_numbers=(([1], [0]), ([], []))] if 904 | # el 905 | # ih:f32[16,72] = broadcast_in_dim[ 906 | # broadcast_dimensions=(1,) 907 | # shape=(16, 72) 908 | # ] em 909 | # ii:f32[16,72] = add ig ih 910 | # ij:f32[16,72] = add hf ii 911 | # in (ij,) } 912 | # length=4 913 | # linear=(False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False) 914 | # num_carry=1 915 | # num_consts=1 916 | # reverse=False 917 | # unroll=1 918 | # ] q bm e f g h i j k l m n o p d bp 919 | # ik:f32[16] = reduce_sum[axes=(1,)] dy 920 | # il:f32[16] = div ik 72.0 921 | # im:f32[16] = pjit[ 922 | # jaxpr={ lambda ; ey:f32[16,72] ez:i32[]. let 923 | # fa:f32[16] = reduce_sum[axes=(1,)] ey 924 | # fb:f32[16,1] = broadcast_in_dim[ 925 | # broadcast_dimensions=(0,) 926 | # shape=(16, 1) 927 | # ] fa 928 | # fc:f32[16,1] = div fb 72.0 929 | # fd:f32[16,72] = sub ey fc 930 | # fe:f32[16,72] = integer_pow[y=2] fd 931 | # ff:f32[] = convert_element_type[new_dtype=float32 weak_type=False] ez 932 | # fg:f32[] = sub 72.0 ff 933 | # fh:f32[16] = reduce_sum[axes=(1,)] fe 934 | # fi:f32[16] = div fh fg 935 | # in (fi,) } 936 | # name=_var 937 | # ] dy 0 938 | # in:f32[16] = add im 9.999999747378752e-06 939 | # io:f32[16] = rsqrt in 940 | # ip:f32[72,16] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(72, 16)] il 941 | # iq:f32[16,72] = transpose[permutation=(1, 0)] ip 942 | # ir:f32[16,72] = sub dy iq 943 | # is:f32[72,16] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(72, 16)] io 944 | # it:f32[16,72] = transpose[permutation=(1, 0)] is 945 | # iu:f32[16,72] = mul ir it 946 | # iv:f32[16,72] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(16, 72)] r 947 | # iw:f32[16,72] = mul iv iu 948 | # ix:f32[16,72] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(16, 72)] s 949 | # iy:f32[16,72] = add iw ix 950 | # iz:f32[16,128] = dot_general[dimension_numbers=(([1], [1]), ([], []))] iy b 951 | # in (iz,) } 952 | 953 | ### Jaxpr for gpt2 train_step 954 | # let _take = { lambda ; a:f32[2,32] b:i32[]. let 955 | # c:bool[] = lt b 0 956 | # d:i32[] = add b 2 957 | # e:i32[] = pjit[ 958 | # name=_where 959 | # jaxpr={ lambda ; f:bool[] g:i32[] h:i32[]. let 960 | # i:i32[] = select_n f h g 961 | # in (i,) } 962 | # ] c d b 963 | # j:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] e 964 | # k:f32[32] = gather[ 965 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(0,), start_index_map=(0,)) 966 | # fill_value=nan 967 | # indices_are_sorted=False 968 | # mode=GatherScatterMode.FILL_OR_DROP 969 | # slice_sizes=(1, 32) 970 | # unique_indices=False 971 | # ] a j 972 | # in (k,) } in 973 | # let _var = { lambda ; l:f32[32,512,32] m:i32[]. let 974 | # n:f32[32,512] = reduce_sum[axes=(2,)] l 975 | # o:f32[32,512,1] = broadcast_in_dim[ 976 | # broadcast_dimensions=(0, 1) 977 | # shape=(32, 512, 1) 978 | # ] n 979 | # p:f32[32,512,1] = div o 32.0 980 | # q:f32[32,512,32] = sub l p 981 | # r:f32[32,512,32] = integer_pow[y=2] q 982 | # s:f32[] = convert_element_type[new_dtype=float32 weak_type=False] m 983 | # t:f32[] = sub 32.0 s 984 | # u:f32[32,512] = reduce_sum[axes=(2,)] r 985 | # v:f32[32,512] = div u t 986 | # in (v,) } in 987 | # let _where = { lambda ; f:bool[] g:i32[] h:i32[]. let 988 | # i:i32[] = select_n f h g 989 | # in (i,) } in 990 | # let _take1 = { lambda ; w:f32[32,3,4,512,8] x:i32[]. let 991 | # y:bool[] = lt x 0 992 | # z:i32[] = add x 3 993 | # ba:i32[] = pjit[name=_where jaxpr=_where] y z x 994 | # bb:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] ba 995 | # bc:f32[32,4,512,8] = gather[ 996 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 2, 3), collapsed_slice_dims=(1,), start_index_map=(1,)) 997 | # fill_value=nan 998 | # indices_are_sorted=False 999 | # mode=GatherScatterMode.FILL_OR_DROP 1000 | # slice_sizes=(32, 1, 4, 512, 8) 1001 | # unique_indices=False 1002 | # ] w bb 1003 | # in (bc,) } in 1004 | # { lambda ; bd:f32[2,32] be:f32[2,32] bf:f32[2,32,3,4,8] bg:f32[2,3,4,8] bh:f32[2,4,8,32] 1005 | # bi:f32[2,32] bj:f32[2,32] bk:f32[2,32] bl:f32[2,32,128] bm:f32[2,128] bn:f32[2,128,32] 1006 | # bo:f32[2,32] bp:f32[32] bq:f32[32] br:f32[50257,32] bs:f32[512,32] bt:i32[32,512] 1007 | # bu:f32[32,512]. let 1008 | # bv:f32[32,512,32] = pjit[ 1009 | # name=_take 1010 | # jaxpr={ lambda ; bw:f32[50257,32] bx:i32[32,512]. let 1011 | # by:bool[32,512] = lt bx 0 1012 | # bz:i32[32,512] = add bx 50257 1013 | # ca:i32[32,512] = pjit[ 1014 | # name=_where 1015 | # jaxpr={ lambda ; cb:bool[32,512] cc:i32[32,512] cd:i32[32,512]. let 1016 | # ce:i32[32,512] = select_n cb cd cc 1017 | # in (ce,) } 1018 | # ] by bz bx 1019 | # cf:i32[32,512,1] = broadcast_in_dim[ 1020 | # broadcast_dimensions=(0, 1) 1021 | # shape=(32, 512, 1) 1022 | # ] ca 1023 | # cg:f32[32,512,32] = gather[ 1024 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(2,), collapsed_slice_dims=(0,), start_index_map=(0,)) 1025 | # fill_value=nan 1026 | # indices_are_sorted=False 1027 | # mode=GatherScatterMode.FILL_OR_DROP 1028 | # slice_sizes=(1, 32) 1029 | # unique_indices=False 1030 | # ] bw cf 1031 | # in (cg,) } 1032 | # ] br bt 1033 | # ch:i32[512] = iota[dimension=0 dtype=int32 shape=(512,)] 1034 | # ci:i32[512] = mul ch 1 1035 | # cj:i32[512] = add ci 0 1036 | # ck:f32[512,32] = pjit[ 1037 | # name=_take 1038 | # jaxpr={ lambda ; cl:f32[512,32] cm:i32[512]. let 1039 | # cn:bool[512] = lt cm 0 1040 | # co:i32[512] = add cm 512 1041 | # cp:i32[512] = pjit[ 1042 | # name=_where 1043 | # jaxpr={ lambda ; cq:bool[512] cr:i32[512] cs:i32[512]. let 1044 | # ct:i32[512] = select_n cq cs cr 1045 | # in (ct,) } 1046 | # ] cn co cm 1047 | # cu:i32[512,1] = broadcast_in_dim[ 1048 | # broadcast_dimensions=(0,) 1049 | # shape=(512, 1) 1050 | # ] cp 1051 | # cv:f32[512,32] = gather[ 1052 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)) 1053 | # fill_value=nan 1054 | # indices_are_sorted=False 1055 | # mode=GatherScatterMode.FILL_OR_DROP 1056 | # slice_sizes=(1, 32) 1057 | # unique_indices=False 1058 | # ] cl cu 1059 | # in (cv,) } 1060 | # ] bs cj 1061 | # cw:f32[32,512,32] = broadcast_in_dim[ 1062 | # broadcast_dimensions=(1, 2) 1063 | # shape=(32, 512, 32) 1064 | # ] ck 1065 | # cx:f32[32,512,32] = add bv cw 1066 | # cy:i32[2] = iota[dimension=0 dtype=int32 shape=(2,)] 1067 | # cz:i32[2] = mul cy 1 1068 | # da:i32[2] = add cz 0 1069 | # _:f32[32] = pjit[name=_take jaxpr=_take] bd 0 1070 | # _:f32[32] = pjit[name=_take jaxpr=_take] be 0 1071 | # _:f32[32,3,4,8] = pjit[ 1072 | # name=_take 1073 | # jaxpr={ lambda ; db:f32[2,32,3,4,8] dc:i32[]. let 1074 | # dd:bool[] = lt dc 0 1075 | # de:i32[] = add dc 2 1076 | # df:i32[] = pjit[name=_where jaxpr=_where] dd de dc 1077 | # dg:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] df 1078 | # dh:f32[32,3,4,8] = gather[ 1079 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 2, 3), collapsed_slice_dims=(0,), start_index_map=(0,)) 1080 | # fill_value=nan 1081 | # indices_are_sorted=False 1082 | # mode=GatherScatterMode.FILL_OR_DROP 1083 | # slice_sizes=(1, 32, 3, 4, 8) 1084 | # unique_indices=False 1085 | # ] db dg 1086 | # in (dh,) } 1087 | # ] bf 0 1088 | # _:f32[3,4,8] = pjit[ 1089 | # name=_take 1090 | # jaxpr={ lambda ; di:f32[2,3,4,8] dj:i32[]. let 1091 | # dk:bool[] = lt dj 0 1092 | # dl:i32[] = add dj 2 1093 | # dm:i32[] = pjit[name=_where jaxpr=_where] dk dl dj 1094 | # dn:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] dm 1095 | # do:f32[3,4,8] = gather[ 1096 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 2), collapsed_slice_dims=(0,), start_index_map=(0,)) 1097 | # fill_value=nan 1098 | # indices_are_sorted=False 1099 | # mode=GatherScatterMode.FILL_OR_DROP 1100 | # slice_sizes=(1, 3, 4, 8) 1101 | # unique_indices=False 1102 | # ] di dn 1103 | # in (do,) } 1104 | # ] bg 0 1105 | # _:f32[4,8,32] = pjit[ 1106 | # name=_take 1107 | # jaxpr={ lambda ; dp:f32[2,4,8,32] dq:i32[]. let 1108 | # dr:bool[] = lt dq 0 1109 | # ds:i32[] = add dq 2 1110 | # dt:i32[] = pjit[name=_where jaxpr=_where] dr ds dq 1111 | # du:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] dt 1112 | # dv:f32[4,8,32] = gather[ 1113 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 2), collapsed_slice_dims=(0,), start_index_map=(0,)) 1114 | # fill_value=nan 1115 | # indices_are_sorted=False 1116 | # mode=GatherScatterMode.FILL_OR_DROP 1117 | # slice_sizes=(1, 4, 8, 32) 1118 | # unique_indices=False 1119 | # ] dp du 1120 | # in (dv,) } 1121 | # ] bh 0 1122 | # _:f32[32] = pjit[name=_take jaxpr=_take] bi 0 1123 | # _:f32[32] = pjit[name=_take jaxpr=_take] bj 0 1124 | # _:f32[32] = pjit[name=_take jaxpr=_take] bk 0 1125 | # _:f32[32,128] = pjit[ 1126 | # name=_take 1127 | # jaxpr={ lambda ; dw:f32[2,32,128] dx:i32[]. let 1128 | # dy:bool[] = lt dx 0 1129 | # dz:i32[] = add dx 2 1130 | # ea:i32[] = pjit[name=_where jaxpr=_where] dy dz dx 1131 | # eb:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] ea 1132 | # ec:f32[32,128] = gather[ 1133 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(0,), start_index_map=(0,)) 1134 | # fill_value=nan 1135 | # indices_are_sorted=False 1136 | # mode=GatherScatterMode.FILL_OR_DROP 1137 | # slice_sizes=(1, 32, 128) 1138 | # unique_indices=False 1139 | # ] dw eb 1140 | # in (ec,) } 1141 | # ] bl 0 1142 | # _:f32[128] = pjit[ 1143 | # name=_take 1144 | # jaxpr={ lambda ; ed:f32[2,128] ee:i32[]. let 1145 | # ef:bool[] = lt ee 0 1146 | # eg:i32[] = add ee 2 1147 | # eh:i32[] = pjit[name=_where jaxpr=_where] ef eg ee 1148 | # ei:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] eh 1149 | # ej:f32[128] = gather[ 1150 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(0,), start_index_map=(0,)) 1151 | # fill_value=nan 1152 | # indices_are_sorted=False 1153 | # mode=GatherScatterMode.FILL_OR_DROP 1154 | # slice_sizes=(1, 128) 1155 | # unique_indices=False 1156 | # ] ed ei 1157 | # in (ej,) } 1158 | # ] bm 0 1159 | # _:f32[128,32] = pjit[ 1160 | # name=_take 1161 | # jaxpr={ lambda ; ek:f32[2,128,32] el:i32[]. let 1162 | # em:bool[] = lt el 0 1163 | # en:i32[] = add el 2 1164 | # eo:i32[] = pjit[name=_where jaxpr=_where] em en el 1165 | # ep:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] eo 1166 | # eq:f32[128,32] = gather[ 1167 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(0,), start_index_map=(0,)) 1168 | # fill_value=nan 1169 | # indices_are_sorted=False 1170 | # mode=GatherScatterMode.FILL_OR_DROP 1171 | # slice_sizes=(1, 128, 32) 1172 | # unique_indices=False 1173 | # ] ek ep 1174 | # in (eq,) } 1175 | # ] bn 0 1176 | # _:f32[32] = pjit[name=_take jaxpr=_take] bo 0 1177 | # _:i32[] = pjit[ 1178 | # name=_take 1179 | # jaxpr={ lambda ; er:i32[2] es:i32[]. let 1180 | # et:bool[] = lt es 0 1181 | # eu:i32[] = add es 2 1182 | # ev:i32[] = pjit[name=_where jaxpr=_where] et eu es 1183 | # ew:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] ev 1184 | # ex:i32[] = gather[ 1185 | # dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) 1186 | # fill_value=-2147483648 1187 | # indices_are_sorted=False 1188 | # mode=GatherScatterMode.FILL_OR_DROP 1189 | # slice_sizes=(1,) 1190 | # unique_indices=False 1191 | # ] er ew 1192 | # in (ex,) } 1193 | # ] da 0 1194 | # ey:f32[32,512,32] = scan[ 1195 | # _split_transpose=False 1196 | # jaxpr={ lambda ; ez:f32[32,512,32] fa:f32[32] fb:f32[32] fc:f32[32,3,4,8] fd:f32[3,4,8] 1197 | # fe:f32[4,8,32] ff:f32[32] fg:f32[32] fh:f32[32] fi:f32[32,128] fj:f32[128] 1198 | # fk:f32[128,32] fl:f32[32] fm:i32[]. let 1199 | # fn:f32[32,512,32] = remat2[ 1200 | # differentiated=False 1201 | # jaxpr={ lambda ; fo:f32[32,512,32] fp:f32[32] fq:f32[32] fr:f32[32,3,4,8] 1202 | # fs:f32[3,4,8] ft:f32[4,8,32] fu:f32[32] fv:f32[32] fw:f32[32] fx:f32[32,128] 1203 | # fy:f32[128] fz:f32[128,32] ga:f32[32] gb:i32[]. let 1204 | # gc:f32[32,512] = reduce_sum[axes=(2,)] fo 1205 | # gd:f32[32,512] = div gc 32.0 1206 | # ge:f32[32,512] = pjit[name=_var jaxpr=_var] fo 0 1207 | # gf:f32[32,512] = add ge 9.999999747378752e-06 1208 | # gg:f32[32,512] = rsqrt gf 1209 | # gh:f32[32,32,512] = broadcast_in_dim[ 1210 | # broadcast_dimensions=(1, 2) 1211 | # shape=(32, 32, 512) 1212 | # ] gd 1213 | # gi:f32[32,512,32] = transpose[permutation=(1, 2, 0)] gh 1214 | # gj:f32[32,512,32] = sub fo gi 1215 | # gk:f32[32,32,512] = broadcast_in_dim[ 1216 | # broadcast_dimensions=(1, 2) 1217 | # shape=(32, 32, 512) 1218 | # ] gg 1219 | # gl:f32[32,512,32] = transpose[permutation=(1, 2, 0)] gk 1220 | # gm:f32[32,512,32] = mul gj gl 1221 | # gn:f32[32,512,32] = broadcast_in_dim[ 1222 | # broadcast_dimensions=(2,) 1223 | # shape=(32, 512, 32) 1224 | # ] fp 1225 | # go:f32[32,512,32] = mul gn gm 1226 | # gp:f32[32,512,32] = broadcast_in_dim[ 1227 | # broadcast_dimensions=(2,) 1228 | # shape=(32, 512, 32) 1229 | # ] fq 1230 | # gq:f32[32,512,32] = add go gp 1231 | # gr:f32[32,512,3,4,8] = dot_general[ 1232 | # dimension_numbers=(([2], [0]), ([], [])) 1233 | # preferred_element_type=float32 1234 | # ] gq fr 1235 | # gs:f32[32,512,3,4,8] = broadcast_in_dim[ 1236 | # broadcast_dimensions=(2, 3, 4) 1237 | # shape=(32, 512, 3, 4, 8) 1238 | # ] fs 1239 | # gt:f32[32,512,3,4,8] = add gr gs 1240 | # gu:f32[32,3,4,512,8] = transpose[permutation=(0, 2, 3, 1, 4)] gt 1241 | # gv:f32[32,4,512,8] = pjit[name=_take jaxpr=_take1] gu 0 1242 | # gw:f32[32,4,512,8] = pjit[name=_take jaxpr=_take1] gu 1 1243 | # gx:f32[32,4,512,8] = pjit[name=_take jaxpr=_take1] gu 2 1244 | # gy:i32[512] = iota[dimension=0 dtype=int32 shape=(512,)] 1245 | # gz:i32[512] = mul gy 1 1246 | # ha:i32[512] = add gz 0 1247 | # hb:i32[512] = iota[dimension=0 dtype=int32 shape=(512,)] 1248 | # hc:i32[512] = mul hb 1 1249 | # hd:i32[512] = add hc 0 1250 | # he:i32[512,512] = broadcast_in_dim[ 1251 | # broadcast_dimensions=(1,) 1252 | # shape=(512, 512) 1253 | # ] hd 1254 | # hf:i32[512,512] = broadcast_in_dim[ 1255 | # broadcast_dimensions=(1,) 1256 | # shape=(512, 512) 1257 | # ] ha 1258 | # hg:i32[512,512] = transpose[permutation=(1, 0)] hf 1259 | # hh:bool[512,512] = ge hg he 1260 | # hi:f32[] = sqrt 8.0 1261 | # hj:f32[] = convert_element_type[ 1262 | # new_dtype=float32 1263 | # weak_type=False 1264 | # ] hi 1265 | # hk:f32[32,4,512,8] = div gv hj 1266 | # hl:f32[32,4,512,512] = dot_general[ 1267 | # dimension_numbers=(([3], [3]), ([0, 1], [0, 1])) 1268 | # preferred_element_type=float32 1269 | # ] hk gw 1270 | # hm:bool[32,4,512,512] = broadcast_in_dim[ 1271 | # broadcast_dimensions=(2, 3) 1272 | # shape=(32, 4, 512, 512) 1273 | # ] hh 1274 | # hn:f32[32,4,512,512] = pjit[ 1275 | # name=_where 1276 | # jaxpr={ lambda ; ho:bool[32,4,512,512] hp:f32[32,4,512,512] hq:f32[]. let 1277 | # hr:f32[] = convert_element_type[ 1278 | # new_dtype=float32 1279 | # weak_type=False 1280 | # ] hq 1281 | # hs:f32[32,4,512,512] = broadcast_in_dim[ 1282 | # broadcast_dimensions=() 1283 | # shape=(32, 4, 512, 512) 1284 | # ] hr 1285 | # ht:f32[32,4,512,512] = select_n ho hs hp 1286 | # in (ht,) } 1287 | # ] hm hl -1000000000.0 1288 | # hu:f32[32,4,512,512] = custom_jvp_call[ 1289 | # call_jaxpr={ lambda ; hv:f32[32,4,512,512]. let 1290 | # hw:f32[32,4,512] = reduce_max[axes=(3,)] hv 1291 | # hx:f32[32,4,512,1] = broadcast_in_dim[ 1292 | # broadcast_dimensions=(0, 1, 2) 1293 | # shape=(32, 4, 512, 1) 1294 | # ] hw 1295 | # hy:f32[32,4,512,512] = sub hv hx 1296 | # hz:f32[32,4,512,512] = exp hy 1297 | # ia:f32[32,4,512] = reduce_sum[axes=(3,)] hz 1298 | # ib:f32[32,4,512,1] = broadcast_in_dim[ 1299 | # broadcast_dimensions=(0, 1, 2) 1300 | # shape=(32, 4, 512, 1) 1301 | # ] ia 1302 | # ic:f32[32,4,512,512] = div hz ib 1303 | # in (ic,) } 1304 | # jvp_jaxpr_thunk=.memoized at 0x348214b80> 1305 | # num_consts=0 1306 | # symbolic_zeros=False 1307 | # ] hn 1308 | # id:f32[32,4,512,8] = dot_general[ 1309 | # dimension_numbers=(([3], [2]), ([0, 1], [0, 1])) 1310 | # preferred_element_type=float32 1311 | # ] hu gx 1312 | # ie:f32[32,512,32] = dot_general[ 1313 | # dimension_numbers=(([1, 3], [0, 1]), ([], [])) 1314 | # preferred_element_type=float32 1315 | # ] id ft 1316 | # if:f32[32,512,32] = broadcast_in_dim[ 1317 | # broadcast_dimensions=(2,) 1318 | # shape=(32, 512, 32) 1319 | # ] fu 1320 | # ig:f32[32,512,32] = add ie if 1321 | # ih:f32[32,512,32] = add fo ig 1322 | # ii:f32[32,512] = reduce_sum[axes=(2,)] ih 1323 | # ij:f32[32,512] = div ii 32.0 1324 | # ik:f32[32,512] = pjit[name=_var jaxpr=_var] ih 0 1325 | # il:f32[32,512] = add ik 9.999999747378752e-06 1326 | # im:f32[32,512] = rsqrt il 1327 | # in:f32[32,32,512] = broadcast_in_dim[ 1328 | # broadcast_dimensions=(1, 2) 1329 | # shape=(32, 32, 512) 1330 | # ] ij 1331 | # io:f32[32,512,32] = transpose[permutation=(1, 2, 0)] in 1332 | # ip:f32[32,512,32] = sub ih io 1333 | # iq:f32[32,32,512] = broadcast_in_dim[ 1334 | # broadcast_dimensions=(1, 2) 1335 | # shape=(32, 32, 512) 1336 | # ] im 1337 | # ir:f32[32,512,32] = transpose[permutation=(1, 2, 0)] iq 1338 | # is:f32[32,512,32] = mul ip ir 1339 | # it:f32[32,512,32] = broadcast_in_dim[ 1340 | # broadcast_dimensions=(2,) 1341 | # shape=(32, 512, 32) 1342 | # ] fv 1343 | # iu:f32[32,512,32] = mul it is 1344 | # iv:f32[32,512,32] = broadcast_in_dim[ 1345 | # broadcast_dimensions=(2,) 1346 | # shape=(32, 512, 32) 1347 | # ] fw 1348 | # iw:f32[32,512,32] = add iu iv 1349 | # ix:f32[32,512,128] = dot_general[ 1350 | # dimension_numbers=(([2], [0]), ([], [])) 1351 | # preferred_element_type=float32 1352 | # ] iw fx 1353 | # iy:f32[32,512,128] = broadcast_in_dim[ 1354 | # broadcast_dimensions=(2,) 1355 | # shape=(32, 512, 128) 1356 | # ] fy 1357 | # iz:f32[32,512,128] = add ix iy 1358 | # ja:f32[32,512,128] = integer_pow[y=3] iz 1359 | # jb:f32[32,512,128] = mul 0.044714998453855515 ja 1360 | # jc:f32[32,512,128] = add iz jb 1361 | # jd:f32[32,512,128] = mul 0.7978845834732056 jc 1362 | # je:f32[32,512,128] = tanh jd 1363 | # jf:f32[32,512,128] = add 1.0 je 1364 | # jg:f32[32,512,128] = mul 0.5 jf 1365 | # jh:f32[32,512,128] = mul iz jg 1366 | # ji:f32[32,512,32] = dot_general[ 1367 | # dimension_numbers=(([2], [0]), ([], [])) 1368 | # preferred_element_type=float32 1369 | # ] jh fz 1370 | # jj:f32[32,512,32] = broadcast_in_dim[ 1371 | # broadcast_dimensions=(2,) 1372 | # shape=(32, 512, 32) 1373 | # ] ga 1374 | # jk:f32[32,512,32] = add ji jj 1375 | # jl:f32[32,512,32] = add ih jk 1376 | # in (jl,) } 1377 | # policy=None 1378 | # prevent_cse=False 1379 | # ] ez fa fb fc fd fe ff fg fh fi fj fk fl fm 1380 | # in (fn,) } 1381 | # length=2 1382 | # linear=(False, False, False, False, False, False, False, False, False, False, False, False, False, False) 1383 | # num_carry=1 1384 | # num_consts=0 1385 | # reverse=False 1386 | # unroll=1 1387 | # ] cx bd be bf bg bh bi bj bk bl bm bn bo da 1388 | # jm:f32[32,512] = reduce_sum[axes=(2,)] ey 1389 | # jn:f32[32,512] = div jm 32.0 1390 | # jo:f32[32,512] = pjit[name=_var jaxpr=_var] ey 0 1391 | # jp:f32[32,512] = add jo 9.999999747378752e-06 1392 | # jq:f32[32,512] = rsqrt jp 1393 | # jr:f32[32,32,512] = broadcast_in_dim[ 1394 | # broadcast_dimensions=(1, 2) 1395 | # shape=(32, 32, 512) 1396 | # ] jn 1397 | # js:f32[32,512,32] = transpose[permutation=(1, 2, 0)] jr 1398 | # jt:f32[32,512,32] = sub ey js 1399 | # ju:f32[32,32,512] = broadcast_in_dim[ 1400 | # broadcast_dimensions=(1, 2) 1401 | # shape=(32, 32, 512) 1402 | # ] jq 1403 | # jv:f32[32,512,32] = transpose[permutation=(1, 2, 0)] ju 1404 | # jw:f32[32,512,32] = mul jt jv 1405 | # jx:f32[32,512,32] = broadcast_in_dim[ 1406 | # broadcast_dimensions=(2,) 1407 | # shape=(32, 512, 32) 1408 | # ] bp 1409 | # jy:f32[32,512,32] = mul jx jw 1410 | # jz:f32[32,512,32] = broadcast_in_dim[ 1411 | # broadcast_dimensions=(2,) 1412 | # shape=(32, 512, 32) 1413 | # ] bq 1414 | # ka:f32[32,512,32] = add jy jz 1415 | # kb:f32[32,512,50257] = dot_general[ 1416 | # dimension_numbers=(([2], [1]), ([], [])) 1417 | # preferred_element_type=float32 1418 | # ] ka br 1419 | # kc:i32[32,512] = pjit[ 1420 | # name=_roll_static 1421 | # jaxpr={ lambda ; kd:i32[32,512]. let 1422 | # ke:i32[32,511] = slice[ 1423 | # limit_indices=(32, 512) 1424 | # start_indices=(0, 1) 1425 | # strides=(1, 1) 1426 | # ] kd 1427 | # kf:i32[32,1] = slice[ 1428 | # limit_indices=(32, 1) 1429 | # start_indices=(0, 0) 1430 | # strides=(1, 1) 1431 | # ] kd 1432 | # kg:i32[32,512] = concatenate[dimension=1] ke kf 1433 | # in (kg,) } 1434 | # ] bt 1435 | # kh:f32[32,512,50257] = pjit[ 1436 | # name=_one_hot 1437 | # jaxpr={ lambda ; ki:i32[32,512]. let 1438 | # kj:i32[32,512,1] = broadcast_in_dim[ 1439 | # broadcast_dimensions=(0, 1) 1440 | # shape=(32, 512, 1) 1441 | # ] ki 1442 | # kk:i32[1,1,50257] = iota[dimension=2 dtype=int32 shape=(1, 1, 50257)] 1443 | # kl:bool[32,512,50257] = eq kj kk 1444 | # km:f32[32,512,50257] = convert_element_type[ 1445 | # new_dtype=float32 1446 | # weak_type=False 1447 | # ] kl 1448 | # in (km,) } 1449 | # ] kc 1450 | # kn:f32[32,512] = reduce_max[axes=(2,)] kb 1451 | # ko:bool[32,512] = is_finite kn 1452 | # kp:f32[32,512] = broadcast_in_dim[broadcast_dimensions=() shape=(32, 512)] 0.0 1453 | # kq:f32[32,512] = select_n ko kp kn 1454 | # kr:f32[32,512] = stop_gradient kq 1455 | # ks:f32[32,512,1] = broadcast_in_dim[ 1456 | # broadcast_dimensions=(0, 1) 1457 | # shape=(32, 512, 1) 1458 | # ] kr 1459 | # kt:f32[32,512,50257] = sub kb ks 1460 | # ku:f32[32,512,50257] = exp kt 1461 | # kv:f32[32,512] = reduce_sum[axes=(2,)] ku 1462 | # _:f32[32,512] = sign kv 1463 | # kw:f32[32,512] = abs kv 1464 | # kx:f32[32,512] = log kw 1465 | # ky:f32[32,512] = add kx kr 1466 | # kz:f32[50257,32,512] = broadcast_in_dim[ 1467 | # broadcast_dimensions=(1, 2) 1468 | # shape=(50257, 32, 512) 1469 | # ] ky 1470 | # la:f32[32,512,50257] = transpose[permutation=(1, 2, 0)] kz 1471 | # lb:f32[32,512,50257] = sub la kb 1472 | # lc:f32[32,512] = dot_general[ 1473 | # dimension_numbers=(([2], [2]), ([0, 1], [0, 1])) 1474 | # preferred_element_type=float32 1475 | # ] kh lb 1476 | # ld:f32[] = reduce_sum[axes=(0, 1)] bu 1477 | # le:f32[32,512] = pjit[ 1478 | # name=_where 1479 | # jaxpr={ lambda ; lf:f32[32,512] lg:f32[32,512] lh:f32[]. let 1480 | # li:bool[32,512] = ne lf 0.0 1481 | # lj:f32[32,512] = broadcast_in_dim[ 1482 | # broadcast_dimensions=() 1483 | # shape=(32, 512) 1484 | # ] lh 1485 | # lk:f32[32,512] = select_n li lj lg 1486 | # in (lk,) } 1487 | # ] bu lc 0.0 1488 | # ll:f32[] = reduce_sum[axes=(0, 1)] le 1489 | # lm:f32[] = div ll ld 1490 | # in (lm,) } --------------------------------------------------------------------------------