├── .github
└── workflows
│ ├── build_docs.yml
│ ├── release.yml
│ └── run_tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── docs
├── .citation.md
├── .htaccess
├── _overrides
│ ├── bluesky.svg
│ └── partials
│ │ └── source.html
├── _static
│ ├── .README.md
│ ├── custom_css.css
│ ├── favicon.png
│ └── mathjax.js
├── api
│ ├── advanced-features.md
│ ├── array.md
│ ├── pytree.md
│ └── runtime-type-checking.md
├── faq.md
└── index.md
├── jaxtyping
├── __init__.py
├── _array_types.py
├── _config.py
├── _decorator.py
├── _errors.py
├── _import_hook.py
├── _indirection.py
├── _ipython_extension.py
├── _pytest_plugin.py
├── _pytree_type.py
├── _storage.py
├── _typeguard
│ ├── LICENSE
│ ├── README.md
│ └── __init__.py
└── py.typed
├── mkdocs.yml
├── pyproject.toml
└── test
├── __init__.py
├── conftest.py
├── helpers.py
├── import_hook_tester.py
├── requirements.txt
├── test_all_importable.py
├── test_array.py
├── test_decorator.py
├── test_generators.py
├── test_import_hook.py
├── test_ipython_extension.py
├── test_messages.py
├── test_no_jax_dependency.py
├── test_pytree.py
├── test_serialisation.py
├── test_tf_dtype.py
├── test_threading.py
└── types
├── __init__.py
└── decorator.py
/.github/workflows/build_docs.yml:
--------------------------------------------------------------------------------
1 | name: Build docs
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 |
8 | jobs:
9 | build:
10 | strategy:
11 | matrix:
12 | python-version: [ 3.11 ]
13 | os: [ ubuntu-latest ]
14 | runs-on: ${{ matrix.os }}
15 | steps:
16 | - name: Checkout code
17 | uses: actions/checkout@v2
18 |
19 | - name: Set up Python ${{ matrix.python-version }}
20 | uses: actions/setup-python@v2
21 | with:
22 | python-version: ${{ matrix.python-version }}
23 |
24 | - name: Install dependencies
25 | run: |
26 | python -m pip install --upgrade pip
27 | python -m pip install '.[docs]'
28 | python -m pip install jax[cpu]
29 |
30 | - name: Build docs
31 | run: |
32 | mkdocs build
33 |
34 | - name: Upload docs
35 | uses: actions/upload-artifact@v4
36 | with:
37 | name: docs
38 | path: site # where `mkdocs build` puts the built site
39 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Google LLC
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | name: Release
21 |
22 | on:
23 | push:
24 | branches:
25 | - main
26 |
27 | jobs:
28 | build:
29 | runs-on: ubuntu-latest
30 | steps:
31 | - name: Release
32 | uses: patrick-kidger/action_update_python_project@v6
33 | with:
34 | python-version: "3.11"
35 | test-script: |
36 | python -m pip install -r ${{ github.workspace }}/test/requirements.txt
37 | python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
38 | cp -r ${{ github.workspace }}/test ./test
39 | pytest
40 | pypi-token: ${{ secrets.pypi_token }}
41 | github-user: patrick-kidger
42 | github-token: ${{ github.token }}
43 |
--------------------------------------------------------------------------------
/.github/workflows/run_tests.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Google LLC
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | name: Run tests
21 |
22 | on:
23 | pull_request:
24 |
25 | jobs:
26 | run-tests:
27 | strategy:
28 | matrix:
29 | python-version: [ "3.10", "3.12" ]
30 | os: [ ubuntu-latest ]
31 | fail-fast: false
32 | runs-on: ${{ matrix.os }}
33 | steps:
34 | - name: Checkout code
35 | uses: actions/checkout@v2
36 |
37 | - name: Set up Python ${{ matrix.python-version }}
38 | uses: actions/setup-python@v2
39 | with:
40 | python-version: ${{ matrix.python-version }}
41 |
42 | - name: Install dependencies
43 | run: |
44 | python -m pip install --upgrade pip
45 | python -m pip install -r test/requirements.txt
46 | python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
47 |
48 | - name: Checks with pre-commit
49 | uses: pre-commit/action@v3.0.1
50 |
51 | - name: Test with pytest
52 | run: |
53 | python -m pip install .
54 | python -m pytest --durations=0
55 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__
2 | *.egg-info
3 | build/
4 | dist/
5 | site/
6 | .all_objects.cache
7 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Google LLC
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | repos:
21 | - repo: https://github.com/astral-sh/ruff-pre-commit
22 | rev: v0.7.3
23 | hooks:
24 | - id: ruff-format # formatter
25 | types_or: [ python, pyi, jupyter ]
26 | - id: ruff # linter
27 | types_or: [ python, pyi, jupyter ]
28 | args: [ --fix ]
29 | - repo: https://github.com/RobertCraigie/pyright-python
30 | rev: v1.1.391
31 | hooks:
32 | - id: pyright
33 | files: ^test/types/
34 | additional_dependencies:
35 | [beartype, numpy<2]
36 | - repo: https://github.com/pre-commit/mirrors-mypy
37 | rev: v1.14.1
38 | hooks:
39 | - id: mypy
40 | files: ^test/types/
41 | additional_dependencies:
42 | [beartype, numpy<2]
43 | args: ["--ignore-missing-imports", "--follow-imports=skip"]
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | Contributions (pull requests) are very welcome! Here's how to get started.
4 |
5 | ---
6 |
7 | First fork the library on GitHub.
8 |
9 | Then clone and install the library in development mode:
10 |
11 | ```bash
12 | git clone https://github.com/your-username-here/jaxtyping.git
13 | cd jaxtyping
14 | pip install -e .
15 | ```
16 |
17 | Then install the pre-commit hook:
18 |
19 | ```bash
20 | pip install pre-commit
21 | pre-commit install
22 | ```
23 |
24 | These hooks use ruff to lint and format the code.
25 |
26 | Now make your changes. Make sure to include additional tests if necessary.
27 |
28 | Next verify the tests all pass:
29 |
30 | ```bash
31 | pip install -r test/requirements.txt
32 | pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
33 | pytest
34 | ```
35 |
36 | Then push your changes back to your fork of the repository:
37 |
38 | ```bash
39 | git push
40 | ```
41 |
42 | Finally, open a pull request on GitHub!
43 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Google LLC
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 |
24 |
25 |
26 | ---
27 | Sections of the code were modified from https://github.com/agronholm/typeguard
28 | under the terms of the MIT license, reproduced below.
29 | ---
30 |
31 | MIT License
32 |
33 | Copyright (c) Alex Grönholm
34 |
35 | Permission is hereby granted, free of charge, to any person obtaining a copy
36 | of this software and associated documentation files (the "Software"), to deal
37 | in the Software without restriction, including without limitation the rights
38 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
39 | copies of the Software, and to permit persons to whom the Software is
40 | furnished to do so, subject to the following conditions:
41 |
42 | The above copyright notice and this permission notice shall be included in all
43 | copies or substantial portions of the Software.
44 |
45 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
46 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
47 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
48 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
49 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
50 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
51 | SOFTWARE.
52 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
jaxtyping
2 |
3 | Type annotations **and runtime type-checking** for:
4 |
5 | 1. shape and dtype of [JAX](https://github.com/google/jax) arrays; *(Now also supports PyTorch, NumPy, MLX, and TensorFlow!)*
6 | 2. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html).
7 |
8 | **For example:**
9 | ```python
10 | from jaxtyping import Array, Float, PyTree
11 |
12 | # Accepts floating-point 2D arrays with matching axes
13 | # You can replace `Array` with `torch.Tensor` etc.
14 | def matrix_multiply(x: Float[Array, "dim1 dim2"],
15 | y: Float[Array, "dim2 dim3"]
16 | ) -> Float[Array, "dim1 dim3"]:
17 | ...
18 |
19 | def accepts_pytree_of_ints(x: PyTree[int]):
20 | ...
21 |
22 | def accepts_pytree_of_arrays(x: PyTree[Float[Array, "batch c1 c2"]]):
23 | ...
24 | ```
25 |
26 | ## Installation
27 |
28 | ```bash
29 | pip install jaxtyping
30 | ```
31 |
32 | Requires Python 3.10+.
33 |
34 | JAX is an optional dependency, required for a few JAX-specific types. If JAX is not installed then these will not be available, but you may still use jaxtyping to provide shape/dtype annotations for PyTorch/NumPy/TensorFlow/etc.
35 |
36 | The annotations provided by jaxtyping are compatible with runtime type-checking packages, so it is common to also install one of these. The two most popular are [typeguard](https://github.com/agronholm/typeguard) (which exhaustively checks every argument) and [beartype](https://github.com/beartype/beartype) (which checks random pieces of arguments).
37 |
38 | ## Documentation
39 |
40 | Available at [https://docs.kidger.site/jaxtyping](https://docs.kidger.site/jaxtyping).
41 |
42 | ## See also: other libraries in the JAX ecosystem
43 |
44 | **Always useful**
45 | [Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX!
46 |
47 | **Deep learning**
48 | [Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
49 | [Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
50 | [Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
51 | [paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees.
52 |
53 | **Scientific computing**
54 | [Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
55 | [Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
56 | [Lineax](https://github.com/patrick-kidger/lineax): linear solvers.
57 | [BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.
58 | [sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.
59 | [PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)
60 |
61 | **Awesome JAX**
62 | [Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects.
63 |
--------------------------------------------------------------------------------
/docs/.citation.md:
--------------------------------------------------------------------------------
1 | If you found this library to be useful in academic work, then please cite: ([arXiv link](https://arxiv.org/abs/2111.00254))
2 |
3 | ```bibtex
4 | @article{kidger2021equinox,
5 | author={Patrick Kidger and Cristian Garcia},
6 | title={{E}quinox: neural networks in {JAX} via callable {P}y{T}rees and filtered transformations},
7 | year={2021},
8 | journal={Differentiable Programming workshop at Neural Information Processing Systems 2021}
9 | }
10 | ```
11 |
12 | (Also consider starring the project [on GitHub](https://github.com/patrick-kidger/equinox).)
13 |
--------------------------------------------------------------------------------
/docs/.htaccess:
--------------------------------------------------------------------------------
1 | ErrorDocument 404 /jaxtyping/404.html
2 |
--------------------------------------------------------------------------------
/docs/_overrides/bluesky.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/docs/_overrides/partials/source.html:
--------------------------------------------------------------------------------
1 | {% import "partials/language.html" as lang with context %}
2 |
3 |
4 | {% set icon = config.theme.icon.repo or "fontawesome/brands/git-alt" %}
5 | {% include ".icons/" ~ icon ~ ".svg" %}
6 |
7 |
8 | {{ config.repo_name }}
9 |
10 |
11 |
12 |
13 | {% include ".icons/fontawesome/brands/twitter.svg" %}
14 |
15 |
16 |
17 |
18 | {% include "bluesky.svg" %}
19 |
20 |
21 | {{ config.theme.twitter_bluesky_name }}
22 |
23 |
24 |
--------------------------------------------------------------------------------
/docs/_static/.README.md:
--------------------------------------------------------------------------------
1 | The favicon is `math-integral` from https://materialdesignicons.com, found by way of https://pictogrammers.com.
2 | (The logo is `math-integral-box`.)
3 |
--------------------------------------------------------------------------------
/docs/_static/custom_css.css:
--------------------------------------------------------------------------------
1 | /* Fix /page#foo going to the top of the viewport and being hidden by the navbar */
2 | html {
3 | scroll-padding-top: 50px;
4 | }
5 |
6 | /* Fit the Twitter handle alongside the GitHub one in the top right. */
7 |
8 | div.md-header__source {
9 | width: revert;
10 | max-width: revert;
11 | }
12 |
13 | a.md-source {
14 | display: inline-block;
15 | }
16 |
17 | .md-source__repository {
18 | max-width: 100%;
19 | }
20 |
21 | /* Emphasise sections of nav on left hand side */
22 |
23 | nav.md-nav {
24 | padding-left: 5px;
25 | }
26 |
27 | nav.md-nav--secondary {
28 | border-left: revert !important;
29 | }
30 |
31 | .md-nav__title {
32 | font-size: 0.9rem;
33 | }
34 |
35 | .md-nav__item--section > .md-nav__link {
36 | font-size: 0.9rem;
37 | }
38 |
39 | /* Indent autogenerated documentation */
40 |
41 | div.doc-contents {
42 | padding-left: 25px;
43 | border-left: 4px solid rgba(230, 230, 230);
44 | }
45 |
46 | /* Increase visibility of splitters "---" */
47 |
48 | [data-md-color-scheme="default"] .md-typeset hr {
49 | border-bottom-color: rgb(0, 0, 0);
50 | border-bottom-width: 1pt;
51 | }
52 |
53 | [data-md-color-scheme="slate"] .md-typeset hr {
54 | border-bottom-color: rgb(230, 230, 230);
55 | }
56 |
57 | /* More space at the bottom of the page */
58 |
59 | .md-main__inner {
60 | margin-bottom: 1.5rem;
61 | }
62 |
63 | /* Remove prev/next footer buttons */
64 |
65 | .md-footer__inner {
66 | display: none;
67 | }
68 |
69 | /* Change font sizes */
70 |
71 | html {
72 | /* Decrease font size for overall webpage
73 | Down from 137.5% which is the Material default */
74 | font-size: 110%;
75 | }
76 |
77 | .md-typeset .admonition {
78 | /* Increase font size in admonitions */
79 | font-size: 100% !important;
80 | }
81 |
82 | .md-typeset details {
83 | /* Increase font size in details */
84 | font-size: 100% !important;
85 | }
86 |
87 | .md-typeset h1 {
88 | font-size: 1.6rem;
89 | }
90 |
91 | .md-typeset h2 {
92 | font-size: 1.5rem;
93 | }
94 |
95 | .md-typeset h3 {
96 | font-size: 1.3rem;
97 | }
98 |
99 | .md-typeset h4 {
100 | font-size: 1.1rem;
101 | }
102 |
103 | .md-typeset h5 {
104 | font-size: 0.9rem;
105 | }
106 |
107 | .md-typeset h6 {
108 | font-size: 0.8rem;
109 | }
110 |
111 | /* Bugfix: remove the superfluous parts generated when doing:
112 |
113 | ??? Blah
114 |
115 | ::: library.something
116 | */
117 |
118 | .md-typeset details .mkdocstrings > h4 {
119 | display: none;
120 | }
121 |
122 | .md-typeset details .mkdocstrings > h5 {
123 | display: none;
124 | }
125 |
126 | /* Change default colours for tags */
127 |
128 | [data-md-color-scheme="default"] {
129 | --md-typeset-a-color: rgb(0, 189, 164) !important;
130 | }
131 | [data-md-color-scheme="slate"] {
132 | --md-typeset-a-color: rgb(0, 189, 164) !important;
133 | }
134 |
135 | /* Highlight functions, classes etc. type signatures. Really helps to make clear where
136 | one item ends and another begins. */
137 |
138 | [data-md-color-scheme="default"] {
139 | --doc-heading-color: #DDD;
140 | --doc-heading-border-color: #CCC;
141 | --doc-heading-color-alt: #F0F0F0;
142 | }
143 | [data-md-color-scheme="slate"] {
144 | --doc-heading-color: rgb(25,25,33);
145 | --doc-heading-border-color: rgb(25,25,33);
146 | --doc-heading-color-alt: rgb(33,33,44);
147 | --md-code-bg-color: rgb(38,38,50);
148 | }
149 |
150 | h4.doc-heading {
151 | /* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/
152 | background-color: var(--doc-heading-color);
153 | border: solid var(--doc-heading-border-color);
154 | border-width: 1.5pt;
155 | border-radius: 2pt;
156 | padding: 0pt 5pt 2pt 5pt;
157 | }
158 | h5.doc-heading, h6.heading {
159 | background-color: var(--doc-heading-color-alt);
160 | border-radius: 2pt;
161 | padding: 0pt 5pt 2pt 5pt;
162 | }
163 |
164 | /* Make errors in notebooks have scrolling */
165 | .output_error > pre {
166 | overflow: auto;
167 | }
168 |
--------------------------------------------------------------------------------
/docs/_static/favicon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrick-kidger/jaxtyping/fe61644be8590cf0a89bacc278283e1e4b9ea3e4/docs/_static/favicon.png
--------------------------------------------------------------------------------
/docs/_static/mathjax.js:
--------------------------------------------------------------------------------
1 | window.MathJax = {
2 | tex: {
3 | inlineMath: [["\\(", "\\)"]],
4 | displayMath: [["\\[", "\\]"]],
5 | processEscapes: true,
6 | processEnvironments: true
7 | },
8 | options: {
9 | ignoreHtmlClass: ".*|",
10 | processHtmlClass: "arithmatex"
11 | }
12 | };
13 |
14 | document$.subscribe(() => {
15 | MathJax.typesetPromise()
16 | })
17 |
--------------------------------------------------------------------------------
/docs/api/advanced-features.md:
--------------------------------------------------------------------------------
1 | # Advanced features
2 |
3 | ## Creating your own dtypes
4 |
5 | ::: jaxtyping.AbstractDtype
6 | options:
7 | members: []
8 |
9 | ::: jaxtyping.make_numpy_struct_dtype
10 |
11 | ## Printing axis bindings
12 |
13 | ::: jaxtyping.print_bindings
14 |
15 | ## Introspection
16 |
17 | If you're writing your own type hint parser, then you may wish to detect if some Python object is a jaxtyping-provided type.
18 |
19 | You can check for dtypes by doing `issubclass(x, AbstractDtype)`. For example, `issubclass(Float32, AbstractDtype)` will pass.
20 |
21 | You can check for arrays by doing `issubclass(x, AbstractArray)`. Here, `AbstractArray` is the base class for all shape-and-dtype specified arrays, e.g. it's a base class for `Float32[Array, "foo"]`.
22 |
23 | You can check for pytrees by doing `issubclass(x, PyTree)`. For example, `issubclass(PyTree[int], PyTree)` will pass.
24 |
--------------------------------------------------------------------------------
/docs/api/array.md:
--------------------------------------------------------------------------------
1 | # Array annotations
2 |
3 | The shape and dtypes of arrays can be annotated in the form `dtype[array, shape]`, such as `Float[Array, "batch channels"]`.
4 |
5 | ## Shape
6 |
7 | **Symbols**
8 |
9 | The shape should be a string of space-separated symbols, such as `"a b c d"`. Each symbol can be either an:
10 |
11 | - `int`: fixed-size axis, e.g. `"28 28"`.
12 | - `str`: variable-size axis, e.g. `"channels"`.
13 | - A symbolic expression in terms of other variable-size axes, e.g.
14 | `def remove_last(x: Float[Array, "dim"]) -> Float[Array, "dim-1"]`.
15 | Symbolic expressions must not use any spaces, otherwise each piece is treated as as a separate axis.
16 |
17 | When calling a function, variable-size axes and symbolic axes will be matched up across all arguments and checked for consistency. (See [Runtime type checking](./runtime-type-checking.md).)
18 |
19 | **Modifiers**
20 |
21 | In addition some modifiers can be applied:
22 |
23 | - Prepend `*` to an axis to indicate that it can match multiple axes, e.g. `"*batch"` will match zero or more batch axes.
24 | - Prepend `#` to an axis to indicate that it can be that size *or* equal to one -- i.e. broadcasting is acceptable, e.g.
25 | `def add(x: Float[Array, "#foo"], y: Float[Array, "#foo"]) -> Float[Array, "#foo"]`.
26 | - Prepend `_` to an axis to disable any runtime checking of that axis (so that it can be used just as documentation). This can also be used as just `_` on its own: e.g. `"b c _ _"`.
27 | - Documentation-only names (i.e. they're ignored by jaxtyping) can be handled by prepending a name followed by `=` e.g. `Float[Array, "rows=4 cols=3"]`.
28 | - Prepend `?` to an axis to indicate that its size can vary within a PyTree structure. (See [PyTree annotations](./pytree.md).)
29 |
30 | When using multiple modifiers, their order does not matter.
31 |
32 | As a special case:
33 |
34 | - `...`: anonymous zero or more axes (equivalent to `*_`) e.g. `"... c h w"`
35 |
36 | **Notes**
37 |
38 | - To denote a scalar shape use `""`, e.g. `Float[Array, ""]`.
39 | - To denote an arbitrary shape (and only check dtype) use `"..."`, e.g. `Float[Array, "..."]`.
40 | - You cannot have more than one use of multiple-axes, i.e. you can only use `...` or `*name` at most once in each array.
41 | - A symbolic expression cannot be evaluated unless all of the axes sizes it refers to have already been processed. In practice this usually means that they should only be used in annotations for the return type, and only use axes declared in the arguments.
42 | - Symbolic expressions are evaluated in two stages: they are first evaluated as f-strings using the arguments of the function, and second are evaluated using the processed axis sizes. The f-string evaluation means that they can use local variables by enclosing them with curly braces, e.g. `{variable}`, e.g.
43 | ```python
44 | def full(size: int, fill: float) -> Float[Array, "{size}"]:
45 | return jax.numpy.full((size,), fill)
46 |
47 | class SomeClass:
48 | some_value = 5
49 |
50 | def full(self, fill: float) -> Float[Array, "{self.some_value}+3"]:
51 | return jax.numpy.full((self.some_value + 3,), fill)
52 | ```
53 |
54 | ## Dtype
55 |
56 | The dtype should be any one of (all imported from `jaxtyping`):
57 |
58 | - Any dtype at all: `Shaped`
59 | - Boolean: `Bool`
60 | - PRNG key: `Key`
61 | - Any integer, unsigned integer, floating, or complex: `Num`
62 | - Any floating or complex: `Inexact`
63 | - Any floating point: `Float`
64 | - Of particular precision: `BFloat16`, `Float16`, `Float32`, `Float64`
65 | - Any complex: `Complex`
66 | - Of particular precision: `Complex64`, `Complex128`
67 | - Any integer or unsigned intger: `Integer`
68 | - Any unsigned integer: `UInt`
69 | - Of particular precision: `UInt2`, `UInt4`, `UInt8`, `UInt16`, `UInt32`, `UInt64`
70 | - Any signed integer: `Int`
71 | - Of particular precision: `Int2`, `Int4`, `Int8`, `Int16`, `Int32`, `Int64`
72 | - Any floating, integer, or unsigned integer: `Real`.
73 |
74 | Unless you really want to force a particular precision, then for most applications you should probably allow any floating-point, any integer, etc. That is, use
75 | ```python
76 | from jaxtyping import Array, Float
77 | Float[Array, "some_shape"]
78 | ```
79 | rather than
80 | ```python
81 | from jaxtyping import Array, Float32
82 | Float32[Array, "some_shape"]
83 | ```
84 |
85 | ## Array
86 |
87 | The array should typically be either one of:
88 | ```python
89 | jaxtyping.Array / jax.Array / jax.numpy.ndarray # these are all aliases of one another
90 | np.ndarray
91 | torch.Tensor
92 | tf.Tensor
93 | mx.array
94 | ```
95 | That is -- despite the now-historical name! -- jaxtyping also supports NumPy + PyTorch + TensorFlow + MLX.
96 |
97 | Some other types are also supported here:
98 |
99 | **Unions:** these are unpacked. For example, `SomeDtype[Union[A, B], "some shape"]` is equivalent to `Union[SomeDtype[A, "some shape"], SomeDtype[B, "some shape"]]`. A common example of a union type here is `np.typing.ArrayLike`.
100 |
101 | **Any:** use `typing.Any` to check just the shape/dtype, but not the array type.
102 |
103 | **Duck-type arrays:** anything with `.shape` and `.dtype` attributes. For example,
104 | ```python
105 | class MyDuckArray:
106 | @property
107 | def shape(self) -> tuple[int, ...]:
108 | return (3, 4, 5)
109 |
110 | @property
111 | def dtype(self) -> str:
112 | return "my_dtype"
113 |
114 | class MyDtype(jaxtyping.AbstractDtype):
115 | dtypes = ["my_dtype"]
116 |
117 | x = MyDuckArray()
118 | assert isinstance(x, MyDtype[MyDuckArray, "3 4 5"])
119 | # checks that `type(x) == MyDuckArray`
120 | # and that `x.shape == (3, 4, 5)`
121 | # and that `x.dtype == "my_dtype"`
122 | ```
123 |
124 | **TypeVars:** in this case the runtime array is checked for matching the bounds or constraints of the `typing.TypeVar`.
125 |
126 | **Existing jaxtyped annotations:**
127 | ```python
128 | Image = Float[Array, "channels height width"]
129 | BatchImage = Float[Image, "batch"]
130 | ```
131 | in which case the additional shape is prepended, and the acceptable dtypes are the intersection of the two dtype specifiers used. (So that e.g. `BatchImage = Shaped[Image, "batch"]` would work just as well. But `Bool[Image, "batch"]` would throw an error, as there are no dtypes that are both bools and floats.) Thus the above is equivalent to
132 | ```python
133 | BatchImage = Float[Array, "batch channels height width"]
134 | ```
135 |
136 | Note that `jaxtyping.{Array, ArrayLike}` are only available if JAX has been installed.
137 |
138 | ## Scalars, PRNG keys
139 |
140 | For convenience, jaxtyping also includes `jaxtyping.Scalar`, `jaxtyping.ScalarLike`, and `jaxtyping.PRNGKeyArray`, defined as:
141 | ```python
142 | Scalar = Shaped[Array, ""]
143 | ScalarLike = Shaped[ArrayLike, ""]
144 |
145 | # Left: new-style typed keys; right: old-style keys. See JEP 9263.
146 | PRNGKeyArray = Union[Key[Array, ""], UInt32[Array, "2"]]
147 | ```
148 |
149 | Recalling that shape-and-dtype specified jaxtyping arrays can be nested, this means that e.g. you can annotate the output of `jax.random.split` with `Shaped[PRNGKeyArray, "2"]`, or e.g. an integer scalar with `Int[Scalar, ""]`.
150 |
151 | Note that `jaxtyping.{Scalar, ScalarLike, PRNGKeyArray}` are only available if JAX has been installed.
152 |
--------------------------------------------------------------------------------
/docs/api/pytree.md:
--------------------------------------------------------------------------------
1 | # PyTree annotations
2 |
3 | :::jaxtyping.PyTree
4 | options:
5 | members: []
6 |
7 | ---
8 |
9 | :::jaxtyping.PyTreeDef
10 | options:
11 | members: []
12 |
13 | ---
14 |
15 | ## Path-dependent shapes
16 |
17 | The prefix `?` may be used to indicate that the axis size can depend on which leaf of a PyTree the array is at. For example:
18 | ```python
19 | def f(
20 | x: PyTree[Shaped[Array, "?foo"], "T"],
21 | y: PyTree[Shaped[Array, "?foo"], "T"],
22 | ):
23 | pass
24 | ```
25 | The above demands that `x` and `y` have matching PyTree structures (due to the `T` annotation), and that their leaves must all be one-dimensional arrays, *and that the corresponding pairs of leaves in `x` and `y` must have the same size as each other*.
26 |
27 | Thus the following is allowed:
28 | ```python
29 | x0 = jnp.arange(3)
30 | x1 = jnp.arange(5)
31 |
32 | y0 = jnp.arange(3) + 1
33 | y1 = jnp.arange(5) + 1
34 |
35 | f((x0, x1), (y0, y1)) # x0 matches y0, and x1 matches y1. All good!
36 | ```
37 |
38 | But this is not:
39 | ```python
40 | f((x1, x1), (y0, y1)) # x1 does not have a size matching y0!
41 | ```
42 |
43 | Internally, all that is happening is that `foo` is replaced with `0foo` for the first leaf, `1foo` for the next leaf, etc., so that each leaf gets a unique version of the name.
44 |
45 | ---
46 |
47 | Note that `jaxtyping.{PyTree, PyTreeDef}` are only available if JAX has been installed.
48 |
--------------------------------------------------------------------------------
/docs/api/runtime-type-checking.md:
--------------------------------------------------------------------------------
1 | # Runtime type checking
2 |
3 | (See the [FAQ](../faq.md) for details on static type checking.)
4 |
5 | Runtime type checking **synergises beautifully with `jax.jit`!** All shape checks will be performed only whilst tracing, and will not impact runtime performance.
6 |
7 | There are two approaches: either use [`jaxtyping.jaxtyped`][] to typecheck a single function, or [`jaxtyping.install_import_hook`][] to typecheck a whole codebase.
8 |
9 | In either case, the actual business of checking types is performed with the help of a runtime type-checking library. The two most popular are [beartype](https://github.com/beartype/beartype) and [typeguard](https://github.com/agronholm/typeguard). (If using typeguard, then specifically the version `2.*` series should be used. Later versions -- `3` and `4` -- have some known issues.)
10 |
11 | !!! warning
12 |
13 | Avoid using `from __future__ import annotations`, or stringified type annotations, where possible. These are largely incompatible with runtime type checking. See also [this FAQ entry](../faq.md#dataclass-annotations-arent-being-checked-properly).
14 |
15 | ---
16 |
17 | ::: jaxtyping.jaxtyped
18 |
19 | ---
20 |
21 | ::: jaxtyping.install_import_hook
22 |
23 | ---
24 |
25 | #### Pytest hook
26 |
27 | The import hook can be installed at test-time only, as a pytest hook. From the command line the syntax is:
28 | ```
29 | pytest --jaxtyping-packages=foo,bar.baz,beartype.beartype
30 | ```
31 | or in `pyproject.toml`:
32 | ```toml
33 | [tool.pytest.ini_options]
34 | addopts = "--jaxtyping-packages=foo,bar.baz,beartype.beartype"
35 | ```
36 | or in `pytest.ini`:
37 | ```ini
38 | [pytest]
39 | addopts = --jaxtyping-packages=foo,bar.baz,beartype.beartype
40 | ```
41 | This example will apply the import hook to all modules whose names start with either `foo` or `bar.baz`. The typechecker used in this example is `beartype.beartype`.
42 |
43 | #### IPython extension
44 |
45 | If you are running in an IPython environment (for example a Jupyter or Colab notebook), then the jaxtyping hook can be automatically ran via a custom magic:
46 | ```python
47 | import jaxtyping
48 | %load_ext jaxtyping
49 | %jaxtyping.typechecker beartype.beartype # or any other runtime type checker
50 | ```
51 | Place this at the start of your notebook -- everything that is directly defined in the notebook, after this magic is run, will be hook'd.
52 |
53 | #### Other runtime type-checking libraries
54 |
55 | Beartype and typeguard happen to be the two most popular runtime type-checking libraries (at least at time of writing), but jaxtyping should be compatible with all runtime type checkers out-of-the-box. The runtime type-checking library just needs to provide a type-checking decorator (analgous to `beartype.beartype` or `typeguard.typechecked`), and perform `isinstance` checks against jaxtyping's types.
56 |
--------------------------------------------------------------------------------
/docs/faq.md:
--------------------------------------------------------------------------------
1 | # FAQ
2 |
3 | ## Is jaxtyping compatible with static type checkers like `mypy`/`pyright`/`pytype`?
4 |
5 | There is partial support for these. An annotation of the form `dtype[array, shape]` should be treated as just `array` by a static type checker. Unfortunately full dtype/shape checking is beyond the scope of what static type checking is currently capable of.
6 |
7 | (Note that at time of writing, `pytype` has a bug in that `dtype[array, shape]` is sometimes treated as `Any` rather than `array`. `mypy` and `pyright` both work fine.)
8 |
9 | ## How does jaxtyping interact with `jax.jit`?
10 |
11 | jaxtyping and `jax.jit` synergise beautifully.
12 |
13 | When calling JAX operations wrapped in a `jax.jit`, then the dtype/shape-checking will happen at trace time. (When JAX traces your function prior to compiling it.) The actual compiled code does not have any dtype/shape-checking, and will therefore still be just as fast as before!
14 |
15 | ## `flake8` or Ruff are throwing an error.
16 |
17 | In type annotations, strings are used for two different things. Sometimes they're strings. Sometimes they're "forward references", used to refer to a type that will be defined later.
18 |
19 | Some tooling in the Python ecosystem assumes that only the latter is true, and will throw spurious errors if you try to use a string just as a string (like we do).
20 |
21 | In the case of `flake8`, or Ruff, this can be resolved. Multi-dimensional arrays (e.g. `Float32[Array, "b c"]`) will throw a very unusual error (F722, syntax error in forward annotation), so you can safely just disable this particular error globally. Uni-dimensional arrays (e.g. `Float32[Array, "x"]`) will throw an error that's actually useful (F821, undefined name), so instead of disabling this globally, you should instead prepend a space to the start of your shape, e.g. `Float32[Array, " x"]`. `jaxtyping` will treat this in the same way, whilst `flake8` will now throw an F722 error that you can disable as before.
22 |
23 | ## Dataclass annotations aren't being checked properly.
24 |
25 | Stringified dataclass annotations, e.g.
26 | ```python
27 | @dataclass()
28 | class Foo:
29 | x: "int"
30 | ```
31 | will be silently skipped without checking them. This is because these are essentially impossible to resolve at runtime. Such stringified annotations typically occur either when using them for forward references, or when using `from __future__ import annotations`. (You should essentially never use the latter, it is largely incompatible with runtime type checking and as such is [being replaced in Python 3.13](https://peps.python.org/pep-0649/).)
32 |
33 | Partially stringified dataclass annotations, e.g.
34 | ```python
35 | @dataclass()
36 | class Foo:
37 | x: tuple["int"]
38 | ```
39 | will likely raise an error, and must not be used at all.
40 |
41 | ## Does jaxtyping use [PEP 646](https://www.python.org/dev/peps/pep-0646/) (variadic generics)?
42 |
43 | The intention of PEP 646 was to make it possible for static type checkers to perform shape checks of arrays. Unfortunately, this still isn't yet practical, so jaxtyping deliberately does not use this. (Yet?)
44 |
45 | The real problem is that Python's static typing ecosystem is a complicated collection of edge cases. Many of them block ML/scientific computing in particular. For example:
46 |
47 | 1. The static type system is intrinsically not expressive enough to describe operations like concatenation, stacking, or broadcasting.
48 |
49 | 2. Axes have to be lifted to type-level variables. Meanwhile the approach taken in libraries like `jaxtyping` and [TorchTyping](https://github.com/patrick-kidger/torchtyping) is to use value-level variables for types: because that's what the underlying JAX, PyTorch etc. libraries use! As such, making a static type checker work with these libraries would require either fundamentally rewriting these libraries, or exhaustively maintaining type stubs for them, and would *still* require a `typing.cast` any time you use anything unstubbed (e.g. any third party library, or part of your codebase you haven't typed yet). This is a huge maintenance burden.
50 |
51 | 3. Static type checkers have a variety of bugs that affect this use case. `mypy` doesn't support `Protocol`s correctly. `pyright` doesn't support genericised subprotocols. etc.
52 |
53 | 4. Variadic generics exist. Variadic protocols do not. (It's not clear that these were contemplated.)
54 |
55 | 5. The syntax for static typing is a little verbose. You have to write things like `Array[Float32, Unpack[AnyShape], Literal[3], Height, Width]` instead of `Float32[Array, "... 3 height width"]`.
56 |
57 | 6. [The underlying type system has flaws](https://github.com/patrick-kidger/torchtyping/issues/37#issuecomment-1153294196).
58 | [The numeric tower is broken](https://stackoverflow.com/a/69383462);
59 | [int is not a number](https://github.com/python/mypy/issues/3186#issuecomment-885718629);
60 | [virtual base classes don't work](https://github.com/python/mypy/issues/2922);
61 | [complex lies about having comparison operations, so type checkers have to lie about that lie in order to remove them again](https://beartype.github.io/numerary/0.4/whytho/);
62 | `typing.*` don't work with `isinstance`;
63 | co/contra-variance are baked into containers (not specified at use-time);
64 | `dict` is variadic despite... not being variadic;
65 | bool is a subclass of int (!);
66 | ... etc. etc.
67 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # Getting started
2 |
3 | jaxtyping is a library providing type annotations **and runtime type-checking** for:
4 |
5 | 1. shape and dtype of [JAX](https://github.com/google/jax) arrays; *(Now also supports PyTorch, NumPy, MLX, and TensorFlow!)*
6 | 2. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html).
7 |
8 | ## Installation
9 |
10 | ```bash
11 | pip install jaxtyping
12 | ```
13 |
14 | Requires Python 3.10+.
15 |
16 | JAX is an optional dependency, required for a few JAX-specific types. If JAX is not installed then these will not be available, but you may still use jaxtyping to provide shape/dtype annotations for PyTorch/NumPy/TensorFlow/etc.
17 |
18 | The annotations provided by jaxtyping are compatible with runtime type-checking packages, so it is common to also install one of these. The two most popular are [typeguard](https://github.com/agronholm/typeguard) (which exhaustively checks every argument) and [beartype](https://github.com/beartype/beartype) (which checks random pieces of arguments).
19 |
20 | ## Example
21 |
22 | ```python
23 | from jaxtyping import Array, Float, PyTree
24 |
25 | # Accepts floating-point 2D arrays with matching axes
26 | # You can replace `Array` with `torch.Tensor` etc.
27 | def matrix_multiply(x: Float[Array, "dim1 dim2"],
28 | y: Float[Array, "dim2 dim3"]
29 | ) -> Float[Array, "dim1 dim3"]:
30 | ...
31 |
32 | def accepts_pytree_of_ints(x: PyTree[int]):
33 | ...
34 |
35 | def accepts_pytree_of_arrays(x: PyTree[Float[Array, "batch c1 c2"]]):
36 | ...
37 | ```
38 |
39 | ## Next steps
40 |
41 | Have a read of the [Array annotations](./api/array.md) documentation on the left-hand bar!
42 |
43 | ## See also: other libraries in the JAX ecosystem
44 |
45 | **Always useful**
46 | [Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX!
47 |
48 | **Deep learning**
49 | [Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
50 | [Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
51 | [Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
52 | [paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees.
53 |
54 | **Scientific computing**
55 | [Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
56 | [Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
57 | [Lineax](https://github.com/patrick-kidger/lineax): linear solvers.
58 | [BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.
59 | [sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.
60 | [PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)
61 |
62 | **Awesome JAX**
63 | [Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects.
64 |
--------------------------------------------------------------------------------
/jaxtyping/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Google LLC
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | import functools as ft
21 | import importlib.metadata
22 | import typing
23 | from typing import TypeAlias, Union
24 |
25 | from ._array_types import (
26 | AbstractArray as AbstractArray,
27 | AbstractDtype as AbstractDtype,
28 | get_array_name_format as get_array_name_format,
29 | make_numpy_struct_dtype as make_numpy_struct_dtype,
30 | set_array_name_format as set_array_name_format,
31 | )
32 | from ._config import config as config
33 | from ._decorator import jaxtyped as jaxtyped
34 | from ._errors import (
35 | AnnotationError as AnnotationError,
36 | TypeCheckError as TypeCheckError,
37 | )
38 | from ._import_hook import install_import_hook as install_import_hook
39 | from ._ipython_extension import load_ipython_extension as load_ipython_extension
40 | from ._storage import print_bindings as print_bindings
41 |
42 |
43 | if typing.TYPE_CHECKING:
44 | from jax import Array as Array
45 | from jax.tree_util import PyTreeDef as PyTreeDef
46 | from jax.typing import ArrayLike as ArrayLike, DTypeLike as DTypeLike
47 |
48 | # Introduce an indirection so that we can `import X as X` to make it clear that
49 | # these are public.
50 | from ._indirection import (
51 | BFloat16 as BFloat16,
52 | Bool as Bool,
53 | Complex as Complex,
54 | Complex64 as Complex64,
55 | Complex128 as Complex128,
56 | Float as Float,
57 | Float8e4m3b11fnuz as Float8e4m3b11fnuz,
58 | Float8e4m3fn as Float8e4m3fn,
59 | Float8e4m3fnuz as Float8e4m3fnuz,
60 | Float8e5m2 as Float8e5m2,
61 | Float8e5m2fnuz as Float8e5m2fnuz,
62 | Float16 as Float16,
63 | Float32 as Float32,
64 | Float64 as Float64,
65 | Inexact as Inexact,
66 | Int as Int,
67 | Int2 as Int2,
68 | Int4 as Int4,
69 | Int8 as Int8,
70 | Int16 as Int16,
71 | Int32 as Int32,
72 | Int64 as Int64,
73 | Integer as Integer,
74 | Key as Key,
75 | Num as Num,
76 | PRNGKeyArray as PRNGKeyArray,
77 | Real as Real,
78 | Scalar as Scalar,
79 | ScalarLike as ScalarLike,
80 | Shaped as Shaped,
81 | UInt as UInt,
82 | UInt2 as UInt2,
83 | UInt4 as UInt4,
84 | UInt8 as UInt8,
85 | UInt16 as UInt16,
86 | UInt32 as UInt32,
87 | UInt64 as UInt64,
88 | )
89 |
90 | # Set up to deliberately confuse a static type checker.
91 | PyTree: TypeAlias = getattr(typing, "foo" + "bar")
92 | # What's going on with this madness?
93 | #
94 | # At static-type-checking-time, we want `PyTree` to be a type for which both
95 | # `PyTree` and `PyTree[Foo]` are equivalent to `Any`.
96 | # (The intention is that `PyTree` be a runtime-only type; there's no real way to
97 | # do more with static type checkers.)
98 | #
99 | # Unfortunately, this isn't possible: `Any` isn't subscriptable. And there's no
100 | # equivalent way we can fake this using typing annotations. (In some sense the
101 | # closest thing would be a `Protocol[T]` with no methods, but that's actually the
102 | # opposite of what we want: that ends up allowing nothing at all.)
103 | #
104 | # The good news for us is that static type checkers have an internal escape hatch.
105 | # If they can't figure out what a type is, then they just give up and allow
106 | # anything. (I believe this is sometimes called `Unknown`.) Thus, this odd-looking
107 | # annotation, which static type checkers aren't smart enough to resolve.
108 | else:
109 | from ._array_types import (
110 | BFloat16 as BFloat16,
111 | Bool as Bool,
112 | Complex as Complex,
113 | Complex64 as Complex64,
114 | Complex128 as Complex128,
115 | Float as Float,
116 | Float8e4m3b11fnuz as Float8e4m3b11fnuz,
117 | Float8e4m3fn as Float8e4m3fn,
118 | Float8e4m3fnuz as Float8e4m3fnuz,
119 | Float8e5m2 as Float8e5m2,
120 | Float8e5m2fnuz as Float8e5m2fnuz,
121 | Float16 as Float16,
122 | Float32 as Float32,
123 | Float64 as Float64,
124 | Inexact as Inexact,
125 | Int as Int,
126 | Int2 as Int2,
127 | Int4 as Int4,
128 | Int8 as Int8,
129 | Int16 as Int16,
130 | Int32 as Int32,
131 | Int64 as Int64,
132 | Integer as Integer,
133 | Key as Key,
134 | Num as Num,
135 | Real as Real,
136 | Shaped as Shaped,
137 | UInt as UInt,
138 | UInt2 as UInt2,
139 | UInt4 as UInt4,
140 | UInt8 as UInt8,
141 | UInt16 as UInt16,
142 | UInt32 as UInt32,
143 | UInt64 as UInt64,
144 | )
145 |
146 | if hasattr(typing, "GENERATING_DOCUMENTATION"):
147 |
148 | class Array:
149 | pass
150 |
151 | Array.__module__ = "builtins"
152 | Array.__qualname__ = "Array"
153 |
154 | class ArrayLike:
155 | pass
156 |
157 | ArrayLike.__module__ = "builtins"
158 | ArrayLike.__qualname__ = "ArrayLike"
159 |
160 | class PRNGKeyArray:
161 | pass
162 |
163 | PRNGKeyArray.__module__ = "builtins"
164 | PRNGKeyArray.__qualname__ = "PRNGKeyArray"
165 |
166 | from ._pytree_type import PyTree as PyTree
167 |
168 | class PyTreeDef:
169 | """Alias for `jax.tree_util.PyTreeDef`, which is the type of the
170 | return from `jax.tree_util.tree_structure(...)`.
171 | """
172 |
173 | if typing.GENERATING_DOCUMENTATION != "jaxtyping":
174 | # Equinox etc. docs get just `PyTreeDef`.
175 | # jaxtyping docs get `jaxtyping.PyTreeDef`.
176 | PyTreeDef.__qualname__ = "PyTreeDef"
177 | PyTreeDef.__module__ = "builtins"
178 |
179 | @ft.cache
180 | def __getattr__(item):
181 | if item == "Array":
182 | import jax
183 |
184 | return jax.Array
185 | elif item == "ArrayLike":
186 | import jax.typing
187 |
188 | return jax.typing.ArrayLike
189 | elif item == "PRNGKeyArray":
190 | # New-style `jax.random.key` have scalar shape and dtype `key`.
191 | # Old-style `jax.random.PRNGKey` have shape `(2,)` and dtype
192 | # `uint32`.
193 | import jax
194 |
195 | return Union[Key[jax.Array, ""], UInt32[jax.Array, "2"]]
196 | elif item == "DTypeLike":
197 | import jax.typing
198 |
199 | return jax.typing.DTypeLike
200 | elif item == "Scalar":
201 | import jax
202 |
203 | return Shaped[jax.Array, ""]
204 | elif item == "ScalarLike":
205 | from . import ArrayLike
206 |
207 | return Shaped[ArrayLike, ""]
208 | elif item == "PyTree":
209 | from ._pytree_type import PyTree
210 |
211 | return PyTree
212 | elif item == "PyTreeDef":
213 | import jax.tree_util
214 |
215 | return jax.tree_util.PyTreeDef
216 | else:
217 | raise AttributeError(f"module jaxtyping has no attribute {item!r}")
218 |
219 |
220 | __version__ = importlib.metadata.version("jaxtyping")
221 |
--------------------------------------------------------------------------------
/jaxtyping/_array_types.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Google LLC
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | import copyreg
21 | import enum
22 | import functools as ft
23 | import importlib.util
24 | import re
25 | import sys
26 | import types
27 | import typing
28 | from dataclasses import dataclass
29 | from typing import (
30 | Any,
31 | get_args,
32 | get_origin,
33 | Literal,
34 | NoReturn,
35 | Optional,
36 | TypeVar,
37 | Union,
38 | )
39 |
40 |
41 | # Bit of a hack, but jaxtyping provides nicer error messages than typeguard. This means
42 | # we sometimes want to use it as our runtime type checker everywhere, even in non-array
43 | # use-cases, for which numpy is too heavy a dependency.
44 | # Honestly we should probably consider factoring out part of jaxtyping into a separate
45 | # package. (Specifically (a) the multi-argument checking and (b) the better error
46 | # messages and (c) the import hook that places the checker on the bottom of the
47 | # decorator stack.) And resist the urge to write our own runtime type-checker, I really
48 | # don't want to have to keep that up-to-date with changes in the Python typing spec...
49 | if importlib.util.find_spec("numpy") is not None:
50 | import numpy as np
51 |
52 | from ._errors import AnnotationError
53 | from ._storage import (
54 | get_shape_memo,
55 | get_treeflatten_memo,
56 | get_treepath_memo,
57 | set_shape_memo,
58 | )
59 |
60 |
61 | _array_name_format = "dtype_and_shape"
62 |
63 |
64 | def get_array_name_format():
65 | return _array_name_format
66 |
67 |
68 | def set_array_name_format(value):
69 | global _array_name_format
70 | _array_name_format = value
71 |
72 |
73 | _any_dtype = object()
74 |
75 | _anonymous_dim = object()
76 | _anonymous_variadic_dim = object()
77 |
78 |
79 | class _DimType(enum.Enum):
80 | named = enum.auto()
81 | fixed = enum.auto()
82 | symbolic = enum.auto()
83 |
84 |
85 | @dataclass(frozen=True)
86 | class _NamedDim:
87 | name: str
88 | broadcastable: bool
89 | treepath: Any
90 |
91 |
92 | @dataclass(frozen=True)
93 | class _NamedVariadicDim:
94 | name: str
95 | broadcastable: bool
96 | treepath: Any
97 |
98 |
99 | @dataclass(frozen=True)
100 | class _FixedDim:
101 | size: int
102 | broadcastable: bool
103 |
104 |
105 | @dataclass(frozen=True)
106 | class _SymbolicDim:
107 | elem: Any
108 | broadcastable: bool
109 |
110 |
111 | _AbstractDimOrVariadicDim = Union[
112 | Literal[_anonymous_dim],
113 | Literal[_anonymous_variadic_dim],
114 | _NamedDim,
115 | _NamedVariadicDim,
116 | _FixedDim,
117 | _SymbolicDim,
118 | ]
119 | _AbstractDim = Union[Literal[_anonymous_dim], _NamedDim, _FixedDim, _SymbolicDim]
120 |
121 |
122 | def _check_dims(
123 | cls_dims: list[_AbstractDim],
124 | obj_shape: tuple[int, ...],
125 | single_memo: dict[str, int],
126 | arg_memo: dict[str, Any],
127 | ) -> str:
128 | assert len(cls_dims) == len(obj_shape)
129 | for cls_dim, obj_size in zip(cls_dims, obj_shape):
130 | if cls_dim is _anonymous_dim:
131 | pass
132 | elif cls_dim.broadcastable and obj_size == 1:
133 | pass
134 | elif type(cls_dim) is _FixedDim:
135 | if cls_dim.size != obj_size:
136 | return f"the dimension size {obj_size} does not equal {cls_dim.size} as expected by the type hint" # noqa: E501
137 | elif type(cls_dim) is _SymbolicDim:
138 | try:
139 | # Support f-string syntax.
140 | # https://stackoverflow.com/a/53671539/22545467
141 | elem = eval(f"f'{cls_dim.elem}'", arg_memo.copy())
142 | # Make a copy to avoid `__builtins__` getting added as a key.
143 | eval_size = eval(elem, single_memo.copy())
144 | except NameError as e:
145 | raise AnnotationError(
146 | f"Cannot process symbolic axis '{cls_dim.elem}' as "
147 | "some axis names have not been processed. "
148 | "Have you applied the `jaxtyped` decorator? "
149 | "In practice you should usually only use symbolic axes in "
150 | "annotations for return types, referring only to axes "
151 | "annotated for arguments."
152 | ) from e
153 | if eval_size != obj_size:
154 | return f"the dimension size {obj_size} does not equal the existing value of {cls_dim.elem}={eval_size}" # noqa: E501
155 | else:
156 | assert type(cls_dim) is _NamedDim
157 | if cls_dim.treepath:
158 | name = get_treepath_memo() + cls_dim.name
159 | else:
160 | name = cls_dim.name
161 | try:
162 | cls_size = single_memo[name]
163 | except KeyError:
164 | single_memo[name] = obj_size
165 | else:
166 | if cls_size != obj_size:
167 | return f"the size of dimension {cls_dim.name} is {obj_size} which does not equal the existing value of {cls_size}" # noqa: E501
168 | return ""
169 |
170 |
171 | def _dtype_is_numpy_struct_array(dtype):
172 | return dtype.type.__name__ == "void" and dtype is not np.dtype(np.void)
173 |
174 |
175 | class _MetaAbstractArray(type):
176 | _skip_instancecheck: bool = False
177 |
178 | def make_transparent(cls):
179 | cls._skip_instancecheck = True
180 |
181 | def __instancecheck__(cls, obj: Any) -> bool:
182 | return cls.__instancecheck_str__(obj) == ""
183 |
184 | def __instancecheck_str__(cls, obj: Any) -> str:
185 | if cls._skip_instancecheck:
186 | return ""
187 | if cls.array_type is Any:
188 | if not (hasattr(obj, "shape") and hasattr(obj, "dtype")):
189 | return "this value does not have both `shape` and `dtype` attributes."
190 | else:
191 | if not isinstance(obj, cls.array_type):
192 | return f"this value is not an instance of the underlying array type {cls.array_type}" # noqa: E501
193 | if get_treeflatten_memo():
194 | return ""
195 |
196 | if hasattr(obj.dtype, "type") and hasattr(obj.dtype.type, "__name__"):
197 | # JAX, numpy
198 | dtype = obj.dtype.type.__name__
199 | # numpy structured array is strictly a subtype of np.void
200 | if _dtype_is_numpy_struct_array(obj.dtype):
201 | dtype = str(obj.dtype)
202 | elif hasattr(obj.dtype, "as_numpy_dtype"):
203 | # TensorFlow
204 | dtype = obj.dtype.as_numpy_dtype.__name__
205 | else:
206 | # Everyone else, including PyTorch.
207 | # This offers an escape hatch for anyone looking to use jaxtyping for their
208 | # own array-like types.
209 | dtype = obj.dtype
210 | if not isinstance(dtype, str):
211 | *_, dtype = repr(obj.dtype).rsplit(".", 1)
212 |
213 | if cls.dtypes is not _any_dtype:
214 | in_dtypes = False
215 | for cls_dtype in cls.dtypes:
216 | if type(cls_dtype) is str:
217 | in_dtypes = dtype == cls_dtype
218 | elif type(cls_dtype) is re.Pattern:
219 | in_dtypes = bool(cls_dtype.match(dtype))
220 | else:
221 | assert False
222 | if in_dtypes:
223 | break
224 | if not in_dtypes:
225 | if len(cls.dtypes) == 1:
226 | return f"this array has dtype {dtype}, not {cls.dtypes[0]} as expected by the type hint" # noqa: E501
227 | else:
228 | return f"this array has dtype {dtype}, not any of {cls.dtypes} as expected by the type hint" # noqa: E501
229 |
230 | single_memo, variadic_memo, pytree_memo, arg_memo = get_shape_memo()
231 | single_memo_bak = single_memo.copy()
232 | variadic_memo_bak = variadic_memo.copy()
233 | pytree_memo_bak = pytree_memo.copy()
234 | arg_memo_bak = arg_memo.copy()
235 | try:
236 | check = cls._check_shape(obj, single_memo, variadic_memo, arg_memo)
237 | except Exception:
238 | set_shape_memo(
239 | single_memo_bak, variadic_memo_bak, pytree_memo_bak, arg_memo_bak
240 | )
241 | raise
242 | if check == "":
243 | return check
244 | else:
245 | set_shape_memo(
246 | single_memo_bak, variadic_memo_bak, pytree_memo_bak, arg_memo_bak
247 | )
248 | return check
249 |
250 | def _check_shape(
251 | cls,
252 | obj,
253 | single_memo: dict[str, int],
254 | variadic_memo: dict[str, tuple[bool, tuple[int, ...]]],
255 | arg_memo: dict[str, Any],
256 | ) -> str:
257 | if cls.index_variadic is None:
258 | if len(obj.shape) != len(cls.dims):
259 | return f"this array has {len(obj.shape)} dimensions, not the {len(cls.dims)} expected by the type hint" # noqa: E501
260 | return _check_dims(cls.dims, obj.shape, single_memo, arg_memo)
261 | else:
262 | if len(obj.shape) < len(cls.dims) - 1:
263 | return f"this array has {len(obj.shape)} dimensions, which is fewer than {len(cls.dims) - 1} that is the minimum expected by the type hint" # noqa: E501
264 | i = cls.index_variadic
265 | j = -(len(cls.dims) - i - 1)
266 | if j == 0:
267 | j = None
268 | prefix_check = _check_dims(
269 | cls.dims[:i], obj.shape[:i], single_memo, arg_memo
270 | )
271 | if prefix_check != "":
272 | return prefix_check
273 | if j is not None:
274 | suffix_check = _check_dims(
275 | cls.dims[j:], obj.shape[j:], single_memo, arg_memo
276 | )
277 | if suffix_check != "":
278 | return suffix_check
279 | variadic_dim = cls.dims[i]
280 | if variadic_dim is _anonymous_variadic_dim:
281 | return ""
282 | else:
283 | assert type(variadic_dim) is _NamedVariadicDim
284 | if variadic_dim.treepath:
285 | name = get_treepath_memo() + variadic_dim.name
286 | else:
287 | name = variadic_dim.name
288 | broadcastable = variadic_dim.broadcastable
289 | try:
290 | prev_broadcastable, prev_shape = variadic_memo[name]
291 | except KeyError:
292 | variadic_memo[name] = (broadcastable, obj.shape[i:j])
293 | return ""
294 | else:
295 | new_shape = obj.shape[i:j]
296 | if prev_broadcastable:
297 | try:
298 | broadcast_shape = np.broadcast_shapes(new_shape, prev_shape)
299 | except ValueError: # not broadcastable e.g. (3, 4) and (5,)
300 | return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which cannot be broadcast with the existing value of {prev_shape}" # noqa: E501
301 | if not broadcastable and broadcast_shape != new_shape:
302 | return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which the existing value of {prev_shape} cannot be broadcast to" # noqa: E501
303 | variadic_memo[name] = (broadcastable, broadcast_shape)
304 | else:
305 | if broadcastable:
306 | try:
307 | broadcast_shape = np.broadcast_shapes(
308 | new_shape, prev_shape
309 | )
310 | except ValueError: # not broadcastable e.g. (3, 4) and (5,)
311 | return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which cannot be broadcast with the existing value of {prev_shape}" # noqa: E501
312 | if broadcast_shape != prev_shape:
313 | return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which cannot be broadcast to the existing value of {prev_shape}" # noqa: E501
314 | else:
315 | if new_shape != prev_shape:
316 | return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which does not equal the existing value of {prev_shape}" # noqa: E501
317 | return ""
318 | assert False
319 |
320 |
321 | def _return_abstractarray():
322 | return AbstractArray
323 |
324 |
325 | def _pickle_array_annotation(x: type["AbstractArray"]):
326 | if x is AbstractArray:
327 | return _return_abstractarray, ()
328 | else:
329 | return x.dtype.__getitem__, ((x.array_type, x.dim_str),)
330 |
331 |
332 | copyreg.pickle(_MetaAbstractArray, _pickle_array_annotation)
333 |
334 |
335 | def _check_scalar(dtype, dtypes, dims):
336 | for dim in dims:
337 | if dim is not _anonymous_variadic_dim and not isinstance(
338 | dim, _NamedVariadicDim
339 | ):
340 | return False
341 | return (_any_dtype is dtypes) or any(d.startswith(dtype) for d in dtypes)
342 |
343 |
344 | class AbstractArray(metaclass=_MetaAbstractArray):
345 | """This is the base class of all shape-and-dtype-specified arrays, e.g. it's a base
346 | class for `Float32[Array, "foo"]`.
347 |
348 | This might be useful if you're trying to inspect type annotations yourself, e.g.
349 | you can check `issubclass(annotation, jaxtyping.AbstractArray)`.
350 | """
351 |
352 | # This is what it was defined with.
353 | dtype: type["AbstractDtype"]
354 | array_type: Any
355 | dim_str: str
356 |
357 | # This is the processed information we need for later typechecking.
358 | dtypes: list[str]
359 | dims: tuple[_AbstractDimOrVariadicDim, ...]
360 | index_variadic: Optional[int]
361 |
362 | def __new__(cls, *args, **kwargs):
363 | raise RuntimeError(
364 | "jaxtyping annotations cannot be instantiated -- they should be used for "
365 | "type hints only."
366 | )
367 |
368 |
369 | _not_made = object()
370 |
371 | _union_types = [Union]
372 | if sys.version_info >= (3, 10):
373 | _union_types.append(types.UnionType)
374 |
375 |
376 | @ft.lru_cache(maxsize=None)
377 | def _make_array_cached(array_type, dim_str, dtypes, name):
378 | if not isinstance(dim_str, str):
379 | raise ValueError(
380 | "Shape specification must be a string. Axes should be separated with "
381 | "spaces."
382 | )
383 | dims = []
384 | index_variadic = None
385 | for index, elem in enumerate(dim_str.split()):
386 | if "," in elem and "(" not in elem:
387 | # Common mistake.
388 | # Disable in the case that there's brackets to allow for function calls,
389 | # e.g. `min(foo,bar)`, in symbolic axes.
390 | raise ValueError("Axes should be separated with spaces, not commas")
391 | if elem.endswith("#"):
392 | raise ValueError(
393 | "As of jaxtyping v0.1.0, broadcastable axes are now denoted "
394 | "with a # at the start, rather than at the end"
395 | )
396 |
397 | if "..." in elem:
398 | if elem != "...":
399 | raise ValueError(
400 | "Anonymous multiple axes '...' must be used on its own; "
401 | f"got {elem}"
402 | )
403 | broadcastable = False
404 | variadic = True
405 | anonymous = True
406 | treepath = False
407 | dim_type = _DimType.named
408 | else:
409 | broadcastable = False
410 | variadic = False
411 | anonymous = False
412 | treepath = False
413 | while True:
414 | if len(elem) == 0:
415 | # This branch needed as just `_` is valid
416 | break
417 | first_char = elem[0]
418 | if first_char == "#":
419 | if broadcastable:
420 | raise ValueError(
421 | "Do not use # twice to denote broadcastability, e.g. "
422 | "`##foo` is not allowed"
423 | )
424 | broadcastable = True
425 | elem = elem[1:]
426 | elif first_char == "*":
427 | if variadic:
428 | raise ValueError(
429 | "Do not use * twice to denote accepting multiple "
430 | "axes, e.g. `**foo` is not allowed"
431 | )
432 | variadic = True
433 | elem = elem[1:]
434 | elif first_char == "_":
435 | if anonymous:
436 | raise ValueError(
437 | "Do not use _ twice to denote anonymity, e.g. `__foo` "
438 | "is not allowed"
439 | )
440 | anonymous = True
441 | elem = elem[1:]
442 | elif first_char == "?":
443 | if treepath:
444 | raise ValueError(
445 | "Do not use ? twice to denote dependence on location "
446 | "within a PyTree, e.g. `??foo` is not allowed"
447 | )
448 | treepath = True
449 | elem = elem[1:]
450 | # Allow e.g. `foo=4` as an alternate syntax for just `4`, so that one
451 | # can write e.g. `Float[Array, "rows=3 cols=4"]`
452 | elif elem.count("=") == 1:
453 | _, elem = elem.split("=")
454 | else:
455 | break
456 | if len(elem) == 0 or elem.isidentifier():
457 | dim_type = _DimType.named
458 | else:
459 | try:
460 | elem = int(elem)
461 | except ValueError:
462 | dim_type = _DimType.symbolic
463 | else:
464 | dim_type = _DimType.fixed
465 |
466 | if variadic:
467 | if index_variadic is not None:
468 | raise ValueError(
469 | "Cannot use variadic specifiers (`*name` or `...`) "
470 | "more than once."
471 | )
472 | index_variadic = index
473 |
474 | if dim_type is _DimType.fixed:
475 | if variadic:
476 | raise ValueError(
477 | "Cannot have a fixed axis bind to multiple axes, e.g. "
478 | "`*4` is not allowed."
479 | )
480 | if anonymous:
481 | raise ValueError(
482 | "Cannot have a fixed axis be anonymous, e.g. `_4` is not allowed."
483 | )
484 | if treepath:
485 | raise ValueError(
486 | "Cannot have a fixed axis have tree-path dependence, e.g. `?4` is "
487 | "not allowed."
488 | )
489 | elem = _FixedDim(elem, broadcastable)
490 | elif dim_type is _DimType.named:
491 | if anonymous:
492 | if broadcastable:
493 | raise ValueError(
494 | "Cannot have an axis be both anonymous and "
495 | "broadcastable, e.g. `#_` is not allowed."
496 | )
497 | if variadic:
498 | elem = _anonymous_variadic_dim
499 | else:
500 | elem = _anonymous_dim
501 | else:
502 | if variadic:
503 | elem = _NamedVariadicDim(elem, broadcastable, treepath)
504 | else:
505 | elem = _NamedDim(elem, broadcastable, treepath)
506 | else:
507 | assert dim_type is _DimType.symbolic
508 | if anonymous:
509 | raise ValueError(
510 | "Cannot have a symbolic axis be anonymous, e.g. "
511 | "`_foo+bar` is not allowed"
512 | )
513 | if variadic:
514 | raise ValueError(
515 | "Cannot have symbolic multiple-axes, e.g. "
516 | "`*foo+bar` is not allowed"
517 | )
518 | if treepath:
519 | raise ValueError(
520 | "Cannot have a symbolic axis with tree-path dependence, e.g. "
521 | "`?foo+bar` is not allowed"
522 | )
523 | elem = _SymbolicDim(elem, broadcastable)
524 | dims.append(elem)
525 | dims = tuple(dims)
526 |
527 | # Allow Python built-in numeric types.
528 | # TODO: do something more generic than this? Should we _make all types
529 | # that have `shape` and `dtype` attributes or something?
530 | array_origin = get_origin(array_type)
531 | if array_origin is not None:
532 | array_type = array_origin
533 | if array_type is bool:
534 | if _check_scalar("bool", dtypes, dims):
535 | return array_type
536 | else:
537 | return _not_made
538 | elif array_type is int:
539 | if _check_scalar("int", dtypes, dims):
540 | return array_type
541 | else:
542 | return _not_made
543 | elif array_type is float:
544 | if _check_scalar("float", dtypes, dims):
545 | return array_type
546 | else:
547 | return _not_made
548 | elif array_type is complex:
549 | if _check_scalar("complex", dtypes, dims):
550 | return array_type
551 | else:
552 | return _not_made
553 | elif array_type is np.bool_:
554 | if _check_scalar("bool", dtypes, dims):
555 | return array_type
556 | else:
557 | return _not_made
558 | elif array_type is np.generic or array_type is np.number:
559 | if _check_scalar("", dtypes, dims):
560 | return array_type
561 | else:
562 | return _not_made
563 | if array_type is not Any and issubclass(array_type, AbstractArray):
564 | if dtypes is _any_dtype:
565 | dtypes = array_type.dtypes
566 | elif array_type.dtypes is not _any_dtype:
567 | dtypes = tuple(x for x in dtypes if x in array_type.dtypes)
568 | if len(dtypes) == 0:
569 | raise ValueError(
570 | "A jaxtyping annotation cannot be extended with no overlapping "
571 | "dtypes. For example, `Bool[Float[Array, 'dim1'], 'dim2']` is an "
572 | "error. You probably want to make the outer wrapper be `Shaped`."
573 | )
574 | if array_type.index_variadic is not None:
575 | if index_variadic is None:
576 | index_variadic = array_type.index_variadic + len(dims)
577 | else:
578 | raise ValueError(
579 | "Cannot use variadic specifiers (`*name` or `...`) "
580 | "in both the original array and the extended array"
581 | )
582 | dims = dims + array_type.dims
583 | dim_str = dim_str + " " + array_type.dim_str
584 | array_type = array_type.array_type
585 | try:
586 | type_str = array_type.__name__
587 | except AttributeError:
588 | type_str = repr(array_type)
589 | if _array_name_format == "dtype_and_shape":
590 | name = f"{name}[{type_str}, '{dim_str}']"
591 | elif _array_name_format == "array":
592 | name = type_str
593 | else:
594 | raise ValueError(f"array_name_format {_array_name_format} not recognised")
595 |
596 | return (array_type, name, dtypes, dims, index_variadic, dim_str)
597 |
598 |
599 | def _make_array(x, dim_str, dtype):
600 | out = _make_array_cached(x, dim_str, dtype.dtypes, dtype.__name__)
601 |
602 | if type(out) is tuple:
603 | array_type, name, dtypes, dims, index_variadic, dim_str = out
604 |
605 | out = _MetaAbstractArray(
606 | name,
607 | (AbstractArray,),
608 | dict(
609 | dtype=dtype,
610 | array_type=array_type,
611 | dim_str=dim_str,
612 | dtypes=dtypes,
613 | dims=dims,
614 | index_variadic=index_variadic,
615 | ),
616 | )
617 | if getattr(typing, "GENERATING_DOCUMENTATION", "") in {"", "jaxtyping"}:
618 | out.__module__ = "jaxtyping"
619 | else:
620 | out.__module__ = "builtins"
621 |
622 | return out
623 |
624 |
625 | class _MetaAbstractDtype(type):
626 | def __instancecheck__(cls, obj: Any) -> NoReturn:
627 | raise AnnotationError(
628 | f"Do not use `isinstance(x, jaxtyping.{cls.__name__})`. If you want to "
629 | "check just the dtype of an array, then use "
630 | f'`jaxtyping.{cls.__name__}[jnp.ndarray, "..."]`.'
631 | )
632 |
633 | def __getitem__(cls, item: tuple[Any, str]):
634 | if not isinstance(item, tuple) or len(item) != 2:
635 | raise ValueError(
636 | "As of jaxtyping v0.2.0, type annotations must now include both an "
637 | "array type and a shape. For example `Float[Array, 'foo bar']`.\n"
638 | "Ellipsis can be used to accept any shape: `Float[Array, '...']`."
639 | )
640 | array_type, dim_str = item
641 | dim_str = dim_str.strip()
642 | if isinstance(array_type, TypeVar):
643 | bound = array_type.__bound__
644 | if bound is None:
645 | constraints = array_type.__constraints__
646 | if constraints == ():
647 | array_type = Any
648 | else:
649 | array_type = Union[constraints]
650 | else:
651 | array_type = bound
652 | del item
653 | if get_origin(array_type) in _union_types:
654 | out = [_make_array(x, dim_str, cls) for x in get_args(array_type)]
655 | out = tuple(x for x in out if x is not _not_made)
656 | if len(out) == 0:
657 | raise ValueError("Invalid jaxtyping type annotation.")
658 | elif len(out) == 1:
659 | (out,) = out
660 | else:
661 | out = Union[out]
662 | else:
663 | out = _make_array(array_type, dim_str, cls)
664 | if out is _not_made:
665 | raise ValueError("Invalid jaxtyping type annotation.")
666 | return out
667 |
668 |
669 | class AbstractDtype(metaclass=_MetaAbstractDtype):
670 | """This is the base class of all dtypes. This can be used to create your own custom
671 | collection of dtypes (analogous to `Float`, `Inexact` etc.)
672 |
673 | You must specify the class attribute `dtypes`. This can either be a string, a
674 | regex (as returned by `re.compile(...)`), or a tuple/list of strings/regexes.
675 |
676 | At runtime, the array or tensor's dtype is converted to a string and compared
677 | against the string (an exact match is required) or regex. (String matching is
678 | performed, rather than just e.g. `array.dtype == dtype`, to provide cross-library
679 | compatibility between JAX/PyTorch/etc.)
680 |
681 | !!! Example
682 |
683 | ```python
684 | class UInt8or16(AbstractDtype):
685 | dtypes = ["uint8", "uint16"]
686 |
687 | UInt8or16[Array, "shape"]
688 | ```
689 | which is essentially equivalent to
690 | ```python
691 | Union[UInt8[Array, "shape"], UInt16[Array, "shape"]]
692 | ```
693 | """
694 |
695 | dtypes: Union[Literal[_any_dtype], list[Union[str, re.Pattern]]]
696 |
697 | def __init__(self, *args, **kwargs):
698 | raise RuntimeError(
699 | "AbstractDtype cannot be instantiated. Perhaps you wrote e.g. "
700 | '`Float32("shape")` when you mean `Float32[jnp.ndarray, "shape"]`?'
701 | )
702 |
703 | def __init_subclass__(cls, **kwargs):
704 | super().__init_subclass__(**kwargs)
705 |
706 | dtypes: Union[Literal[_any_dtype], str, list[str]] = cls.dtypes
707 | if isinstance(dtypes, (str, re.Pattern)):
708 | dtypes = (dtypes,)
709 | elif dtypes is not _any_dtype:
710 | dtypes = tuple(dtypes)
711 | cls.dtypes = dtypes
712 |
713 |
714 | _prng_key = "prng_key"
715 | _bool = "bool"
716 | _bool_ = "bool_"
717 | _uint2 = "uint2"
718 | _uint4 = "uint4"
719 | _uint8 = "uint8"
720 | _uint16 = "uint16"
721 | _uint32 = "uint32"
722 | _uint64 = "uint64"
723 | _int2 = "int2"
724 | _int4 = "int4"
725 | _int8 = "int8"
726 | _int16 = "int16"
727 | _int32 = "int32"
728 | _int64 = "int64"
729 | # fp8 types exposed in Jax, see https://github.com/jax-ml/jax/blob/main/jax/_src/dtypes.py#L92-L97
730 | _float8_e4m3b11fnuz = "float8_e4m3b11fnuz"
731 | _float8_e4m3fn = "float8_e4m3fn"
732 | _float8_e4m3fnuz = "float8_e4m3fnuz"
733 | _float8_e5m2 = "float8_e5m2"
734 | _float8_e5m2fnuz = "float8_e5m2fnuz"
735 | _bfloat16 = "bfloat16"
736 | _float16 = "float16"
737 | _float32 = "float32"
738 | _float64 = "float64"
739 | _complex64 = "complex64"
740 | _complex128 = "complex128"
741 |
742 |
743 | def _make_dtype(_dtypes, name):
744 | class _Cls(AbstractDtype):
745 | dtypes = _dtypes
746 |
747 | _Cls.__name__ = name
748 | _Cls.__qualname__ = name
749 | if getattr(typing, "GENERATING_DOCUMENTATION", "") in {"", "jaxtyping"}:
750 | _Cls.__module__ = "jaxtyping"
751 | else:
752 | _Cls.__module__ = "builtins"
753 | return _Cls
754 |
755 |
756 | UInt2 = _make_dtype(_uint2, "UInt2")
757 | UInt4 = _make_dtype(_uint4, "UInt4")
758 | UInt8 = _make_dtype(_uint8, "UInt8")
759 | UInt16 = _make_dtype(_uint16, "UInt16")
760 | UInt32 = _make_dtype(_uint32, "UInt32")
761 | UInt64 = _make_dtype(_uint64, "UInt64")
762 | Int2 = _make_dtype(_int2, "Int2")
763 | Int4 = _make_dtype(_int4, "Int4")
764 | Int8 = _make_dtype(_int8, "Int8")
765 | Int16 = _make_dtype(_int16, "Int16")
766 | Int32 = _make_dtype(_int32, "Int32")
767 | Int64 = _make_dtype(_int64, "Int64")
768 | Float8e4m3b11fnuz = _make_dtype(_float8_e4m3b11fnuz, "Float8e4m3b11fnuz")
769 | Float8e4m3fn = _make_dtype(_float8_e4m3fn, "Float8e4m3fn")
770 | Float8e4m3fnuz = _make_dtype(_float8_e4m3fnuz, "Float8e4m3fnuz")
771 | Float8e5m2 = _make_dtype(_float8_e5m2, "Float8e5m2")
772 | Float8e5m2fnuz = _make_dtype(_float8_e5m2fnuz, "Float8e5m2fnuz")
773 | BFloat16 = _make_dtype(_bfloat16, "BFloat16")
774 | Float16 = _make_dtype(_float16, "Float16")
775 | Float32 = _make_dtype(_float32, "Float32")
776 | Float64 = _make_dtype(_float64, "Float64")
777 | Complex64 = _make_dtype(_complex64, "Complex64")
778 | Complex128 = _make_dtype(_complex128, "Complex128")
779 |
780 | bools = [_bool, _bool_]
781 | uints = [_uint2, _uint4, _uint8, _uint16, _uint32, _uint64]
782 | ints = [_int2, _int4, _int8, _int16, _int32, _int64]
783 | float8 = [
784 | _float8_e4m3b11fnuz,
785 | _float8_e4m3fn,
786 | _float8_e4m3fnuz,
787 | _float8_e5m2,
788 | _float8_e5m2fnuz,
789 | ]
790 | floats = float8 + [_bfloat16, _float16, _float32, _float64]
791 | complexes = [_complex64, _complex128]
792 |
793 | # We match NumPy's type hierarachy in what types to provide. See the diagram at
794 | # https://numpy.org/doc/stable/reference/arrays.scalars.html#scalars
795 |
796 | Bool = _make_dtype(bools, "Bool")
797 | UInt = _make_dtype(uints, "UInt")
798 | Int = _make_dtype(ints, "Int")
799 | Integer = _make_dtype(uints + ints, "Integer")
800 | Float = _make_dtype(floats, "Float")
801 | Complex = _make_dtype(complexes, "Complex")
802 | Inexact = _make_dtype(floats + complexes, "Inexact")
803 | Real = _make_dtype(floats + uints + ints, "Real")
804 | Num = _make_dtype(uints + ints + floats + complexes, "Num")
805 |
806 | Shaped = _make_dtype(_any_dtype, "Shaped")
807 |
808 | Key = _make_dtype(_prng_key, "Key")
809 |
810 |
811 | def make_numpy_struct_dtype(dtype: "np.dtype", name: str):
812 | """Creates a type annotation for [numpy structured array](https://numpy.org/doc/stable/user/basics.rec.html#structured-arrays)
813 | It performs an exact match on the name, order, and dtype of all its fields.
814 |
815 | !!! Example
816 |
817 | ```python
818 | label_t = np.dtype([('first', np.uint8), ('second', np.int8)])
819 | Label = make_numpy_struct_dtype(label_t, 'Label')
820 | ```
821 |
822 | after that, you can use it just like any other [`jaxtyping.AbstractDtype`][]:
823 |
824 | ```python
825 | a: Label[np.ndarray, 'a b'] = np.array([[(1, 0), (0, 1)]], dtype=label_t)
826 | ```
827 |
828 | **Arguments:**
829 |
830 | - `dtype`: The numpy structured dtype to use.
831 | - `name`: The name to use for the returned Python class.
832 |
833 | **Returns:**
834 |
835 | A type annotation with classname `name` that matches exactly `dtype` when used like
836 | any other [`jaxtyping.AbstractDtype`][].
837 | """
838 | if not (isinstance(dtype, np.dtype) and _dtype_is_numpy_struct_array(dtype)):
839 | raise ValueError(f"Expecting a numpy structured array dtype, not {dtype}")
840 | return _make_dtype(str(dtype), name)
841 |
--------------------------------------------------------------------------------
/jaxtyping/_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Union
3 |
4 |
5 | def _maybestr2bool(value: Union[bool, str], error: str) -> bool:
6 | if isinstance(value, bool):
7 | return value
8 | elif isinstance(value, str):
9 | if value.lower() in ("0", "false"):
10 | return False
11 | elif value.lower() in ("1", "true"):
12 | return True
13 | else:
14 | raise ValueError(error)
15 | else:
16 | raise ValueError(error)
17 |
18 |
19 | class _JaxtypingConfig:
20 | def __init__(self):
21 | self.update("jaxtyping_disable", os.environ.get("JAXTYPING_DISABLE", "0"))
22 | self.update(
23 | "jaxtyping_remove_typechecker_stack",
24 | os.environ.get("JAXTYPING_REMOVE_TYPECHECKER_STACK", "0"),
25 | )
26 |
27 | def update(self, item: str, value):
28 | if item.lower() == "jaxtyping_disable":
29 | msg = (
30 | "Unrecognised value for `JAXTYPING_DISABLE`. Valid values are "
31 | "`JAXTYPING_DISABLE=0` (the default) or `JAXTYPING_DISABLE=1` (to "
32 | "disable runtime type checking)."
33 | )
34 | self.jaxtyping_disable = _maybestr2bool(value, msg)
35 | elif item.lower() == "jaxtyping_remove_typechecker_stack":
36 | msg = (
37 | "Unrecognised value for `JAXTYPING_REMOVE_TYPECHECKER_STACK`. Valid "
38 | "values are `JAXTYPING_REMOVE_TYPECHECKER_STACK=0` (the default) or "
39 | "`JAXTYPING_REMOVE_TYPECHECKER_STACK=1` (to remove the stack frames "
40 | "from the typechecker in `jaxtyped(typechecker=...)`, when it raises a "
41 | "runtime type-checking error)."
42 | )
43 | self.jaxtyping_remove_typechecker_stack = _maybestr2bool(value, msg)
44 | else:
45 | raise ValueError(f"Unrecognised config value {item}")
46 |
47 |
48 | config = _JaxtypingConfig()
49 |
--------------------------------------------------------------------------------
/jaxtyping/_errors.py:
--------------------------------------------------------------------------------
1 | class TypeCheckError(TypeError):
2 | pass
3 |
4 |
5 | # Not inheriting from TypeError as that gets caught and re-reraised as just a TypeError
6 | # when using typeguard<3.
7 | class AnnotationError(Exception):
8 | pass
9 |
10 |
11 | TypeCheckError.__module__ = "jaxtyping"
12 | AnnotationError.__module__ = "jaxtyping"
13 |
--------------------------------------------------------------------------------
/jaxtyping/_import_hook.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Google LLC
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | # This source code is adapted from typeguard:
21 | # https://github.com/agronholm/typeguard/blob/0dd7f7510b7c694e66a0d17d1d58d185125bad5d/src/typeguard/importhook.py
22 | #
23 | # Copied and adapted in compliance with the terms of typeguard's MIT license.
24 | # The original license is reproduced here.
25 | #
26 | # ---------
27 | #
28 | # This is the MIT license: http://www.opensource.org/licenses/mit-license.php
29 | #
30 | # Copyright (c) Alex Grönholm
31 | #
32 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this
33 | # software and associated documentation files (the "Software"), to deal in the Software
34 | # without restriction, including without limitation the rights to use, copy, modify,
35 | # merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
36 | # permit persons to whom the Software is furnished to do so, subject to the following
37 | # conditions:
38 | #
39 | # The above copyright notice and this permission notice shall be included in all copies
40 | # or substantial portions of the Software.
41 | #
42 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
43 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
44 | # PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
45 | # HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
46 | # CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
47 | # OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
48 | #
49 | # ---------
50 |
51 |
52 | import ast
53 | import functools as ft
54 | import hashlib
55 | import sys
56 | from collections.abc import Sequence
57 | from importlib.abc import MetaPathFinder
58 | from importlib.machinery import SourceFileLoader
59 | from importlib.util import cache_from_source, decode_source
60 | from inspect import isclass
61 | from typing import Optional, Union
62 | from unittest.mock import patch
63 |
64 |
65 | # The name of this function is magical
66 | def _call_with_frames_removed(f, *args, **kwargs):
67 | return f(*args, **kwargs)
68 |
69 |
70 | def _optimized_cache_from_source(typechecker_hash, /, path, debug_override=None):
71 | # Version 2: change the position of the `@jaxtyped` decorator, so need a
72 | # different name to avoid hitting old __pycache__.
73 | # Version 3: now also annotating classes.
74 | # Version 4: I'm honestly not sure, but bumping this fixed some kind of odd error.
75 | # Maybe I changed something with hte classes part way through version 3?
76 | # Version 5: Added support for string-based `typechecker` argument.
77 | # Version 6: optimization tag now depends on `typechecker` argument, so that
78 | # changing the typechecker will hit a different cache.
79 | # Version 7: Using the same md5 hash of the `typechecker` argument
80 | # for importlib and decorator lookup.
81 | # Version 8: Now using new-style `jaxtyped(typechecker=...)` rather than old-style
82 | # double-decorators.
83 | # Version 9: Now reporting the correct source code lines. (Important when used with
84 | # a debugger.)
85 | return cache_from_source(
86 | path, debug_override, optimization=f"jaxtyping9{typechecker_hash}"
87 | )
88 |
89 |
90 | class Typechecker:
91 | lookup = {}
92 |
93 | def __init__(self, typechecker):
94 | if isinstance(typechecker, str):
95 | # If the typechecker is a string, then we parse it
96 | string_to_eval = (
97 | "def f(x, *args, **kwargs):\n"
98 | + f" import {typechecker.split('.', 1)[0]}\n"
99 | + f" return {typechecker}(x, *args, **kwargs)"
100 | )
101 |
102 | # md5 hashing instead of __hash__
103 | # because __hash__ is different for each Python session
104 | self.hash = hashlib.md5(typechecker.encode("utf-8")).hexdigest()
105 |
106 | vars = {}
107 | exec(string_to_eval, {}, vars)
108 | Typechecker.lookup[self.hash] = vars["f"]
109 |
110 | elif typechecker is None:
111 | # If it is None, ignore it silently (use dummy decorator)
112 | self.hash = "0"
113 | Typechecker.lookup[self.hash] = lambda x, *_, **__: x
114 | else:
115 | # Passed typechecker is invalid
116 | raise TypeError(
117 | "Jaxtyping typechecker has to be either a string or a None."
118 | )
119 |
120 | def get_hash(self):
121 | return self.hash
122 |
123 | def get_ast(self):
124 | # Note that we compile AST only if we missed importlib cache.
125 | # No caching on this function! We modify the return type every time, with
126 | # its appropriate source code location.
127 | return (
128 | ast.parse(
129 | f"@jaxtyping.jaxtyped(typechecker=jaxtyping._import_hook.Typechecker.lookup['{self.hash}'])\n"
130 | "def _():\n ..."
131 | )
132 | .body[0]
133 | .decorator_list[0]
134 | )
135 |
136 |
137 | class JaxtypingTransformer(ast.NodeVisitor):
138 | def __init__(self, *, typechecker: Typechecker) -> None:
139 | self._parents: list[ast.AST] = []
140 | self._typechecker = typechecker
141 |
142 | def visit_Module(self, node: ast.Module):
143 | # Insert "import jaxtyping" after any "from __future__ ..." imports
144 | for i, child in enumerate(node.body):
145 | if isinstance(child, ast.ImportFrom) and child.module == "__future__":
146 | continue
147 | elif isinstance(child, ast.Expr) and isinstance(child.value, ast.Constant):
148 | continue # module docstring
149 | else:
150 | node.body.insert(i, ast.Import(names=[ast.alias("jaxtyping", None)]))
151 | break
152 |
153 | self._parents.append(node)
154 | self.generic_visit(node)
155 | self._parents.pop()
156 | return node
157 |
158 | def visit_ClassDef(self, node: ast.ClassDef):
159 | # Place at the start of the decorator list, so that `@dataclass` decorators get
160 | # called first.
161 | decorator = self._typechecker.get_ast()
162 | ast.copy_location(decorator, node)
163 | node.decorator_list.insert(0, decorator)
164 | self._parents.append(node)
165 | self.generic_visit(node)
166 | self._parents.pop()
167 | return node
168 |
169 | def visit_FunctionDef(self, node: ast.FunctionDef):
170 | # Originally, we had some code here to explicitly check if the function
171 | # had any annotated arguments or annotated return types, and if not, we
172 | # would skip adding the `@jaxtyped` decorator.
173 | # However, this has been removed because it would ignore functions that
174 | # had type annotations in the body of the function (or
175 | # `assert isinstance(..., SomeType)`).
176 |
177 | decorator = self._typechecker.get_ast()
178 | ast.copy_location(decorator, node)
179 | # Place at the end of the decorator list, because:
180 | # - as otherwise we wrap e.g. `jax.custom_{jvp,vjp}` and lose the ability
181 | # to `defjvp` etc.
182 | # - decorators frequently remove annotations from functions, and we'd like
183 | # to use those annotations.
184 | # - typeguard in particular wants to be at the end of the decorator list, as
185 | # it works by recompling the wrapped function.
186 | #
187 | # Note that the counter-argument here is that we'd like to place this
188 | # at the start of the decorator list, in case a typechecking annotation
189 | # has been manually applied, and we'd need to be above that. In this
190 | # case we're just going to have to need to ask the user to remove their
191 | # typechecking annotation (and let this decorator do it instead).
192 | # It's more important we be compatible with normal JAX code.
193 | node.decorator_list.append(decorator)
194 |
195 | self._parents.append(node)
196 | self.generic_visit(node)
197 | self._parents.pop()
198 | return node
199 |
200 |
201 | class _JaxtypingLoader(SourceFileLoader):
202 | def __init__(self, *args, typechecker: Typechecker, **kwargs):
203 | super().__init__(*args, **kwargs)
204 | self._typechecker = typechecker
205 |
206 | def source_to_code(self, data, path, *, _optimize=-1):
207 | source = decode_source(data)
208 | tree = _call_with_frames_removed(
209 | compile,
210 | source,
211 | path,
212 | "exec",
213 | ast.PyCF_ONLY_AST,
214 | dont_inherit=True,
215 | optimize=_optimize,
216 | )
217 | tree = JaxtypingTransformer(typechecker=self._typechecker).visit(tree)
218 | ast.fix_missing_locations(tree)
219 | return _call_with_frames_removed(
220 | compile, tree, path, "exec", dont_inherit=True, optimize=_optimize
221 | )
222 |
223 | def exec_module(self, module):
224 | # Use a custom optimization marker - the import lock should make this monkey
225 | # patch safe
226 | with patch(
227 | "importlib._bootstrap_external.cache_from_source",
228 | ft.partial(_optimized_cache_from_source, self._typechecker.get_hash()),
229 | ):
230 | return super().exec_module(module)
231 |
232 |
233 | class _JaxtypingFinder(MetaPathFinder):
234 | """Wraps another path finder and instruments the module with `@jaxtyped` and
235 | `@typechecked` if `should_instrument()` returns `True`.
236 |
237 | Should not be used directly, but rather via `install_import_hook`.
238 | """
239 |
240 | def __init__(self, modules, original_pathfinder, typechecker: Typechecker):
241 | self.modules = modules
242 | self._original_pathfinder = original_pathfinder
243 | self._typechecker = typechecker
244 |
245 | def find_spec(self, fullname, path=None, target=None):
246 | if self.should_instrument(fullname):
247 | spec = self._original_pathfinder.find_spec(fullname, path, target)
248 | if spec is not None and isinstance(spec.loader, SourceFileLoader):
249 | spec.loader = _JaxtypingLoader(
250 | spec.loader.name, spec.loader.path, typechecker=self._typechecker
251 | )
252 | return spec
253 |
254 | return None
255 |
256 | def should_instrument(self, module_name: str) -> bool:
257 | """Determine whether the module with the given name should be instrumented.
258 |
259 | **Arguments:**
260 |
261 | - `module_name`: the full name of the module that is about to be imported
262 | (e.g. ``xyz.abc``)
263 | """
264 | for module in self.modules:
265 | if module_name == module or module_name.startswith(module + "."):
266 | return True
267 |
268 | return False
269 |
270 |
271 | class ImportHookManager:
272 | def __init__(self, hook: MetaPathFinder):
273 | self.hook = hook
274 |
275 | def __enter__(self):
276 | pass
277 |
278 | def __exit__(self, exc_type, exc_val, exc_tb):
279 | self.uninstall()
280 |
281 | def uninstall(self):
282 | try:
283 | sys.meta_path.remove(self.hook)
284 | except ValueError:
285 | pass # already removed
286 |
287 |
288 | # Deliberately no default for `typechecker` so that folks must opt-in to not having
289 | # a typechecker.
290 | def install_import_hook(modules: Union[str, Sequence[str]], typechecker: Optional[str]):
291 | """Automatically apply the `@jaxtyped(typechecker=typechecker)` decorator to every
292 | function and dataclass over a whole codebase.
293 |
294 | !!! Tip "Usage"
295 |
296 | ```python
297 | from jaxtyping import install_import_hook
298 | # Plus any one of the following:
299 |
300 | # decorate `@jaxtyped(typechecker=typeguard.typechecked)`
301 | with install_import_hook("foo", "typeguard.typechecked"):
302 | import foo # Any module imported inside this `with` block, whose
303 | import foo.bar # name begins with the specified string, will
304 | import foo.bar.qux # automatically have both `@jaxtyped` and the specified
305 | # typechecker applied to all of their functions and
306 | # dataclasses.
307 |
308 | # decorate `@jaxtyped(typechecker=beartype.beartype)`
309 | with install_import_hook("foo", "beartype.beartype"):
310 | ...
311 |
312 | # decorate only `@jaxtyped` (if you want that for some reason)
313 | with install_import_hook("foo", None):
314 | ...
315 | ```
316 |
317 | If you don't like using the `with` block, the hook can be used without that:
318 | ```python
319 | hook = install_import_hook(...)
320 | import ...
321 | hook.uninstall()
322 | ```
323 |
324 | The import hook can be applied to multiple packages via
325 | ```python
326 | install_import_hook(["foo", "bar.baz"], ...)
327 | ```
328 |
329 | **Arguments:**
330 |
331 | - `modules`: the names of the modules in which to automatically apply `@jaxtyped`.
332 | - `typechecker`: the module and function of the typechecker you want to use, as a
333 | string. For example `typechecker="typeguard.typechecked"`, or
334 | `typechecker="beartype.beartype"`. You may pass `typechecker=None` if you do not
335 | want to automatically decorate with a typechecker as well.
336 |
337 | **Returns:**
338 |
339 | A context manager that uninstalls the hook on exit, or when you call `.uninstall()`.
340 |
341 | !!! Example "Example: end-user script"
342 |
343 | ```python
344 | ### entry_point.py
345 | from jaxtyping import install_import_hook
346 | with install_import_hook("main", "typeguard.typechecked"):
347 | import main
348 |
349 | ### main.py
350 | from jaxtyping import Array, Float32
351 |
352 | def f(x: Float32[Array, "batch channels"]):
353 | ...
354 | ```
355 |
356 | !!! Example "Example: writing a library"
357 |
358 | ```python
359 | ### __init__.py
360 | from jaxtyping import install_import_hook
361 | with install_import_hook("my_library_name", "beartype.beartype"):
362 | from .subpackage import foo # full name is my_library_name.subpackage so
363 | # will be hook'd
364 | from .another_subpackage import bar # full name is my_library_name.another_subpackage
365 | # so will be hook'd.
366 | ```
367 |
368 | !!! warning
369 |
370 | If a function already has any decorators on it, then `@jaxtyped` will get added
371 | at the bottom of the decorator list, e.g.
372 | ```python
373 | @some_other_decorator
374 | @jaxtyped(typechecker=beartype.beartype)
375 | def foo(...): ...
376 | ```
377 | This is to support the common case in which
378 | `some_other_decorator = jax.custom_jvp` etc.
379 |
380 | If a class already has any decorators in it, then `@jaxtyped` will get added to
381 | the top of the decorator list, e.g.
382 | ```python
383 | @jaxtyped(typechecker=beartype.beartype)
384 | @some_other_decorator
385 | class A:
386 | ...
387 | ```
388 | This is to support the common case in which
389 | `some_other_decorator = dataclasses.dataclass`.
390 | """ # noqa: E501
391 |
392 | if isinstance(modules, str):
393 | modules = [modules]
394 |
395 | # Support old less-flexible API.
396 | if isinstance(typechecker, tuple):
397 | typechecker = ".".join(typechecker)
398 |
399 | for i, finder in enumerate(sys.meta_path):
400 | if (
401 | isclass(finder)
402 | and finder.__name__ == "PathFinder"
403 | and hasattr(finder, "find_spec")
404 | ):
405 | break
406 | else:
407 | raise RuntimeError("Cannot find a PathFinder in sys.meta_path")
408 |
409 | wrapped_typechecker = Typechecker(typechecker)
410 | hook = _JaxtypingFinder(modules, finder, wrapped_typechecker)
411 | sys.meta_path.insert(0, hook)
412 | return ImportHookManager(hook)
413 |
--------------------------------------------------------------------------------
/jaxtyping/_indirection.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Google LLC
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | # Note that `from typing import Annotated; Bool = Annotated`
21 | # does not work with static type checkers. `Annotated` is a typeform rather
22 | # than a type, meaning it cannot be assigned.
23 | from typing import (
24 | Annotated as BFloat16, # noqa: F401
25 | Annotated as Bool, # noqa: F401
26 | Annotated as Complex, # noqa: F401
27 | Annotated as Complex64, # noqa: F401
28 | Annotated as Complex128, # noqa: F401
29 | Annotated as Float, # noqa: F401
30 | Annotated as Float8e4m3b11fnuz, # noqa: F401
31 | Annotated as Float8e4m3fn, # noqa: F401
32 | Annotated as Float8e4m3fnuz, # noqa: F401
33 | Annotated as Float8e5m2, # noqa: F401
34 | Annotated as Float8e5m2fnuz, # noqa: F401
35 | Annotated as Float16, # noqa: F401
36 | Annotated as Float32, # noqa: F401
37 | Annotated as Float64, # noqa: F401
38 | Annotated as Inexact, # noqa: F401
39 | Annotated as Int, # noqa: F401
40 | Annotated as Int2, # noqa: F401
41 | Annotated as Int4, # noqa: F401
42 | Annotated as Int8, # noqa: F401
43 | Annotated as Int16, # noqa: F401
44 | Annotated as Int32, # noqa: F401
45 | Annotated as Int64, # noqa: F401
46 | Annotated as Integer, # noqa: F401
47 | Annotated as Key, # noqa: F401
48 | Annotated as Num, # noqa: F401
49 | Annotated as Real, # noqa: F401
50 | Annotated as Shaped, # noqa: F401
51 | Annotated as UInt, # noqa: F401
52 | Annotated as UInt2, # noqa: F401
53 | Annotated as UInt4, # noqa: F401
54 | Annotated as UInt8, # noqa: F401
55 | Annotated as UInt16, # noqa: F401
56 | Annotated as UInt32, # noqa: F401
57 | Annotated as UInt64, # noqa: F401
58 | TYPE_CHECKING,
59 | )
60 |
61 |
62 | if not TYPE_CHECKING:
63 | assert False
64 |
65 | from jax import (
66 | Array as PRNGKeyArray, # noqa: F401
67 | Array as Scalar, # noqa: F401
68 | )
69 | from jax.typing import ArrayLike as ScalarLike # noqa: F401
70 |
--------------------------------------------------------------------------------
/jaxtyping/_ipython_extension.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Google LLC
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | from ._import_hook import JaxtypingTransformer, Typechecker
21 |
22 |
23 | def choose_typechecker_magics():
24 | # The import is local to avoid degrading import times when the magic is
25 | # not needed.
26 | from IPython.core.magic import line_magic, Magics, magics_class
27 |
28 | @magics_class
29 | class ChooseTypecheckerMagics(Magics):
30 | @line_magic("jaxtyping.typechecker")
31 | def typechecker(self, typechecker):
32 | # remove old JaxtypingTransformer, if present
33 | self.shell.ast_transformers = list(
34 | filter(
35 | lambda x: not isinstance(x, JaxtypingTransformer),
36 | self.shell.ast_transformers,
37 | )
38 | )
39 |
40 | # add new one
41 | self.shell.ast_transformers.append(
42 | JaxtypingTransformer(typechecker=Typechecker(typechecker))
43 | )
44 |
45 | return ChooseTypecheckerMagics
46 |
47 |
48 | def load_ipython_extension(ipython):
49 | try:
50 | ChooseTypecheckerMagics = choose_typechecker_magics()
51 | except Exception as e:
52 | # Very broad exception-handling, as e.g. IPython will sometimes be
53 | # present but fail to import for mysterious reasons.
54 | raise RuntimeError("Failed to define jaxtyping.typechecker magic") from e
55 |
56 | ipython.register_magics(ChooseTypecheckerMagics)
57 |
--------------------------------------------------------------------------------
/jaxtyping/_pytest_plugin.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Google LLC
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | import sys
21 |
22 | from ._import_hook import install_import_hook
23 |
24 |
25 | def pytest_addoption(parser):
26 | group = parser.getgroup("jaxtyping")
27 | group.addoption(
28 | "--jaxtyping-packages",
29 | action="store",
30 | help="comma separated name list of packages and modules to instrument for "
31 | "type checking with jaxtyping. The last element in the list should be the "
32 | "type checker to use, e.g. "
33 | "--jaxtyping-packages=foopackage,barpackage,typeguard.typechecked",
34 | )
35 |
36 |
37 | def pytest_configure(config):
38 | value = config.getoption("jaxtyping_packages")
39 | if not value:
40 | return
41 |
42 | packages = [pkg.strip() for pkg in value.split(",")]
43 | *packages, typechecker = packages
44 |
45 | already_imported_packages = sorted(
46 | package for package in packages if package in sys.modules
47 | )
48 | if already_imported_packages:
49 | message = (
50 | "jaxtyping cannot check these packages because they "
51 | "are already imported: {}"
52 | )
53 | raise RuntimeError(message.format(", ".join(already_imported_packages)))
54 |
55 | install_import_hook(packages, typechecker)
56 |
--------------------------------------------------------------------------------
/jaxtyping/_pytree_type.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Google LLC
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | import functools as ft
21 | import typing
22 | from typing import Any, Generic, TypeVar
23 |
24 | import jax.tree_util as jtu
25 | import wadler_lindig as wl
26 |
27 | from ._errors import AnnotationError
28 | from ._storage import (
29 | clear_treeflatten_memo,
30 | clear_treepath_memo,
31 | get_shape_memo,
32 | set_shape_memo,
33 | set_treeflatten_memo,
34 | set_treepath_memo,
35 | )
36 |
37 |
38 | _T = TypeVar("_T")
39 | _S = TypeVar("_S")
40 |
41 |
42 | class _FakePyTree1(Generic[_T]):
43 | pass
44 |
45 |
46 | _FakePyTree1.__name__ = "PyTree"
47 | _FakePyTree1.__qualname__ = "PyTree"
48 | _FakePyTree1.__module__ = "builtins"
49 |
50 |
51 | class _FakePyTree2(Generic[_T, _S]):
52 | pass
53 |
54 |
55 | _FakePyTree2.__name__ = "PyTree"
56 | _FakePyTree2.__qualname__ = "PyTree"
57 | _FakePyTree2.__module__ = "builtins"
58 |
59 |
60 | class _MetaPyTree(type):
61 | def __call__(self, *args, **kwargs):
62 | raise RuntimeError("PyTree cannot be instantiated")
63 |
64 | def __instancecheck__(cls, obj):
65 | if not hasattr(cls, "leaftype"):
66 | return True # Just `isinstance(x, PyTree)`
67 | # Handle beartype doing `isinstance(None, hint)` to check if
68 | # is `instance`able.
69 | if obj is None:
70 | return True
71 |
72 | single_memo, variadic_memo, pytree_memo, arg_memo = get_shape_memo()
73 | single_memo_bak = single_memo.copy()
74 | variadic_memo_bak = variadic_memo.copy()
75 | pytree_memo_bak = pytree_memo.copy()
76 | arg_memo_bak = arg_memo.copy()
77 | try:
78 | out = cls._check(obj, pytree_memo)
79 | except Exception:
80 | set_shape_memo(
81 | single_memo_bak, variadic_memo_bak, pytree_memo_bak, arg_memo_bak
82 | )
83 | raise
84 | if out:
85 | return True
86 | else:
87 | set_shape_memo(
88 | single_memo_bak, variadic_memo_bak, pytree_memo_bak, arg_memo_bak
89 | )
90 | return False
91 |
92 | def _check(cls, obj, pytree_memo):
93 | if cls.leaftype is Any:
94 |
95 | def is_flatten_leaftype(x):
96 | return False
97 |
98 | def is_check_leaftype(x):
99 | return True
100 |
101 | else:
102 | # We could use `isinstance` here but that would fail for more complicated
103 | # types, e.g. PyTree[tuple[int]]. So at least internally we make a
104 | # particular choice of typechecker.
105 | #
106 | # Deliberately not using @jaxtyped so that we share the same `memo` as
107 | # whatever dynamic context we're currently in.
108 | from ._typeguard import typechecked
109 |
110 | @typechecked
111 | def accepts_leaftype(x: cls.leaftype):
112 | pass
113 |
114 | def is_leaftype(x):
115 | try:
116 | accepts_leaftype(x)
117 | except TypeError:
118 | return False
119 | else:
120 | return True
121 |
122 | is_flatten_leaftype = is_check_leaftype = is_leaftype
123 |
124 | set_treeflatten_memo()
125 | try:
126 | leaves, structure = jtu.tree_flatten(obj, is_leaf=is_flatten_leaftype)
127 | finally:
128 | clear_treeflatten_memo()
129 | if cls.structure is not None:
130 | if cls.structure.isidentifier():
131 | try:
132 | prev_structure = pytree_memo[cls.structure]
133 | except KeyError:
134 | pytree_memo[cls.structure] = structure
135 | else:
136 | if prev_structure != structure:
137 | return False
138 | else:
139 | named_pytree = 0
140 | pieces = cls.structure.split()
141 | if pieces[0] == "...":
142 | pieces = pieces[1:]
143 | prefix = False
144 | suffix = True
145 | elif pieces[-1] == "...":
146 | pieces = pieces[:-1]
147 | prefix = True
148 | suffix = False
149 | else:
150 | prefix = False
151 | suffix = False
152 | for identifier in pieces:
153 | try:
154 | prev_structure = pytree_memo[identifier]
155 | except KeyError as e:
156 | raise AnnotationError(
157 | f"Cannot process composite structure '{cls.structure}' "
158 | f"as the structure name {identifier} has not been seen "
159 | "before."
160 | ) from e
161 | # Not using `PyTreeDef.compose` due to JAX bug #18218.
162 | prev_pytree = jtu.tree_unflatten(
163 | prev_structure, [0] * prev_structure.num_leaves
164 | )
165 | named_pytree = jtu.tree_map(lambda _: prev_pytree, named_pytree)
166 | named_structure = jtu.tree_structure(named_pytree)
167 | if prefix:
168 | dummy_pytree = jtu.tree_unflatten(structure, [0] * len(leaves))
169 | dummy_named = jtu.tree_unflatten(
170 | named_structure, [0] * named_structure.num_leaves
171 | )
172 | try:
173 | jtu.tree_map(lambda _, __: 0, dummy_named, dummy_pytree)
174 | except ValueError:
175 | return False
176 | elif suffix:
177 | has_structure = lambda x: jtu.tree_structure(x) == named_structure
178 | dummy_pytree = jtu.tree_unflatten(structure, [0] * len(leaves))
179 | dummy_leaves = jtu.tree_leaves(dummy_pytree, is_leaf=has_structure)
180 | if any(not has_structure(x) for x in dummy_leaves):
181 | return False
182 | else:
183 | if structure != named_structure:
184 | return False
185 |
186 | try:
187 | for leaf_index, leaf in enumerate(leaves):
188 | if cls.structure is not None:
189 | set_treepath_memo(leaf_index, cls.structure)
190 | if not is_check_leaftype(leaf):
191 | return False
192 | clear_treepath_memo()
193 | finally:
194 | clear_treepath_memo()
195 | return True
196 |
197 | # Can't return a generic (e.g. _FakePyTree[item]) because generic aliases don't do
198 | # the custom __instancecheck__ that we want.
199 | # We can't add that __instancecheck__ via subclassing, e.g.
200 | # type("PyTree", (Generic[_T],), {}), because dynamic subclassing of typeforms
201 | # isn't allowed.
202 | # Likewise we can't do types.new_class("PyTree", (Generic[_T],), {}) because that
203 | # has __module__ "types", e.g. we get types.PyTree[int].
204 | @ft.lru_cache(maxsize=None)
205 | def __getitem__(cls, item):
206 | if isinstance(item, tuple):
207 | if len(item) == 2:
208 |
209 | class X(PyTree):
210 | leaftype = item[0]
211 | structure = item[1].strip()
212 |
213 | if not isinstance(X.structure, str):
214 | raise ValueError(
215 | "The structure annotation `struct` in "
216 | "`jaxtyping.PyTree[leaftype, struct]` must be be a string, "
217 | f"e.g. `jaxtyping.PyTree[leaftype, 'T']`. Got '{X.structure}'."
218 | )
219 | pieces = X.structure.split()
220 | if len(pieces) == 0:
221 | raise ValueError(
222 | "The string `struct` in `jaxtyping.PyTree[leaftype, struct]` "
223 | "cannot be the empty string."
224 | )
225 | for piece_index, piece in enumerate(pieces):
226 | if (piece_index == 0) or (piece_index == len(pieces) - 1):
227 | if piece == "...":
228 | continue
229 | if not piece.isidentifier():
230 | raise ValueError(
231 | "The string `struct` in "
232 | "`jaxtyping.PyTree[leaftype, struct]` must be be a "
233 | "whitespace-separated sequence of identifiers, e.g. "
234 | "`jaxtyping.PyTree[leaftype, 'T']` or "
235 | "`jaxtyping.PyTree[leaftype, 'foo bar']`.\n"
236 | "(Here, 'identifier' is used in the same sense as in "
237 | "regular Python, i.e. a valid variable name.)\n"
238 | f"Got piece '{piece}' in overall structure '{X.structure}'."
239 | )
240 |
241 | class Y:
242 | pass
243 |
244 | Y.__module__ = "builtins"
245 | Y.__name__ = repr(X.structure)
246 | Y.__qualname__ = repr(X.structure)
247 | name = wl.pformat(_FakePyTree2[X.leaftype, Y], width=9999)
248 | del Y
249 | else:
250 | raise ValueError(
251 | "The subscript `foo` in `jaxtyping.PyTree[foo]` must either be a "
252 | "leaf type, e.g. `PyTree[int]`, or a 2-tuple of leaf and "
253 | "structure, e.g. `PyTree[int, 'T']`. Received a tuple of length "
254 | f"{len(item)}."
255 | )
256 | else:
257 | name = wl.pformat(_FakePyTree1[item], width=9999)
258 |
259 | class X(PyTree):
260 | leaftype = item
261 | structure = None
262 |
263 | X.__name__ = name
264 | X.__qualname__ = name
265 | if getattr(typing, "GENERATING_DOCUMENTATION", "") in {"", "jaxtyping"}:
266 | X.__module__ = "jaxtyping"
267 | else:
268 | X.__module__ = "builtins"
269 | return X
270 |
271 | def __pdoc__(self, **kwargs):
272 | if self is PyTree:
273 | return wl.TextDoc("PyTree")
274 | else:
275 | indent = kwargs["indent"]
276 | docs = [wl.pdoc(self.leaftype, **kwargs)]
277 | if self.structure is not None:
278 | docs.append(wl.pdoc(self.structure, **kwargs))
279 | return wl.bracketed(
280 | begin=wl.TextDoc("PyTree["),
281 | docs=docs,
282 | sep=wl.comma,
283 | end=wl.TextDoc("]"),
284 | indent=indent,
285 | )
286 |
287 |
288 | # Can't do `class PyTree(Generic[_T]): ...` because we need to override the
289 | # instancecheck for PyTree[foo], but subclassing
290 | # `type(Generic[int])`, i.e. `typing._GenericAlias` is disallowed.
291 | PyTree = _MetaPyTree("PyTree", (), {})
292 | if getattr(typing, "GENERATING_DOCUMENTATION", "") in {"", "jaxtyping"}:
293 | PyTree.__module__ = "jaxtyping"
294 | else:
295 | PyTree.__module__ = "builtins"
296 | PyTree.__doc__ = """Represents a PyTree.
297 |
298 | Annotations of the following sorts are supported:
299 | ```python
300 | a: PyTree
301 | b: PyTree[LeafType]
302 | c: PyTree[LeafType, "T"]
303 | d: PyTree[LeafType, "S T"]
304 | e: PyTree[LeafType, "... T"]
305 | f: PyTree[LeafType, "T ..."]
306 | ```
307 |
308 | These correspond to:
309 |
310 | a. A plain `PyTree` can be used an annotation, in which case `PyTree` is simply a
311 | suggestively-named alternative to `Any`.
312 | ([By definition all types are PyTrees.](https://jax.readthedocs.io/en/latest/pytrees.html))
313 |
314 | b. `PyTree[LeafType]` denotes a PyTree all of whose leaves match `LeafType`. For
315 | example, `PyTree[int]` or `PyTree[Union[str, Float32[Array, "b c"]]]`.
316 |
317 | c. A structure name can also be passed. In this case
318 | `jax.tree_util.tree_structure(...)` will be called, and bound to the structure name.
319 | This can be used to mark that multiple PyTrees all have the same structure:
320 | ```python
321 | def f(x: PyTree[int, "T"], y: PyTree[int, "T"]):
322 | ...
323 | ```
324 | Structures are bound to names in the same way as array shape annotations, i.e.
325 | within the thread-local dynamic context of a [`jaxtyping.jaxtyped`][] decorator.
326 |
327 | d. A composite structure can be declared. In this case the variable must have a PyTree
328 | structure each to the composition of multiple previously-bound PyTree structures.
329 | For example:
330 | ```python
331 | def f(x: PyTree[int, "T"], y: PyTree[int, "S"], z: PyTree[int, "S T"]):
332 | ...
333 |
334 | x = (1, 2)
335 | y = {"key": 3}
336 | z = {"key": (4, 5)} # structure is the composition of the structures of `y` and `z`
337 | f(x, y, z)
338 | ```
339 | When performing runtime type-checking, all the individual pieces must have already
340 | been bound to structures, otherwise the composite structure check will throw an error.
341 |
342 | e. A structure can begin with a `...`, to denote that the lower levels of the PyTree
343 | must match the declared structure, but the upper levels can be arbitrary. As in the
344 | previous case, all named pieces must already have been seen and their structures
345 | bound.
346 |
347 | f. A structure can end with a `...`, to denote that the PyTree must be a prefix of the
348 | declared structure, but the lower levels can be arbitrary. As in the previous two
349 | cases, all named pieces must already have been seen and their structures bound.
350 | """ # noqa: E501
351 |
--------------------------------------------------------------------------------
/jaxtyping/_storage.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Google LLC
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | import threading
21 | from typing import Any, Optional
22 |
23 | from ._errors import AnnotationError
24 |
25 |
26 | _shape_storage = threading.local()
27 |
28 |
29 | def _has_shape_memo():
30 | return hasattr(_shape_storage, "memo_stack") and len(_shape_storage.memo_stack) != 0
31 |
32 |
33 | def get_shape_memo():
34 | if _has_shape_memo():
35 | single_memo, variadic_memo, pytree_memo, arguments = _shape_storage.memo_stack[
36 | -1
37 | ]
38 | else:
39 | # `isinstance` happening outside any @jaxtyped decorators, e.g. at the
40 | # global scope. In this case just create a temporary memo, since we're not
41 | # going to be comparing against any stored values anyway.
42 | single_memo = {}
43 | variadic_memo = {}
44 | pytree_memo = {}
45 | arguments = {}
46 | return single_memo, variadic_memo, pytree_memo, arguments
47 |
48 |
49 | def set_shape_memo(single_memo, variadic_memo, pytree_memo, arg_memo) -> None:
50 | if _has_shape_memo():
51 | _shape_storage.memo_stack[-1] = (
52 | single_memo,
53 | variadic_memo,
54 | pytree_memo,
55 | arg_memo,
56 | )
57 |
58 |
59 | def push_shape_memo(arguments: dict[str, Any]):
60 | try:
61 | memo_stack = _shape_storage.memo_stack
62 | except AttributeError:
63 | # Can't be done when `_stack_storage` is created for reasons I forget.
64 | memo_stack = _shape_storage.memo_stack = []
65 | memos = ({}, {}, {}, arguments.copy())
66 | memo_stack.append(memos)
67 | return memos
68 |
69 |
70 | def pop_shape_memo() -> None:
71 | _shape_storage.memo_stack.pop()
72 |
73 |
74 | def shape_str(memos) -> str:
75 | """Gives debug information on the current state of jaxtyping's internal memos.
76 | Used in type-checking error messages.
77 |
78 | **Arguments:**
79 |
80 | - `memos`: as returned by `get_shape_memo` or `push_shape_memo`.
81 | """
82 | single_memo, variadic_memo, pytree_memo, _ = memos
83 | single_memo = {
84 | name: size
85 | for name, size in single_memo.items()
86 | if not name.startswith("~~delete~~")
87 | }
88 | variadic_memo = {
89 | name: shape
90 | for name, (_, shape) in variadic_memo.items()
91 | if not name.startswith("~~delete~~")
92 | }
93 | pieces = []
94 | if len(single_memo) > 0 or len(variadic_memo) > 0:
95 | pieces.append(
96 | "The current values for each jaxtyping axis annotation are as follows."
97 | )
98 | for name, size in single_memo.items():
99 | pieces.append(f"{name}={size}")
100 | for name, shape in variadic_memo.items():
101 | pieces.append(f"{name}={shape}")
102 | if len(pytree_memo) > 0:
103 | pieces.append(
104 | "The current values for each jaxtyping PyTree structure annotation are as "
105 | "follows."
106 | )
107 | for name, structure in pytree_memo.items():
108 | pieces.append(f"{name}={structure}")
109 | return "\n".join(pieces)
110 |
111 |
112 | def print_bindings():
113 | """Prints the values of the current jaxtyping axis bindings. Intended for debugging.
114 |
115 | For example, this can be used to find the values bound to `foo` and `bar` in
116 |
117 | ```python
118 | @jaxtyped(typechecker=...)
119 | def f(x: Float[Array, "foo bar"]):
120 | print_bindings()
121 | ...
122 | ```
123 |
124 | noting that these values are bounding during runtime typechecking, so that the
125 | [`jaxtyping.jaxtyped`][] decorator is required.
126 |
127 | **Arguments:**
128 |
129 | Nothing.
130 |
131 | **Returns:**
132 |
133 | Nothing.
134 | """
135 | print(shape_str(get_shape_memo()))
136 |
137 |
138 | _treepath_storage = threading.local()
139 |
140 |
141 | def clear_treepath_memo() -> None:
142 | _treepath_storage.value = None
143 |
144 |
145 | def set_treepath_memo(index: Optional[int], structure: str) -> None:
146 | if hasattr(_treepath_storage, "value") and _treepath_storage.value is not None:
147 | raise AnnotationError(
148 | "Cannot typecheck annotations of the form "
149 | "`PyTree[PyTree[Shaped[Array, '?foo'], 'T'], 'S']` as it is ambiguous "
150 | "which PyTree the `?` annotation refers to."
151 | )
152 | if index is None:
153 | _treepath_storage.value = f"~~delete~~({structure}) "
154 | else:
155 | # Appears in error messages, so human-readable
156 | _treepath_storage.value = f"(Leaf {index} in structure {structure}) "
157 |
158 |
159 | def get_treepath_memo() -> str:
160 | if not hasattr(_treepath_storage, "value") or _treepath_storage.value is None:
161 | raise AnnotationError(
162 | "Cannot use `?` annotations, e.g. `Shaped[Array, '?foo']`, except "
163 | "when contained with structured `PyTree` annotations, e.g. "
164 | "`PyTree[Shaped[Array, '?foo'], 'T']`."
165 | )
166 | return _treepath_storage.value
167 |
168 |
169 | _treeflatten_storage = threading.local()
170 |
171 |
172 | def clear_treeflatten_memo() -> None:
173 | _treeflatten_storage.value = False
174 |
175 |
176 | def set_treeflatten_memo():
177 | _treeflatten_storage.value = True
178 |
179 |
180 | def get_treeflatten_memo():
181 | try:
182 | return _treeflatten_storage.value
183 | except AttributeError:
184 | return False
185 |
--------------------------------------------------------------------------------
/jaxtyping/_typeguard/LICENSE:
--------------------------------------------------------------------------------
1 | This is the MIT license: http://www.opensource.org/licenses/mit-license.php
2 |
3 | Copyright (c) Alex Grönholm
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this
6 | software and associated documentation files (the "Software"), to deal in the Software
7 | without restriction, including without limitation the rights to use, copy, modify, merge,
8 | publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons
9 | to whom the Software is furnished to do so, subject to the following conditions:
10 |
11 | The above copyright notice and this permission notice shall be included in all copies or
12 | substantial portions of the Software.
13 |
14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
15 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
16 | PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
17 | FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
18 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19 | DEALINGS IN THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/jaxtyping/_typeguard/README.md:
--------------------------------------------------------------------------------
1 | # Vendored copy of typeguard v2.13.3
2 |
3 | We include a vendored copy of typeguard v2.13.3. The reason we need a runtime typechecker is to be able to define `isinstance(..., PyTree[Foo])`.
4 |
5 | Of the available options:
6 |
7 | - `beartype` does not support `O(n)` checking.
8 | - `typeguard` v4 is notorious for having some bugs (they seem to re-parse the AST or something?? And then they die on the fact that we have strings in our annotations.)
9 | - `typeguard` v2 is what we use here... but we vendor it instead of depending on it, because people may still wish to use typeguard v4 in their own environments. (Notably a number of other packages depend on this, and it's just inconvenient to be incompatible at the package level, when the combinations which don't mix at runtime might never actually be used.)
10 |
11 | This is vendored under the terms of the MIT license, which is also reproduced here.
12 |
--------------------------------------------------------------------------------
/jaxtyping/py.typed:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | theme:
2 | name: material
3 | features:
4 | - navigation.sections # Sections are included in the navigation on the left.
5 | - toc.integrate # Table of contents is integrated on the left; does not appear separately on the right.
6 | - header.autohide # header disappears as you scroll
7 | palette:
8 | # Light mode / dark mode
9 | # We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as
10 | # (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle.
11 | - scheme: default
12 | primary: white
13 | accent: amber
14 | toggle:
15 | icon: material/weather-night
16 | name: Switch to dark mode
17 | - scheme: slate
18 | primary: black
19 | accent: amber
20 | toggle:
21 | icon: material/weather-sunny
22 | name: Switch to light mode
23 | icon:
24 | repo: fontawesome/brands/github # GitHub logo in top right
25 | logo: "material/check-network-outline" # jaxtyping logo in top left
26 | favicon: "_static/favicon.png"
27 | custom_dir: "docs/_overrides" # Overriding part of the HTML
28 |
29 | # These additions are my own custom ones, having overridden a partial.
30 | twitter_bluesky_name: "@PatrickKidger"
31 | twitter_url: "https://twitter.com/PatrickKidger"
32 | bluesky_url: "https://PatrickKidger.bsky.social"
33 |
34 | site_name: jaxtyping
35 | site_description: The documentation for the jaxtyping software library.
36 | site_author: Patrick Kidger
37 | site_url: https://docs.kidger.site/jaxtyping
38 |
39 | repo_url: https://github.com/patrick-kidger/jaxtyping
40 | repo_name: patrick-kidger/jaxtyping
41 | edit_uri: "" # No edit button, as some of our pages are in /docs and some in /examples via symlink, so it's impossible for them all to be accurate
42 |
43 | strict: true # Don't allow warnings during the build process
44 |
45 | extra_javascript:
46 | # The below two make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/
47 | - _static/mathjax.js
48 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
49 |
50 | extra_css:
51 | - _static/custom_css.css
52 |
53 | markdown_extensions:
54 | - pymdownx.arithmatex: # Render LaTeX via MathJax
55 | generic: true
56 | - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme.
57 | - pymdownx.details # Allowing hidden expandable regions denoted by ???
58 | - pymdownx.snippets: # Include one Markdown file into another
59 | base_path: docs
60 | - admonition
61 | - toc:
62 | permalink: "¤" # Adds a clickable permalink to each section heading
63 | toc_depth: 4
64 |
65 | plugins:
66 | - search:
67 | separator: '[\s\-,:!=\[\]()"/]+|(?!\b)(?=[A-Z][a-z])|\.(?!\d)|&[lg]t;'
68 | - autorefs # Cross-links to headings
69 | - include_exclude_files:
70 | include:
71 | - ".htaccess"
72 | exclude:
73 | - "_overrides"
74 | - ipynb
75 | - hippogriffe:
76 | # Our docs are generated pretty much independently of the runtime code, these links aren't useful.
77 | show_source_links: none
78 | extra_public_objects:
79 | - numpy.dtype
80 | - mkdocstrings:
81 | handlers:
82 | python:
83 | options:
84 | force_inspection: true
85 | heading_level: 4
86 | inherited_members: true
87 | members_order: source
88 | show_bases: false
89 | show_if_no_docstring: true
90 | show_overloads: false
91 | show_root_heading: true
92 | show_signature_annotations: true
93 | show_source: false
94 | show_symbol_type_heading: true
95 | show_symbol_type_toc: true
96 |
97 | nav:
98 | - 'index.md'
99 | - API:
100 | - 'api/array.md'
101 | - 'api/pytree.md'
102 | - 'api/runtime-type-checking.md'
103 | - 'api/advanced-features.md'
104 | - 'faq.md'
105 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "jaxtyping"
3 | version = "0.3.2"
4 | description = "Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays."
5 | readme = "README.md"
6 | requires-python =">=3.10"
7 | license = {file = "LICENSE"}
8 | authors = [
9 | {name = "Patrick Kidger", email = "contact@kidger.site"},
10 | ]
11 | keywords = ["jax", "neural-networks", "deep-learning", "equinox", "typing"]
12 | classifiers = [
13 | "Development Status :: 3 - Alpha",
14 | "Intended Audience :: Developers",
15 | "Intended Audience :: Financial and Insurance Industry",
16 | "Intended Audience :: Information Technology",
17 | "Intended Audience :: Science/Research",
18 | "License :: OSI Approved :: MIT License",
19 | "Natural Language :: English",
20 | "Programming Language :: Python :: 3",
21 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
22 | "Topic :: Scientific/Engineering :: Information Analysis",
23 | "Topic :: Scientific/Engineering :: Mathematics",
24 | ]
25 | urls = {repository = "https://github.com/google/jaxtyping" }
26 | dependencies = ["wadler_lindig>=0.1.3"]
27 | entry-points = {pytest11 = {jaxtyping = "jaxtyping._pytest_plugin"}}
28 |
29 | [project.optional-dependencies]
30 | docs = [
31 | "hippogriffe==0.2.0",
32 | "mkdocs==1.6.1",
33 | "mkdocs-include-exclude-files==0.1.0",
34 | "mkdocs-ipynb==0.1.0",
35 | "mkdocs-material==9.6.7",
36 | "mkdocstrings[python]==0.28.3",
37 | "pymdown-extensions==10.14.3",
38 | ]
39 |
40 | [build-system]
41 | requires = ["hatchling"]
42 | build-backend = "hatchling.build"
43 |
44 | [tool.hatch.build]
45 | include = ["jaxtyping/*"]
46 |
47 | [tool.ruff.lint]
48 | select = ["E", "F", "I001"]
49 | ignore = ["E721", "E731", "F722"]
50 |
51 | [tool.ruff.lint.per-file-ignores]
52 | "jaxtyping/_typeguard/__init__.py" = ["E", "F", "I001"]
53 |
54 | [tool.ruff.lint.isort]
55 | combine-as-imports = true
56 | lines-after-imports = 2
57 | extra-standard-library = ["typing_extensions"]
58 | order-by-type = false
59 |
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrick-kidger/jaxtyping/fe61644be8590cf0a89bacc278283e1e4b9ea3e4/test/__init__.py
--------------------------------------------------------------------------------
/test/conftest.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Google LLC
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | import random
21 |
22 | import jax.random as jr
23 | import pytest
24 | import typeguard
25 |
26 |
27 | try:
28 | import beartype
29 | except ImportError:
30 |
31 | def skip(*args, **kwargs):
32 | pytest.skip("Beartype not installed")
33 |
34 | typecheck_params = [typeguard.typechecked, skip]
35 | else:
36 | typecheck_params = [typeguard.typechecked, beartype.beartype]
37 |
38 |
39 | @pytest.fixture(params=typecheck_params)
40 | def typecheck(request):
41 | return request.param
42 |
43 |
44 | @pytest.fixture(params=(False, True))
45 | def jaxtyp(request):
46 | import jaxtyping
47 |
48 | if request.param:
49 | # New-style
50 | # @jaxtyping.jaxtyped(typechecker=typechecker)
51 | # def f(...)
52 | return lambda typechecker: jaxtyping.jaxtyped(typechecker=typechecker)
53 | else:
54 | # Old-style
55 | # @jaxtyping.jaxtyped
56 | # @typechecker
57 | # def f(...)
58 | def impl(typechecker):
59 | def decorator(fn):
60 | with pytest.warns(match="As of jaxtyping version 0.2.24"):
61 | return jaxtyping.jaxtyped(typechecker(fn))
62 |
63 | return decorator
64 |
65 | return impl
66 |
67 |
68 | @pytest.fixture()
69 | def getkey():
70 | def _getkey():
71 | # Not sure what the maximum actually is but this will do
72 | return jr.PRNGKey(random.randint(0, 2**31 - 1))
73 |
74 | return _getkey
75 |
76 |
77 | @pytest.fixture(scope="module")
78 | def beartype_or_skip():
79 | yield pytest.importorskip("beartype")
80 |
81 |
82 | @pytest.fixture(scope="module")
83 | def typeguard_or_skip():
84 | yield pytest.importorskip("typeguard")
85 |
--------------------------------------------------------------------------------
/test/helpers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Google LLC
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | import contextlib
21 | import gc
22 | from collections.abc import Callable
23 | from typing import Any
24 |
25 | import equinox as eqx
26 | import typeguard
27 |
28 |
29 | ParamError = []
30 | ReturnError = []
31 | ParamError.append(TypeError) # old typeguard
32 | ReturnError.append(TypeError) # old typeguard
33 |
34 | try:
35 | # new typeguard
36 | ParamError.append(typeguard.TypeCheckError)
37 | ReturnError.append(typeguard.TypeCheckError)
38 | except AttributeError:
39 | pass
40 |
41 | try:
42 | import beartype
43 | except ImportError:
44 | pass
45 | else:
46 | ParamError.append(beartype.roar.BeartypeCallHintParamViolation)
47 | ReturnError.append(beartype.roar.BeartypeCallHintReturnViolation)
48 |
49 | ParamError = tuple(ParamError)
50 | ReturnError = tuple(ReturnError)
51 |
52 |
53 | @eqx.filter_jit
54 | def make_mlp(key):
55 | return eqx.nn.MLP(2, 2, 2, 2, key=key)
56 |
57 |
58 | @contextlib.contextmanager
59 | def assert_no_garbage(
60 | allowed_garbage_predicate: Callable[[Any], bool] = lambda _: False,
61 | ):
62 | try:
63 | gc.disable()
64 | gc.collect()
65 | # It's unclear why, but a second GC is necessary to fully collect
66 | # existing garbage.
67 | gc.collect()
68 | gc.garbage.clear()
69 |
70 | yield
71 |
72 | # Do a GC collection, saving collected objects in gc.garbage.
73 | gc.set_debug(gc.DEBUG_SAVEALL)
74 | gc.collect()
75 |
76 | disallowed_garbage = [
77 | obj for obj in gc.garbage if not allowed_garbage_predicate(obj)
78 | ]
79 | assert not disallowed_garbage
80 | finally:
81 | # Reset the GC back to normal.
82 | gc.set_debug(0)
83 | gc.garbage.clear()
84 | gc.collect()
85 | gc.enable()
86 |
--------------------------------------------------------------------------------
/test/import_hook_tester.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Google LLC
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | import dataclasses
21 | from typing import no_type_check
22 |
23 | import equinox as eqx
24 | import jax.numpy as jnp
25 | import pytest
26 | from helpers import ParamError, ReturnError
27 |
28 | import jaxtyping
29 | from jaxtyping import Float32, Int
30 |
31 |
32 | #
33 | # Test that functions get checked
34 | #
35 |
36 |
37 | def g(x: Float32[jnp.ndarray, " b"]):
38 | pass
39 |
40 |
41 | g(jnp.array([1.0]))
42 | with pytest.raises(ParamError):
43 | g(jnp.array(1))
44 |
45 |
46 | #
47 | # Test that Equinox modules get checked
48 | #
49 |
50 |
51 | # Dataclass `__init__`, no converter
52 | class Mod1(eqx.Module):
53 | foo: int
54 | bar: Float32[jnp.ndarray, " a"]
55 |
56 |
57 | Mod1(1, jnp.array([1.0]))
58 | with pytest.raises(ParamError):
59 | Mod1(1.0, jnp.array([1.0]))
60 | with pytest.raises(ParamError):
61 | Mod1(1, jnp.array(1.0))
62 |
63 |
64 | # Dataclass `__init__`, converter
65 | class Mod2(eqx.Module):
66 | a: jnp.ndarray = eqx.field(converter=jnp.asarray)
67 |
68 |
69 | Mod2(1) # This will fail unless we run typechecking after conversion
70 |
71 |
72 | # This silently passes -- the untyped `lambda x: x` launders the value through.
73 | # No easy way to tackle this. That's okay.
74 |
75 | # class BadMod2(eqx.Module):
76 | # a: jnp.ndarray = eqx.field(converter=lambda x: x)
77 |
78 |
79 | # with pytest.raises(ParamError):
80 | # BadMod2(1)
81 | # with pytest.raises(ParamError):
82 | # BadMod2("asdf")
83 |
84 |
85 | # Custom `__init__`, no converter
86 | class Mod3(eqx.Module):
87 | foo: int
88 | bar: Float32[jnp.ndarray, " a"]
89 |
90 | def __init__(self, foo: str, bar: Float32[jnp.ndarray, " a"]):
91 | self.foo = int(foo)
92 | self.bar = bar
93 |
94 |
95 | Mod3("1", jnp.array([1.0]))
96 | with pytest.raises(ParamError):
97 | Mod3(1, jnp.array([1.0]))
98 | with pytest.raises(ParamError):
99 | Mod3("1", jnp.array(1.0))
100 |
101 |
102 | # Custom `__init__`, converter
103 | class Mod4(eqx.Module):
104 | a: Int[jnp.ndarray, ""] = eqx.field(converter=jnp.asarray)
105 |
106 | def __init__(self, a: str):
107 | self.a = int(a)
108 |
109 |
110 | Mod4("1") # This will fail unless we run typechecking after conversion
111 |
112 |
113 | # Custom `__post_init__`, no converter
114 | class Mod5(eqx.Module):
115 | foo: int
116 | bar: Float32[jnp.ndarray, " a"]
117 |
118 | def __post_init__(self):
119 | pass
120 |
121 |
122 | Mod5(1, jnp.array([1.0]))
123 | with pytest.raises(ParamError):
124 | Mod5(1.0, jnp.array([1.0]))
125 | with pytest.raises(ParamError):
126 | Mod5(1, jnp.array(1.0))
127 |
128 |
129 | # Dataclass `__init__`, converter
130 | class Mod6(eqx.Module):
131 | a: jnp.ndarray = eqx.field(converter=jnp.asarray)
132 |
133 | def __post_init__(self):
134 | pass
135 |
136 |
137 | Mod6(1) # This will fail unless we run typechecking after conversion
138 |
139 |
140 | #
141 | # Test that dataclasses get checked
142 | #
143 |
144 |
145 | @dataclasses.dataclass
146 | class D:
147 | foo: int
148 | bar: Float32[jnp.ndarray, " a"]
149 |
150 |
151 | D(1, jnp.array([1.0]))
152 | with pytest.raises(ParamError):
153 | D(1.0, jnp.array([1.0]))
154 | with pytest.raises(ParamError):
155 | D(1, jnp.array(1.0))
156 |
157 |
158 | #
159 | # Test that methods get checked
160 | #
161 |
162 |
163 | class N(eqx.Module):
164 | a: jnp.ndarray
165 |
166 | def __init__(self, foo: str):
167 | self.a = jnp.array(1)
168 |
169 | def foo(self, x: jnp.ndarray):
170 | pass
171 |
172 | def bar(self) -> jnp.ndarray:
173 | return self.a
174 |
175 |
176 | n = N("hi")
177 | with pytest.raises(ParamError):
178 | N(123)
179 | with pytest.raises(ParamError):
180 | n.foo("not_an_array_either")
181 | bad_n = eqx.tree_at(lambda x: x.a, n, "not_an_array")
182 | with pytest.raises(ReturnError):
183 | bad_n.bar()
184 |
185 |
186 | #
187 | # Test that we don't get called in `super()`.
188 | #
189 |
190 |
191 | called = False
192 |
193 |
194 | class Base(eqx.Module):
195 | x: int
196 |
197 | def __init__(self):
198 | self.x = "not an int"
199 | global called
200 | assert not called
201 | called = True
202 |
203 |
204 | class Derived(Base):
205 | def __init__(self):
206 | assert not called
207 | super().__init__()
208 | assert called
209 | self.x = 2
210 |
211 |
212 | Derived()
213 |
214 |
215 | #
216 | # Test that stringified type annotations work
217 |
218 |
219 | class Foo:
220 | pass
221 |
222 |
223 | class Bar(eqx.Module):
224 | x: type[Foo]
225 | y: "type[Foo]"
226 | # Partially-stringified hints not tested; not supported.
227 |
228 |
229 | Bar(Foo, Foo)
230 |
231 | with pytest.raises(ParamError):
232 | Bar(1, Foo)
233 |
234 |
235 | #
236 | # Test that assert isinstance works (even if no arg/return annotations)
237 |
238 |
239 | def isinstance_test(x):
240 | assert isinstance(x, Float32[jnp.ndarray, " b"])
241 | _ = x
242 |
243 |
244 | isinstance_test(jnp.array([1.0]))
245 | with pytest.raises(AssertionError):
246 | isinstance_test(jnp.array(1))
247 |
248 |
249 | @no_type_check
250 | def f(_: Float32[jnp.ndarray, "foo bar"]):
251 | pass
252 |
253 |
254 | f("not an array")
255 |
256 |
257 | # Record that we've finished our checks successfully
258 |
259 | jaxtyping._test_import_hook_counter += 1
260 |
--------------------------------------------------------------------------------
/test/requirements.txt:
--------------------------------------------------------------------------------
1 | beartype
2 | cloudpickle
3 | equinox
4 | IPython
5 | jax
6 | numpy<2
7 | pytest
8 | pytest-asyncio
9 | tensorflow
10 | typeguard<3
11 | mlx
12 |
--------------------------------------------------------------------------------
/test/test_all_importable.py:
--------------------------------------------------------------------------------
1 | # We have some pretty complicated semantics in `__init__.py`.
2 | # Here we check that we didn't miss one of them on our runtime branch.
3 | def test_all_importable():
4 | # Ordered according to their appearance in the documentation.
5 | from jaxtyping import ( # noqa: I001
6 | Shaped, # noqa: F401
7 | Bool, # noqa: F401
8 | Key, # noqa: F401
9 | Num, # noqa: F401
10 | Inexact, # noqa: F401
11 | Float, # noqa: F401
12 | BFloat16, # noqa: F401
13 | Float16, # noqa: F401
14 | Float32, # noqa: F401
15 | Float64, # noqa: F401
16 | Complex, # noqa: F401
17 | Complex64, # noqa: F401
18 | Complex128, # noqa: F401
19 | Integer, # noqa: F401
20 | UInt, # noqa: F401
21 | UInt2, # noqa: F401
22 | UInt4, # noqa: F401
23 | UInt8, # noqa: F401
24 | UInt16, # noqa: F401
25 | UInt32, # noqa: F401
26 | UInt64, # noqa: F401
27 | Int, # noqa: F401
28 | Int2, # noqa: F401
29 | Int4, # noqa: F401
30 | Int8, # noqa: F401
31 | Int16, # noqa: F401
32 | Int32, # noqa: F401
33 | Int64, # noqa: F401
34 | Real, # noqa: F401
35 | Array, # noqa: F401
36 | ArrayLike, # noqa: F401
37 | Scalar, # noqa: F401
38 | ScalarLike, # noqa: F401
39 | PRNGKeyArray, # noqa: F401
40 | PyTreeDef, # noqa: F401
41 | PyTree, # noqa: F401
42 | jaxtyped, # noqa: F401
43 | install_import_hook, # noqa: F401
44 | AbstractArray, # noqa: F401
45 | AbstractDtype, # noqa: F401
46 | print_bindings, # noqa: F401
47 | get_array_name_format, # noqa: F401
48 | set_array_name_format, # noqa: F401
49 | )
50 |
--------------------------------------------------------------------------------
/test/test_array.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Google LLC
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | import dataclasses as dc
21 | import sys
22 | from typing import Any, get_args, get_origin, TypeVar, Union
23 |
24 | import jax.numpy as jnp
25 | import jax.random as jr
26 | import numpy as np
27 | import pytest
28 |
29 |
30 | try:
31 | import torch
32 | except ImportError:
33 | torch = None
34 |
35 | from jaxtyping import (
36 | AbstractArray,
37 | AbstractDtype,
38 | AnnotationError,
39 | Array,
40 | ArrayLike,
41 | Bool,
42 | Float,
43 | Float32,
44 | jaxtyped,
45 | Key,
46 | PRNGKeyArray,
47 | Scalar,
48 | Shaped,
49 | )
50 |
51 | from .helpers import ParamError, ReturnError
52 |
53 |
54 | def test_basic(jaxtyp, typecheck):
55 | @jaxtyp(typecheck)
56 | def g(x: Shaped[Array, "..."]):
57 | pass
58 |
59 | g(jnp.array(1.0))
60 |
61 |
62 | def test_dtypes():
63 | from jaxtyping import ( # noqa: F401
64 | Array,
65 | BFloat16,
66 | Bool,
67 | Complex,
68 | Complex64,
69 | Complex128,
70 | Float,
71 | Float8e4m3b11fnuz,
72 | Float8e4m3fn,
73 | Float8e4m3fnuz,
74 | Float8e5m2,
75 | Float8e5m2fnuz,
76 | Float16,
77 | Float32,
78 | Float64,
79 | Inexact,
80 | Int,
81 | Int2,
82 | Int4,
83 | Int8,
84 | Int16,
85 | Int32,
86 | Int64,
87 | Num,
88 | Shaped,
89 | UInt,
90 | UInt2,
91 | UInt4,
92 | UInt8,
93 | UInt16,
94 | UInt32,
95 | UInt64,
96 | )
97 |
98 | for key, val in locals().items():
99 | if issubclass(val, AbstractDtype):
100 | assert key == val.__name__
101 |
102 |
103 | def test_numpy_struct_dtype():
104 | from jaxtyping import make_numpy_struct_dtype
105 |
106 | dtype1 = np.dtype([("first", np.uint8), ("second", bool)])
107 | Dtype1 = make_numpy_struct_dtype(dtype1, "Dtype1")
108 | arr = np.array([0, False], dtype=dtype1)
109 |
110 | assert isinstance(arr, Dtype1[np.ndarray, "_"])
111 |
112 | dtype2 = np.dtype([("third", np.uint8), ("second", bool)])
113 | Dtype2 = make_numpy_struct_dtype(dtype2, "Dtype2")
114 | assert not isinstance(arr, Dtype2[np.ndarray, "_"])
115 |
116 | dtype3 = np.dtype([("second", bool), ("first", np.uint8)])
117 | Dtype3 = make_numpy_struct_dtype(dtype3, "Dtype3")
118 | assert not isinstance(arr, Dtype3[np.ndarray, "_"])
119 |
120 |
121 | def test_return(jaxtyp, typecheck, getkey):
122 | @jaxtyp(typecheck)
123 | def g(x: Float[Array, "b c"]) -> Float[Array, "c b"]:
124 | return jnp.transpose(x)
125 |
126 | g(jr.normal(getkey(), (3, 4)))
127 |
128 | @jaxtyp(typecheck)
129 | def h(x: Float[Array, "b c"]) -> Float[Array, "b c"]:
130 | return jnp.transpose(x)
131 |
132 | with pytest.raises(ReturnError):
133 | h(jr.normal(getkey(), (3, 4)))
134 |
135 |
136 | def test_two_args(jaxtyp, typecheck, getkey):
137 | @jaxtyp(typecheck)
138 | def g(x: Shaped[Array, "b c"], y: Shaped[Array, "c d"]):
139 | return x @ y
140 |
141 | g(jr.normal(getkey(), (3, 4)), jr.normal(getkey(), (4, 5)))
142 | with pytest.raises(ParamError):
143 | g(jr.normal(getkey(), (3, 4)), jr.normal(getkey(), (5, 4)))
144 |
145 | @jaxtyp(typecheck)
146 | def h(x: Shaped[Array, "b c"], y: Shaped[Array, "c d"]) -> Shaped[Array, "b d"]:
147 | return x @ y
148 |
149 | h(jr.normal(getkey(), (3, 4)), jr.normal(getkey(), (4, 5)))
150 | with pytest.raises(ParamError):
151 | h(jr.normal(getkey(), (3, 4)), jr.normal(getkey(), (5, 4)))
152 |
153 |
154 | def test_any_dtype(jaxtyp, typecheck, getkey):
155 | @jaxtyp(typecheck)
156 | def g(x: Shaped[Array, "a b"]) -> Shaped[Array, "a b"]:
157 | return x
158 |
159 | g(jr.normal(getkey(), (3, 4)))
160 | g(jnp.array([[True, False]]))
161 | g(jnp.array([[1, 2], [3, 4]], dtype=jnp.int2))
162 | g(jnp.array([[1, 2], [3, 4]], dtype=jnp.int4))
163 | g(jnp.array([[1, 2], [3, 4]], dtype=jnp.int8))
164 | g(jnp.array([[1, 2], [3, 4]], dtype=jnp.uint2))
165 | g(jnp.array([[1, 2], [3, 4]], dtype=jnp.uint4))
166 | g(jnp.array([[1, 2], [3, 4]], dtype=jnp.uint16))
167 | g(jr.normal(getkey(), (3, 4), dtype=jnp.complex64))
168 | g(jr.normal(getkey(), (3, 4), dtype=jnp.bfloat16))
169 |
170 | with pytest.raises(ParamError):
171 | g(jr.normal(getkey(), (1,)))
172 |
173 |
174 | def test_nested_jaxtyped(jaxtyp, typecheck, getkey):
175 | @jaxtyp(typecheck)
176 | def g(x: Float32[Array, "b c"], transpose: bool) -> Float32[Array, "c b"]:
177 | return h(x, transpose)
178 |
179 | @jaxtyp(typecheck)
180 | def h(x: Float32[Array, "c b"], transpose: bool) -> Float32[Array, "b c"]:
181 | if transpose:
182 | return jnp.transpose(x)
183 | else:
184 | return x
185 |
186 | g(jr.normal(getkey(), (2, 3)), True)
187 | g(jr.normal(getkey(), (3, 3)), True)
188 | g(jr.normal(getkey(), (3, 3)), False)
189 | with pytest.raises(ReturnError):
190 | g(jr.normal(getkey(), (2, 3)), False)
191 |
192 |
193 | def test_nested_nojaxtyped(jaxtyp, typecheck, getkey):
194 | @jaxtyp(typecheck)
195 | def g(x: Float32[Array, "b c"]):
196 | return h(x)
197 |
198 | @typecheck
199 | def h(x: Float32[Array, "c b"]):
200 | return x
201 |
202 | with pytest.raises(ParamError):
203 | g(jr.normal(getkey(), (2, 3)))
204 |
205 |
206 | def test_isinstance(jaxtyp, typecheck, getkey):
207 | @jaxtyp(typecheck)
208 | def g(x: Float32[Array, "b c"]) -> Float32[Array, " z"]:
209 | y = jnp.transpose(x)
210 | assert isinstance(y, Float32[Array, "c b"])
211 | assert not isinstance(
212 | y, Float32[Array, "b z"]
213 | ) # z left unbound as b!=c (unless x symmetric, which it isn't)
214 | out = jr.normal(getkey(), (500,))
215 | assert isinstance(out, Float32[Array, "z"]) # z now bound
216 | return out
217 |
218 | g(jr.normal(getkey(), (2, 3)))
219 |
220 |
221 | def test_fixed(jaxtyp, typecheck, getkey):
222 | @jaxtyp(typecheck)
223 | def g(
224 | x: Float32[Array, "4 5 foo"], y: Float32[Array, " foo"]
225 | ) -> Float32[Array, "4 5"]:
226 | return x @ y
227 |
228 | a = jr.normal(getkey(), (4, 5, 2))
229 | b = jr.normal(getkey(), (2,))
230 | assert g(a, b).shape == (4, 5)
231 |
232 | c = jr.normal(getkey(), (3, 5, 2))
233 | with pytest.raises(ParamError):
234 | g(c, b)
235 |
236 |
237 | def test_anonymous(jaxtyp, typecheck, getkey):
238 | @jaxtyp(typecheck)
239 | def g(x: Float32[Array, "foo _"], y: Float32[Array, " _"]):
240 | pass
241 |
242 | a = jr.normal(getkey(), (3, 4))
243 | b = jr.normal(getkey(), (5,))
244 | g(a, b)
245 |
246 |
247 | def test_named_variadic(jaxtyp, typecheck, getkey):
248 | @jaxtyp(typecheck)
249 | def g(
250 | x: Float32[Array, "*batch foo"],
251 | y: Float32[Array, " *batch"],
252 | z: Float32[Array, " foo"],
253 | ):
254 | pass
255 |
256 | c = jr.normal(getkey(), (5,))
257 |
258 | a1 = jr.normal(getkey(), (5,))
259 | b1 = jr.normal(getkey(), ())
260 | g(a1, b1, c)
261 |
262 | a2 = jr.normal(getkey(), (3, 5))
263 | b2 = jr.normal(getkey(), (3,))
264 | g(a2, b2, c)
265 |
266 | with pytest.raises(ParamError):
267 | g(a1, b2, c)
268 | with pytest.raises(ParamError):
269 | g(a2, b1, c)
270 |
271 | @jaxtyp(typecheck)
272 | def h(x: Float32[Array, " foo *batch"], y: Float32[Array, " foo *batch bar"]):
273 | pass
274 |
275 | a = jr.normal(getkey(), (4,))
276 | b = jr.normal(getkey(), (4, 3))
277 | c = jr.normal(getkey(), (3, 4))
278 | h(a, b)
279 | with pytest.raises(ParamError):
280 | h(a, c)
281 | with pytest.raises(ParamError):
282 | h(b, c)
283 |
284 |
285 | def test_anonymous_variadic(jaxtyp, typecheck, getkey):
286 | @jaxtyp(typecheck)
287 | def g(x: Float32[Array, "... foo"], y: Float32[Array, " foo"]):
288 | pass
289 |
290 | a1 = jr.normal(getkey(), (5,))
291 | a2 = jr.normal(getkey(), (3, 5))
292 | a3 = jr.normal(getkey(), (3, 4, 5))
293 | b = jr.normal(getkey(), (5,))
294 | c = jr.normal(getkey(), (1,))
295 | g(a1, b)
296 | g(a2, b)
297 | g(a3, b)
298 | with pytest.raises(ParamError):
299 | g(a1, c)
300 | with pytest.raises(ParamError):
301 | g(a2, c)
302 | with pytest.raises(ParamError):
303 | g(a3, c)
304 |
305 |
306 | def test_broadcast_fixed(jaxtyp, typecheck, getkey):
307 | @jaxtyp(typecheck)
308 | def g(x: Float32[Array, "#4"]):
309 | pass
310 |
311 | g(jr.normal(getkey(), (4,)))
312 | g(jr.normal(getkey(), (1,)))
313 |
314 | with pytest.raises(ParamError):
315 | g(jr.normal(getkey(), (3,)))
316 |
317 |
318 | def test_broadcast_named(jaxtyp, typecheck, getkey):
319 | @jaxtyp(typecheck)
320 | def g(x: Float32[Array, " #foo"], y: Float32[Array, " #foo"]):
321 | pass
322 |
323 | a = jr.normal(getkey(), (3,))
324 | b = jr.normal(getkey(), (4,))
325 | c = jr.normal(getkey(), (1,))
326 |
327 | g(a, a)
328 | g(b, b)
329 | g(c, c)
330 | g(a, c)
331 | g(b, c)
332 | g(c, a)
333 | g(c, b)
334 |
335 | with pytest.raises(ParamError):
336 | g(a, b)
337 | with pytest.raises(ParamError):
338 | g(b, a)
339 |
340 |
341 | def test_broadcast_variadic_named(jaxtyp, typecheck, getkey):
342 | @jaxtyp(typecheck)
343 | def g(x: Float32[Array, " *#foo"], y: Float32[Array, " *#foo"]):
344 | pass
345 |
346 | a = jr.normal(getkey(), (3,))
347 | b = jr.normal(getkey(), (4,))
348 | c = jr.normal(getkey(), (4, 4))
349 | d = jr.normal(getkey(), (5, 6))
350 |
351 | j = jr.normal(getkey(), (1,))
352 | k = jr.normal(getkey(), (1, 4))
353 | l = jr.normal(getkey(), (5, 1)) # noqa: E741
354 | m = jr.normal(getkey(), (1, 1))
355 | n = jr.normal(getkey(), (2, 1))
356 | o = jr.normal(getkey(), (1, 6))
357 |
358 | g(a, a)
359 | g(b, b)
360 | g(c, c)
361 | g(d, d)
362 | g(b, c)
363 | with pytest.raises(ParamError):
364 | g(a, b)
365 | with pytest.raises(ParamError):
366 | g(a, c)
367 | with pytest.raises(ParamError):
368 | g(a, b)
369 | with pytest.raises(ParamError):
370 | g(d, b)
371 |
372 | g(a, j)
373 | g(b, j)
374 | g(c, j)
375 | g(d, j)
376 | g(b, k)
377 | g(c, k)
378 | with pytest.raises(ParamError):
379 | g(d, k)
380 | with pytest.raises(ParamError):
381 | g(c, l)
382 | g(d, l)
383 | g(a, m)
384 | g(c, m)
385 | g(d, m)
386 | g(a, n)
387 | g(b, n)
388 | with pytest.raises(ParamError):
389 | g(c, n)
390 | with pytest.raises(ParamError):
391 | g(d, n)
392 | g(o, d)
393 | with pytest.raises(ParamError):
394 | g(o, c)
395 | with pytest.raises(ParamError):
396 | g(o, a)
397 |
398 |
399 | def test_variadic_mixed_broadcast(jaxtyp, typecheck, getkey):
400 | @jaxtyp(typecheck)
401 | def f(x: Float[Array, " *foo"], y: Float[Array, " #*foo"]):
402 | pass
403 |
404 | a = jr.normal(getkey(), (3, 4))
405 | b = jr.normal(getkey(), (5,))
406 | with pytest.raises(ParamError):
407 | f(a, b)
408 |
409 | c = jr.normal(getkey(), (7, 3, 2))
410 | d = jr.normal(getkey(), (1, 2))
411 | f(c, d)
412 |
413 |
414 | def test_variadic_mixed_broadcast2(jaxtyp, typecheck, getkey):
415 | @jaxtyp(typecheck)
416 | def f(x: Float[Array, " *#foo"], y: Float[Array, " *foo"]):
417 | pass
418 |
419 | a = jr.normal(getkey(), (3, 4))
420 | b = jr.normal(getkey(), (5,))
421 | with pytest.raises(ParamError):
422 | f(a, b)
423 |
424 | c = jr.normal(getkey(), (1, 2))
425 | d = jr.normal(getkey(), (7, 3, 2))
426 | f(c, d)
427 |
428 |
429 | def test_variadic_mixed_broadcast3(jaxtyp, typecheck, getkey):
430 | @jaxtyp(typecheck)
431 | def f(
432 | x: Float[Array, "*B L D"],
433 | *,
434 | y: Float[Array, "*#B J d"],
435 | z: Bool[Array, "*B L J"],
436 | ) -> Float[Array, "*B L D"]:
437 | return x
438 |
439 | x = jr.normal(getkey(), (2, 7, 3, 2, 2))
440 | y = jr.bernoulli(getkey(), shape=(2, 7, 3, 2, 2))
441 | z = jr.normal(getkey(), (2, 7, 1, 2, 2))
442 | f(x, y=z, z=y)
443 |
444 |
445 | def test_no_commas():
446 | with pytest.raises(ValueError):
447 | Float32[Array, "foo, bar"]
448 |
449 |
450 | def test_symbolic(jaxtyp, typecheck, getkey):
451 | @jaxtyp(typecheck)
452 | def make_slice(x: Float32[Array, " dim"]) -> Float32[Array, " dim-1"]:
453 | return x[1:]
454 |
455 | @jaxtyp(typecheck)
456 | def cat(x: Float32[Array, " dim"]) -> Float32[Array, " 2*dim"]:
457 | return jnp.concatenate([x, x])
458 |
459 | @jaxtyp(typecheck)
460 | def bad_make_slice(x: Float32[Array, " dim"]) -> Float32[Array, " dim-1"]:
461 | return x
462 |
463 | @jaxtyp(typecheck)
464 | def bad_cat(x: Float32[Array, " dim"]) -> Float32[Array, " 2*dim"]:
465 | return jnp.concatenate([x, x, x])
466 |
467 | x = jr.normal(getkey(), (5,))
468 | assert make_slice(x).shape == (4,)
469 | assert cat(x).shape == (10,)
470 |
471 | y = jr.normal(getkey(), (3, 4))
472 | with pytest.raises(ParamError):
473 | make_slice(y)
474 | with pytest.raises(ParamError):
475 | cat(y)
476 |
477 | with pytest.raises(ReturnError):
478 | bad_make_slice(x)
479 | with pytest.raises(ReturnError):
480 | bad_cat(x)
481 |
482 |
483 | def test_incomplete_symbolic(jaxtyp, typecheck, getkey):
484 | @jaxtyp(typecheck)
485 | def foo(x: Float32[Array, " 2*dim"]):
486 | pass
487 |
488 | x = jr.normal(getkey(), (4,))
489 | with pytest.raises(AnnotationError):
490 | foo(x)
491 |
492 |
493 | def test_deferred_symbolic_good(jaxtyp, typecheck):
494 | @jaxtyp(typecheck)
495 | def foo(dim: int, fill: Float[Array, ""]) -> Float[Array, " {dim}"]:
496 | return jnp.full((dim,), fill)
497 |
498 | class A:
499 | size = 5
500 |
501 | @jaxtyp(typecheck)
502 | def bar(self, fill: Float[Array, ""]) -> Float[Array, " {self.size}"]:
503 | return jnp.full((self.size,), fill)
504 |
505 | foo(3, jnp.array(0.0))
506 | A().bar(jnp.array(0.0))
507 |
508 |
509 | def test_deferred_symbolic_bad(jaxtyp, typecheck):
510 | @jaxtyp(typecheck)
511 | def foo(dim: int, fill: Float[Array, ""]) -> Float[Array, " {dim-1}"]:
512 | return jnp.full((dim,), fill)
513 |
514 | class A:
515 | size = 5
516 |
517 | @jaxtyp(typecheck)
518 | def bar(self, fill: Float[Array, ""]) -> Float[Array, " {self.size}-1"]:
519 | return jnp.full((self.size,), fill)
520 |
521 | with pytest.raises(ReturnError):
522 | foo(3, jnp.array(0.0))
523 |
524 | with pytest.raises(ReturnError):
525 | A().bar(jnp.array(0.0))
526 |
527 |
528 | def test_deferred_symbolic_dataclass(typecheck):
529 | @jaxtyped(typechecker=typecheck)
530 | @dc.dataclass
531 | class A:
532 | value: int
533 | array: Float[Array, " {value}"]
534 |
535 | A(3, jnp.zeros(3))
536 |
537 | with pytest.raises(ParamError):
538 | A(3, jnp.zeros(4))
539 |
540 |
541 | def _to_set(x) -> set[tuple]:
542 | return {
543 | (xi.index_variadic, xi.dims, xi.array_type, xi.dtypes, xi.dim_str)
544 | if issubclass(xi, AbstractArray)
545 | else xi
546 | for xi in x
547 | }
548 |
549 |
550 | def test_arraylike(typecheck, getkey):
551 | floatlike1 = Float32[ArrayLike, ""]
552 | floatlike2 = Float[ArrayLike, ""]
553 | floatlike3 = Float32[ArrayLike, "4"]
554 |
555 | assert get_origin(floatlike1) is Union
556 | assert get_origin(floatlike2) is Union
557 | assert get_origin(floatlike3) is Union
558 | assert _to_set(get_args(floatlike1)) == _to_set(
559 | [
560 | Float32[Array, ""],
561 | Float32[np.ndarray, ""],
562 | Float32[np.number, ""],
563 | float,
564 | ]
565 | )
566 | assert _to_set(get_args(floatlike2)) == _to_set(
567 | [
568 | Float[Array, ""],
569 | Float[np.ndarray, ""],
570 | Float[np.number, ""],
571 | float,
572 | ]
573 | )
574 | assert _to_set(get_args(floatlike3)) == _to_set(
575 | [
576 | Float32[Array, "4"],
577 | Float32[np.ndarray, "4"],
578 | ]
579 | )
580 |
581 | shaped1 = Shaped[ArrayLike, ""]
582 | shaped2 = Shaped[ArrayLike, "4"]
583 | assert get_origin(shaped1) is Union
584 | assert get_origin(shaped2) is Union
585 | assert _to_set(get_args(shaped1)) == _to_set(
586 | [
587 | Shaped[Array, ""],
588 | Shaped[np.ndarray, ""],
589 | Shaped[np.bool_, ""],
590 | Shaped[np.number, ""],
591 | bool,
592 | int,
593 | float,
594 | complex,
595 | ]
596 | )
597 | assert _to_set(get_args(shaped2)) == _to_set(
598 | [
599 | Shaped[Array, "4"],
600 | Shaped[np.ndarray, "4"],
601 | ]
602 | )
603 |
604 |
605 | def test_ignored_names():
606 | x = Float[np.ndarray, "foo=4"]
607 |
608 | assert isinstance(np.zeros(4), x)
609 | assert not isinstance(np.zeros(5), x)
610 | assert not isinstance(np.zeros((4, 5)), x)
611 |
612 | y = Float[np.ndarray, "bar qux foo=bar+qux"]
613 |
614 | assert isinstance(np.zeros((2, 3, 5)), y)
615 | assert not isinstance(np.zeros((2, 3, 6)), y)
616 |
617 | z = Float[np.ndarray, "bar #foo=bar"]
618 |
619 | assert isinstance(np.zeros((3, 3)), z)
620 | assert isinstance(np.zeros((3, 1)), z)
621 | assert not isinstance(np.zeros((3, 4)), z)
622 |
623 | # Weird but legal
624 | w = Float[np.ndarray, "bar foo=#bar"]
625 |
626 | assert isinstance(np.zeros((3, 3)), w)
627 | assert isinstance(np.zeros((3, 1)), w)
628 | assert not isinstance(np.zeros((3, 4)), w)
629 |
630 |
631 | def test_symbolic_functions():
632 | x = Float[np.ndarray, "foo bar min(foo,bar)"]
633 |
634 | assert isinstance(np.zeros((2, 3, 2)), x)
635 | assert isinstance(np.zeros((3, 2, 2)), x)
636 | assert not isinstance(np.zeros((3, 2, 4)), x)
637 |
638 |
639 | @pytest.mark.skipif(sys.version_info < (3, 10), reason="requires Python 3.10")
640 | def test_py310_unions():
641 | x = np.zeros(3)
642 | y = Shaped[Array | np.ndarray, "_"]
643 | assert isinstance(x, get_args(y))
644 |
645 |
646 | def test_key(jaxtyp, typecheck):
647 | @jaxtyp(typecheck)
648 | def f(x: PRNGKeyArray):
649 | pass
650 |
651 | f(jr.key(0))
652 | f(jr.PRNGKey(0))
653 |
654 | with pytest.raises(ParamError):
655 | f(object())
656 | with pytest.raises(ParamError):
657 | f(1)
658 | with pytest.raises(ParamError):
659 | f(jnp.array(3))
660 | with pytest.raises(ParamError):
661 | f(jnp.array(3.0))
662 |
663 |
664 | def test_key_dtype(jaxtyp, typecheck):
665 | @jaxtyp(typecheck)
666 | def f1(x: Key[Array, ""]):
667 | pass
668 |
669 | @jaxtyp(typecheck)
670 | def f2(x: Key[Scalar, ""]):
671 | pass
672 |
673 | for f in (f1, f2):
674 | f(jr.key(0))
675 |
676 | with pytest.raises(ParamError):
677 | f(jr.PRNGKey(0))
678 | with pytest.raises(ParamError):
679 | f(object())
680 | with pytest.raises(ParamError):
681 | f(1)
682 | with pytest.raises(ParamError):
683 | f(jnp.array(3))
684 | with pytest.raises(ParamError):
685 | f(jnp.array(3.0))
686 |
687 |
688 | def test_extension(jaxtyp, typecheck, getkey):
689 | X = Shaped[Array, "a b"]
690 | Y = Shaped[X, "c d"]
691 | Z = Shaped[Array, "c d a b"]
692 | assert str(Z) == str(Y)
693 |
694 | X = Float[Array, "a"]
695 | Y = Float[X, "b"]
696 |
697 | @jaxtyp(typecheck)
698 | def f(a: X, b: Y): ...
699 |
700 | a = jr.normal(getkey(), (3, 4))
701 | b = jr.normal(getkey(), (4,))
702 | c = jr.normal(getkey(), (3,))
703 |
704 | f(b, a)
705 | with pytest.raises(ParamError):
706 | f(c, a)
707 | with pytest.raises(ParamError):
708 | f(a, a)
709 |
710 | @typecheck
711 | def g(a: Shaped[PRNGKeyArray, "2"]): ...
712 |
713 | with pytest.raises(ParamError):
714 | g(jr.PRNGKey(0))
715 | g(jr.split(jr.PRNGKey(0)))
716 | with pytest.raises(ParamError):
717 | g(jr.split(jr.PRNGKey(0), 3))
718 |
719 |
720 | def test_scalar_variadic_dim():
721 | assert Float[float, "..."] is float
722 | assert Float[float, "#*shape"] is float
723 |
724 | # This one is a bit weird -- it should really also assert that shape==(), but we
725 | # don't implement that.
726 | assert Float[float, "*shape"] is float
727 |
728 |
729 | def test_scalar_dtype_mismatch():
730 | with pytest.raises(ValueError):
731 | Float[bool, "..."]
732 |
733 |
734 | def test_custom_array(jaxtyp, typecheck):
735 | class MyArray1:
736 | @property
737 | def dtype(self):
738 | return "foo"
739 |
740 | @property
741 | def shape(self):
742 | return (3,)
743 |
744 | class MyArray2:
745 | @property
746 | def dtype(self):
747 | return "bar"
748 |
749 | @property
750 | def shape(self):
751 | return (3,)
752 |
753 | class MyArray3:
754 | @property
755 | def dtype(self):
756 | return "foo"
757 |
758 | @property
759 | def shape(self):
760 | return (4,)
761 |
762 | class FooDtype(AbstractDtype):
763 | dtypes = ["foo"]
764 |
765 | @jaxtyp(typecheck)
766 | def f(x: FooDtype[MyArray1, "3"]):
767 | pass
768 |
769 | f(MyArray1())
770 | with pytest.raises(ParamError):
771 | f(MyArray2())
772 | with pytest.raises(ParamError):
773 | f(MyArray3())
774 |
775 | @jaxtyp(typecheck)
776 | def g(x: FooDtype[MyArray1, "3"], y: FooDtype[MyArray1, "4"]):
777 | pass
778 |
779 | with pytest.raises(ParamError):
780 | g(MyArray1(), MyArray1())
781 |
782 | @jaxtyp(typecheck)
783 | def h(x: FooDtype[MyArray1, "3"], y: FooDtype[MyArray3, "4"]):
784 | pass
785 |
786 | with pytest.raises(ParamError):
787 | g(MyArray1(), MyArray3())
788 |
789 |
790 | @pytest.mark.parametrize(
791 | "array_type", [Any, TypeVar("T"), TypeVar("T", bound=ArrayLike)]
792 | )
793 | def test_any(array_type, jaxtyp, typecheck):
794 | class DuckArray1:
795 | @property
796 | def shape(self):
797 | return 3, 4
798 |
799 | @property
800 | def dtype(self):
801 | return np.array([], dtype=np.float32).dtype
802 |
803 | class DuckArray2:
804 | @property
805 | def shape(self):
806 | return 3, 4, 5
807 |
808 | @property
809 | def dtype(self):
810 | return np.array([], dtype=np.float32).dtype
811 |
812 | class DuckArray3:
813 | @property
814 | def shape(self):
815 | return 3, 4
816 |
817 | @property
818 | def dtype(self):
819 | return np.array([], dtype=np.int32).dtype
820 |
821 | @jaxtyp(typecheck)
822 | def f(x: Float[array_type, "foo bar"]):
823 | del x
824 |
825 | f(np.arange(12.0).reshape(3, 4))
826 | f(jnp.arange(12.0).reshape(3, 4))
827 | if isinstance(array_type, TypeVar) and array_type.__bound__ is ArrayLike:
828 | with pytest.raises(ParamError):
829 | f(DuckArray1())
830 | else:
831 | f(DuckArray1())
832 |
833 | # Wrong shape
834 | with pytest.raises(ParamError):
835 | f(np.arange(12.0).reshape(3, 2, 2))
836 | with pytest.raises(ParamError):
837 | f(jnp.arange(12.0).reshape(3, 2, 2))
838 | with pytest.raises(ParamError):
839 | f(DuckArray2())
840 |
841 | # Wrong dtype
842 | with pytest.raises(ParamError):
843 | f(np.arange(12).reshape(3, 4))
844 | with pytest.raises(ParamError):
845 | f(jnp.arange(12).reshape(3, 4))
846 | with pytest.raises(ParamError):
847 | f(DuckArray3())
848 |
849 | # Not an array
850 | with pytest.raises(ParamError):
851 | f(1)
852 |
853 |
854 | def test_non_instantiation():
855 | with pytest.raises(RuntimeError, match="cannot be instantiated"):
856 | Float[Array, ""]() # pyright: ignore[reportCallIssue]
857 |
--------------------------------------------------------------------------------
/test/test_decorator.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import dataclasses
3 | import sys
4 | from typing import no_type_check
5 |
6 | import equinox as eqx
7 | import jax.numpy as jnp
8 | import jax.random as jr
9 | import pytest
10 |
11 | from jaxtyping import Array, Float, jaxtyped, print_bindings
12 |
13 | from .helpers import assert_no_garbage, ParamError, ReturnError
14 |
15 |
16 | class M(metaclass=abc.ABCMeta):
17 | @jaxtyped(typechecker=None)
18 | def f(self): ...
19 |
20 | @jaxtyped(typechecker=None)
21 | @classmethod
22 | def g1(cls):
23 | return 3
24 |
25 | @classmethod
26 | @jaxtyped(typechecker=None)
27 | def g2(cls):
28 | return 4
29 |
30 | @jaxtyped(typechecker=None)
31 | @staticmethod
32 | def h1():
33 | return 3
34 |
35 | @staticmethod
36 | @jaxtyped(typechecker=None)
37 | def h2():
38 | return 4
39 |
40 | @jaxtyped(typechecker=None)
41 | @abc.abstractmethod
42 | def i1(self): ...
43 |
44 | @abc.abstractmethod
45 | @jaxtyped(typechecker=None)
46 | def i2(self): ...
47 |
48 |
49 | class N:
50 | @jaxtyped(typechecker=None)
51 | @property
52 | def j1(self):
53 | return 3
54 |
55 | @property
56 | @jaxtyped(typechecker=None)
57 | def j2(self):
58 | return 4
59 |
60 |
61 | def test_identity():
62 | assert M.f is M.f
63 |
64 |
65 | def test_classmethod():
66 | assert M.g1() == 3
67 | assert M.g2() == 4
68 |
69 |
70 | def test_staticmethod():
71 | assert M.h1() == 3
72 | assert M.h2() == 4
73 |
74 |
75 | # Check that the @jaxtyped decorator doesn't blat the __isabstractmethod__ of
76 | # @abstractmethod
77 | def test_abstractmethod():
78 | assert M.i1.__isabstractmethod__
79 | assert M.i2.__isabstractmethod__
80 |
81 |
82 | def test_property():
83 | assert N().j1 == 3
84 | assert N().j2 == 4
85 |
86 |
87 | def test_context(getkey):
88 | a = jr.normal(getkey(), (3, 4))
89 | b = jr.normal(getkey(), (5,))
90 | with jaxtyped("context"):
91 | assert isinstance(a, Float[Array, "foo bar"])
92 | assert not isinstance(b, Float[Array, "foo"])
93 | assert isinstance(a, Float[Array, "foo bar"])
94 | assert isinstance(b, Float[Array, "foo"])
95 |
96 |
97 | def test_varargs(jaxtyp, typecheck):
98 | @jaxtyp(typecheck)
99 | def f(*args) -> None:
100 | pass
101 |
102 | f(1, 2)
103 |
104 |
105 | def test_varkwargs(jaxtyp, typecheck):
106 | @jaxtyp(typecheck)
107 | def f(**kwargs) -> None:
108 | pass
109 |
110 | f(a=1, b=2)
111 |
112 |
113 | def test_defaults(jaxtyp, typecheck):
114 | @jaxtyp(typecheck)
115 | def f(x: int, y=1) -> None:
116 | pass
117 |
118 | f(1)
119 |
120 |
121 | def test_default_bindings(getkey, jaxtyp, typecheck):
122 | @jaxtyp(typecheck)
123 | def f(x: int, y: int = 1) -> Float[Array, "x {y}"]:
124 | return jr.normal(getkey(), (x, y))
125 |
126 | f(1)
127 | f(1, 1)
128 | f(1, 0)
129 | f(1, 5)
130 |
131 |
132 | class _GlobalFoo:
133 | pass
134 |
135 |
136 | def test_global_stringified_annotation(jaxtyp, typecheck):
137 | @jaxtyp(typecheck)
138 | def f(x: "_GlobalFoo") -> "_GlobalFoo":
139 | return x
140 |
141 | f(_GlobalFoo())
142 |
143 | @jaxtyp(typecheck)
144 | def g(x: int) -> "_GlobalFoo":
145 | return x
146 |
147 | @jaxtyp(typecheck)
148 | def h(x: "_GlobalFoo") -> int:
149 | return x
150 |
151 | with pytest.raises(ReturnError):
152 | g(1)
153 |
154 | with pytest.raises(ParamError):
155 | h(1)
156 |
157 |
158 | # This test does not use `jaxtyp(typecheck)` because typeguard does some evil stack
159 | # frame introspection to try and grab local variables.
160 | def test_local_stringified_annotation(typecheck):
161 | class LocalFoo:
162 | pass
163 |
164 | @jaxtyped(typechecker=typecheck)
165 | def f(x: "LocalFoo") -> "LocalFoo":
166 | return x
167 |
168 | f(LocalFoo())
169 |
170 | with pytest.warns(match="As of jaxtyping version 0.2.24"):
171 |
172 | @jaxtyped
173 | @typecheck
174 | def g(x: "LocalFoo") -> "LocalFoo":
175 | return x
176 |
177 | g(LocalFoo())
178 |
179 | # We don't check that errors are raised if it goes wrong, since we can't usually
180 | # resolve local type annotations at runtime. Best we can hope for is not to raise
181 | # a spurious error about not being able to find the type.
182 |
183 |
184 | def test_print_bindings(typecheck, capfd):
185 | @jaxtyped(typechecker=typecheck)
186 | def f(x: Float[Array, "foo bar"]):
187 | print_bindings()
188 |
189 | capfd.readouterr()
190 | f(jnp.zeros((3, 4)))
191 | text, _ = capfd.readouterr()
192 | assert text == (
193 | "The current values for each jaxtyping axis annotation are as follows."
194 | "\nfoo=3\nbar=4\n"
195 | )
196 |
197 |
198 | def test_no_type_check(typecheck):
199 | @jaxtyped(typechecker=typecheck)
200 | @no_type_check
201 | def f(x: Float[Array, "foo bar"]):
202 | pass
203 |
204 | @no_type_check
205 | @jaxtyped(typechecker=typecheck)
206 | def g(x: Float[Array, "foo bar"]):
207 | pass
208 |
209 | f("not an array")
210 | g("not an array")
211 |
212 |
213 | def test_no_garbage(typecheck):
214 | with assert_no_garbage():
215 |
216 | @jaxtyped(typechecker=typecheck)
217 | @dataclasses.dataclass
218 | class _Obj:
219 | x: int
220 |
221 | _Obj(x=5)
222 |
223 |
224 | def test_no_garbage_identity_typecheck():
225 | with assert_no_garbage():
226 |
227 | @jaxtyped(typechecker=lambda x: x)
228 | @dataclasses.dataclass
229 | class _Obj:
230 | x: int
231 |
232 | _Obj(x=5)
233 |
234 |
235 | def test_no_garbage_frame_capture_typecheck():
236 | with assert_no_garbage():
237 | # Some typechecker implementations (e.g., typeguard 2.13.3) capture the calling
238 | # frame's f_locals. This test checks that the calling frames in jaxtyping are
239 | # sufficiently isolated to avoid introducing reference cycles when a
240 | # typechecker does this.
241 | def frame_locals_capture(fn):
242 | locals = sys._getframe(1).f_locals
243 |
244 | def wrapper(*args, **kwargs):
245 | # Required to ensure wrapper holds a reference to f_locals, which is
246 | # the scenario under test.
247 | _ = locals
248 | return fn(*args, **kwargs)
249 |
250 | return wrapper
251 |
252 | @jaxtyped(typechecker=frame_locals_capture)
253 | @dataclasses.dataclass
254 | class _Obj:
255 | x: int
256 |
257 | _Obj(x=5)
258 |
259 |
260 | def test_equinox_converter(typecheck):
261 | def _typed_str(x: int) -> str:
262 | return str(x)
263 |
264 | @jaxtyped(typechecker=typecheck)
265 | class X(eqx.Module):
266 | x: str = eqx.field(converter=_typed_str)
267 |
268 | X(1)
269 | with pytest.raises(ParamError):
270 | X("1")
271 |
272 |
273 | def test_mlx(jaxtyp, typecheck):
274 | import mlx.core as mx
275 | import numpy as np
276 |
277 | @jaxtyp(typecheck)
278 | def hello(x: Float[mx.array, "8 16"]):
279 | pass
280 |
281 | hello(mx.zeros((8, 16), dtype=mx.float32))
282 |
283 | with pytest.raises(ParamError):
284 | hello(mx.zeros((8, 14), dtype=mx.float32))
285 |
286 | with pytest.raises(ParamError):
287 | hello(np.zeros((8, 16), dtype=np.float32))
288 |
289 | with pytest.raises(ParamError):
290 | hello(mx.zeros((8, 16), dtype=mx.int32))
291 |
292 |
293 | # In particular the below scenario occurs during the import hook when we have an
294 | # explicit `__init__`.
295 | def test_no_rewrapping_of_dataclass_init(typecheck):
296 | @jaxtyped(typechecker=typecheck)
297 | @dataclasses.dataclass
298 | class Foo:
299 | x: int
300 |
301 | @jaxtyped(typechecker=typecheck)
302 | def __init__(self, x: int):
303 | self.x = x
304 |
305 | wrapped = Foo.__init__.__wrapped__
306 | with pytest.raises(AttributeError):
307 | wrapped.__wrapped__
308 |
309 |
310 | def test_stringified_multiple_varaidic(typecheck):
311 | @jaxtyped(typechecker=typecheck)
312 | def foo() -> 'Float[Array, "*foo *bar"]':
313 | return jnp.arange(3)
314 |
315 | foo()
316 |
--------------------------------------------------------------------------------
/test/test_generators.py:
--------------------------------------------------------------------------------
1 | from collections.abc import AsyncIterator, Iterator
2 |
3 | import jax.numpy as jnp
4 | import pytest
5 |
6 | from jaxtyping import Array, Float, Shaped
7 |
8 | from .helpers import ParamError
9 |
10 |
11 | try:
12 | import torch
13 | except ImportError:
14 | torch = None
15 |
16 |
17 | def test_generators_simple(jaxtyp, typecheck):
18 | @jaxtyp(typecheck)
19 | def gen(x: Float[Array, "*"]) -> Iterator[Float[Array, "*"]]:
20 | yield x
21 |
22 | @jaxtyp(typecheck)
23 | def foo() -> None:
24 | next(gen(jnp.zeros(2)))
25 | next(gen(jnp.zeros((3, 4))))
26 |
27 | foo()
28 |
29 |
30 | def test_generators_return_no_annotations(jaxtyp, typecheck):
31 | @jaxtyp(typecheck)
32 | def gen(x: Float[Array, "*"]):
33 | yield x
34 |
35 | @jaxtyp(typecheck)
36 | def foo():
37 | next(gen(jnp.zeros(2)))
38 | next(gen(jnp.zeros((3, 4))))
39 |
40 | foo()
41 |
42 |
43 | @pytest.mark.asyncio
44 | async def test_async_generators_simple(jaxtyp, typecheck):
45 | @jaxtyp(typecheck)
46 | async def gen(x: Float[Array, "*"]) -> AsyncIterator[Float[Array, "*"]]:
47 | yield x
48 |
49 | @jaxtyp(typecheck)
50 | async def foo():
51 | async for _ in gen(jnp.zeros(2)):
52 | pass
53 | async for _ in gen(jnp.zeros((3, 4))):
54 | pass
55 |
56 | await foo()
57 |
58 |
59 | def test_generators_dont_modify_same_annotations(jaxtyp, typecheck):
60 | @jaxtyp(typecheck)
61 | def g(x: Float[Array, "1"]) -> Iterator[Float[Array, "1"]]:
62 | yield x
63 |
64 | @jaxtyp(typecheck)
65 | def m(x: Float[Array, "1"]) -> Float[Array, "1"]:
66 | return x
67 |
68 | with pytest.raises(ParamError):
69 | next(g(jnp.zeros(2)))
70 | with pytest.raises(ParamError):
71 | m(jnp.zeros(2))
72 |
73 |
74 | def test_generators_original_issue(jaxtyp, typecheck):
75 | # Effectively the same as https://github.com/patrick-kidger/jaxtyping/issues/91
76 | if torch is None:
77 | pytest.skip("torch is not available")
78 |
79 | @jaxtyp(typecheck)
80 | def g(x: Shaped[torch.Tensor, "*"]) -> Iterator[Shaped[torch.Tensor, "*"]]:
81 | yield x
82 |
83 | @jaxtyp(typecheck)
84 | def f() -> None:
85 | next(g(torch.zeros(1)))
86 | next(g(torch.zeros(2)))
87 |
88 | f()
89 |
--------------------------------------------------------------------------------
/test/test_import_hook.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Google LLC
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | import importlib
21 | import importlib.metadata
22 | import pathlib
23 | import shutil
24 | import sys
25 | import tempfile
26 |
27 | import pytest
28 |
29 | import jaxtyping
30 |
31 |
32 | _here = pathlib.Path(__file__).parent
33 |
34 |
35 | try:
36 | typeguard_version = importlib.metadata.version("typeguard")
37 | except Exception as e:
38 | raise ImportError("Could not find typeguard version") from e
39 | else:
40 | try:
41 | major, _, _ = typeguard_version.split(".")
42 | major = int(major)
43 | except Exception as e:
44 | raise ImportError(
45 | f"Unexpected typeguard version {typeguard_version}; not formatted as "
46 | "`major.minor.patch`"
47 | ) from e
48 | if major != 2:
49 | raise ImportError(
50 | "jaxtyping's tests required typeguard version 2. (Versions 3 and 4 are both "
51 | "known to have bugs.)"
52 | )
53 |
54 |
55 | assert not hasattr(jaxtyping, "_test_import_hook_counter")
56 | jaxtyping._test_import_hook_counter = 0
57 |
58 |
59 | @pytest.fixture(scope="module")
60 | def importhook_tempdir():
61 | with tempfile.TemporaryDirectory() as dir:
62 | sys.path.append(dir)
63 | dir = pathlib.Path(dir)
64 | shutil.copyfile(_here / "helpers.py", dir / "helpers.py")
65 | yield dir
66 |
67 |
68 | def _test_import_hook(importhook_tempdir, typechecker):
69 | counter = jaxtyping._test_import_hook_counter
70 | stem = f"import_hook_tester{counter}"
71 | shutil.copyfile(_here / "import_hook_tester.py", importhook_tempdir / f"{stem}.py")
72 |
73 | importlib.invalidate_caches()
74 | with jaxtyping.install_import_hook(stem, typechecker):
75 | importlib.import_module(stem)
76 | assert counter + 1 == jaxtyping._test_import_hook_counter
77 |
78 |
79 | # Tests start below...
80 |
81 |
82 | def test_import_hook_typeguard(importhook_tempdir, typeguard_or_skip):
83 | _test_import_hook(importhook_tempdir, "typeguard.typechecked")
84 |
85 |
86 | def test_import_hook_beartype(importhook_tempdir, beartype_or_skip):
87 | _test_import_hook(importhook_tempdir, "beartype.beartype")
88 |
89 |
90 | def test_import_hook_beartype_full(importhook_tempdir, beartype_or_skip):
91 | bearchecker = "beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))" # noqa: E501
92 | _test_import_hook(importhook_tempdir, bearchecker)
93 |
94 |
95 | def test_import_hook_typeguard_old(importhook_tempdir, typeguard_or_skip):
96 | _test_import_hook(importhook_tempdir, ("typeguard", "typechecked"))
97 |
98 |
99 | def test_import_hook_beartype_old(importhook_tempdir, beartype_or_skip):
100 | _test_import_hook(importhook_tempdir, ("beartype", "beartype"))
101 |
102 |
103 | def test_import_hook_broken_checker(importhook_tempdir):
104 | with pytest.raises(AttributeError):
105 | _test_import_hook(importhook_tempdir, "jaxtyping.does_not_exist")
106 |
107 |
108 | def test_import_hook_transitive(importhook_tempdir, typeguard_or_skip):
109 | counter = jaxtyping._test_import_hook_counter
110 | transitive_name = "jaxtyping_transitive_test"
111 | transitive_dir = importhook_tempdir / transitive_name
112 | transitive_dir.mkdir()
113 | shutil.copyfile(_here / "import_hook_tester.py", transitive_dir / "tester.py")
114 | with open(transitive_dir / "__init__.py", "w") as f:
115 | f.write("from . import tester")
116 | f.flush()
117 |
118 | importlib.invalidate_caches()
119 | with jaxtyping.install_import_hook(transitive_name, "typeguard.typechecked"):
120 | importlib.import_module(transitive_name)
121 | assert counter + 1 == jaxtyping._test_import_hook_counter
122 |
--------------------------------------------------------------------------------
/test/test_ipython_extension.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from IPython.testing.globalipapp import start_ipython
3 |
4 | from .helpers import ParamError
5 |
6 |
7 | @pytest.fixture(scope="session")
8 | def session_ip():
9 | yield start_ipython()
10 |
11 |
12 | @pytest.fixture(scope="function")
13 | def ip(session_ip):
14 | session_ip.run_cell(raw_cell="import jaxtyping")
15 | session_ip.run_line_magic(magic_name="load_ext", line="jaxtyping")
16 | session_ip.run_line_magic(
17 | magic_name="jaxtyping.typechecker", line="typeguard.typechecked"
18 | )
19 | yield session_ip
20 |
21 |
22 | def test_that_ipython_works(ip):
23 | ip.run_cell(raw_cell="x = 1").raise_error()
24 | assert ip.user_global_ns["x"] == 1
25 |
26 |
27 | def test_function_beartype(ip):
28 | ip.run_cell(
29 | raw_cell="""
30 | def f(x: int):
31 | pass
32 | """
33 | ).raise_error()
34 | ip.run_cell(raw_cell="f(1)").raise_error()
35 |
36 | with pytest.raises(ParamError):
37 | ip.run_cell(raw_cell='f("x")').raise_error()
38 |
39 |
40 | def test_function_none(ip):
41 | ip.run_cell(
42 | raw_cell="""
43 | def f(a,b,c):
44 | pass
45 | """
46 | ).raise_error()
47 | ip.run_cell(raw_cell='f(1,2,"k")').raise_error()
48 |
49 |
50 | def test_function_jaxtyped(ip):
51 | ip.run_cell(
52 | raw_cell="""
53 | from jaxtyping import Float, Array, Int
54 | import jax
55 |
56 | def g(x: Float[Array, "1"]):
57 | return x + 1
58 |
59 | """
60 | ).raise_error()
61 |
62 | ip.run_cell(raw_cell="g(jax.numpy.array([1.0]))").raise_error()
63 |
64 | with pytest.raises(ParamError):
65 | ip.run_cell(raw_cell="g(jax.numpy.array(1.0))").raise_error()
66 |
67 | with pytest.raises(ParamError):
68 | ip.run_cell(raw_cell="g(jax.numpy.array([1]))").raise_error()
69 |
70 | with pytest.raises(ParamError):
71 | ip.run_cell(raw_cell="g(jax.numpy.array([2, 3]))").raise_error()
72 |
73 | with pytest.raises(ParamError):
74 | ip.run_cell(raw_cell='g("string")').raise_error()
75 |
76 |
77 | def test_function_jaxtyped_and_jitted(ip):
78 | ip.run_cell(
79 | raw_cell="""
80 | from jaxtyping import Float, Array, Int
81 | import jax
82 |
83 | @jax.jit
84 | def g(x: Float[Array, "1"]):
85 | return x + 1
86 |
87 | """
88 | ).raise_error()
89 |
90 | ip.run_cell(raw_cell="g(jax.numpy.array([1.0]))").raise_error()
91 |
92 | with pytest.raises(ParamError):
93 | ip.run_cell(raw_cell="g(jax.numpy.array(1.0))").raise_error()
94 |
95 | with pytest.raises(ParamError):
96 | ip.run_cell(raw_cell="g(jax.numpy.array([1]))").raise_error()
97 |
98 | with pytest.raises(ParamError):
99 | ip.run_cell(raw_cell="g(jax.numpy.array([2, 3]))").raise_error()
100 |
101 | with pytest.raises(ParamError):
102 | ip.run_cell(raw_cell='g("string")').raise_error()
103 |
104 |
105 | def test_class_jaxtyped(ip):
106 | ip.run_cell(
107 | raw_cell="""
108 | from jaxtyping import Float, Array, Int
109 | import equinox as eqx
110 | import jax
111 |
112 | class A(eqx.Module):
113 | x: Float[Array, "2"]
114 |
115 | def do_something(self, y: Int[Array, ""]):
116 | return self.x + y
117 | """
118 | ).raise_error()
119 |
120 | ip.run_cell(raw_cell="a = A(jax.numpy.array([1.0, 2.0]))").raise_error()
121 | ip.run_cell(raw_cell="a.do_something(jax.numpy.array(2))").raise_error()
122 |
123 | with pytest.raises(ParamError):
124 | ip.run_cell(raw_cell="A(jax.numpy.array([1.0]))").raise_error()
125 |
126 | with pytest.raises(ParamError):
127 | ip.run_cell(
128 | raw_cell="a.do_something(jax.numpy.array([2.0, 3.0]))"
129 | ).raise_error()
130 |
131 |
132 | def test_class_not_dataclass(ip):
133 | ip.run_cell(
134 | raw_cell="""
135 | from jaxtyping import Float, Array, Int
136 | import equinox as eqx
137 | import jax
138 |
139 | class A:
140 | def __init__(self, x):
141 | self.x = x
142 |
143 | def do_something(self, y):
144 | return x + y
145 | """
146 | ).raise_error()
147 |
148 | ip.run_cell(raw_cell="a = A(jax.numpy.array([1.0, 2.0]))").raise_error()
149 | ip.run_cell(raw_cell="a.do_something(jax.numpy.array(2))").raise_error()
150 | ip.run_cell(raw_cell="A(jax.numpy.array([1.0]))").raise_error()
151 | ip.run_cell(raw_cell="a.do_something(jax.numpy.array([2.0, 3.0]))").raise_error()
152 |
--------------------------------------------------------------------------------
/test/test_messages.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import equinox as eqx
4 | import jax.numpy as jnp
5 | import pytest
6 |
7 | from jaxtyping import Array, Float, jaxtyped, PyTree, TypeCheckError
8 |
9 |
10 | def test_arg_localisation(typecheck):
11 | @jaxtyped(typechecker=typecheck)
12 | def f(x: str, y: str, z: int):
13 | pass
14 |
15 | matches = [
16 | "Type-check error whilst checking the parameters of .*.f",
17 | "The problem arose whilst typechecking parameter 'z'.",
18 | "Called with parameters: {'x': 'hi', 'y': 'bye', 'z': 'not-an-int'}",
19 | r"Parameter annotations: \(x: str, y: str, z: int\).",
20 | ]
21 | for match in matches:
22 | with pytest.raises(TypeCheckError, match=match):
23 | f("hi", "bye", "not-an-int")
24 |
25 | @jaxtyped(typechecker=typecheck)
26 | def g(x: Float[Array, "a b"], y: Float[Array, "b c"]):
27 | pass
28 |
29 | x = jnp.zeros((2, 3))
30 | y = jnp.zeros((4, 3))
31 | matches = [
32 | "Type-check error whilst checking the parameters of .*..g",
33 | "The problem arose whilst typechecking parameter 'y'.",
34 | r"Called with parameters: {'x': f32\[2,3\], 'y': f32\[4,3\]}",
35 | (
36 | r"Parameter annotations: \(x: Float\[Array, 'a b'\], y: "
37 | r"Float\[Array, 'b c'\]\)."
38 | ),
39 | "The current values for each jaxtyping axis annotation are as follows.",
40 | "a=2",
41 | "b=3",
42 | ]
43 | for match in matches:
44 | with pytest.raises(TypeCheckError, match=match):
45 | g(x, y=y)
46 |
47 |
48 | def test_return(typecheck):
49 | @jaxtyped(typechecker=typecheck)
50 | def f(x: PyTree[Any, " T"], y: PyTree[Any, " S"]) -> PyTree[Any, "T S"]:
51 | return "foo"
52 |
53 | x = (1, 2)
54 | y = {"a": 1}
55 | matches = [
56 | "Type-check error whilst checking the return value of .*..f",
57 | r"Called with parameters: {'x': \(1, 2\), 'y': {'a': 1}}",
58 | "Actual value: 'foo'",
59 | r"Expected type: PyTree\[Any, 'T S'\].",
60 | (
61 | "The current values for each jaxtyping PyTree structure annotation are as "
62 | "follows."
63 | ),
64 | r"T=PyTreeDef\(\(\*, \*\)\)",
65 | r"S=PyTreeDef\({'a': \*}\)",
66 | ]
67 | for match in matches:
68 | with pytest.raises(TypeCheckError, match=match):
69 | f(x, y=y)
70 |
71 |
72 | def test_dataclass_init(typecheck):
73 | @jaxtyped(typechecker=typecheck)
74 | class M(eqx.Module):
75 | x: Float[Array, " *foo"]
76 | y: PyTree[Any, " T"]
77 | z: int
78 |
79 | x = jnp.zeros((2, 3))
80 | y = (1, (3, 4))
81 | z = "not-an-int"
82 |
83 | matches = [
84 | "Type-check error whilst checking the parameters of .*..M",
85 | "The problem arose whilst typechecking parameter 'z'.",
86 | (
87 | r"Called with parameters: {'self': M\(\.\.\.\), 'x': f32\[2,3\], "
88 | r"'y': \(1, \(3, 4\)\), 'z': 'not-an-int'}"
89 | ),
90 | (
91 | r"Parameter annotations: \(self, x: Float\[Array, '\*foo'\], "
92 | r"y: PyTree\[Any, 'T'\], z: int\)."
93 | ),
94 | "The current values for each jaxtyping axis annotation are as follows.",
95 | r"foo=\(2, 3\)",
96 | (
97 | "The current values for each jaxtyping PyTree structure annotation are as "
98 | "follows."
99 | ),
100 | r"T=PyTreeDef\(\(\*, \(\*, \*\)\)\)",
101 | ]
102 | for match in matches:
103 | with pytest.raises(TypeCheckError, match=match):
104 | M(x, y, z)
105 |
--------------------------------------------------------------------------------
/test/test_no_jax_dependency.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import sys
3 | import unittest
4 |
5 |
6 | _py_path = sys.executable
7 |
8 |
9 | @unittest.skipIf(not _py_path, "test requires sys.executable")
10 | def test_no_jax_dependency():
11 | result = subprocess.run(
12 | f"{_py_path} -c "
13 | "'import jaxtyping; import sys; sys.exit(\"jax\" in sys.modules)'",
14 | shell=True,
15 | )
16 | assert result.returncode == 0
17 |
18 |
19 | # Meta-test: test that the above test will work. (i.e. that I haven't messed up using
20 | # subprocess.)
21 | @unittest.skipIf(not _py_path, "test requires sys.executable")
22 | def test_meta():
23 | result = subprocess.run(
24 | f"{_py_path} -c 'import jaxtyping; import jax; import sys; "
25 | 'sys.exit("jax" in sys.modules)\'',
26 | shell=True,
27 | )
28 | assert result.returncode == 1
29 |
--------------------------------------------------------------------------------
/test/test_pytree.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Google LLC
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | from collections.abc import Callable
21 | from typing import NamedTuple, Tuple, Union
22 |
23 | import equinox as eqx
24 | import jax
25 | import jax.numpy as jnp
26 | import jax.random as jr
27 | import pytest
28 | import wadler_lindig as wl
29 |
30 | import jaxtyping
31 | from jaxtyping import AnnotationError, Array, Float, PyTree
32 |
33 | from .helpers import make_mlp, ParamError
34 |
35 |
36 | def test_direct(typecheck):
37 | @typecheck
38 | def g(x: PyTree):
39 | pass
40 |
41 | g(1)
42 | g({"a": jnp.array(1), "b": [object()]})
43 | g(object())
44 |
45 | @typecheck
46 | def h() -> PyTree:
47 | return object()
48 |
49 | h()
50 |
51 |
52 | def test_subscript(getkey, typecheck):
53 | @typecheck
54 | def g(x: PyTree[int]):
55 | pass
56 |
57 | g(1)
58 | g([1, 2, {"a": 3}])
59 | g(jax.tree.map(lambda _: 1, make_mlp(getkey())))
60 |
61 | with pytest.raises(ParamError):
62 | g(object())
63 | with pytest.raises(ParamError):
64 | g("hi")
65 | with pytest.raises(ParamError):
66 | g([1, 2, {"a": 3}, "bye"])
67 |
68 |
69 | def test_leaf_pytrees(getkey, typecheck):
70 | @typecheck
71 | def g(x: PyTree[eqx.nn.MLP]):
72 | pass
73 |
74 | g(make_mlp(getkey()))
75 | g([make_mlp(getkey()), make_mlp(getkey()), {"a": make_mlp(getkey())}])
76 |
77 | with pytest.raises(ParamError):
78 | g([1, 2])
79 | with pytest.raises(ParamError):
80 | g([1, 2, make_mlp()])
81 |
82 |
83 | def test_nested_pytrees(getkey, typecheck):
84 | # PyTree[...] is logically equivalent to PyTree[PyTree[...]]
85 | @typecheck
86 | def g(x: PyTree[PyTree[eqx.nn.MLP]]):
87 | pass
88 |
89 | g(make_mlp(getkey()))
90 | g([make_mlp(getkey()), make_mlp(getkey()), {"a": make_mlp(getkey())}])
91 |
92 | with pytest.raises(ParamError):
93 | g([1, 2])
94 | with pytest.raises(ParamError):
95 | g([1, 2, make_mlp()])
96 |
97 |
98 | def test_pytree_array(jaxtyp, typecheck):
99 | @jaxtyp(typecheck)
100 | def g(x: PyTree[Float[jnp.ndarray, "..."]]):
101 | pass
102 |
103 | g(jnp.array(1.0))
104 | g([jnp.array(1.0), jnp.array(1.0), {"a": jnp.array(1.0)}])
105 | with pytest.raises(ParamError):
106 | g(jnp.array(1))
107 | with pytest.raises(ParamError):
108 | g(1.0)
109 |
110 |
111 | def test_pytree_shaped_array(jaxtyp, typecheck, getkey):
112 | @jaxtyp(typecheck)
113 | def g(x: PyTree[Float[jnp.ndarray, "b c"]]):
114 | pass
115 |
116 | g(jnp.array([[1.0]]))
117 | g([jr.normal(getkey(), (1, 1)), jr.normal(getkey(), (1, 1))])
118 | g([jr.normal(getkey(), (3, 4)), jr.normal(getkey(), (3, 4))])
119 | with pytest.raises(ParamError):
120 | g(jnp.array(1.0))
121 | with pytest.raises(ParamError):
122 | g(jnp.array([[1]]))
123 | with pytest.raises(ParamError):
124 | g([jr.normal(getkey(), (3, 4)), jr.normal(getkey(), (1, 1))])
125 | with pytest.raises(ParamError):
126 | g([jr.normal(getkey(), (1, 1)), jr.normal(getkey(), (3, 4))])
127 |
128 |
129 | def test_pytree_union(typecheck):
130 | @typecheck
131 | def g(x: PyTree[Union[int, str]]):
132 | pass
133 |
134 | g([1])
135 | g(["hi"])
136 | g([1, "hi"])
137 | with pytest.raises(ParamError):
138 | g(object())
139 |
140 |
141 | def test_pytree_tuple(typecheck):
142 | @typecheck
143 | def g(x: PyTree[Tuple[int, int]]):
144 | pass
145 |
146 | g((1, 1))
147 | g([(1, 1)])
148 | g([(1, 1), {"a": (1, 1)}])
149 | with pytest.raises(ParamError):
150 | g(object())
151 | with pytest.raises(ParamError):
152 | g(1)
153 | with pytest.raises(ParamError):
154 | g((1, 2, 3))
155 | with pytest.raises(ParamError):
156 | g([1, 1])
157 | with pytest.raises(ParamError):
158 | g([(1, 1), "hi"])
159 |
160 |
161 | def test_pytree_namedtuple(typecheck):
162 | class CustomNamedTuple(NamedTuple):
163 | x: Float[jnp.ndarray, "a b"]
164 | y: Float[jnp.ndarray, "b c"]
165 |
166 | class OtherCustomNamedTuple(NamedTuple):
167 | x: Float[jnp.ndarray, "a b"]
168 | y: Float[jnp.ndarray, "b c"]
169 |
170 | @typecheck
171 | def g(x: PyTree[CustomNamedTuple]): ...
172 |
173 | g(
174 | CustomNamedTuple(
175 | x=jax.random.normal(jax.random.PRNGKey(42), (3, 2)),
176 | y=jax.random.normal(jax.random.PRNGKey(420), (2, 5)),
177 | )
178 | )
179 | with pytest.raises(ParamError):
180 | g(object())
181 | with pytest.raises(ParamError):
182 | g(
183 | OtherCustomNamedTuple(
184 | x=jax.random.normal(jax.random.PRNGKey(42), (3, 2)),
185 | y=jax.random.normal(jax.random.PRNGKey(420), (2, 5)),
186 | )
187 | )
188 |
189 |
190 | def test_subclass_pytree():
191 | x = PyTree
192 | y = PyTree[int]
193 | assert issubclass(x, PyTree)
194 | assert issubclass(y, PyTree)
195 | assert not issubclass(int, PyTree)
196 |
197 |
198 | def test_structure_match(jaxtyp, typecheck):
199 | @jaxtyp(typecheck)
200 | def f(x: PyTree[int, " T"], y: PyTree[str, " T"]):
201 | pass
202 |
203 | f(1, "hi")
204 | f((3, 4, {"a": 5}), ("a", "b", {"a": "c"}))
205 |
206 | with pytest.raises(ParamError):
207 | f(1, ("hi",))
208 |
209 |
210 | def test_structure_prefix(jaxtyp, typecheck):
211 | @jaxtyp(typecheck)
212 | def f(x: PyTree[int, " T"], y: PyTree[str, "T ..."]):
213 | pass
214 |
215 | f(1, "hi")
216 | f((3, 4, {"a": 5}), ("a", "b", {"a": "c"}))
217 | f(1, ("hi",))
218 | f((1, 2), ({"a": "hi"}, {"a": "bye"}))
219 | f((1, 2), ({"a": "hi"}, {"not-a": "bye"}))
220 |
221 | with pytest.raises(ParamError):
222 | f((1, 2), ({"a": "hi"}, {"a": "bye"}, {"a": "oh-no"}))
223 |
224 | with pytest.raises(ParamError):
225 | f((3, 4, 5), {"a": ("hi", "bye")})
226 |
227 |
228 | def test_structure_suffix(jaxtyp, typecheck):
229 | @jaxtyp(typecheck)
230 | def f(x: PyTree[int, " T"], y: PyTree[str, "... T"]):
231 | pass
232 |
233 | f(1, "hi")
234 | f((3, 4, {"a": 5}), ("a", "b", {"a": "c"}))
235 | f(1, ("hi",))
236 |
237 | with pytest.raises(ParamError):
238 | f((3, 4), {"a": (1, 2)})
239 |
240 | with pytest.raises(ParamError):
241 | f((3, 4, 5), {"a": ("hi", "bye")})
242 |
243 |
244 | def test_structure_compose(jaxtyp, typecheck):
245 | @jaxtyp(typecheck)
246 | def f(x: PyTree[int, " T"], y: PyTree[int, " S"], z: PyTree[str, "S T"]):
247 | pass
248 |
249 | f(1, 2, "hi")
250 | f((1, 2), 2, ("a", "b"))
251 |
252 | with pytest.raises(ParamError):
253 | f((1, 2), 2, (1, 2))
254 |
255 | f((1, 2), {"a": 3}, {"a": ("hi", "bye")})
256 |
257 | with pytest.raises(ParamError):
258 | f((1, 2), {"a": 3}, ({"a": "hi"}, {"a": "bye"}))
259 |
260 | @jaxtyp(typecheck)
261 | def g(x: PyTree[int, " T"], y: PyTree[int, " S"], z: PyTree[str, "T S"]):
262 | pass
263 |
264 | with pytest.raises(ParamError):
265 | g((1, 2), {"a": 3}, {"a": ("hi", "bye")})
266 |
267 | g((1, 2), {"a": 3}, ({"a": "hi"}, {"a": "bye"}))
268 |
269 |
270 | @pytest.mark.parametrize("variadic", (False, True))
271 | def test_treepath_dependence_function(variadic, jaxtyp, typecheck, getkey):
272 | if variadic:
273 | jtshape = "*?foo"
274 | shape = (2, 3)
275 | else:
276 | jtshape = "?foo"
277 | shape = (4,)
278 |
279 | @jaxtyp(typecheck)
280 | def f(
281 | x: PyTree[Float[Array, jtshape], " T"], y: PyTree[Float[Array, jtshape], " T"]
282 | ):
283 | pass
284 |
285 | x1 = jr.normal(getkey(), shape)
286 | y1 = jr.normal(getkey(), shape)
287 | x2 = jr.normal(getkey(), (5,))
288 | y2 = jr.normal(getkey(), (5,))
289 | f(x1, y1)
290 | f((x1, x2), (y1, y2))
291 |
292 | with pytest.raises(ParamError):
293 | f(x1, y2)
294 |
295 | with pytest.raises(ParamError):
296 | f((x1, x2), (y2, y1))
297 |
298 |
299 | @pytest.mark.parametrize("variadic", (False, True))
300 | def test_treepath_dependence_dataclass(variadic, typecheck, getkey):
301 | if variadic:
302 | jtshape = "*?foo"
303 | shape = (2, 3)
304 | else:
305 | jtshape = "?foo"
306 | shape = (4,)
307 |
308 | @jaxtyping.jaxtyped(typechecker=typecheck)
309 | class A(eqx.Module):
310 | x: PyTree[Float[Array, jtshape], " T"]
311 | y: PyTree[Float[Array, jtshape], " T"]
312 |
313 | x1 = jr.normal(getkey(), shape)
314 | y1 = jr.normal(getkey(), shape)
315 | x2 = jr.normal(getkey(), (5,))
316 | y2 = jr.normal(getkey(), (5,))
317 | A(x1, y1)
318 | A((x1, x2), (y1, y2))
319 |
320 | with pytest.raises(ParamError):
321 | A(x1, y2)
322 |
323 | with pytest.raises(ParamError):
324 | A((x1, x2), (y2, y1))
325 |
326 |
327 | def test_treepath_dependence_missing_structure_annotation(jaxtyp, typecheck, getkey):
328 | @jaxtyp(typecheck)
329 | def f(x: PyTree[Float[Array, "?foo"], " T"], y: PyTree[Float[Array, "?foo"]]):
330 | pass
331 |
332 | x1 = jr.normal(getkey(), (2,))
333 | y1 = jr.normal(getkey(), (2,))
334 | with pytest.raises(AnnotationError, match="except when contained with structured"):
335 | f(x1, y1)
336 |
337 |
338 | def test_treepath_dependence_multiple_structure_annotation(jaxtyp, typecheck, getkey):
339 | @jaxtyp(typecheck)
340 | def f(x: PyTree[PyTree[Float[Array, "?foo"], " S"], " T"]):
341 | pass
342 |
343 | x1 = jr.normal(getkey(), (2,))
344 | with pytest.raises(AnnotationError, match="ambiguous which PyTree"):
345 | f(x1)
346 |
347 |
348 | def test_name():
349 | assert PyTree.__name__ == "PyTree"
350 | assert PyTree[int].__name__ == "PyTree[int]"
351 | assert PyTree[int, "foo"].__name__ == "PyTree[int, 'foo']"
352 | assert PyTree[PyTree[str], "foo"].__name__ == "PyTree[PyTree[str], 'foo']"
353 | assert (
354 | PyTree[PyTree[str, "bar"], "foo"].__name__
355 | == "PyTree[PyTree[str, 'bar'], 'foo']"
356 | )
357 | assert PyTree[PyTree[str, "bar"]].__name__ == "PyTree[PyTree[str, 'bar']]"
358 | assert (
359 | PyTree[None | Callable[[PyTree[int, " T"]], str]].__name__
360 | == "PyTree[None | Callable[[PyTree[int, 'T']], str]]"
361 | )
362 |
363 |
364 | def test_pdoc():
365 | assert wl.pformat(PyTree) == "PyTree"
366 | assert wl.pformat(PyTree[int]) == "PyTree[int]"
367 | assert wl.pformat(PyTree[int, "foo"]) == "PyTree[int, 'foo']"
368 | assert wl.pformat(PyTree[PyTree[str], "foo"]) == "PyTree[PyTree[str], 'foo']"
369 | assert (
370 | wl.pformat(PyTree[PyTree[str, "bar"], "foo"])
371 | == "PyTree[PyTree[str, 'bar'], 'foo']"
372 | )
373 | assert wl.pformat(PyTree[PyTree[str, "bar"]]) == "PyTree[PyTree[str, 'bar']]"
374 | assert (
375 | wl.pformat(PyTree[None | Callable[[PyTree[int, " T"]], str]])
376 | == "PyTree[None | Callable[[PyTree[int, 'T']], str]]"
377 | )
378 | expected = """
379 | PyTree[
380 | None
381 | | Callable[
382 | [
383 | PyTree[
384 | int,
385 | 'T'
386 | ]
387 | ],
388 | str
389 | ]
390 | ]
391 | """.strip()
392 | assert (
393 | wl.pformat(PyTree[None | Callable[[PyTree[int, " T"]], str]], width=2).strip()
394 | == expected
395 | )
396 |
--------------------------------------------------------------------------------
/test/test_serialisation.py:
--------------------------------------------------------------------------------
1 | import pickle
2 |
3 | import cloudpickle
4 | import numpy as np
5 |
6 |
7 | try:
8 | import torch
9 | except ImportError:
10 | torch = None
11 |
12 | from jaxtyping import AbstractArray, Array, Float, Shaped
13 |
14 |
15 | def test_pickle():
16 | for p in (pickle, cloudpickle):
17 | x = p.dumps(Shaped[Array, ""])
18 | y = p.loads(x)
19 | assert y.dtype is Shaped
20 | assert y.dim_str == ""
21 |
22 | x = p.dumps(AbstractArray)
23 | y = p.loads(x)
24 | assert y is AbstractArray
25 |
26 | x = p.dumps(Shaped[np.ndarray, "3 4 hi"])
27 | y = p.loads(x)
28 | assert y.dtype is Shaped
29 | assert y.dim_str == "3 4 hi"
30 |
31 | if torch is not None:
32 | x = p.dumps(Float[torch.Tensor, "batch length"])
33 | y = p.loads(x)
34 | assert y.dtype is Float
35 | assert y.dim_str == "batch length"
36 |
--------------------------------------------------------------------------------
/test/test_tf_dtype.py:
--------------------------------------------------------------------------------
1 | # Tensorflow dependency kept in a separate file, so that we can optionally exclude it
2 | # more easily.
3 | import tensorflow as tf
4 |
5 | from jaxtyping import UInt
6 |
7 |
8 | def test_tf_dtype():
9 | x = tf.constant(1, dtype=tf.uint8)
10 | y = tf.constant(1, dtype=tf.float32)
11 | hint = UInt[tf.Tensor, "..."]
12 | assert isinstance(x, hint)
13 | assert not isinstance(y, hint)
14 |
--------------------------------------------------------------------------------
/test/test_threading.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Google LLC
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | import threading
21 |
22 | import jax.numpy as jnp
23 | from typeguard import typechecked
24 |
25 | from jaxtyping import Array, Float, jaxtyped
26 |
27 |
28 | class _ErrorableThread(threading.Thread):
29 | def run(self):
30 | try:
31 | super().run()
32 | except Exception as e:
33 | self.exc = e
34 |
35 | def join(self, timeout=None):
36 | super().join(timeout)
37 | if hasattr(self, "exc"):
38 | raise self.exc
39 |
40 |
41 | def test_threading_jaxtyped():
42 | @jaxtyped(typechecker=typechecked)
43 | def add(x: Float[Array, "a b"], y: Float[Array, "a b"]) -> Float[Array, "a b"]:
44 | return x + y
45 |
46 | def run():
47 | a = jnp.array([[1.0, 2.0]])
48 | b = jnp.array([[2.0, 3.0]])
49 | add(a, b)
50 |
51 | thread = _ErrorableThread(target=run)
52 | thread.start()
53 | thread.join()
54 |
55 |
56 | def test_threading_nojaxtyped():
57 | def run():
58 | a = jnp.array([[1.0, 2.0]])
59 | assert isinstance(a, Float[Array, "..."])
60 |
61 | thread = _ErrorableThread(target=run)
62 | thread.start()
63 | thread.join()
64 |
--------------------------------------------------------------------------------
/test/types/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrick-kidger/jaxtyping/fe61644be8590cf0a89bacc278283e1e4b9ea3e4/test/types/__init__.py
--------------------------------------------------------------------------------
/test/types/decorator.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import numpy as np
4 | from beartype import beartype
5 |
6 | from jaxtyping import Float, Int, jaxtyped
7 |
8 |
9 | @jaxtyped(typechecker=beartype)
10 | @dataclass
11 | class User:
12 | name: str
13 | age: int
14 | items: Float[np.ndarray, " N"]
15 | timestamps: Int[np.ndarray, " N"]
16 |
17 |
18 | @jaxtyped(typechecker=beartype)
19 | def transform_user(user: User, increment_age: int = 1) -> User:
20 | user.age += increment_age
21 | return user
22 |
23 |
24 | user = User(
25 | name="John",
26 | age=20,
27 | items=np.random.normal(size=10),
28 | timestamps=np.random.randint(0, 100, size=10),
29 | )
30 |
31 | new_user = transform_user(user, increment_age=2)
32 |
--------------------------------------------------------------------------------