├── .flake8 ├── .github ├── FUNDING.yml └── workflows │ └── ci.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── CONTRIBUTING.md ├── FURTHER-DOCUMENTATION.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── pyproject.toml ├── setup.py ├── test ├── conftest.py ├── test_consistency.py ├── test_details.py ├── test_dtype_layout.py ├── test_ellipsis.py ├── test_examples.py ├── test_extensions.py ├── test_misc.py └── test_shape.py └── torchtyping ├── __init__.py ├── pytest_plugin.py ├── tensor_details.py ├── tensor_type.py ├── typechecker.py └── utils.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | ignore = W291,W503,E203 4 | per-file-ignores = __init__.py: F401 -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: [patrick-kidger] 2 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: torchtyping CI 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - master 7 | 8 | jobs: 9 | formatting_check: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v1 13 | - name: Set up Python 3.9 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: 3.9 17 | - name: Install black 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install black flake8 21 | - name: Format with black 22 | run: | 23 | python -m black --check torchtyping/ 24 | - name: Lint with flake8 25 | run: | 26 | flake8 torchtyping/ 27 | 28 | test_suite: 29 | runs-on: ubuntu-latest 30 | strategy: 31 | matrix: 32 | python-version: [3.7, 3.8, 3.9] 33 | steps: 34 | - uses: actions/checkout@v2 35 | - name: Set up Python ${{ matrix.python-version }} 36 | uses: actions/setup-python@v2 37 | with: 38 | python-version: ${{ matrix.python-version }} 39 | - name: Install dependencies 40 | run: | 41 | python -m pip install --upgrade pip 42 | pip install wheel 43 | pip install -e . 44 | pip install pytest 45 | - name: Test with pytest 46 | run: | 47 | python -m pytest test/ 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | **/.ipynb_checkpoints/ 3 | *.py[cod] 4 | .idea/ 5 | .vs/ 6 | build/ 7 | dist/ 8 | *.egg_info/ 9 | *.egg 10 | *.so 11 | *.egg-info/ 12 | **/.mypy_cache/ 13 | env/ -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: stable 4 | hooks: 5 | - id: black 6 | language_version: python3.9 7 | 8 | - repo: https://github.com/pycqa/flake8 9 | rev: 3.9.2 10 | hooks: 11 | - id: flake8 -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | **0.1.4** 2 | 3 | Fixed metaclass incompatibility to work with PyTorch 1.9.0. 4 | 5 | **0.1.3** 6 | 7 | `TensorType` now inherits from `torch.Tensor` so that IDE lookup+error messages work as expected. 8 | Updated pre-commit hooks. These were failing for some reason. 9 | 10 | **0.1.2** 11 | 12 | Added support for `str: str` pairs and `None: str` pairs. 13 | 14 | **0.1.1** 15 | 16 | Added support for Python 3.7+. (Down from Python 3.9+.) 17 | Added support for `typing.Any` to indicate an arbitrary-size dimension. 18 | 19 | **0.1.0** 20 | 21 | Initial release. 22 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contributions (pull requests) are very welcome. 4 | 5 | First fork the library on GitHub. 6 | 7 | Then clone and install the library in development mode: 8 | 9 | ```bash 10 | git clone https://github.com/your-username-here/torchtyping.git 11 | cd torchtyping 12 | pip install -e . 13 | ``` 14 | 15 | Then install the pre-commit hook: 16 | 17 | ```bash 18 | pip install pre-commit 19 | pre-commit install 20 | ``` 21 | 22 | These automatically check that the code is formatted, using Black and flake8. 23 | 24 | Make your changes. Make sure to include additional tests testing any new functionality. 25 | 26 | Verify the tests all pass: 27 | 28 | ```bash 29 | pip install pytest 30 | pytest 31 | ``` 32 | 33 | Push your changes back to your fork of the repository: 34 | 35 | ```bash 36 | git push 37 | ``` 38 | 39 | Then open a pull request on GitHub. 40 | -------------------------------------------------------------------------------- /FURTHER-DOCUMENTATION.md: -------------------------------------------------------------------------------- 1 | # Further documentation 2 | 3 | ## Design goals 4 | 5 | `torchtyping` had a few design goals. 6 | 7 | - **Use type annotations.** There's a few other libraries out there that do this via, essentially, syntactic sugar around `assert` statements. I wanted something neater than that. 8 | - **It should be easy to stop using `torchtyping`.** No, really! If it's not for you then it's easy to remove afterwards. Using `torchtyping` isn't something you should have to bake into your code; just replace `from torchtyping import TensorType` with `TensorType = list` (as a dummy), and your code should still all run. 9 | - **The runtime type checking should be optional.** Runtime checks obviously impose a performance penalty. Integrating with `typeguard` accomplishes this perfectly, in particular through its option to only activate when running tests (my favourite choice). 10 | - **`torchtyping` should be human-readable.** A big part of using type annotations in Python code is to document -- for whoever's reading it -- what is expected. (Particularly valuable on large codebases with several developers.) `torchtyping`'s syntax, and the use of type annotations over some other mechanism, is deliberately chosen to fulfill this goal. 11 | 12 | ## FAQ 13 | 14 | **The runtime checking isn't working!** 15 | 16 | First make sure that you're calling `torchtyping.patch_typeguard`. 17 | 18 | Then make sure that you've enabled `typeguard`, either by decorating the function with `typeguard.typechecked`, or by using `typeguard.importhook.install_import_hook`, or by using the pytest command line flags listed in the main [README](./README.md). 19 | 20 | Make sure that function you're checking is defined _after_ calling `torchtyping.patch_typeguard`. 21 | 22 | If you have done all of that, then feel free to raise an issue. 23 | 24 | **flake8 is giving spurious warnings.** 25 | 26 | Running flake8 will produce spurious warnings for annotations using strings: `TensorType["batch"]` gives `F821 undefined name 'batch'`. 27 | 28 | You can silence these en-masse just by creating a dummy `batch = None` anywhere in the file. (Or by placing `# noqa: F821` on the relevant lines.) 29 | 30 | **Does this work with `mypy`?** 31 | 32 | Mostly. You'll need to tell `mypy` not to think too hard about `torchtyping`, by annotating its import statements with: 33 | 34 | ```python 35 | from torchtyping import TensorType # type: ignore 36 | ``` 37 | 38 | This is because the functionality provided by `torchtyping` is [currently beyond](https://www.python.org/dev/peps/pep-0646/) what `mypy` is capable of representing/understanding. (See also the [links at the end](#other-libraries-and-resources) for further material on this.) 39 | 40 | Additionally `mypy` has a bug which causes it crash on any file using the `str: int` or `str: ...` notation, as in `TensorType["batch": 10]`. This can be worked around by skipping the file, by creating a `filename.pyi` file in the same directory. See also the corresponding [mypy issue](https://github.com/python/mypy/issues/10266). 41 | 42 | **Are nested annotations of the form `Blahblah[Moreblah[TensorType[...]]]` supported?** 43 | 44 | Yes. 45 | 46 | **Are multiple `...` supported?** 47 | 48 | Yes. For example: 49 | 50 | ```python 51 | def func(x: TensorType["dim1": ..., "dim2": ...], 52 | y: TensorType["dim2": ...] 53 | ) -> TensorType["dim1": ...]: 54 | sum_dims = [-i - 1 for i in range(y.dim())] 55 | return (x * y).sum(dim=sum_dims) 56 | ``` 57 | 58 | **`TensorType[float]` corresponds to`float32` but `torch.rand(2).to(float)` produces `float64`**. 59 | 60 | This is a deliberate asymmetry. `TensorType[float]` corresponds to `torch.get_default_dtype()`, as a convenience, but `.to(float)` always corresponds to `float64`. 61 | 62 | **How to indicate a scalar Tensor, i.e. one with zero dimensions?** 63 | 64 | `TensorType[()]`. Equivalently `TensorType[(), float]`, etc. 65 | 66 | **Support for TensorFlow/JAX/etc?** 67 | 68 | Not at the moment. The library is called `torchtyping` after all. [There are alternatives for these libraries.](#other-libraries-and-resources) 69 | 70 | ## Custom extensions 71 | 72 | Writing custom extensions is a breeze. Checking extra properties is done by subclassing `torchtyping.TensorDetail`, and passing instances of your `detail` to `torchtyping.TensorType`. For example this checks that the tensor has an additional attribute `foo`, which must be a string with value `"good-foo"`: 73 | 74 | ```python 75 | from torch import rand, Tensor 76 | from torchtyping import TensorDetail, TensorType 77 | from typeguard import typechecked 78 | 79 | # Write the extension 80 | 81 | class FooDetail(TensorDetail): 82 | def __init__(self, foo): 83 | super().__init__() 84 | self.foo = foo 85 | 86 | def check(self, tensor: Tensor) -> bool: 87 | return hasattr(tensor, "foo") and tensor.foo == self.foo 88 | 89 | # reprs used in error messages when the check is failed 90 | 91 | def __repr__(self) -> str: 92 | return f"FooDetail({self.foo})" 93 | 94 | @classmethod 95 | def tensor_repr(cls, tensor: Tensor) -> str: 96 | # Should return a representation of the tensor with respect 97 | # to what this detail is checking 98 | if hasattr(tensor, "foo"): 99 | return f"FooDetail({tensor.foo})" 100 | else: 101 | return "" 102 | 103 | # Test the extension 104 | 105 | @typechecked 106 | def foo_checker(tensor: TensorType[float, FooDetail("good-foo")]): 107 | pass 108 | 109 | 110 | def valid_foo(): 111 | x = rand(3) 112 | x.foo = "good-foo" 113 | foo_checker(x) 114 | 115 | 116 | def invalid_foo_one(): 117 | x = rand(3) 118 | x.foo = "bad-foo" 119 | foo_checker(x) 120 | 121 | 122 | def invalid_foo_two(): 123 | x = rand(2).int() 124 | x.foo = "good-foo" 125 | foo_checker(x) 126 | ``` 127 | 128 | As you can see, a `detail` must supply three methods. The first is a `check` method, which takes a tensor and checks whether it satisfies the detail. Second is a `__repr__`, which is used in error messages, to describe the detail that wasn't satisfied. Third is a `tensor_repr`, which is also used in error messages, to describe what property the tensor had (instead of the desired detail). 129 | 130 | ## Other libraries and resources 131 | 132 | `torchtyping` is one amongst a few libraries trying to do this kind of thing. Here's some links for the curious: 133 | 134 | **Discussion** 135 | - [PEP 646](https://www.python.org/dev/peps/pep-0646/) proposes variadic generics. These are a tool needed for static checkers (like `mypy`) to be able to do the kind of shape checking that `torchtyping` does dynamically. However at time of writing Python doesn't yet support this. 136 | - The [Ideas for array shape typing in Python](https://docs.google.com/document/d/1vpMse4c6DrWH5rq2tQSx3qwP_m_0lyn-Ij4WHqQqRHY/) document is a good overview of some of the ways to type check arrays. 137 | 138 | **Other libraries** 139 | - [TensorAnnotations](https://github.com/deepmind/tensor_annotations) is a library for statically checking JAX or TensorFlow tensor shapes. (It also has some good links on to other discussions around this topic.) 140 | - [`nptyping`](https://github.com/ramonhagenaars/nptyping) does something very similar to `torchtyping`, but for numpy. 141 | - [`tsanley`](https://github.com/ofnote/tsanley)/[`tsalib`](https://github.com/ofnote/tsalib) is an alternative dynamic shape checker, but requires a bit of extra setup. 142 | - [TensorGuard](https://github.com/Michedev/tensorguard) is an alternative, using extra function calls rather than type hints. 143 | 144 | ## More Examples 145 | 146 | **Shape checking:** 147 | 148 | ```python 149 | def func(x: TensorType["batch", 5], 150 | y: TensorType["batch", 3]): 151 | # x has shape (batch, 5) 152 | # y has shape (batch, 3) 153 | # batch dimension is the same for both 154 | 155 | def func(x: TensorType[2, -1, -1]): 156 | # x has shape (2, any_one, any_two) 157 | # -1 is a special value to represent any size. 158 | ``` 159 | 160 | **Checking arbitrary numbers of batch dimensions:** 161 | 162 | ```python 163 | def func(x: TensorType[..., 2, 3]): 164 | # x has shape (..., 2, 3) 165 | 166 | def func(x: TensorType[..., 2, "channels"], 167 | y: TensorType[..., "channels"]): 168 | # x has shape (..., 2, channels) 169 | # y has shape (..., channels) 170 | # "channels" is checked to be the same size for both arguments. 171 | 172 | def func(x: TensorType["batch": ..., "channels_x"], 173 | y: TensorType["batch": ..., "channels_y"]): 174 | # x has shape (..., channels_x) 175 | # y has shape (..., channels_y) 176 | # the ... batch dimensions are checked to be of the same size. 177 | ``` 178 | 179 | **Return value checking:** 180 | 181 | ```python 182 | def func(x: TensorType[3, 4]) -> TensorType[()]: 183 | # x has shape (3, 4) 184 | # return has shape () 185 | ``` 186 | 187 | **Dtype checking:** 188 | 189 | ```python 190 | def func(x: TensorType[float]): 191 | # x has dtype torch.float32 192 | ``` 193 | 194 | **Checking shape and dtype at the same time:** 195 | 196 | ```python 197 | def func(x: TensorType[3, 4, float]): 198 | # x has shape (3, 4) and has dtype torch.float32 199 | ``` 200 | 201 | **Checking names for dimensions as per [named tensors](https://pytorch.org/docs/stable/named_tensor.html):** 202 | 203 | ```python 204 | def func(x: TensorType["a": 3, "b", is_named]): 205 | # x has has names ("a", "b") 206 | # x has shape (3, Any) 207 | ``` 208 | 209 | **Checking layouts:** 210 | 211 | ```python 212 | def func(x: TensorType[torch.sparse_coo]): 213 | # x is a sparse tensor with layout sparse_coo 214 | ``` 215 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include FURTHER-DOCUMENTATION.md 3 | prune test 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Please use jaxtyping instead 2 | 3 | *Welcome! For new projects I now **strongly** recommend using my newer [jaxtyping](https://github.com/google/jaxtyping) project instead. It supports PyTorch, doesn't actually depend on JAX, and unlike TorchTyping it is compatible with static type checkers. The 'jax' in the name is now historical!* 4 | 5 |
6 |
7 |
8 | 9 | The original torchtyping README is as follows. 10 | 11 | --- 12 | 13 |

torchtyping

14 |

Type annotations for a tensor's shape, dtype, names, ...

15 | 16 | Turn this: 17 | ```python 18 | def batch_outer_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 19 | # x has shape (batch, x_channels) 20 | # y has shape (batch, y_channels) 21 | # return has shape (batch, x_channels, y_channels) 22 | 23 | return x.unsqueeze(-1) * y.unsqueeze(-2) 24 | ``` 25 | into this: 26 | ```python 27 | def batch_outer_product(x: TensorType["batch", "x_channels"], 28 | y: TensorType["batch", "y_channels"] 29 | ) -> TensorType["batch", "x_channels", "y_channels"]: 30 | 31 | return x.unsqueeze(-1) * y.unsqueeze(-2) 32 | ``` 33 | **with programmatic checking that the shape (dtype, ...) specification is met.** 34 | 35 | Bye-bye bugs! Say hello to enforced, clear documentation of your code. 36 | 37 | If (like me) you find yourself littering your code with comments like `# x has shape (batch, hidden_state)` or statements like `assert x.shape == y.shape` , just to keep track of what shape everything is, **then this is for you.** 38 | 39 | --- 40 | 41 | ## Installation 42 | 43 | ```bash 44 | pip install torchtyping 45 | ``` 46 | 47 | Requires Python >=3.7 and PyTorch >=1.7.0. 48 | 49 | If using [`typeguard`](https://github.com/agronholm/typeguard) then it must be a version <3.0.0. 50 | 51 | ## Usage 52 | 53 | `torchtyping` allows for type annotating: 54 | 55 | - **shape**: size, number of dimensions; 56 | - **dtype** (float, integer, etc.); 57 | - **layout** (dense, sparse); 58 | - **names** of dimensions as per [named tensors](https://pytorch.org/docs/stable/named_tensor.html); 59 | - **arbitrary number of batch dimensions** with `...`; 60 | - **...plus anything else you like**, as `torchtyping` is highly extensible. 61 | 62 | If [`typeguard`](https://github.com/agronholm/typeguard) is (optionally) installed then **at runtime the types can be checked** to ensure that the tensors really are of the advertised shape, dtype, etc. 63 | 64 | ```python 65 | # EXAMPLE 66 | 67 | from torch import rand 68 | from torchtyping import TensorType, patch_typeguard 69 | from typeguard import typechecked 70 | 71 | patch_typeguard() # use before @typechecked 72 | 73 | @typechecked 74 | def func(x: TensorType["batch"], 75 | y: TensorType["batch"]) -> TensorType["batch"]: 76 | return x + y 77 | 78 | func(rand(3), rand(3)) # works 79 | func(rand(3), rand(1)) 80 | # TypeError: Dimension 'batch' of inconsistent size. Got both 1 and 3. 81 | ``` 82 | 83 | `typeguard` also has an import hook that can be used to automatically test an entire module, without needing to manually add `@typeguard.typechecked` decorators. 84 | 85 | If you're not using `typeguard` then `torchtyping.patch_typeguard()` can be omitted altogether, and `torchtyping` just used for documentation purposes. If you're not already using `typeguard` for your regular Python programming, then strongly consider using it. It's a great way to squash bugs. Both `typeguard` and `torchtyping` also integrate with `pytest`, so if you're concerned about any performance penalty then they can be enabled during tests only. 86 | 87 | ## API 88 | 89 | ```python 90 | torchtyping.TensorType[shape, dtype, layout, details] 91 | ``` 92 | 93 | The core of the library. 94 | 95 | Each of `shape`, `dtype`, `layout`, `details` are optional. 96 | 97 | - The `shape` argument can be any of: 98 | - An `int`: the dimension must be of exactly this size. If it is `-1` then any size is allowed. 99 | - A `str`: the size of the dimension passed at runtime will be bound to this name, and all tensors checked that the sizes are consistent. 100 | - A `...`: An arbitrary number of dimensions of any sizes. 101 | - A `str: int` pair (technically it's a slice), combining both `str` and `int` behaviour. (Just a `str` on its own is equivalent to `str: -1`.) 102 | - A `str: str` pair, in which case the size of the dimension passed at runtime will be bound to _both_ names, and all dimensions with either name must have the same size. (Some people like to use this as a way to associate multiple names with a dimension, for extra documentation purposes.) 103 | - A `str: ...` pair, in which case the multiple dimensions corresponding to `...` will be bound to the name specified by `str`, and again checked for consistency between arguments. 104 | - `None`, which when used in conjunction with `is_named` below, indicates a dimension that must _not_ have a name in the sense of [named tensors](https://pytorch.org/docs/stable/named_tensor.html). 105 | - A `None: int` pair, combining both `None` and `int` behaviour. (Just a `None` on its own is equivalent to `None: -1`.) 106 | - A `None: str` pair, combining both `None` and `str` behaviour. (That is, it must not have a named dimension, but must be of a size consistent with other uses of the string.) 107 | - A `typing.Any`: Any size is allowed for this dimension (equivalent to `-1`). 108 | - Any tuple of the above. For example.`TensorType["batch": ..., "length": 10, "channels", -1]`. If you just want to specify the number of dimensions then use for example `TensorType[-1, -1, -1]` for a three-dimensional tensor. 109 | - The `dtype` argument can be any of: 110 | - `torch.float32`, `torch.float64` etc. 111 | - `int`, `bool`, `float`, which are converted to their corresponding PyTorch types. `float` is specifically interpreted as `torch.get_default_dtype()`, which is usually `float32`. 112 | - The `layout` argument can be either `torch.strided` or `torch.sparse_coo`, for dense and sparse tensors respectively. 113 | - The `details` argument offers a way to pass an arbitrary number of additional flags that customise and extend `torchtyping`. Two flags are built-in by default. `torchtyping.is_named` causes the [names of tensor dimensions](https://pytorch.org/docs/stable/named_tensor.html) to be checked, and `torchtyping.is_float` can be used to check that arbitrary floating point types are passed in. (Rather than just a specific one as with e.g. `TensorType[torch.float32]`.) For discussion on how to customise `torchtyping` with your own `details`, see the [further documentation](https://github.com/patrick-kidger/torchtyping/blob/master/FURTHER-DOCUMENTATION.md#custom-extensions). 114 | - Check multiple things at once by just putting them all together inside a single `[]`. For example `TensorType["batch": ..., "length", "channels", float, is_named]`. 115 | 116 | ```python 117 | torchtyping.patch_typeguard() 118 | ``` 119 | 120 | `torchtyping` integrates with `typeguard` to perform runtime type checking. `torchtyping.patch_typeguard()` should be called at the global level, and will patch `typeguard` to check `TensorType`s. 121 | 122 | This function is safe to run multiple times. (It does nothing after the first run). 123 | 124 | - If using `@typeguard.typechecked`, then `torchtyping.patch_typeguard()` should be called any time before using `@typeguard.typechecked`. For example you could call it at the start of each file using `torchtyping`. 125 | - If using `typeguard.importhook.install_import_hook`, then `torchtyping.patch_typeguard()` should be called any time before defining the functions you want checked. For example you could call `torchtyping.patch_typeguard()` just once, at the same time as the `typeguard` import hook. (The order of the hook and the patch doesn't matter.) 126 | - If you're not using `typeguard` then `torchtyping.patch_typeguard()` can be omitted altogether, and `torchtyping` just used for documentation purposes. 127 | 128 | ```bash 129 | pytest --torchtyping-patch-typeguard 130 | ``` 131 | 132 | `torchtyping` offers a `pytest` plugin to automatically run `torchtyping.patch_typeguard()` before your tests. `pytest` will automatically discover the plugin, you just need to pass the `--torchtyping-patch-typeguard` flag to enable it. Packages can then be passed to `typeguard` as normal, either by using `@typeguard.typechecked`, `typeguard`'s import hook, or the `pytest` flag `--typeguard-packages="your_package_here"`. 133 | 134 | ## Further documentation 135 | 136 | See the [further documentation](https://github.com/patrick-kidger/torchtyping/blob/master/FURTHER-DOCUMENTATION.md) for: 137 | 138 | - FAQ; 139 | - Including `flake8` and `mypy` compatibility; 140 | - How to write custom extensions to `torchtyping`; 141 | - Resources and links to other libraries and materials on this topic; 142 | - More examples. 143 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 88 3 | 4 | [tool.pytest.ini_options] 5 | addopts = "--torchtyping-patch-typeguard" 6 | # No running typeguard unfortunately, because we define a pytest import hook and that means torchtyping gets imported before typeguard gets a chance to run. 7 | # Ironic. 8 | #"--typeguard-packages=torchtyping" -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import setuptools 4 | import sys 5 | 6 | here = os.path.realpath(os.path.dirname(__file__)) 7 | 8 | 9 | name = "torchtyping" 10 | 11 | # for simplicity we actually store the version in the __version__ attribute in the 12 | # source 13 | with open(os.path.join(here, name, "__init__.py")) as f: 14 | meta_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M) 15 | if meta_match: 16 | version = meta_match.group(1) 17 | else: 18 | raise RuntimeError("Unable to find __version__ string.") 19 | 20 | author = "Patrick Kidger" 21 | 22 | author_email = "contact@kidger.site" 23 | 24 | description = "Runtime type annotations for the shape, dtype etc. of PyTorch Tensors. " 25 | 26 | with open(os.path.join(here, "README.md"), "r", encoding="utf-8") as f: 27 | readme = f.read() 28 | 29 | url = "https://github.com/patrick-kidger/torchtyping" 30 | 31 | license = "Apache-2.0" 32 | 33 | classifiers = [ 34 | "Development Status :: 3 - Alpha", 35 | "Intended Audience :: Developers", 36 | "Intended Audience :: Financial and Insurance Industry", 37 | "Intended Audience :: Information Technology", 38 | "Intended Audience :: Science/Research", 39 | "License :: OSI Approved :: Apache Software License", 40 | "Natural Language :: English", 41 | "Programming Language :: Python :: 3", 42 | "Programming Language :: Python :: 3.7", 43 | "Programming Language :: Python :: 3.8", 44 | "Programming Language :: Python :: 3.9", 45 | "Programming Language :: Python :: Implementation :: CPython", 46 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 47 | "Topic :: Scientific/Engineering :: Information Analysis", 48 | "Topic :: Scientific/Engineering :: Mathematics", 49 | "Framework :: Pytest", 50 | ] 51 | 52 | user_python_version = sys.version_info 53 | 54 | python_requires = ">=3.7.0" 55 | 56 | install_requires = ["torch>=1.7.0", "typeguard>=2.11.1,<3"] 57 | 58 | if user_python_version < (3, 9): 59 | install_requires += ["typing_extensions==3.7.4.3"] 60 | 61 | entry_points = dict(pytest11=["torchtyping = torchtyping.pytest_plugin"]) 62 | 63 | setuptools.setup( 64 | name=name, 65 | version=version, 66 | author=author, 67 | author_email=author_email, 68 | maintainer=author, 69 | maintainer_email=author_email, 70 | description=description, 71 | long_description=readme, 72 | long_description_content_type="text/markdown", 73 | url=url, 74 | license=license, 75 | classifiers=classifiers, 76 | zip_safe=False, 77 | python_requires=python_requires, 78 | install_requires=install_requires, 79 | entry_points=entry_points, 80 | packages=[name], 81 | ) 82 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | 4 | 5 | with warnings.catch_warnings(): 6 | warnings.simplefilter("ignore") 7 | torch.rand(2, names=("a",)) 8 | -------------------------------------------------------------------------------- /test/test_consistency.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from torch import rand 3 | from torchtyping import TensorType 4 | from typeguard import typechecked 5 | 6 | 7 | x = y = None 8 | 9 | 10 | def test_single(): 11 | @typechecked 12 | def func1(x: TensorType["x"], y: TensorType["x"]): 13 | pass 14 | 15 | @typechecked 16 | def func2(x: TensorType["x"], y: TensorType["x"]) -> TensorType["x"]: 17 | return x + y 18 | 19 | @typechecked 20 | def func3(x: TensorType["x"], y: TensorType["x"]) -> TensorType["x", "x"]: 21 | return x + y 22 | 23 | @typechecked 24 | def func4(x: TensorType["x"], y: TensorType["x"]) -> TensorType["x", "x"]: 25 | return x.unsqueeze(0) + y.unsqueeze(1) 26 | 27 | @typechecked 28 | def func5(x: TensorType["x"], y: TensorType["x"]) -> TensorType["x", "y"]: 29 | return x 30 | 31 | @typechecked 32 | def func6(x: TensorType["x"], y: TensorType["x"]) -> TensorType["y", "x"]: 33 | return x 34 | 35 | @typechecked 36 | def func7(x: TensorType["x"]) -> TensorType["x"]: 37 | assert x.shape != (1,) 38 | return rand((1,)) 39 | 40 | func1(rand(2), rand(2)) 41 | func2(rand(2), rand(2)) 42 | with pytest.raises(TypeError): 43 | func3(rand(2), rand(2)) 44 | func4(rand(2), rand(2)) 45 | with pytest.raises(TypeError): 46 | func5(rand(2), rand(2)) 47 | with pytest.raises(TypeError): 48 | func6(rand(2), rand(2)) 49 | with pytest.raises(TypeError): 50 | func7(rand(3)) 51 | 52 | 53 | def test_multiple(): 54 | # Fun fact, this "wrong" func0 is actually a mistype of func1, that torchtyping 55 | # caught for me when I ran the tests! 56 | @typechecked 57 | def func0(x: TensorType["x"], y: TensorType["y"]) -> TensorType["x", "y"]: 58 | return x.unsqueeze(0) + y.unsqueeze(1) 59 | 60 | @typechecked 61 | def func1(x: TensorType["x"], y: TensorType["y"]) -> TensorType["x", "y"]: 62 | return x.unsqueeze(1) + y.unsqueeze(0) 63 | 64 | @typechecked 65 | def func2(x: TensorType["x", "x"]): 66 | pass 67 | 68 | @typechecked 69 | def func3(x: TensorType["x", "x", "x"]): 70 | pass 71 | 72 | @typechecked 73 | def func4(x: TensorType["x"], y: TensorType["x", "y"]): 74 | pass 75 | 76 | @typechecked 77 | def func5(x: TensorType["x", "y"], y: TensorType["y", "x"]): 78 | pass 79 | 80 | @typechecked 81 | def func6(x: TensorType["x"], y: TensorType["y"]) -> TensorType["x", "y"]: 82 | assert not (x.shape == (2,) and y.shape == (3,)) 83 | return rand(2, 3) 84 | 85 | func0(rand(2), rand(2)) # can't catch this 86 | with pytest.raises(TypeError): 87 | func0(rand(2), rand(3)) 88 | with pytest.raises(TypeError): 89 | func0(rand(10), rand(0)) 90 | 91 | func1(rand(2), rand(2)) 92 | func1(rand(2), rand(3)) 93 | func1(rand(10), rand(0)) 94 | 95 | func2(rand(0, 0)) 96 | func2(rand(2, 2)) 97 | func2(rand(9, 9)) 98 | with pytest.raises(TypeError): 99 | func2(rand(0, 4)) 100 | func2(rand(1, 4)) 101 | func2(rand(3, 4)) 102 | 103 | func3(rand(0, 0, 0)) 104 | func3(rand(2, 2, 2)) 105 | func3(rand(9, 9, 9)) 106 | with pytest.raises(TypeError): 107 | func3(rand(0, 4, 4)) 108 | func3(rand(1, 4, 4)) 109 | func3(rand(3, 3, 4)) 110 | 111 | func4(rand(3), rand(3, 4)) 112 | with pytest.raises(TypeError): 113 | func4(rand(3), rand(4, 3)) 114 | 115 | func5(rand(2, 3), rand(3, 2)) 116 | func5(rand(0, 5), rand(5, 0)) 117 | func5(rand(2, 2), rand(2, 2)) 118 | with pytest.raises(TypeError): 119 | func5(rand(2, 3), rand(2, 3)) 120 | func5(rand(2, 3), rand(2, 2)) 121 | 122 | with pytest.raises(TypeError): 123 | func6(rand(5), rand(3)) 124 | -------------------------------------------------------------------------------- /test/test_details.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torchtyping import TensorType, is_float, is_named 4 | import typeguard 5 | 6 | 7 | dim1 = dim2 = dim3 = None 8 | 9 | 10 | def test_float_tensor(): 11 | @typeguard.typechecked 12 | def func1(x: TensorType[is_float]): 13 | pass 14 | 15 | @typeguard.typechecked 16 | def func2(x: TensorType[2, 2, is_float]): 17 | pass 18 | 19 | @typeguard.typechecked 20 | def func3(x: TensorType[float, is_float]): 21 | pass 22 | 23 | @typeguard.typechecked 24 | def func4(x: TensorType[bool, is_float]): 25 | pass 26 | 27 | @typeguard.typechecked 28 | def func5(x: TensorType["dim1":2, 2, float, is_float]): 29 | pass 30 | 31 | @typeguard.typechecked 32 | def func6(x: TensorType[2, "dim2":2, torch.sparse_coo, is_float]): 33 | pass 34 | 35 | x = torch.rand(2, 2) 36 | y = torch.rand(1) 37 | z = torch.tensor([[0, 1], [2, 3]]) 38 | w = torch.rand(4).to_sparse() 39 | w1 = torch.rand(2, 2).to_sparse() 40 | w2 = torch.tensor([[0, 1], [2, 3]]).to_sparse() 41 | 42 | func1(x) 43 | func1(y) 44 | with pytest.raises(TypeError): 45 | func1(z) 46 | func1(w) 47 | func1(w1) 48 | with pytest.raises(TypeError): 49 | func1(w2) 50 | 51 | func2(x) 52 | with pytest.raises(TypeError): 53 | func2(y) 54 | with pytest.raises(TypeError): 55 | func2(z) 56 | with pytest.raises(TypeError): 57 | func2(w) 58 | func2(w1) 59 | with pytest.raises(TypeError): 60 | func2(w2) 61 | 62 | func3(x) 63 | func3(y) 64 | with pytest.raises(TypeError): 65 | func3(z) 66 | func3(w) 67 | func3(w1) 68 | with pytest.raises(TypeError): 69 | func3(w2) 70 | 71 | with pytest.raises(TypeError): 72 | func4(x) 73 | with pytest.raises(TypeError): 74 | func4(y) 75 | with pytest.raises(TypeError): 76 | func4(z) 77 | with pytest.raises(TypeError): 78 | func4(w) 79 | with pytest.raises(TypeError): 80 | func4(w1) 81 | with pytest.raises(TypeError): 82 | func4(w2) 83 | 84 | func5(x) 85 | with pytest.raises(TypeError): 86 | func5(y) 87 | with pytest.raises(TypeError): 88 | func5(z) 89 | with pytest.raises(TypeError): 90 | func5(w) 91 | func5(w1) 92 | with pytest.raises(TypeError): 93 | func5(w2) 94 | 95 | with pytest.raises(TypeError): 96 | func6(x) 97 | with pytest.raises(TypeError): 98 | func6(y) 99 | with pytest.raises(TypeError): 100 | func6(z) 101 | with pytest.raises(TypeError): 102 | func6(w) 103 | func6(w1) 104 | with pytest.raises(TypeError): 105 | func6(w2) 106 | 107 | 108 | def test_named_tensor(): 109 | @typeguard.typechecked 110 | def _named_a_dim_checker(x: TensorType["dim1", is_named]): 111 | pass 112 | 113 | @typeguard.typechecked 114 | def _named_ab_dim_checker(x: TensorType["dim1", "dim2", is_named]): 115 | pass 116 | 117 | @typeguard.typechecked 118 | def _named_abc_dim_checker(x: TensorType["dim1", "dim2", "dim3", is_named]): 119 | pass 120 | 121 | @typeguard.typechecked 122 | def _named_cb_dim_checker(x: TensorType["dim3", "dim2", is_named]): 123 | pass 124 | 125 | @typeguard.typechecked 126 | def _named_am1_dim_checker(x: TensorType["dim1", -1, is_named]): 127 | pass 128 | 129 | @typeguard.typechecked 130 | def _named_m1b_dim_checker(x: TensorType[-1, "dim2", is_named]): 131 | pass 132 | 133 | @typeguard.typechecked 134 | def _named_abm1_dim_checker(x: TensorType["dim1", "dim2", -1, is_named]): 135 | pass 136 | 137 | @typeguard.typechecked 138 | def _named_m1bm1_dim_checker(x: TensorType[-1, "dim2", -1, is_named]): 139 | pass 140 | 141 | x = torch.rand(3, 4) 142 | named_x = torch.rand(3, 4, names=("dim1", "dim2")) 143 | 144 | with pytest.raises(TypeError): 145 | _named_ab_dim_checker(x) 146 | with pytest.raises(TypeError): 147 | _named_cb_dim_checker(x) 148 | with pytest.raises(TypeError): 149 | _named_am1_dim_checker(x) 150 | with pytest.raises(TypeError): 151 | _named_m1b_dim_checker(x) 152 | with pytest.raises(TypeError): 153 | _named_a_dim_checker(x) 154 | with pytest.raises(TypeError): 155 | _named_abc_dim_checker(x) 156 | with pytest.raises(TypeError): 157 | _named_abm1_dim_checker(x) 158 | with pytest.raises(TypeError): 159 | _named_m1bm1_dim_checker(x) 160 | 161 | _named_ab_dim_checker(named_x) 162 | _named_am1_dim_checker(named_x) 163 | _named_m1b_dim_checker(named_x) 164 | with pytest.raises(TypeError): 165 | _named_a_dim_checker(named_x) 166 | with pytest.raises(TypeError): 167 | _named_abc_dim_checker(named_x) 168 | with pytest.raises(TypeError): 169 | _named_cb_dim_checker(named_x) 170 | with pytest.raises(TypeError): 171 | _named_abm1_dim_checker(named_x) 172 | with pytest.raises(TypeError): 173 | _named_m1bm1_dim_checker(named_x) 174 | 175 | 176 | def test_named_float_tensor(): 177 | @typeguard.typechecked 178 | def func(x: TensorType["dim1", "dim2":3, is_float, is_named]): 179 | pass 180 | 181 | x = torch.rand(2, 3, names=("dim1", "dim2")) 182 | y = torch.rand(2, 2, names=("dim1", "dim2")) 183 | z = torch.rand(2, 2, names=("dim1", "dim3")) 184 | w = torch.rand(2, 3) 185 | w1 = torch.rand(2, 2, names=("dim1", None)) 186 | w2 = torch.rand(2, 3, names=("dim1", "dim2")).int() 187 | 188 | func(x) 189 | with pytest.raises(TypeError): 190 | func(y) 191 | with pytest.raises(TypeError): 192 | func(z) 193 | with pytest.raises(TypeError): 194 | func(w) 195 | with pytest.raises(TypeError): 196 | func(w1) 197 | with pytest.raises(TypeError): 198 | func(w2) 199 | 200 | 201 | def test_none_names(): 202 | @typeguard.typechecked 203 | def func_unnamed1(x: TensorType[None:4]): 204 | pass 205 | 206 | @typeguard.typechecked 207 | def func_unnamed2(x: TensorType[None:4, "dim1"]): 208 | pass 209 | 210 | @typeguard.typechecked 211 | def func_unnamed3(x: TensorType[None:4, "dim1"], y: TensorType["dim1", None]): 212 | pass 213 | 214 | @typeguard.typechecked 215 | def func_named1(x: TensorType[None:4, is_named]): 216 | pass 217 | 218 | @typeguard.typechecked 219 | def func_named2(x: TensorType[None:4, "dim1", is_named]): 220 | pass 221 | 222 | func_unnamed1(torch.rand(4)) 223 | func_unnamed1(torch.rand(4, names=(None,))) 224 | func_unnamed1(torch.rand(4, names=("not_none",))) 225 | with pytest.raises(TypeError): 226 | func_unnamed1(torch.rand(5)) 227 | with pytest.raises(TypeError): 228 | func_unnamed1(torch.rand(5, names=(None,))) 229 | with pytest.raises(TypeError): 230 | func_unnamed1(torch.rand(5, names=("not_none",))) 231 | with pytest.raises(TypeError): 232 | func_unnamed1(torch.rand(2, 3)) 233 | with pytest.raises(TypeError): 234 | func_unnamed1(torch.rand((), names=())) 235 | with pytest.raises(TypeError): 236 | func_unnamed1(torch.rand(1, 6, 7, 8, names=("not_none", None, None, None))) 237 | 238 | func_unnamed2(torch.rand(4, 5)) 239 | func_unnamed2(torch.rand(4, 5, names=(None, None))) 240 | func_unnamed2(torch.rand(4, 5, names=("dim1", None))) 241 | func_unnamed2(torch.rand(4, 5, names=(None, "dim1"))) 242 | 243 | func_unnamed3(torch.rand(4, 5), torch.rand(5, 3)) 244 | func_unnamed3(torch.rand(4, 5, names=(None, None)), torch.rand(5, 3)) 245 | func_unnamed3(torch.rand(4, 5), torch.rand(5, 3, names=("dim1", None))) 246 | func_unnamed3( 247 | torch.rand(4, 5, names=("another_name", "some_name")), 248 | torch.rand(5, 3, names=(None, "dim1")), 249 | ) 250 | with pytest.raises(TypeError): 251 | func_unnamed3(torch.rand(4, 5), torch.rand(3, 3)) 252 | 253 | func_named1(torch.rand(4)) 254 | func_named1(torch.rand(4, names=(None,))) 255 | with pytest.raises(TypeError): 256 | func_named1(torch.rand(5, names=(None,))) 257 | with pytest.raises(TypeError): 258 | func_named1(torch.rand(4, names=("dim1",))) 259 | with pytest.raises(TypeError): 260 | func_named1(torch.rand(5, names=("dim1",))) 261 | 262 | func_named2(torch.rand(4, 5, names=(None, "dim1"))) 263 | with pytest.raises(TypeError): 264 | func_named2(torch.rand(4, 5)) 265 | with pytest.raises(TypeError): 266 | func_named2(torch.rand(4, 5, names=("another_dim", "dim1"))) 267 | with pytest.raises(TypeError): 268 | func_named2(torch.rand(4, 5, names=(None, "dim2"))) 269 | with pytest.raises(TypeError): 270 | func_named2(torch.rand(4, 5, names=(None, None))) 271 | 272 | 273 | def test_named_ellipsis(): 274 | @typeguard.typechecked 275 | def func(x: TensorType["dim1":..., "dim2", is_named]): 276 | pass 277 | 278 | func(torch.rand(3, 4, names=(None, "dim2"))) 279 | func(torch.rand(3, 4, names=("another_dim", "dim2"))) 280 | with pytest.raises(TypeError): 281 | func(torch.rand(3, 4)) 282 | with pytest.raises(TypeError): 283 | func(torch.rand(3, 4, names=(None, None))) 284 | with pytest.raises(TypeError): 285 | func(torch.rand(3, 4, names=("dim2", None))) 286 | -------------------------------------------------------------------------------- /test/test_dtype_layout.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torchtyping import TensorType 4 | import typeguard 5 | 6 | from typing import Union 7 | 8 | 9 | @typeguard.typechecked 10 | def _float_checker(x: TensorType[float]): 11 | pass 12 | 13 | 14 | @typeguard.typechecked 15 | def _int_checker(x: TensorType[int]): 16 | pass 17 | 18 | 19 | @typeguard.typechecked 20 | def _union_int_float_checker(x: Union[TensorType[float], TensorType[int]]): 21 | pass 22 | 23 | 24 | def test_float_dtype(): 25 | x = torch.rand(2) 26 | _float_checker(x) 27 | _union_int_float_checker(x) 28 | with pytest.raises(TypeError): 29 | _int_checker(x) 30 | 31 | 32 | def test_int_dtype(): 33 | x = torch.tensor(2) 34 | _int_checker(x) 35 | _union_int_float_checker(x) 36 | with pytest.raises(TypeError): 37 | _float_checker(x) 38 | 39 | 40 | @typeguard.typechecked 41 | def _strided_checker(x: TensorType[torch.strided]): 42 | pass 43 | 44 | 45 | @typeguard.typechecked 46 | def _sparse_coo_checker(x: TensorType[torch.sparse_coo]): 47 | pass 48 | 49 | 50 | def test_strided_layout(): 51 | x = torch.rand(2) 52 | _strided_checker(x) 53 | with pytest.raises(TypeError): 54 | _sparse_coo_checker(x) 55 | 56 | 57 | def test_sparse_coo_layout(): 58 | x = torch.rand(2).to_sparse() 59 | _sparse_coo_checker(x) 60 | with pytest.raises(TypeError): 61 | _strided_checker(x) 62 | -------------------------------------------------------------------------------- /test/test_ellipsis.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torchtyping import TensorType 4 | from typeguard import typechecked 5 | 6 | 7 | dim1 = dim2 = dim3 = channel = None 8 | 9 | 10 | def test_basic_ellipsis(): 11 | @typechecked 12 | def func(x: TensorType["dim1":..., "dim2", "dim3"], y: TensorType["dim2", "dim3"]): 13 | pass 14 | 15 | func(torch.rand(2, 2), torch.rand(2, 2)) 16 | func(torch.rand(2, 3), torch.rand(2, 3)) 17 | func(torch.rand(1, 4, 2, 3), torch.rand(2, 3)) 18 | func(torch.rand(2, 3, 2, 3), torch.rand(2, 3)) 19 | 20 | with pytest.raises(TypeError): 21 | func(torch.rand(2, 3), torch.rand(3, 2)) 22 | with pytest.raises(TypeError): 23 | func(torch.rand(1, 4, 2, 3), torch.rand(3, 2)) 24 | with pytest.raises(TypeError): 25 | func(torch.rand(2, 3, 2, 3), torch.rand(3, 2)) 26 | 27 | 28 | def test_zero_size_ellipsis1(): 29 | @typechecked 30 | def func( 31 | x: TensorType["dim1":..., "dim2", "dim3"], 32 | y: TensorType["dim1":..., "dim2", "dim3"], 33 | ): 34 | pass 35 | 36 | with pytest.raises(TypeError): 37 | func(torch.rand(2, 2), torch.rand(2, 2, 2)) 38 | with pytest.raises(TypeError): 39 | func(torch.rand(2, 2, 2), torch.rand(2, 2)) 40 | 41 | 42 | def test_zero_size_ellipsis2(): 43 | @typechecked 44 | def func( 45 | x: TensorType["dim1":..., "dim2", "dim3"], 46 | y: TensorType["dim1":..., "dim2", "dim3"], 47 | ): 48 | pass 49 | 50 | with pytest.raises(TypeError): 51 | func(torch.rand(2, 3), torch.rand(2, 2, 3)) 52 | with pytest.raises(TypeError): 53 | func(torch.rand(2, 2, 3), torch.rand(2, 3)) 54 | with pytest.raises(TypeError): 55 | func(torch.rand(2, 2), torch.rand(2, 2, 2)) 56 | with pytest.raises(TypeError): 57 | func(torch.rand(2, 2, 2), torch.rand(2, 2)) 58 | 59 | 60 | def test_multiple_ellipsis1(): 61 | @typechecked 62 | def func( 63 | x: TensorType["dim1":..., "dim2":...], y: TensorType["dim2":...] 64 | ) -> TensorType["dim1":...]: 65 | sum_dims = [-i - 1 for i in range(y.dim())] 66 | return (x * y).sum(dim=sum_dims) 67 | 68 | func(torch.rand(1, 2), torch.rand(2)) 69 | func(torch.rand(3, 4, 5, 9), torch.rand(5, 9)) 70 | func(torch.rand(3, 4, 11, 5, 9), torch.rand(5, 9)) 71 | func(torch.rand(3, 4, 11, 5, 9), torch.rand(11, 5, 9)) 72 | with pytest.raises(TypeError): 73 | func(torch.rand(1), torch.rand(2)) 74 | with pytest.raises(TypeError): 75 | func(torch.rand(1, 3, 5), torch.rand(3)) 76 | with pytest.raises(TypeError): 77 | func(torch.rand(1, 4), torch.rand(1, 1, 4)) 78 | 79 | 80 | def test_multiple_ellipsis2(): 81 | @typechecked 82 | def func( 83 | x: TensorType["dim2":...], y: TensorType["dim1":..., "dim2":...] 84 | ) -> TensorType["dim1":...]: 85 | sum_dims = [-i - 1 for i in range(x.dim())] 86 | return (x * y).sum(dim=sum_dims) 87 | 88 | with pytest.raises(TypeError): 89 | func(torch.rand(1, 1, 4), torch.rand(1, 4)) 90 | 91 | 92 | def test_multiple_ellipsis3(): 93 | @typechecked 94 | def func( 95 | x: TensorType["dim1":..., "dim2":..., "dim3":...], 96 | y: TensorType["dim2":..., "dim3":...], 97 | z: TensorType["dim2":...], 98 | ) -> TensorType["dim1":...]: 99 | num2 = y.dim() - z.dim() 100 | num3 = z.dim() 101 | for _ in range(num2): 102 | z = z.unsqueeze(-1) 103 | y = y * z 104 | x = x + y 105 | for _ in range(num2 + num3): 106 | x = x.sum(dim=-1) 107 | return x 108 | 109 | func(torch.rand(1, 2, 3), torch.rand(2, 3), torch.rand(2)) 110 | func(torch.rand(3, 5, 6, 7, 8, 0), torch.rand(5, 6, 7, 8, 0), torch.rand(5, 6, 7)) 111 | func(torch.rand(3, 5, 6, 7, 8, 9), torch.rand(5, 6, 7, 8, 9), torch.rand(5, 6, 7)) 112 | 113 | 114 | def test_repeat_ellipsis1(): 115 | @typechecked 116 | def func(x: TensorType["dim1":..., "dim1":...], y: TensorType["dim1":...]): 117 | pass 118 | 119 | func(torch.rand(3, 4, 3, 4), torch.rand(3, 4)) 120 | func(torch.rand(5, 5), torch.rand(5)) 121 | with pytest.raises(TypeError): 122 | func(torch.rand(7, 9), torch.rand(7)) 123 | with pytest.raises(TypeError): 124 | func(torch.rand(7, 4, 9, 4), torch.rand(7, 4)) 125 | with pytest.raises(TypeError): 126 | func(torch.rand(7, 9), torch.rand(9)) 127 | with pytest.raises(TypeError): 128 | func(torch.rand(3, 7, 3, 9), torch.rand(3, 9)) 129 | with pytest.raises(TypeError): 130 | func(torch.rand(7, 3, 3, 9), torch.rand(3, 9)) 131 | with pytest.raises(TypeError): 132 | func(torch.rand(7, 7), torch.rand(3)) 133 | 134 | 135 | def test_repeat_ellipsis2(): 136 | @typechecked 137 | def func( 138 | x: TensorType["dim1":..., "dim1":...], 139 | y: TensorType["dim1":..., "dim2":...], 140 | z: TensorType["dim2":...], 141 | ): 142 | pass 143 | 144 | func(torch.rand(4, 4), torch.rand(4, 5), torch.rand(5)) 145 | func(torch.rand(3, 4, 3, 4), torch.rand(3, 4, 5), torch.rand(5)) 146 | func(torch.rand(2, 3, 4, 2, 3, 4), torch.rand(2, 3, 4, 5, 6), torch.rand(5, 6)) 147 | with pytest.raises(TypeError): 148 | func(torch.rand(2, 3, 4, 2, 3), torch.rand(2, 3, 4, 5, 6), torch.rand(5, 6)) 149 | with pytest.raises(TypeError): 150 | func(torch.rand(2, 3, 4, 2, 3), torch.rand(2, 3, 4, 6), torch.rand(3, 4)) 151 | 152 | 153 | def test_ambiguous_ellipsis(): 154 | @typechecked 155 | def func1(x: TensorType["dim1":..., "dim2":...]): 156 | pass 157 | 158 | with pytest.raises(TypeError): 159 | func1(torch.rand(2, 2)) 160 | 161 | @typechecked 162 | def func2(x: TensorType["dim1":..., "dim1":...]): 163 | pass 164 | 165 | with pytest.raises(TypeError): 166 | func2(torch.rand(2, 2)) 167 | -------------------------------------------------------------------------------- /test/test_examples.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from torch import ones, rand, sparse_coo, tensor 4 | from torchtyping import TensorType, is_named 5 | from typeguard import typechecked 6 | 7 | 8 | # make flake8 happy 9 | batch = x_channels = y_channels = a = b = channels = channels_x = channels_y = None 10 | annotator = word = feature = predicate = argument = None 11 | 12 | 13 | def test_example0(): 14 | @typechecked 15 | def batch_outer_product( 16 | x: TensorType["batch", "x_channels"], y: TensorType["batch", "y_channels"] 17 | ) -> TensorType["batch", "x_channels", "y_channels"]: 18 | 19 | return x.unsqueeze(-1) * y.unsqueeze(-2) 20 | 21 | batch_outer_product(rand(2, 3), rand(2, 4)) 22 | batch_outer_product(rand(5, 2), rand(5, 2)) 23 | with pytest.raises(TypeError): 24 | batch_outer_product(rand(3, 2), rand(2, 3)) 25 | with pytest.raises(TypeError): 26 | batch_outer_product(rand(1, 2, 3), rand(2, 3)) 27 | 28 | 29 | def test_example1(): 30 | @typechecked 31 | def func(x: TensorType["batch"], y: TensorType["batch"]) -> TensorType["batch"]: 32 | return x + y 33 | 34 | func(rand(3), rand(3)) 35 | with pytest.raises(TypeError): 36 | func(rand(3), rand(1)) 37 | 38 | 39 | def test_example2(): 40 | @typechecked 41 | def func(x: TensorType["batch", 5], y: TensorType["batch", 3]): 42 | pass 43 | 44 | func(rand(3, 5), rand(3, 3)) 45 | func(rand(7, 5), rand(7, 3)) 46 | with pytest.raises(TypeError): 47 | func(rand(4, 5), rand(3, 5)) 48 | with pytest.raises(TypeError): 49 | func(rand(1, 3, 5), rand(3, 5)) 50 | with pytest.raises(TypeError): 51 | func(rand(3, 4), rand(3, 3)) 52 | with pytest.raises(TypeError): 53 | func(rand(1, 3, 5), rand(1, 3, 3)) 54 | 55 | 56 | def test_example3(): 57 | @typechecked 58 | def func(x: TensorType[2, -1, -1]): 59 | pass 60 | 61 | func(rand(2, 1, 1)) 62 | func(rand(2, 2, 1)) 63 | func(rand(2, 10, 1)) 64 | func(rand(2, 1, 10)) 65 | with pytest.raises(TypeError): 66 | func(rand(1, 2, 1, 1)) 67 | with pytest.raises(TypeError): 68 | func(rand(2, 1)) 69 | with pytest.raises(TypeError): 70 | func( 71 | rand( 72 | 2, 73 | ) 74 | ) 75 | with pytest.raises(TypeError): 76 | func(rand(4, 2, 2)) 77 | 78 | 79 | def test_example4(): 80 | @typechecked 81 | def func(x: TensorType[..., 2, 3]): 82 | pass 83 | 84 | func(rand(2, 3)) 85 | func(rand(1, 2, 3)) 86 | func(rand(2, 2, 3)) 87 | func(rand(3, 3, 5, 2, 3)) 88 | with pytest.raises(TypeError): 89 | func(rand(1, 3)) 90 | with pytest.raises(TypeError): 91 | func(rand(2, 1)) 92 | with pytest.raises(TypeError): 93 | func(rand(1, 1, 3)) 94 | with pytest.raises(TypeError): 95 | func(rand(2, 3, 3)) 96 | with pytest.raises(TypeError): 97 | func(rand(3)) 98 | with pytest.raises(TypeError): 99 | func(rand(2)) 100 | 101 | 102 | def test_example5(): 103 | @typechecked 104 | def func(x: TensorType[..., 2, "channels"], y: TensorType[..., "channels"]): 105 | pass 106 | 107 | func(rand(1, 2, 3), rand(3)) 108 | func(rand(1, 2, 3), rand(1, 3)) 109 | func(rand(2, 3), rand(2, 3)) 110 | with pytest.raises(TypeError): 111 | func(rand(2, 2, 2, 2), rand(2, 4)) 112 | with pytest.raises(TypeError): 113 | func(rand(3, 2), rand(2)) 114 | with pytest.raises(TypeError): 115 | func(rand(5, 2, 1), rand(2)) 116 | 117 | 118 | def test_example6(): 119 | @typechecked 120 | def func( 121 | x: TensorType["batch":..., "channels_x"], 122 | y: TensorType["batch":..., "channels_y"], 123 | ): 124 | pass 125 | 126 | func(rand(3, 3, 3), rand(3, 3, 4)) 127 | func(rand(1, 5, 6, 7), rand(1, 5, 6, 1)) 128 | with pytest.raises(TypeError): 129 | func(rand(2, 2, 2), rand(2, 1, 2)) 130 | with pytest.raises(TypeError): 131 | func(rand(4, 2, 2), rand(2, 2, 2)) 132 | with pytest.raises(TypeError): 133 | func(rand(2, 2, 2), rand(2, 1, 4)) 134 | with pytest.raises(TypeError): 135 | func(rand(2, 2), rand(2, 1, 2)) 136 | with pytest.raises(TypeError): 137 | func(rand(2, 2), rand(1, 2, 2)) 138 | with pytest.raises(TypeError): 139 | func(rand(4, 2), rand(3, 2)) 140 | 141 | 142 | def test_example7(): 143 | @typechecked 144 | def func(x: TensorType[3, 4]) -> TensorType[()]: 145 | return rand(()) 146 | 147 | func(rand(3, 4)) 148 | with pytest.raises(TypeError): 149 | func(rand(2, 4)) 150 | 151 | @typechecked 152 | def func2(x: TensorType[3, 4]) -> TensorType[()]: 153 | return rand((1,)) 154 | 155 | with pytest.raises(TypeError): 156 | func2(rand(3, 4)) 157 | with pytest.raises(TypeError): 158 | func2(rand(2, 4)) 159 | 160 | 161 | def test_example8(): 162 | @typechecked 163 | def func(x: TensorType[float]): 164 | pass 165 | 166 | func(rand(2, 3)) 167 | func(rand(1)) 168 | func(rand(())) 169 | with pytest.raises(TypeError): 170 | func(tensor(1)) 171 | with pytest.raises(TypeError): 172 | func(tensor([1, 2])) 173 | with pytest.raises(TypeError): 174 | func(tensor([[1, 1], [2, 2]])) 175 | with pytest.raises(TypeError): 176 | func(tensor(True)) 177 | 178 | 179 | def test_example9(): 180 | @typechecked 181 | def func(x: TensorType[3, 4, float]): 182 | pass 183 | 184 | func(rand(3, 4)) 185 | with pytest.raises(TypeError): 186 | func(rand(3, 4).long()) 187 | with pytest.raises(TypeError): 188 | func(rand(2, 3)) 189 | 190 | 191 | def test_example10(): 192 | @typechecked 193 | def func(x: TensorType["a":3, "b", is_named]): 194 | pass 195 | 196 | func(rand(3, 4, names=("a", "b"))) 197 | with pytest.raises(TypeError): 198 | func(rand(3, 4), names=("a", "c")) 199 | with pytest.raises(TypeError): 200 | func(rand(3, 3, 3), names=(None, "a", "b")) 201 | with pytest.raises(TypeError): 202 | func(rand(3, 3, 3), names=("a", None, "b")) 203 | with pytest.raises(TypeError): 204 | func(rand(3, 3, 3), names=("a", "b", None)) 205 | with pytest.raises(TypeError): 206 | func(rand(3, 4, names=(None, "b"))) 207 | with pytest.raises(TypeError): 208 | func(rand(3, 4, names=("a", None))) 209 | with pytest.raises(TypeError): 210 | func(rand(3, 4)) 211 | with pytest.raises(TypeError): 212 | func(rand(3, 4).long()) 213 | with pytest.raises(TypeError): 214 | func(rand(3)) 215 | 216 | 217 | def test_example11(): 218 | @typechecked 219 | def func(x: TensorType[sparse_coo]): 220 | pass 221 | 222 | func(rand(3, 4).to_sparse()) 223 | with pytest.raises(TypeError): 224 | func(rand(3, 4)) 225 | with pytest.raises(TypeError): 226 | func(rand(3, 4).long()) 227 | 228 | 229 | def test_example12(): 230 | @typechecked 231 | def func( 232 | feats: TensorType["batch":..., "annotator":3, "word", "feature"], 233 | predicates: TensorType[ 234 | "batch":..., "annotator":3, "predicate":"word", "feature" 235 | ], 236 | pred_arg_pairs: TensorType[ 237 | "batch":..., "annotator":3, "predicate":"word", "argument":"word" 238 | ], 239 | ): 240 | pass 241 | 242 | func(ones(2, 1, 3, 4, 5), ones(2, 1, 3, 4, 5), ones(2, 1, 3, 4, 4)) 243 | 244 | # matches ... 245 | with pytest.raises(TypeError): 246 | func(ones(2, 1, 3, 4, 5), ones(2, 3, 4, 5), ones(2, 1, 3, 4, 4)) 247 | with pytest.raises(TypeError): 248 | func(ones(2, 1, 3, 4, 5), ones(2, 2, 3, 4, 5), ones(2, 1, 3, 4, 4)) 249 | 250 | # annotator has 3 251 | with pytest.raises(TypeError): 252 | func(ones(2, 1, 2, 4, 5), ones(2, 1, 2, 4, 5), ones(2, 1, 2, 4, 4)) 253 | 254 | # predicate and argument match word 255 | with pytest.raises(TypeError): 256 | func(ones(2, 1, 3, 3, 5), ones(2, 1, 3, 4, 5), ones(2, 1, 3, 4, 4)) 257 | with pytest.raises(TypeError): 258 | func(ones(2, 1, 3, 4, 5), ones(2, 1, 3, 3, 5), ones(2, 1, 3, 4, 4)) 259 | with pytest.raises(TypeError): 260 | func(ones(2, 1, 3, 4, 5), ones(2, 1, 3, 4, 5), ones(2, 1, 3, 3, 4)) 261 | -------------------------------------------------------------------------------- /test/test_extensions.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from torch import rand, Tensor 3 | from torchtyping import TensorDetail, TensorType 4 | from typeguard import typechecked 5 | 6 | good = foo = None 7 | 8 | # Write the extension 9 | 10 | 11 | class FooDetail(TensorDetail): 12 | def __init__(self, foo): 13 | super().__init__() 14 | self.foo = foo 15 | 16 | def check(self, tensor: Tensor) -> bool: 17 | return hasattr(tensor, "foo") and tensor.foo == self.foo 18 | 19 | # reprs used in error messages when the check is failed 20 | 21 | def __repr__(self) -> str: 22 | return f"FooDetail({self.foo})" 23 | 24 | @classmethod 25 | def tensor_repr(cls, tensor: Tensor) -> str: 26 | # Should return a representation of the tensor with respect 27 | # to what this detail is checking 28 | if hasattr(tensor, "foo"): 29 | return f"FooDetail({tensor.foo})" 30 | else: 31 | return "" 32 | 33 | 34 | # Test the extension 35 | 36 | 37 | @typechecked 38 | def foo_checker(tensor: TensorType[float, FooDetail("good-foo")]): 39 | pass 40 | 41 | 42 | def valid_foo(): 43 | x = rand(3) 44 | x.foo = "good-foo" 45 | foo_checker(x) 46 | 47 | 48 | def invalid_foo_one(): 49 | x = rand(3) 50 | x.foo = "bad-foo" 51 | foo_checker(x) 52 | 53 | 54 | def invalid_foo_two(): 55 | x = rand(2).int() 56 | x.foo = "good-foo" 57 | foo_checker(x) 58 | 59 | 60 | def test_extensions(): 61 | valid_foo() 62 | with pytest.raises(TypeError): 63 | invalid_foo_one() 64 | with pytest.raises(TypeError): 65 | invalid_foo_two() 66 | -------------------------------------------------------------------------------- /test/test_misc.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torchtyping import TensorType 4 | from typeguard import typechecked 5 | from typing import Tuple 6 | 7 | 8 | dim1 = dim2 = dim3 = channel = None 9 | 10 | 11 | def test_non_tensor(): 12 | class Tensor: 13 | shape = torch.Size([2, 2]) 14 | dtype = torch.float32 15 | layout = torch.strided 16 | 17 | args = (None, 4, 3.0, 3.2, "Tensor", Tensor, Tensor()) 18 | 19 | @typechecked 20 | def accepts_tensor1(x: TensorType): 21 | pass 22 | 23 | @typechecked 24 | def accepts_tensor2(x: TensorType[2, 2]): 25 | pass 26 | 27 | @typechecked 28 | def accepts_tensor3(x: TensorType[...]): 29 | pass 30 | 31 | @typechecked 32 | def accepts_tensor4(x: TensorType[float]): 33 | pass 34 | 35 | @typechecked 36 | def accepts_tensor5(x: TensorType[..., float]): 37 | pass 38 | 39 | @typechecked 40 | def accepts_tensor6(x: TensorType[2, int]): 41 | pass 42 | 43 | @typechecked 44 | def accepts_tensor7(x: TensorType[torch.strided]): 45 | pass 46 | 47 | @typechecked 48 | def accepts_tensor8(x: TensorType[2, float, torch.sparse_coo]): 49 | pass 50 | 51 | for func in ( 52 | accepts_tensor1, 53 | accepts_tensor2, 54 | accepts_tensor3, 55 | accepts_tensor4, 56 | accepts_tensor5, 57 | accepts_tensor6, 58 | accepts_tensor7, 59 | accepts_tensor8, 60 | ): 61 | for arg in args: 62 | with pytest.raises(TypeError): 63 | func(arg) 64 | 65 | @typechecked 66 | def accepts_tensors1(x: TensorType, y: TensorType): 67 | pass 68 | 69 | @typechecked 70 | def accepts_tensors2(x: TensorType[2, 2], y: TensorType): 71 | pass 72 | 73 | @typechecked 74 | def accepts_tensors3(x: TensorType[...], y: TensorType): 75 | pass 76 | 77 | @typechecked 78 | def accepts_tensors4(x: TensorType[float], y: TensorType): 79 | pass 80 | 81 | @typechecked 82 | def accepts_tensors5(x: TensorType[..., float], y: TensorType): 83 | pass 84 | 85 | @typechecked 86 | def accepts_tensors6(x: TensorType[2, int], y: TensorType): 87 | pass 88 | 89 | @typechecked 90 | def accepts_tensors7(x: TensorType[torch.strided], y: TensorType): 91 | pass 92 | 93 | @typechecked 94 | def accepts_tensors8(x: TensorType[torch.sparse_coo, float, 2], y: TensorType): 95 | pass 96 | 97 | for func in ( 98 | accepts_tensors1, 99 | accepts_tensors2, 100 | accepts_tensors3, 101 | accepts_tensors4, 102 | accepts_tensors5, 103 | accepts_tensors6, 104 | accepts_tensors7, 105 | accepts_tensors8, 106 | ): 107 | for arg1 in args: 108 | for arg2 in args: 109 | with pytest.raises(TypeError): 110 | func(arg1, arg2) 111 | 112 | 113 | def test_nested_types(): 114 | @typechecked 115 | def func(x: Tuple[TensorType[3, "channel", 4], TensorType["channel"]]): 116 | pass 117 | 118 | func((torch.rand(3, 1, 4), torch.rand(1))) 119 | func((torch.rand(3, 5, 4), torch.rand(5))) 120 | with pytest.raises(TypeError): 121 | func((torch.rand(3, 1, 4), torch.rand(2))) 122 | 123 | 124 | def test_no_getitem(): 125 | @typechecked 126 | def func(x: TensorType, y: TensorType): 127 | pass 128 | 129 | func(torch.rand(2), torch.rand(2)) 130 | with pytest.raises(TypeError): 131 | func(torch.rand(2), None) 132 | with pytest.raises(TypeError): 133 | func(torch.rand(2), [3, 4]) 134 | 135 | 136 | def test_scalar_tensor(): 137 | @typechecked 138 | def func(x: TensorType[()]): 139 | pass 140 | 141 | func(torch.rand(())) 142 | with pytest.raises(TypeError): 143 | func(torch.rand((1,))) 144 | with pytest.raises(TypeError): 145 | func(torch.rand((1, 2))) 146 | with pytest.raises(TypeError): 147 | func(torch.rand((5, 2, 2))) 148 | 149 | 150 | def test_square(): 151 | @typechecked 152 | def func(x: TensorType["dim1", "dim1"]): 153 | pass 154 | 155 | func(torch.rand(2, 2)) 156 | func(torch.rand(5, 5)) 157 | with pytest.raises(TypeError): 158 | func(torch.rand(3, 5)) 159 | with pytest.raises(TypeError): 160 | func(torch.rand(5, 3)) 161 | 162 | 163 | def test_repeat(): 164 | @typechecked 165 | def func(x: TensorType["dim1", "dim2", "dim2"], y: TensorType[-1, "dim2"]): 166 | pass 167 | 168 | func(torch.rand(5, 3, 3), torch.rand(9, 3)) 169 | func(torch.rand(5, 5, 5), torch.rand(9, 5)) 170 | func(torch.rand(4, 5, 5), torch.rand(2, 5)) 171 | with pytest.raises(TypeError): 172 | func(torch.rand(4, 5, 4), torch.rand(3, 5)) 173 | with pytest.raises(TypeError): 174 | func(torch.rand(4, 5, 5), torch.rand(3, 3)) 175 | with pytest.raises(TypeError): 176 | func(torch.rand(4, 3, 5), torch.rand(3, 3)) 177 | with pytest.raises(TypeError): 178 | func(torch.rand(4, 3, 3), torch.rand(0, 2)) 179 | -------------------------------------------------------------------------------- /test/test_shape.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from typing import Any 3 | import torch 4 | from torchtyping import TensorType, is_named 5 | import typeguard 6 | 7 | 8 | # make flake8 happy 9 | a = b = c = x = y = z = None 10 | 11 | 12 | def test_fixed_int_dim(): 13 | @typeguard.typechecked 14 | def _3_dim_checker(x: TensorType[3]): 15 | pass 16 | 17 | @typeguard.typechecked 18 | def _3m1_dim_checker(x: TensorType[3, -1]): 19 | pass 20 | 21 | @typeguard.typechecked 22 | def _4_dim_checker(x: TensorType[4]): 23 | pass 24 | 25 | @typeguard.typechecked 26 | def _4m1_dim_checker(x: TensorType[4, -1]): 27 | pass 28 | 29 | @typeguard.typechecked 30 | def _m14_dim_checker(x: TensorType[-1, 4]): 31 | pass 32 | 33 | @typeguard.typechecked 34 | def _m1m1_dim_checker(x: TensorType[-1, -1]): 35 | pass 36 | 37 | @typeguard.typechecked 38 | def _34_dim_checker(x: TensorType[3, 4]): 39 | pass 40 | 41 | @typeguard.typechecked 42 | def _34m1_dim_checker(x: TensorType[3, 4, -1]): 43 | pass 44 | 45 | @typeguard.typechecked 46 | def _m14m1_dim_checker(x: TensorType[-1, 4, -1]): 47 | pass 48 | 49 | x = torch.rand(3) 50 | _3_dim_checker(x) 51 | with pytest.raises(TypeError): 52 | _3m1_dim_checker(x) 53 | with pytest.raises(TypeError): 54 | _4_dim_checker(x) 55 | with pytest.raises(TypeError): 56 | _4m1_dim_checker(x) 57 | with pytest.raises(TypeError): 58 | _m14_dim_checker(x) 59 | with pytest.raises(TypeError): 60 | _m1m1_dim_checker(x) 61 | with pytest.raises(TypeError): 62 | _34_dim_checker(x) 63 | with pytest.raises(TypeError): 64 | _34m1_dim_checker(x) 65 | with pytest.raises(TypeError): 66 | _m14m1_dim_checker(x) 67 | 68 | x = torch.rand(3, 4) 69 | _3m1_dim_checker(x) 70 | _m14_dim_checker(x) 71 | _m1m1_dim_checker(x) 72 | _34_dim_checker(x) 73 | with pytest.raises(TypeError): 74 | _3_dim_checker(x) 75 | with pytest.raises(TypeError): 76 | _4_dim_checker(x) 77 | with pytest.raises(TypeError): 78 | _4m1_dim_checker(x) 79 | with pytest.raises(TypeError): 80 | _34m1_dim_checker(x) 81 | with pytest.raises(TypeError): 82 | _m14m1_dim_checker(x) 83 | 84 | x = torch.rand(4, 3) 85 | _4m1_dim_checker(x) 86 | _m1m1_dim_checker(x) 87 | with pytest.raises(TypeError): 88 | _3_dim_checker(x) 89 | with pytest.raises(TypeError): 90 | _3m1_dim_checker(x) 91 | with pytest.raises(TypeError): 92 | _4_dim_checker(x) 93 | with pytest.raises(TypeError): 94 | _m14_dim_checker(x) 95 | with pytest.raises(TypeError): 96 | _34_dim_checker(x) 97 | with pytest.raises(TypeError): 98 | _34m1_dim_checker(x) 99 | with pytest.raises(TypeError): 100 | _m14m1_dim_checker(x) 101 | 102 | 103 | def test_str_dim(): 104 | @typeguard.typechecked 105 | def _a_dim_checker(x: TensorType["a"]): 106 | pass 107 | 108 | @typeguard.typechecked 109 | def _ab_dim_checker(x: TensorType["a", "b"]): 110 | pass 111 | 112 | @typeguard.typechecked 113 | def _abc_dim_checker(x: TensorType["a", "b", "c"]): 114 | pass 115 | 116 | @typeguard.typechecked 117 | def _cb_dim_checker(x: TensorType["c", "b"]): 118 | pass 119 | 120 | @typeguard.typechecked 121 | def _am1_dim_checker(x: TensorType["a", -1]): 122 | pass 123 | 124 | @typeguard.typechecked 125 | def _m1b_dim_checker(x: TensorType[-1, "b"]): 126 | pass 127 | 128 | @typeguard.typechecked 129 | def _abm1_dim_checker(x: TensorType["a", "b", -1]): 130 | pass 131 | 132 | @typeguard.typechecked 133 | def _m1bm1_dim_checker(x: TensorType[-1, "b", -1]): 134 | pass 135 | 136 | x = torch.rand(3, 4) 137 | _ab_dim_checker(x) 138 | _cb_dim_checker(x) 139 | _am1_dim_checker(x) 140 | _m1b_dim_checker(x) 141 | with pytest.raises(TypeError): 142 | _a_dim_checker(x) 143 | with pytest.raises(TypeError): 144 | _abc_dim_checker(x) 145 | with pytest.raises(TypeError): 146 | _abm1_dim_checker(x) 147 | with pytest.raises(TypeError): 148 | _m1bm1_dim_checker(x) 149 | 150 | 151 | def test_str_str_dim1(): 152 | @typeguard.typechecked 153 | def func(x: TensorType["a":"x"]): 154 | pass 155 | 156 | func(torch.ones(3)) 157 | func(torch.ones(2)) 158 | with pytest.raises(TypeError): 159 | func(torch.tensor(3.0)) 160 | with pytest.raises(TypeError): 161 | func(torch.ones(3, 3)) 162 | 163 | 164 | def test_str_str_dim2(): 165 | @typeguard.typechecked 166 | def func(x: TensorType["a":"x", "b":"x"]): 167 | pass 168 | 169 | func(torch.ones(3, 3)) 170 | func(torch.ones(2, 2)) 171 | with pytest.raises(TypeError): 172 | func(torch.tensor(3.0)) 173 | with pytest.raises(TypeError): 174 | func(torch.ones(3)) 175 | with pytest.raises(TypeError): 176 | func(torch.ones(3, 2)) 177 | with pytest.raises(TypeError): 178 | func(torch.ones(2, 3)) 179 | 180 | 181 | def test_str_str_dim_complex(): 182 | @typeguard.typechecked 183 | def func(x: TensorType["a":"x", "b":"x", "x", "a", "b"]) -> TensorType["c":"x"]: 184 | return torch.ones(x.shape[0]) 185 | 186 | func(torch.ones(3, 3, 3, 3, 3)) 187 | func(torch.ones(2, 2, 2, 2, 2)) 188 | with pytest.raises(TypeError): 189 | func(torch.ones(1, 2, 2, 2, 2)) 190 | with pytest.raises(TypeError): 191 | func(torch.ones(2, 1, 2, 2, 2)) 192 | with pytest.raises(TypeError): 193 | func(torch.ones(2, 2, 1, 2, 2)) 194 | with pytest.raises(TypeError): 195 | func(torch.ones(2, 2, 2, 1, 2)) 196 | with pytest.raises(TypeError): 197 | func(torch.ones(2, 2, 2, 2, 1)) 198 | 199 | @typeguard.typechecked 200 | def func2(x: TensorType["a":"x", "b":"x", "x", "a", "b"]) -> TensorType["c":"x"]: 201 | return torch.ones(x.shape[0] + 1) 202 | 203 | with pytest.raises(TypeError): 204 | func2(torch.ones(2, 2, 2, 2, 2)) 205 | 206 | @typeguard.typechecked 207 | def func3(x: TensorType["a":"x", "b":"x", "x", "a", "b"]) -> TensorType["c":"x"]: 208 | return torch.ones(x.shape[0], x.shape[0]) 209 | 210 | with pytest.raises(TypeError): 211 | func3(torch.ones(2, 2, 2, 2, 2)) 212 | 213 | 214 | def test_str_str_dim_fixed_num(): 215 | @typeguard.typechecked 216 | def func(x: TensorType["a":"x"]) -> TensorType["x":3]: 217 | return torch.ones(x.shape[0]) 218 | 219 | func(torch.ones(3)) 220 | with pytest.raises(TypeError): 221 | func(torch.ones(2)) 222 | 223 | 224 | def test_str_str_dim_fixed_names(): 225 | @typeguard.typechecked 226 | def func(x: TensorType["a":"x", is_named]) -> TensorType["x":3]: 227 | return torch.ones(x.shape[0]) 228 | 229 | func(torch.ones(3, names=["a"])) 230 | with pytest.raises(TypeError): 231 | func(torch.ones(3)) 232 | with pytest.raises(TypeError): 233 | func(torch.ones(3, names=["b"])) 234 | with pytest.raises(TypeError): 235 | func(torch.ones(2, names=["a"])) 236 | with pytest.raises(TypeError): 237 | func(torch.ones(3, names=["x"])) 238 | 239 | 240 | def test_str_str_dim_no_early_return(): 241 | @typeguard.typechecked 242 | def func(x: TensorType["a":"x", "b":"y", "c":"z", is_named]): 243 | pass 244 | 245 | func(torch.ones(2, 2, 2, names=["a", "b", "c"])) 246 | with pytest.raises(TypeError): 247 | func(torch.ones(2, 2, 2, names=["d", "b", "c"])) 248 | with pytest.raises(TypeError): 249 | func(torch.ones(2, 2, 2, names=["a", "b", "d"])) 250 | 251 | 252 | def test_none_str(): 253 | @typeguard.typechecked 254 | def func(x: TensorType[None:"x", "b":"x", is_named]): 255 | pass 256 | 257 | func(torch.ones(2, 2, names=[None, "b"])) 258 | func(torch.ones(3, 3, names=[None, "b"])) 259 | with pytest.raises(TypeError): 260 | func(torch.ones(2, 2, names=["a", "b"])) 261 | with pytest.raises(TypeError): 262 | func(torch.ones(2, 2, names=["x", "b"])) 263 | with pytest.raises(TypeError): 264 | func(torch.ones(2, 2, names=[None, None])) 265 | with pytest.raises(TypeError): 266 | func(torch.ones(2, 2, names=[None, "c"])) 267 | with pytest.raises(TypeError): 268 | func(torch.ones(2, 2, names=[None, "x"])) 269 | 270 | 271 | def test_other_str_should_fail(): 272 | with pytest.raises(TypeError): 273 | 274 | def func(x: TensorType[3:"x"]): 275 | pass 276 | 277 | 278 | def test_int_str_dim(): 279 | @typeguard.typechecked 280 | def _a_dim_checker1(x: TensorType["a":3]): 281 | pass 282 | 283 | @typeguard.typechecked 284 | def _a_dim_checker2(x: TensorType["a":-1]): 285 | pass 286 | 287 | @typeguard.typechecked 288 | def _ab_dim_checker1(x: TensorType["a":3, "b":4]): 289 | pass 290 | 291 | @typeguard.typechecked 292 | def _ab_dim_checker2(x: TensorType["a":3, "b":-1]): 293 | pass 294 | 295 | @typeguard.typechecked 296 | def _ab_dim_checker3(x: TensorType["a":-1, "b":4]): 297 | pass 298 | 299 | @typeguard.typechecked 300 | def _ab_dim_checker4(x: TensorType["a":3, "b"]): 301 | pass 302 | 303 | @typeguard.typechecked 304 | def _ab_dim_checker5(x: TensorType["a", "b":4]): 305 | pass 306 | 307 | @typeguard.typechecked 308 | def _ab_dim_checker6(x: TensorType["a":5, "b":4]): 309 | pass 310 | 311 | @typeguard.typechecked 312 | def _ab_dim_checker7(x: TensorType["a":5, "b":-1]): 313 | pass 314 | 315 | @typeguard.typechecked 316 | def _m1b_dim_checker(x: TensorType[-1, "b":4]): 317 | pass 318 | 319 | @typeguard.typechecked 320 | def _abm1_dim_checker(x: TensorType["a":3, "b":4, -1]): 321 | pass 322 | 323 | @typeguard.typechecked 324 | def _m1bm1_dim_checker(x: TensorType[-1, "b":4, -1]): 325 | pass 326 | 327 | x = torch.rand(3, 4) 328 | _ab_dim_checker1(x) 329 | _ab_dim_checker2(x) 330 | _ab_dim_checker3(x) 331 | _ab_dim_checker4(x) 332 | _ab_dim_checker5(x) 333 | _m1b_dim_checker(x) 334 | with pytest.raises(TypeError): 335 | _a_dim_checker1(x) 336 | with pytest.raises(TypeError): 337 | _a_dim_checker2(x) 338 | with pytest.raises(TypeError): 339 | _ab_dim_checker6(x) 340 | with pytest.raises(TypeError): 341 | _ab_dim_checker7(x) 342 | with pytest.raises(TypeError): 343 | _abm1_dim_checker(x) 344 | with pytest.raises(TypeError): 345 | _m1bm1_dim_checker(x) 346 | 347 | 348 | def test_return(): 349 | @typeguard.typechecked 350 | def f1(x: TensorType["b":4]) -> TensorType["b":4]: 351 | return torch.rand(3) 352 | 353 | @typeguard.typechecked 354 | def f2(x: TensorType["b"]) -> TensorType["b":4]: 355 | return torch.rand(3) 356 | 357 | @typeguard.typechecked 358 | def f3(x: TensorType[4]) -> TensorType["b":4]: 359 | return torch.rand(3) 360 | 361 | @typeguard.typechecked 362 | def f4(x: TensorType["b":4]) -> TensorType["b"]: 363 | return torch.rand(3) 364 | 365 | @typeguard.typechecked 366 | def f5(x: TensorType["b"]) -> TensorType["b"]: 367 | return torch.rand(3) 368 | 369 | @typeguard.typechecked 370 | def f6(x: TensorType[4]) -> TensorType["b"]: 371 | return torch.rand(3) 372 | 373 | @typeguard.typechecked 374 | def f7(x: TensorType["b":4]) -> TensorType[4]: 375 | return torch.rand(3) 376 | 377 | @typeguard.typechecked 378 | def f8(x: TensorType["b"]) -> TensorType[4]: 379 | return torch.rand(3) 380 | 381 | @typeguard.typechecked 382 | def f9(x: TensorType[4]) -> TensorType[4]: 383 | return torch.rand(3) 384 | 385 | with pytest.raises(TypeError): 386 | f1(torch.rand(3)) 387 | with pytest.raises(TypeError): 388 | f2(torch.rand(3)) 389 | with pytest.raises(TypeError): 390 | f3(torch.rand(3)) 391 | with pytest.raises(TypeError): 392 | f4(torch.rand(3)) 393 | f5(torch.rand(3)) 394 | with pytest.raises(TypeError): 395 | f6(torch.rand(3)) 396 | with pytest.raises(TypeError): 397 | f7(torch.rand(3)) 398 | with pytest.raises(TypeError): 399 | f8(torch.rand(3)) 400 | with pytest.raises(TypeError): 401 | f9(torch.rand(3)) 402 | 403 | with pytest.raises(TypeError): 404 | f1(torch.rand(4)) 405 | with pytest.raises(TypeError): 406 | f2(torch.rand(4)) 407 | with pytest.raises(TypeError): 408 | f3(torch.rand(4)) 409 | with pytest.raises(TypeError): 410 | f4(torch.rand(4)) 411 | with pytest.raises(TypeError): 412 | f5(torch.rand(4)) 413 | f6(torch.rand(4)) 414 | with pytest.raises(TypeError): 415 | f7(torch.rand(4)) 416 | with pytest.raises(TypeError): 417 | f8(torch.rand(4)) 418 | with pytest.raises(TypeError): 419 | f9(torch.rand(4)) 420 | 421 | 422 | def test_any_dim(): 423 | @typeguard.typechecked 424 | def _3any_dim_checker(x: TensorType[3, Any]): 425 | pass 426 | 427 | @typeguard.typechecked 428 | def _any4_dim_checker(x: TensorType[Any, 4]): 429 | pass 430 | 431 | @typeguard.typechecked 432 | def _anyany_dim_checker(x: TensorType[Any, Any]): 433 | pass 434 | 435 | @typeguard.typechecked 436 | def _34any_dim_checker(x: TensorType[3, 4, Any]): 437 | pass 438 | 439 | @typeguard.typechecked 440 | def _any4any_dim_checker(x: TensorType[Any, 4, Any]): 441 | pass 442 | 443 | x = torch.rand(3) 444 | with pytest.raises(TypeError): 445 | _3any_dim_checker(x) 446 | with pytest.raises(TypeError): 447 | _any4_dim_checker(x) 448 | with pytest.raises(TypeError): 449 | _anyany_dim_checker(x) 450 | with pytest.raises(TypeError): 451 | _34any_dim_checker(x) 452 | with pytest.raises(TypeError): 453 | _any4any_dim_checker(x) 454 | 455 | x = torch.rand((3, 4)) 456 | _3any_dim_checker(x) 457 | _any4_dim_checker(x) 458 | _anyany_dim_checker(x) 459 | 460 | x = torch.rand((4, 5)) 461 | with pytest.raises(TypeError): 462 | _any4_dim_checker(x) 463 | 464 | x = torch.rand(4, 5) 465 | with pytest.raises(TypeError): 466 | _3any_dim_checker(x) 467 | 468 | x = torch.rand((3, 4, 5)) 469 | _34any_dim_checker(x) 470 | _any4any_dim_checker(x) 471 | 472 | x = torch.rand((3, 5, 5)) 473 | with pytest.raises(TypeError): 474 | x = _any4any_dim_checker(x) 475 | with pytest.raises(TypeError): 476 | _34any_dim_checker(x) 477 | -------------------------------------------------------------------------------- /torchtyping/__init__.py: -------------------------------------------------------------------------------- 1 | from .tensor_details import ( 2 | DtypeDetail, 3 | is_float, 4 | is_named, 5 | LayoutDetail, 6 | ShapeDetail, 7 | TensorDetail, 8 | ) 9 | 10 | from .tensor_type import TensorType 11 | from .typechecker import patch_typeguard 12 | 13 | __version__ = "0.1.5" 14 | -------------------------------------------------------------------------------- /torchtyping/pytest_plugin.py: -------------------------------------------------------------------------------- 1 | from .typechecker import patch_typeguard 2 | 3 | 4 | def pytest_addoption(parser): 5 | group = parser.getgroup("torchtyping") 6 | group.addoption( 7 | "--torchtyping-patch-typeguard", 8 | action="store_true", 9 | help="Run torchtyping's typeguard patch.", 10 | ) 11 | 12 | 13 | def pytest_configure(config): 14 | if config.getoption("torchtyping_patch_typeguard"): 15 | patch_typeguard() 16 | -------------------------------------------------------------------------------- /torchtyping/tensor_details.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import abc 4 | import collections 5 | import torch 6 | 7 | from typing import Optional, Union 8 | 9 | 10 | ellipsis = type(...) 11 | 12 | 13 | class TensorDetail(metaclass=abc.ABCMeta): 14 | @abc.abstractmethod 15 | def __repr__(self) -> str: 16 | raise NotImplementedError 17 | 18 | @abc.abstractmethod 19 | def check(self, tensor: torch.Tensor) -> bool: 20 | raise NotImplementedError 21 | 22 | @classmethod 23 | @abc.abstractmethod 24 | def tensor_repr(cls, tensor: torch.Tensor) -> str: 25 | raise NotImplementedError 26 | 27 | 28 | _no_name = object() 29 | 30 | 31 | # inheriting from typing.NamedTuple crashes typeguard 32 | class _Dim(collections.namedtuple("_Dim", ["name", "size"])): 33 | # None corresponds to a name not being set. no_name corresponds to us not caring 34 | # whether a name is set. 35 | name: Union[None, str, type(_no_name)] 36 | # technically supposed to use an enum to annotate singletons but that's overkill. 37 | 38 | size: Union[ellipsis, int] 39 | 40 | def __repr__(self) -> str: 41 | if self.name is _no_name: 42 | if self.size is ...: 43 | return "..." 44 | else: 45 | return repr(self.size) 46 | else: 47 | if self.size is ...: 48 | return f"{self.name}: ..." 49 | elif self.size == -1: 50 | return repr(self.name) 51 | else: 52 | return f"{self.name}: {self.size}" 53 | 54 | 55 | class ShapeDetail(TensorDetail): 56 | def __init__(self, *, dims: list[_Dim], check_names: bool, **kwargs) -> None: 57 | super().__init__(**kwargs) 58 | self.dims = dims 59 | self.check_names = check_names 60 | 61 | def __repr__(self) -> str: 62 | if len(self.dims) == 0: 63 | out = "()" 64 | elif len(self.dims) == 1: 65 | out = repr(self.dims[0]) 66 | else: 67 | out = repr(tuple(self.dims))[1:-1] 68 | if self.check_names: 69 | out += ", is_named" 70 | return out 71 | 72 | def check(self, tensor: torch.Tensor) -> bool: 73 | self_names = [self_dim.name for self_dim in self.dims] 74 | self_shape = [self_dim.size for self_dim in self.dims] 75 | 76 | if ... in self_shape: 77 | if sum(1 for size in self_shape if size is not ...) > len(tensor.names): 78 | return False 79 | else: 80 | if len(self_shape) != len(tensor.names): 81 | return False 82 | 83 | for self_name, self_size, tensor_name, tensor_size in zip( 84 | reversed(self_names), 85 | reversed(self_shape), 86 | reversed(tensor.names), 87 | reversed(tensor.shape), 88 | ): 89 | if self_size is ...: 90 | # This assumes that Ellipses only occur on the left hand edge. 91 | # So once we hit one we're done. 92 | break 93 | 94 | if ( 95 | self.check_names 96 | and self_name is not _no_name 97 | and self_name != tensor_name 98 | ): 99 | return False 100 | if not isinstance(self_size, str) and self_size not in (-1, tensor_size): 101 | return False 102 | 103 | return True 104 | 105 | @classmethod 106 | def tensor_repr(cls, tensor: torch.Tensor) -> str: 107 | dims = [] 108 | check_names = any(name is not None for name in tensor.names) 109 | for name, size in zip(tensor.names, tensor.shape): 110 | if not check_names: 111 | name = _no_name 112 | dims.append(_Dim(name=name, size=size)) 113 | return repr(cls(dims=dims, check_names=check_names)) 114 | 115 | def update( 116 | self, 117 | *, 118 | dims: Optional[list[_Dim]] = None, 119 | check_names: Optional[bool] = None, 120 | **kwargs, 121 | ) -> ShapeDetail: 122 | dims = self.dims if dims is None else dims 123 | check_names = self.check_names if check_names is None else check_names 124 | return type(self)(dims=dims, check_names=check_names, **kwargs) 125 | 126 | 127 | class DtypeDetail(TensorDetail): 128 | def __init__(self, *, dtype, **kwargs) -> None: 129 | super().__init__(**kwargs) 130 | assert isinstance(dtype, torch.dtype) 131 | self.dtype = dtype 132 | 133 | def __repr__(self) -> str: 134 | return repr(self.dtype) 135 | 136 | def check(self, tensor: torch.Tensor) -> bool: 137 | return self.dtype == tensor.dtype 138 | 139 | @classmethod 140 | def tensor_repr(cls, tensor: torch.Tensor) -> str: 141 | return repr(cls(dtype=tensor.dtype)) 142 | 143 | 144 | class LayoutDetail(TensorDetail): 145 | def __init__(self, *, layout, **kwargs) -> None: 146 | super().__init__(**kwargs) 147 | self.layout = layout 148 | 149 | def __repr__(self) -> str: 150 | return repr(self.layout) 151 | 152 | def check(self, tensor: torch.Tensor) -> bool: 153 | return self.layout == tensor.layout 154 | 155 | @classmethod 156 | def tensor_repr(cls, tensor: torch.Tensor) -> str: 157 | return repr(cls(layout=tensor.layout)) 158 | 159 | 160 | class _FloatDetail(TensorDetail): 161 | def __repr__(self) -> str: 162 | return "is_float" 163 | 164 | def check(self, tensor: torch.Tensor) -> bool: 165 | return tensor.is_floating_point() 166 | 167 | @classmethod 168 | def tensor_repr(cls, tensor: torch.Tensor) -> str: 169 | return "is_float" if tensor.is_floating_point() else "" 170 | 171 | 172 | # is_named is special-cased and consumed by TensorType. 173 | # It's a bit of an odd exception. 174 | # It's only a TensorDetail for consistency, as the other 175 | # extra flags that get passed are TensorDetails. 176 | class _NamedTensorDetail(TensorDetail): 177 | def __repr__(self) -> str: 178 | raise RuntimeError 179 | 180 | def check(self, tensor: torch.Tensor) -> bool: 181 | raise RuntimeError 182 | 183 | @classmethod 184 | def tensor_repr(cls, tensor: torch.Tensor) -> str: 185 | raise RuntimeError 186 | 187 | 188 | is_float = _FloatDetail() # singleton flag 189 | is_named = _NamedTensorDetail() # singleton flag 190 | -------------------------------------------------------------------------------- /torchtyping/tensor_type.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import sys 4 | import torch 5 | 6 | from .tensor_details import ( 7 | _Dim, 8 | _no_name, 9 | is_named, 10 | DtypeDetail, 11 | LayoutDetail, 12 | ShapeDetail, 13 | TensorDetail, 14 | ) 15 | from .utils import frozendict 16 | 17 | from typing import Any, NoReturn 18 | 19 | # Annotated is available in python version 3.9 (PEP 593) 20 | if sys.version_info >= (3, 9): 21 | from typing import Annotated 22 | else: 23 | # Else python version is lower than 3.9 24 | # we import Annotated from typing_annotations 25 | from typing_extensions import Annotated 26 | 27 | # Not Type[Annotated...] as we want to use this in instance checks. 28 | _AnnotatedType = type(Annotated[torch.Tensor, ...]) 29 | 30 | 31 | # For use when we have a plain TensorType, without any []. 32 | class _TensorTypeMeta(type(torch.Tensor)): 33 | def __instancecheck__(cls, obj: Any) -> bool: 34 | return isinstance(obj, cls.base_cls) 35 | 36 | 37 | # Inherit from torch.Tensor so that IDEs are happy to find methods on functions 38 | # annotated as TensorTypes. 39 | class TensorType(torch.Tensor, metaclass=_TensorTypeMeta): 40 | base_cls = torch.Tensor 41 | 42 | def __new__(cls, *args, **kwargs) -> NoReturn: 43 | raise RuntimeError(f"Class {cls.__name__} cannot be instantiated.") 44 | 45 | @staticmethod 46 | def _type_error(item: Any) -> NoReturn: 47 | raise TypeError(f"{item} not a valid type argument.") 48 | 49 | @classmethod 50 | def _convert_shape_element(cls, item_i: Any) -> _Dim: 51 | if isinstance(item_i, int) and not isinstance(item_i, bool): 52 | return _Dim(name=_no_name, size=item_i) 53 | elif isinstance(item_i, str): 54 | return _Dim(name=item_i, size=-1) 55 | elif item_i is None: 56 | return _Dim(name=None, size=-1) 57 | elif isinstance(item_i, slice): 58 | if item_i.step is not None: 59 | cls._type_error(item_i) 60 | if item_i.start is not None and not isinstance(item_i.start, str): 61 | cls._type_error(item_i) 62 | if item_i.stop is not ... and not isinstance(item_i.stop, (int, str)): 63 | cls._type_error(item_i) 64 | if item_i.start is None and item_i.stop is ...: 65 | cls._type_error(item_i) 66 | return _Dim(name=item_i.start, size=item_i.stop) 67 | elif item_i is ...: 68 | return _Dim(name=_no_name, size=...) 69 | elif item_i is Any: 70 | return _Dim(name=_no_name, size=-1) 71 | else: 72 | cls._type_error(item_i) 73 | 74 | @staticmethod 75 | def _convert_dtype_element(item_i: Any) -> torch.dtype: 76 | if item_i is int: 77 | return torch.long 78 | elif item_i is float: 79 | return torch.get_default_dtype() 80 | elif item_i is bool: 81 | return torch.bool 82 | else: 83 | return item_i 84 | 85 | def __class_getitem__(cls, item: Any) -> _AnnotatedType: 86 | if isinstance(item, tuple): 87 | if len(item) == 0: 88 | item = ((),) 89 | else: 90 | item = (item,) 91 | 92 | scalar_shape = False 93 | not_ellipsis = False 94 | not_named_ellipsis = False 95 | check_names = False 96 | dims = [] 97 | dtypes = [] 98 | layouts = [] 99 | details = [] 100 | for item_i in item: 101 | if isinstance(item_i, (int, str, slice)) or item_i in (None, ..., Any): 102 | item_i = cls._convert_shape_element(item_i) 103 | if item_i.size is ...: 104 | # Supporting an arbitrary number of Ellipsis in arbitrary 105 | # locations feels concerningly close to writing a regex 106 | # parser and I definitely don't have time for that. 107 | if not_ellipsis: 108 | raise NotImplementedError( 109 | "Having dimensions to the left of `...` is not currently " 110 | "supported." 111 | ) 112 | if item_i.name is None: 113 | if not_named_ellipsis: 114 | raise NotImplementedError( 115 | "Having named `...` to the left of unnamed `...` is " 116 | "not currently supported." 117 | ) 118 | else: 119 | not_named_ellipsis = True 120 | else: 121 | not_ellipsis = True 122 | dims.append(item_i) 123 | elif isinstance(item_i, tuple): 124 | if len(item_i) == 0: 125 | scalar_shape = True 126 | else: 127 | cls._type_error(item_i) 128 | elif item_i in (int, bool, float) or isinstance(item_i, torch.dtype): 129 | dtypes.append(cls._convert_dtype_element(item_i)) 130 | elif isinstance(item_i, torch.layout): 131 | layouts.append(item_i) 132 | elif item_i is is_named: 133 | check_names = True 134 | elif isinstance(item_i, TensorDetail): 135 | details.append(item_i) 136 | else: 137 | cls._type_error(item_i) 138 | 139 | if scalar_shape: 140 | if len(dims) != 0: 141 | cls._type_error(item) 142 | else: 143 | if len(dims) == 0: 144 | dims = None 145 | 146 | pre_details = [] 147 | if dims is not None: 148 | pre_details.append(ShapeDetail(dims=dims, check_names=check_names)) 149 | 150 | if len(dtypes) == 0: 151 | pass 152 | elif len(dtypes) == 1: 153 | pre_details.append(DtypeDetail(dtype=dtypes[0])) 154 | else: 155 | raise TypeError("Cannot have multiple dtypes.") 156 | 157 | if len(layouts) == 0: 158 | pass 159 | elif len(layouts) == 1: 160 | pre_details.append(LayoutDetail(layout=layouts[0])) 161 | else: 162 | raise TypeError("Cannot have multiple layouts.") 163 | 164 | details = tuple(pre_details + details) 165 | 166 | assert len(details) > 0 167 | 168 | # Frozen dict needed for Union[TensorType[...], ...], as Union hashes its 169 | # arguments. 170 | return Annotated[ 171 | cls.base_cls, 172 | frozendict( 173 | {"__torchtyping__": True, "details": details, "cls_name": cls.__name__} 174 | ), 175 | ] 176 | -------------------------------------------------------------------------------- /torchtyping/typechecker.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import sys 3 | import torch 4 | import typeguard 5 | 6 | from .tensor_details import _Dim, _no_name, ShapeDetail 7 | from .tensor_type import _AnnotatedType 8 | 9 | from typing import Any, Dict, List, Tuple 10 | 11 | # get_args is available in python version 3.8 12 | # get_type_hints with include_extras parameter is available in 3.9 PEP 593. 13 | if sys.version_info >= (3, 9): 14 | from typing import get_type_hints, get_args, Type 15 | else: 16 | from typing_extensions import get_type_hints, get_args, Type 17 | 18 | 19 | # TYPEGUARD PATCHER 20 | ####################### 21 | # So there's quite a lot of moving pieces here. 22 | # The logic proceeds as follows. 23 | # 24 | # Calling patch_typeguard() just monkey-patches some of its functions and classes. 25 | # 26 | # typeguard uses a `_CallMemo` object to store information about each function that it 27 | # is checking: this is what allows us to perform function-level checking (consistency 28 | # of tensor shapes) rather than just argument-level checking (simple isinstance 29 | # checks). 30 | # So the first thing we do is enhance that with a couple extra slots to store our 31 | # information 32 | # 33 | # Second, we patch `check_type`. typeguard traverses the [] hierarchy, e.g. from 34 | # Tuple[List[int]] to List[int] to int, recursively calling `check_type`. By patching 35 | # `check_type` we can check for our `TensorType`s and record every value-type pair. 36 | # (Actually it's a bit more than that: we record some names for use in the error 37 | # messages.) These are recorded in our enhanced `_CallMemo` object. 38 | # 39 | # (Incidentally we also have to patch typeguard's use of typing.get_type_hints, so that 40 | # our annotations aren't stripped.) 41 | # 42 | # Then we patch `check_argument_types` and `check_return_type`, to perform our extra 43 | # TensorType checking. This is the same checking in both cases so we factor that out 44 | # into _check_memo. 45 | # 46 | # _check_memo performs the real logic of the checking here. This looks at all the 47 | # recorded value-type pairs and checks for any inconsistencies. 48 | 49 | 50 | def _to_string(name, detail_reprs: List[str]) -> str: 51 | assert len(detail_reprs) > 0 52 | string = name + "[" 53 | pieces = [] 54 | for detail_repr in detail_reprs: 55 | if detail_repr != "": 56 | pieces.append(detail_repr) 57 | string += ", ".join(pieces) 58 | string += "]" 59 | return string 60 | 61 | 62 | def _check_tensor( 63 | argname: str, value: Any, origin: Type[torch.Tensor], metadata: Dict[str, Any] 64 | ): 65 | details = metadata["details"] 66 | if not isinstance(value, origin) or any( 67 | not detail.check(value) for detail in details 68 | ): 69 | expected_string = _to_string( 70 | metadata["cls_name"], [repr(detail) for detail in details] 71 | ) 72 | if isinstance(value, torch.Tensor): 73 | given_string = _to_string( 74 | metadata["cls_name"], [detail.tensor_repr(value) for detail in details] 75 | ) 76 | else: 77 | value = type(value) 78 | if hasattr(value, "__qualname__"): 79 | given_string = value.__qualname__ 80 | elif hasattr(value, "__name__"): 81 | given_string = value.__name__ 82 | else: 83 | given_string = repr(value) 84 | raise TypeError( 85 | f"{argname} must be of type {expected_string}, got type {given_string} " 86 | "instead." 87 | ) 88 | 89 | 90 | def _check_memo(memo): 91 | ########### 92 | # Parse the tensors and figure out the sizes of all labelled 93 | # dimensions. 94 | # This also performs some (and in practice most) of the consistency 95 | # checks. However its job is primarily one of assigning sizes to labels. 96 | # The final checking of the inferred sizes is performed afterwards. 97 | # 98 | # This logic is a bit hairy. Most of the complexity comes from 99 | # supporting `...` arbitrary numbers of dimensions. 100 | ########### 101 | 102 | # ordered set 103 | shape_info = { 104 | (argname, value.shape, detail): None 105 | for argname, value, _, detail in memo.value_info 106 | } 107 | while len(shape_info): 108 | for argname, shape, detail in shape_info: 109 | num_free_ellipsis = 0 110 | for dim in detail.dims: 111 | if dim.size is ... and dim.name not in memo.name_to_shape: 112 | num_free_ellipsis += 1 113 | if num_free_ellipsis <= 1: 114 | reversed_shape = enumerate(reversed(shape)) 115 | for dim in reversed(detail.dims): 116 | try: 117 | reverse_dim_index, size = next(reversed_shape) 118 | except StopIteration: 119 | if dim.size is ...: 120 | if dim.name not in (None, _no_name): 121 | try: 122 | lookup_shape = memo.name_to_shape[dim.name] 123 | except KeyError: 124 | memo.name_to_shape[dim.name] = () 125 | else: 126 | if lookup_shape != (): 127 | raise TypeError( 128 | f"Dimension group '{dim.name}' of " 129 | f"inconsistent shape. Got both () and " 130 | f"{lookup_shape}." 131 | ) 132 | else: 133 | # I don't think it's possible to get here, as the earlier 134 | # call to _check_tensor in check_type should catch 135 | # this case. 136 | raise TypeError( 137 | f"{argname} has {len(shape)} dimensions but type " 138 | "requires more than this." 139 | ) 140 | 141 | if dim.name not in (None, _no_name): 142 | if dim.size is ...: 143 | try: 144 | lookup_shape = memo.name_to_shape[dim.name] 145 | except KeyError: 146 | # Can only get here if we're the single free 147 | # ellipsis. 148 | # Therefore the number of dimensions the ellipsis 149 | # corresponds to, is the number of dimensions 150 | # remaining. 151 | forward_index = 0 152 | for forward_dim in detail.dims: # now iterate forwards 153 | if forward_dim is dim: 154 | break 155 | assert forward_dim.size is ... 156 | forward_index += len( 157 | memo.name_to_shape[forward_dim.name] 158 | ) 159 | if reverse_dim_index == 0: 160 | # since [:-0] doesn't work 161 | end_index = None 162 | else: 163 | end_index = -reverse_dim_index 164 | clip_shape = shape[forward_index:end_index] 165 | memo.name_to_shape[dim.name] = tuple(clip_shape) 166 | for _ in range(len(clip_shape) - 1): 167 | next(reversed_shape) 168 | else: 169 | reversed_shape_piece = [] 170 | if len(lookup_shape) >= 1: 171 | reversed_shape_piece.append(size) 172 | for _ in range(len(lookup_shape) - 1): 173 | try: 174 | _, size = next(reversed_shape) 175 | except StopIteration: 176 | break 177 | reversed_shape_piece.append(size) 178 | 179 | shape_piece = tuple(reversed(reversed_shape_piece)) 180 | if lookup_shape != shape_piece: 181 | raise TypeError( 182 | f"Dimension group '{dim.name}' of " 183 | f"inconsistent shape. Got both {shape_piece} " 184 | f"and {lookup_shape}." 185 | ) 186 | else: 187 | names_to_check = ( 188 | [dim.name, dim.size] 189 | if isinstance(dim.size, str) 190 | else [dim.name] 191 | ) 192 | for name in names_to_check: 193 | try: 194 | lookup_size = memo.name_to_size[name] 195 | except KeyError: 196 | memo.name_to_size[name] = size 197 | else: 198 | # Technically not necessary, as one of the 199 | # sizes will override the other, and then the 200 | # instance check will fail. 201 | # This gives a nicer error message though. 202 | if lookup_size != size: 203 | raise TypeError( 204 | f"Dimension '{dim.name}' of inconsistent" 205 | f" size. Got both {size} and " 206 | f"{lookup_size}." 207 | ) 208 | 209 | del shape_info[argname, shape, detail] 210 | break 211 | else: 212 | if len(shape_info): 213 | names = {argname for argname, _, _ in shape_info} 214 | raise TypeError( 215 | f"Could not resolve the size of all `...` in {names}. Either:\n" 216 | "(1) the specification is ambiguous. For example " 217 | "`func(tensor: TensorType['x': ..., 'y': ...])`.\n" 218 | "(2) or repeated named `...` are used without being able to " 219 | "resolve the size of those named `...` via another argument " 220 | "For example `func(tensor: TensorType['x': ..., 'x': ...])`. " 221 | "(But `func(tensor1: TensorType['x': ..., 'x': ...], tensor2: " 222 | "TensorType['x': ...])` would be fine.)\n" 223 | "\n" 224 | "Removing the names of the `...` should suffice to resolve this " 225 | "error. (But will of course remove that checking as well.)" 226 | ) 227 | 228 | ########### 229 | # Do the final checking with the inferred sizes filled in. 230 | # In practice, malformed inputs will usually trip one of the 231 | # checks in the previous logic, so this block doesn't actually raise 232 | # errors very often. (In 1/37 tests at time of writing.) 233 | # A potential performance improvement might be to integrate it into 234 | # the previous block. 235 | ########### 236 | 237 | for argname, value, cls_name, detail in memo.value_info: 238 | dims = [] 239 | for dim in detail.dims: 240 | size = dim.size 241 | if dim.name not in (None, _no_name): 242 | if size == -1: 243 | size = memo.name_to_size[dim.name] 244 | elif isinstance(size, str): 245 | size = memo.name_to_size[size] 246 | elif size is ...: 247 | # This assumes that named Ellipses only occur to the 248 | # right of unnamed Ellipses, to avoid filling in 249 | # Ellipses that occur to the left of other Ellipses. 250 | for size in memo.name_to_shape[dim.name]: 251 | dims.append(_Dim(name=_no_name, size=size)) 252 | continue 253 | dims.append(_Dim(name=dim.name, size=size)) 254 | detail = detail.update(dims=tuple(dims)) 255 | _check_tensor( 256 | argname, value, torch.Tensor, {"cls_name": cls_name, "details": [detail]} 257 | ) 258 | 259 | 260 | unpatched_typeguard = True 261 | 262 | 263 | def patch_typeguard(): 264 | global unpatched_typeguard 265 | if unpatched_typeguard: 266 | unpatched_typeguard = False 267 | 268 | # Defined dynamically, in case something else is doing similar levels of hackery 269 | # patching typeguard. We want to get typeguard._CallMemo at the time we patch, 270 | # not any earlier. (Someone might have replaced it since the import statement.) 271 | class _CallMemo(typeguard._CallMemo): 272 | __slots__ = ( 273 | "value_info", 274 | "name_to_size", 275 | "name_to_shape", 276 | ) 277 | value_info: List[Tuple[str, torch.Tensor, str, Dict[str, Any]]] 278 | name_to_size: Dict[str, int] 279 | name_to_shape: Dict[str, Tuple[int]] 280 | 281 | _check_type = typeguard.check_type 282 | _check_argument_types = typeguard.check_argument_types 283 | _check_return_type = typeguard.check_return_type 284 | 285 | check_type_signature = inspect.signature(_check_type) 286 | check_argument_types_signature = inspect.signature(_check_argument_types) 287 | check_return_type_signature = inspect.signature(_check_return_type) 288 | 289 | def check_type(*args, **kwargs): 290 | bound_args = check_type_signature.bind(*args, **kwargs).arguments 291 | argname = bound_args["argname"] 292 | value = bound_args["value"] 293 | expected_type = bound_args["expected_type"] 294 | memo = bound_args["memo"] 295 | # First look for an annotated type 296 | is_torchtyping_annotation = ( 297 | memo is not None 298 | and hasattr(memo, "value_info") 299 | and isinstance(expected_type, _AnnotatedType) 300 | ) 301 | # Now check if it's annotating a tensor 302 | if is_torchtyping_annotation: 303 | base_cls, *all_metadata = get_args(expected_type) 304 | if not issubclass(base_cls, torch.Tensor): 305 | is_torchtyping_annotation = False 306 | # Now check if the annotation's metadata is our metadata 307 | if is_torchtyping_annotation: 308 | for metadata in all_metadata: 309 | if isinstance(metadata, dict) and "__torchtyping__" in metadata: 310 | break 311 | else: 312 | is_torchtyping_annotation = False 313 | if is_torchtyping_annotation: 314 | # We call _check_tensor here -- despite calling _check_tensor again 315 | # once we've seen every argument and filled in the shape details -- 316 | # just because we want to check that `value` is in fact a tensor before 317 | # we access its `shape` field on the next line. 318 | _check_tensor(argname, value, base_cls, metadata) 319 | for detail in metadata["details"]: 320 | if isinstance(detail, ShapeDetail): 321 | memo.value_info.append( 322 | (argname, value, metadata["cls_name"], detail) 323 | ) 324 | break 325 | 326 | else: 327 | _check_type(*args, **kwargs) 328 | 329 | def check_argument_types(*args, **kwargs): 330 | bound_args = check_argument_types_signature.bind(*args, **kwargs).arguments 331 | memo = bound_args["memo"] 332 | if memo is None: 333 | return _check_argument_types(*args, **kwargs) 334 | else: 335 | memo.value_info = [] 336 | memo.name_to_size = {} 337 | memo.name_to_shape = {} 338 | retval = _check_argument_types(*args, **kwargs) 339 | try: 340 | _check_memo(memo) 341 | except TypeError as exc: # suppress long traceback 342 | raise TypeError(*exc.args) from None 343 | return retval 344 | 345 | def check_return_type(*args, **kwargs): 346 | bound_args = check_return_type_signature.bind(*args, **kwargs).arguments 347 | memo = bound_args["memo"] 348 | if memo is None: 349 | return _check_return_type(*args, **kwargs) 350 | else: 351 | # Reset the collection of things that need checking. 352 | memo.value_info = [] 353 | # Do _not_ set memo.name_to_size or memo.name_to_shape, as we want to 354 | # keep using the same sizes inferred from the arguments. 355 | retval = _check_return_type(*args, **kwargs) 356 | try: 357 | _check_memo(memo) 358 | except TypeError as exc: # suppress long traceback 359 | raise TypeError(*exc.args) from None 360 | return retval 361 | 362 | typeguard._CallMemo = _CallMemo 363 | typeguard.check_type = check_type 364 | typeguard.check_argument_types = check_argument_types 365 | typeguard.check_return_type = check_return_type 366 | typeguard.get_type_hints = lambda *args, **kwargs: get_type_hints( 367 | *args, **kwargs, include_extras=True 368 | ) 369 | -------------------------------------------------------------------------------- /torchtyping/utils.py: -------------------------------------------------------------------------------- 1 | class frozendict(dict): 2 | def __init__(self, *args, **kwargs): 3 | super().__init__(*args, **kwargs) 4 | # Calling this immediately ensures that no unhashable types are used as 5 | # entries. 6 | # There's also no way this is an efficient hash algorithm, but we're only 7 | # planning on using this with small dictionaries. 8 | self._hash = hash(tuple(sorted(self.items()))) 9 | 10 | def __setitem__(self, item): 11 | raise RuntimeError(f"Cannot add items to a {type(self)}.") 12 | 13 | def __delitem__(self, item): 14 | raise RuntimeError(f"Cannot delete items from a {type(self)}.") 15 | 16 | def __hash__(self): 17 | return self._hash 18 | 19 | def __copy__(self): 20 | return self 21 | 22 | def __deepcopy__(self, memo): 23 | return self 24 | --------------------------------------------------------------------------------