├── .github └── workflows │ ├── build_docs.yml │ ├── release.yml │ └── run_tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── .citation.md ├── .htaccess ├── _overrides │ ├── bluesky.svg │ └── partials │ │ └── source.html ├── _static │ ├── .README.md │ ├── custom_css.css │ ├── favicon.png │ └── mathjax.js ├── api │ ├── advanced-features.md │ ├── array.md │ ├── pytree.md │ └── runtime-type-checking.md ├── faq.md └── index.md ├── jaxtyping ├── __init__.py ├── _array_types.py ├── _config.py ├── _decorator.py ├── _errors.py ├── _import_hook.py ├── _indirection.py ├── _ipython_extension.py ├── _pytest_plugin.py ├── _pytree_type.py ├── _storage.py ├── _typeguard │ ├── LICENSE │ ├── README.md │ └── __init__.py └── py.typed ├── mkdocs.yml ├── pyproject.toml └── test ├── __init__.py ├── conftest.py ├── helpers.py ├── import_hook_tester.py ├── requirements.txt ├── test_all_importable.py ├── test_array.py ├── test_decorator.py ├── test_generators.py ├── test_import_hook.py ├── test_ipython_extension.py ├── test_messages.py ├── test_no_jax_dependency.py ├── test_pytree.py ├── test_serialisation.py ├── test_tf_dtype.py ├── test_threading.py └── types ├── __init__.py └── decorator.py /.github/workflows/build_docs.yml: -------------------------------------------------------------------------------- 1 | name: Build docs 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | build: 10 | strategy: 11 | matrix: 12 | python-version: [ 3.11 ] 13 | os: [ ubuntu-latest ] 14 | runs-on: ${{ matrix.os }} 15 | steps: 16 | - name: Checkout code 17 | uses: actions/checkout@v2 18 | 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | python -m pip install '.[docs]' 28 | python -m pip install jax[cpu] 29 | 30 | - name: Build docs 31 | run: | 32 | mkdocs build 33 | 34 | - name: Upload docs 35 | uses: actions/upload-artifact@v4 36 | with: 37 | name: docs 38 | path: site # where `mkdocs build` puts the built site 39 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Google LLC 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | name: Release 21 | 22 | on: 23 | push: 24 | branches: 25 | - main 26 | 27 | jobs: 28 | build: 29 | runs-on: ubuntu-latest 30 | steps: 31 | - name: Release 32 | uses: patrick-kidger/action_update_python_project@v6 33 | with: 34 | python-version: "3.11" 35 | test-script: | 36 | python -m pip install -r ${{ github.workspace }}/test/requirements.txt 37 | python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu 38 | cp -r ${{ github.workspace }}/test ./test 39 | pytest 40 | pypi-token: ${{ secrets.pypi_token }} 41 | github-user: patrick-kidger 42 | github-token: ${{ github.token }} 43 | -------------------------------------------------------------------------------- /.github/workflows/run_tests.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Google LLC 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | name: Run tests 21 | 22 | on: 23 | pull_request: 24 | 25 | jobs: 26 | run-tests: 27 | strategy: 28 | matrix: 29 | python-version: [ "3.10", "3.12" ] 30 | os: [ ubuntu-latest ] 31 | fail-fast: false 32 | runs-on: ${{ matrix.os }} 33 | steps: 34 | - name: Checkout code 35 | uses: actions/checkout@v2 36 | 37 | - name: Set up Python ${{ matrix.python-version }} 38 | uses: actions/setup-python@v2 39 | with: 40 | python-version: ${{ matrix.python-version }} 41 | 42 | - name: Install dependencies 43 | run: | 44 | python -m pip install --upgrade pip 45 | python -m pip install -r test/requirements.txt 46 | python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu 47 | 48 | - name: Checks with pre-commit 49 | uses: pre-commit/action@v3.0.1 50 | 51 | - name: Test with pytest 52 | run: | 53 | python -m pip install . 54 | python -m pytest --durations=0 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | *.egg-info 3 | build/ 4 | dist/ 5 | site/ 6 | .all_objects.cache 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Google LLC 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | repos: 21 | - repo: https://github.com/astral-sh/ruff-pre-commit 22 | rev: v0.7.3 23 | hooks: 24 | - id: ruff-format # formatter 25 | types_or: [ python, pyi, jupyter ] 26 | - id: ruff # linter 27 | types_or: [ python, pyi, jupyter ] 28 | args: [ --fix ] 29 | - repo: https://github.com/RobertCraigie/pyright-python 30 | rev: v1.1.391 31 | hooks: 32 | - id: pyright 33 | files: ^test/types/ 34 | additional_dependencies: 35 | [beartype, numpy<2] 36 | - repo: https://github.com/pre-commit/mirrors-mypy 37 | rev: v1.14.1 38 | hooks: 39 | - id: mypy 40 | files: ^test/types/ 41 | additional_dependencies: 42 | [beartype, numpy<2] 43 | args: ["--ignore-missing-imports", "--follow-imports=skip"] -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contributions (pull requests) are very welcome! Here's how to get started. 4 | 5 | --- 6 | 7 | First fork the library on GitHub. 8 | 9 | Then clone and install the library in development mode: 10 | 11 | ```bash 12 | git clone https://github.com/your-username-here/jaxtyping.git 13 | cd jaxtyping 14 | pip install -e . 15 | ``` 16 | 17 | Then install the pre-commit hook: 18 | 19 | ```bash 20 | pip install pre-commit 21 | pre-commit install 22 | ``` 23 | 24 | These hooks use ruff to lint and format the code. 25 | 26 | Now make your changes. Make sure to include additional tests if necessary. 27 | 28 | Next verify the tests all pass: 29 | 30 | ```bash 31 | pip install -r test/requirements.txt 32 | pip install torch --extra-index-url https://download.pytorch.org/whl/cpu 33 | pytest 34 | ``` 35 | 36 | Then push your changes back to your fork of the repository: 37 | 38 | ```bash 39 | git push 40 | ``` 41 | 42 | Finally, open a pull request on GitHub! 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Google LLC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | 24 | 25 | 26 | --- 27 | Sections of the code were modified from https://github.com/agronholm/typeguard 28 | under the terms of the MIT license, reproduced below. 29 | --- 30 | 31 | MIT License 32 | 33 | Copyright (c) Alex Grönholm 34 | 35 | Permission is hereby granted, free of charge, to any person obtaining a copy 36 | of this software and associated documentation files (the "Software"), to deal 37 | in the Software without restriction, including without limitation the rights 38 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 39 | copies of the Software, and to permit persons to whom the Software is 40 | furnished to do so, subject to the following conditions: 41 | 42 | The above copyright notice and this permission notice shall be included in all 43 | copies or substantial portions of the Software. 44 | 45 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 46 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 47 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 48 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 49 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 50 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 51 | SOFTWARE. 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

jaxtyping

2 | 3 | Type annotations **and runtime type-checking** for: 4 | 5 | 1. shape and dtype of [JAX](https://github.com/google/jax) arrays; *(Now also supports PyTorch, NumPy, MLX, and TensorFlow!)* 6 | 2. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html). 7 | 8 | **For example:** 9 | ```python 10 | from jaxtyping import Array, Float, PyTree 11 | 12 | # Accepts floating-point 2D arrays with matching axes 13 | # You can replace `Array` with `torch.Tensor` etc. 14 | def matrix_multiply(x: Float[Array, "dim1 dim2"], 15 | y: Float[Array, "dim2 dim3"] 16 | ) -> Float[Array, "dim1 dim3"]: 17 | ... 18 | 19 | def accepts_pytree_of_ints(x: PyTree[int]): 20 | ... 21 | 22 | def accepts_pytree_of_arrays(x: PyTree[Float[Array, "batch c1 c2"]]): 23 | ... 24 | ``` 25 | 26 | ## Installation 27 | 28 | ```bash 29 | pip install jaxtyping 30 | ``` 31 | 32 | Requires Python 3.10+. 33 | 34 | JAX is an optional dependency, required for a few JAX-specific types. If JAX is not installed then these will not be available, but you may still use jaxtyping to provide shape/dtype annotations for PyTorch/NumPy/TensorFlow/etc. 35 | 36 | The annotations provided by jaxtyping are compatible with runtime type-checking packages, so it is common to also install one of these. The two most popular are [typeguard](https://github.com/agronholm/typeguard) (which exhaustively checks every argument) and [beartype](https://github.com/beartype/beartype) (which checks random pieces of arguments). 37 | 38 | ## Documentation 39 | 40 | Available at [https://docs.kidger.site/jaxtyping](https://docs.kidger.site/jaxtyping). 41 | 42 | ## See also: other libraries in the JAX ecosystem 43 | 44 | **Always useful** 45 | [Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX! 46 | 47 | **Deep learning** 48 | [Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers. 49 | [Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device). 50 | [Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs). 51 | [paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees. 52 | 53 | **Scientific computing** 54 | [Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers. 55 | [Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares. 56 | [Lineax](https://github.com/patrick-kidger/lineax): linear solvers. 57 | [BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling. 58 | [sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent. 59 | [PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!) 60 | 61 | **Awesome JAX** 62 | [Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects. 63 | -------------------------------------------------------------------------------- /docs/.citation.md: -------------------------------------------------------------------------------- 1 | If you found this library to be useful in academic work, then please cite: ([arXiv link](https://arxiv.org/abs/2111.00254)) 2 | 3 | ```bibtex 4 | @article{kidger2021equinox, 5 | author={Patrick Kidger and Cristian Garcia}, 6 | title={{E}quinox: neural networks in {JAX} via callable {P}y{T}rees and filtered transformations}, 7 | year={2021}, 8 | journal={Differentiable Programming workshop at Neural Information Processing Systems 2021} 9 | } 10 | ``` 11 | 12 | (Also consider starring the project [on GitHub](https://github.com/patrick-kidger/equinox).) 13 | -------------------------------------------------------------------------------- /docs/.htaccess: -------------------------------------------------------------------------------- 1 | ErrorDocument 404 /jaxtyping/404.html 2 | -------------------------------------------------------------------------------- /docs/_overrides/bluesky.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/_overrides/partials/source.html: -------------------------------------------------------------------------------- 1 | {% import "partials/language.html" as lang with context %} 2 | 3 |
4 | {% set icon = config.theme.icon.repo or "fontawesome/brands/git-alt" %} 5 | {% include ".icons/" ~ icon ~ ".svg" %} 6 |
7 |
8 | {{ config.repo_name }} 9 |
10 |
11 | 12 |
13 | {% include ".icons/fontawesome/brands/twitter.svg" %} 14 |
15 |
16 | 17 |
18 | {% include "bluesky.svg" %} 19 |
20 |
21 | {{ config.theme.twitter_bluesky_name }} 22 |
23 |
24 | -------------------------------------------------------------------------------- /docs/_static/.README.md: -------------------------------------------------------------------------------- 1 | The favicon is `math-integral` from https://materialdesignicons.com, found by way of https://pictogrammers.com. 2 | (The logo is `math-integral-box`.) 3 | -------------------------------------------------------------------------------- /docs/_static/custom_css.css: -------------------------------------------------------------------------------- 1 | /* Fix /page#foo going to the top of the viewport and being hidden by the navbar */ 2 | html { 3 | scroll-padding-top: 50px; 4 | } 5 | 6 | /* Fit the Twitter handle alongside the GitHub one in the top right. */ 7 | 8 | div.md-header__source { 9 | width: revert; 10 | max-width: revert; 11 | } 12 | 13 | a.md-source { 14 | display: inline-block; 15 | } 16 | 17 | .md-source__repository { 18 | max-width: 100%; 19 | } 20 | 21 | /* Emphasise sections of nav on left hand side */ 22 | 23 | nav.md-nav { 24 | padding-left: 5px; 25 | } 26 | 27 | nav.md-nav--secondary { 28 | border-left: revert !important; 29 | } 30 | 31 | .md-nav__title { 32 | font-size: 0.9rem; 33 | } 34 | 35 | .md-nav__item--section > .md-nav__link { 36 | font-size: 0.9rem; 37 | } 38 | 39 | /* Indent autogenerated documentation */ 40 | 41 | div.doc-contents { 42 | padding-left: 25px; 43 | border-left: 4px solid rgba(230, 230, 230); 44 | } 45 | 46 | /* Increase visibility of splitters "---" */ 47 | 48 | [data-md-color-scheme="default"] .md-typeset hr { 49 | border-bottom-color: rgb(0, 0, 0); 50 | border-bottom-width: 1pt; 51 | } 52 | 53 | [data-md-color-scheme="slate"] .md-typeset hr { 54 | border-bottom-color: rgb(230, 230, 230); 55 | } 56 | 57 | /* More space at the bottom of the page */ 58 | 59 | .md-main__inner { 60 | margin-bottom: 1.5rem; 61 | } 62 | 63 | /* Remove prev/next footer buttons */ 64 | 65 | .md-footer__inner { 66 | display: none; 67 | } 68 | 69 | /* Change font sizes */ 70 | 71 | html { 72 | /* Decrease font size for overall webpage 73 | Down from 137.5% which is the Material default */ 74 | font-size: 110%; 75 | } 76 | 77 | .md-typeset .admonition { 78 | /* Increase font size in admonitions */ 79 | font-size: 100% !important; 80 | } 81 | 82 | .md-typeset details { 83 | /* Increase font size in details */ 84 | font-size: 100% !important; 85 | } 86 | 87 | .md-typeset h1 { 88 | font-size: 1.6rem; 89 | } 90 | 91 | .md-typeset h2 { 92 | font-size: 1.5rem; 93 | } 94 | 95 | .md-typeset h3 { 96 | font-size: 1.3rem; 97 | } 98 | 99 | .md-typeset h4 { 100 | font-size: 1.1rem; 101 | } 102 | 103 | .md-typeset h5 { 104 | font-size: 0.9rem; 105 | } 106 | 107 | .md-typeset h6 { 108 | font-size: 0.8rem; 109 | } 110 | 111 | /* Bugfix: remove the superfluous parts generated when doing: 112 | 113 | ??? Blah 114 | 115 | ::: library.something 116 | */ 117 | 118 | .md-typeset details .mkdocstrings > h4 { 119 | display: none; 120 | } 121 | 122 | .md-typeset details .mkdocstrings > h5 { 123 | display: none; 124 | } 125 | 126 | /* Change default colours for tags */ 127 | 128 | [data-md-color-scheme="default"] { 129 | --md-typeset-a-color: rgb(0, 189, 164) !important; 130 | } 131 | [data-md-color-scheme="slate"] { 132 | --md-typeset-a-color: rgb(0, 189, 164) !important; 133 | } 134 | 135 | /* Highlight functions, classes etc. type signatures. Really helps to make clear where 136 | one item ends and another begins. */ 137 | 138 | [data-md-color-scheme="default"] { 139 | --doc-heading-color: #DDD; 140 | --doc-heading-border-color: #CCC; 141 | --doc-heading-color-alt: #F0F0F0; 142 | } 143 | [data-md-color-scheme="slate"] { 144 | --doc-heading-color: rgb(25,25,33); 145 | --doc-heading-border-color: rgb(25,25,33); 146 | --doc-heading-color-alt: rgb(33,33,44); 147 | --md-code-bg-color: rgb(38,38,50); 148 | } 149 | 150 | h4.doc-heading { 151 | /* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/ 152 | background-color: var(--doc-heading-color); 153 | border: solid var(--doc-heading-border-color); 154 | border-width: 1.5pt; 155 | border-radius: 2pt; 156 | padding: 0pt 5pt 2pt 5pt; 157 | } 158 | h5.doc-heading, h6.heading { 159 | background-color: var(--doc-heading-color-alt); 160 | border-radius: 2pt; 161 | padding: 0pt 5pt 2pt 5pt; 162 | } 163 | 164 | /* Make errors in notebooks have scrolling */ 165 | .output_error > pre { 166 | overflow: auto; 167 | } 168 | -------------------------------------------------------------------------------- /docs/_static/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrick-kidger/jaxtyping/fe61644be8590cf0a89bacc278283e1e4b9ea3e4/docs/_static/favicon.png -------------------------------------------------------------------------------- /docs/_static/mathjax.js: -------------------------------------------------------------------------------- 1 | window.MathJax = { 2 | tex: { 3 | inlineMath: [["\\(", "\\)"]], 4 | displayMath: [["\\[", "\\]"]], 5 | processEscapes: true, 6 | processEnvironments: true 7 | }, 8 | options: { 9 | ignoreHtmlClass: ".*|", 10 | processHtmlClass: "arithmatex" 11 | } 12 | }; 13 | 14 | document$.subscribe(() => { 15 | MathJax.typesetPromise() 16 | }) 17 | -------------------------------------------------------------------------------- /docs/api/advanced-features.md: -------------------------------------------------------------------------------- 1 | # Advanced features 2 | 3 | ## Creating your own dtypes 4 | 5 | ::: jaxtyping.AbstractDtype 6 | options: 7 | members: [] 8 | 9 | ::: jaxtyping.make_numpy_struct_dtype 10 | 11 | ## Printing axis bindings 12 | 13 | ::: jaxtyping.print_bindings 14 | 15 | ## Introspection 16 | 17 | If you're writing your own type hint parser, then you may wish to detect if some Python object is a jaxtyping-provided type. 18 | 19 | You can check for dtypes by doing `issubclass(x, AbstractDtype)`. For example, `issubclass(Float32, AbstractDtype)` will pass. 20 | 21 | You can check for arrays by doing `issubclass(x, AbstractArray)`. Here, `AbstractArray` is the base class for all shape-and-dtype specified arrays, e.g. it's a base class for `Float32[Array, "foo"]`. 22 | 23 | You can check for pytrees by doing `issubclass(x, PyTree)`. For example, `issubclass(PyTree[int], PyTree)` will pass. 24 | -------------------------------------------------------------------------------- /docs/api/array.md: -------------------------------------------------------------------------------- 1 | # Array annotations 2 | 3 | The shape and dtypes of arrays can be annotated in the form `dtype[array, shape]`, such as `Float[Array, "batch channels"]`. 4 | 5 | ## Shape 6 | 7 | **Symbols** 8 | 9 | The shape should be a string of space-separated symbols, such as `"a b c d"`. Each symbol can be either an: 10 | 11 | - `int`: fixed-size axis, e.g. `"28 28"`. 12 | - `str`: variable-size axis, e.g. `"channels"`. 13 | - A symbolic expression in terms of other variable-size axes, e.g. 14 | `def remove_last(x: Float[Array, "dim"]) -> Float[Array, "dim-1"]`. 15 | Symbolic expressions must not use any spaces, otherwise each piece is treated as as a separate axis. 16 | 17 | When calling a function, variable-size axes and symbolic axes will be matched up across all arguments and checked for consistency. (See [Runtime type checking](./runtime-type-checking.md).) 18 | 19 | **Modifiers** 20 | 21 | In addition some modifiers can be applied: 22 | 23 | - Prepend `*` to an axis to indicate that it can match multiple axes, e.g. `"*batch"` will match zero or more batch axes. 24 | - Prepend `#` to an axis to indicate that it can be that size *or* equal to one -- i.e. broadcasting is acceptable, e.g. 25 | `def add(x: Float[Array, "#foo"], y: Float[Array, "#foo"]) -> Float[Array, "#foo"]`. 26 | - Prepend `_` to an axis to disable any runtime checking of that axis (so that it can be used just as documentation). This can also be used as just `_` on its own: e.g. `"b c _ _"`. 27 | - Documentation-only names (i.e. they're ignored by jaxtyping) can be handled by prepending a name followed by `=` e.g. `Float[Array, "rows=4 cols=3"]`. 28 | - Prepend `?` to an axis to indicate that its size can vary within a PyTree structure. (See [PyTree annotations](./pytree.md).) 29 | 30 | When using multiple modifiers, their order does not matter. 31 | 32 | As a special case: 33 | 34 | - `...`: anonymous zero or more axes (equivalent to `*_`) e.g. `"... c h w"` 35 | 36 | **Notes** 37 | 38 | - To denote a scalar shape use `""`, e.g. `Float[Array, ""]`. 39 | - To denote an arbitrary shape (and only check dtype) use `"..."`, e.g. `Float[Array, "..."]`. 40 | - You cannot have more than one use of multiple-axes, i.e. you can only use `...` or `*name` at most once in each array. 41 | - A symbolic expression cannot be evaluated unless all of the axes sizes it refers to have already been processed. In practice this usually means that they should only be used in annotations for the return type, and only use axes declared in the arguments. 42 | - Symbolic expressions are evaluated in two stages: they are first evaluated as f-strings using the arguments of the function, and second are evaluated using the processed axis sizes. The f-string evaluation means that they can use local variables by enclosing them with curly braces, e.g. `{variable}`, e.g. 43 | ```python 44 | def full(size: int, fill: float) -> Float[Array, "{size}"]: 45 | return jax.numpy.full((size,), fill) 46 | 47 | class SomeClass: 48 | some_value = 5 49 | 50 | def full(self, fill: float) -> Float[Array, "{self.some_value}+3"]: 51 | return jax.numpy.full((self.some_value + 3,), fill) 52 | ``` 53 | 54 | ## Dtype 55 | 56 | The dtype should be any one of (all imported from `jaxtyping`): 57 | 58 | - Any dtype at all: `Shaped` 59 | - Boolean: `Bool` 60 | - PRNG key: `Key` 61 | - Any integer, unsigned integer, floating, or complex: `Num` 62 | - Any floating or complex: `Inexact` 63 | - Any floating point: `Float` 64 | - Of particular precision: `BFloat16`, `Float16`, `Float32`, `Float64` 65 | - Any complex: `Complex` 66 | - Of particular precision: `Complex64`, `Complex128` 67 | - Any integer or unsigned intger: `Integer` 68 | - Any unsigned integer: `UInt` 69 | - Of particular precision: `UInt2`, `UInt4`, `UInt8`, `UInt16`, `UInt32`, `UInt64` 70 | - Any signed integer: `Int` 71 | - Of particular precision: `Int2`, `Int4`, `Int8`, `Int16`, `Int32`, `Int64` 72 | - Any floating, integer, or unsigned integer: `Real`. 73 | 74 | Unless you really want to force a particular precision, then for most applications you should probably allow any floating-point, any integer, etc. That is, use 75 | ```python 76 | from jaxtyping import Array, Float 77 | Float[Array, "some_shape"] 78 | ``` 79 | rather than 80 | ```python 81 | from jaxtyping import Array, Float32 82 | Float32[Array, "some_shape"] 83 | ``` 84 | 85 | ## Array 86 | 87 | The array should typically be either one of: 88 | ```python 89 | jaxtyping.Array / jax.Array / jax.numpy.ndarray # these are all aliases of one another 90 | np.ndarray 91 | torch.Tensor 92 | tf.Tensor 93 | mx.array 94 | ``` 95 | That is -- despite the now-historical name! -- jaxtyping also supports NumPy + PyTorch + TensorFlow + MLX. 96 | 97 | Some other types are also supported here: 98 | 99 | **Unions:** these are unpacked. For example, `SomeDtype[Union[A, B], "some shape"]` is equivalent to `Union[SomeDtype[A, "some shape"], SomeDtype[B, "some shape"]]`. A common example of a union type here is `np.typing.ArrayLike`. 100 | 101 | **Any:** use `typing.Any` to check just the shape/dtype, but not the array type. 102 | 103 | **Duck-type arrays:** anything with `.shape` and `.dtype` attributes. For example, 104 | ```python 105 | class MyDuckArray: 106 | @property 107 | def shape(self) -> tuple[int, ...]: 108 | return (3, 4, 5) 109 | 110 | @property 111 | def dtype(self) -> str: 112 | return "my_dtype" 113 | 114 | class MyDtype(jaxtyping.AbstractDtype): 115 | dtypes = ["my_dtype"] 116 | 117 | x = MyDuckArray() 118 | assert isinstance(x, MyDtype[MyDuckArray, "3 4 5"]) 119 | # checks that `type(x) == MyDuckArray` 120 | # and that `x.shape == (3, 4, 5)` 121 | # and that `x.dtype == "my_dtype"` 122 | ``` 123 | 124 | **TypeVars:** in this case the runtime array is checked for matching the bounds or constraints of the `typing.TypeVar`. 125 | 126 | **Existing jaxtyped annotations:** 127 | ```python 128 | Image = Float[Array, "channels height width"] 129 | BatchImage = Float[Image, "batch"] 130 | ``` 131 | in which case the additional shape is prepended, and the acceptable dtypes are the intersection of the two dtype specifiers used. (So that e.g. `BatchImage = Shaped[Image, "batch"]` would work just as well. But `Bool[Image, "batch"]` would throw an error, as there are no dtypes that are both bools and floats.) Thus the above is equivalent to 132 | ```python 133 | BatchImage = Float[Array, "batch channels height width"] 134 | ``` 135 | 136 | Note that `jaxtyping.{Array, ArrayLike}` are only available if JAX has been installed. 137 | 138 | ## Scalars, PRNG keys 139 | 140 | For convenience, jaxtyping also includes `jaxtyping.Scalar`, `jaxtyping.ScalarLike`, and `jaxtyping.PRNGKeyArray`, defined as: 141 | ```python 142 | Scalar = Shaped[Array, ""] 143 | ScalarLike = Shaped[ArrayLike, ""] 144 | 145 | # Left: new-style typed keys; right: old-style keys. See JEP 9263. 146 | PRNGKeyArray = Union[Key[Array, ""], UInt32[Array, "2"]] 147 | ``` 148 | 149 | Recalling that shape-and-dtype specified jaxtyping arrays can be nested, this means that e.g. you can annotate the output of `jax.random.split` with `Shaped[PRNGKeyArray, "2"]`, or e.g. an integer scalar with `Int[Scalar, ""]`. 150 | 151 | Note that `jaxtyping.{Scalar, ScalarLike, PRNGKeyArray}` are only available if JAX has been installed. 152 | -------------------------------------------------------------------------------- /docs/api/pytree.md: -------------------------------------------------------------------------------- 1 | # PyTree annotations 2 | 3 | :::jaxtyping.PyTree 4 | options: 5 | members: [] 6 | 7 | --- 8 | 9 | :::jaxtyping.PyTreeDef 10 | options: 11 | members: [] 12 | 13 | --- 14 | 15 | ## Path-dependent shapes 16 | 17 | The prefix `?` may be used to indicate that the axis size can depend on which leaf of a PyTree the array is at. For example: 18 | ```python 19 | def f( 20 | x: PyTree[Shaped[Array, "?foo"], "T"], 21 | y: PyTree[Shaped[Array, "?foo"], "T"], 22 | ): 23 | pass 24 | ``` 25 | The above demands that `x` and `y` have matching PyTree structures (due to the `T` annotation), and that their leaves must all be one-dimensional arrays, *and that the corresponding pairs of leaves in `x` and `y` must have the same size as each other*. 26 | 27 | Thus the following is allowed: 28 | ```python 29 | x0 = jnp.arange(3) 30 | x1 = jnp.arange(5) 31 | 32 | y0 = jnp.arange(3) + 1 33 | y1 = jnp.arange(5) + 1 34 | 35 | f((x0, x1), (y0, y1)) # x0 matches y0, and x1 matches y1. All good! 36 | ``` 37 | 38 | But this is not: 39 | ```python 40 | f((x1, x1), (y0, y1)) # x1 does not have a size matching y0! 41 | ``` 42 | 43 | Internally, all that is happening is that `foo` is replaced with `0foo` for the first leaf, `1foo` for the next leaf, etc., so that each leaf gets a unique version of the name. 44 | 45 | --- 46 | 47 | Note that `jaxtyping.{PyTree, PyTreeDef}` are only available if JAX has been installed. 48 | -------------------------------------------------------------------------------- /docs/api/runtime-type-checking.md: -------------------------------------------------------------------------------- 1 | # Runtime type checking 2 | 3 | (See the [FAQ](../faq.md) for details on static type checking.) 4 | 5 | Runtime type checking **synergises beautifully with `jax.jit`!** All shape checks will be performed only whilst tracing, and will not impact runtime performance. 6 | 7 | There are two approaches: either use [`jaxtyping.jaxtyped`][] to typecheck a single function, or [`jaxtyping.install_import_hook`][] to typecheck a whole codebase. 8 | 9 | In either case, the actual business of checking types is performed with the help of a runtime type-checking library. The two most popular are [beartype](https://github.com/beartype/beartype) and [typeguard](https://github.com/agronholm/typeguard). (If using typeguard, then specifically the version `2.*` series should be used. Later versions -- `3` and `4` -- have some known issues.) 10 | 11 | !!! warning 12 | 13 | Avoid using `from __future__ import annotations`, or stringified type annotations, where possible. These are largely incompatible with runtime type checking. See also [this FAQ entry](../faq.md#dataclass-annotations-arent-being-checked-properly). 14 | 15 | --- 16 | 17 | ::: jaxtyping.jaxtyped 18 | 19 | --- 20 | 21 | ::: jaxtyping.install_import_hook 22 | 23 | --- 24 | 25 | #### Pytest hook 26 | 27 | The import hook can be installed at test-time only, as a pytest hook. From the command line the syntax is: 28 | ``` 29 | pytest --jaxtyping-packages=foo,bar.baz,beartype.beartype 30 | ``` 31 | or in `pyproject.toml`: 32 | ```toml 33 | [tool.pytest.ini_options] 34 | addopts = "--jaxtyping-packages=foo,bar.baz,beartype.beartype" 35 | ``` 36 | or in `pytest.ini`: 37 | ```ini 38 | [pytest] 39 | addopts = --jaxtyping-packages=foo,bar.baz,beartype.beartype 40 | ``` 41 | This example will apply the import hook to all modules whose names start with either `foo` or `bar.baz`. The typechecker used in this example is `beartype.beartype`. 42 | 43 | #### IPython extension 44 | 45 | If you are running in an IPython environment (for example a Jupyter or Colab notebook), then the jaxtyping hook can be automatically ran via a custom magic: 46 | ```python 47 | import jaxtyping 48 | %load_ext jaxtyping 49 | %jaxtyping.typechecker beartype.beartype # or any other runtime type checker 50 | ``` 51 | Place this at the start of your notebook -- everything that is directly defined in the notebook, after this magic is run, will be hook'd. 52 | 53 | #### Other runtime type-checking libraries 54 | 55 | Beartype and typeguard happen to be the two most popular runtime type-checking libraries (at least at time of writing), but jaxtyping should be compatible with all runtime type checkers out-of-the-box. The runtime type-checking library just needs to provide a type-checking decorator (analgous to `beartype.beartype` or `typeguard.typechecked`), and perform `isinstance` checks against jaxtyping's types. 56 | -------------------------------------------------------------------------------- /docs/faq.md: -------------------------------------------------------------------------------- 1 | # FAQ 2 | 3 | ## Is jaxtyping compatible with static type checkers like `mypy`/`pyright`/`pytype`? 4 | 5 | There is partial support for these. An annotation of the form `dtype[array, shape]` should be treated as just `array` by a static type checker. Unfortunately full dtype/shape checking is beyond the scope of what static type checking is currently capable of. 6 | 7 | (Note that at time of writing, `pytype` has a bug in that `dtype[array, shape]` is sometimes treated as `Any` rather than `array`. `mypy` and `pyright` both work fine.) 8 | 9 | ## How does jaxtyping interact with `jax.jit`? 10 | 11 | jaxtyping and `jax.jit` synergise beautifully. 12 | 13 | When calling JAX operations wrapped in a `jax.jit`, then the dtype/shape-checking will happen at trace time. (When JAX traces your function prior to compiling it.) The actual compiled code does not have any dtype/shape-checking, and will therefore still be just as fast as before! 14 | 15 | ## `flake8` or Ruff are throwing an error. 16 | 17 | In type annotations, strings are used for two different things. Sometimes they're strings. Sometimes they're "forward references", used to refer to a type that will be defined later. 18 | 19 | Some tooling in the Python ecosystem assumes that only the latter is true, and will throw spurious errors if you try to use a string just as a string (like we do). 20 | 21 | In the case of `flake8`, or Ruff, this can be resolved. Multi-dimensional arrays (e.g. `Float32[Array, "b c"]`) will throw a very unusual error (F722, syntax error in forward annotation), so you can safely just disable this particular error globally. Uni-dimensional arrays (e.g. `Float32[Array, "x"]`) will throw an error that's actually useful (F821, undefined name), so instead of disabling this globally, you should instead prepend a space to the start of your shape, e.g. `Float32[Array, " x"]`. `jaxtyping` will treat this in the same way, whilst `flake8` will now throw an F722 error that you can disable as before. 22 | 23 | ## Dataclass annotations aren't being checked properly. 24 | 25 | Stringified dataclass annotations, e.g. 26 | ```python 27 | @dataclass() 28 | class Foo: 29 | x: "int" 30 | ``` 31 | will be silently skipped without checking them. This is because these are essentially impossible to resolve at runtime. Such stringified annotations typically occur either when using them for forward references, or when using `from __future__ import annotations`. (You should essentially never use the latter, it is largely incompatible with runtime type checking and as such is [being replaced in Python 3.13](https://peps.python.org/pep-0649/).) 32 | 33 | Partially stringified dataclass annotations, e.g. 34 | ```python 35 | @dataclass() 36 | class Foo: 37 | x: tuple["int"] 38 | ``` 39 | will likely raise an error, and must not be used at all. 40 | 41 | ## Does jaxtyping use [PEP 646](https://www.python.org/dev/peps/pep-0646/) (variadic generics)? 42 | 43 | The intention of PEP 646 was to make it possible for static type checkers to perform shape checks of arrays. Unfortunately, this still isn't yet practical, so jaxtyping deliberately does not use this. (Yet?) 44 | 45 | The real problem is that Python's static typing ecosystem is a complicated collection of edge cases. Many of them block ML/scientific computing in particular. For example: 46 | 47 | 1. The static type system is intrinsically not expressive enough to describe operations like concatenation, stacking, or broadcasting. 48 | 49 | 2. Axes have to be lifted to type-level variables. Meanwhile the approach taken in libraries like `jaxtyping` and [TorchTyping](https://github.com/patrick-kidger/torchtyping) is to use value-level variables for types: because that's what the underlying JAX, PyTorch etc. libraries use! As such, making a static type checker work with these libraries would require either fundamentally rewriting these libraries, or exhaustively maintaining type stubs for them, and would *still* require a `typing.cast` any time you use anything unstubbed (e.g. any third party library, or part of your codebase you haven't typed yet). This is a huge maintenance burden. 50 | 51 | 3. Static type checkers have a variety of bugs that affect this use case. `mypy` doesn't support `Protocol`s correctly. `pyright` doesn't support genericised subprotocols. etc. 52 | 53 | 4. Variadic generics exist. Variadic protocols do not. (It's not clear that these were contemplated.) 54 | 55 | 5. The syntax for static typing is a little verbose. You have to write things like `Array[Float32, Unpack[AnyShape], Literal[3], Height, Width]` instead of `Float32[Array, "... 3 height width"]`. 56 | 57 | 6. [The underlying type system has flaws](https://github.com/patrick-kidger/torchtyping/issues/37#issuecomment-1153294196). 58 | [The numeric tower is broken](https://stackoverflow.com/a/69383462); 59 | [int is not a number](https://github.com/python/mypy/issues/3186#issuecomment-885718629); 60 | [virtual base classes don't work](https://github.com/python/mypy/issues/2922); 61 | [complex lies about having comparison operations, so type checkers have to lie about that lie in order to remove them again](https://beartype.github.io/numerary/0.4/whytho/); 62 | `typing.*` don't work with `isinstance`; 63 | co/contra-variance are baked into containers (not specified at use-time); 64 | `dict` is variadic despite... not being variadic; 65 | bool is a subclass of int (!); 66 | ... etc. etc. 67 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Getting started 2 | 3 | jaxtyping is a library providing type annotations **and runtime type-checking** for: 4 | 5 | 1. shape and dtype of [JAX](https://github.com/google/jax) arrays; *(Now also supports PyTorch, NumPy, MLX, and TensorFlow!)* 6 | 2. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html). 7 | 8 | ## Installation 9 | 10 | ```bash 11 | pip install jaxtyping 12 | ``` 13 | 14 | Requires Python 3.10+. 15 | 16 | JAX is an optional dependency, required for a few JAX-specific types. If JAX is not installed then these will not be available, but you may still use jaxtyping to provide shape/dtype annotations for PyTorch/NumPy/TensorFlow/etc. 17 | 18 | The annotations provided by jaxtyping are compatible with runtime type-checking packages, so it is common to also install one of these. The two most popular are [typeguard](https://github.com/agronholm/typeguard) (which exhaustively checks every argument) and [beartype](https://github.com/beartype/beartype) (which checks random pieces of arguments). 19 | 20 | ## Example 21 | 22 | ```python 23 | from jaxtyping import Array, Float, PyTree 24 | 25 | # Accepts floating-point 2D arrays with matching axes 26 | # You can replace `Array` with `torch.Tensor` etc. 27 | def matrix_multiply(x: Float[Array, "dim1 dim2"], 28 | y: Float[Array, "dim2 dim3"] 29 | ) -> Float[Array, "dim1 dim3"]: 30 | ... 31 | 32 | def accepts_pytree_of_ints(x: PyTree[int]): 33 | ... 34 | 35 | def accepts_pytree_of_arrays(x: PyTree[Float[Array, "batch c1 c2"]]): 36 | ... 37 | ``` 38 | 39 | ## Next steps 40 | 41 | Have a read of the [Array annotations](./api/array.md) documentation on the left-hand bar! 42 | 43 | ## See also: other libraries in the JAX ecosystem 44 | 45 | **Always useful** 46 | [Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX! 47 | 48 | **Deep learning** 49 | [Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers. 50 | [Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device). 51 | [Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs). 52 | [paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees. 53 | 54 | **Scientific computing** 55 | [Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers. 56 | [Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares. 57 | [Lineax](https://github.com/patrick-kidger/lineax): linear solvers. 58 | [BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling. 59 | [sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent. 60 | [PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!) 61 | 62 | **Awesome JAX** 63 | [Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects. 64 | -------------------------------------------------------------------------------- /jaxtyping/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Google LLC 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | import functools as ft 21 | import importlib.metadata 22 | import typing 23 | from typing import TypeAlias, Union 24 | 25 | from ._array_types import ( 26 | AbstractArray as AbstractArray, 27 | AbstractDtype as AbstractDtype, 28 | get_array_name_format as get_array_name_format, 29 | make_numpy_struct_dtype as make_numpy_struct_dtype, 30 | set_array_name_format as set_array_name_format, 31 | ) 32 | from ._config import config as config 33 | from ._decorator import jaxtyped as jaxtyped 34 | from ._errors import ( 35 | AnnotationError as AnnotationError, 36 | TypeCheckError as TypeCheckError, 37 | ) 38 | from ._import_hook import install_import_hook as install_import_hook 39 | from ._ipython_extension import load_ipython_extension as load_ipython_extension 40 | from ._storage import print_bindings as print_bindings 41 | 42 | 43 | if typing.TYPE_CHECKING: 44 | from jax import Array as Array 45 | from jax.tree_util import PyTreeDef as PyTreeDef 46 | from jax.typing import ArrayLike as ArrayLike, DTypeLike as DTypeLike 47 | 48 | # Introduce an indirection so that we can `import X as X` to make it clear that 49 | # these are public. 50 | from ._indirection import ( 51 | BFloat16 as BFloat16, 52 | Bool as Bool, 53 | Complex as Complex, 54 | Complex64 as Complex64, 55 | Complex128 as Complex128, 56 | Float as Float, 57 | Float8e4m3b11fnuz as Float8e4m3b11fnuz, 58 | Float8e4m3fn as Float8e4m3fn, 59 | Float8e4m3fnuz as Float8e4m3fnuz, 60 | Float8e5m2 as Float8e5m2, 61 | Float8e5m2fnuz as Float8e5m2fnuz, 62 | Float16 as Float16, 63 | Float32 as Float32, 64 | Float64 as Float64, 65 | Inexact as Inexact, 66 | Int as Int, 67 | Int2 as Int2, 68 | Int4 as Int4, 69 | Int8 as Int8, 70 | Int16 as Int16, 71 | Int32 as Int32, 72 | Int64 as Int64, 73 | Integer as Integer, 74 | Key as Key, 75 | Num as Num, 76 | PRNGKeyArray as PRNGKeyArray, 77 | Real as Real, 78 | Scalar as Scalar, 79 | ScalarLike as ScalarLike, 80 | Shaped as Shaped, 81 | UInt as UInt, 82 | UInt2 as UInt2, 83 | UInt4 as UInt4, 84 | UInt8 as UInt8, 85 | UInt16 as UInt16, 86 | UInt32 as UInt32, 87 | UInt64 as UInt64, 88 | ) 89 | 90 | # Set up to deliberately confuse a static type checker. 91 | PyTree: TypeAlias = getattr(typing, "foo" + "bar") 92 | # What's going on with this madness? 93 | # 94 | # At static-type-checking-time, we want `PyTree` to be a type for which both 95 | # `PyTree` and `PyTree[Foo]` are equivalent to `Any`. 96 | # (The intention is that `PyTree` be a runtime-only type; there's no real way to 97 | # do more with static type checkers.) 98 | # 99 | # Unfortunately, this isn't possible: `Any` isn't subscriptable. And there's no 100 | # equivalent way we can fake this using typing annotations. (In some sense the 101 | # closest thing would be a `Protocol[T]` with no methods, but that's actually the 102 | # opposite of what we want: that ends up allowing nothing at all.) 103 | # 104 | # The good news for us is that static type checkers have an internal escape hatch. 105 | # If they can't figure out what a type is, then they just give up and allow 106 | # anything. (I believe this is sometimes called `Unknown`.) Thus, this odd-looking 107 | # annotation, which static type checkers aren't smart enough to resolve. 108 | else: 109 | from ._array_types import ( 110 | BFloat16 as BFloat16, 111 | Bool as Bool, 112 | Complex as Complex, 113 | Complex64 as Complex64, 114 | Complex128 as Complex128, 115 | Float as Float, 116 | Float8e4m3b11fnuz as Float8e4m3b11fnuz, 117 | Float8e4m3fn as Float8e4m3fn, 118 | Float8e4m3fnuz as Float8e4m3fnuz, 119 | Float8e5m2 as Float8e5m2, 120 | Float8e5m2fnuz as Float8e5m2fnuz, 121 | Float16 as Float16, 122 | Float32 as Float32, 123 | Float64 as Float64, 124 | Inexact as Inexact, 125 | Int as Int, 126 | Int2 as Int2, 127 | Int4 as Int4, 128 | Int8 as Int8, 129 | Int16 as Int16, 130 | Int32 as Int32, 131 | Int64 as Int64, 132 | Integer as Integer, 133 | Key as Key, 134 | Num as Num, 135 | Real as Real, 136 | Shaped as Shaped, 137 | UInt as UInt, 138 | UInt2 as UInt2, 139 | UInt4 as UInt4, 140 | UInt8 as UInt8, 141 | UInt16 as UInt16, 142 | UInt32 as UInt32, 143 | UInt64 as UInt64, 144 | ) 145 | 146 | if hasattr(typing, "GENERATING_DOCUMENTATION"): 147 | 148 | class Array: 149 | pass 150 | 151 | Array.__module__ = "builtins" 152 | Array.__qualname__ = "Array" 153 | 154 | class ArrayLike: 155 | pass 156 | 157 | ArrayLike.__module__ = "builtins" 158 | ArrayLike.__qualname__ = "ArrayLike" 159 | 160 | class PRNGKeyArray: 161 | pass 162 | 163 | PRNGKeyArray.__module__ = "builtins" 164 | PRNGKeyArray.__qualname__ = "PRNGKeyArray" 165 | 166 | from ._pytree_type import PyTree as PyTree 167 | 168 | class PyTreeDef: 169 | """Alias for `jax.tree_util.PyTreeDef`, which is the type of the 170 | return from `jax.tree_util.tree_structure(...)`. 171 | """ 172 | 173 | if typing.GENERATING_DOCUMENTATION != "jaxtyping": 174 | # Equinox etc. docs get just `PyTreeDef`. 175 | # jaxtyping docs get `jaxtyping.PyTreeDef`. 176 | PyTreeDef.__qualname__ = "PyTreeDef" 177 | PyTreeDef.__module__ = "builtins" 178 | 179 | @ft.cache 180 | def __getattr__(item): 181 | if item == "Array": 182 | import jax 183 | 184 | return jax.Array 185 | elif item == "ArrayLike": 186 | import jax.typing 187 | 188 | return jax.typing.ArrayLike 189 | elif item == "PRNGKeyArray": 190 | # New-style `jax.random.key` have scalar shape and dtype `key`. 191 | # Old-style `jax.random.PRNGKey` have shape `(2,)` and dtype 192 | # `uint32`. 193 | import jax 194 | 195 | return Union[Key[jax.Array, ""], UInt32[jax.Array, "2"]] 196 | elif item == "DTypeLike": 197 | import jax.typing 198 | 199 | return jax.typing.DTypeLike 200 | elif item == "Scalar": 201 | import jax 202 | 203 | return Shaped[jax.Array, ""] 204 | elif item == "ScalarLike": 205 | from . import ArrayLike 206 | 207 | return Shaped[ArrayLike, ""] 208 | elif item == "PyTree": 209 | from ._pytree_type import PyTree 210 | 211 | return PyTree 212 | elif item == "PyTreeDef": 213 | import jax.tree_util 214 | 215 | return jax.tree_util.PyTreeDef 216 | else: 217 | raise AttributeError(f"module jaxtyping has no attribute {item!r}") 218 | 219 | 220 | __version__ = importlib.metadata.version("jaxtyping") 221 | -------------------------------------------------------------------------------- /jaxtyping/_array_types.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Google LLC 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | import copyreg 21 | import enum 22 | import functools as ft 23 | import importlib.util 24 | import re 25 | import sys 26 | import types 27 | import typing 28 | from dataclasses import dataclass 29 | from typing import ( 30 | Any, 31 | get_args, 32 | get_origin, 33 | Literal, 34 | NoReturn, 35 | Optional, 36 | TypeVar, 37 | Union, 38 | ) 39 | 40 | 41 | # Bit of a hack, but jaxtyping provides nicer error messages than typeguard. This means 42 | # we sometimes want to use it as our runtime type checker everywhere, even in non-array 43 | # use-cases, for which numpy is too heavy a dependency. 44 | # Honestly we should probably consider factoring out part of jaxtyping into a separate 45 | # package. (Specifically (a) the multi-argument checking and (b) the better error 46 | # messages and (c) the import hook that places the checker on the bottom of the 47 | # decorator stack.) And resist the urge to write our own runtime type-checker, I really 48 | # don't want to have to keep that up-to-date with changes in the Python typing spec... 49 | if importlib.util.find_spec("numpy") is not None: 50 | import numpy as np 51 | 52 | from ._errors import AnnotationError 53 | from ._storage import ( 54 | get_shape_memo, 55 | get_treeflatten_memo, 56 | get_treepath_memo, 57 | set_shape_memo, 58 | ) 59 | 60 | 61 | _array_name_format = "dtype_and_shape" 62 | 63 | 64 | def get_array_name_format(): 65 | return _array_name_format 66 | 67 | 68 | def set_array_name_format(value): 69 | global _array_name_format 70 | _array_name_format = value 71 | 72 | 73 | _any_dtype = object() 74 | 75 | _anonymous_dim = object() 76 | _anonymous_variadic_dim = object() 77 | 78 | 79 | class _DimType(enum.Enum): 80 | named = enum.auto() 81 | fixed = enum.auto() 82 | symbolic = enum.auto() 83 | 84 | 85 | @dataclass(frozen=True) 86 | class _NamedDim: 87 | name: str 88 | broadcastable: bool 89 | treepath: Any 90 | 91 | 92 | @dataclass(frozen=True) 93 | class _NamedVariadicDim: 94 | name: str 95 | broadcastable: bool 96 | treepath: Any 97 | 98 | 99 | @dataclass(frozen=True) 100 | class _FixedDim: 101 | size: int 102 | broadcastable: bool 103 | 104 | 105 | @dataclass(frozen=True) 106 | class _SymbolicDim: 107 | elem: Any 108 | broadcastable: bool 109 | 110 | 111 | _AbstractDimOrVariadicDim = Union[ 112 | Literal[_anonymous_dim], 113 | Literal[_anonymous_variadic_dim], 114 | _NamedDim, 115 | _NamedVariadicDim, 116 | _FixedDim, 117 | _SymbolicDim, 118 | ] 119 | _AbstractDim = Union[Literal[_anonymous_dim], _NamedDim, _FixedDim, _SymbolicDim] 120 | 121 | 122 | def _check_dims( 123 | cls_dims: list[_AbstractDim], 124 | obj_shape: tuple[int, ...], 125 | single_memo: dict[str, int], 126 | arg_memo: dict[str, Any], 127 | ) -> str: 128 | assert len(cls_dims) == len(obj_shape) 129 | for cls_dim, obj_size in zip(cls_dims, obj_shape): 130 | if cls_dim is _anonymous_dim: 131 | pass 132 | elif cls_dim.broadcastable and obj_size == 1: 133 | pass 134 | elif type(cls_dim) is _FixedDim: 135 | if cls_dim.size != obj_size: 136 | return f"the dimension size {obj_size} does not equal {cls_dim.size} as expected by the type hint" # noqa: E501 137 | elif type(cls_dim) is _SymbolicDim: 138 | try: 139 | # Support f-string syntax. 140 | # https://stackoverflow.com/a/53671539/22545467 141 | elem = eval(f"f'{cls_dim.elem}'", arg_memo.copy()) 142 | # Make a copy to avoid `__builtins__` getting added as a key. 143 | eval_size = eval(elem, single_memo.copy()) 144 | except NameError as e: 145 | raise AnnotationError( 146 | f"Cannot process symbolic axis '{cls_dim.elem}' as " 147 | "some axis names have not been processed. " 148 | "Have you applied the `jaxtyped` decorator? " 149 | "In practice you should usually only use symbolic axes in " 150 | "annotations for return types, referring only to axes " 151 | "annotated for arguments." 152 | ) from e 153 | if eval_size != obj_size: 154 | return f"the dimension size {obj_size} does not equal the existing value of {cls_dim.elem}={eval_size}" # noqa: E501 155 | else: 156 | assert type(cls_dim) is _NamedDim 157 | if cls_dim.treepath: 158 | name = get_treepath_memo() + cls_dim.name 159 | else: 160 | name = cls_dim.name 161 | try: 162 | cls_size = single_memo[name] 163 | except KeyError: 164 | single_memo[name] = obj_size 165 | else: 166 | if cls_size != obj_size: 167 | return f"the size of dimension {cls_dim.name} is {obj_size} which does not equal the existing value of {cls_size}" # noqa: E501 168 | return "" 169 | 170 | 171 | def _dtype_is_numpy_struct_array(dtype): 172 | return dtype.type.__name__ == "void" and dtype is not np.dtype(np.void) 173 | 174 | 175 | class _MetaAbstractArray(type): 176 | _skip_instancecheck: bool = False 177 | 178 | def make_transparent(cls): 179 | cls._skip_instancecheck = True 180 | 181 | def __instancecheck__(cls, obj: Any) -> bool: 182 | return cls.__instancecheck_str__(obj) == "" 183 | 184 | def __instancecheck_str__(cls, obj: Any) -> str: 185 | if cls._skip_instancecheck: 186 | return "" 187 | if cls.array_type is Any: 188 | if not (hasattr(obj, "shape") and hasattr(obj, "dtype")): 189 | return "this value does not have both `shape` and `dtype` attributes." 190 | else: 191 | if not isinstance(obj, cls.array_type): 192 | return f"this value is not an instance of the underlying array type {cls.array_type}" # noqa: E501 193 | if get_treeflatten_memo(): 194 | return "" 195 | 196 | if hasattr(obj.dtype, "type") and hasattr(obj.dtype.type, "__name__"): 197 | # JAX, numpy 198 | dtype = obj.dtype.type.__name__ 199 | # numpy structured array is strictly a subtype of np.void 200 | if _dtype_is_numpy_struct_array(obj.dtype): 201 | dtype = str(obj.dtype) 202 | elif hasattr(obj.dtype, "as_numpy_dtype"): 203 | # TensorFlow 204 | dtype = obj.dtype.as_numpy_dtype.__name__ 205 | else: 206 | # Everyone else, including PyTorch. 207 | # This offers an escape hatch for anyone looking to use jaxtyping for their 208 | # own array-like types. 209 | dtype = obj.dtype 210 | if not isinstance(dtype, str): 211 | *_, dtype = repr(obj.dtype).rsplit(".", 1) 212 | 213 | if cls.dtypes is not _any_dtype: 214 | in_dtypes = False 215 | for cls_dtype in cls.dtypes: 216 | if type(cls_dtype) is str: 217 | in_dtypes = dtype == cls_dtype 218 | elif type(cls_dtype) is re.Pattern: 219 | in_dtypes = bool(cls_dtype.match(dtype)) 220 | else: 221 | assert False 222 | if in_dtypes: 223 | break 224 | if not in_dtypes: 225 | if len(cls.dtypes) == 1: 226 | return f"this array has dtype {dtype}, not {cls.dtypes[0]} as expected by the type hint" # noqa: E501 227 | else: 228 | return f"this array has dtype {dtype}, not any of {cls.dtypes} as expected by the type hint" # noqa: E501 229 | 230 | single_memo, variadic_memo, pytree_memo, arg_memo = get_shape_memo() 231 | single_memo_bak = single_memo.copy() 232 | variadic_memo_bak = variadic_memo.copy() 233 | pytree_memo_bak = pytree_memo.copy() 234 | arg_memo_bak = arg_memo.copy() 235 | try: 236 | check = cls._check_shape(obj, single_memo, variadic_memo, arg_memo) 237 | except Exception: 238 | set_shape_memo( 239 | single_memo_bak, variadic_memo_bak, pytree_memo_bak, arg_memo_bak 240 | ) 241 | raise 242 | if check == "": 243 | return check 244 | else: 245 | set_shape_memo( 246 | single_memo_bak, variadic_memo_bak, pytree_memo_bak, arg_memo_bak 247 | ) 248 | return check 249 | 250 | def _check_shape( 251 | cls, 252 | obj, 253 | single_memo: dict[str, int], 254 | variadic_memo: dict[str, tuple[bool, tuple[int, ...]]], 255 | arg_memo: dict[str, Any], 256 | ) -> str: 257 | if cls.index_variadic is None: 258 | if len(obj.shape) != len(cls.dims): 259 | return f"this array has {len(obj.shape)} dimensions, not the {len(cls.dims)} expected by the type hint" # noqa: E501 260 | return _check_dims(cls.dims, obj.shape, single_memo, arg_memo) 261 | else: 262 | if len(obj.shape) < len(cls.dims) - 1: 263 | return f"this array has {len(obj.shape)} dimensions, which is fewer than {len(cls.dims) - 1} that is the minimum expected by the type hint" # noqa: E501 264 | i = cls.index_variadic 265 | j = -(len(cls.dims) - i - 1) 266 | if j == 0: 267 | j = None 268 | prefix_check = _check_dims( 269 | cls.dims[:i], obj.shape[:i], single_memo, arg_memo 270 | ) 271 | if prefix_check != "": 272 | return prefix_check 273 | if j is not None: 274 | suffix_check = _check_dims( 275 | cls.dims[j:], obj.shape[j:], single_memo, arg_memo 276 | ) 277 | if suffix_check != "": 278 | return suffix_check 279 | variadic_dim = cls.dims[i] 280 | if variadic_dim is _anonymous_variadic_dim: 281 | return "" 282 | else: 283 | assert type(variadic_dim) is _NamedVariadicDim 284 | if variadic_dim.treepath: 285 | name = get_treepath_memo() + variadic_dim.name 286 | else: 287 | name = variadic_dim.name 288 | broadcastable = variadic_dim.broadcastable 289 | try: 290 | prev_broadcastable, prev_shape = variadic_memo[name] 291 | except KeyError: 292 | variadic_memo[name] = (broadcastable, obj.shape[i:j]) 293 | return "" 294 | else: 295 | new_shape = obj.shape[i:j] 296 | if prev_broadcastable: 297 | try: 298 | broadcast_shape = np.broadcast_shapes(new_shape, prev_shape) 299 | except ValueError: # not broadcastable e.g. (3, 4) and (5,) 300 | return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which cannot be broadcast with the existing value of {prev_shape}" # noqa: E501 301 | if not broadcastable and broadcast_shape != new_shape: 302 | return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which the existing value of {prev_shape} cannot be broadcast to" # noqa: E501 303 | variadic_memo[name] = (broadcastable, broadcast_shape) 304 | else: 305 | if broadcastable: 306 | try: 307 | broadcast_shape = np.broadcast_shapes( 308 | new_shape, prev_shape 309 | ) 310 | except ValueError: # not broadcastable e.g. (3, 4) and (5,) 311 | return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which cannot be broadcast with the existing value of {prev_shape}" # noqa: E501 312 | if broadcast_shape != prev_shape: 313 | return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which cannot be broadcast to the existing value of {prev_shape}" # noqa: E501 314 | else: 315 | if new_shape != prev_shape: 316 | return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which does not equal the existing value of {prev_shape}" # noqa: E501 317 | return "" 318 | assert False 319 | 320 | 321 | def _return_abstractarray(): 322 | return AbstractArray 323 | 324 | 325 | def _pickle_array_annotation(x: type["AbstractArray"]): 326 | if x is AbstractArray: 327 | return _return_abstractarray, () 328 | else: 329 | return x.dtype.__getitem__, ((x.array_type, x.dim_str),) 330 | 331 | 332 | copyreg.pickle(_MetaAbstractArray, _pickle_array_annotation) 333 | 334 | 335 | def _check_scalar(dtype, dtypes, dims): 336 | for dim in dims: 337 | if dim is not _anonymous_variadic_dim and not isinstance( 338 | dim, _NamedVariadicDim 339 | ): 340 | return False 341 | return (_any_dtype is dtypes) or any(d.startswith(dtype) for d in dtypes) 342 | 343 | 344 | class AbstractArray(metaclass=_MetaAbstractArray): 345 | """This is the base class of all shape-and-dtype-specified arrays, e.g. it's a base 346 | class for `Float32[Array, "foo"]`. 347 | 348 | This might be useful if you're trying to inspect type annotations yourself, e.g. 349 | you can check `issubclass(annotation, jaxtyping.AbstractArray)`. 350 | """ 351 | 352 | # This is what it was defined with. 353 | dtype: type["AbstractDtype"] 354 | array_type: Any 355 | dim_str: str 356 | 357 | # This is the processed information we need for later typechecking. 358 | dtypes: list[str] 359 | dims: tuple[_AbstractDimOrVariadicDim, ...] 360 | index_variadic: Optional[int] 361 | 362 | def __new__(cls, *args, **kwargs): 363 | raise RuntimeError( 364 | "jaxtyping annotations cannot be instantiated -- they should be used for " 365 | "type hints only." 366 | ) 367 | 368 | 369 | _not_made = object() 370 | 371 | _union_types = [Union] 372 | if sys.version_info >= (3, 10): 373 | _union_types.append(types.UnionType) 374 | 375 | 376 | @ft.lru_cache(maxsize=None) 377 | def _make_array_cached(array_type, dim_str, dtypes, name): 378 | if not isinstance(dim_str, str): 379 | raise ValueError( 380 | "Shape specification must be a string. Axes should be separated with " 381 | "spaces." 382 | ) 383 | dims = [] 384 | index_variadic = None 385 | for index, elem in enumerate(dim_str.split()): 386 | if "," in elem and "(" not in elem: 387 | # Common mistake. 388 | # Disable in the case that there's brackets to allow for function calls, 389 | # e.g. `min(foo,bar)`, in symbolic axes. 390 | raise ValueError("Axes should be separated with spaces, not commas") 391 | if elem.endswith("#"): 392 | raise ValueError( 393 | "As of jaxtyping v0.1.0, broadcastable axes are now denoted " 394 | "with a # at the start, rather than at the end" 395 | ) 396 | 397 | if "..." in elem: 398 | if elem != "...": 399 | raise ValueError( 400 | "Anonymous multiple axes '...' must be used on its own; " 401 | f"got {elem}" 402 | ) 403 | broadcastable = False 404 | variadic = True 405 | anonymous = True 406 | treepath = False 407 | dim_type = _DimType.named 408 | else: 409 | broadcastable = False 410 | variadic = False 411 | anonymous = False 412 | treepath = False 413 | while True: 414 | if len(elem) == 0: 415 | # This branch needed as just `_` is valid 416 | break 417 | first_char = elem[0] 418 | if first_char == "#": 419 | if broadcastable: 420 | raise ValueError( 421 | "Do not use # twice to denote broadcastability, e.g. " 422 | "`##foo` is not allowed" 423 | ) 424 | broadcastable = True 425 | elem = elem[1:] 426 | elif first_char == "*": 427 | if variadic: 428 | raise ValueError( 429 | "Do not use * twice to denote accepting multiple " 430 | "axes, e.g. `**foo` is not allowed" 431 | ) 432 | variadic = True 433 | elem = elem[1:] 434 | elif first_char == "_": 435 | if anonymous: 436 | raise ValueError( 437 | "Do not use _ twice to denote anonymity, e.g. `__foo` " 438 | "is not allowed" 439 | ) 440 | anonymous = True 441 | elem = elem[1:] 442 | elif first_char == "?": 443 | if treepath: 444 | raise ValueError( 445 | "Do not use ? twice to denote dependence on location " 446 | "within a PyTree, e.g. `??foo` is not allowed" 447 | ) 448 | treepath = True 449 | elem = elem[1:] 450 | # Allow e.g. `foo=4` as an alternate syntax for just `4`, so that one 451 | # can write e.g. `Float[Array, "rows=3 cols=4"]` 452 | elif elem.count("=") == 1: 453 | _, elem = elem.split("=") 454 | else: 455 | break 456 | if len(elem) == 0 or elem.isidentifier(): 457 | dim_type = _DimType.named 458 | else: 459 | try: 460 | elem = int(elem) 461 | except ValueError: 462 | dim_type = _DimType.symbolic 463 | else: 464 | dim_type = _DimType.fixed 465 | 466 | if variadic: 467 | if index_variadic is not None: 468 | raise ValueError( 469 | "Cannot use variadic specifiers (`*name` or `...`) " 470 | "more than once." 471 | ) 472 | index_variadic = index 473 | 474 | if dim_type is _DimType.fixed: 475 | if variadic: 476 | raise ValueError( 477 | "Cannot have a fixed axis bind to multiple axes, e.g. " 478 | "`*4` is not allowed." 479 | ) 480 | if anonymous: 481 | raise ValueError( 482 | "Cannot have a fixed axis be anonymous, e.g. `_4` is not allowed." 483 | ) 484 | if treepath: 485 | raise ValueError( 486 | "Cannot have a fixed axis have tree-path dependence, e.g. `?4` is " 487 | "not allowed." 488 | ) 489 | elem = _FixedDim(elem, broadcastable) 490 | elif dim_type is _DimType.named: 491 | if anonymous: 492 | if broadcastable: 493 | raise ValueError( 494 | "Cannot have an axis be both anonymous and " 495 | "broadcastable, e.g. `#_` is not allowed." 496 | ) 497 | if variadic: 498 | elem = _anonymous_variadic_dim 499 | else: 500 | elem = _anonymous_dim 501 | else: 502 | if variadic: 503 | elem = _NamedVariadicDim(elem, broadcastable, treepath) 504 | else: 505 | elem = _NamedDim(elem, broadcastable, treepath) 506 | else: 507 | assert dim_type is _DimType.symbolic 508 | if anonymous: 509 | raise ValueError( 510 | "Cannot have a symbolic axis be anonymous, e.g. " 511 | "`_foo+bar` is not allowed" 512 | ) 513 | if variadic: 514 | raise ValueError( 515 | "Cannot have symbolic multiple-axes, e.g. " 516 | "`*foo+bar` is not allowed" 517 | ) 518 | if treepath: 519 | raise ValueError( 520 | "Cannot have a symbolic axis with tree-path dependence, e.g. " 521 | "`?foo+bar` is not allowed" 522 | ) 523 | elem = _SymbolicDim(elem, broadcastable) 524 | dims.append(elem) 525 | dims = tuple(dims) 526 | 527 | # Allow Python built-in numeric types. 528 | # TODO: do something more generic than this? Should we _make all types 529 | # that have `shape` and `dtype` attributes or something? 530 | array_origin = get_origin(array_type) 531 | if array_origin is not None: 532 | array_type = array_origin 533 | if array_type is bool: 534 | if _check_scalar("bool", dtypes, dims): 535 | return array_type 536 | else: 537 | return _not_made 538 | elif array_type is int: 539 | if _check_scalar("int", dtypes, dims): 540 | return array_type 541 | else: 542 | return _not_made 543 | elif array_type is float: 544 | if _check_scalar("float", dtypes, dims): 545 | return array_type 546 | else: 547 | return _not_made 548 | elif array_type is complex: 549 | if _check_scalar("complex", dtypes, dims): 550 | return array_type 551 | else: 552 | return _not_made 553 | elif array_type is np.bool_: 554 | if _check_scalar("bool", dtypes, dims): 555 | return array_type 556 | else: 557 | return _not_made 558 | elif array_type is np.generic or array_type is np.number: 559 | if _check_scalar("", dtypes, dims): 560 | return array_type 561 | else: 562 | return _not_made 563 | if array_type is not Any and issubclass(array_type, AbstractArray): 564 | if dtypes is _any_dtype: 565 | dtypes = array_type.dtypes 566 | elif array_type.dtypes is not _any_dtype: 567 | dtypes = tuple(x for x in dtypes if x in array_type.dtypes) 568 | if len(dtypes) == 0: 569 | raise ValueError( 570 | "A jaxtyping annotation cannot be extended with no overlapping " 571 | "dtypes. For example, `Bool[Float[Array, 'dim1'], 'dim2']` is an " 572 | "error. You probably want to make the outer wrapper be `Shaped`." 573 | ) 574 | if array_type.index_variadic is not None: 575 | if index_variadic is None: 576 | index_variadic = array_type.index_variadic + len(dims) 577 | else: 578 | raise ValueError( 579 | "Cannot use variadic specifiers (`*name` or `...`) " 580 | "in both the original array and the extended array" 581 | ) 582 | dims = dims + array_type.dims 583 | dim_str = dim_str + " " + array_type.dim_str 584 | array_type = array_type.array_type 585 | try: 586 | type_str = array_type.__name__ 587 | except AttributeError: 588 | type_str = repr(array_type) 589 | if _array_name_format == "dtype_and_shape": 590 | name = f"{name}[{type_str}, '{dim_str}']" 591 | elif _array_name_format == "array": 592 | name = type_str 593 | else: 594 | raise ValueError(f"array_name_format {_array_name_format} not recognised") 595 | 596 | return (array_type, name, dtypes, dims, index_variadic, dim_str) 597 | 598 | 599 | def _make_array(x, dim_str, dtype): 600 | out = _make_array_cached(x, dim_str, dtype.dtypes, dtype.__name__) 601 | 602 | if type(out) is tuple: 603 | array_type, name, dtypes, dims, index_variadic, dim_str = out 604 | 605 | out = _MetaAbstractArray( 606 | name, 607 | (AbstractArray,), 608 | dict( 609 | dtype=dtype, 610 | array_type=array_type, 611 | dim_str=dim_str, 612 | dtypes=dtypes, 613 | dims=dims, 614 | index_variadic=index_variadic, 615 | ), 616 | ) 617 | if getattr(typing, "GENERATING_DOCUMENTATION", "") in {"", "jaxtyping"}: 618 | out.__module__ = "jaxtyping" 619 | else: 620 | out.__module__ = "builtins" 621 | 622 | return out 623 | 624 | 625 | class _MetaAbstractDtype(type): 626 | def __instancecheck__(cls, obj: Any) -> NoReturn: 627 | raise AnnotationError( 628 | f"Do not use `isinstance(x, jaxtyping.{cls.__name__})`. If you want to " 629 | "check just the dtype of an array, then use " 630 | f'`jaxtyping.{cls.__name__}[jnp.ndarray, "..."]`.' 631 | ) 632 | 633 | def __getitem__(cls, item: tuple[Any, str]): 634 | if not isinstance(item, tuple) or len(item) != 2: 635 | raise ValueError( 636 | "As of jaxtyping v0.2.0, type annotations must now include both an " 637 | "array type and a shape. For example `Float[Array, 'foo bar']`.\n" 638 | "Ellipsis can be used to accept any shape: `Float[Array, '...']`." 639 | ) 640 | array_type, dim_str = item 641 | dim_str = dim_str.strip() 642 | if isinstance(array_type, TypeVar): 643 | bound = array_type.__bound__ 644 | if bound is None: 645 | constraints = array_type.__constraints__ 646 | if constraints == (): 647 | array_type = Any 648 | else: 649 | array_type = Union[constraints] 650 | else: 651 | array_type = bound 652 | del item 653 | if get_origin(array_type) in _union_types: 654 | out = [_make_array(x, dim_str, cls) for x in get_args(array_type)] 655 | out = tuple(x for x in out if x is not _not_made) 656 | if len(out) == 0: 657 | raise ValueError("Invalid jaxtyping type annotation.") 658 | elif len(out) == 1: 659 | (out,) = out 660 | else: 661 | out = Union[out] 662 | else: 663 | out = _make_array(array_type, dim_str, cls) 664 | if out is _not_made: 665 | raise ValueError("Invalid jaxtyping type annotation.") 666 | return out 667 | 668 | 669 | class AbstractDtype(metaclass=_MetaAbstractDtype): 670 | """This is the base class of all dtypes. This can be used to create your own custom 671 | collection of dtypes (analogous to `Float`, `Inexact` etc.) 672 | 673 | You must specify the class attribute `dtypes`. This can either be a string, a 674 | regex (as returned by `re.compile(...)`), or a tuple/list of strings/regexes. 675 | 676 | At runtime, the array or tensor's dtype is converted to a string and compared 677 | against the string (an exact match is required) or regex. (String matching is 678 | performed, rather than just e.g. `array.dtype == dtype`, to provide cross-library 679 | compatibility between JAX/PyTorch/etc.) 680 | 681 | !!! Example 682 | 683 | ```python 684 | class UInt8or16(AbstractDtype): 685 | dtypes = ["uint8", "uint16"] 686 | 687 | UInt8or16[Array, "shape"] 688 | ``` 689 | which is essentially equivalent to 690 | ```python 691 | Union[UInt8[Array, "shape"], UInt16[Array, "shape"]] 692 | ``` 693 | """ 694 | 695 | dtypes: Union[Literal[_any_dtype], list[Union[str, re.Pattern]]] 696 | 697 | def __init__(self, *args, **kwargs): 698 | raise RuntimeError( 699 | "AbstractDtype cannot be instantiated. Perhaps you wrote e.g. " 700 | '`Float32("shape")` when you mean `Float32[jnp.ndarray, "shape"]`?' 701 | ) 702 | 703 | def __init_subclass__(cls, **kwargs): 704 | super().__init_subclass__(**kwargs) 705 | 706 | dtypes: Union[Literal[_any_dtype], str, list[str]] = cls.dtypes 707 | if isinstance(dtypes, (str, re.Pattern)): 708 | dtypes = (dtypes,) 709 | elif dtypes is not _any_dtype: 710 | dtypes = tuple(dtypes) 711 | cls.dtypes = dtypes 712 | 713 | 714 | _prng_key = "prng_key" 715 | _bool = "bool" 716 | _bool_ = "bool_" 717 | _uint2 = "uint2" 718 | _uint4 = "uint4" 719 | _uint8 = "uint8" 720 | _uint16 = "uint16" 721 | _uint32 = "uint32" 722 | _uint64 = "uint64" 723 | _int2 = "int2" 724 | _int4 = "int4" 725 | _int8 = "int8" 726 | _int16 = "int16" 727 | _int32 = "int32" 728 | _int64 = "int64" 729 | # fp8 types exposed in Jax, see https://github.com/jax-ml/jax/blob/main/jax/_src/dtypes.py#L92-L97 730 | _float8_e4m3b11fnuz = "float8_e4m3b11fnuz" 731 | _float8_e4m3fn = "float8_e4m3fn" 732 | _float8_e4m3fnuz = "float8_e4m3fnuz" 733 | _float8_e5m2 = "float8_e5m2" 734 | _float8_e5m2fnuz = "float8_e5m2fnuz" 735 | _bfloat16 = "bfloat16" 736 | _float16 = "float16" 737 | _float32 = "float32" 738 | _float64 = "float64" 739 | _complex64 = "complex64" 740 | _complex128 = "complex128" 741 | 742 | 743 | def _make_dtype(_dtypes, name): 744 | class _Cls(AbstractDtype): 745 | dtypes = _dtypes 746 | 747 | _Cls.__name__ = name 748 | _Cls.__qualname__ = name 749 | if getattr(typing, "GENERATING_DOCUMENTATION", "") in {"", "jaxtyping"}: 750 | _Cls.__module__ = "jaxtyping" 751 | else: 752 | _Cls.__module__ = "builtins" 753 | return _Cls 754 | 755 | 756 | UInt2 = _make_dtype(_uint2, "UInt2") 757 | UInt4 = _make_dtype(_uint4, "UInt4") 758 | UInt8 = _make_dtype(_uint8, "UInt8") 759 | UInt16 = _make_dtype(_uint16, "UInt16") 760 | UInt32 = _make_dtype(_uint32, "UInt32") 761 | UInt64 = _make_dtype(_uint64, "UInt64") 762 | Int2 = _make_dtype(_int2, "Int2") 763 | Int4 = _make_dtype(_int4, "Int4") 764 | Int8 = _make_dtype(_int8, "Int8") 765 | Int16 = _make_dtype(_int16, "Int16") 766 | Int32 = _make_dtype(_int32, "Int32") 767 | Int64 = _make_dtype(_int64, "Int64") 768 | Float8e4m3b11fnuz = _make_dtype(_float8_e4m3b11fnuz, "Float8e4m3b11fnuz") 769 | Float8e4m3fn = _make_dtype(_float8_e4m3fn, "Float8e4m3fn") 770 | Float8e4m3fnuz = _make_dtype(_float8_e4m3fnuz, "Float8e4m3fnuz") 771 | Float8e5m2 = _make_dtype(_float8_e5m2, "Float8e5m2") 772 | Float8e5m2fnuz = _make_dtype(_float8_e5m2fnuz, "Float8e5m2fnuz") 773 | BFloat16 = _make_dtype(_bfloat16, "BFloat16") 774 | Float16 = _make_dtype(_float16, "Float16") 775 | Float32 = _make_dtype(_float32, "Float32") 776 | Float64 = _make_dtype(_float64, "Float64") 777 | Complex64 = _make_dtype(_complex64, "Complex64") 778 | Complex128 = _make_dtype(_complex128, "Complex128") 779 | 780 | bools = [_bool, _bool_] 781 | uints = [_uint2, _uint4, _uint8, _uint16, _uint32, _uint64] 782 | ints = [_int2, _int4, _int8, _int16, _int32, _int64] 783 | float8 = [ 784 | _float8_e4m3b11fnuz, 785 | _float8_e4m3fn, 786 | _float8_e4m3fnuz, 787 | _float8_e5m2, 788 | _float8_e5m2fnuz, 789 | ] 790 | floats = float8 + [_bfloat16, _float16, _float32, _float64] 791 | complexes = [_complex64, _complex128] 792 | 793 | # We match NumPy's type hierarachy in what types to provide. See the diagram at 794 | # https://numpy.org/doc/stable/reference/arrays.scalars.html#scalars 795 | 796 | Bool = _make_dtype(bools, "Bool") 797 | UInt = _make_dtype(uints, "UInt") 798 | Int = _make_dtype(ints, "Int") 799 | Integer = _make_dtype(uints + ints, "Integer") 800 | Float = _make_dtype(floats, "Float") 801 | Complex = _make_dtype(complexes, "Complex") 802 | Inexact = _make_dtype(floats + complexes, "Inexact") 803 | Real = _make_dtype(floats + uints + ints, "Real") 804 | Num = _make_dtype(uints + ints + floats + complexes, "Num") 805 | 806 | Shaped = _make_dtype(_any_dtype, "Shaped") 807 | 808 | Key = _make_dtype(_prng_key, "Key") 809 | 810 | 811 | def make_numpy_struct_dtype(dtype: "np.dtype", name: str): 812 | """Creates a type annotation for [numpy structured array](https://numpy.org/doc/stable/user/basics.rec.html#structured-arrays) 813 | It performs an exact match on the name, order, and dtype of all its fields. 814 | 815 | !!! Example 816 | 817 | ```python 818 | label_t = np.dtype([('first', np.uint8), ('second', np.int8)]) 819 | Label = make_numpy_struct_dtype(label_t, 'Label') 820 | ``` 821 | 822 | after that, you can use it just like any other [`jaxtyping.AbstractDtype`][]: 823 | 824 | ```python 825 | a: Label[np.ndarray, 'a b'] = np.array([[(1, 0), (0, 1)]], dtype=label_t) 826 | ``` 827 | 828 | **Arguments:** 829 | 830 | - `dtype`: The numpy structured dtype to use. 831 | - `name`: The name to use for the returned Python class. 832 | 833 | **Returns:** 834 | 835 | A type annotation with classname `name` that matches exactly `dtype` when used like 836 | any other [`jaxtyping.AbstractDtype`][]. 837 | """ 838 | if not (isinstance(dtype, np.dtype) and _dtype_is_numpy_struct_array(dtype)): 839 | raise ValueError(f"Expecting a numpy structured array dtype, not {dtype}") 840 | return _make_dtype(str(dtype), name) 841 | -------------------------------------------------------------------------------- /jaxtyping/_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union 3 | 4 | 5 | def _maybestr2bool(value: Union[bool, str], error: str) -> bool: 6 | if isinstance(value, bool): 7 | return value 8 | elif isinstance(value, str): 9 | if value.lower() in ("0", "false"): 10 | return False 11 | elif value.lower() in ("1", "true"): 12 | return True 13 | else: 14 | raise ValueError(error) 15 | else: 16 | raise ValueError(error) 17 | 18 | 19 | class _JaxtypingConfig: 20 | def __init__(self): 21 | self.update("jaxtyping_disable", os.environ.get("JAXTYPING_DISABLE", "0")) 22 | self.update( 23 | "jaxtyping_remove_typechecker_stack", 24 | os.environ.get("JAXTYPING_REMOVE_TYPECHECKER_STACK", "0"), 25 | ) 26 | 27 | def update(self, item: str, value): 28 | if item.lower() == "jaxtyping_disable": 29 | msg = ( 30 | "Unrecognised value for `JAXTYPING_DISABLE`. Valid values are " 31 | "`JAXTYPING_DISABLE=0` (the default) or `JAXTYPING_DISABLE=1` (to " 32 | "disable runtime type checking)." 33 | ) 34 | self.jaxtyping_disable = _maybestr2bool(value, msg) 35 | elif item.lower() == "jaxtyping_remove_typechecker_stack": 36 | msg = ( 37 | "Unrecognised value for `JAXTYPING_REMOVE_TYPECHECKER_STACK`. Valid " 38 | "values are `JAXTYPING_REMOVE_TYPECHECKER_STACK=0` (the default) or " 39 | "`JAXTYPING_REMOVE_TYPECHECKER_STACK=1` (to remove the stack frames " 40 | "from the typechecker in `jaxtyped(typechecker=...)`, when it raises a " 41 | "runtime type-checking error)." 42 | ) 43 | self.jaxtyping_remove_typechecker_stack = _maybestr2bool(value, msg) 44 | else: 45 | raise ValueError(f"Unrecognised config value {item}") 46 | 47 | 48 | config = _JaxtypingConfig() 49 | -------------------------------------------------------------------------------- /jaxtyping/_errors.py: -------------------------------------------------------------------------------- 1 | class TypeCheckError(TypeError): 2 | pass 3 | 4 | 5 | # Not inheriting from TypeError as that gets caught and re-reraised as just a TypeError 6 | # when using typeguard<3. 7 | class AnnotationError(Exception): 8 | pass 9 | 10 | 11 | TypeCheckError.__module__ = "jaxtyping" 12 | AnnotationError.__module__ = "jaxtyping" 13 | -------------------------------------------------------------------------------- /jaxtyping/_import_hook.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Google LLC 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | # This source code is adapted from typeguard: 21 | # https://github.com/agronholm/typeguard/blob/0dd7f7510b7c694e66a0d17d1d58d185125bad5d/src/typeguard/importhook.py 22 | # 23 | # Copied and adapted in compliance with the terms of typeguard's MIT license. 24 | # The original license is reproduced here. 25 | # 26 | # --------- 27 | # 28 | # This is the MIT license: http://www.opensource.org/licenses/mit-license.php 29 | # 30 | # Copyright (c) Alex Grönholm 31 | # 32 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this 33 | # software and associated documentation files (the "Software"), to deal in the Software 34 | # without restriction, including without limitation the rights to use, copy, modify, 35 | # merge, publish, distribute, sublicense, and/or sell copies of the Software, and to 36 | # permit persons to whom the Software is furnished to do so, subject to the following 37 | # conditions: 38 | # 39 | # The above copyright notice and this permission notice shall be included in all copies 40 | # or substantial portions of the Software. 41 | # 42 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 43 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 44 | # PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 45 | # HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF 46 | # CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE 47 | # OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 48 | # 49 | # --------- 50 | 51 | 52 | import ast 53 | import functools as ft 54 | import hashlib 55 | import sys 56 | from collections.abc import Sequence 57 | from importlib.abc import MetaPathFinder 58 | from importlib.machinery import SourceFileLoader 59 | from importlib.util import cache_from_source, decode_source 60 | from inspect import isclass 61 | from typing import Optional, Union 62 | from unittest.mock import patch 63 | 64 | 65 | # The name of this function is magical 66 | def _call_with_frames_removed(f, *args, **kwargs): 67 | return f(*args, **kwargs) 68 | 69 | 70 | def _optimized_cache_from_source(typechecker_hash, /, path, debug_override=None): 71 | # Version 2: change the position of the `@jaxtyped` decorator, so need a 72 | # different name to avoid hitting old __pycache__. 73 | # Version 3: now also annotating classes. 74 | # Version 4: I'm honestly not sure, but bumping this fixed some kind of odd error. 75 | # Maybe I changed something with hte classes part way through version 3? 76 | # Version 5: Added support for string-based `typechecker` argument. 77 | # Version 6: optimization tag now depends on `typechecker` argument, so that 78 | # changing the typechecker will hit a different cache. 79 | # Version 7: Using the same md5 hash of the `typechecker` argument 80 | # for importlib and decorator lookup. 81 | # Version 8: Now using new-style `jaxtyped(typechecker=...)` rather than old-style 82 | # double-decorators. 83 | # Version 9: Now reporting the correct source code lines. (Important when used with 84 | # a debugger.) 85 | return cache_from_source( 86 | path, debug_override, optimization=f"jaxtyping9{typechecker_hash}" 87 | ) 88 | 89 | 90 | class Typechecker: 91 | lookup = {} 92 | 93 | def __init__(self, typechecker): 94 | if isinstance(typechecker, str): 95 | # If the typechecker is a string, then we parse it 96 | string_to_eval = ( 97 | "def f(x, *args, **kwargs):\n" 98 | + f" import {typechecker.split('.', 1)[0]}\n" 99 | + f" return {typechecker}(x, *args, **kwargs)" 100 | ) 101 | 102 | # md5 hashing instead of __hash__ 103 | # because __hash__ is different for each Python session 104 | self.hash = hashlib.md5(typechecker.encode("utf-8")).hexdigest() 105 | 106 | vars = {} 107 | exec(string_to_eval, {}, vars) 108 | Typechecker.lookup[self.hash] = vars["f"] 109 | 110 | elif typechecker is None: 111 | # If it is None, ignore it silently (use dummy decorator) 112 | self.hash = "0" 113 | Typechecker.lookup[self.hash] = lambda x, *_, **__: x 114 | else: 115 | # Passed typechecker is invalid 116 | raise TypeError( 117 | "Jaxtyping typechecker has to be either a string or a None." 118 | ) 119 | 120 | def get_hash(self): 121 | return self.hash 122 | 123 | def get_ast(self): 124 | # Note that we compile AST only if we missed importlib cache. 125 | # No caching on this function! We modify the return type every time, with 126 | # its appropriate source code location. 127 | return ( 128 | ast.parse( 129 | f"@jaxtyping.jaxtyped(typechecker=jaxtyping._import_hook.Typechecker.lookup['{self.hash}'])\n" 130 | "def _():\n ..." 131 | ) 132 | .body[0] 133 | .decorator_list[0] 134 | ) 135 | 136 | 137 | class JaxtypingTransformer(ast.NodeVisitor): 138 | def __init__(self, *, typechecker: Typechecker) -> None: 139 | self._parents: list[ast.AST] = [] 140 | self._typechecker = typechecker 141 | 142 | def visit_Module(self, node: ast.Module): 143 | # Insert "import jaxtyping" after any "from __future__ ..." imports 144 | for i, child in enumerate(node.body): 145 | if isinstance(child, ast.ImportFrom) and child.module == "__future__": 146 | continue 147 | elif isinstance(child, ast.Expr) and isinstance(child.value, ast.Constant): 148 | continue # module docstring 149 | else: 150 | node.body.insert(i, ast.Import(names=[ast.alias("jaxtyping", None)])) 151 | break 152 | 153 | self._parents.append(node) 154 | self.generic_visit(node) 155 | self._parents.pop() 156 | return node 157 | 158 | def visit_ClassDef(self, node: ast.ClassDef): 159 | # Place at the start of the decorator list, so that `@dataclass` decorators get 160 | # called first. 161 | decorator = self._typechecker.get_ast() 162 | ast.copy_location(decorator, node) 163 | node.decorator_list.insert(0, decorator) 164 | self._parents.append(node) 165 | self.generic_visit(node) 166 | self._parents.pop() 167 | return node 168 | 169 | def visit_FunctionDef(self, node: ast.FunctionDef): 170 | # Originally, we had some code here to explicitly check if the function 171 | # had any annotated arguments or annotated return types, and if not, we 172 | # would skip adding the `@jaxtyped` decorator. 173 | # However, this has been removed because it would ignore functions that 174 | # had type annotations in the body of the function (or 175 | # `assert isinstance(..., SomeType)`). 176 | 177 | decorator = self._typechecker.get_ast() 178 | ast.copy_location(decorator, node) 179 | # Place at the end of the decorator list, because: 180 | # - as otherwise we wrap e.g. `jax.custom_{jvp,vjp}` and lose the ability 181 | # to `defjvp` etc. 182 | # - decorators frequently remove annotations from functions, and we'd like 183 | # to use those annotations. 184 | # - typeguard in particular wants to be at the end of the decorator list, as 185 | # it works by recompling the wrapped function. 186 | # 187 | # Note that the counter-argument here is that we'd like to place this 188 | # at the start of the decorator list, in case a typechecking annotation 189 | # has been manually applied, and we'd need to be above that. In this 190 | # case we're just going to have to need to ask the user to remove their 191 | # typechecking annotation (and let this decorator do it instead). 192 | # It's more important we be compatible with normal JAX code. 193 | node.decorator_list.append(decorator) 194 | 195 | self._parents.append(node) 196 | self.generic_visit(node) 197 | self._parents.pop() 198 | return node 199 | 200 | 201 | class _JaxtypingLoader(SourceFileLoader): 202 | def __init__(self, *args, typechecker: Typechecker, **kwargs): 203 | super().__init__(*args, **kwargs) 204 | self._typechecker = typechecker 205 | 206 | def source_to_code(self, data, path, *, _optimize=-1): 207 | source = decode_source(data) 208 | tree = _call_with_frames_removed( 209 | compile, 210 | source, 211 | path, 212 | "exec", 213 | ast.PyCF_ONLY_AST, 214 | dont_inherit=True, 215 | optimize=_optimize, 216 | ) 217 | tree = JaxtypingTransformer(typechecker=self._typechecker).visit(tree) 218 | ast.fix_missing_locations(tree) 219 | return _call_with_frames_removed( 220 | compile, tree, path, "exec", dont_inherit=True, optimize=_optimize 221 | ) 222 | 223 | def exec_module(self, module): 224 | # Use a custom optimization marker - the import lock should make this monkey 225 | # patch safe 226 | with patch( 227 | "importlib._bootstrap_external.cache_from_source", 228 | ft.partial(_optimized_cache_from_source, self._typechecker.get_hash()), 229 | ): 230 | return super().exec_module(module) 231 | 232 | 233 | class _JaxtypingFinder(MetaPathFinder): 234 | """Wraps another path finder and instruments the module with `@jaxtyped` and 235 | `@typechecked` if `should_instrument()` returns `True`. 236 | 237 | Should not be used directly, but rather via `install_import_hook`. 238 | """ 239 | 240 | def __init__(self, modules, original_pathfinder, typechecker: Typechecker): 241 | self.modules = modules 242 | self._original_pathfinder = original_pathfinder 243 | self._typechecker = typechecker 244 | 245 | def find_spec(self, fullname, path=None, target=None): 246 | if self.should_instrument(fullname): 247 | spec = self._original_pathfinder.find_spec(fullname, path, target) 248 | if spec is not None and isinstance(spec.loader, SourceFileLoader): 249 | spec.loader = _JaxtypingLoader( 250 | spec.loader.name, spec.loader.path, typechecker=self._typechecker 251 | ) 252 | return spec 253 | 254 | return None 255 | 256 | def should_instrument(self, module_name: str) -> bool: 257 | """Determine whether the module with the given name should be instrumented. 258 | 259 | **Arguments:** 260 | 261 | - `module_name`: the full name of the module that is about to be imported 262 | (e.g. ``xyz.abc``) 263 | """ 264 | for module in self.modules: 265 | if module_name == module or module_name.startswith(module + "."): 266 | return True 267 | 268 | return False 269 | 270 | 271 | class ImportHookManager: 272 | def __init__(self, hook: MetaPathFinder): 273 | self.hook = hook 274 | 275 | def __enter__(self): 276 | pass 277 | 278 | def __exit__(self, exc_type, exc_val, exc_tb): 279 | self.uninstall() 280 | 281 | def uninstall(self): 282 | try: 283 | sys.meta_path.remove(self.hook) 284 | except ValueError: 285 | pass # already removed 286 | 287 | 288 | # Deliberately no default for `typechecker` so that folks must opt-in to not having 289 | # a typechecker. 290 | def install_import_hook(modules: Union[str, Sequence[str]], typechecker: Optional[str]): 291 | """Automatically apply the `@jaxtyped(typechecker=typechecker)` decorator to every 292 | function and dataclass over a whole codebase. 293 | 294 | !!! Tip "Usage" 295 | 296 | ```python 297 | from jaxtyping import install_import_hook 298 | # Plus any one of the following: 299 | 300 | # decorate `@jaxtyped(typechecker=typeguard.typechecked)` 301 | with install_import_hook("foo", "typeguard.typechecked"): 302 | import foo # Any module imported inside this `with` block, whose 303 | import foo.bar # name begins with the specified string, will 304 | import foo.bar.qux # automatically have both `@jaxtyped` and the specified 305 | # typechecker applied to all of their functions and 306 | # dataclasses. 307 | 308 | # decorate `@jaxtyped(typechecker=beartype.beartype)` 309 | with install_import_hook("foo", "beartype.beartype"): 310 | ... 311 | 312 | # decorate only `@jaxtyped` (if you want that for some reason) 313 | with install_import_hook("foo", None): 314 | ... 315 | ``` 316 | 317 | If you don't like using the `with` block, the hook can be used without that: 318 | ```python 319 | hook = install_import_hook(...) 320 | import ... 321 | hook.uninstall() 322 | ``` 323 | 324 | The import hook can be applied to multiple packages via 325 | ```python 326 | install_import_hook(["foo", "bar.baz"], ...) 327 | ``` 328 | 329 | **Arguments:** 330 | 331 | - `modules`: the names of the modules in which to automatically apply `@jaxtyped`. 332 | - `typechecker`: the module and function of the typechecker you want to use, as a 333 | string. For example `typechecker="typeguard.typechecked"`, or 334 | `typechecker="beartype.beartype"`. You may pass `typechecker=None` if you do not 335 | want to automatically decorate with a typechecker as well. 336 | 337 | **Returns:** 338 | 339 | A context manager that uninstalls the hook on exit, or when you call `.uninstall()`. 340 | 341 | !!! Example "Example: end-user script" 342 | 343 | ```python 344 | ### entry_point.py 345 | from jaxtyping import install_import_hook 346 | with install_import_hook("main", "typeguard.typechecked"): 347 | import main 348 | 349 | ### main.py 350 | from jaxtyping import Array, Float32 351 | 352 | def f(x: Float32[Array, "batch channels"]): 353 | ... 354 | ``` 355 | 356 | !!! Example "Example: writing a library" 357 | 358 | ```python 359 | ### __init__.py 360 | from jaxtyping import install_import_hook 361 | with install_import_hook("my_library_name", "beartype.beartype"): 362 | from .subpackage import foo # full name is my_library_name.subpackage so 363 | # will be hook'd 364 | from .another_subpackage import bar # full name is my_library_name.another_subpackage 365 | # so will be hook'd. 366 | ``` 367 | 368 | !!! warning 369 | 370 | If a function already has any decorators on it, then `@jaxtyped` will get added 371 | at the bottom of the decorator list, e.g. 372 | ```python 373 | @some_other_decorator 374 | @jaxtyped(typechecker=beartype.beartype) 375 | def foo(...): ... 376 | ``` 377 | This is to support the common case in which 378 | `some_other_decorator = jax.custom_jvp` etc. 379 | 380 | If a class already has any decorators in it, then `@jaxtyped` will get added to 381 | the top of the decorator list, e.g. 382 | ```python 383 | @jaxtyped(typechecker=beartype.beartype) 384 | @some_other_decorator 385 | class A: 386 | ... 387 | ``` 388 | This is to support the common case in which 389 | `some_other_decorator = dataclasses.dataclass`. 390 | """ # noqa: E501 391 | 392 | if isinstance(modules, str): 393 | modules = [modules] 394 | 395 | # Support old less-flexible API. 396 | if isinstance(typechecker, tuple): 397 | typechecker = ".".join(typechecker) 398 | 399 | for i, finder in enumerate(sys.meta_path): 400 | if ( 401 | isclass(finder) 402 | and finder.__name__ == "PathFinder" 403 | and hasattr(finder, "find_spec") 404 | ): 405 | break 406 | else: 407 | raise RuntimeError("Cannot find a PathFinder in sys.meta_path") 408 | 409 | wrapped_typechecker = Typechecker(typechecker) 410 | hook = _JaxtypingFinder(modules, finder, wrapped_typechecker) 411 | sys.meta_path.insert(0, hook) 412 | return ImportHookManager(hook) 413 | -------------------------------------------------------------------------------- /jaxtyping/_indirection.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Google LLC 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | # Note that `from typing import Annotated; Bool = Annotated` 21 | # does not work with static type checkers. `Annotated` is a typeform rather 22 | # than a type, meaning it cannot be assigned. 23 | from typing import ( 24 | Annotated as BFloat16, # noqa: F401 25 | Annotated as Bool, # noqa: F401 26 | Annotated as Complex, # noqa: F401 27 | Annotated as Complex64, # noqa: F401 28 | Annotated as Complex128, # noqa: F401 29 | Annotated as Float, # noqa: F401 30 | Annotated as Float8e4m3b11fnuz, # noqa: F401 31 | Annotated as Float8e4m3fn, # noqa: F401 32 | Annotated as Float8e4m3fnuz, # noqa: F401 33 | Annotated as Float8e5m2, # noqa: F401 34 | Annotated as Float8e5m2fnuz, # noqa: F401 35 | Annotated as Float16, # noqa: F401 36 | Annotated as Float32, # noqa: F401 37 | Annotated as Float64, # noqa: F401 38 | Annotated as Inexact, # noqa: F401 39 | Annotated as Int, # noqa: F401 40 | Annotated as Int2, # noqa: F401 41 | Annotated as Int4, # noqa: F401 42 | Annotated as Int8, # noqa: F401 43 | Annotated as Int16, # noqa: F401 44 | Annotated as Int32, # noqa: F401 45 | Annotated as Int64, # noqa: F401 46 | Annotated as Integer, # noqa: F401 47 | Annotated as Key, # noqa: F401 48 | Annotated as Num, # noqa: F401 49 | Annotated as Real, # noqa: F401 50 | Annotated as Shaped, # noqa: F401 51 | Annotated as UInt, # noqa: F401 52 | Annotated as UInt2, # noqa: F401 53 | Annotated as UInt4, # noqa: F401 54 | Annotated as UInt8, # noqa: F401 55 | Annotated as UInt16, # noqa: F401 56 | Annotated as UInt32, # noqa: F401 57 | Annotated as UInt64, # noqa: F401 58 | TYPE_CHECKING, 59 | ) 60 | 61 | 62 | if not TYPE_CHECKING: 63 | assert False 64 | 65 | from jax import ( 66 | Array as PRNGKeyArray, # noqa: F401 67 | Array as Scalar, # noqa: F401 68 | ) 69 | from jax.typing import ArrayLike as ScalarLike # noqa: F401 70 | -------------------------------------------------------------------------------- /jaxtyping/_ipython_extension.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Google LLC 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | from ._import_hook import JaxtypingTransformer, Typechecker 21 | 22 | 23 | def choose_typechecker_magics(): 24 | # The import is local to avoid degrading import times when the magic is 25 | # not needed. 26 | from IPython.core.magic import line_magic, Magics, magics_class 27 | 28 | @magics_class 29 | class ChooseTypecheckerMagics(Magics): 30 | @line_magic("jaxtyping.typechecker") 31 | def typechecker(self, typechecker): 32 | # remove old JaxtypingTransformer, if present 33 | self.shell.ast_transformers = list( 34 | filter( 35 | lambda x: not isinstance(x, JaxtypingTransformer), 36 | self.shell.ast_transformers, 37 | ) 38 | ) 39 | 40 | # add new one 41 | self.shell.ast_transformers.append( 42 | JaxtypingTransformer(typechecker=Typechecker(typechecker)) 43 | ) 44 | 45 | return ChooseTypecheckerMagics 46 | 47 | 48 | def load_ipython_extension(ipython): 49 | try: 50 | ChooseTypecheckerMagics = choose_typechecker_magics() 51 | except Exception as e: 52 | # Very broad exception-handling, as e.g. IPython will sometimes be 53 | # present but fail to import for mysterious reasons. 54 | raise RuntimeError("Failed to define jaxtyping.typechecker magic") from e 55 | 56 | ipython.register_magics(ChooseTypecheckerMagics) 57 | -------------------------------------------------------------------------------- /jaxtyping/_pytest_plugin.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Google LLC 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | import sys 21 | 22 | from ._import_hook import install_import_hook 23 | 24 | 25 | def pytest_addoption(parser): 26 | group = parser.getgroup("jaxtyping") 27 | group.addoption( 28 | "--jaxtyping-packages", 29 | action="store", 30 | help="comma separated name list of packages and modules to instrument for " 31 | "type checking with jaxtyping. The last element in the list should be the " 32 | "type checker to use, e.g. " 33 | "--jaxtyping-packages=foopackage,barpackage,typeguard.typechecked", 34 | ) 35 | 36 | 37 | def pytest_configure(config): 38 | value = config.getoption("jaxtyping_packages") 39 | if not value: 40 | return 41 | 42 | packages = [pkg.strip() for pkg in value.split(",")] 43 | *packages, typechecker = packages 44 | 45 | already_imported_packages = sorted( 46 | package for package in packages if package in sys.modules 47 | ) 48 | if already_imported_packages: 49 | message = ( 50 | "jaxtyping cannot check these packages because they " 51 | "are already imported: {}" 52 | ) 53 | raise RuntimeError(message.format(", ".join(already_imported_packages))) 54 | 55 | install_import_hook(packages, typechecker) 56 | -------------------------------------------------------------------------------- /jaxtyping/_pytree_type.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Google LLC 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | import functools as ft 21 | import typing 22 | from typing import Any, Generic, TypeVar 23 | 24 | import jax.tree_util as jtu 25 | import wadler_lindig as wl 26 | 27 | from ._errors import AnnotationError 28 | from ._storage import ( 29 | clear_treeflatten_memo, 30 | clear_treepath_memo, 31 | get_shape_memo, 32 | set_shape_memo, 33 | set_treeflatten_memo, 34 | set_treepath_memo, 35 | ) 36 | 37 | 38 | _T = TypeVar("_T") 39 | _S = TypeVar("_S") 40 | 41 | 42 | class _FakePyTree1(Generic[_T]): 43 | pass 44 | 45 | 46 | _FakePyTree1.__name__ = "PyTree" 47 | _FakePyTree1.__qualname__ = "PyTree" 48 | _FakePyTree1.__module__ = "builtins" 49 | 50 | 51 | class _FakePyTree2(Generic[_T, _S]): 52 | pass 53 | 54 | 55 | _FakePyTree2.__name__ = "PyTree" 56 | _FakePyTree2.__qualname__ = "PyTree" 57 | _FakePyTree2.__module__ = "builtins" 58 | 59 | 60 | class _MetaPyTree(type): 61 | def __call__(self, *args, **kwargs): 62 | raise RuntimeError("PyTree cannot be instantiated") 63 | 64 | def __instancecheck__(cls, obj): 65 | if not hasattr(cls, "leaftype"): 66 | return True # Just `isinstance(x, PyTree)` 67 | # Handle beartype doing `isinstance(None, hint)` to check if 68 | # is `instance`able. 69 | if obj is None: 70 | return True 71 | 72 | single_memo, variadic_memo, pytree_memo, arg_memo = get_shape_memo() 73 | single_memo_bak = single_memo.copy() 74 | variadic_memo_bak = variadic_memo.copy() 75 | pytree_memo_bak = pytree_memo.copy() 76 | arg_memo_bak = arg_memo.copy() 77 | try: 78 | out = cls._check(obj, pytree_memo) 79 | except Exception: 80 | set_shape_memo( 81 | single_memo_bak, variadic_memo_bak, pytree_memo_bak, arg_memo_bak 82 | ) 83 | raise 84 | if out: 85 | return True 86 | else: 87 | set_shape_memo( 88 | single_memo_bak, variadic_memo_bak, pytree_memo_bak, arg_memo_bak 89 | ) 90 | return False 91 | 92 | def _check(cls, obj, pytree_memo): 93 | if cls.leaftype is Any: 94 | 95 | def is_flatten_leaftype(x): 96 | return False 97 | 98 | def is_check_leaftype(x): 99 | return True 100 | 101 | else: 102 | # We could use `isinstance` here but that would fail for more complicated 103 | # types, e.g. PyTree[tuple[int]]. So at least internally we make a 104 | # particular choice of typechecker. 105 | # 106 | # Deliberately not using @jaxtyped so that we share the same `memo` as 107 | # whatever dynamic context we're currently in. 108 | from ._typeguard import typechecked 109 | 110 | @typechecked 111 | def accepts_leaftype(x: cls.leaftype): 112 | pass 113 | 114 | def is_leaftype(x): 115 | try: 116 | accepts_leaftype(x) 117 | except TypeError: 118 | return False 119 | else: 120 | return True 121 | 122 | is_flatten_leaftype = is_check_leaftype = is_leaftype 123 | 124 | set_treeflatten_memo() 125 | try: 126 | leaves, structure = jtu.tree_flatten(obj, is_leaf=is_flatten_leaftype) 127 | finally: 128 | clear_treeflatten_memo() 129 | if cls.structure is not None: 130 | if cls.structure.isidentifier(): 131 | try: 132 | prev_structure = pytree_memo[cls.structure] 133 | except KeyError: 134 | pytree_memo[cls.structure] = structure 135 | else: 136 | if prev_structure != structure: 137 | return False 138 | else: 139 | named_pytree = 0 140 | pieces = cls.structure.split() 141 | if pieces[0] == "...": 142 | pieces = pieces[1:] 143 | prefix = False 144 | suffix = True 145 | elif pieces[-1] == "...": 146 | pieces = pieces[:-1] 147 | prefix = True 148 | suffix = False 149 | else: 150 | prefix = False 151 | suffix = False 152 | for identifier in pieces: 153 | try: 154 | prev_structure = pytree_memo[identifier] 155 | except KeyError as e: 156 | raise AnnotationError( 157 | f"Cannot process composite structure '{cls.structure}' " 158 | f"as the structure name {identifier} has not been seen " 159 | "before." 160 | ) from e 161 | # Not using `PyTreeDef.compose` due to JAX bug #18218. 162 | prev_pytree = jtu.tree_unflatten( 163 | prev_structure, [0] * prev_structure.num_leaves 164 | ) 165 | named_pytree = jtu.tree_map(lambda _: prev_pytree, named_pytree) 166 | named_structure = jtu.tree_structure(named_pytree) 167 | if prefix: 168 | dummy_pytree = jtu.tree_unflatten(structure, [0] * len(leaves)) 169 | dummy_named = jtu.tree_unflatten( 170 | named_structure, [0] * named_structure.num_leaves 171 | ) 172 | try: 173 | jtu.tree_map(lambda _, __: 0, dummy_named, dummy_pytree) 174 | except ValueError: 175 | return False 176 | elif suffix: 177 | has_structure = lambda x: jtu.tree_structure(x) == named_structure 178 | dummy_pytree = jtu.tree_unflatten(structure, [0] * len(leaves)) 179 | dummy_leaves = jtu.tree_leaves(dummy_pytree, is_leaf=has_structure) 180 | if any(not has_structure(x) for x in dummy_leaves): 181 | return False 182 | else: 183 | if structure != named_structure: 184 | return False 185 | 186 | try: 187 | for leaf_index, leaf in enumerate(leaves): 188 | if cls.structure is not None: 189 | set_treepath_memo(leaf_index, cls.structure) 190 | if not is_check_leaftype(leaf): 191 | return False 192 | clear_treepath_memo() 193 | finally: 194 | clear_treepath_memo() 195 | return True 196 | 197 | # Can't return a generic (e.g. _FakePyTree[item]) because generic aliases don't do 198 | # the custom __instancecheck__ that we want. 199 | # We can't add that __instancecheck__ via subclassing, e.g. 200 | # type("PyTree", (Generic[_T],), {}), because dynamic subclassing of typeforms 201 | # isn't allowed. 202 | # Likewise we can't do types.new_class("PyTree", (Generic[_T],), {}) because that 203 | # has __module__ "types", e.g. we get types.PyTree[int]. 204 | @ft.lru_cache(maxsize=None) 205 | def __getitem__(cls, item): 206 | if isinstance(item, tuple): 207 | if len(item) == 2: 208 | 209 | class X(PyTree): 210 | leaftype = item[0] 211 | structure = item[1].strip() 212 | 213 | if not isinstance(X.structure, str): 214 | raise ValueError( 215 | "The structure annotation `struct` in " 216 | "`jaxtyping.PyTree[leaftype, struct]` must be be a string, " 217 | f"e.g. `jaxtyping.PyTree[leaftype, 'T']`. Got '{X.structure}'." 218 | ) 219 | pieces = X.structure.split() 220 | if len(pieces) == 0: 221 | raise ValueError( 222 | "The string `struct` in `jaxtyping.PyTree[leaftype, struct]` " 223 | "cannot be the empty string." 224 | ) 225 | for piece_index, piece in enumerate(pieces): 226 | if (piece_index == 0) or (piece_index == len(pieces) - 1): 227 | if piece == "...": 228 | continue 229 | if not piece.isidentifier(): 230 | raise ValueError( 231 | "The string `struct` in " 232 | "`jaxtyping.PyTree[leaftype, struct]` must be be a " 233 | "whitespace-separated sequence of identifiers, e.g. " 234 | "`jaxtyping.PyTree[leaftype, 'T']` or " 235 | "`jaxtyping.PyTree[leaftype, 'foo bar']`.\n" 236 | "(Here, 'identifier' is used in the same sense as in " 237 | "regular Python, i.e. a valid variable name.)\n" 238 | f"Got piece '{piece}' in overall structure '{X.structure}'." 239 | ) 240 | 241 | class Y: 242 | pass 243 | 244 | Y.__module__ = "builtins" 245 | Y.__name__ = repr(X.structure) 246 | Y.__qualname__ = repr(X.structure) 247 | name = wl.pformat(_FakePyTree2[X.leaftype, Y], width=9999) 248 | del Y 249 | else: 250 | raise ValueError( 251 | "The subscript `foo` in `jaxtyping.PyTree[foo]` must either be a " 252 | "leaf type, e.g. `PyTree[int]`, or a 2-tuple of leaf and " 253 | "structure, e.g. `PyTree[int, 'T']`. Received a tuple of length " 254 | f"{len(item)}." 255 | ) 256 | else: 257 | name = wl.pformat(_FakePyTree1[item], width=9999) 258 | 259 | class X(PyTree): 260 | leaftype = item 261 | structure = None 262 | 263 | X.__name__ = name 264 | X.__qualname__ = name 265 | if getattr(typing, "GENERATING_DOCUMENTATION", "") in {"", "jaxtyping"}: 266 | X.__module__ = "jaxtyping" 267 | else: 268 | X.__module__ = "builtins" 269 | return X 270 | 271 | def __pdoc__(self, **kwargs): 272 | if self is PyTree: 273 | return wl.TextDoc("PyTree") 274 | else: 275 | indent = kwargs["indent"] 276 | docs = [wl.pdoc(self.leaftype, **kwargs)] 277 | if self.structure is not None: 278 | docs.append(wl.pdoc(self.structure, **kwargs)) 279 | return wl.bracketed( 280 | begin=wl.TextDoc("PyTree["), 281 | docs=docs, 282 | sep=wl.comma, 283 | end=wl.TextDoc("]"), 284 | indent=indent, 285 | ) 286 | 287 | 288 | # Can't do `class PyTree(Generic[_T]): ...` because we need to override the 289 | # instancecheck for PyTree[foo], but subclassing 290 | # `type(Generic[int])`, i.e. `typing._GenericAlias` is disallowed. 291 | PyTree = _MetaPyTree("PyTree", (), {}) 292 | if getattr(typing, "GENERATING_DOCUMENTATION", "") in {"", "jaxtyping"}: 293 | PyTree.__module__ = "jaxtyping" 294 | else: 295 | PyTree.__module__ = "builtins" 296 | PyTree.__doc__ = """Represents a PyTree. 297 | 298 | Annotations of the following sorts are supported: 299 | ```python 300 | a: PyTree 301 | b: PyTree[LeafType] 302 | c: PyTree[LeafType, "T"] 303 | d: PyTree[LeafType, "S T"] 304 | e: PyTree[LeafType, "... T"] 305 | f: PyTree[LeafType, "T ..."] 306 | ``` 307 | 308 | These correspond to: 309 | 310 | a. A plain `PyTree` can be used an annotation, in which case `PyTree` is simply a 311 | suggestively-named alternative to `Any`. 312 | ([By definition all types are PyTrees.](https://jax.readthedocs.io/en/latest/pytrees.html)) 313 | 314 | b. `PyTree[LeafType]` denotes a PyTree all of whose leaves match `LeafType`. For 315 | example, `PyTree[int]` or `PyTree[Union[str, Float32[Array, "b c"]]]`. 316 | 317 | c. A structure name can also be passed. In this case 318 | `jax.tree_util.tree_structure(...)` will be called, and bound to the structure name. 319 | This can be used to mark that multiple PyTrees all have the same structure: 320 | ```python 321 | def f(x: PyTree[int, "T"], y: PyTree[int, "T"]): 322 | ... 323 | ``` 324 | Structures are bound to names in the same way as array shape annotations, i.e. 325 | within the thread-local dynamic context of a [`jaxtyping.jaxtyped`][] decorator. 326 | 327 | d. A composite structure can be declared. In this case the variable must have a PyTree 328 | structure each to the composition of multiple previously-bound PyTree structures. 329 | For example: 330 | ```python 331 | def f(x: PyTree[int, "T"], y: PyTree[int, "S"], z: PyTree[int, "S T"]): 332 | ... 333 | 334 | x = (1, 2) 335 | y = {"key": 3} 336 | z = {"key": (4, 5)} # structure is the composition of the structures of `y` and `z` 337 | f(x, y, z) 338 | ``` 339 | When performing runtime type-checking, all the individual pieces must have already 340 | been bound to structures, otherwise the composite structure check will throw an error. 341 | 342 | e. A structure can begin with a `...`, to denote that the lower levels of the PyTree 343 | must match the declared structure, but the upper levels can be arbitrary. As in the 344 | previous case, all named pieces must already have been seen and their structures 345 | bound. 346 | 347 | f. A structure can end with a `...`, to denote that the PyTree must be a prefix of the 348 | declared structure, but the lower levels can be arbitrary. As in the previous two 349 | cases, all named pieces must already have been seen and their structures bound. 350 | """ # noqa: E501 351 | -------------------------------------------------------------------------------- /jaxtyping/_storage.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Google LLC 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | import threading 21 | from typing import Any, Optional 22 | 23 | from ._errors import AnnotationError 24 | 25 | 26 | _shape_storage = threading.local() 27 | 28 | 29 | def _has_shape_memo(): 30 | return hasattr(_shape_storage, "memo_stack") and len(_shape_storage.memo_stack) != 0 31 | 32 | 33 | def get_shape_memo(): 34 | if _has_shape_memo(): 35 | single_memo, variadic_memo, pytree_memo, arguments = _shape_storage.memo_stack[ 36 | -1 37 | ] 38 | else: 39 | # `isinstance` happening outside any @jaxtyped decorators, e.g. at the 40 | # global scope. In this case just create a temporary memo, since we're not 41 | # going to be comparing against any stored values anyway. 42 | single_memo = {} 43 | variadic_memo = {} 44 | pytree_memo = {} 45 | arguments = {} 46 | return single_memo, variadic_memo, pytree_memo, arguments 47 | 48 | 49 | def set_shape_memo(single_memo, variadic_memo, pytree_memo, arg_memo) -> None: 50 | if _has_shape_memo(): 51 | _shape_storage.memo_stack[-1] = ( 52 | single_memo, 53 | variadic_memo, 54 | pytree_memo, 55 | arg_memo, 56 | ) 57 | 58 | 59 | def push_shape_memo(arguments: dict[str, Any]): 60 | try: 61 | memo_stack = _shape_storage.memo_stack 62 | except AttributeError: 63 | # Can't be done when `_stack_storage` is created for reasons I forget. 64 | memo_stack = _shape_storage.memo_stack = [] 65 | memos = ({}, {}, {}, arguments.copy()) 66 | memo_stack.append(memos) 67 | return memos 68 | 69 | 70 | def pop_shape_memo() -> None: 71 | _shape_storage.memo_stack.pop() 72 | 73 | 74 | def shape_str(memos) -> str: 75 | """Gives debug information on the current state of jaxtyping's internal memos. 76 | Used in type-checking error messages. 77 | 78 | **Arguments:** 79 | 80 | - `memos`: as returned by `get_shape_memo` or `push_shape_memo`. 81 | """ 82 | single_memo, variadic_memo, pytree_memo, _ = memos 83 | single_memo = { 84 | name: size 85 | for name, size in single_memo.items() 86 | if not name.startswith("~~delete~~") 87 | } 88 | variadic_memo = { 89 | name: shape 90 | for name, (_, shape) in variadic_memo.items() 91 | if not name.startswith("~~delete~~") 92 | } 93 | pieces = [] 94 | if len(single_memo) > 0 or len(variadic_memo) > 0: 95 | pieces.append( 96 | "The current values for each jaxtyping axis annotation are as follows." 97 | ) 98 | for name, size in single_memo.items(): 99 | pieces.append(f"{name}={size}") 100 | for name, shape in variadic_memo.items(): 101 | pieces.append(f"{name}={shape}") 102 | if len(pytree_memo) > 0: 103 | pieces.append( 104 | "The current values for each jaxtyping PyTree structure annotation are as " 105 | "follows." 106 | ) 107 | for name, structure in pytree_memo.items(): 108 | pieces.append(f"{name}={structure}") 109 | return "\n".join(pieces) 110 | 111 | 112 | def print_bindings(): 113 | """Prints the values of the current jaxtyping axis bindings. Intended for debugging. 114 | 115 | For example, this can be used to find the values bound to `foo` and `bar` in 116 | 117 | ```python 118 | @jaxtyped(typechecker=...) 119 | def f(x: Float[Array, "foo bar"]): 120 | print_bindings() 121 | ... 122 | ``` 123 | 124 | noting that these values are bounding during runtime typechecking, so that the 125 | [`jaxtyping.jaxtyped`][] decorator is required. 126 | 127 | **Arguments:** 128 | 129 | Nothing. 130 | 131 | **Returns:** 132 | 133 | Nothing. 134 | """ 135 | print(shape_str(get_shape_memo())) 136 | 137 | 138 | _treepath_storage = threading.local() 139 | 140 | 141 | def clear_treepath_memo() -> None: 142 | _treepath_storage.value = None 143 | 144 | 145 | def set_treepath_memo(index: Optional[int], structure: str) -> None: 146 | if hasattr(_treepath_storage, "value") and _treepath_storage.value is not None: 147 | raise AnnotationError( 148 | "Cannot typecheck annotations of the form " 149 | "`PyTree[PyTree[Shaped[Array, '?foo'], 'T'], 'S']` as it is ambiguous " 150 | "which PyTree the `?` annotation refers to." 151 | ) 152 | if index is None: 153 | _treepath_storage.value = f"~~delete~~({structure}) " 154 | else: 155 | # Appears in error messages, so human-readable 156 | _treepath_storage.value = f"(Leaf {index} in structure {structure}) " 157 | 158 | 159 | def get_treepath_memo() -> str: 160 | if not hasattr(_treepath_storage, "value") or _treepath_storage.value is None: 161 | raise AnnotationError( 162 | "Cannot use `?` annotations, e.g. `Shaped[Array, '?foo']`, except " 163 | "when contained with structured `PyTree` annotations, e.g. " 164 | "`PyTree[Shaped[Array, '?foo'], 'T']`." 165 | ) 166 | return _treepath_storage.value 167 | 168 | 169 | _treeflatten_storage = threading.local() 170 | 171 | 172 | def clear_treeflatten_memo() -> None: 173 | _treeflatten_storage.value = False 174 | 175 | 176 | def set_treeflatten_memo(): 177 | _treeflatten_storage.value = True 178 | 179 | 180 | def get_treeflatten_memo(): 181 | try: 182 | return _treeflatten_storage.value 183 | except AttributeError: 184 | return False 185 | -------------------------------------------------------------------------------- /jaxtyping/_typeguard/LICENSE: -------------------------------------------------------------------------------- 1 | This is the MIT license: http://www.opensource.org/licenses/mit-license.php 2 | 3 | Copyright (c) Alex Grönholm 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this 6 | software and associated documentation files (the "Software"), to deal in the Software 7 | without restriction, including without limitation the rights to use, copy, modify, merge, 8 | publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons 9 | to whom the Software is furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in all copies or 12 | substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 15 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 16 | PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 17 | FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 18 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | DEALINGS IN THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /jaxtyping/_typeguard/README.md: -------------------------------------------------------------------------------- 1 | # Vendored copy of typeguard v2.13.3 2 | 3 | We include a vendored copy of typeguard v2.13.3. The reason we need a runtime typechecker is to be able to define `isinstance(..., PyTree[Foo])`. 4 | 5 | Of the available options: 6 | 7 | - `beartype` does not support `O(n)` checking. 8 | - `typeguard` v4 is notorious for having some bugs (they seem to re-parse the AST or something?? And then they die on the fact that we have strings in our annotations.) 9 | - `typeguard` v2 is what we use here... but we vendor it instead of depending on it, because people may still wish to use typeguard v4 in their own environments. (Notably a number of other packages depend on this, and it's just inconvenient to be incompatible at the package level, when the combinations which don't mix at runtime might never actually be used.) 10 | 11 | This is vendored under the terms of the MIT license, which is also reproduced here. 12 | -------------------------------------------------------------------------------- /jaxtyping/py.typed: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | theme: 2 | name: material 3 | features: 4 | - navigation.sections # Sections are included in the navigation on the left. 5 | - toc.integrate # Table of contents is integrated on the left; does not appear separately on the right. 6 | - header.autohide # header disappears as you scroll 7 | palette: 8 | # Light mode / dark mode 9 | # We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as 10 | # (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle. 11 | - scheme: default 12 | primary: white 13 | accent: amber 14 | toggle: 15 | icon: material/weather-night 16 | name: Switch to dark mode 17 | - scheme: slate 18 | primary: black 19 | accent: amber 20 | toggle: 21 | icon: material/weather-sunny 22 | name: Switch to light mode 23 | icon: 24 | repo: fontawesome/brands/github # GitHub logo in top right 25 | logo: "material/check-network-outline" # jaxtyping logo in top left 26 | favicon: "_static/favicon.png" 27 | custom_dir: "docs/_overrides" # Overriding part of the HTML 28 | 29 | # These additions are my own custom ones, having overridden a partial. 30 | twitter_bluesky_name: "@PatrickKidger" 31 | twitter_url: "https://twitter.com/PatrickKidger" 32 | bluesky_url: "https://PatrickKidger.bsky.social" 33 | 34 | site_name: jaxtyping 35 | site_description: The documentation for the jaxtyping software library. 36 | site_author: Patrick Kidger 37 | site_url: https://docs.kidger.site/jaxtyping 38 | 39 | repo_url: https://github.com/patrick-kidger/jaxtyping 40 | repo_name: patrick-kidger/jaxtyping 41 | edit_uri: "" # No edit button, as some of our pages are in /docs and some in /examples via symlink, so it's impossible for them all to be accurate 42 | 43 | strict: true # Don't allow warnings during the build process 44 | 45 | extra_javascript: 46 | # The below two make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/ 47 | - _static/mathjax.js 48 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 49 | 50 | extra_css: 51 | - _static/custom_css.css 52 | 53 | markdown_extensions: 54 | - pymdownx.arithmatex: # Render LaTeX via MathJax 55 | generic: true 56 | - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme. 57 | - pymdownx.details # Allowing hidden expandable regions denoted by ??? 58 | - pymdownx.snippets: # Include one Markdown file into another 59 | base_path: docs 60 | - admonition 61 | - toc: 62 | permalink: "¤" # Adds a clickable permalink to each section heading 63 | toc_depth: 4 64 | 65 | plugins: 66 | - search: 67 | separator: '[\s\-,:!=\[\]()"/]+|(?!\b)(?=[A-Z][a-z])|\.(?!\d)|&[lg]t;' 68 | - autorefs # Cross-links to headings 69 | - include_exclude_files: 70 | include: 71 | - ".htaccess" 72 | exclude: 73 | - "_overrides" 74 | - ipynb 75 | - hippogriffe: 76 | # Our docs are generated pretty much independently of the runtime code, these links aren't useful. 77 | show_source_links: none 78 | extra_public_objects: 79 | - numpy.dtype 80 | - mkdocstrings: 81 | handlers: 82 | python: 83 | options: 84 | force_inspection: true 85 | heading_level: 4 86 | inherited_members: true 87 | members_order: source 88 | show_bases: false 89 | show_if_no_docstring: true 90 | show_overloads: false 91 | show_root_heading: true 92 | show_signature_annotations: true 93 | show_source: false 94 | show_symbol_type_heading: true 95 | show_symbol_type_toc: true 96 | 97 | nav: 98 | - 'index.md' 99 | - API: 100 | - 'api/array.md' 101 | - 'api/pytree.md' 102 | - 'api/runtime-type-checking.md' 103 | - 'api/advanced-features.md' 104 | - 'faq.md' 105 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "jaxtyping" 3 | version = "0.3.2" 4 | description = "Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays." 5 | readme = "README.md" 6 | requires-python =">=3.10" 7 | license = {file = "LICENSE"} 8 | authors = [ 9 | {name = "Patrick Kidger", email = "contact@kidger.site"}, 10 | ] 11 | keywords = ["jax", "neural-networks", "deep-learning", "equinox", "typing"] 12 | classifiers = [ 13 | "Development Status :: 3 - Alpha", 14 | "Intended Audience :: Developers", 15 | "Intended Audience :: Financial and Insurance Industry", 16 | "Intended Audience :: Information Technology", 17 | "Intended Audience :: Science/Research", 18 | "License :: OSI Approved :: MIT License", 19 | "Natural Language :: English", 20 | "Programming Language :: Python :: 3", 21 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 22 | "Topic :: Scientific/Engineering :: Information Analysis", 23 | "Topic :: Scientific/Engineering :: Mathematics", 24 | ] 25 | urls = {repository = "https://github.com/google/jaxtyping" } 26 | dependencies = ["wadler_lindig>=0.1.3"] 27 | entry-points = {pytest11 = {jaxtyping = "jaxtyping._pytest_plugin"}} 28 | 29 | [project.optional-dependencies] 30 | docs = [ 31 | "hippogriffe==0.2.0", 32 | "mkdocs==1.6.1", 33 | "mkdocs-include-exclude-files==0.1.0", 34 | "mkdocs-ipynb==0.1.0", 35 | "mkdocs-material==9.6.7", 36 | "mkdocstrings[python]==0.28.3", 37 | "pymdown-extensions==10.14.3", 38 | ] 39 | 40 | [build-system] 41 | requires = ["hatchling"] 42 | build-backend = "hatchling.build" 43 | 44 | [tool.hatch.build] 45 | include = ["jaxtyping/*"] 46 | 47 | [tool.ruff.lint] 48 | select = ["E", "F", "I001"] 49 | ignore = ["E721", "E731", "F722"] 50 | 51 | [tool.ruff.lint.per-file-ignores] 52 | "jaxtyping/_typeguard/__init__.py" = ["E", "F", "I001"] 53 | 54 | [tool.ruff.lint.isort] 55 | combine-as-imports = true 56 | lines-after-imports = 2 57 | extra-standard-library = ["typing_extensions"] 58 | order-by-type = false 59 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrick-kidger/jaxtyping/fe61644be8590cf0a89bacc278283e1e4b9ea3e4/test/__init__.py -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Google LLC 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | import random 21 | 22 | import jax.random as jr 23 | import pytest 24 | import typeguard 25 | 26 | 27 | try: 28 | import beartype 29 | except ImportError: 30 | 31 | def skip(*args, **kwargs): 32 | pytest.skip("Beartype not installed") 33 | 34 | typecheck_params = [typeguard.typechecked, skip] 35 | else: 36 | typecheck_params = [typeguard.typechecked, beartype.beartype] 37 | 38 | 39 | @pytest.fixture(params=typecheck_params) 40 | def typecheck(request): 41 | return request.param 42 | 43 | 44 | @pytest.fixture(params=(False, True)) 45 | def jaxtyp(request): 46 | import jaxtyping 47 | 48 | if request.param: 49 | # New-style 50 | # @jaxtyping.jaxtyped(typechecker=typechecker) 51 | # def f(...) 52 | return lambda typechecker: jaxtyping.jaxtyped(typechecker=typechecker) 53 | else: 54 | # Old-style 55 | # @jaxtyping.jaxtyped 56 | # @typechecker 57 | # def f(...) 58 | def impl(typechecker): 59 | def decorator(fn): 60 | with pytest.warns(match="As of jaxtyping version 0.2.24"): 61 | return jaxtyping.jaxtyped(typechecker(fn)) 62 | 63 | return decorator 64 | 65 | return impl 66 | 67 | 68 | @pytest.fixture() 69 | def getkey(): 70 | def _getkey(): 71 | # Not sure what the maximum actually is but this will do 72 | return jr.PRNGKey(random.randint(0, 2**31 - 1)) 73 | 74 | return _getkey 75 | 76 | 77 | @pytest.fixture(scope="module") 78 | def beartype_or_skip(): 79 | yield pytest.importorskip("beartype") 80 | 81 | 82 | @pytest.fixture(scope="module") 83 | def typeguard_or_skip(): 84 | yield pytest.importorskip("typeguard") 85 | -------------------------------------------------------------------------------- /test/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Google LLC 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | import contextlib 21 | import gc 22 | from collections.abc import Callable 23 | from typing import Any 24 | 25 | import equinox as eqx 26 | import typeguard 27 | 28 | 29 | ParamError = [] 30 | ReturnError = [] 31 | ParamError.append(TypeError) # old typeguard 32 | ReturnError.append(TypeError) # old typeguard 33 | 34 | try: 35 | # new typeguard 36 | ParamError.append(typeguard.TypeCheckError) 37 | ReturnError.append(typeguard.TypeCheckError) 38 | except AttributeError: 39 | pass 40 | 41 | try: 42 | import beartype 43 | except ImportError: 44 | pass 45 | else: 46 | ParamError.append(beartype.roar.BeartypeCallHintParamViolation) 47 | ReturnError.append(beartype.roar.BeartypeCallHintReturnViolation) 48 | 49 | ParamError = tuple(ParamError) 50 | ReturnError = tuple(ReturnError) 51 | 52 | 53 | @eqx.filter_jit 54 | def make_mlp(key): 55 | return eqx.nn.MLP(2, 2, 2, 2, key=key) 56 | 57 | 58 | @contextlib.contextmanager 59 | def assert_no_garbage( 60 | allowed_garbage_predicate: Callable[[Any], bool] = lambda _: False, 61 | ): 62 | try: 63 | gc.disable() 64 | gc.collect() 65 | # It's unclear why, but a second GC is necessary to fully collect 66 | # existing garbage. 67 | gc.collect() 68 | gc.garbage.clear() 69 | 70 | yield 71 | 72 | # Do a GC collection, saving collected objects in gc.garbage. 73 | gc.set_debug(gc.DEBUG_SAVEALL) 74 | gc.collect() 75 | 76 | disallowed_garbage = [ 77 | obj for obj in gc.garbage if not allowed_garbage_predicate(obj) 78 | ] 79 | assert not disallowed_garbage 80 | finally: 81 | # Reset the GC back to normal. 82 | gc.set_debug(0) 83 | gc.garbage.clear() 84 | gc.collect() 85 | gc.enable() 86 | -------------------------------------------------------------------------------- /test/import_hook_tester.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Google LLC 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | import dataclasses 21 | from typing import no_type_check 22 | 23 | import equinox as eqx 24 | import jax.numpy as jnp 25 | import pytest 26 | from helpers import ParamError, ReturnError 27 | 28 | import jaxtyping 29 | from jaxtyping import Float32, Int 30 | 31 | 32 | # 33 | # Test that functions get checked 34 | # 35 | 36 | 37 | def g(x: Float32[jnp.ndarray, " b"]): 38 | pass 39 | 40 | 41 | g(jnp.array([1.0])) 42 | with pytest.raises(ParamError): 43 | g(jnp.array(1)) 44 | 45 | 46 | # 47 | # Test that Equinox modules get checked 48 | # 49 | 50 | 51 | # Dataclass `__init__`, no converter 52 | class Mod1(eqx.Module): 53 | foo: int 54 | bar: Float32[jnp.ndarray, " a"] 55 | 56 | 57 | Mod1(1, jnp.array([1.0])) 58 | with pytest.raises(ParamError): 59 | Mod1(1.0, jnp.array([1.0])) 60 | with pytest.raises(ParamError): 61 | Mod1(1, jnp.array(1.0)) 62 | 63 | 64 | # Dataclass `__init__`, converter 65 | class Mod2(eqx.Module): 66 | a: jnp.ndarray = eqx.field(converter=jnp.asarray) 67 | 68 | 69 | Mod2(1) # This will fail unless we run typechecking after conversion 70 | 71 | 72 | # This silently passes -- the untyped `lambda x: x` launders the value through. 73 | # No easy way to tackle this. That's okay. 74 | 75 | # class BadMod2(eqx.Module): 76 | # a: jnp.ndarray = eqx.field(converter=lambda x: x) 77 | 78 | 79 | # with pytest.raises(ParamError): 80 | # BadMod2(1) 81 | # with pytest.raises(ParamError): 82 | # BadMod2("asdf") 83 | 84 | 85 | # Custom `__init__`, no converter 86 | class Mod3(eqx.Module): 87 | foo: int 88 | bar: Float32[jnp.ndarray, " a"] 89 | 90 | def __init__(self, foo: str, bar: Float32[jnp.ndarray, " a"]): 91 | self.foo = int(foo) 92 | self.bar = bar 93 | 94 | 95 | Mod3("1", jnp.array([1.0])) 96 | with pytest.raises(ParamError): 97 | Mod3(1, jnp.array([1.0])) 98 | with pytest.raises(ParamError): 99 | Mod3("1", jnp.array(1.0)) 100 | 101 | 102 | # Custom `__init__`, converter 103 | class Mod4(eqx.Module): 104 | a: Int[jnp.ndarray, ""] = eqx.field(converter=jnp.asarray) 105 | 106 | def __init__(self, a: str): 107 | self.a = int(a) 108 | 109 | 110 | Mod4("1") # This will fail unless we run typechecking after conversion 111 | 112 | 113 | # Custom `__post_init__`, no converter 114 | class Mod5(eqx.Module): 115 | foo: int 116 | bar: Float32[jnp.ndarray, " a"] 117 | 118 | def __post_init__(self): 119 | pass 120 | 121 | 122 | Mod5(1, jnp.array([1.0])) 123 | with pytest.raises(ParamError): 124 | Mod5(1.0, jnp.array([1.0])) 125 | with pytest.raises(ParamError): 126 | Mod5(1, jnp.array(1.0)) 127 | 128 | 129 | # Dataclass `__init__`, converter 130 | class Mod6(eqx.Module): 131 | a: jnp.ndarray = eqx.field(converter=jnp.asarray) 132 | 133 | def __post_init__(self): 134 | pass 135 | 136 | 137 | Mod6(1) # This will fail unless we run typechecking after conversion 138 | 139 | 140 | # 141 | # Test that dataclasses get checked 142 | # 143 | 144 | 145 | @dataclasses.dataclass 146 | class D: 147 | foo: int 148 | bar: Float32[jnp.ndarray, " a"] 149 | 150 | 151 | D(1, jnp.array([1.0])) 152 | with pytest.raises(ParamError): 153 | D(1.0, jnp.array([1.0])) 154 | with pytest.raises(ParamError): 155 | D(1, jnp.array(1.0)) 156 | 157 | 158 | # 159 | # Test that methods get checked 160 | # 161 | 162 | 163 | class N(eqx.Module): 164 | a: jnp.ndarray 165 | 166 | def __init__(self, foo: str): 167 | self.a = jnp.array(1) 168 | 169 | def foo(self, x: jnp.ndarray): 170 | pass 171 | 172 | def bar(self) -> jnp.ndarray: 173 | return self.a 174 | 175 | 176 | n = N("hi") 177 | with pytest.raises(ParamError): 178 | N(123) 179 | with pytest.raises(ParamError): 180 | n.foo("not_an_array_either") 181 | bad_n = eqx.tree_at(lambda x: x.a, n, "not_an_array") 182 | with pytest.raises(ReturnError): 183 | bad_n.bar() 184 | 185 | 186 | # 187 | # Test that we don't get called in `super()`. 188 | # 189 | 190 | 191 | called = False 192 | 193 | 194 | class Base(eqx.Module): 195 | x: int 196 | 197 | def __init__(self): 198 | self.x = "not an int" 199 | global called 200 | assert not called 201 | called = True 202 | 203 | 204 | class Derived(Base): 205 | def __init__(self): 206 | assert not called 207 | super().__init__() 208 | assert called 209 | self.x = 2 210 | 211 | 212 | Derived() 213 | 214 | 215 | # 216 | # Test that stringified type annotations work 217 | 218 | 219 | class Foo: 220 | pass 221 | 222 | 223 | class Bar(eqx.Module): 224 | x: type[Foo] 225 | y: "type[Foo]" 226 | # Partially-stringified hints not tested; not supported. 227 | 228 | 229 | Bar(Foo, Foo) 230 | 231 | with pytest.raises(ParamError): 232 | Bar(1, Foo) 233 | 234 | 235 | # 236 | # Test that assert isinstance works (even if no arg/return annotations) 237 | 238 | 239 | def isinstance_test(x): 240 | assert isinstance(x, Float32[jnp.ndarray, " b"]) 241 | _ = x 242 | 243 | 244 | isinstance_test(jnp.array([1.0])) 245 | with pytest.raises(AssertionError): 246 | isinstance_test(jnp.array(1)) 247 | 248 | 249 | @no_type_check 250 | def f(_: Float32[jnp.ndarray, "foo bar"]): 251 | pass 252 | 253 | 254 | f("not an array") 255 | 256 | 257 | # Record that we've finished our checks successfully 258 | 259 | jaxtyping._test_import_hook_counter += 1 260 | -------------------------------------------------------------------------------- /test/requirements.txt: -------------------------------------------------------------------------------- 1 | beartype 2 | cloudpickle 3 | equinox 4 | IPython 5 | jax 6 | numpy<2 7 | pytest 8 | pytest-asyncio 9 | tensorflow 10 | typeguard<3 11 | mlx 12 | -------------------------------------------------------------------------------- /test/test_all_importable.py: -------------------------------------------------------------------------------- 1 | # We have some pretty complicated semantics in `__init__.py`. 2 | # Here we check that we didn't miss one of them on our runtime branch. 3 | def test_all_importable(): 4 | # Ordered according to their appearance in the documentation. 5 | from jaxtyping import ( # noqa: I001 6 | Shaped, # noqa: F401 7 | Bool, # noqa: F401 8 | Key, # noqa: F401 9 | Num, # noqa: F401 10 | Inexact, # noqa: F401 11 | Float, # noqa: F401 12 | BFloat16, # noqa: F401 13 | Float16, # noqa: F401 14 | Float32, # noqa: F401 15 | Float64, # noqa: F401 16 | Complex, # noqa: F401 17 | Complex64, # noqa: F401 18 | Complex128, # noqa: F401 19 | Integer, # noqa: F401 20 | UInt, # noqa: F401 21 | UInt2, # noqa: F401 22 | UInt4, # noqa: F401 23 | UInt8, # noqa: F401 24 | UInt16, # noqa: F401 25 | UInt32, # noqa: F401 26 | UInt64, # noqa: F401 27 | Int, # noqa: F401 28 | Int2, # noqa: F401 29 | Int4, # noqa: F401 30 | Int8, # noqa: F401 31 | Int16, # noqa: F401 32 | Int32, # noqa: F401 33 | Int64, # noqa: F401 34 | Real, # noqa: F401 35 | Array, # noqa: F401 36 | ArrayLike, # noqa: F401 37 | Scalar, # noqa: F401 38 | ScalarLike, # noqa: F401 39 | PRNGKeyArray, # noqa: F401 40 | PyTreeDef, # noqa: F401 41 | PyTree, # noqa: F401 42 | jaxtyped, # noqa: F401 43 | install_import_hook, # noqa: F401 44 | AbstractArray, # noqa: F401 45 | AbstractDtype, # noqa: F401 46 | print_bindings, # noqa: F401 47 | get_array_name_format, # noqa: F401 48 | set_array_name_format, # noqa: F401 49 | ) 50 | -------------------------------------------------------------------------------- /test/test_array.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Google LLC 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | import dataclasses as dc 21 | import sys 22 | from typing import Any, get_args, get_origin, TypeVar, Union 23 | 24 | import jax.numpy as jnp 25 | import jax.random as jr 26 | import numpy as np 27 | import pytest 28 | 29 | 30 | try: 31 | import torch 32 | except ImportError: 33 | torch = None 34 | 35 | from jaxtyping import ( 36 | AbstractArray, 37 | AbstractDtype, 38 | AnnotationError, 39 | Array, 40 | ArrayLike, 41 | Bool, 42 | Float, 43 | Float32, 44 | jaxtyped, 45 | Key, 46 | PRNGKeyArray, 47 | Scalar, 48 | Shaped, 49 | ) 50 | 51 | from .helpers import ParamError, ReturnError 52 | 53 | 54 | def test_basic(jaxtyp, typecheck): 55 | @jaxtyp(typecheck) 56 | def g(x: Shaped[Array, "..."]): 57 | pass 58 | 59 | g(jnp.array(1.0)) 60 | 61 | 62 | def test_dtypes(): 63 | from jaxtyping import ( # noqa: F401 64 | Array, 65 | BFloat16, 66 | Bool, 67 | Complex, 68 | Complex64, 69 | Complex128, 70 | Float, 71 | Float8e4m3b11fnuz, 72 | Float8e4m3fn, 73 | Float8e4m3fnuz, 74 | Float8e5m2, 75 | Float8e5m2fnuz, 76 | Float16, 77 | Float32, 78 | Float64, 79 | Inexact, 80 | Int, 81 | Int2, 82 | Int4, 83 | Int8, 84 | Int16, 85 | Int32, 86 | Int64, 87 | Num, 88 | Shaped, 89 | UInt, 90 | UInt2, 91 | UInt4, 92 | UInt8, 93 | UInt16, 94 | UInt32, 95 | UInt64, 96 | ) 97 | 98 | for key, val in locals().items(): 99 | if issubclass(val, AbstractDtype): 100 | assert key == val.__name__ 101 | 102 | 103 | def test_numpy_struct_dtype(): 104 | from jaxtyping import make_numpy_struct_dtype 105 | 106 | dtype1 = np.dtype([("first", np.uint8), ("second", bool)]) 107 | Dtype1 = make_numpy_struct_dtype(dtype1, "Dtype1") 108 | arr = np.array([0, False], dtype=dtype1) 109 | 110 | assert isinstance(arr, Dtype1[np.ndarray, "_"]) 111 | 112 | dtype2 = np.dtype([("third", np.uint8), ("second", bool)]) 113 | Dtype2 = make_numpy_struct_dtype(dtype2, "Dtype2") 114 | assert not isinstance(arr, Dtype2[np.ndarray, "_"]) 115 | 116 | dtype3 = np.dtype([("second", bool), ("first", np.uint8)]) 117 | Dtype3 = make_numpy_struct_dtype(dtype3, "Dtype3") 118 | assert not isinstance(arr, Dtype3[np.ndarray, "_"]) 119 | 120 | 121 | def test_return(jaxtyp, typecheck, getkey): 122 | @jaxtyp(typecheck) 123 | def g(x: Float[Array, "b c"]) -> Float[Array, "c b"]: 124 | return jnp.transpose(x) 125 | 126 | g(jr.normal(getkey(), (3, 4))) 127 | 128 | @jaxtyp(typecheck) 129 | def h(x: Float[Array, "b c"]) -> Float[Array, "b c"]: 130 | return jnp.transpose(x) 131 | 132 | with pytest.raises(ReturnError): 133 | h(jr.normal(getkey(), (3, 4))) 134 | 135 | 136 | def test_two_args(jaxtyp, typecheck, getkey): 137 | @jaxtyp(typecheck) 138 | def g(x: Shaped[Array, "b c"], y: Shaped[Array, "c d"]): 139 | return x @ y 140 | 141 | g(jr.normal(getkey(), (3, 4)), jr.normal(getkey(), (4, 5))) 142 | with pytest.raises(ParamError): 143 | g(jr.normal(getkey(), (3, 4)), jr.normal(getkey(), (5, 4))) 144 | 145 | @jaxtyp(typecheck) 146 | def h(x: Shaped[Array, "b c"], y: Shaped[Array, "c d"]) -> Shaped[Array, "b d"]: 147 | return x @ y 148 | 149 | h(jr.normal(getkey(), (3, 4)), jr.normal(getkey(), (4, 5))) 150 | with pytest.raises(ParamError): 151 | h(jr.normal(getkey(), (3, 4)), jr.normal(getkey(), (5, 4))) 152 | 153 | 154 | def test_any_dtype(jaxtyp, typecheck, getkey): 155 | @jaxtyp(typecheck) 156 | def g(x: Shaped[Array, "a b"]) -> Shaped[Array, "a b"]: 157 | return x 158 | 159 | g(jr.normal(getkey(), (3, 4))) 160 | g(jnp.array([[True, False]])) 161 | g(jnp.array([[1, 2], [3, 4]], dtype=jnp.int2)) 162 | g(jnp.array([[1, 2], [3, 4]], dtype=jnp.int4)) 163 | g(jnp.array([[1, 2], [3, 4]], dtype=jnp.int8)) 164 | g(jnp.array([[1, 2], [3, 4]], dtype=jnp.uint2)) 165 | g(jnp.array([[1, 2], [3, 4]], dtype=jnp.uint4)) 166 | g(jnp.array([[1, 2], [3, 4]], dtype=jnp.uint16)) 167 | g(jr.normal(getkey(), (3, 4), dtype=jnp.complex64)) 168 | g(jr.normal(getkey(), (3, 4), dtype=jnp.bfloat16)) 169 | 170 | with pytest.raises(ParamError): 171 | g(jr.normal(getkey(), (1,))) 172 | 173 | 174 | def test_nested_jaxtyped(jaxtyp, typecheck, getkey): 175 | @jaxtyp(typecheck) 176 | def g(x: Float32[Array, "b c"], transpose: bool) -> Float32[Array, "c b"]: 177 | return h(x, transpose) 178 | 179 | @jaxtyp(typecheck) 180 | def h(x: Float32[Array, "c b"], transpose: bool) -> Float32[Array, "b c"]: 181 | if transpose: 182 | return jnp.transpose(x) 183 | else: 184 | return x 185 | 186 | g(jr.normal(getkey(), (2, 3)), True) 187 | g(jr.normal(getkey(), (3, 3)), True) 188 | g(jr.normal(getkey(), (3, 3)), False) 189 | with pytest.raises(ReturnError): 190 | g(jr.normal(getkey(), (2, 3)), False) 191 | 192 | 193 | def test_nested_nojaxtyped(jaxtyp, typecheck, getkey): 194 | @jaxtyp(typecheck) 195 | def g(x: Float32[Array, "b c"]): 196 | return h(x) 197 | 198 | @typecheck 199 | def h(x: Float32[Array, "c b"]): 200 | return x 201 | 202 | with pytest.raises(ParamError): 203 | g(jr.normal(getkey(), (2, 3))) 204 | 205 | 206 | def test_isinstance(jaxtyp, typecheck, getkey): 207 | @jaxtyp(typecheck) 208 | def g(x: Float32[Array, "b c"]) -> Float32[Array, " z"]: 209 | y = jnp.transpose(x) 210 | assert isinstance(y, Float32[Array, "c b"]) 211 | assert not isinstance( 212 | y, Float32[Array, "b z"] 213 | ) # z left unbound as b!=c (unless x symmetric, which it isn't) 214 | out = jr.normal(getkey(), (500,)) 215 | assert isinstance(out, Float32[Array, "z"]) # z now bound 216 | return out 217 | 218 | g(jr.normal(getkey(), (2, 3))) 219 | 220 | 221 | def test_fixed(jaxtyp, typecheck, getkey): 222 | @jaxtyp(typecheck) 223 | def g( 224 | x: Float32[Array, "4 5 foo"], y: Float32[Array, " foo"] 225 | ) -> Float32[Array, "4 5"]: 226 | return x @ y 227 | 228 | a = jr.normal(getkey(), (4, 5, 2)) 229 | b = jr.normal(getkey(), (2,)) 230 | assert g(a, b).shape == (4, 5) 231 | 232 | c = jr.normal(getkey(), (3, 5, 2)) 233 | with pytest.raises(ParamError): 234 | g(c, b) 235 | 236 | 237 | def test_anonymous(jaxtyp, typecheck, getkey): 238 | @jaxtyp(typecheck) 239 | def g(x: Float32[Array, "foo _"], y: Float32[Array, " _"]): 240 | pass 241 | 242 | a = jr.normal(getkey(), (3, 4)) 243 | b = jr.normal(getkey(), (5,)) 244 | g(a, b) 245 | 246 | 247 | def test_named_variadic(jaxtyp, typecheck, getkey): 248 | @jaxtyp(typecheck) 249 | def g( 250 | x: Float32[Array, "*batch foo"], 251 | y: Float32[Array, " *batch"], 252 | z: Float32[Array, " foo"], 253 | ): 254 | pass 255 | 256 | c = jr.normal(getkey(), (5,)) 257 | 258 | a1 = jr.normal(getkey(), (5,)) 259 | b1 = jr.normal(getkey(), ()) 260 | g(a1, b1, c) 261 | 262 | a2 = jr.normal(getkey(), (3, 5)) 263 | b2 = jr.normal(getkey(), (3,)) 264 | g(a2, b2, c) 265 | 266 | with pytest.raises(ParamError): 267 | g(a1, b2, c) 268 | with pytest.raises(ParamError): 269 | g(a2, b1, c) 270 | 271 | @jaxtyp(typecheck) 272 | def h(x: Float32[Array, " foo *batch"], y: Float32[Array, " foo *batch bar"]): 273 | pass 274 | 275 | a = jr.normal(getkey(), (4,)) 276 | b = jr.normal(getkey(), (4, 3)) 277 | c = jr.normal(getkey(), (3, 4)) 278 | h(a, b) 279 | with pytest.raises(ParamError): 280 | h(a, c) 281 | with pytest.raises(ParamError): 282 | h(b, c) 283 | 284 | 285 | def test_anonymous_variadic(jaxtyp, typecheck, getkey): 286 | @jaxtyp(typecheck) 287 | def g(x: Float32[Array, "... foo"], y: Float32[Array, " foo"]): 288 | pass 289 | 290 | a1 = jr.normal(getkey(), (5,)) 291 | a2 = jr.normal(getkey(), (3, 5)) 292 | a3 = jr.normal(getkey(), (3, 4, 5)) 293 | b = jr.normal(getkey(), (5,)) 294 | c = jr.normal(getkey(), (1,)) 295 | g(a1, b) 296 | g(a2, b) 297 | g(a3, b) 298 | with pytest.raises(ParamError): 299 | g(a1, c) 300 | with pytest.raises(ParamError): 301 | g(a2, c) 302 | with pytest.raises(ParamError): 303 | g(a3, c) 304 | 305 | 306 | def test_broadcast_fixed(jaxtyp, typecheck, getkey): 307 | @jaxtyp(typecheck) 308 | def g(x: Float32[Array, "#4"]): 309 | pass 310 | 311 | g(jr.normal(getkey(), (4,))) 312 | g(jr.normal(getkey(), (1,))) 313 | 314 | with pytest.raises(ParamError): 315 | g(jr.normal(getkey(), (3,))) 316 | 317 | 318 | def test_broadcast_named(jaxtyp, typecheck, getkey): 319 | @jaxtyp(typecheck) 320 | def g(x: Float32[Array, " #foo"], y: Float32[Array, " #foo"]): 321 | pass 322 | 323 | a = jr.normal(getkey(), (3,)) 324 | b = jr.normal(getkey(), (4,)) 325 | c = jr.normal(getkey(), (1,)) 326 | 327 | g(a, a) 328 | g(b, b) 329 | g(c, c) 330 | g(a, c) 331 | g(b, c) 332 | g(c, a) 333 | g(c, b) 334 | 335 | with pytest.raises(ParamError): 336 | g(a, b) 337 | with pytest.raises(ParamError): 338 | g(b, a) 339 | 340 | 341 | def test_broadcast_variadic_named(jaxtyp, typecheck, getkey): 342 | @jaxtyp(typecheck) 343 | def g(x: Float32[Array, " *#foo"], y: Float32[Array, " *#foo"]): 344 | pass 345 | 346 | a = jr.normal(getkey(), (3,)) 347 | b = jr.normal(getkey(), (4,)) 348 | c = jr.normal(getkey(), (4, 4)) 349 | d = jr.normal(getkey(), (5, 6)) 350 | 351 | j = jr.normal(getkey(), (1,)) 352 | k = jr.normal(getkey(), (1, 4)) 353 | l = jr.normal(getkey(), (5, 1)) # noqa: E741 354 | m = jr.normal(getkey(), (1, 1)) 355 | n = jr.normal(getkey(), (2, 1)) 356 | o = jr.normal(getkey(), (1, 6)) 357 | 358 | g(a, a) 359 | g(b, b) 360 | g(c, c) 361 | g(d, d) 362 | g(b, c) 363 | with pytest.raises(ParamError): 364 | g(a, b) 365 | with pytest.raises(ParamError): 366 | g(a, c) 367 | with pytest.raises(ParamError): 368 | g(a, b) 369 | with pytest.raises(ParamError): 370 | g(d, b) 371 | 372 | g(a, j) 373 | g(b, j) 374 | g(c, j) 375 | g(d, j) 376 | g(b, k) 377 | g(c, k) 378 | with pytest.raises(ParamError): 379 | g(d, k) 380 | with pytest.raises(ParamError): 381 | g(c, l) 382 | g(d, l) 383 | g(a, m) 384 | g(c, m) 385 | g(d, m) 386 | g(a, n) 387 | g(b, n) 388 | with pytest.raises(ParamError): 389 | g(c, n) 390 | with pytest.raises(ParamError): 391 | g(d, n) 392 | g(o, d) 393 | with pytest.raises(ParamError): 394 | g(o, c) 395 | with pytest.raises(ParamError): 396 | g(o, a) 397 | 398 | 399 | def test_variadic_mixed_broadcast(jaxtyp, typecheck, getkey): 400 | @jaxtyp(typecheck) 401 | def f(x: Float[Array, " *foo"], y: Float[Array, " #*foo"]): 402 | pass 403 | 404 | a = jr.normal(getkey(), (3, 4)) 405 | b = jr.normal(getkey(), (5,)) 406 | with pytest.raises(ParamError): 407 | f(a, b) 408 | 409 | c = jr.normal(getkey(), (7, 3, 2)) 410 | d = jr.normal(getkey(), (1, 2)) 411 | f(c, d) 412 | 413 | 414 | def test_variadic_mixed_broadcast2(jaxtyp, typecheck, getkey): 415 | @jaxtyp(typecheck) 416 | def f(x: Float[Array, " *#foo"], y: Float[Array, " *foo"]): 417 | pass 418 | 419 | a = jr.normal(getkey(), (3, 4)) 420 | b = jr.normal(getkey(), (5,)) 421 | with pytest.raises(ParamError): 422 | f(a, b) 423 | 424 | c = jr.normal(getkey(), (1, 2)) 425 | d = jr.normal(getkey(), (7, 3, 2)) 426 | f(c, d) 427 | 428 | 429 | def test_variadic_mixed_broadcast3(jaxtyp, typecheck, getkey): 430 | @jaxtyp(typecheck) 431 | def f( 432 | x: Float[Array, "*B L D"], 433 | *, 434 | y: Float[Array, "*#B J d"], 435 | z: Bool[Array, "*B L J"], 436 | ) -> Float[Array, "*B L D"]: 437 | return x 438 | 439 | x = jr.normal(getkey(), (2, 7, 3, 2, 2)) 440 | y = jr.bernoulli(getkey(), shape=(2, 7, 3, 2, 2)) 441 | z = jr.normal(getkey(), (2, 7, 1, 2, 2)) 442 | f(x, y=z, z=y) 443 | 444 | 445 | def test_no_commas(): 446 | with pytest.raises(ValueError): 447 | Float32[Array, "foo, bar"] 448 | 449 | 450 | def test_symbolic(jaxtyp, typecheck, getkey): 451 | @jaxtyp(typecheck) 452 | def make_slice(x: Float32[Array, " dim"]) -> Float32[Array, " dim-1"]: 453 | return x[1:] 454 | 455 | @jaxtyp(typecheck) 456 | def cat(x: Float32[Array, " dim"]) -> Float32[Array, " 2*dim"]: 457 | return jnp.concatenate([x, x]) 458 | 459 | @jaxtyp(typecheck) 460 | def bad_make_slice(x: Float32[Array, " dim"]) -> Float32[Array, " dim-1"]: 461 | return x 462 | 463 | @jaxtyp(typecheck) 464 | def bad_cat(x: Float32[Array, " dim"]) -> Float32[Array, " 2*dim"]: 465 | return jnp.concatenate([x, x, x]) 466 | 467 | x = jr.normal(getkey(), (5,)) 468 | assert make_slice(x).shape == (4,) 469 | assert cat(x).shape == (10,) 470 | 471 | y = jr.normal(getkey(), (3, 4)) 472 | with pytest.raises(ParamError): 473 | make_slice(y) 474 | with pytest.raises(ParamError): 475 | cat(y) 476 | 477 | with pytest.raises(ReturnError): 478 | bad_make_slice(x) 479 | with pytest.raises(ReturnError): 480 | bad_cat(x) 481 | 482 | 483 | def test_incomplete_symbolic(jaxtyp, typecheck, getkey): 484 | @jaxtyp(typecheck) 485 | def foo(x: Float32[Array, " 2*dim"]): 486 | pass 487 | 488 | x = jr.normal(getkey(), (4,)) 489 | with pytest.raises(AnnotationError): 490 | foo(x) 491 | 492 | 493 | def test_deferred_symbolic_good(jaxtyp, typecheck): 494 | @jaxtyp(typecheck) 495 | def foo(dim: int, fill: Float[Array, ""]) -> Float[Array, " {dim}"]: 496 | return jnp.full((dim,), fill) 497 | 498 | class A: 499 | size = 5 500 | 501 | @jaxtyp(typecheck) 502 | def bar(self, fill: Float[Array, ""]) -> Float[Array, " {self.size}"]: 503 | return jnp.full((self.size,), fill) 504 | 505 | foo(3, jnp.array(0.0)) 506 | A().bar(jnp.array(0.0)) 507 | 508 | 509 | def test_deferred_symbolic_bad(jaxtyp, typecheck): 510 | @jaxtyp(typecheck) 511 | def foo(dim: int, fill: Float[Array, ""]) -> Float[Array, " {dim-1}"]: 512 | return jnp.full((dim,), fill) 513 | 514 | class A: 515 | size = 5 516 | 517 | @jaxtyp(typecheck) 518 | def bar(self, fill: Float[Array, ""]) -> Float[Array, " {self.size}-1"]: 519 | return jnp.full((self.size,), fill) 520 | 521 | with pytest.raises(ReturnError): 522 | foo(3, jnp.array(0.0)) 523 | 524 | with pytest.raises(ReturnError): 525 | A().bar(jnp.array(0.0)) 526 | 527 | 528 | def test_deferred_symbolic_dataclass(typecheck): 529 | @jaxtyped(typechecker=typecheck) 530 | @dc.dataclass 531 | class A: 532 | value: int 533 | array: Float[Array, " {value}"] 534 | 535 | A(3, jnp.zeros(3)) 536 | 537 | with pytest.raises(ParamError): 538 | A(3, jnp.zeros(4)) 539 | 540 | 541 | def _to_set(x) -> set[tuple]: 542 | return { 543 | (xi.index_variadic, xi.dims, xi.array_type, xi.dtypes, xi.dim_str) 544 | if issubclass(xi, AbstractArray) 545 | else xi 546 | for xi in x 547 | } 548 | 549 | 550 | def test_arraylike(typecheck, getkey): 551 | floatlike1 = Float32[ArrayLike, ""] 552 | floatlike2 = Float[ArrayLike, ""] 553 | floatlike3 = Float32[ArrayLike, "4"] 554 | 555 | assert get_origin(floatlike1) is Union 556 | assert get_origin(floatlike2) is Union 557 | assert get_origin(floatlike3) is Union 558 | assert _to_set(get_args(floatlike1)) == _to_set( 559 | [ 560 | Float32[Array, ""], 561 | Float32[np.ndarray, ""], 562 | Float32[np.number, ""], 563 | float, 564 | ] 565 | ) 566 | assert _to_set(get_args(floatlike2)) == _to_set( 567 | [ 568 | Float[Array, ""], 569 | Float[np.ndarray, ""], 570 | Float[np.number, ""], 571 | float, 572 | ] 573 | ) 574 | assert _to_set(get_args(floatlike3)) == _to_set( 575 | [ 576 | Float32[Array, "4"], 577 | Float32[np.ndarray, "4"], 578 | ] 579 | ) 580 | 581 | shaped1 = Shaped[ArrayLike, ""] 582 | shaped2 = Shaped[ArrayLike, "4"] 583 | assert get_origin(shaped1) is Union 584 | assert get_origin(shaped2) is Union 585 | assert _to_set(get_args(shaped1)) == _to_set( 586 | [ 587 | Shaped[Array, ""], 588 | Shaped[np.ndarray, ""], 589 | Shaped[np.bool_, ""], 590 | Shaped[np.number, ""], 591 | bool, 592 | int, 593 | float, 594 | complex, 595 | ] 596 | ) 597 | assert _to_set(get_args(shaped2)) == _to_set( 598 | [ 599 | Shaped[Array, "4"], 600 | Shaped[np.ndarray, "4"], 601 | ] 602 | ) 603 | 604 | 605 | def test_ignored_names(): 606 | x = Float[np.ndarray, "foo=4"] 607 | 608 | assert isinstance(np.zeros(4), x) 609 | assert not isinstance(np.zeros(5), x) 610 | assert not isinstance(np.zeros((4, 5)), x) 611 | 612 | y = Float[np.ndarray, "bar qux foo=bar+qux"] 613 | 614 | assert isinstance(np.zeros((2, 3, 5)), y) 615 | assert not isinstance(np.zeros((2, 3, 6)), y) 616 | 617 | z = Float[np.ndarray, "bar #foo=bar"] 618 | 619 | assert isinstance(np.zeros((3, 3)), z) 620 | assert isinstance(np.zeros((3, 1)), z) 621 | assert not isinstance(np.zeros((3, 4)), z) 622 | 623 | # Weird but legal 624 | w = Float[np.ndarray, "bar foo=#bar"] 625 | 626 | assert isinstance(np.zeros((3, 3)), w) 627 | assert isinstance(np.zeros((3, 1)), w) 628 | assert not isinstance(np.zeros((3, 4)), w) 629 | 630 | 631 | def test_symbolic_functions(): 632 | x = Float[np.ndarray, "foo bar min(foo,bar)"] 633 | 634 | assert isinstance(np.zeros((2, 3, 2)), x) 635 | assert isinstance(np.zeros((3, 2, 2)), x) 636 | assert not isinstance(np.zeros((3, 2, 4)), x) 637 | 638 | 639 | @pytest.mark.skipif(sys.version_info < (3, 10), reason="requires Python 3.10") 640 | def test_py310_unions(): 641 | x = np.zeros(3) 642 | y = Shaped[Array | np.ndarray, "_"] 643 | assert isinstance(x, get_args(y)) 644 | 645 | 646 | def test_key(jaxtyp, typecheck): 647 | @jaxtyp(typecheck) 648 | def f(x: PRNGKeyArray): 649 | pass 650 | 651 | f(jr.key(0)) 652 | f(jr.PRNGKey(0)) 653 | 654 | with pytest.raises(ParamError): 655 | f(object()) 656 | with pytest.raises(ParamError): 657 | f(1) 658 | with pytest.raises(ParamError): 659 | f(jnp.array(3)) 660 | with pytest.raises(ParamError): 661 | f(jnp.array(3.0)) 662 | 663 | 664 | def test_key_dtype(jaxtyp, typecheck): 665 | @jaxtyp(typecheck) 666 | def f1(x: Key[Array, ""]): 667 | pass 668 | 669 | @jaxtyp(typecheck) 670 | def f2(x: Key[Scalar, ""]): 671 | pass 672 | 673 | for f in (f1, f2): 674 | f(jr.key(0)) 675 | 676 | with pytest.raises(ParamError): 677 | f(jr.PRNGKey(0)) 678 | with pytest.raises(ParamError): 679 | f(object()) 680 | with pytest.raises(ParamError): 681 | f(1) 682 | with pytest.raises(ParamError): 683 | f(jnp.array(3)) 684 | with pytest.raises(ParamError): 685 | f(jnp.array(3.0)) 686 | 687 | 688 | def test_extension(jaxtyp, typecheck, getkey): 689 | X = Shaped[Array, "a b"] 690 | Y = Shaped[X, "c d"] 691 | Z = Shaped[Array, "c d a b"] 692 | assert str(Z) == str(Y) 693 | 694 | X = Float[Array, "a"] 695 | Y = Float[X, "b"] 696 | 697 | @jaxtyp(typecheck) 698 | def f(a: X, b: Y): ... 699 | 700 | a = jr.normal(getkey(), (3, 4)) 701 | b = jr.normal(getkey(), (4,)) 702 | c = jr.normal(getkey(), (3,)) 703 | 704 | f(b, a) 705 | with pytest.raises(ParamError): 706 | f(c, a) 707 | with pytest.raises(ParamError): 708 | f(a, a) 709 | 710 | @typecheck 711 | def g(a: Shaped[PRNGKeyArray, "2"]): ... 712 | 713 | with pytest.raises(ParamError): 714 | g(jr.PRNGKey(0)) 715 | g(jr.split(jr.PRNGKey(0))) 716 | with pytest.raises(ParamError): 717 | g(jr.split(jr.PRNGKey(0), 3)) 718 | 719 | 720 | def test_scalar_variadic_dim(): 721 | assert Float[float, "..."] is float 722 | assert Float[float, "#*shape"] is float 723 | 724 | # This one is a bit weird -- it should really also assert that shape==(), but we 725 | # don't implement that. 726 | assert Float[float, "*shape"] is float 727 | 728 | 729 | def test_scalar_dtype_mismatch(): 730 | with pytest.raises(ValueError): 731 | Float[bool, "..."] 732 | 733 | 734 | def test_custom_array(jaxtyp, typecheck): 735 | class MyArray1: 736 | @property 737 | def dtype(self): 738 | return "foo" 739 | 740 | @property 741 | def shape(self): 742 | return (3,) 743 | 744 | class MyArray2: 745 | @property 746 | def dtype(self): 747 | return "bar" 748 | 749 | @property 750 | def shape(self): 751 | return (3,) 752 | 753 | class MyArray3: 754 | @property 755 | def dtype(self): 756 | return "foo" 757 | 758 | @property 759 | def shape(self): 760 | return (4,) 761 | 762 | class FooDtype(AbstractDtype): 763 | dtypes = ["foo"] 764 | 765 | @jaxtyp(typecheck) 766 | def f(x: FooDtype[MyArray1, "3"]): 767 | pass 768 | 769 | f(MyArray1()) 770 | with pytest.raises(ParamError): 771 | f(MyArray2()) 772 | with pytest.raises(ParamError): 773 | f(MyArray3()) 774 | 775 | @jaxtyp(typecheck) 776 | def g(x: FooDtype[MyArray1, "3"], y: FooDtype[MyArray1, "4"]): 777 | pass 778 | 779 | with pytest.raises(ParamError): 780 | g(MyArray1(), MyArray1()) 781 | 782 | @jaxtyp(typecheck) 783 | def h(x: FooDtype[MyArray1, "3"], y: FooDtype[MyArray3, "4"]): 784 | pass 785 | 786 | with pytest.raises(ParamError): 787 | g(MyArray1(), MyArray3()) 788 | 789 | 790 | @pytest.mark.parametrize( 791 | "array_type", [Any, TypeVar("T"), TypeVar("T", bound=ArrayLike)] 792 | ) 793 | def test_any(array_type, jaxtyp, typecheck): 794 | class DuckArray1: 795 | @property 796 | def shape(self): 797 | return 3, 4 798 | 799 | @property 800 | def dtype(self): 801 | return np.array([], dtype=np.float32).dtype 802 | 803 | class DuckArray2: 804 | @property 805 | def shape(self): 806 | return 3, 4, 5 807 | 808 | @property 809 | def dtype(self): 810 | return np.array([], dtype=np.float32).dtype 811 | 812 | class DuckArray3: 813 | @property 814 | def shape(self): 815 | return 3, 4 816 | 817 | @property 818 | def dtype(self): 819 | return np.array([], dtype=np.int32).dtype 820 | 821 | @jaxtyp(typecheck) 822 | def f(x: Float[array_type, "foo bar"]): 823 | del x 824 | 825 | f(np.arange(12.0).reshape(3, 4)) 826 | f(jnp.arange(12.0).reshape(3, 4)) 827 | if isinstance(array_type, TypeVar) and array_type.__bound__ is ArrayLike: 828 | with pytest.raises(ParamError): 829 | f(DuckArray1()) 830 | else: 831 | f(DuckArray1()) 832 | 833 | # Wrong shape 834 | with pytest.raises(ParamError): 835 | f(np.arange(12.0).reshape(3, 2, 2)) 836 | with pytest.raises(ParamError): 837 | f(jnp.arange(12.0).reshape(3, 2, 2)) 838 | with pytest.raises(ParamError): 839 | f(DuckArray2()) 840 | 841 | # Wrong dtype 842 | with pytest.raises(ParamError): 843 | f(np.arange(12).reshape(3, 4)) 844 | with pytest.raises(ParamError): 845 | f(jnp.arange(12).reshape(3, 4)) 846 | with pytest.raises(ParamError): 847 | f(DuckArray3()) 848 | 849 | # Not an array 850 | with pytest.raises(ParamError): 851 | f(1) 852 | 853 | 854 | def test_non_instantiation(): 855 | with pytest.raises(RuntimeError, match="cannot be instantiated"): 856 | Float[Array, ""]() # pyright: ignore[reportCallIssue] 857 | -------------------------------------------------------------------------------- /test/test_decorator.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import dataclasses 3 | import sys 4 | from typing import no_type_check 5 | 6 | import equinox as eqx 7 | import jax.numpy as jnp 8 | import jax.random as jr 9 | import pytest 10 | 11 | from jaxtyping import Array, Float, jaxtyped, print_bindings 12 | 13 | from .helpers import assert_no_garbage, ParamError, ReturnError 14 | 15 | 16 | class M(metaclass=abc.ABCMeta): 17 | @jaxtyped(typechecker=None) 18 | def f(self): ... 19 | 20 | @jaxtyped(typechecker=None) 21 | @classmethod 22 | def g1(cls): 23 | return 3 24 | 25 | @classmethod 26 | @jaxtyped(typechecker=None) 27 | def g2(cls): 28 | return 4 29 | 30 | @jaxtyped(typechecker=None) 31 | @staticmethod 32 | def h1(): 33 | return 3 34 | 35 | @staticmethod 36 | @jaxtyped(typechecker=None) 37 | def h2(): 38 | return 4 39 | 40 | @jaxtyped(typechecker=None) 41 | @abc.abstractmethod 42 | def i1(self): ... 43 | 44 | @abc.abstractmethod 45 | @jaxtyped(typechecker=None) 46 | def i2(self): ... 47 | 48 | 49 | class N: 50 | @jaxtyped(typechecker=None) 51 | @property 52 | def j1(self): 53 | return 3 54 | 55 | @property 56 | @jaxtyped(typechecker=None) 57 | def j2(self): 58 | return 4 59 | 60 | 61 | def test_identity(): 62 | assert M.f is M.f 63 | 64 | 65 | def test_classmethod(): 66 | assert M.g1() == 3 67 | assert M.g2() == 4 68 | 69 | 70 | def test_staticmethod(): 71 | assert M.h1() == 3 72 | assert M.h2() == 4 73 | 74 | 75 | # Check that the @jaxtyped decorator doesn't blat the __isabstractmethod__ of 76 | # @abstractmethod 77 | def test_abstractmethod(): 78 | assert M.i1.__isabstractmethod__ 79 | assert M.i2.__isabstractmethod__ 80 | 81 | 82 | def test_property(): 83 | assert N().j1 == 3 84 | assert N().j2 == 4 85 | 86 | 87 | def test_context(getkey): 88 | a = jr.normal(getkey(), (3, 4)) 89 | b = jr.normal(getkey(), (5,)) 90 | with jaxtyped("context"): 91 | assert isinstance(a, Float[Array, "foo bar"]) 92 | assert not isinstance(b, Float[Array, "foo"]) 93 | assert isinstance(a, Float[Array, "foo bar"]) 94 | assert isinstance(b, Float[Array, "foo"]) 95 | 96 | 97 | def test_varargs(jaxtyp, typecheck): 98 | @jaxtyp(typecheck) 99 | def f(*args) -> None: 100 | pass 101 | 102 | f(1, 2) 103 | 104 | 105 | def test_varkwargs(jaxtyp, typecheck): 106 | @jaxtyp(typecheck) 107 | def f(**kwargs) -> None: 108 | pass 109 | 110 | f(a=1, b=2) 111 | 112 | 113 | def test_defaults(jaxtyp, typecheck): 114 | @jaxtyp(typecheck) 115 | def f(x: int, y=1) -> None: 116 | pass 117 | 118 | f(1) 119 | 120 | 121 | def test_default_bindings(getkey, jaxtyp, typecheck): 122 | @jaxtyp(typecheck) 123 | def f(x: int, y: int = 1) -> Float[Array, "x {y}"]: 124 | return jr.normal(getkey(), (x, y)) 125 | 126 | f(1) 127 | f(1, 1) 128 | f(1, 0) 129 | f(1, 5) 130 | 131 | 132 | class _GlobalFoo: 133 | pass 134 | 135 | 136 | def test_global_stringified_annotation(jaxtyp, typecheck): 137 | @jaxtyp(typecheck) 138 | def f(x: "_GlobalFoo") -> "_GlobalFoo": 139 | return x 140 | 141 | f(_GlobalFoo()) 142 | 143 | @jaxtyp(typecheck) 144 | def g(x: int) -> "_GlobalFoo": 145 | return x 146 | 147 | @jaxtyp(typecheck) 148 | def h(x: "_GlobalFoo") -> int: 149 | return x 150 | 151 | with pytest.raises(ReturnError): 152 | g(1) 153 | 154 | with pytest.raises(ParamError): 155 | h(1) 156 | 157 | 158 | # This test does not use `jaxtyp(typecheck)` because typeguard does some evil stack 159 | # frame introspection to try and grab local variables. 160 | def test_local_stringified_annotation(typecheck): 161 | class LocalFoo: 162 | pass 163 | 164 | @jaxtyped(typechecker=typecheck) 165 | def f(x: "LocalFoo") -> "LocalFoo": 166 | return x 167 | 168 | f(LocalFoo()) 169 | 170 | with pytest.warns(match="As of jaxtyping version 0.2.24"): 171 | 172 | @jaxtyped 173 | @typecheck 174 | def g(x: "LocalFoo") -> "LocalFoo": 175 | return x 176 | 177 | g(LocalFoo()) 178 | 179 | # We don't check that errors are raised if it goes wrong, since we can't usually 180 | # resolve local type annotations at runtime. Best we can hope for is not to raise 181 | # a spurious error about not being able to find the type. 182 | 183 | 184 | def test_print_bindings(typecheck, capfd): 185 | @jaxtyped(typechecker=typecheck) 186 | def f(x: Float[Array, "foo bar"]): 187 | print_bindings() 188 | 189 | capfd.readouterr() 190 | f(jnp.zeros((3, 4))) 191 | text, _ = capfd.readouterr() 192 | assert text == ( 193 | "The current values for each jaxtyping axis annotation are as follows." 194 | "\nfoo=3\nbar=4\n" 195 | ) 196 | 197 | 198 | def test_no_type_check(typecheck): 199 | @jaxtyped(typechecker=typecheck) 200 | @no_type_check 201 | def f(x: Float[Array, "foo bar"]): 202 | pass 203 | 204 | @no_type_check 205 | @jaxtyped(typechecker=typecheck) 206 | def g(x: Float[Array, "foo bar"]): 207 | pass 208 | 209 | f("not an array") 210 | g("not an array") 211 | 212 | 213 | def test_no_garbage(typecheck): 214 | with assert_no_garbage(): 215 | 216 | @jaxtyped(typechecker=typecheck) 217 | @dataclasses.dataclass 218 | class _Obj: 219 | x: int 220 | 221 | _Obj(x=5) 222 | 223 | 224 | def test_no_garbage_identity_typecheck(): 225 | with assert_no_garbage(): 226 | 227 | @jaxtyped(typechecker=lambda x: x) 228 | @dataclasses.dataclass 229 | class _Obj: 230 | x: int 231 | 232 | _Obj(x=5) 233 | 234 | 235 | def test_no_garbage_frame_capture_typecheck(): 236 | with assert_no_garbage(): 237 | # Some typechecker implementations (e.g., typeguard 2.13.3) capture the calling 238 | # frame's f_locals. This test checks that the calling frames in jaxtyping are 239 | # sufficiently isolated to avoid introducing reference cycles when a 240 | # typechecker does this. 241 | def frame_locals_capture(fn): 242 | locals = sys._getframe(1).f_locals 243 | 244 | def wrapper(*args, **kwargs): 245 | # Required to ensure wrapper holds a reference to f_locals, which is 246 | # the scenario under test. 247 | _ = locals 248 | return fn(*args, **kwargs) 249 | 250 | return wrapper 251 | 252 | @jaxtyped(typechecker=frame_locals_capture) 253 | @dataclasses.dataclass 254 | class _Obj: 255 | x: int 256 | 257 | _Obj(x=5) 258 | 259 | 260 | def test_equinox_converter(typecheck): 261 | def _typed_str(x: int) -> str: 262 | return str(x) 263 | 264 | @jaxtyped(typechecker=typecheck) 265 | class X(eqx.Module): 266 | x: str = eqx.field(converter=_typed_str) 267 | 268 | X(1) 269 | with pytest.raises(ParamError): 270 | X("1") 271 | 272 | 273 | def test_mlx(jaxtyp, typecheck): 274 | import mlx.core as mx 275 | import numpy as np 276 | 277 | @jaxtyp(typecheck) 278 | def hello(x: Float[mx.array, "8 16"]): 279 | pass 280 | 281 | hello(mx.zeros((8, 16), dtype=mx.float32)) 282 | 283 | with pytest.raises(ParamError): 284 | hello(mx.zeros((8, 14), dtype=mx.float32)) 285 | 286 | with pytest.raises(ParamError): 287 | hello(np.zeros((8, 16), dtype=np.float32)) 288 | 289 | with pytest.raises(ParamError): 290 | hello(mx.zeros((8, 16), dtype=mx.int32)) 291 | 292 | 293 | # In particular the below scenario occurs during the import hook when we have an 294 | # explicit `__init__`. 295 | def test_no_rewrapping_of_dataclass_init(typecheck): 296 | @jaxtyped(typechecker=typecheck) 297 | @dataclasses.dataclass 298 | class Foo: 299 | x: int 300 | 301 | @jaxtyped(typechecker=typecheck) 302 | def __init__(self, x: int): 303 | self.x = x 304 | 305 | wrapped = Foo.__init__.__wrapped__ 306 | with pytest.raises(AttributeError): 307 | wrapped.__wrapped__ 308 | 309 | 310 | def test_stringified_multiple_varaidic(typecheck): 311 | @jaxtyped(typechecker=typecheck) 312 | def foo() -> 'Float[Array, "*foo *bar"]': 313 | return jnp.arange(3) 314 | 315 | foo() 316 | -------------------------------------------------------------------------------- /test/test_generators.py: -------------------------------------------------------------------------------- 1 | from collections.abc import AsyncIterator, Iterator 2 | 3 | import jax.numpy as jnp 4 | import pytest 5 | 6 | from jaxtyping import Array, Float, Shaped 7 | 8 | from .helpers import ParamError 9 | 10 | 11 | try: 12 | import torch 13 | except ImportError: 14 | torch = None 15 | 16 | 17 | def test_generators_simple(jaxtyp, typecheck): 18 | @jaxtyp(typecheck) 19 | def gen(x: Float[Array, "*"]) -> Iterator[Float[Array, "*"]]: 20 | yield x 21 | 22 | @jaxtyp(typecheck) 23 | def foo() -> None: 24 | next(gen(jnp.zeros(2))) 25 | next(gen(jnp.zeros((3, 4)))) 26 | 27 | foo() 28 | 29 | 30 | def test_generators_return_no_annotations(jaxtyp, typecheck): 31 | @jaxtyp(typecheck) 32 | def gen(x: Float[Array, "*"]): 33 | yield x 34 | 35 | @jaxtyp(typecheck) 36 | def foo(): 37 | next(gen(jnp.zeros(2))) 38 | next(gen(jnp.zeros((3, 4)))) 39 | 40 | foo() 41 | 42 | 43 | @pytest.mark.asyncio 44 | async def test_async_generators_simple(jaxtyp, typecheck): 45 | @jaxtyp(typecheck) 46 | async def gen(x: Float[Array, "*"]) -> AsyncIterator[Float[Array, "*"]]: 47 | yield x 48 | 49 | @jaxtyp(typecheck) 50 | async def foo(): 51 | async for _ in gen(jnp.zeros(2)): 52 | pass 53 | async for _ in gen(jnp.zeros((3, 4))): 54 | pass 55 | 56 | await foo() 57 | 58 | 59 | def test_generators_dont_modify_same_annotations(jaxtyp, typecheck): 60 | @jaxtyp(typecheck) 61 | def g(x: Float[Array, "1"]) -> Iterator[Float[Array, "1"]]: 62 | yield x 63 | 64 | @jaxtyp(typecheck) 65 | def m(x: Float[Array, "1"]) -> Float[Array, "1"]: 66 | return x 67 | 68 | with pytest.raises(ParamError): 69 | next(g(jnp.zeros(2))) 70 | with pytest.raises(ParamError): 71 | m(jnp.zeros(2)) 72 | 73 | 74 | def test_generators_original_issue(jaxtyp, typecheck): 75 | # Effectively the same as https://github.com/patrick-kidger/jaxtyping/issues/91 76 | if torch is None: 77 | pytest.skip("torch is not available") 78 | 79 | @jaxtyp(typecheck) 80 | def g(x: Shaped[torch.Tensor, "*"]) -> Iterator[Shaped[torch.Tensor, "*"]]: 81 | yield x 82 | 83 | @jaxtyp(typecheck) 84 | def f() -> None: 85 | next(g(torch.zeros(1))) 86 | next(g(torch.zeros(2))) 87 | 88 | f() 89 | -------------------------------------------------------------------------------- /test/test_import_hook.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Google LLC 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | import importlib 21 | import importlib.metadata 22 | import pathlib 23 | import shutil 24 | import sys 25 | import tempfile 26 | 27 | import pytest 28 | 29 | import jaxtyping 30 | 31 | 32 | _here = pathlib.Path(__file__).parent 33 | 34 | 35 | try: 36 | typeguard_version = importlib.metadata.version("typeguard") 37 | except Exception as e: 38 | raise ImportError("Could not find typeguard version") from e 39 | else: 40 | try: 41 | major, _, _ = typeguard_version.split(".") 42 | major = int(major) 43 | except Exception as e: 44 | raise ImportError( 45 | f"Unexpected typeguard version {typeguard_version}; not formatted as " 46 | "`major.minor.patch`" 47 | ) from e 48 | if major != 2: 49 | raise ImportError( 50 | "jaxtyping's tests required typeguard version 2. (Versions 3 and 4 are both " 51 | "known to have bugs.)" 52 | ) 53 | 54 | 55 | assert not hasattr(jaxtyping, "_test_import_hook_counter") 56 | jaxtyping._test_import_hook_counter = 0 57 | 58 | 59 | @pytest.fixture(scope="module") 60 | def importhook_tempdir(): 61 | with tempfile.TemporaryDirectory() as dir: 62 | sys.path.append(dir) 63 | dir = pathlib.Path(dir) 64 | shutil.copyfile(_here / "helpers.py", dir / "helpers.py") 65 | yield dir 66 | 67 | 68 | def _test_import_hook(importhook_tempdir, typechecker): 69 | counter = jaxtyping._test_import_hook_counter 70 | stem = f"import_hook_tester{counter}" 71 | shutil.copyfile(_here / "import_hook_tester.py", importhook_tempdir / f"{stem}.py") 72 | 73 | importlib.invalidate_caches() 74 | with jaxtyping.install_import_hook(stem, typechecker): 75 | importlib.import_module(stem) 76 | assert counter + 1 == jaxtyping._test_import_hook_counter 77 | 78 | 79 | # Tests start below... 80 | 81 | 82 | def test_import_hook_typeguard(importhook_tempdir, typeguard_or_skip): 83 | _test_import_hook(importhook_tempdir, "typeguard.typechecked") 84 | 85 | 86 | def test_import_hook_beartype(importhook_tempdir, beartype_or_skip): 87 | _test_import_hook(importhook_tempdir, "beartype.beartype") 88 | 89 | 90 | def test_import_hook_beartype_full(importhook_tempdir, beartype_or_skip): 91 | bearchecker = "beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))" # noqa: E501 92 | _test_import_hook(importhook_tempdir, bearchecker) 93 | 94 | 95 | def test_import_hook_typeguard_old(importhook_tempdir, typeguard_or_skip): 96 | _test_import_hook(importhook_tempdir, ("typeguard", "typechecked")) 97 | 98 | 99 | def test_import_hook_beartype_old(importhook_tempdir, beartype_or_skip): 100 | _test_import_hook(importhook_tempdir, ("beartype", "beartype")) 101 | 102 | 103 | def test_import_hook_broken_checker(importhook_tempdir): 104 | with pytest.raises(AttributeError): 105 | _test_import_hook(importhook_tempdir, "jaxtyping.does_not_exist") 106 | 107 | 108 | def test_import_hook_transitive(importhook_tempdir, typeguard_or_skip): 109 | counter = jaxtyping._test_import_hook_counter 110 | transitive_name = "jaxtyping_transitive_test" 111 | transitive_dir = importhook_tempdir / transitive_name 112 | transitive_dir.mkdir() 113 | shutil.copyfile(_here / "import_hook_tester.py", transitive_dir / "tester.py") 114 | with open(transitive_dir / "__init__.py", "w") as f: 115 | f.write("from . import tester") 116 | f.flush() 117 | 118 | importlib.invalidate_caches() 119 | with jaxtyping.install_import_hook(transitive_name, "typeguard.typechecked"): 120 | importlib.import_module(transitive_name) 121 | assert counter + 1 == jaxtyping._test_import_hook_counter 122 | -------------------------------------------------------------------------------- /test/test_ipython_extension.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from IPython.testing.globalipapp import start_ipython 3 | 4 | from .helpers import ParamError 5 | 6 | 7 | @pytest.fixture(scope="session") 8 | def session_ip(): 9 | yield start_ipython() 10 | 11 | 12 | @pytest.fixture(scope="function") 13 | def ip(session_ip): 14 | session_ip.run_cell(raw_cell="import jaxtyping") 15 | session_ip.run_line_magic(magic_name="load_ext", line="jaxtyping") 16 | session_ip.run_line_magic( 17 | magic_name="jaxtyping.typechecker", line="typeguard.typechecked" 18 | ) 19 | yield session_ip 20 | 21 | 22 | def test_that_ipython_works(ip): 23 | ip.run_cell(raw_cell="x = 1").raise_error() 24 | assert ip.user_global_ns["x"] == 1 25 | 26 | 27 | def test_function_beartype(ip): 28 | ip.run_cell( 29 | raw_cell=""" 30 | def f(x: int): 31 | pass 32 | """ 33 | ).raise_error() 34 | ip.run_cell(raw_cell="f(1)").raise_error() 35 | 36 | with pytest.raises(ParamError): 37 | ip.run_cell(raw_cell='f("x")').raise_error() 38 | 39 | 40 | def test_function_none(ip): 41 | ip.run_cell( 42 | raw_cell=""" 43 | def f(a,b,c): 44 | pass 45 | """ 46 | ).raise_error() 47 | ip.run_cell(raw_cell='f(1,2,"k")').raise_error() 48 | 49 | 50 | def test_function_jaxtyped(ip): 51 | ip.run_cell( 52 | raw_cell=""" 53 | from jaxtyping import Float, Array, Int 54 | import jax 55 | 56 | def g(x: Float[Array, "1"]): 57 | return x + 1 58 | 59 | """ 60 | ).raise_error() 61 | 62 | ip.run_cell(raw_cell="g(jax.numpy.array([1.0]))").raise_error() 63 | 64 | with pytest.raises(ParamError): 65 | ip.run_cell(raw_cell="g(jax.numpy.array(1.0))").raise_error() 66 | 67 | with pytest.raises(ParamError): 68 | ip.run_cell(raw_cell="g(jax.numpy.array([1]))").raise_error() 69 | 70 | with pytest.raises(ParamError): 71 | ip.run_cell(raw_cell="g(jax.numpy.array([2, 3]))").raise_error() 72 | 73 | with pytest.raises(ParamError): 74 | ip.run_cell(raw_cell='g("string")').raise_error() 75 | 76 | 77 | def test_function_jaxtyped_and_jitted(ip): 78 | ip.run_cell( 79 | raw_cell=""" 80 | from jaxtyping import Float, Array, Int 81 | import jax 82 | 83 | @jax.jit 84 | def g(x: Float[Array, "1"]): 85 | return x + 1 86 | 87 | """ 88 | ).raise_error() 89 | 90 | ip.run_cell(raw_cell="g(jax.numpy.array([1.0]))").raise_error() 91 | 92 | with pytest.raises(ParamError): 93 | ip.run_cell(raw_cell="g(jax.numpy.array(1.0))").raise_error() 94 | 95 | with pytest.raises(ParamError): 96 | ip.run_cell(raw_cell="g(jax.numpy.array([1]))").raise_error() 97 | 98 | with pytest.raises(ParamError): 99 | ip.run_cell(raw_cell="g(jax.numpy.array([2, 3]))").raise_error() 100 | 101 | with pytest.raises(ParamError): 102 | ip.run_cell(raw_cell='g("string")').raise_error() 103 | 104 | 105 | def test_class_jaxtyped(ip): 106 | ip.run_cell( 107 | raw_cell=""" 108 | from jaxtyping import Float, Array, Int 109 | import equinox as eqx 110 | import jax 111 | 112 | class A(eqx.Module): 113 | x: Float[Array, "2"] 114 | 115 | def do_something(self, y: Int[Array, ""]): 116 | return self.x + y 117 | """ 118 | ).raise_error() 119 | 120 | ip.run_cell(raw_cell="a = A(jax.numpy.array([1.0, 2.0]))").raise_error() 121 | ip.run_cell(raw_cell="a.do_something(jax.numpy.array(2))").raise_error() 122 | 123 | with pytest.raises(ParamError): 124 | ip.run_cell(raw_cell="A(jax.numpy.array([1.0]))").raise_error() 125 | 126 | with pytest.raises(ParamError): 127 | ip.run_cell( 128 | raw_cell="a.do_something(jax.numpy.array([2.0, 3.0]))" 129 | ).raise_error() 130 | 131 | 132 | def test_class_not_dataclass(ip): 133 | ip.run_cell( 134 | raw_cell=""" 135 | from jaxtyping import Float, Array, Int 136 | import equinox as eqx 137 | import jax 138 | 139 | class A: 140 | def __init__(self, x): 141 | self.x = x 142 | 143 | def do_something(self, y): 144 | return x + y 145 | """ 146 | ).raise_error() 147 | 148 | ip.run_cell(raw_cell="a = A(jax.numpy.array([1.0, 2.0]))").raise_error() 149 | ip.run_cell(raw_cell="a.do_something(jax.numpy.array(2))").raise_error() 150 | ip.run_cell(raw_cell="A(jax.numpy.array([1.0]))").raise_error() 151 | ip.run_cell(raw_cell="a.do_something(jax.numpy.array([2.0, 3.0]))").raise_error() 152 | -------------------------------------------------------------------------------- /test/test_messages.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import equinox as eqx 4 | import jax.numpy as jnp 5 | import pytest 6 | 7 | from jaxtyping import Array, Float, jaxtyped, PyTree, TypeCheckError 8 | 9 | 10 | def test_arg_localisation(typecheck): 11 | @jaxtyped(typechecker=typecheck) 12 | def f(x: str, y: str, z: int): 13 | pass 14 | 15 | matches = [ 16 | "Type-check error whilst checking the parameters of .*.f", 17 | "The problem arose whilst typechecking parameter 'z'.", 18 | "Called with parameters: {'x': 'hi', 'y': 'bye', 'z': 'not-an-int'}", 19 | r"Parameter annotations: \(x: str, y: str, z: int\).", 20 | ] 21 | for match in matches: 22 | with pytest.raises(TypeCheckError, match=match): 23 | f("hi", "bye", "not-an-int") 24 | 25 | @jaxtyped(typechecker=typecheck) 26 | def g(x: Float[Array, "a b"], y: Float[Array, "b c"]): 27 | pass 28 | 29 | x = jnp.zeros((2, 3)) 30 | y = jnp.zeros((4, 3)) 31 | matches = [ 32 | "Type-check error whilst checking the parameters of .*..g", 33 | "The problem arose whilst typechecking parameter 'y'.", 34 | r"Called with parameters: {'x': f32\[2,3\], 'y': f32\[4,3\]}", 35 | ( 36 | r"Parameter annotations: \(x: Float\[Array, 'a b'\], y: " 37 | r"Float\[Array, 'b c'\]\)." 38 | ), 39 | "The current values for each jaxtyping axis annotation are as follows.", 40 | "a=2", 41 | "b=3", 42 | ] 43 | for match in matches: 44 | with pytest.raises(TypeCheckError, match=match): 45 | g(x, y=y) 46 | 47 | 48 | def test_return(typecheck): 49 | @jaxtyped(typechecker=typecheck) 50 | def f(x: PyTree[Any, " T"], y: PyTree[Any, " S"]) -> PyTree[Any, "T S"]: 51 | return "foo" 52 | 53 | x = (1, 2) 54 | y = {"a": 1} 55 | matches = [ 56 | "Type-check error whilst checking the return value of .*..f", 57 | r"Called with parameters: {'x': \(1, 2\), 'y': {'a': 1}}", 58 | "Actual value: 'foo'", 59 | r"Expected type: PyTree\[Any, 'T S'\].", 60 | ( 61 | "The current values for each jaxtyping PyTree structure annotation are as " 62 | "follows." 63 | ), 64 | r"T=PyTreeDef\(\(\*, \*\)\)", 65 | r"S=PyTreeDef\({'a': \*}\)", 66 | ] 67 | for match in matches: 68 | with pytest.raises(TypeCheckError, match=match): 69 | f(x, y=y) 70 | 71 | 72 | def test_dataclass_init(typecheck): 73 | @jaxtyped(typechecker=typecheck) 74 | class M(eqx.Module): 75 | x: Float[Array, " *foo"] 76 | y: PyTree[Any, " T"] 77 | z: int 78 | 79 | x = jnp.zeros((2, 3)) 80 | y = (1, (3, 4)) 81 | z = "not-an-int" 82 | 83 | matches = [ 84 | "Type-check error whilst checking the parameters of .*..M", 85 | "The problem arose whilst typechecking parameter 'z'.", 86 | ( 87 | r"Called with parameters: {'self': M\(\.\.\.\), 'x': f32\[2,3\], " 88 | r"'y': \(1, \(3, 4\)\), 'z': 'not-an-int'}" 89 | ), 90 | ( 91 | r"Parameter annotations: \(self, x: Float\[Array, '\*foo'\], " 92 | r"y: PyTree\[Any, 'T'\], z: int\)." 93 | ), 94 | "The current values for each jaxtyping axis annotation are as follows.", 95 | r"foo=\(2, 3\)", 96 | ( 97 | "The current values for each jaxtyping PyTree structure annotation are as " 98 | "follows." 99 | ), 100 | r"T=PyTreeDef\(\(\*, \(\*, \*\)\)\)", 101 | ] 102 | for match in matches: 103 | with pytest.raises(TypeCheckError, match=match): 104 | M(x, y, z) 105 | -------------------------------------------------------------------------------- /test/test_no_jax_dependency.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | import unittest 4 | 5 | 6 | _py_path = sys.executable 7 | 8 | 9 | @unittest.skipIf(not _py_path, "test requires sys.executable") 10 | def test_no_jax_dependency(): 11 | result = subprocess.run( 12 | f"{_py_path} -c " 13 | "'import jaxtyping; import sys; sys.exit(\"jax\" in sys.modules)'", 14 | shell=True, 15 | ) 16 | assert result.returncode == 0 17 | 18 | 19 | # Meta-test: test that the above test will work. (i.e. that I haven't messed up using 20 | # subprocess.) 21 | @unittest.skipIf(not _py_path, "test requires sys.executable") 22 | def test_meta(): 23 | result = subprocess.run( 24 | f"{_py_path} -c 'import jaxtyping; import jax; import sys; " 25 | 'sys.exit("jax" in sys.modules)\'', 26 | shell=True, 27 | ) 28 | assert result.returncode == 1 29 | -------------------------------------------------------------------------------- /test/test_pytree.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Google LLC 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | from collections.abc import Callable 21 | from typing import NamedTuple, Tuple, Union 22 | 23 | import equinox as eqx 24 | import jax 25 | import jax.numpy as jnp 26 | import jax.random as jr 27 | import pytest 28 | import wadler_lindig as wl 29 | 30 | import jaxtyping 31 | from jaxtyping import AnnotationError, Array, Float, PyTree 32 | 33 | from .helpers import make_mlp, ParamError 34 | 35 | 36 | def test_direct(typecheck): 37 | @typecheck 38 | def g(x: PyTree): 39 | pass 40 | 41 | g(1) 42 | g({"a": jnp.array(1), "b": [object()]}) 43 | g(object()) 44 | 45 | @typecheck 46 | def h() -> PyTree: 47 | return object() 48 | 49 | h() 50 | 51 | 52 | def test_subscript(getkey, typecheck): 53 | @typecheck 54 | def g(x: PyTree[int]): 55 | pass 56 | 57 | g(1) 58 | g([1, 2, {"a": 3}]) 59 | g(jax.tree.map(lambda _: 1, make_mlp(getkey()))) 60 | 61 | with pytest.raises(ParamError): 62 | g(object()) 63 | with pytest.raises(ParamError): 64 | g("hi") 65 | with pytest.raises(ParamError): 66 | g([1, 2, {"a": 3}, "bye"]) 67 | 68 | 69 | def test_leaf_pytrees(getkey, typecheck): 70 | @typecheck 71 | def g(x: PyTree[eqx.nn.MLP]): 72 | pass 73 | 74 | g(make_mlp(getkey())) 75 | g([make_mlp(getkey()), make_mlp(getkey()), {"a": make_mlp(getkey())}]) 76 | 77 | with pytest.raises(ParamError): 78 | g([1, 2]) 79 | with pytest.raises(ParamError): 80 | g([1, 2, make_mlp()]) 81 | 82 | 83 | def test_nested_pytrees(getkey, typecheck): 84 | # PyTree[...] is logically equivalent to PyTree[PyTree[...]] 85 | @typecheck 86 | def g(x: PyTree[PyTree[eqx.nn.MLP]]): 87 | pass 88 | 89 | g(make_mlp(getkey())) 90 | g([make_mlp(getkey()), make_mlp(getkey()), {"a": make_mlp(getkey())}]) 91 | 92 | with pytest.raises(ParamError): 93 | g([1, 2]) 94 | with pytest.raises(ParamError): 95 | g([1, 2, make_mlp()]) 96 | 97 | 98 | def test_pytree_array(jaxtyp, typecheck): 99 | @jaxtyp(typecheck) 100 | def g(x: PyTree[Float[jnp.ndarray, "..."]]): 101 | pass 102 | 103 | g(jnp.array(1.0)) 104 | g([jnp.array(1.0), jnp.array(1.0), {"a": jnp.array(1.0)}]) 105 | with pytest.raises(ParamError): 106 | g(jnp.array(1)) 107 | with pytest.raises(ParamError): 108 | g(1.0) 109 | 110 | 111 | def test_pytree_shaped_array(jaxtyp, typecheck, getkey): 112 | @jaxtyp(typecheck) 113 | def g(x: PyTree[Float[jnp.ndarray, "b c"]]): 114 | pass 115 | 116 | g(jnp.array([[1.0]])) 117 | g([jr.normal(getkey(), (1, 1)), jr.normal(getkey(), (1, 1))]) 118 | g([jr.normal(getkey(), (3, 4)), jr.normal(getkey(), (3, 4))]) 119 | with pytest.raises(ParamError): 120 | g(jnp.array(1.0)) 121 | with pytest.raises(ParamError): 122 | g(jnp.array([[1]])) 123 | with pytest.raises(ParamError): 124 | g([jr.normal(getkey(), (3, 4)), jr.normal(getkey(), (1, 1))]) 125 | with pytest.raises(ParamError): 126 | g([jr.normal(getkey(), (1, 1)), jr.normal(getkey(), (3, 4))]) 127 | 128 | 129 | def test_pytree_union(typecheck): 130 | @typecheck 131 | def g(x: PyTree[Union[int, str]]): 132 | pass 133 | 134 | g([1]) 135 | g(["hi"]) 136 | g([1, "hi"]) 137 | with pytest.raises(ParamError): 138 | g(object()) 139 | 140 | 141 | def test_pytree_tuple(typecheck): 142 | @typecheck 143 | def g(x: PyTree[Tuple[int, int]]): 144 | pass 145 | 146 | g((1, 1)) 147 | g([(1, 1)]) 148 | g([(1, 1), {"a": (1, 1)}]) 149 | with pytest.raises(ParamError): 150 | g(object()) 151 | with pytest.raises(ParamError): 152 | g(1) 153 | with pytest.raises(ParamError): 154 | g((1, 2, 3)) 155 | with pytest.raises(ParamError): 156 | g([1, 1]) 157 | with pytest.raises(ParamError): 158 | g([(1, 1), "hi"]) 159 | 160 | 161 | def test_pytree_namedtuple(typecheck): 162 | class CustomNamedTuple(NamedTuple): 163 | x: Float[jnp.ndarray, "a b"] 164 | y: Float[jnp.ndarray, "b c"] 165 | 166 | class OtherCustomNamedTuple(NamedTuple): 167 | x: Float[jnp.ndarray, "a b"] 168 | y: Float[jnp.ndarray, "b c"] 169 | 170 | @typecheck 171 | def g(x: PyTree[CustomNamedTuple]): ... 172 | 173 | g( 174 | CustomNamedTuple( 175 | x=jax.random.normal(jax.random.PRNGKey(42), (3, 2)), 176 | y=jax.random.normal(jax.random.PRNGKey(420), (2, 5)), 177 | ) 178 | ) 179 | with pytest.raises(ParamError): 180 | g(object()) 181 | with pytest.raises(ParamError): 182 | g( 183 | OtherCustomNamedTuple( 184 | x=jax.random.normal(jax.random.PRNGKey(42), (3, 2)), 185 | y=jax.random.normal(jax.random.PRNGKey(420), (2, 5)), 186 | ) 187 | ) 188 | 189 | 190 | def test_subclass_pytree(): 191 | x = PyTree 192 | y = PyTree[int] 193 | assert issubclass(x, PyTree) 194 | assert issubclass(y, PyTree) 195 | assert not issubclass(int, PyTree) 196 | 197 | 198 | def test_structure_match(jaxtyp, typecheck): 199 | @jaxtyp(typecheck) 200 | def f(x: PyTree[int, " T"], y: PyTree[str, " T"]): 201 | pass 202 | 203 | f(1, "hi") 204 | f((3, 4, {"a": 5}), ("a", "b", {"a": "c"})) 205 | 206 | with pytest.raises(ParamError): 207 | f(1, ("hi",)) 208 | 209 | 210 | def test_structure_prefix(jaxtyp, typecheck): 211 | @jaxtyp(typecheck) 212 | def f(x: PyTree[int, " T"], y: PyTree[str, "T ..."]): 213 | pass 214 | 215 | f(1, "hi") 216 | f((3, 4, {"a": 5}), ("a", "b", {"a": "c"})) 217 | f(1, ("hi",)) 218 | f((1, 2), ({"a": "hi"}, {"a": "bye"})) 219 | f((1, 2), ({"a": "hi"}, {"not-a": "bye"})) 220 | 221 | with pytest.raises(ParamError): 222 | f((1, 2), ({"a": "hi"}, {"a": "bye"}, {"a": "oh-no"})) 223 | 224 | with pytest.raises(ParamError): 225 | f((3, 4, 5), {"a": ("hi", "bye")}) 226 | 227 | 228 | def test_structure_suffix(jaxtyp, typecheck): 229 | @jaxtyp(typecheck) 230 | def f(x: PyTree[int, " T"], y: PyTree[str, "... T"]): 231 | pass 232 | 233 | f(1, "hi") 234 | f((3, 4, {"a": 5}), ("a", "b", {"a": "c"})) 235 | f(1, ("hi",)) 236 | 237 | with pytest.raises(ParamError): 238 | f((3, 4), {"a": (1, 2)}) 239 | 240 | with pytest.raises(ParamError): 241 | f((3, 4, 5), {"a": ("hi", "bye")}) 242 | 243 | 244 | def test_structure_compose(jaxtyp, typecheck): 245 | @jaxtyp(typecheck) 246 | def f(x: PyTree[int, " T"], y: PyTree[int, " S"], z: PyTree[str, "S T"]): 247 | pass 248 | 249 | f(1, 2, "hi") 250 | f((1, 2), 2, ("a", "b")) 251 | 252 | with pytest.raises(ParamError): 253 | f((1, 2), 2, (1, 2)) 254 | 255 | f((1, 2), {"a": 3}, {"a": ("hi", "bye")}) 256 | 257 | with pytest.raises(ParamError): 258 | f((1, 2), {"a": 3}, ({"a": "hi"}, {"a": "bye"})) 259 | 260 | @jaxtyp(typecheck) 261 | def g(x: PyTree[int, " T"], y: PyTree[int, " S"], z: PyTree[str, "T S"]): 262 | pass 263 | 264 | with pytest.raises(ParamError): 265 | g((1, 2), {"a": 3}, {"a": ("hi", "bye")}) 266 | 267 | g((1, 2), {"a": 3}, ({"a": "hi"}, {"a": "bye"})) 268 | 269 | 270 | @pytest.mark.parametrize("variadic", (False, True)) 271 | def test_treepath_dependence_function(variadic, jaxtyp, typecheck, getkey): 272 | if variadic: 273 | jtshape = "*?foo" 274 | shape = (2, 3) 275 | else: 276 | jtshape = "?foo" 277 | shape = (4,) 278 | 279 | @jaxtyp(typecheck) 280 | def f( 281 | x: PyTree[Float[Array, jtshape], " T"], y: PyTree[Float[Array, jtshape], " T"] 282 | ): 283 | pass 284 | 285 | x1 = jr.normal(getkey(), shape) 286 | y1 = jr.normal(getkey(), shape) 287 | x2 = jr.normal(getkey(), (5,)) 288 | y2 = jr.normal(getkey(), (5,)) 289 | f(x1, y1) 290 | f((x1, x2), (y1, y2)) 291 | 292 | with pytest.raises(ParamError): 293 | f(x1, y2) 294 | 295 | with pytest.raises(ParamError): 296 | f((x1, x2), (y2, y1)) 297 | 298 | 299 | @pytest.mark.parametrize("variadic", (False, True)) 300 | def test_treepath_dependence_dataclass(variadic, typecheck, getkey): 301 | if variadic: 302 | jtshape = "*?foo" 303 | shape = (2, 3) 304 | else: 305 | jtshape = "?foo" 306 | shape = (4,) 307 | 308 | @jaxtyping.jaxtyped(typechecker=typecheck) 309 | class A(eqx.Module): 310 | x: PyTree[Float[Array, jtshape], " T"] 311 | y: PyTree[Float[Array, jtshape], " T"] 312 | 313 | x1 = jr.normal(getkey(), shape) 314 | y1 = jr.normal(getkey(), shape) 315 | x2 = jr.normal(getkey(), (5,)) 316 | y2 = jr.normal(getkey(), (5,)) 317 | A(x1, y1) 318 | A((x1, x2), (y1, y2)) 319 | 320 | with pytest.raises(ParamError): 321 | A(x1, y2) 322 | 323 | with pytest.raises(ParamError): 324 | A((x1, x2), (y2, y1)) 325 | 326 | 327 | def test_treepath_dependence_missing_structure_annotation(jaxtyp, typecheck, getkey): 328 | @jaxtyp(typecheck) 329 | def f(x: PyTree[Float[Array, "?foo"], " T"], y: PyTree[Float[Array, "?foo"]]): 330 | pass 331 | 332 | x1 = jr.normal(getkey(), (2,)) 333 | y1 = jr.normal(getkey(), (2,)) 334 | with pytest.raises(AnnotationError, match="except when contained with structured"): 335 | f(x1, y1) 336 | 337 | 338 | def test_treepath_dependence_multiple_structure_annotation(jaxtyp, typecheck, getkey): 339 | @jaxtyp(typecheck) 340 | def f(x: PyTree[PyTree[Float[Array, "?foo"], " S"], " T"]): 341 | pass 342 | 343 | x1 = jr.normal(getkey(), (2,)) 344 | with pytest.raises(AnnotationError, match="ambiguous which PyTree"): 345 | f(x1) 346 | 347 | 348 | def test_name(): 349 | assert PyTree.__name__ == "PyTree" 350 | assert PyTree[int].__name__ == "PyTree[int]" 351 | assert PyTree[int, "foo"].__name__ == "PyTree[int, 'foo']" 352 | assert PyTree[PyTree[str], "foo"].__name__ == "PyTree[PyTree[str], 'foo']" 353 | assert ( 354 | PyTree[PyTree[str, "bar"], "foo"].__name__ 355 | == "PyTree[PyTree[str, 'bar'], 'foo']" 356 | ) 357 | assert PyTree[PyTree[str, "bar"]].__name__ == "PyTree[PyTree[str, 'bar']]" 358 | assert ( 359 | PyTree[None | Callable[[PyTree[int, " T"]], str]].__name__ 360 | == "PyTree[None | Callable[[PyTree[int, 'T']], str]]" 361 | ) 362 | 363 | 364 | def test_pdoc(): 365 | assert wl.pformat(PyTree) == "PyTree" 366 | assert wl.pformat(PyTree[int]) == "PyTree[int]" 367 | assert wl.pformat(PyTree[int, "foo"]) == "PyTree[int, 'foo']" 368 | assert wl.pformat(PyTree[PyTree[str], "foo"]) == "PyTree[PyTree[str], 'foo']" 369 | assert ( 370 | wl.pformat(PyTree[PyTree[str, "bar"], "foo"]) 371 | == "PyTree[PyTree[str, 'bar'], 'foo']" 372 | ) 373 | assert wl.pformat(PyTree[PyTree[str, "bar"]]) == "PyTree[PyTree[str, 'bar']]" 374 | assert ( 375 | wl.pformat(PyTree[None | Callable[[PyTree[int, " T"]], str]]) 376 | == "PyTree[None | Callable[[PyTree[int, 'T']], str]]" 377 | ) 378 | expected = """ 379 | PyTree[ 380 | None 381 | | Callable[ 382 | [ 383 | PyTree[ 384 | int, 385 | 'T' 386 | ] 387 | ], 388 | str 389 | ] 390 | ] 391 | """.strip() 392 | assert ( 393 | wl.pformat(PyTree[None | Callable[[PyTree[int, " T"]], str]], width=2).strip() 394 | == expected 395 | ) 396 | -------------------------------------------------------------------------------- /test/test_serialisation.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import cloudpickle 4 | import numpy as np 5 | 6 | 7 | try: 8 | import torch 9 | except ImportError: 10 | torch = None 11 | 12 | from jaxtyping import AbstractArray, Array, Float, Shaped 13 | 14 | 15 | def test_pickle(): 16 | for p in (pickle, cloudpickle): 17 | x = p.dumps(Shaped[Array, ""]) 18 | y = p.loads(x) 19 | assert y.dtype is Shaped 20 | assert y.dim_str == "" 21 | 22 | x = p.dumps(AbstractArray) 23 | y = p.loads(x) 24 | assert y is AbstractArray 25 | 26 | x = p.dumps(Shaped[np.ndarray, "3 4 hi"]) 27 | y = p.loads(x) 28 | assert y.dtype is Shaped 29 | assert y.dim_str == "3 4 hi" 30 | 31 | if torch is not None: 32 | x = p.dumps(Float[torch.Tensor, "batch length"]) 33 | y = p.loads(x) 34 | assert y.dtype is Float 35 | assert y.dim_str == "batch length" 36 | -------------------------------------------------------------------------------- /test/test_tf_dtype.py: -------------------------------------------------------------------------------- 1 | # Tensorflow dependency kept in a separate file, so that we can optionally exclude it 2 | # more easily. 3 | import tensorflow as tf 4 | 5 | from jaxtyping import UInt 6 | 7 | 8 | def test_tf_dtype(): 9 | x = tf.constant(1, dtype=tf.uint8) 10 | y = tf.constant(1, dtype=tf.float32) 11 | hint = UInt[tf.Tensor, "..."] 12 | assert isinstance(x, hint) 13 | assert not isinstance(y, hint) 14 | -------------------------------------------------------------------------------- /test/test_threading.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Google LLC 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | import threading 21 | 22 | import jax.numpy as jnp 23 | from typeguard import typechecked 24 | 25 | from jaxtyping import Array, Float, jaxtyped 26 | 27 | 28 | class _ErrorableThread(threading.Thread): 29 | def run(self): 30 | try: 31 | super().run() 32 | except Exception as e: 33 | self.exc = e 34 | 35 | def join(self, timeout=None): 36 | super().join(timeout) 37 | if hasattr(self, "exc"): 38 | raise self.exc 39 | 40 | 41 | def test_threading_jaxtyped(): 42 | @jaxtyped(typechecker=typechecked) 43 | def add(x: Float[Array, "a b"], y: Float[Array, "a b"]) -> Float[Array, "a b"]: 44 | return x + y 45 | 46 | def run(): 47 | a = jnp.array([[1.0, 2.0]]) 48 | b = jnp.array([[2.0, 3.0]]) 49 | add(a, b) 50 | 51 | thread = _ErrorableThread(target=run) 52 | thread.start() 53 | thread.join() 54 | 55 | 56 | def test_threading_nojaxtyped(): 57 | def run(): 58 | a = jnp.array([[1.0, 2.0]]) 59 | assert isinstance(a, Float[Array, "..."]) 60 | 61 | thread = _ErrorableThread(target=run) 62 | thread.start() 63 | thread.join() 64 | -------------------------------------------------------------------------------- /test/types/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrick-kidger/jaxtyping/fe61644be8590cf0a89bacc278283e1e4b9ea3e4/test/types/__init__.py -------------------------------------------------------------------------------- /test/types/decorator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import numpy as np 4 | from beartype import beartype 5 | 6 | from jaxtyping import Float, Int, jaxtyped 7 | 8 | 9 | @jaxtyped(typechecker=beartype) 10 | @dataclass 11 | class User: 12 | name: str 13 | age: int 14 | items: Float[np.ndarray, " N"] 15 | timestamps: Int[np.ndarray, " N"] 16 | 17 | 18 | @jaxtyped(typechecker=beartype) 19 | def transform_user(user: User, increment_age: int = 1) -> User: 20 | user.age += increment_age 21 | return user 22 | 23 | 24 | user = User( 25 | name="John", 26 | age=20, 27 | items=np.random.normal(size=10), 28 | timestamps=np.random.randint(0, 100, size=10), 29 | ) 30 | 31 | new_user = transform_user(user, increment_age=2) 32 | --------------------------------------------------------------------------------