├── .flake8
├── .github
├── FUNDING.yml
└── workflows
│ └── ci.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── CHANGELOG.md
├── CONTRIBUTING.md
├── FURTHER-DOCUMENTATION.md
├── LICENSE
├── MANIFEST.in
├── README.md
├── pyproject.toml
├── setup.py
├── test
├── conftest.py
├── test_consistency.py
├── test_details.py
├── test_dtype_layout.py
├── test_ellipsis.py
├── test_examples.py
├── test_extensions.py
├── test_misc.py
└── test_shape.py
└── torchtyping
├── __init__.py
├── pytest_plugin.py
├── tensor_details.py
├── tensor_type.py
├── typechecker.py
└── utils.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 88
3 | ignore = W291,W503,E203
4 | per-file-ignores = __init__.py: F401
--------------------------------------------------------------------------------
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | github: [patrick-kidger]
2 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yaml:
--------------------------------------------------------------------------------
1 | name: torchtyping CI
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - master
7 |
8 | jobs:
9 | formatting_check:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v1
13 | - name: Set up Python 3.9
14 | uses: actions/setup-python@v2
15 | with:
16 | python-version: 3.9
17 | - name: Install black
18 | run: |
19 | python -m pip install --upgrade pip
20 | pip install black flake8
21 | - name: Format with black
22 | run: |
23 | python -m black --check torchtyping/
24 | - name: Lint with flake8
25 | run: |
26 | flake8 torchtyping/
27 |
28 | test_suite:
29 | runs-on: ubuntu-latest
30 | strategy:
31 | matrix:
32 | python-version: [3.7, 3.8, 3.9]
33 | steps:
34 | - uses: actions/checkout@v2
35 | - name: Set up Python ${{ matrix.python-version }}
36 | uses: actions/setup-python@v2
37 | with:
38 | python-version: ${{ matrix.python-version }}
39 | - name: Install dependencies
40 | run: |
41 | python -m pip install --upgrade pip
42 | pip install wheel
43 | pip install -e .
44 | pip install pytest
45 | - name: Test with pytest
46 | run: |
47 | python -m pytest test/
48 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__/
2 | **/.ipynb_checkpoints/
3 | *.py[cod]
4 | .idea/
5 | .vs/
6 | build/
7 | dist/
8 | *.egg_info/
9 | *.egg
10 | *.so
11 | *.egg-info/
12 | **/.mypy_cache/
13 | env/
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/ambv/black
3 | rev: stable
4 | hooks:
5 | - id: black
6 | language_version: python3.9
7 |
8 | - repo: https://github.com/pycqa/flake8
9 | rev: 3.9.2
10 | hooks:
11 | - id: flake8
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | **0.1.4**
2 |
3 | Fixed metaclass incompatibility to work with PyTorch 1.9.0.
4 |
5 | **0.1.3**
6 |
7 | `TensorType` now inherits from `torch.Tensor` so that IDE lookup+error messages work as expected.
8 | Updated pre-commit hooks. These were failing for some reason.
9 |
10 | **0.1.2**
11 |
12 | Added support for `str: str` pairs and `None: str` pairs.
13 |
14 | **0.1.1**
15 |
16 | Added support for Python 3.7+. (Down from Python 3.9+.)
17 | Added support for `typing.Any` to indicate an arbitrary-size dimension.
18 |
19 | **0.1.0**
20 |
21 | Initial release.
22 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | Contributions (pull requests) are very welcome.
4 |
5 | First fork the library on GitHub.
6 |
7 | Then clone and install the library in development mode:
8 |
9 | ```bash
10 | git clone https://github.com/your-username-here/torchtyping.git
11 | cd torchtyping
12 | pip install -e .
13 | ```
14 |
15 | Then install the pre-commit hook:
16 |
17 | ```bash
18 | pip install pre-commit
19 | pre-commit install
20 | ```
21 |
22 | These automatically check that the code is formatted, using Black and flake8.
23 |
24 | Make your changes. Make sure to include additional tests testing any new functionality.
25 |
26 | Verify the tests all pass:
27 |
28 | ```bash
29 | pip install pytest
30 | pytest
31 | ```
32 |
33 | Push your changes back to your fork of the repository:
34 |
35 | ```bash
36 | git push
37 | ```
38 |
39 | Then open a pull request on GitHub.
40 |
--------------------------------------------------------------------------------
/FURTHER-DOCUMENTATION.md:
--------------------------------------------------------------------------------
1 | # Further documentation
2 |
3 | ## Design goals
4 |
5 | `torchtyping` had a few design goals.
6 |
7 | - **Use type annotations.** There's a few other libraries out there that do this via, essentially, syntactic sugar around `assert` statements. I wanted something neater than that.
8 | - **It should be easy to stop using `torchtyping`.** No, really! If it's not for you then it's easy to remove afterwards. Using `torchtyping` isn't something you should have to bake into your code; just replace `from torchtyping import TensorType` with `TensorType = list` (as a dummy), and your code should still all run.
9 | - **The runtime type checking should be optional.** Runtime checks obviously impose a performance penalty. Integrating with `typeguard` accomplishes this perfectly, in particular through its option to only activate when running tests (my favourite choice).
10 | - **`torchtyping` should be human-readable.** A big part of using type annotations in Python code is to document -- for whoever's reading it -- what is expected. (Particularly valuable on large codebases with several developers.) `torchtyping`'s syntax, and the use of type annotations over some other mechanism, is deliberately chosen to fulfill this goal.
11 |
12 | ## FAQ
13 |
14 | **The runtime checking isn't working!**
15 |
16 | First make sure that you're calling `torchtyping.patch_typeguard`.
17 |
18 | Then make sure that you've enabled `typeguard`, either by decorating the function with `typeguard.typechecked`, or by using `typeguard.importhook.install_import_hook`, or by using the pytest command line flags listed in the main [README](./README.md).
19 |
20 | Make sure that function you're checking is defined _after_ calling `torchtyping.patch_typeguard`.
21 |
22 | If you have done all of that, then feel free to raise an issue.
23 |
24 | **flake8 is giving spurious warnings.**
25 |
26 | Running flake8 will produce spurious warnings for annotations using strings: `TensorType["batch"]` gives `F821 undefined name 'batch'`.
27 |
28 | You can silence these en-masse just by creating a dummy `batch = None` anywhere in the file. (Or by placing `# noqa: F821` on the relevant lines.)
29 |
30 | **Does this work with `mypy`?**
31 |
32 | Mostly. You'll need to tell `mypy` not to think too hard about `torchtyping`, by annotating its import statements with:
33 |
34 | ```python
35 | from torchtyping import TensorType # type: ignore
36 | ```
37 |
38 | This is because the functionality provided by `torchtyping` is [currently beyond](https://www.python.org/dev/peps/pep-0646/) what `mypy` is capable of representing/understanding. (See also the [links at the end](#other-libraries-and-resources) for further material on this.)
39 |
40 | Additionally `mypy` has a bug which causes it crash on any file using the `str: int` or `str: ...` notation, as in `TensorType["batch": 10]`. This can be worked around by skipping the file, by creating a `filename.pyi` file in the same directory. See also the corresponding [mypy issue](https://github.com/python/mypy/issues/10266).
41 |
42 | **Are nested annotations of the form `Blahblah[Moreblah[TensorType[...]]]` supported?**
43 |
44 | Yes.
45 |
46 | **Are multiple `...` supported?**
47 |
48 | Yes. For example:
49 |
50 | ```python
51 | def func(x: TensorType["dim1": ..., "dim2": ...],
52 | y: TensorType["dim2": ...]
53 | ) -> TensorType["dim1": ...]:
54 | sum_dims = [-i - 1 for i in range(y.dim())]
55 | return (x * y).sum(dim=sum_dims)
56 | ```
57 |
58 | **`TensorType[float]` corresponds to`float32` but `torch.rand(2).to(float)` produces `float64`**.
59 |
60 | This is a deliberate asymmetry. `TensorType[float]` corresponds to `torch.get_default_dtype()`, as a convenience, but `.to(float)` always corresponds to `float64`.
61 |
62 | **How to indicate a scalar Tensor, i.e. one with zero dimensions?**
63 |
64 | `TensorType[()]`. Equivalently `TensorType[(), float]`, etc.
65 |
66 | **Support for TensorFlow/JAX/etc?**
67 |
68 | Not at the moment. The library is called `torchtyping` after all. [There are alternatives for these libraries.](#other-libraries-and-resources)
69 |
70 | ## Custom extensions
71 |
72 | Writing custom extensions is a breeze. Checking extra properties is done by subclassing `torchtyping.TensorDetail`, and passing instances of your `detail` to `torchtyping.TensorType`. For example this checks that the tensor has an additional attribute `foo`, which must be a string with value `"good-foo"`:
73 |
74 | ```python
75 | from torch import rand, Tensor
76 | from torchtyping import TensorDetail, TensorType
77 | from typeguard import typechecked
78 |
79 | # Write the extension
80 |
81 | class FooDetail(TensorDetail):
82 | def __init__(self, foo):
83 | super().__init__()
84 | self.foo = foo
85 |
86 | def check(self, tensor: Tensor) -> bool:
87 | return hasattr(tensor, "foo") and tensor.foo == self.foo
88 |
89 | # reprs used in error messages when the check is failed
90 |
91 | def __repr__(self) -> str:
92 | return f"FooDetail({self.foo})"
93 |
94 | @classmethod
95 | def tensor_repr(cls, tensor: Tensor) -> str:
96 | # Should return a representation of the tensor with respect
97 | # to what this detail is checking
98 | if hasattr(tensor, "foo"):
99 | return f"FooDetail({tensor.foo})"
100 | else:
101 | return ""
102 |
103 | # Test the extension
104 |
105 | @typechecked
106 | def foo_checker(tensor: TensorType[float, FooDetail("good-foo")]):
107 | pass
108 |
109 |
110 | def valid_foo():
111 | x = rand(3)
112 | x.foo = "good-foo"
113 | foo_checker(x)
114 |
115 |
116 | def invalid_foo_one():
117 | x = rand(3)
118 | x.foo = "bad-foo"
119 | foo_checker(x)
120 |
121 |
122 | def invalid_foo_two():
123 | x = rand(2).int()
124 | x.foo = "good-foo"
125 | foo_checker(x)
126 | ```
127 |
128 | As you can see, a `detail` must supply three methods. The first is a `check` method, which takes a tensor and checks whether it satisfies the detail. Second is a `__repr__`, which is used in error messages, to describe the detail that wasn't satisfied. Third is a `tensor_repr`, which is also used in error messages, to describe what property the tensor had (instead of the desired detail).
129 |
130 | ## Other libraries and resources
131 |
132 | `torchtyping` is one amongst a few libraries trying to do this kind of thing. Here's some links for the curious:
133 |
134 | **Discussion**
135 | - [PEP 646](https://www.python.org/dev/peps/pep-0646/) proposes variadic generics. These are a tool needed for static checkers (like `mypy`) to be able to do the kind of shape checking that `torchtyping` does dynamically. However at time of writing Python doesn't yet support this.
136 | - The [Ideas for array shape typing in Python](https://docs.google.com/document/d/1vpMse4c6DrWH5rq2tQSx3qwP_m_0lyn-Ij4WHqQqRHY/) document is a good overview of some of the ways to type check arrays.
137 |
138 | **Other libraries**
139 | - [TensorAnnotations](https://github.com/deepmind/tensor_annotations) is a library for statically checking JAX or TensorFlow tensor shapes. (It also has some good links on to other discussions around this topic.)
140 | - [`nptyping`](https://github.com/ramonhagenaars/nptyping) does something very similar to `torchtyping`, but for numpy.
141 | - [`tsanley`](https://github.com/ofnote/tsanley)/[`tsalib`](https://github.com/ofnote/tsalib) is an alternative dynamic shape checker, but requires a bit of extra setup.
142 | - [TensorGuard](https://github.com/Michedev/tensorguard) is an alternative, using extra function calls rather than type hints.
143 |
144 | ## More Examples
145 |
146 | **Shape checking:**
147 |
148 | ```python
149 | def func(x: TensorType["batch", 5],
150 | y: TensorType["batch", 3]):
151 | # x has shape (batch, 5)
152 | # y has shape (batch, 3)
153 | # batch dimension is the same for both
154 |
155 | def func(x: TensorType[2, -1, -1]):
156 | # x has shape (2, any_one, any_two)
157 | # -1 is a special value to represent any size.
158 | ```
159 |
160 | **Checking arbitrary numbers of batch dimensions:**
161 |
162 | ```python
163 | def func(x: TensorType[..., 2, 3]):
164 | # x has shape (..., 2, 3)
165 |
166 | def func(x: TensorType[..., 2, "channels"],
167 | y: TensorType[..., "channels"]):
168 | # x has shape (..., 2, channels)
169 | # y has shape (..., channels)
170 | # "channels" is checked to be the same size for both arguments.
171 |
172 | def func(x: TensorType["batch": ..., "channels_x"],
173 | y: TensorType["batch": ..., "channels_y"]):
174 | # x has shape (..., channels_x)
175 | # y has shape (..., channels_y)
176 | # the ... batch dimensions are checked to be of the same size.
177 | ```
178 |
179 | **Return value checking:**
180 |
181 | ```python
182 | def func(x: TensorType[3, 4]) -> TensorType[()]:
183 | # x has shape (3, 4)
184 | # return has shape ()
185 | ```
186 |
187 | **Dtype checking:**
188 |
189 | ```python
190 | def func(x: TensorType[float]):
191 | # x has dtype torch.float32
192 | ```
193 |
194 | **Checking shape and dtype at the same time:**
195 |
196 | ```python
197 | def func(x: TensorType[3, 4, float]):
198 | # x has shape (3, 4) and has dtype torch.float32
199 | ```
200 |
201 | **Checking names for dimensions as per [named tensors](https://pytorch.org/docs/stable/named_tensor.html):**
202 |
203 | ```python
204 | def func(x: TensorType["a": 3, "b", is_named]):
205 | # x has has names ("a", "b")
206 | # x has shape (3, Any)
207 | ```
208 |
209 | **Checking layouts:**
210 |
211 | ```python
212 | def func(x: TensorType[torch.sparse_coo]):
213 | # x is a sparse tensor with layout sparse_coo
214 | ```
215 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE
2 | include FURTHER-DOCUMENTATION.md
3 | prune test
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Please use jaxtyping instead
2 |
3 | *Welcome! For new projects I now **strongly** recommend using my newer [jaxtyping](https://github.com/google/jaxtyping) project instead. It supports PyTorch, doesn't actually depend on JAX, and unlike TorchTyping it is compatible with static type checkers. The 'jax' in the name is now historical!*
4 |
5 |
6 |
7 |
8 |
9 | The original torchtyping README is as follows.
10 |
11 | ---
12 |
13 |
torchtyping
14 | Type annotations for a tensor's shape, dtype, names, ...
15 |
16 | Turn this:
17 | ```python
18 | def batch_outer_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
19 | # x has shape (batch, x_channels)
20 | # y has shape (batch, y_channels)
21 | # return has shape (batch, x_channels, y_channels)
22 |
23 | return x.unsqueeze(-1) * y.unsqueeze(-2)
24 | ```
25 | into this:
26 | ```python
27 | def batch_outer_product(x: TensorType["batch", "x_channels"],
28 | y: TensorType["batch", "y_channels"]
29 | ) -> TensorType["batch", "x_channels", "y_channels"]:
30 |
31 | return x.unsqueeze(-1) * y.unsqueeze(-2)
32 | ```
33 | **with programmatic checking that the shape (dtype, ...) specification is met.**
34 |
35 | Bye-bye bugs! Say hello to enforced, clear documentation of your code.
36 |
37 | If (like me) you find yourself littering your code with comments like `# x has shape (batch, hidden_state)` or statements like `assert x.shape == y.shape` , just to keep track of what shape everything is, **then this is for you.**
38 |
39 | ---
40 |
41 | ## Installation
42 |
43 | ```bash
44 | pip install torchtyping
45 | ```
46 |
47 | Requires Python >=3.7 and PyTorch >=1.7.0.
48 |
49 | If using [`typeguard`](https://github.com/agronholm/typeguard) then it must be a version <3.0.0.
50 |
51 | ## Usage
52 |
53 | `torchtyping` allows for type annotating:
54 |
55 | - **shape**: size, number of dimensions;
56 | - **dtype** (float, integer, etc.);
57 | - **layout** (dense, sparse);
58 | - **names** of dimensions as per [named tensors](https://pytorch.org/docs/stable/named_tensor.html);
59 | - **arbitrary number of batch dimensions** with `...`;
60 | - **...plus anything else you like**, as `torchtyping` is highly extensible.
61 |
62 | If [`typeguard`](https://github.com/agronholm/typeguard) is (optionally) installed then **at runtime the types can be checked** to ensure that the tensors really are of the advertised shape, dtype, etc.
63 |
64 | ```python
65 | # EXAMPLE
66 |
67 | from torch import rand
68 | from torchtyping import TensorType, patch_typeguard
69 | from typeguard import typechecked
70 |
71 | patch_typeguard() # use before @typechecked
72 |
73 | @typechecked
74 | def func(x: TensorType["batch"],
75 | y: TensorType["batch"]) -> TensorType["batch"]:
76 | return x + y
77 |
78 | func(rand(3), rand(3)) # works
79 | func(rand(3), rand(1))
80 | # TypeError: Dimension 'batch' of inconsistent size. Got both 1 and 3.
81 | ```
82 |
83 | `typeguard` also has an import hook that can be used to automatically test an entire module, without needing to manually add `@typeguard.typechecked` decorators.
84 |
85 | If you're not using `typeguard` then `torchtyping.patch_typeguard()` can be omitted altogether, and `torchtyping` just used for documentation purposes. If you're not already using `typeguard` for your regular Python programming, then strongly consider using it. It's a great way to squash bugs. Both `typeguard` and `torchtyping` also integrate with `pytest`, so if you're concerned about any performance penalty then they can be enabled during tests only.
86 |
87 | ## API
88 |
89 | ```python
90 | torchtyping.TensorType[shape, dtype, layout, details]
91 | ```
92 |
93 | The core of the library.
94 |
95 | Each of `shape`, `dtype`, `layout`, `details` are optional.
96 |
97 | - The `shape` argument can be any of:
98 | - An `int`: the dimension must be of exactly this size. If it is `-1` then any size is allowed.
99 | - A `str`: the size of the dimension passed at runtime will be bound to this name, and all tensors checked that the sizes are consistent.
100 | - A `...`: An arbitrary number of dimensions of any sizes.
101 | - A `str: int` pair (technically it's a slice), combining both `str` and `int` behaviour. (Just a `str` on its own is equivalent to `str: -1`.)
102 | - A `str: str` pair, in which case the size of the dimension passed at runtime will be bound to _both_ names, and all dimensions with either name must have the same size. (Some people like to use this as a way to associate multiple names with a dimension, for extra documentation purposes.)
103 | - A `str: ...` pair, in which case the multiple dimensions corresponding to `...` will be bound to the name specified by `str`, and again checked for consistency between arguments.
104 | - `None`, which when used in conjunction with `is_named` below, indicates a dimension that must _not_ have a name in the sense of [named tensors](https://pytorch.org/docs/stable/named_tensor.html).
105 | - A `None: int` pair, combining both `None` and `int` behaviour. (Just a `None` on its own is equivalent to `None: -1`.)
106 | - A `None: str` pair, combining both `None` and `str` behaviour. (That is, it must not have a named dimension, but must be of a size consistent with other uses of the string.)
107 | - A `typing.Any`: Any size is allowed for this dimension (equivalent to `-1`).
108 | - Any tuple of the above. For example.`TensorType["batch": ..., "length": 10, "channels", -1]`. If you just want to specify the number of dimensions then use for example `TensorType[-1, -1, -1]` for a three-dimensional tensor.
109 | - The `dtype` argument can be any of:
110 | - `torch.float32`, `torch.float64` etc.
111 | - `int`, `bool`, `float`, which are converted to their corresponding PyTorch types. `float` is specifically interpreted as `torch.get_default_dtype()`, which is usually `float32`.
112 | - The `layout` argument can be either `torch.strided` or `torch.sparse_coo`, for dense and sparse tensors respectively.
113 | - The `details` argument offers a way to pass an arbitrary number of additional flags that customise and extend `torchtyping`. Two flags are built-in by default. `torchtyping.is_named` causes the [names of tensor dimensions](https://pytorch.org/docs/stable/named_tensor.html) to be checked, and `torchtyping.is_float` can be used to check that arbitrary floating point types are passed in. (Rather than just a specific one as with e.g. `TensorType[torch.float32]`.) For discussion on how to customise `torchtyping` with your own `details`, see the [further documentation](https://github.com/patrick-kidger/torchtyping/blob/master/FURTHER-DOCUMENTATION.md#custom-extensions).
114 | - Check multiple things at once by just putting them all together inside a single `[]`. For example `TensorType["batch": ..., "length", "channels", float, is_named]`.
115 |
116 | ```python
117 | torchtyping.patch_typeguard()
118 | ```
119 |
120 | `torchtyping` integrates with `typeguard` to perform runtime type checking. `torchtyping.patch_typeguard()` should be called at the global level, and will patch `typeguard` to check `TensorType`s.
121 |
122 | This function is safe to run multiple times. (It does nothing after the first run).
123 |
124 | - If using `@typeguard.typechecked`, then `torchtyping.patch_typeguard()` should be called any time before using `@typeguard.typechecked`. For example you could call it at the start of each file using `torchtyping`.
125 | - If using `typeguard.importhook.install_import_hook`, then `torchtyping.patch_typeguard()` should be called any time before defining the functions you want checked. For example you could call `torchtyping.patch_typeguard()` just once, at the same time as the `typeguard` import hook. (The order of the hook and the patch doesn't matter.)
126 | - If you're not using `typeguard` then `torchtyping.patch_typeguard()` can be omitted altogether, and `torchtyping` just used for documentation purposes.
127 |
128 | ```bash
129 | pytest --torchtyping-patch-typeguard
130 | ```
131 |
132 | `torchtyping` offers a `pytest` plugin to automatically run `torchtyping.patch_typeguard()` before your tests. `pytest` will automatically discover the plugin, you just need to pass the `--torchtyping-patch-typeguard` flag to enable it. Packages can then be passed to `typeguard` as normal, either by using `@typeguard.typechecked`, `typeguard`'s import hook, or the `pytest` flag `--typeguard-packages="your_package_here"`.
133 |
134 | ## Further documentation
135 |
136 | See the [further documentation](https://github.com/patrick-kidger/torchtyping/blob/master/FURTHER-DOCUMENTATION.md) for:
137 |
138 | - FAQ;
139 | - Including `flake8` and `mypy` compatibility;
140 | - How to write custom extensions to `torchtyping`;
141 | - Resources and links to other libraries and materials on this topic;
142 | - More examples.
143 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 88
3 |
4 | [tool.pytest.ini_options]
5 | addopts = "--torchtyping-patch-typeguard"
6 | # No running typeguard unfortunately, because we define a pytest import hook and that means torchtyping gets imported before typeguard gets a chance to run.
7 | # Ironic.
8 | #"--typeguard-packages=torchtyping"
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import setuptools
4 | import sys
5 |
6 | here = os.path.realpath(os.path.dirname(__file__))
7 |
8 |
9 | name = "torchtyping"
10 |
11 | # for simplicity we actually store the version in the __version__ attribute in the
12 | # source
13 | with open(os.path.join(here, name, "__init__.py")) as f:
14 | meta_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M)
15 | if meta_match:
16 | version = meta_match.group(1)
17 | else:
18 | raise RuntimeError("Unable to find __version__ string.")
19 |
20 | author = "Patrick Kidger"
21 |
22 | author_email = "contact@kidger.site"
23 |
24 | description = "Runtime type annotations for the shape, dtype etc. of PyTorch Tensors. "
25 |
26 | with open(os.path.join(here, "README.md"), "r", encoding="utf-8") as f:
27 | readme = f.read()
28 |
29 | url = "https://github.com/patrick-kidger/torchtyping"
30 |
31 | license = "Apache-2.0"
32 |
33 | classifiers = [
34 | "Development Status :: 3 - Alpha",
35 | "Intended Audience :: Developers",
36 | "Intended Audience :: Financial and Insurance Industry",
37 | "Intended Audience :: Information Technology",
38 | "Intended Audience :: Science/Research",
39 | "License :: OSI Approved :: Apache Software License",
40 | "Natural Language :: English",
41 | "Programming Language :: Python :: 3",
42 | "Programming Language :: Python :: 3.7",
43 | "Programming Language :: Python :: 3.8",
44 | "Programming Language :: Python :: 3.9",
45 | "Programming Language :: Python :: Implementation :: CPython",
46 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
47 | "Topic :: Scientific/Engineering :: Information Analysis",
48 | "Topic :: Scientific/Engineering :: Mathematics",
49 | "Framework :: Pytest",
50 | ]
51 |
52 | user_python_version = sys.version_info
53 |
54 | python_requires = ">=3.7.0"
55 |
56 | install_requires = ["torch>=1.7.0", "typeguard>=2.11.1,<3"]
57 |
58 | if user_python_version < (3, 9):
59 | install_requires += ["typing_extensions==3.7.4.3"]
60 |
61 | entry_points = dict(pytest11=["torchtyping = torchtyping.pytest_plugin"])
62 |
63 | setuptools.setup(
64 | name=name,
65 | version=version,
66 | author=author,
67 | author_email=author_email,
68 | maintainer=author,
69 | maintainer_email=author_email,
70 | description=description,
71 | long_description=readme,
72 | long_description_content_type="text/markdown",
73 | url=url,
74 | license=license,
75 | classifiers=classifiers,
76 | zip_safe=False,
77 | python_requires=python_requires,
78 | install_requires=install_requires,
79 | entry_points=entry_points,
80 | packages=[name],
81 | )
82 |
--------------------------------------------------------------------------------
/test/conftest.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import warnings
3 |
4 |
5 | with warnings.catch_warnings():
6 | warnings.simplefilter("ignore")
7 | torch.rand(2, names=("a",))
8 |
--------------------------------------------------------------------------------
/test/test_consistency.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from torch import rand
3 | from torchtyping import TensorType
4 | from typeguard import typechecked
5 |
6 |
7 | x = y = None
8 |
9 |
10 | def test_single():
11 | @typechecked
12 | def func1(x: TensorType["x"], y: TensorType["x"]):
13 | pass
14 |
15 | @typechecked
16 | def func2(x: TensorType["x"], y: TensorType["x"]) -> TensorType["x"]:
17 | return x + y
18 |
19 | @typechecked
20 | def func3(x: TensorType["x"], y: TensorType["x"]) -> TensorType["x", "x"]:
21 | return x + y
22 |
23 | @typechecked
24 | def func4(x: TensorType["x"], y: TensorType["x"]) -> TensorType["x", "x"]:
25 | return x.unsqueeze(0) + y.unsqueeze(1)
26 |
27 | @typechecked
28 | def func5(x: TensorType["x"], y: TensorType["x"]) -> TensorType["x", "y"]:
29 | return x
30 |
31 | @typechecked
32 | def func6(x: TensorType["x"], y: TensorType["x"]) -> TensorType["y", "x"]:
33 | return x
34 |
35 | @typechecked
36 | def func7(x: TensorType["x"]) -> TensorType["x"]:
37 | assert x.shape != (1,)
38 | return rand((1,))
39 |
40 | func1(rand(2), rand(2))
41 | func2(rand(2), rand(2))
42 | with pytest.raises(TypeError):
43 | func3(rand(2), rand(2))
44 | func4(rand(2), rand(2))
45 | with pytest.raises(TypeError):
46 | func5(rand(2), rand(2))
47 | with pytest.raises(TypeError):
48 | func6(rand(2), rand(2))
49 | with pytest.raises(TypeError):
50 | func7(rand(3))
51 |
52 |
53 | def test_multiple():
54 | # Fun fact, this "wrong" func0 is actually a mistype of func1, that torchtyping
55 | # caught for me when I ran the tests!
56 | @typechecked
57 | def func0(x: TensorType["x"], y: TensorType["y"]) -> TensorType["x", "y"]:
58 | return x.unsqueeze(0) + y.unsqueeze(1)
59 |
60 | @typechecked
61 | def func1(x: TensorType["x"], y: TensorType["y"]) -> TensorType["x", "y"]:
62 | return x.unsqueeze(1) + y.unsqueeze(0)
63 |
64 | @typechecked
65 | def func2(x: TensorType["x", "x"]):
66 | pass
67 |
68 | @typechecked
69 | def func3(x: TensorType["x", "x", "x"]):
70 | pass
71 |
72 | @typechecked
73 | def func4(x: TensorType["x"], y: TensorType["x", "y"]):
74 | pass
75 |
76 | @typechecked
77 | def func5(x: TensorType["x", "y"], y: TensorType["y", "x"]):
78 | pass
79 |
80 | @typechecked
81 | def func6(x: TensorType["x"], y: TensorType["y"]) -> TensorType["x", "y"]:
82 | assert not (x.shape == (2,) and y.shape == (3,))
83 | return rand(2, 3)
84 |
85 | func0(rand(2), rand(2)) # can't catch this
86 | with pytest.raises(TypeError):
87 | func0(rand(2), rand(3))
88 | with pytest.raises(TypeError):
89 | func0(rand(10), rand(0))
90 |
91 | func1(rand(2), rand(2))
92 | func1(rand(2), rand(3))
93 | func1(rand(10), rand(0))
94 |
95 | func2(rand(0, 0))
96 | func2(rand(2, 2))
97 | func2(rand(9, 9))
98 | with pytest.raises(TypeError):
99 | func2(rand(0, 4))
100 | func2(rand(1, 4))
101 | func2(rand(3, 4))
102 |
103 | func3(rand(0, 0, 0))
104 | func3(rand(2, 2, 2))
105 | func3(rand(9, 9, 9))
106 | with pytest.raises(TypeError):
107 | func3(rand(0, 4, 4))
108 | func3(rand(1, 4, 4))
109 | func3(rand(3, 3, 4))
110 |
111 | func4(rand(3), rand(3, 4))
112 | with pytest.raises(TypeError):
113 | func4(rand(3), rand(4, 3))
114 |
115 | func5(rand(2, 3), rand(3, 2))
116 | func5(rand(0, 5), rand(5, 0))
117 | func5(rand(2, 2), rand(2, 2))
118 | with pytest.raises(TypeError):
119 | func5(rand(2, 3), rand(2, 3))
120 | func5(rand(2, 3), rand(2, 2))
121 |
122 | with pytest.raises(TypeError):
123 | func6(rand(5), rand(3))
124 |
--------------------------------------------------------------------------------
/test/test_details.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torchtyping import TensorType, is_float, is_named
4 | import typeguard
5 |
6 |
7 | dim1 = dim2 = dim3 = None
8 |
9 |
10 | def test_float_tensor():
11 | @typeguard.typechecked
12 | def func1(x: TensorType[is_float]):
13 | pass
14 |
15 | @typeguard.typechecked
16 | def func2(x: TensorType[2, 2, is_float]):
17 | pass
18 |
19 | @typeguard.typechecked
20 | def func3(x: TensorType[float, is_float]):
21 | pass
22 |
23 | @typeguard.typechecked
24 | def func4(x: TensorType[bool, is_float]):
25 | pass
26 |
27 | @typeguard.typechecked
28 | def func5(x: TensorType["dim1":2, 2, float, is_float]):
29 | pass
30 |
31 | @typeguard.typechecked
32 | def func6(x: TensorType[2, "dim2":2, torch.sparse_coo, is_float]):
33 | pass
34 |
35 | x = torch.rand(2, 2)
36 | y = torch.rand(1)
37 | z = torch.tensor([[0, 1], [2, 3]])
38 | w = torch.rand(4).to_sparse()
39 | w1 = torch.rand(2, 2).to_sparse()
40 | w2 = torch.tensor([[0, 1], [2, 3]]).to_sparse()
41 |
42 | func1(x)
43 | func1(y)
44 | with pytest.raises(TypeError):
45 | func1(z)
46 | func1(w)
47 | func1(w1)
48 | with pytest.raises(TypeError):
49 | func1(w2)
50 |
51 | func2(x)
52 | with pytest.raises(TypeError):
53 | func2(y)
54 | with pytest.raises(TypeError):
55 | func2(z)
56 | with pytest.raises(TypeError):
57 | func2(w)
58 | func2(w1)
59 | with pytest.raises(TypeError):
60 | func2(w2)
61 |
62 | func3(x)
63 | func3(y)
64 | with pytest.raises(TypeError):
65 | func3(z)
66 | func3(w)
67 | func3(w1)
68 | with pytest.raises(TypeError):
69 | func3(w2)
70 |
71 | with pytest.raises(TypeError):
72 | func4(x)
73 | with pytest.raises(TypeError):
74 | func4(y)
75 | with pytest.raises(TypeError):
76 | func4(z)
77 | with pytest.raises(TypeError):
78 | func4(w)
79 | with pytest.raises(TypeError):
80 | func4(w1)
81 | with pytest.raises(TypeError):
82 | func4(w2)
83 |
84 | func5(x)
85 | with pytest.raises(TypeError):
86 | func5(y)
87 | with pytest.raises(TypeError):
88 | func5(z)
89 | with pytest.raises(TypeError):
90 | func5(w)
91 | func5(w1)
92 | with pytest.raises(TypeError):
93 | func5(w2)
94 |
95 | with pytest.raises(TypeError):
96 | func6(x)
97 | with pytest.raises(TypeError):
98 | func6(y)
99 | with pytest.raises(TypeError):
100 | func6(z)
101 | with pytest.raises(TypeError):
102 | func6(w)
103 | func6(w1)
104 | with pytest.raises(TypeError):
105 | func6(w2)
106 |
107 |
108 | def test_named_tensor():
109 | @typeguard.typechecked
110 | def _named_a_dim_checker(x: TensorType["dim1", is_named]):
111 | pass
112 |
113 | @typeguard.typechecked
114 | def _named_ab_dim_checker(x: TensorType["dim1", "dim2", is_named]):
115 | pass
116 |
117 | @typeguard.typechecked
118 | def _named_abc_dim_checker(x: TensorType["dim1", "dim2", "dim3", is_named]):
119 | pass
120 |
121 | @typeguard.typechecked
122 | def _named_cb_dim_checker(x: TensorType["dim3", "dim2", is_named]):
123 | pass
124 |
125 | @typeguard.typechecked
126 | def _named_am1_dim_checker(x: TensorType["dim1", -1, is_named]):
127 | pass
128 |
129 | @typeguard.typechecked
130 | def _named_m1b_dim_checker(x: TensorType[-1, "dim2", is_named]):
131 | pass
132 |
133 | @typeguard.typechecked
134 | def _named_abm1_dim_checker(x: TensorType["dim1", "dim2", -1, is_named]):
135 | pass
136 |
137 | @typeguard.typechecked
138 | def _named_m1bm1_dim_checker(x: TensorType[-1, "dim2", -1, is_named]):
139 | pass
140 |
141 | x = torch.rand(3, 4)
142 | named_x = torch.rand(3, 4, names=("dim1", "dim2"))
143 |
144 | with pytest.raises(TypeError):
145 | _named_ab_dim_checker(x)
146 | with pytest.raises(TypeError):
147 | _named_cb_dim_checker(x)
148 | with pytest.raises(TypeError):
149 | _named_am1_dim_checker(x)
150 | with pytest.raises(TypeError):
151 | _named_m1b_dim_checker(x)
152 | with pytest.raises(TypeError):
153 | _named_a_dim_checker(x)
154 | with pytest.raises(TypeError):
155 | _named_abc_dim_checker(x)
156 | with pytest.raises(TypeError):
157 | _named_abm1_dim_checker(x)
158 | with pytest.raises(TypeError):
159 | _named_m1bm1_dim_checker(x)
160 |
161 | _named_ab_dim_checker(named_x)
162 | _named_am1_dim_checker(named_x)
163 | _named_m1b_dim_checker(named_x)
164 | with pytest.raises(TypeError):
165 | _named_a_dim_checker(named_x)
166 | with pytest.raises(TypeError):
167 | _named_abc_dim_checker(named_x)
168 | with pytest.raises(TypeError):
169 | _named_cb_dim_checker(named_x)
170 | with pytest.raises(TypeError):
171 | _named_abm1_dim_checker(named_x)
172 | with pytest.raises(TypeError):
173 | _named_m1bm1_dim_checker(named_x)
174 |
175 |
176 | def test_named_float_tensor():
177 | @typeguard.typechecked
178 | def func(x: TensorType["dim1", "dim2":3, is_float, is_named]):
179 | pass
180 |
181 | x = torch.rand(2, 3, names=("dim1", "dim2"))
182 | y = torch.rand(2, 2, names=("dim1", "dim2"))
183 | z = torch.rand(2, 2, names=("dim1", "dim3"))
184 | w = torch.rand(2, 3)
185 | w1 = torch.rand(2, 2, names=("dim1", None))
186 | w2 = torch.rand(2, 3, names=("dim1", "dim2")).int()
187 |
188 | func(x)
189 | with pytest.raises(TypeError):
190 | func(y)
191 | with pytest.raises(TypeError):
192 | func(z)
193 | with pytest.raises(TypeError):
194 | func(w)
195 | with pytest.raises(TypeError):
196 | func(w1)
197 | with pytest.raises(TypeError):
198 | func(w2)
199 |
200 |
201 | def test_none_names():
202 | @typeguard.typechecked
203 | def func_unnamed1(x: TensorType[None:4]):
204 | pass
205 |
206 | @typeguard.typechecked
207 | def func_unnamed2(x: TensorType[None:4, "dim1"]):
208 | pass
209 |
210 | @typeguard.typechecked
211 | def func_unnamed3(x: TensorType[None:4, "dim1"], y: TensorType["dim1", None]):
212 | pass
213 |
214 | @typeguard.typechecked
215 | def func_named1(x: TensorType[None:4, is_named]):
216 | pass
217 |
218 | @typeguard.typechecked
219 | def func_named2(x: TensorType[None:4, "dim1", is_named]):
220 | pass
221 |
222 | func_unnamed1(torch.rand(4))
223 | func_unnamed1(torch.rand(4, names=(None,)))
224 | func_unnamed1(torch.rand(4, names=("not_none",)))
225 | with pytest.raises(TypeError):
226 | func_unnamed1(torch.rand(5))
227 | with pytest.raises(TypeError):
228 | func_unnamed1(torch.rand(5, names=(None,)))
229 | with pytest.raises(TypeError):
230 | func_unnamed1(torch.rand(5, names=("not_none",)))
231 | with pytest.raises(TypeError):
232 | func_unnamed1(torch.rand(2, 3))
233 | with pytest.raises(TypeError):
234 | func_unnamed1(torch.rand((), names=()))
235 | with pytest.raises(TypeError):
236 | func_unnamed1(torch.rand(1, 6, 7, 8, names=("not_none", None, None, None)))
237 |
238 | func_unnamed2(torch.rand(4, 5))
239 | func_unnamed2(torch.rand(4, 5, names=(None, None)))
240 | func_unnamed2(torch.rand(4, 5, names=("dim1", None)))
241 | func_unnamed2(torch.rand(4, 5, names=(None, "dim1")))
242 |
243 | func_unnamed3(torch.rand(4, 5), torch.rand(5, 3))
244 | func_unnamed3(torch.rand(4, 5, names=(None, None)), torch.rand(5, 3))
245 | func_unnamed3(torch.rand(4, 5), torch.rand(5, 3, names=("dim1", None)))
246 | func_unnamed3(
247 | torch.rand(4, 5, names=("another_name", "some_name")),
248 | torch.rand(5, 3, names=(None, "dim1")),
249 | )
250 | with pytest.raises(TypeError):
251 | func_unnamed3(torch.rand(4, 5), torch.rand(3, 3))
252 |
253 | func_named1(torch.rand(4))
254 | func_named1(torch.rand(4, names=(None,)))
255 | with pytest.raises(TypeError):
256 | func_named1(torch.rand(5, names=(None,)))
257 | with pytest.raises(TypeError):
258 | func_named1(torch.rand(4, names=("dim1",)))
259 | with pytest.raises(TypeError):
260 | func_named1(torch.rand(5, names=("dim1",)))
261 |
262 | func_named2(torch.rand(4, 5, names=(None, "dim1")))
263 | with pytest.raises(TypeError):
264 | func_named2(torch.rand(4, 5))
265 | with pytest.raises(TypeError):
266 | func_named2(torch.rand(4, 5, names=("another_dim", "dim1")))
267 | with pytest.raises(TypeError):
268 | func_named2(torch.rand(4, 5, names=(None, "dim2")))
269 | with pytest.raises(TypeError):
270 | func_named2(torch.rand(4, 5, names=(None, None)))
271 |
272 |
273 | def test_named_ellipsis():
274 | @typeguard.typechecked
275 | def func(x: TensorType["dim1":..., "dim2", is_named]):
276 | pass
277 |
278 | func(torch.rand(3, 4, names=(None, "dim2")))
279 | func(torch.rand(3, 4, names=("another_dim", "dim2")))
280 | with pytest.raises(TypeError):
281 | func(torch.rand(3, 4))
282 | with pytest.raises(TypeError):
283 | func(torch.rand(3, 4, names=(None, None)))
284 | with pytest.raises(TypeError):
285 | func(torch.rand(3, 4, names=("dim2", None)))
286 |
--------------------------------------------------------------------------------
/test/test_dtype_layout.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torchtyping import TensorType
4 | import typeguard
5 |
6 | from typing import Union
7 |
8 |
9 | @typeguard.typechecked
10 | def _float_checker(x: TensorType[float]):
11 | pass
12 |
13 |
14 | @typeguard.typechecked
15 | def _int_checker(x: TensorType[int]):
16 | pass
17 |
18 |
19 | @typeguard.typechecked
20 | def _union_int_float_checker(x: Union[TensorType[float], TensorType[int]]):
21 | pass
22 |
23 |
24 | def test_float_dtype():
25 | x = torch.rand(2)
26 | _float_checker(x)
27 | _union_int_float_checker(x)
28 | with pytest.raises(TypeError):
29 | _int_checker(x)
30 |
31 |
32 | def test_int_dtype():
33 | x = torch.tensor(2)
34 | _int_checker(x)
35 | _union_int_float_checker(x)
36 | with pytest.raises(TypeError):
37 | _float_checker(x)
38 |
39 |
40 | @typeguard.typechecked
41 | def _strided_checker(x: TensorType[torch.strided]):
42 | pass
43 |
44 |
45 | @typeguard.typechecked
46 | def _sparse_coo_checker(x: TensorType[torch.sparse_coo]):
47 | pass
48 |
49 |
50 | def test_strided_layout():
51 | x = torch.rand(2)
52 | _strided_checker(x)
53 | with pytest.raises(TypeError):
54 | _sparse_coo_checker(x)
55 |
56 |
57 | def test_sparse_coo_layout():
58 | x = torch.rand(2).to_sparse()
59 | _sparse_coo_checker(x)
60 | with pytest.raises(TypeError):
61 | _strided_checker(x)
62 |
--------------------------------------------------------------------------------
/test/test_ellipsis.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torchtyping import TensorType
4 | from typeguard import typechecked
5 |
6 |
7 | dim1 = dim2 = dim3 = channel = None
8 |
9 |
10 | def test_basic_ellipsis():
11 | @typechecked
12 | def func(x: TensorType["dim1":..., "dim2", "dim3"], y: TensorType["dim2", "dim3"]):
13 | pass
14 |
15 | func(torch.rand(2, 2), torch.rand(2, 2))
16 | func(torch.rand(2, 3), torch.rand(2, 3))
17 | func(torch.rand(1, 4, 2, 3), torch.rand(2, 3))
18 | func(torch.rand(2, 3, 2, 3), torch.rand(2, 3))
19 |
20 | with pytest.raises(TypeError):
21 | func(torch.rand(2, 3), torch.rand(3, 2))
22 | with pytest.raises(TypeError):
23 | func(torch.rand(1, 4, 2, 3), torch.rand(3, 2))
24 | with pytest.raises(TypeError):
25 | func(torch.rand(2, 3, 2, 3), torch.rand(3, 2))
26 |
27 |
28 | def test_zero_size_ellipsis1():
29 | @typechecked
30 | def func(
31 | x: TensorType["dim1":..., "dim2", "dim3"],
32 | y: TensorType["dim1":..., "dim2", "dim3"],
33 | ):
34 | pass
35 |
36 | with pytest.raises(TypeError):
37 | func(torch.rand(2, 2), torch.rand(2, 2, 2))
38 | with pytest.raises(TypeError):
39 | func(torch.rand(2, 2, 2), torch.rand(2, 2))
40 |
41 |
42 | def test_zero_size_ellipsis2():
43 | @typechecked
44 | def func(
45 | x: TensorType["dim1":..., "dim2", "dim3"],
46 | y: TensorType["dim1":..., "dim2", "dim3"],
47 | ):
48 | pass
49 |
50 | with pytest.raises(TypeError):
51 | func(torch.rand(2, 3), torch.rand(2, 2, 3))
52 | with pytest.raises(TypeError):
53 | func(torch.rand(2, 2, 3), torch.rand(2, 3))
54 | with pytest.raises(TypeError):
55 | func(torch.rand(2, 2), torch.rand(2, 2, 2))
56 | with pytest.raises(TypeError):
57 | func(torch.rand(2, 2, 2), torch.rand(2, 2))
58 |
59 |
60 | def test_multiple_ellipsis1():
61 | @typechecked
62 | def func(
63 | x: TensorType["dim1":..., "dim2":...], y: TensorType["dim2":...]
64 | ) -> TensorType["dim1":...]:
65 | sum_dims = [-i - 1 for i in range(y.dim())]
66 | return (x * y).sum(dim=sum_dims)
67 |
68 | func(torch.rand(1, 2), torch.rand(2))
69 | func(torch.rand(3, 4, 5, 9), torch.rand(5, 9))
70 | func(torch.rand(3, 4, 11, 5, 9), torch.rand(5, 9))
71 | func(torch.rand(3, 4, 11, 5, 9), torch.rand(11, 5, 9))
72 | with pytest.raises(TypeError):
73 | func(torch.rand(1), torch.rand(2))
74 | with pytest.raises(TypeError):
75 | func(torch.rand(1, 3, 5), torch.rand(3))
76 | with pytest.raises(TypeError):
77 | func(torch.rand(1, 4), torch.rand(1, 1, 4))
78 |
79 |
80 | def test_multiple_ellipsis2():
81 | @typechecked
82 | def func(
83 | x: TensorType["dim2":...], y: TensorType["dim1":..., "dim2":...]
84 | ) -> TensorType["dim1":...]:
85 | sum_dims = [-i - 1 for i in range(x.dim())]
86 | return (x * y).sum(dim=sum_dims)
87 |
88 | with pytest.raises(TypeError):
89 | func(torch.rand(1, 1, 4), torch.rand(1, 4))
90 |
91 |
92 | def test_multiple_ellipsis3():
93 | @typechecked
94 | def func(
95 | x: TensorType["dim1":..., "dim2":..., "dim3":...],
96 | y: TensorType["dim2":..., "dim3":...],
97 | z: TensorType["dim2":...],
98 | ) -> TensorType["dim1":...]:
99 | num2 = y.dim() - z.dim()
100 | num3 = z.dim()
101 | for _ in range(num2):
102 | z = z.unsqueeze(-1)
103 | y = y * z
104 | x = x + y
105 | for _ in range(num2 + num3):
106 | x = x.sum(dim=-1)
107 | return x
108 |
109 | func(torch.rand(1, 2, 3), torch.rand(2, 3), torch.rand(2))
110 | func(torch.rand(3, 5, 6, 7, 8, 0), torch.rand(5, 6, 7, 8, 0), torch.rand(5, 6, 7))
111 | func(torch.rand(3, 5, 6, 7, 8, 9), torch.rand(5, 6, 7, 8, 9), torch.rand(5, 6, 7))
112 |
113 |
114 | def test_repeat_ellipsis1():
115 | @typechecked
116 | def func(x: TensorType["dim1":..., "dim1":...], y: TensorType["dim1":...]):
117 | pass
118 |
119 | func(torch.rand(3, 4, 3, 4), torch.rand(3, 4))
120 | func(torch.rand(5, 5), torch.rand(5))
121 | with pytest.raises(TypeError):
122 | func(torch.rand(7, 9), torch.rand(7))
123 | with pytest.raises(TypeError):
124 | func(torch.rand(7, 4, 9, 4), torch.rand(7, 4))
125 | with pytest.raises(TypeError):
126 | func(torch.rand(7, 9), torch.rand(9))
127 | with pytest.raises(TypeError):
128 | func(torch.rand(3, 7, 3, 9), torch.rand(3, 9))
129 | with pytest.raises(TypeError):
130 | func(torch.rand(7, 3, 3, 9), torch.rand(3, 9))
131 | with pytest.raises(TypeError):
132 | func(torch.rand(7, 7), torch.rand(3))
133 |
134 |
135 | def test_repeat_ellipsis2():
136 | @typechecked
137 | def func(
138 | x: TensorType["dim1":..., "dim1":...],
139 | y: TensorType["dim1":..., "dim2":...],
140 | z: TensorType["dim2":...],
141 | ):
142 | pass
143 |
144 | func(torch.rand(4, 4), torch.rand(4, 5), torch.rand(5))
145 | func(torch.rand(3, 4, 3, 4), torch.rand(3, 4, 5), torch.rand(5))
146 | func(torch.rand(2, 3, 4, 2, 3, 4), torch.rand(2, 3, 4, 5, 6), torch.rand(5, 6))
147 | with pytest.raises(TypeError):
148 | func(torch.rand(2, 3, 4, 2, 3), torch.rand(2, 3, 4, 5, 6), torch.rand(5, 6))
149 | with pytest.raises(TypeError):
150 | func(torch.rand(2, 3, 4, 2, 3), torch.rand(2, 3, 4, 6), torch.rand(3, 4))
151 |
152 |
153 | def test_ambiguous_ellipsis():
154 | @typechecked
155 | def func1(x: TensorType["dim1":..., "dim2":...]):
156 | pass
157 |
158 | with pytest.raises(TypeError):
159 | func1(torch.rand(2, 2))
160 |
161 | @typechecked
162 | def func2(x: TensorType["dim1":..., "dim1":...]):
163 | pass
164 |
165 | with pytest.raises(TypeError):
166 | func2(torch.rand(2, 2))
167 |
--------------------------------------------------------------------------------
/test/test_examples.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from torch import ones, rand, sparse_coo, tensor
4 | from torchtyping import TensorType, is_named
5 | from typeguard import typechecked
6 |
7 |
8 | # make flake8 happy
9 | batch = x_channels = y_channels = a = b = channels = channels_x = channels_y = None
10 | annotator = word = feature = predicate = argument = None
11 |
12 |
13 | def test_example0():
14 | @typechecked
15 | def batch_outer_product(
16 | x: TensorType["batch", "x_channels"], y: TensorType["batch", "y_channels"]
17 | ) -> TensorType["batch", "x_channels", "y_channels"]:
18 |
19 | return x.unsqueeze(-1) * y.unsqueeze(-2)
20 |
21 | batch_outer_product(rand(2, 3), rand(2, 4))
22 | batch_outer_product(rand(5, 2), rand(5, 2))
23 | with pytest.raises(TypeError):
24 | batch_outer_product(rand(3, 2), rand(2, 3))
25 | with pytest.raises(TypeError):
26 | batch_outer_product(rand(1, 2, 3), rand(2, 3))
27 |
28 |
29 | def test_example1():
30 | @typechecked
31 | def func(x: TensorType["batch"], y: TensorType["batch"]) -> TensorType["batch"]:
32 | return x + y
33 |
34 | func(rand(3), rand(3))
35 | with pytest.raises(TypeError):
36 | func(rand(3), rand(1))
37 |
38 |
39 | def test_example2():
40 | @typechecked
41 | def func(x: TensorType["batch", 5], y: TensorType["batch", 3]):
42 | pass
43 |
44 | func(rand(3, 5), rand(3, 3))
45 | func(rand(7, 5), rand(7, 3))
46 | with pytest.raises(TypeError):
47 | func(rand(4, 5), rand(3, 5))
48 | with pytest.raises(TypeError):
49 | func(rand(1, 3, 5), rand(3, 5))
50 | with pytest.raises(TypeError):
51 | func(rand(3, 4), rand(3, 3))
52 | with pytest.raises(TypeError):
53 | func(rand(1, 3, 5), rand(1, 3, 3))
54 |
55 |
56 | def test_example3():
57 | @typechecked
58 | def func(x: TensorType[2, -1, -1]):
59 | pass
60 |
61 | func(rand(2, 1, 1))
62 | func(rand(2, 2, 1))
63 | func(rand(2, 10, 1))
64 | func(rand(2, 1, 10))
65 | with pytest.raises(TypeError):
66 | func(rand(1, 2, 1, 1))
67 | with pytest.raises(TypeError):
68 | func(rand(2, 1))
69 | with pytest.raises(TypeError):
70 | func(
71 | rand(
72 | 2,
73 | )
74 | )
75 | with pytest.raises(TypeError):
76 | func(rand(4, 2, 2))
77 |
78 |
79 | def test_example4():
80 | @typechecked
81 | def func(x: TensorType[..., 2, 3]):
82 | pass
83 |
84 | func(rand(2, 3))
85 | func(rand(1, 2, 3))
86 | func(rand(2, 2, 3))
87 | func(rand(3, 3, 5, 2, 3))
88 | with pytest.raises(TypeError):
89 | func(rand(1, 3))
90 | with pytest.raises(TypeError):
91 | func(rand(2, 1))
92 | with pytest.raises(TypeError):
93 | func(rand(1, 1, 3))
94 | with pytest.raises(TypeError):
95 | func(rand(2, 3, 3))
96 | with pytest.raises(TypeError):
97 | func(rand(3))
98 | with pytest.raises(TypeError):
99 | func(rand(2))
100 |
101 |
102 | def test_example5():
103 | @typechecked
104 | def func(x: TensorType[..., 2, "channels"], y: TensorType[..., "channels"]):
105 | pass
106 |
107 | func(rand(1, 2, 3), rand(3))
108 | func(rand(1, 2, 3), rand(1, 3))
109 | func(rand(2, 3), rand(2, 3))
110 | with pytest.raises(TypeError):
111 | func(rand(2, 2, 2, 2), rand(2, 4))
112 | with pytest.raises(TypeError):
113 | func(rand(3, 2), rand(2))
114 | with pytest.raises(TypeError):
115 | func(rand(5, 2, 1), rand(2))
116 |
117 |
118 | def test_example6():
119 | @typechecked
120 | def func(
121 | x: TensorType["batch":..., "channels_x"],
122 | y: TensorType["batch":..., "channels_y"],
123 | ):
124 | pass
125 |
126 | func(rand(3, 3, 3), rand(3, 3, 4))
127 | func(rand(1, 5, 6, 7), rand(1, 5, 6, 1))
128 | with pytest.raises(TypeError):
129 | func(rand(2, 2, 2), rand(2, 1, 2))
130 | with pytest.raises(TypeError):
131 | func(rand(4, 2, 2), rand(2, 2, 2))
132 | with pytest.raises(TypeError):
133 | func(rand(2, 2, 2), rand(2, 1, 4))
134 | with pytest.raises(TypeError):
135 | func(rand(2, 2), rand(2, 1, 2))
136 | with pytest.raises(TypeError):
137 | func(rand(2, 2), rand(1, 2, 2))
138 | with pytest.raises(TypeError):
139 | func(rand(4, 2), rand(3, 2))
140 |
141 |
142 | def test_example7():
143 | @typechecked
144 | def func(x: TensorType[3, 4]) -> TensorType[()]:
145 | return rand(())
146 |
147 | func(rand(3, 4))
148 | with pytest.raises(TypeError):
149 | func(rand(2, 4))
150 |
151 | @typechecked
152 | def func2(x: TensorType[3, 4]) -> TensorType[()]:
153 | return rand((1,))
154 |
155 | with pytest.raises(TypeError):
156 | func2(rand(3, 4))
157 | with pytest.raises(TypeError):
158 | func2(rand(2, 4))
159 |
160 |
161 | def test_example8():
162 | @typechecked
163 | def func(x: TensorType[float]):
164 | pass
165 |
166 | func(rand(2, 3))
167 | func(rand(1))
168 | func(rand(()))
169 | with pytest.raises(TypeError):
170 | func(tensor(1))
171 | with pytest.raises(TypeError):
172 | func(tensor([1, 2]))
173 | with pytest.raises(TypeError):
174 | func(tensor([[1, 1], [2, 2]]))
175 | with pytest.raises(TypeError):
176 | func(tensor(True))
177 |
178 |
179 | def test_example9():
180 | @typechecked
181 | def func(x: TensorType[3, 4, float]):
182 | pass
183 |
184 | func(rand(3, 4))
185 | with pytest.raises(TypeError):
186 | func(rand(3, 4).long())
187 | with pytest.raises(TypeError):
188 | func(rand(2, 3))
189 |
190 |
191 | def test_example10():
192 | @typechecked
193 | def func(x: TensorType["a":3, "b", is_named]):
194 | pass
195 |
196 | func(rand(3, 4, names=("a", "b")))
197 | with pytest.raises(TypeError):
198 | func(rand(3, 4), names=("a", "c"))
199 | with pytest.raises(TypeError):
200 | func(rand(3, 3, 3), names=(None, "a", "b"))
201 | with pytest.raises(TypeError):
202 | func(rand(3, 3, 3), names=("a", None, "b"))
203 | with pytest.raises(TypeError):
204 | func(rand(3, 3, 3), names=("a", "b", None))
205 | with pytest.raises(TypeError):
206 | func(rand(3, 4, names=(None, "b")))
207 | with pytest.raises(TypeError):
208 | func(rand(3, 4, names=("a", None)))
209 | with pytest.raises(TypeError):
210 | func(rand(3, 4))
211 | with pytest.raises(TypeError):
212 | func(rand(3, 4).long())
213 | with pytest.raises(TypeError):
214 | func(rand(3))
215 |
216 |
217 | def test_example11():
218 | @typechecked
219 | def func(x: TensorType[sparse_coo]):
220 | pass
221 |
222 | func(rand(3, 4).to_sparse())
223 | with pytest.raises(TypeError):
224 | func(rand(3, 4))
225 | with pytest.raises(TypeError):
226 | func(rand(3, 4).long())
227 |
228 |
229 | def test_example12():
230 | @typechecked
231 | def func(
232 | feats: TensorType["batch":..., "annotator":3, "word", "feature"],
233 | predicates: TensorType[
234 | "batch":..., "annotator":3, "predicate":"word", "feature"
235 | ],
236 | pred_arg_pairs: TensorType[
237 | "batch":..., "annotator":3, "predicate":"word", "argument":"word"
238 | ],
239 | ):
240 | pass
241 |
242 | func(ones(2, 1, 3, 4, 5), ones(2, 1, 3, 4, 5), ones(2, 1, 3, 4, 4))
243 |
244 | # matches ...
245 | with pytest.raises(TypeError):
246 | func(ones(2, 1, 3, 4, 5), ones(2, 3, 4, 5), ones(2, 1, 3, 4, 4))
247 | with pytest.raises(TypeError):
248 | func(ones(2, 1, 3, 4, 5), ones(2, 2, 3, 4, 5), ones(2, 1, 3, 4, 4))
249 |
250 | # annotator has 3
251 | with pytest.raises(TypeError):
252 | func(ones(2, 1, 2, 4, 5), ones(2, 1, 2, 4, 5), ones(2, 1, 2, 4, 4))
253 |
254 | # predicate and argument match word
255 | with pytest.raises(TypeError):
256 | func(ones(2, 1, 3, 3, 5), ones(2, 1, 3, 4, 5), ones(2, 1, 3, 4, 4))
257 | with pytest.raises(TypeError):
258 | func(ones(2, 1, 3, 4, 5), ones(2, 1, 3, 3, 5), ones(2, 1, 3, 4, 4))
259 | with pytest.raises(TypeError):
260 | func(ones(2, 1, 3, 4, 5), ones(2, 1, 3, 4, 5), ones(2, 1, 3, 3, 4))
261 |
--------------------------------------------------------------------------------
/test/test_extensions.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from torch import rand, Tensor
3 | from torchtyping import TensorDetail, TensorType
4 | from typeguard import typechecked
5 |
6 | good = foo = None
7 |
8 | # Write the extension
9 |
10 |
11 | class FooDetail(TensorDetail):
12 | def __init__(self, foo):
13 | super().__init__()
14 | self.foo = foo
15 |
16 | def check(self, tensor: Tensor) -> bool:
17 | return hasattr(tensor, "foo") and tensor.foo == self.foo
18 |
19 | # reprs used in error messages when the check is failed
20 |
21 | def __repr__(self) -> str:
22 | return f"FooDetail({self.foo})"
23 |
24 | @classmethod
25 | def tensor_repr(cls, tensor: Tensor) -> str:
26 | # Should return a representation of the tensor with respect
27 | # to what this detail is checking
28 | if hasattr(tensor, "foo"):
29 | return f"FooDetail({tensor.foo})"
30 | else:
31 | return ""
32 |
33 |
34 | # Test the extension
35 |
36 |
37 | @typechecked
38 | def foo_checker(tensor: TensorType[float, FooDetail("good-foo")]):
39 | pass
40 |
41 |
42 | def valid_foo():
43 | x = rand(3)
44 | x.foo = "good-foo"
45 | foo_checker(x)
46 |
47 |
48 | def invalid_foo_one():
49 | x = rand(3)
50 | x.foo = "bad-foo"
51 | foo_checker(x)
52 |
53 |
54 | def invalid_foo_two():
55 | x = rand(2).int()
56 | x.foo = "good-foo"
57 | foo_checker(x)
58 |
59 |
60 | def test_extensions():
61 | valid_foo()
62 | with pytest.raises(TypeError):
63 | invalid_foo_one()
64 | with pytest.raises(TypeError):
65 | invalid_foo_two()
66 |
--------------------------------------------------------------------------------
/test/test_misc.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torchtyping import TensorType
4 | from typeguard import typechecked
5 | from typing import Tuple
6 |
7 |
8 | dim1 = dim2 = dim3 = channel = None
9 |
10 |
11 | def test_non_tensor():
12 | class Tensor:
13 | shape = torch.Size([2, 2])
14 | dtype = torch.float32
15 | layout = torch.strided
16 |
17 | args = (None, 4, 3.0, 3.2, "Tensor", Tensor, Tensor())
18 |
19 | @typechecked
20 | def accepts_tensor1(x: TensorType):
21 | pass
22 |
23 | @typechecked
24 | def accepts_tensor2(x: TensorType[2, 2]):
25 | pass
26 |
27 | @typechecked
28 | def accepts_tensor3(x: TensorType[...]):
29 | pass
30 |
31 | @typechecked
32 | def accepts_tensor4(x: TensorType[float]):
33 | pass
34 |
35 | @typechecked
36 | def accepts_tensor5(x: TensorType[..., float]):
37 | pass
38 |
39 | @typechecked
40 | def accepts_tensor6(x: TensorType[2, int]):
41 | pass
42 |
43 | @typechecked
44 | def accepts_tensor7(x: TensorType[torch.strided]):
45 | pass
46 |
47 | @typechecked
48 | def accepts_tensor8(x: TensorType[2, float, torch.sparse_coo]):
49 | pass
50 |
51 | for func in (
52 | accepts_tensor1,
53 | accepts_tensor2,
54 | accepts_tensor3,
55 | accepts_tensor4,
56 | accepts_tensor5,
57 | accepts_tensor6,
58 | accepts_tensor7,
59 | accepts_tensor8,
60 | ):
61 | for arg in args:
62 | with pytest.raises(TypeError):
63 | func(arg)
64 |
65 | @typechecked
66 | def accepts_tensors1(x: TensorType, y: TensorType):
67 | pass
68 |
69 | @typechecked
70 | def accepts_tensors2(x: TensorType[2, 2], y: TensorType):
71 | pass
72 |
73 | @typechecked
74 | def accepts_tensors3(x: TensorType[...], y: TensorType):
75 | pass
76 |
77 | @typechecked
78 | def accepts_tensors4(x: TensorType[float], y: TensorType):
79 | pass
80 |
81 | @typechecked
82 | def accepts_tensors5(x: TensorType[..., float], y: TensorType):
83 | pass
84 |
85 | @typechecked
86 | def accepts_tensors6(x: TensorType[2, int], y: TensorType):
87 | pass
88 |
89 | @typechecked
90 | def accepts_tensors7(x: TensorType[torch.strided], y: TensorType):
91 | pass
92 |
93 | @typechecked
94 | def accepts_tensors8(x: TensorType[torch.sparse_coo, float, 2], y: TensorType):
95 | pass
96 |
97 | for func in (
98 | accepts_tensors1,
99 | accepts_tensors2,
100 | accepts_tensors3,
101 | accepts_tensors4,
102 | accepts_tensors5,
103 | accepts_tensors6,
104 | accepts_tensors7,
105 | accepts_tensors8,
106 | ):
107 | for arg1 in args:
108 | for arg2 in args:
109 | with pytest.raises(TypeError):
110 | func(arg1, arg2)
111 |
112 |
113 | def test_nested_types():
114 | @typechecked
115 | def func(x: Tuple[TensorType[3, "channel", 4], TensorType["channel"]]):
116 | pass
117 |
118 | func((torch.rand(3, 1, 4), torch.rand(1)))
119 | func((torch.rand(3, 5, 4), torch.rand(5)))
120 | with pytest.raises(TypeError):
121 | func((torch.rand(3, 1, 4), torch.rand(2)))
122 |
123 |
124 | def test_no_getitem():
125 | @typechecked
126 | def func(x: TensorType, y: TensorType):
127 | pass
128 |
129 | func(torch.rand(2), torch.rand(2))
130 | with pytest.raises(TypeError):
131 | func(torch.rand(2), None)
132 | with pytest.raises(TypeError):
133 | func(torch.rand(2), [3, 4])
134 |
135 |
136 | def test_scalar_tensor():
137 | @typechecked
138 | def func(x: TensorType[()]):
139 | pass
140 |
141 | func(torch.rand(()))
142 | with pytest.raises(TypeError):
143 | func(torch.rand((1,)))
144 | with pytest.raises(TypeError):
145 | func(torch.rand((1, 2)))
146 | with pytest.raises(TypeError):
147 | func(torch.rand((5, 2, 2)))
148 |
149 |
150 | def test_square():
151 | @typechecked
152 | def func(x: TensorType["dim1", "dim1"]):
153 | pass
154 |
155 | func(torch.rand(2, 2))
156 | func(torch.rand(5, 5))
157 | with pytest.raises(TypeError):
158 | func(torch.rand(3, 5))
159 | with pytest.raises(TypeError):
160 | func(torch.rand(5, 3))
161 |
162 |
163 | def test_repeat():
164 | @typechecked
165 | def func(x: TensorType["dim1", "dim2", "dim2"], y: TensorType[-1, "dim2"]):
166 | pass
167 |
168 | func(torch.rand(5, 3, 3), torch.rand(9, 3))
169 | func(torch.rand(5, 5, 5), torch.rand(9, 5))
170 | func(torch.rand(4, 5, 5), torch.rand(2, 5))
171 | with pytest.raises(TypeError):
172 | func(torch.rand(4, 5, 4), torch.rand(3, 5))
173 | with pytest.raises(TypeError):
174 | func(torch.rand(4, 5, 5), torch.rand(3, 3))
175 | with pytest.raises(TypeError):
176 | func(torch.rand(4, 3, 5), torch.rand(3, 3))
177 | with pytest.raises(TypeError):
178 | func(torch.rand(4, 3, 3), torch.rand(0, 2))
179 |
--------------------------------------------------------------------------------
/test/test_shape.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from typing import Any
3 | import torch
4 | from torchtyping import TensorType, is_named
5 | import typeguard
6 |
7 |
8 | # make flake8 happy
9 | a = b = c = x = y = z = None
10 |
11 |
12 | def test_fixed_int_dim():
13 | @typeguard.typechecked
14 | def _3_dim_checker(x: TensorType[3]):
15 | pass
16 |
17 | @typeguard.typechecked
18 | def _3m1_dim_checker(x: TensorType[3, -1]):
19 | pass
20 |
21 | @typeguard.typechecked
22 | def _4_dim_checker(x: TensorType[4]):
23 | pass
24 |
25 | @typeguard.typechecked
26 | def _4m1_dim_checker(x: TensorType[4, -1]):
27 | pass
28 |
29 | @typeguard.typechecked
30 | def _m14_dim_checker(x: TensorType[-1, 4]):
31 | pass
32 |
33 | @typeguard.typechecked
34 | def _m1m1_dim_checker(x: TensorType[-1, -1]):
35 | pass
36 |
37 | @typeguard.typechecked
38 | def _34_dim_checker(x: TensorType[3, 4]):
39 | pass
40 |
41 | @typeguard.typechecked
42 | def _34m1_dim_checker(x: TensorType[3, 4, -1]):
43 | pass
44 |
45 | @typeguard.typechecked
46 | def _m14m1_dim_checker(x: TensorType[-1, 4, -1]):
47 | pass
48 |
49 | x = torch.rand(3)
50 | _3_dim_checker(x)
51 | with pytest.raises(TypeError):
52 | _3m1_dim_checker(x)
53 | with pytest.raises(TypeError):
54 | _4_dim_checker(x)
55 | with pytest.raises(TypeError):
56 | _4m1_dim_checker(x)
57 | with pytest.raises(TypeError):
58 | _m14_dim_checker(x)
59 | with pytest.raises(TypeError):
60 | _m1m1_dim_checker(x)
61 | with pytest.raises(TypeError):
62 | _34_dim_checker(x)
63 | with pytest.raises(TypeError):
64 | _34m1_dim_checker(x)
65 | with pytest.raises(TypeError):
66 | _m14m1_dim_checker(x)
67 |
68 | x = torch.rand(3, 4)
69 | _3m1_dim_checker(x)
70 | _m14_dim_checker(x)
71 | _m1m1_dim_checker(x)
72 | _34_dim_checker(x)
73 | with pytest.raises(TypeError):
74 | _3_dim_checker(x)
75 | with pytest.raises(TypeError):
76 | _4_dim_checker(x)
77 | with pytest.raises(TypeError):
78 | _4m1_dim_checker(x)
79 | with pytest.raises(TypeError):
80 | _34m1_dim_checker(x)
81 | with pytest.raises(TypeError):
82 | _m14m1_dim_checker(x)
83 |
84 | x = torch.rand(4, 3)
85 | _4m1_dim_checker(x)
86 | _m1m1_dim_checker(x)
87 | with pytest.raises(TypeError):
88 | _3_dim_checker(x)
89 | with pytest.raises(TypeError):
90 | _3m1_dim_checker(x)
91 | with pytest.raises(TypeError):
92 | _4_dim_checker(x)
93 | with pytest.raises(TypeError):
94 | _m14_dim_checker(x)
95 | with pytest.raises(TypeError):
96 | _34_dim_checker(x)
97 | with pytest.raises(TypeError):
98 | _34m1_dim_checker(x)
99 | with pytest.raises(TypeError):
100 | _m14m1_dim_checker(x)
101 |
102 |
103 | def test_str_dim():
104 | @typeguard.typechecked
105 | def _a_dim_checker(x: TensorType["a"]):
106 | pass
107 |
108 | @typeguard.typechecked
109 | def _ab_dim_checker(x: TensorType["a", "b"]):
110 | pass
111 |
112 | @typeguard.typechecked
113 | def _abc_dim_checker(x: TensorType["a", "b", "c"]):
114 | pass
115 |
116 | @typeguard.typechecked
117 | def _cb_dim_checker(x: TensorType["c", "b"]):
118 | pass
119 |
120 | @typeguard.typechecked
121 | def _am1_dim_checker(x: TensorType["a", -1]):
122 | pass
123 |
124 | @typeguard.typechecked
125 | def _m1b_dim_checker(x: TensorType[-1, "b"]):
126 | pass
127 |
128 | @typeguard.typechecked
129 | def _abm1_dim_checker(x: TensorType["a", "b", -1]):
130 | pass
131 |
132 | @typeguard.typechecked
133 | def _m1bm1_dim_checker(x: TensorType[-1, "b", -1]):
134 | pass
135 |
136 | x = torch.rand(3, 4)
137 | _ab_dim_checker(x)
138 | _cb_dim_checker(x)
139 | _am1_dim_checker(x)
140 | _m1b_dim_checker(x)
141 | with pytest.raises(TypeError):
142 | _a_dim_checker(x)
143 | with pytest.raises(TypeError):
144 | _abc_dim_checker(x)
145 | with pytest.raises(TypeError):
146 | _abm1_dim_checker(x)
147 | with pytest.raises(TypeError):
148 | _m1bm1_dim_checker(x)
149 |
150 |
151 | def test_str_str_dim1():
152 | @typeguard.typechecked
153 | def func(x: TensorType["a":"x"]):
154 | pass
155 |
156 | func(torch.ones(3))
157 | func(torch.ones(2))
158 | with pytest.raises(TypeError):
159 | func(torch.tensor(3.0))
160 | with pytest.raises(TypeError):
161 | func(torch.ones(3, 3))
162 |
163 |
164 | def test_str_str_dim2():
165 | @typeguard.typechecked
166 | def func(x: TensorType["a":"x", "b":"x"]):
167 | pass
168 |
169 | func(torch.ones(3, 3))
170 | func(torch.ones(2, 2))
171 | with pytest.raises(TypeError):
172 | func(torch.tensor(3.0))
173 | with pytest.raises(TypeError):
174 | func(torch.ones(3))
175 | with pytest.raises(TypeError):
176 | func(torch.ones(3, 2))
177 | with pytest.raises(TypeError):
178 | func(torch.ones(2, 3))
179 |
180 |
181 | def test_str_str_dim_complex():
182 | @typeguard.typechecked
183 | def func(x: TensorType["a":"x", "b":"x", "x", "a", "b"]) -> TensorType["c":"x"]:
184 | return torch.ones(x.shape[0])
185 |
186 | func(torch.ones(3, 3, 3, 3, 3))
187 | func(torch.ones(2, 2, 2, 2, 2))
188 | with pytest.raises(TypeError):
189 | func(torch.ones(1, 2, 2, 2, 2))
190 | with pytest.raises(TypeError):
191 | func(torch.ones(2, 1, 2, 2, 2))
192 | with pytest.raises(TypeError):
193 | func(torch.ones(2, 2, 1, 2, 2))
194 | with pytest.raises(TypeError):
195 | func(torch.ones(2, 2, 2, 1, 2))
196 | with pytest.raises(TypeError):
197 | func(torch.ones(2, 2, 2, 2, 1))
198 |
199 | @typeguard.typechecked
200 | def func2(x: TensorType["a":"x", "b":"x", "x", "a", "b"]) -> TensorType["c":"x"]:
201 | return torch.ones(x.shape[0] + 1)
202 |
203 | with pytest.raises(TypeError):
204 | func2(torch.ones(2, 2, 2, 2, 2))
205 |
206 | @typeguard.typechecked
207 | def func3(x: TensorType["a":"x", "b":"x", "x", "a", "b"]) -> TensorType["c":"x"]:
208 | return torch.ones(x.shape[0], x.shape[0])
209 |
210 | with pytest.raises(TypeError):
211 | func3(torch.ones(2, 2, 2, 2, 2))
212 |
213 |
214 | def test_str_str_dim_fixed_num():
215 | @typeguard.typechecked
216 | def func(x: TensorType["a":"x"]) -> TensorType["x":3]:
217 | return torch.ones(x.shape[0])
218 |
219 | func(torch.ones(3))
220 | with pytest.raises(TypeError):
221 | func(torch.ones(2))
222 |
223 |
224 | def test_str_str_dim_fixed_names():
225 | @typeguard.typechecked
226 | def func(x: TensorType["a":"x", is_named]) -> TensorType["x":3]:
227 | return torch.ones(x.shape[0])
228 |
229 | func(torch.ones(3, names=["a"]))
230 | with pytest.raises(TypeError):
231 | func(torch.ones(3))
232 | with pytest.raises(TypeError):
233 | func(torch.ones(3, names=["b"]))
234 | with pytest.raises(TypeError):
235 | func(torch.ones(2, names=["a"]))
236 | with pytest.raises(TypeError):
237 | func(torch.ones(3, names=["x"]))
238 |
239 |
240 | def test_str_str_dim_no_early_return():
241 | @typeguard.typechecked
242 | def func(x: TensorType["a":"x", "b":"y", "c":"z", is_named]):
243 | pass
244 |
245 | func(torch.ones(2, 2, 2, names=["a", "b", "c"]))
246 | with pytest.raises(TypeError):
247 | func(torch.ones(2, 2, 2, names=["d", "b", "c"]))
248 | with pytest.raises(TypeError):
249 | func(torch.ones(2, 2, 2, names=["a", "b", "d"]))
250 |
251 |
252 | def test_none_str():
253 | @typeguard.typechecked
254 | def func(x: TensorType[None:"x", "b":"x", is_named]):
255 | pass
256 |
257 | func(torch.ones(2, 2, names=[None, "b"]))
258 | func(torch.ones(3, 3, names=[None, "b"]))
259 | with pytest.raises(TypeError):
260 | func(torch.ones(2, 2, names=["a", "b"]))
261 | with pytest.raises(TypeError):
262 | func(torch.ones(2, 2, names=["x", "b"]))
263 | with pytest.raises(TypeError):
264 | func(torch.ones(2, 2, names=[None, None]))
265 | with pytest.raises(TypeError):
266 | func(torch.ones(2, 2, names=[None, "c"]))
267 | with pytest.raises(TypeError):
268 | func(torch.ones(2, 2, names=[None, "x"]))
269 |
270 |
271 | def test_other_str_should_fail():
272 | with pytest.raises(TypeError):
273 |
274 | def func(x: TensorType[3:"x"]):
275 | pass
276 |
277 |
278 | def test_int_str_dim():
279 | @typeguard.typechecked
280 | def _a_dim_checker1(x: TensorType["a":3]):
281 | pass
282 |
283 | @typeguard.typechecked
284 | def _a_dim_checker2(x: TensorType["a":-1]):
285 | pass
286 |
287 | @typeguard.typechecked
288 | def _ab_dim_checker1(x: TensorType["a":3, "b":4]):
289 | pass
290 |
291 | @typeguard.typechecked
292 | def _ab_dim_checker2(x: TensorType["a":3, "b":-1]):
293 | pass
294 |
295 | @typeguard.typechecked
296 | def _ab_dim_checker3(x: TensorType["a":-1, "b":4]):
297 | pass
298 |
299 | @typeguard.typechecked
300 | def _ab_dim_checker4(x: TensorType["a":3, "b"]):
301 | pass
302 |
303 | @typeguard.typechecked
304 | def _ab_dim_checker5(x: TensorType["a", "b":4]):
305 | pass
306 |
307 | @typeguard.typechecked
308 | def _ab_dim_checker6(x: TensorType["a":5, "b":4]):
309 | pass
310 |
311 | @typeguard.typechecked
312 | def _ab_dim_checker7(x: TensorType["a":5, "b":-1]):
313 | pass
314 |
315 | @typeguard.typechecked
316 | def _m1b_dim_checker(x: TensorType[-1, "b":4]):
317 | pass
318 |
319 | @typeguard.typechecked
320 | def _abm1_dim_checker(x: TensorType["a":3, "b":4, -1]):
321 | pass
322 |
323 | @typeguard.typechecked
324 | def _m1bm1_dim_checker(x: TensorType[-1, "b":4, -1]):
325 | pass
326 |
327 | x = torch.rand(3, 4)
328 | _ab_dim_checker1(x)
329 | _ab_dim_checker2(x)
330 | _ab_dim_checker3(x)
331 | _ab_dim_checker4(x)
332 | _ab_dim_checker5(x)
333 | _m1b_dim_checker(x)
334 | with pytest.raises(TypeError):
335 | _a_dim_checker1(x)
336 | with pytest.raises(TypeError):
337 | _a_dim_checker2(x)
338 | with pytest.raises(TypeError):
339 | _ab_dim_checker6(x)
340 | with pytest.raises(TypeError):
341 | _ab_dim_checker7(x)
342 | with pytest.raises(TypeError):
343 | _abm1_dim_checker(x)
344 | with pytest.raises(TypeError):
345 | _m1bm1_dim_checker(x)
346 |
347 |
348 | def test_return():
349 | @typeguard.typechecked
350 | def f1(x: TensorType["b":4]) -> TensorType["b":4]:
351 | return torch.rand(3)
352 |
353 | @typeguard.typechecked
354 | def f2(x: TensorType["b"]) -> TensorType["b":4]:
355 | return torch.rand(3)
356 |
357 | @typeguard.typechecked
358 | def f3(x: TensorType[4]) -> TensorType["b":4]:
359 | return torch.rand(3)
360 |
361 | @typeguard.typechecked
362 | def f4(x: TensorType["b":4]) -> TensorType["b"]:
363 | return torch.rand(3)
364 |
365 | @typeguard.typechecked
366 | def f5(x: TensorType["b"]) -> TensorType["b"]:
367 | return torch.rand(3)
368 |
369 | @typeguard.typechecked
370 | def f6(x: TensorType[4]) -> TensorType["b"]:
371 | return torch.rand(3)
372 |
373 | @typeguard.typechecked
374 | def f7(x: TensorType["b":4]) -> TensorType[4]:
375 | return torch.rand(3)
376 |
377 | @typeguard.typechecked
378 | def f8(x: TensorType["b"]) -> TensorType[4]:
379 | return torch.rand(3)
380 |
381 | @typeguard.typechecked
382 | def f9(x: TensorType[4]) -> TensorType[4]:
383 | return torch.rand(3)
384 |
385 | with pytest.raises(TypeError):
386 | f1(torch.rand(3))
387 | with pytest.raises(TypeError):
388 | f2(torch.rand(3))
389 | with pytest.raises(TypeError):
390 | f3(torch.rand(3))
391 | with pytest.raises(TypeError):
392 | f4(torch.rand(3))
393 | f5(torch.rand(3))
394 | with pytest.raises(TypeError):
395 | f6(torch.rand(3))
396 | with pytest.raises(TypeError):
397 | f7(torch.rand(3))
398 | with pytest.raises(TypeError):
399 | f8(torch.rand(3))
400 | with pytest.raises(TypeError):
401 | f9(torch.rand(3))
402 |
403 | with pytest.raises(TypeError):
404 | f1(torch.rand(4))
405 | with pytest.raises(TypeError):
406 | f2(torch.rand(4))
407 | with pytest.raises(TypeError):
408 | f3(torch.rand(4))
409 | with pytest.raises(TypeError):
410 | f4(torch.rand(4))
411 | with pytest.raises(TypeError):
412 | f5(torch.rand(4))
413 | f6(torch.rand(4))
414 | with pytest.raises(TypeError):
415 | f7(torch.rand(4))
416 | with pytest.raises(TypeError):
417 | f8(torch.rand(4))
418 | with pytest.raises(TypeError):
419 | f9(torch.rand(4))
420 |
421 |
422 | def test_any_dim():
423 | @typeguard.typechecked
424 | def _3any_dim_checker(x: TensorType[3, Any]):
425 | pass
426 |
427 | @typeguard.typechecked
428 | def _any4_dim_checker(x: TensorType[Any, 4]):
429 | pass
430 |
431 | @typeguard.typechecked
432 | def _anyany_dim_checker(x: TensorType[Any, Any]):
433 | pass
434 |
435 | @typeguard.typechecked
436 | def _34any_dim_checker(x: TensorType[3, 4, Any]):
437 | pass
438 |
439 | @typeguard.typechecked
440 | def _any4any_dim_checker(x: TensorType[Any, 4, Any]):
441 | pass
442 |
443 | x = torch.rand(3)
444 | with pytest.raises(TypeError):
445 | _3any_dim_checker(x)
446 | with pytest.raises(TypeError):
447 | _any4_dim_checker(x)
448 | with pytest.raises(TypeError):
449 | _anyany_dim_checker(x)
450 | with pytest.raises(TypeError):
451 | _34any_dim_checker(x)
452 | with pytest.raises(TypeError):
453 | _any4any_dim_checker(x)
454 |
455 | x = torch.rand((3, 4))
456 | _3any_dim_checker(x)
457 | _any4_dim_checker(x)
458 | _anyany_dim_checker(x)
459 |
460 | x = torch.rand((4, 5))
461 | with pytest.raises(TypeError):
462 | _any4_dim_checker(x)
463 |
464 | x = torch.rand(4, 5)
465 | with pytest.raises(TypeError):
466 | _3any_dim_checker(x)
467 |
468 | x = torch.rand((3, 4, 5))
469 | _34any_dim_checker(x)
470 | _any4any_dim_checker(x)
471 |
472 | x = torch.rand((3, 5, 5))
473 | with pytest.raises(TypeError):
474 | x = _any4any_dim_checker(x)
475 | with pytest.raises(TypeError):
476 | _34any_dim_checker(x)
477 |
--------------------------------------------------------------------------------
/torchtyping/__init__.py:
--------------------------------------------------------------------------------
1 | from .tensor_details import (
2 | DtypeDetail,
3 | is_float,
4 | is_named,
5 | LayoutDetail,
6 | ShapeDetail,
7 | TensorDetail,
8 | )
9 |
10 | from .tensor_type import TensorType
11 | from .typechecker import patch_typeguard
12 |
13 | __version__ = "0.1.5"
14 |
--------------------------------------------------------------------------------
/torchtyping/pytest_plugin.py:
--------------------------------------------------------------------------------
1 | from .typechecker import patch_typeguard
2 |
3 |
4 | def pytest_addoption(parser):
5 | group = parser.getgroup("torchtyping")
6 | group.addoption(
7 | "--torchtyping-patch-typeguard",
8 | action="store_true",
9 | help="Run torchtyping's typeguard patch.",
10 | )
11 |
12 |
13 | def pytest_configure(config):
14 | if config.getoption("torchtyping_patch_typeguard"):
15 | patch_typeguard()
16 |
--------------------------------------------------------------------------------
/torchtyping/tensor_details.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import abc
4 | import collections
5 | import torch
6 |
7 | from typing import Optional, Union
8 |
9 |
10 | ellipsis = type(...)
11 |
12 |
13 | class TensorDetail(metaclass=abc.ABCMeta):
14 | @abc.abstractmethod
15 | def __repr__(self) -> str:
16 | raise NotImplementedError
17 |
18 | @abc.abstractmethod
19 | def check(self, tensor: torch.Tensor) -> bool:
20 | raise NotImplementedError
21 |
22 | @classmethod
23 | @abc.abstractmethod
24 | def tensor_repr(cls, tensor: torch.Tensor) -> str:
25 | raise NotImplementedError
26 |
27 |
28 | _no_name = object()
29 |
30 |
31 | # inheriting from typing.NamedTuple crashes typeguard
32 | class _Dim(collections.namedtuple("_Dim", ["name", "size"])):
33 | # None corresponds to a name not being set. no_name corresponds to us not caring
34 | # whether a name is set.
35 | name: Union[None, str, type(_no_name)]
36 | # technically supposed to use an enum to annotate singletons but that's overkill.
37 |
38 | size: Union[ellipsis, int]
39 |
40 | def __repr__(self) -> str:
41 | if self.name is _no_name:
42 | if self.size is ...:
43 | return "..."
44 | else:
45 | return repr(self.size)
46 | else:
47 | if self.size is ...:
48 | return f"{self.name}: ..."
49 | elif self.size == -1:
50 | return repr(self.name)
51 | else:
52 | return f"{self.name}: {self.size}"
53 |
54 |
55 | class ShapeDetail(TensorDetail):
56 | def __init__(self, *, dims: list[_Dim], check_names: bool, **kwargs) -> None:
57 | super().__init__(**kwargs)
58 | self.dims = dims
59 | self.check_names = check_names
60 |
61 | def __repr__(self) -> str:
62 | if len(self.dims) == 0:
63 | out = "()"
64 | elif len(self.dims) == 1:
65 | out = repr(self.dims[0])
66 | else:
67 | out = repr(tuple(self.dims))[1:-1]
68 | if self.check_names:
69 | out += ", is_named"
70 | return out
71 |
72 | def check(self, tensor: torch.Tensor) -> bool:
73 | self_names = [self_dim.name for self_dim in self.dims]
74 | self_shape = [self_dim.size for self_dim in self.dims]
75 |
76 | if ... in self_shape:
77 | if sum(1 for size in self_shape if size is not ...) > len(tensor.names):
78 | return False
79 | else:
80 | if len(self_shape) != len(tensor.names):
81 | return False
82 |
83 | for self_name, self_size, tensor_name, tensor_size in zip(
84 | reversed(self_names),
85 | reversed(self_shape),
86 | reversed(tensor.names),
87 | reversed(tensor.shape),
88 | ):
89 | if self_size is ...:
90 | # This assumes that Ellipses only occur on the left hand edge.
91 | # So once we hit one we're done.
92 | break
93 |
94 | if (
95 | self.check_names
96 | and self_name is not _no_name
97 | and self_name != tensor_name
98 | ):
99 | return False
100 | if not isinstance(self_size, str) and self_size not in (-1, tensor_size):
101 | return False
102 |
103 | return True
104 |
105 | @classmethod
106 | def tensor_repr(cls, tensor: torch.Tensor) -> str:
107 | dims = []
108 | check_names = any(name is not None for name in tensor.names)
109 | for name, size in zip(tensor.names, tensor.shape):
110 | if not check_names:
111 | name = _no_name
112 | dims.append(_Dim(name=name, size=size))
113 | return repr(cls(dims=dims, check_names=check_names))
114 |
115 | def update(
116 | self,
117 | *,
118 | dims: Optional[list[_Dim]] = None,
119 | check_names: Optional[bool] = None,
120 | **kwargs,
121 | ) -> ShapeDetail:
122 | dims = self.dims if dims is None else dims
123 | check_names = self.check_names if check_names is None else check_names
124 | return type(self)(dims=dims, check_names=check_names, **kwargs)
125 |
126 |
127 | class DtypeDetail(TensorDetail):
128 | def __init__(self, *, dtype, **kwargs) -> None:
129 | super().__init__(**kwargs)
130 | assert isinstance(dtype, torch.dtype)
131 | self.dtype = dtype
132 |
133 | def __repr__(self) -> str:
134 | return repr(self.dtype)
135 |
136 | def check(self, tensor: torch.Tensor) -> bool:
137 | return self.dtype == tensor.dtype
138 |
139 | @classmethod
140 | def tensor_repr(cls, tensor: torch.Tensor) -> str:
141 | return repr(cls(dtype=tensor.dtype))
142 |
143 |
144 | class LayoutDetail(TensorDetail):
145 | def __init__(self, *, layout, **kwargs) -> None:
146 | super().__init__(**kwargs)
147 | self.layout = layout
148 |
149 | def __repr__(self) -> str:
150 | return repr(self.layout)
151 |
152 | def check(self, tensor: torch.Tensor) -> bool:
153 | return self.layout == tensor.layout
154 |
155 | @classmethod
156 | def tensor_repr(cls, tensor: torch.Tensor) -> str:
157 | return repr(cls(layout=tensor.layout))
158 |
159 |
160 | class _FloatDetail(TensorDetail):
161 | def __repr__(self) -> str:
162 | return "is_float"
163 |
164 | def check(self, tensor: torch.Tensor) -> bool:
165 | return tensor.is_floating_point()
166 |
167 | @classmethod
168 | def tensor_repr(cls, tensor: torch.Tensor) -> str:
169 | return "is_float" if tensor.is_floating_point() else ""
170 |
171 |
172 | # is_named is special-cased and consumed by TensorType.
173 | # It's a bit of an odd exception.
174 | # It's only a TensorDetail for consistency, as the other
175 | # extra flags that get passed are TensorDetails.
176 | class _NamedTensorDetail(TensorDetail):
177 | def __repr__(self) -> str:
178 | raise RuntimeError
179 |
180 | def check(self, tensor: torch.Tensor) -> bool:
181 | raise RuntimeError
182 |
183 | @classmethod
184 | def tensor_repr(cls, tensor: torch.Tensor) -> str:
185 | raise RuntimeError
186 |
187 |
188 | is_float = _FloatDetail() # singleton flag
189 | is_named = _NamedTensorDetail() # singleton flag
190 |
--------------------------------------------------------------------------------
/torchtyping/tensor_type.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import sys
4 | import torch
5 |
6 | from .tensor_details import (
7 | _Dim,
8 | _no_name,
9 | is_named,
10 | DtypeDetail,
11 | LayoutDetail,
12 | ShapeDetail,
13 | TensorDetail,
14 | )
15 | from .utils import frozendict
16 |
17 | from typing import Any, NoReturn
18 |
19 | # Annotated is available in python version 3.9 (PEP 593)
20 | if sys.version_info >= (3, 9):
21 | from typing import Annotated
22 | else:
23 | # Else python version is lower than 3.9
24 | # we import Annotated from typing_annotations
25 | from typing_extensions import Annotated
26 |
27 | # Not Type[Annotated...] as we want to use this in instance checks.
28 | _AnnotatedType = type(Annotated[torch.Tensor, ...])
29 |
30 |
31 | # For use when we have a plain TensorType, without any [].
32 | class _TensorTypeMeta(type(torch.Tensor)):
33 | def __instancecheck__(cls, obj: Any) -> bool:
34 | return isinstance(obj, cls.base_cls)
35 |
36 |
37 | # Inherit from torch.Tensor so that IDEs are happy to find methods on functions
38 | # annotated as TensorTypes.
39 | class TensorType(torch.Tensor, metaclass=_TensorTypeMeta):
40 | base_cls = torch.Tensor
41 |
42 | def __new__(cls, *args, **kwargs) -> NoReturn:
43 | raise RuntimeError(f"Class {cls.__name__} cannot be instantiated.")
44 |
45 | @staticmethod
46 | def _type_error(item: Any) -> NoReturn:
47 | raise TypeError(f"{item} not a valid type argument.")
48 |
49 | @classmethod
50 | def _convert_shape_element(cls, item_i: Any) -> _Dim:
51 | if isinstance(item_i, int) and not isinstance(item_i, bool):
52 | return _Dim(name=_no_name, size=item_i)
53 | elif isinstance(item_i, str):
54 | return _Dim(name=item_i, size=-1)
55 | elif item_i is None:
56 | return _Dim(name=None, size=-1)
57 | elif isinstance(item_i, slice):
58 | if item_i.step is not None:
59 | cls._type_error(item_i)
60 | if item_i.start is not None and not isinstance(item_i.start, str):
61 | cls._type_error(item_i)
62 | if item_i.stop is not ... and not isinstance(item_i.stop, (int, str)):
63 | cls._type_error(item_i)
64 | if item_i.start is None and item_i.stop is ...:
65 | cls._type_error(item_i)
66 | return _Dim(name=item_i.start, size=item_i.stop)
67 | elif item_i is ...:
68 | return _Dim(name=_no_name, size=...)
69 | elif item_i is Any:
70 | return _Dim(name=_no_name, size=-1)
71 | else:
72 | cls._type_error(item_i)
73 |
74 | @staticmethod
75 | def _convert_dtype_element(item_i: Any) -> torch.dtype:
76 | if item_i is int:
77 | return torch.long
78 | elif item_i is float:
79 | return torch.get_default_dtype()
80 | elif item_i is bool:
81 | return torch.bool
82 | else:
83 | return item_i
84 |
85 | def __class_getitem__(cls, item: Any) -> _AnnotatedType:
86 | if isinstance(item, tuple):
87 | if len(item) == 0:
88 | item = ((),)
89 | else:
90 | item = (item,)
91 |
92 | scalar_shape = False
93 | not_ellipsis = False
94 | not_named_ellipsis = False
95 | check_names = False
96 | dims = []
97 | dtypes = []
98 | layouts = []
99 | details = []
100 | for item_i in item:
101 | if isinstance(item_i, (int, str, slice)) or item_i in (None, ..., Any):
102 | item_i = cls._convert_shape_element(item_i)
103 | if item_i.size is ...:
104 | # Supporting an arbitrary number of Ellipsis in arbitrary
105 | # locations feels concerningly close to writing a regex
106 | # parser and I definitely don't have time for that.
107 | if not_ellipsis:
108 | raise NotImplementedError(
109 | "Having dimensions to the left of `...` is not currently "
110 | "supported."
111 | )
112 | if item_i.name is None:
113 | if not_named_ellipsis:
114 | raise NotImplementedError(
115 | "Having named `...` to the left of unnamed `...` is "
116 | "not currently supported."
117 | )
118 | else:
119 | not_named_ellipsis = True
120 | else:
121 | not_ellipsis = True
122 | dims.append(item_i)
123 | elif isinstance(item_i, tuple):
124 | if len(item_i) == 0:
125 | scalar_shape = True
126 | else:
127 | cls._type_error(item_i)
128 | elif item_i in (int, bool, float) or isinstance(item_i, torch.dtype):
129 | dtypes.append(cls._convert_dtype_element(item_i))
130 | elif isinstance(item_i, torch.layout):
131 | layouts.append(item_i)
132 | elif item_i is is_named:
133 | check_names = True
134 | elif isinstance(item_i, TensorDetail):
135 | details.append(item_i)
136 | else:
137 | cls._type_error(item_i)
138 |
139 | if scalar_shape:
140 | if len(dims) != 0:
141 | cls._type_error(item)
142 | else:
143 | if len(dims) == 0:
144 | dims = None
145 |
146 | pre_details = []
147 | if dims is not None:
148 | pre_details.append(ShapeDetail(dims=dims, check_names=check_names))
149 |
150 | if len(dtypes) == 0:
151 | pass
152 | elif len(dtypes) == 1:
153 | pre_details.append(DtypeDetail(dtype=dtypes[0]))
154 | else:
155 | raise TypeError("Cannot have multiple dtypes.")
156 |
157 | if len(layouts) == 0:
158 | pass
159 | elif len(layouts) == 1:
160 | pre_details.append(LayoutDetail(layout=layouts[0]))
161 | else:
162 | raise TypeError("Cannot have multiple layouts.")
163 |
164 | details = tuple(pre_details + details)
165 |
166 | assert len(details) > 0
167 |
168 | # Frozen dict needed for Union[TensorType[...], ...], as Union hashes its
169 | # arguments.
170 | return Annotated[
171 | cls.base_cls,
172 | frozendict(
173 | {"__torchtyping__": True, "details": details, "cls_name": cls.__name__}
174 | ),
175 | ]
176 |
--------------------------------------------------------------------------------
/torchtyping/typechecker.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import sys
3 | import torch
4 | import typeguard
5 |
6 | from .tensor_details import _Dim, _no_name, ShapeDetail
7 | from .tensor_type import _AnnotatedType
8 |
9 | from typing import Any, Dict, List, Tuple
10 |
11 | # get_args is available in python version 3.8
12 | # get_type_hints with include_extras parameter is available in 3.9 PEP 593.
13 | if sys.version_info >= (3, 9):
14 | from typing import get_type_hints, get_args, Type
15 | else:
16 | from typing_extensions import get_type_hints, get_args, Type
17 |
18 |
19 | # TYPEGUARD PATCHER
20 | #######################
21 | # So there's quite a lot of moving pieces here.
22 | # The logic proceeds as follows.
23 | #
24 | # Calling patch_typeguard() just monkey-patches some of its functions and classes.
25 | #
26 | # typeguard uses a `_CallMemo` object to store information about each function that it
27 | # is checking: this is what allows us to perform function-level checking (consistency
28 | # of tensor shapes) rather than just argument-level checking (simple isinstance
29 | # checks).
30 | # So the first thing we do is enhance that with a couple extra slots to store our
31 | # information
32 | #
33 | # Second, we patch `check_type`. typeguard traverses the [] hierarchy, e.g. from
34 | # Tuple[List[int]] to List[int] to int, recursively calling `check_type`. By patching
35 | # `check_type` we can check for our `TensorType`s and record every value-type pair.
36 | # (Actually it's a bit more than that: we record some names for use in the error
37 | # messages.) These are recorded in our enhanced `_CallMemo` object.
38 | #
39 | # (Incidentally we also have to patch typeguard's use of typing.get_type_hints, so that
40 | # our annotations aren't stripped.)
41 | #
42 | # Then we patch `check_argument_types` and `check_return_type`, to perform our extra
43 | # TensorType checking. This is the same checking in both cases so we factor that out
44 | # into _check_memo.
45 | #
46 | # _check_memo performs the real logic of the checking here. This looks at all the
47 | # recorded value-type pairs and checks for any inconsistencies.
48 |
49 |
50 | def _to_string(name, detail_reprs: List[str]) -> str:
51 | assert len(detail_reprs) > 0
52 | string = name + "["
53 | pieces = []
54 | for detail_repr in detail_reprs:
55 | if detail_repr != "":
56 | pieces.append(detail_repr)
57 | string += ", ".join(pieces)
58 | string += "]"
59 | return string
60 |
61 |
62 | def _check_tensor(
63 | argname: str, value: Any, origin: Type[torch.Tensor], metadata: Dict[str, Any]
64 | ):
65 | details = metadata["details"]
66 | if not isinstance(value, origin) or any(
67 | not detail.check(value) for detail in details
68 | ):
69 | expected_string = _to_string(
70 | metadata["cls_name"], [repr(detail) for detail in details]
71 | )
72 | if isinstance(value, torch.Tensor):
73 | given_string = _to_string(
74 | metadata["cls_name"], [detail.tensor_repr(value) for detail in details]
75 | )
76 | else:
77 | value = type(value)
78 | if hasattr(value, "__qualname__"):
79 | given_string = value.__qualname__
80 | elif hasattr(value, "__name__"):
81 | given_string = value.__name__
82 | else:
83 | given_string = repr(value)
84 | raise TypeError(
85 | f"{argname} must be of type {expected_string}, got type {given_string} "
86 | "instead."
87 | )
88 |
89 |
90 | def _check_memo(memo):
91 | ###########
92 | # Parse the tensors and figure out the sizes of all labelled
93 | # dimensions.
94 | # This also performs some (and in practice most) of the consistency
95 | # checks. However its job is primarily one of assigning sizes to labels.
96 | # The final checking of the inferred sizes is performed afterwards.
97 | #
98 | # This logic is a bit hairy. Most of the complexity comes from
99 | # supporting `...` arbitrary numbers of dimensions.
100 | ###########
101 |
102 | # ordered set
103 | shape_info = {
104 | (argname, value.shape, detail): None
105 | for argname, value, _, detail in memo.value_info
106 | }
107 | while len(shape_info):
108 | for argname, shape, detail in shape_info:
109 | num_free_ellipsis = 0
110 | for dim in detail.dims:
111 | if dim.size is ... and dim.name not in memo.name_to_shape:
112 | num_free_ellipsis += 1
113 | if num_free_ellipsis <= 1:
114 | reversed_shape = enumerate(reversed(shape))
115 | for dim in reversed(detail.dims):
116 | try:
117 | reverse_dim_index, size = next(reversed_shape)
118 | except StopIteration:
119 | if dim.size is ...:
120 | if dim.name not in (None, _no_name):
121 | try:
122 | lookup_shape = memo.name_to_shape[dim.name]
123 | except KeyError:
124 | memo.name_to_shape[dim.name] = ()
125 | else:
126 | if lookup_shape != ():
127 | raise TypeError(
128 | f"Dimension group '{dim.name}' of "
129 | f"inconsistent shape. Got both () and "
130 | f"{lookup_shape}."
131 | )
132 | else:
133 | # I don't think it's possible to get here, as the earlier
134 | # call to _check_tensor in check_type should catch
135 | # this case.
136 | raise TypeError(
137 | f"{argname} has {len(shape)} dimensions but type "
138 | "requires more than this."
139 | )
140 |
141 | if dim.name not in (None, _no_name):
142 | if dim.size is ...:
143 | try:
144 | lookup_shape = memo.name_to_shape[dim.name]
145 | except KeyError:
146 | # Can only get here if we're the single free
147 | # ellipsis.
148 | # Therefore the number of dimensions the ellipsis
149 | # corresponds to, is the number of dimensions
150 | # remaining.
151 | forward_index = 0
152 | for forward_dim in detail.dims: # now iterate forwards
153 | if forward_dim is dim:
154 | break
155 | assert forward_dim.size is ...
156 | forward_index += len(
157 | memo.name_to_shape[forward_dim.name]
158 | )
159 | if reverse_dim_index == 0:
160 | # since [:-0] doesn't work
161 | end_index = None
162 | else:
163 | end_index = -reverse_dim_index
164 | clip_shape = shape[forward_index:end_index]
165 | memo.name_to_shape[dim.name] = tuple(clip_shape)
166 | for _ in range(len(clip_shape) - 1):
167 | next(reversed_shape)
168 | else:
169 | reversed_shape_piece = []
170 | if len(lookup_shape) >= 1:
171 | reversed_shape_piece.append(size)
172 | for _ in range(len(lookup_shape) - 1):
173 | try:
174 | _, size = next(reversed_shape)
175 | except StopIteration:
176 | break
177 | reversed_shape_piece.append(size)
178 |
179 | shape_piece = tuple(reversed(reversed_shape_piece))
180 | if lookup_shape != shape_piece:
181 | raise TypeError(
182 | f"Dimension group '{dim.name}' of "
183 | f"inconsistent shape. Got both {shape_piece} "
184 | f"and {lookup_shape}."
185 | )
186 | else:
187 | names_to_check = (
188 | [dim.name, dim.size]
189 | if isinstance(dim.size, str)
190 | else [dim.name]
191 | )
192 | for name in names_to_check:
193 | try:
194 | lookup_size = memo.name_to_size[name]
195 | except KeyError:
196 | memo.name_to_size[name] = size
197 | else:
198 | # Technically not necessary, as one of the
199 | # sizes will override the other, and then the
200 | # instance check will fail.
201 | # This gives a nicer error message though.
202 | if lookup_size != size:
203 | raise TypeError(
204 | f"Dimension '{dim.name}' of inconsistent"
205 | f" size. Got both {size} and "
206 | f"{lookup_size}."
207 | )
208 |
209 | del shape_info[argname, shape, detail]
210 | break
211 | else:
212 | if len(shape_info):
213 | names = {argname for argname, _, _ in shape_info}
214 | raise TypeError(
215 | f"Could not resolve the size of all `...` in {names}. Either:\n"
216 | "(1) the specification is ambiguous. For example "
217 | "`func(tensor: TensorType['x': ..., 'y': ...])`.\n"
218 | "(2) or repeated named `...` are used without being able to "
219 | "resolve the size of those named `...` via another argument "
220 | "For example `func(tensor: TensorType['x': ..., 'x': ...])`. "
221 | "(But `func(tensor1: TensorType['x': ..., 'x': ...], tensor2: "
222 | "TensorType['x': ...])` would be fine.)\n"
223 | "\n"
224 | "Removing the names of the `...` should suffice to resolve this "
225 | "error. (But will of course remove that checking as well.)"
226 | )
227 |
228 | ###########
229 | # Do the final checking with the inferred sizes filled in.
230 | # In practice, malformed inputs will usually trip one of the
231 | # checks in the previous logic, so this block doesn't actually raise
232 | # errors very often. (In 1/37 tests at time of writing.)
233 | # A potential performance improvement might be to integrate it into
234 | # the previous block.
235 | ###########
236 |
237 | for argname, value, cls_name, detail in memo.value_info:
238 | dims = []
239 | for dim in detail.dims:
240 | size = dim.size
241 | if dim.name not in (None, _no_name):
242 | if size == -1:
243 | size = memo.name_to_size[dim.name]
244 | elif isinstance(size, str):
245 | size = memo.name_to_size[size]
246 | elif size is ...:
247 | # This assumes that named Ellipses only occur to the
248 | # right of unnamed Ellipses, to avoid filling in
249 | # Ellipses that occur to the left of other Ellipses.
250 | for size in memo.name_to_shape[dim.name]:
251 | dims.append(_Dim(name=_no_name, size=size))
252 | continue
253 | dims.append(_Dim(name=dim.name, size=size))
254 | detail = detail.update(dims=tuple(dims))
255 | _check_tensor(
256 | argname, value, torch.Tensor, {"cls_name": cls_name, "details": [detail]}
257 | )
258 |
259 |
260 | unpatched_typeguard = True
261 |
262 |
263 | def patch_typeguard():
264 | global unpatched_typeguard
265 | if unpatched_typeguard:
266 | unpatched_typeguard = False
267 |
268 | # Defined dynamically, in case something else is doing similar levels of hackery
269 | # patching typeguard. We want to get typeguard._CallMemo at the time we patch,
270 | # not any earlier. (Someone might have replaced it since the import statement.)
271 | class _CallMemo(typeguard._CallMemo):
272 | __slots__ = (
273 | "value_info",
274 | "name_to_size",
275 | "name_to_shape",
276 | )
277 | value_info: List[Tuple[str, torch.Tensor, str, Dict[str, Any]]]
278 | name_to_size: Dict[str, int]
279 | name_to_shape: Dict[str, Tuple[int]]
280 |
281 | _check_type = typeguard.check_type
282 | _check_argument_types = typeguard.check_argument_types
283 | _check_return_type = typeguard.check_return_type
284 |
285 | check_type_signature = inspect.signature(_check_type)
286 | check_argument_types_signature = inspect.signature(_check_argument_types)
287 | check_return_type_signature = inspect.signature(_check_return_type)
288 |
289 | def check_type(*args, **kwargs):
290 | bound_args = check_type_signature.bind(*args, **kwargs).arguments
291 | argname = bound_args["argname"]
292 | value = bound_args["value"]
293 | expected_type = bound_args["expected_type"]
294 | memo = bound_args["memo"]
295 | # First look for an annotated type
296 | is_torchtyping_annotation = (
297 | memo is not None
298 | and hasattr(memo, "value_info")
299 | and isinstance(expected_type, _AnnotatedType)
300 | )
301 | # Now check if it's annotating a tensor
302 | if is_torchtyping_annotation:
303 | base_cls, *all_metadata = get_args(expected_type)
304 | if not issubclass(base_cls, torch.Tensor):
305 | is_torchtyping_annotation = False
306 | # Now check if the annotation's metadata is our metadata
307 | if is_torchtyping_annotation:
308 | for metadata in all_metadata:
309 | if isinstance(metadata, dict) and "__torchtyping__" in metadata:
310 | break
311 | else:
312 | is_torchtyping_annotation = False
313 | if is_torchtyping_annotation:
314 | # We call _check_tensor here -- despite calling _check_tensor again
315 | # once we've seen every argument and filled in the shape details --
316 | # just because we want to check that `value` is in fact a tensor before
317 | # we access its `shape` field on the next line.
318 | _check_tensor(argname, value, base_cls, metadata)
319 | for detail in metadata["details"]:
320 | if isinstance(detail, ShapeDetail):
321 | memo.value_info.append(
322 | (argname, value, metadata["cls_name"], detail)
323 | )
324 | break
325 |
326 | else:
327 | _check_type(*args, **kwargs)
328 |
329 | def check_argument_types(*args, **kwargs):
330 | bound_args = check_argument_types_signature.bind(*args, **kwargs).arguments
331 | memo = bound_args["memo"]
332 | if memo is None:
333 | return _check_argument_types(*args, **kwargs)
334 | else:
335 | memo.value_info = []
336 | memo.name_to_size = {}
337 | memo.name_to_shape = {}
338 | retval = _check_argument_types(*args, **kwargs)
339 | try:
340 | _check_memo(memo)
341 | except TypeError as exc: # suppress long traceback
342 | raise TypeError(*exc.args) from None
343 | return retval
344 |
345 | def check_return_type(*args, **kwargs):
346 | bound_args = check_return_type_signature.bind(*args, **kwargs).arguments
347 | memo = bound_args["memo"]
348 | if memo is None:
349 | return _check_return_type(*args, **kwargs)
350 | else:
351 | # Reset the collection of things that need checking.
352 | memo.value_info = []
353 | # Do _not_ set memo.name_to_size or memo.name_to_shape, as we want to
354 | # keep using the same sizes inferred from the arguments.
355 | retval = _check_return_type(*args, **kwargs)
356 | try:
357 | _check_memo(memo)
358 | except TypeError as exc: # suppress long traceback
359 | raise TypeError(*exc.args) from None
360 | return retval
361 |
362 | typeguard._CallMemo = _CallMemo
363 | typeguard.check_type = check_type
364 | typeguard.check_argument_types = check_argument_types
365 | typeguard.check_return_type = check_return_type
366 | typeguard.get_type_hints = lambda *args, **kwargs: get_type_hints(
367 | *args, **kwargs, include_extras=True
368 | )
369 |
--------------------------------------------------------------------------------
/torchtyping/utils.py:
--------------------------------------------------------------------------------
1 | class frozendict(dict):
2 | def __init__(self, *args, **kwargs):
3 | super().__init__(*args, **kwargs)
4 | # Calling this immediately ensures that no unhashable types are used as
5 | # entries.
6 | # There's also no way this is an efficient hash algorithm, but we're only
7 | # planning on using this with small dictionaries.
8 | self._hash = hash(tuple(sorted(self.items())))
9 |
10 | def __setitem__(self, item):
11 | raise RuntimeError(f"Cannot add items to a {type(self)}.")
12 |
13 | def __delitem__(self, item):
14 | raise RuntimeError(f"Cannot delete items from a {type(self)}.")
15 |
16 | def __hash__(self):
17 | return self._hash
18 |
19 | def __copy__(self):
20 | return self
21 |
22 | def __deepcopy__(self, memo):
23 | return self
24 |
--------------------------------------------------------------------------------