├── .github
├── dependabot.yml
└── workflows
│ ├── pypi_publish.yml
│ └── tox_run.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE.txt
├── README.md
├── deps
└── requirements-pyright.txt
├── pyproject.toml
├── pyrightconfig_py38.json
├── setup.cfg
├── src
└── phantom_tensors
│ ├── __init__.py
│ ├── _internals
│ ├── __init__.py
│ ├── dim_binding.py
│ ├── parse.py
│ └── utils.py
│ ├── alphabet.py
│ ├── array.py
│ ├── errors.py
│ ├── meta.py
│ ├── numpy.py
│ ├── parse.py
│ ├── py.typed
│ ├── torch.py
│ └── words.py
└── tests
├── __init__.py
├── annotations.py
├── arrlike.py
├── conftest.py
├── test_letters_and_words.py
├── test_prototype.py
├── test_third_party.py
└── test_type_properties.py
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: "github-actions" # See documentation for possible values
4 | directory: "/" # Location of package manifests
5 | schedule:
6 | interval: "daily"
7 | - package-ecosystem: "pip"
8 | directory: "/deps/"
9 | schedule:
10 | interval: "daily"
11 | target-branch: "main"
--------------------------------------------------------------------------------
/.github/workflows/pypi_publish.yml:
--------------------------------------------------------------------------------
1 |
2 | # This workflows will upload a Python Package using Twine when a release is created
3 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
4 |
5 | name: Upload Python Package
6 |
7 | on:
8 | release:
9 | types: [created]
10 |
11 | jobs:
12 | deploy:
13 |
14 | runs-on: ubuntu-latest
15 |
16 | steps:
17 | - uses: actions/checkout@v4
18 | - name: Set up Python
19 | uses: actions/setup-python@v5
20 | with:
21 | python-version: '3.x'
22 | - name: Install dependencies
23 | run: |
24 | python -m pip install --upgrade pip
25 | pip install build twine
26 | - name: Build and publish
27 | env:
28 | TWINE_USERNAME: __token__
29 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
30 | run: |
31 | python -m build
32 | twine upload dist/*
--------------------------------------------------------------------------------
/.github/workflows/tox_run.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 |
3 | on:
4 | # Trigger the workflow on push or pull request,
5 | # but only for the main branch
6 | push:
7 | branches:
8 | - main
9 | - develop
10 | pull_request:
11 | branches:
12 | - main
13 | - develop
14 |
15 | jobs:
16 | tests:
17 | runs-on: ubuntu-latest
18 |
19 | strategy:
20 | max-parallel: 5
21 | matrix:
22 | python-version: [3.8, 3.9, "3.10", 3.11, 3.12]
23 | fail-fast: false
24 |
25 | steps:
26 | - uses: actions/checkout@v4
27 | - name: Set up Python ${{ matrix.python-version }}
28 | uses: actions/setup-python@v5
29 | with:
30 | python-version: ${{ matrix.python-version }}
31 | - name: Install dependencies
32 | run: |
33 | python -m pip install --upgrade pip
34 | pip install tox tox-gh-actions
35 | - name: Test with tox
36 | run: tox -e py
37 |
38 | test-third-party:
39 | runs-on: ubuntu-latest
40 | steps:
41 | - uses: actions/checkout@v4
42 | - name: Set up Python 3.11
43 | uses: actions/setup-python@v5
44 | with:
45 | python-version: 3.11
46 | - name: Install dependencies
47 | run: |
48 | python -m pip install --upgrade pip
49 | pip install tox tox-gh-actions
50 | - name: Cache tox environments
51 | id: cache-third-party
52 | uses: actions/cache@v4
53 | with:
54 | path: .tox
55 | key: tox-third-party-${{ hashFiles('setup.cfg') }}-${{ hashFiles('setup.py') }}-${{ hashFiles('tests/conftest.py') }}-${{ hashFiles('.github/workflows/tox.yml') }}
56 | - name: Test with tox
57 | run: tox -e third-party
58 |
59 | run-pyright:
60 | runs-on: ubuntu-latest
61 | steps:
62 | - uses: actions/checkout@v4
63 | - name: Set up Python 3.11
64 | uses: actions/setup-python@v5
65 | with:
66 | python-version: 3.11
67 | - name: Install dependencies
68 | run: |
69 | python -m pip install --upgrade pip
70 | pip install tox tox-gh-actions
71 | - name: Cache tox environments
72 | id: cache-third-party
73 | uses: actions/cache@v4
74 | with:
75 | path: .tox
76 | key: tox-third-party-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('setup.py') }}-${{ hashFiles('.github/workflows/tox.yml') }}
77 | - name: Test with tox
78 | run: tox -e pyright
79 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # hydra-zen docs
132 |
133 | docs/source/_build
134 | docs/builds
135 |
136 | # pycharm
137 | .idea/
138 |
139 | # pyright
140 | node_modules/
141 |
142 | scratch/
143 |
144 | # npm
145 | package-lock.json
146 |
147 | # vscode
148 | settings.json
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/psf/black
3 | rev: 24.3.0
4 | hooks:
5 | - id: black
6 | - repo: https://github.com/PyCQA/flake8
7 | rev: 7.0.0
8 | hooks:
9 | - id: flake8
10 | - repo: https://github.com/pycqa/isort
11 | rev: 5.13.2
12 | hooks:
13 | - id: isort
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Ryan Soklaski
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Phantom Tensors
2 | > Tensor types with variadic shapes, for any array-based library, that work with both static and runtime type checkers
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | **This project is currently just a rough prototype! Inspired by: [phantom-types](https://github.com/antonagestam/phantom-types)**
15 |
16 | The goal of this project is to let users write tensor-like types with variadic shapes (via [PEP 646](https://peps.python.org/pep-0646/)) that are:
17 | - Amendable to **static type checking (without mypy plugins)**.
18 | > E.g., pyright can tell the difference between `Tensor[Batch, Channel]` and `Tensor[Batch, Feature]`
19 | - Useful for performing **runtime checks of tensor types and shapes**.
20 | > E.g., can validate -- at runtime -- that arrays of types `NDArray[A, B]` and `NDArray[B, A]` indeed have transposed shapes with respect with each other.
21 | - Compatible with *any* array-based library (numpy, pytorch, xarray, cupy, mygrad, etc.)
22 | > E.g. A function annotated with `x: torch.Tensor` can be passed `phantom_tensors.torch.Tensor[N, B, D]`. It is trivial to write custom phantom-tensor flavored types for any array-based library.
23 |
24 | `phantom_tensors.parse` makes it easy to declare shaped tensor types in a way that static type checkers understand, and that are validated at runtime:
25 |
26 | ```python
27 | from typing import NewType
28 |
29 | import numpy as np
30 |
31 | from phantom_tensors import parse
32 | from phantom_tensors.numpy import NDArray
33 |
34 | A = NewType("A", int)
35 | B = NewType("B", int)
36 |
37 | # static: declare that x is of type NDArray[A, B]
38 | # declare that y is of type NDArray[B, A]
39 | # runtime: check that shapes (2, 3) and (3, 2)
40 | # match (A, B) and (B, A) pattern across
41 | # tensors
42 | x, y = parse(
43 | (np.ones((2, 3)), NDArray[A, B]),
44 | (np.ones((3, 2)), NDArray[B, A]),
45 | )
46 |
47 | x # static type checker sees: NDArray[A, B]
48 | y # static type checker sees: NDArray[B, A]
49 |
50 | ```
51 |
52 | Passing inconsistent types to `parse` will result in a runtime validation error.
53 | ```python
54 | # Runtime: Raises `ParseError` A=10 and A=2 do not match
55 | z, w = parse(
56 | (np.ones((10, 3)), NDArray[A, B]),
57 | (np.ones((3, 2)), NDArray[B, A]),
58 | )
59 | ```
60 |
61 | These shaped tensor types are amenable to static type checking:
62 |
63 | ```python
64 | from typing import Any
65 |
66 | import numpy as np
67 |
68 | from phantom_tensors import parse
69 | from phantom_tensors.numpy import NDArray
70 | from phantom_tensors.alphabet import A, B # these are just NewType(..., int) types
71 |
72 | def func_on_2d(x: NDArray[Any, Any]): ...
73 | def func_on_3d(x: NDArray[Any, Any, Any]): ...
74 | def func_on_any_arr(x: np.ndarray): ...
75 |
76 | # runtime: ensures shape of arr_3d matches (A, B, A) patterns
77 | arr_3d = parse(np.ones((3, 5, 3)), NDArray[A, B, A])
78 |
79 | func_on_2d(arr_3d) # static type checker: Error! # expects 2D arr, got 3D
80 |
81 | func_on_3d(arr_3d) # static type checker: OK
82 | func_on_any_arr(arr_3d) # static type checker: OK
83 | ```
84 |
85 |
86 | Write easy-to-understand interfaces using common dimension names (or make up your own):
87 |
88 | ```python
89 | from phantom_tensors.torch import Tensor
90 | from phantom_tensors.words import Batch, Embed, Vocab
91 |
92 | def embedder(x: Tensor[Batch, Vocab]) -> Tensor[Batch, Embed]:
93 | ...
94 | ```
95 |
96 |
97 | Using a runtime type checker, such as [beartype](https://github.com/beartype/beartype) or [typeguard](https://github.com/agronholm/typeguard), in conjunction with `phantom_tensors` means that the typed shape information will be validated at runtime across a function's inputs and outputs, whenever that function is called.
98 |
99 | ```python
100 | from typing import TypeVar, cast
101 | from typing_extensions import assert_type
102 |
103 | import torch as tr
104 | from beartype import beartype
105 |
106 | from phantom_tensors import dim_binding_scope, parse
107 | from phantom_tensors.torch import Tensor
108 | from phantom_tensors.alphabet import A, B, C
109 |
110 | T1 = TypeVar("T1")
111 | T2 = TypeVar("T2")
112 | T3 = TypeVar("T3")
113 |
114 |
115 | @dim_binding_scope
116 | @beartype # <- adds runtime type checking to function's interfaces
117 | def buggy_matmul(x: Tensor[T1, T2], y: Tensor[T2, T3]) -> Tensor[T1, T3]:
118 | # This is the wrong operation!
119 | # Will return shape-(T1, T1) tensor, not (T1, T3)
120 | out = x @ x.T
121 |
122 | # We lie to the static type checker to try to get away with it
123 | return cast(Tensor[T1, T3], out)
124 |
125 | x, y = parse(
126 | (tr.ones(3, 4), Tensor[A, B]),
127 | (tr.ones(4, 5), Tensor[B, C]),
128 | )
129 |
130 | # At runtime beartype raises:
131 | # Function should return shape-(A, C) but returned shape-(A, A)
132 | z = buggy_matmul(x, y) # Runtime validation error!
133 |
134 | ```
135 |
136 | ## Installation
137 |
138 | ```shell
139 | pip install phantom-tensors
140 | ```
141 |
142 | `typing-extensions` is the only strict dependency. Using features from `phantom_tensors.torch(numpy)` requires that `torch`(`numpy`) is installed too.
143 |
144 | ## Some Lower-Level Details and Features
145 |
146 | Everything on display here is achieved using relatively minimal hacks (no mypy plugin necessary, no monkeypatching). Presently, `torch.Tensor` and `numpy.ndarray` are explicitly supported by phantom-tensors, but it is trivial to add support for other array-like classes.
147 |
148 | > Note that mypy does not support PEP 646 yet, but pyright does. You can run pyright on the following examples to see that they do, indeed type-check as expected!
149 |
150 |
151 | ### Dimension-Binding Contexts
152 |
153 | `phantom_tensors.parse` validates inputs against types-with-shapes and performs [type narrowing](https://mypy.readthedocs.io/en/latest/type_narrowing.html) so that static type checkers are privy to the newly proven type information about those inputs. It performs inter-tensor shape consistency checks within a "dimension-binding context". Tensor-likes that are parsed simultaneously are automatically checked within a common dimension-binding context.
154 |
155 |
156 | ```python
157 | import numpy as np
158 | import torch as tr
159 |
160 | from phantom_tensors import parse
161 | from phantom_tensors.alphabet import A, B, C
162 | from phantom_tensors.numpy import NDArray
163 | from phantom_tensors.torch import Tensor
164 |
165 | t1, arr, t2 = parse(
166 | # <- Runtime: enter dimension-binding context
167 | (tr.rand(9, 2, 9), Tensor[B, A, B]), # <-binds A=2 & B=9
168 | (np.ones((2,)), NDArray[A]), # <- checks A==2
169 | (tr.rand(9), Tensor[B]), # <- checks B==9
170 | ) # <- Runtime: exit dimension-binding scope
171 | # Statically: casts t1, arr, t2 to shape-typed Tensors
172 |
173 | # static type checkers now see
174 | # t1: Tensor[B, A, B]
175 | # arr: NDArray[A]
176 | # t2: Tensor[B]
177 |
178 | w = parse(tr.rand(78), Tensor[A]); # <- binds A=78 within this context
179 | ```
180 |
181 | As indicated above, the type-checker sees the shaped-tensor/array types. Additionally, these are subclasses of their rightful parents, so we can pass these to functions typed with vanilla `torch.Tensor` and `numpy.ndarry` annotations, and type checkers will be a-ok with that.
182 |
183 | ```python
184 | def vanilla_numpy(x: np.ndarray): ...
185 | def vanilla_torch(x: tr.Tensor): ...
186 |
187 | vanilla_numpy(arr) # type checker: OK
188 | vanilla_torch(arr) # type checker: Error!
189 | vanilla_torch(t1) # type checker: OK
190 | ```
191 |
192 | #### Basic forms of runtime validation performed by `parse`
193 |
194 | ```python
195 | # runtime type checking
196 | >>> parse(1, Tensor[A])
197 | ---------------------------------------------------------------------------
198 | ParseError: Expected , got:
199 |
200 | # dimensionality mismatch
201 | >>> parse(tr.ones(3), Tensor[A, A, A])
202 | ---------------------------------------------------------------------------
203 | ParseError: shape-(3,) doesn't match shape-type (A=?, A=?, A=?)
204 |
205 | # unsatisfied shape pattern
206 | >>> parse(tr.ones(1, 2), Tensor[A, A])
207 | ---------------------------------------------------------------------------
208 | ParseError: shape-(1, 2) doesn't match shape-type (A=1, A=1)
209 |
210 | # inconsistent dimension sizes across tensors
211 | >>> x, y = parse(
212 | ... (tr.ones(1, 2), Tensor[A, B]),
213 | ... (tr.ones(4, 1), Tensor[B, A]),
214 | ... )
215 |
216 | ---------------------------------------------------------------------------
217 | ParseError: shape-(4, 1) doesn't match shape-type (B=2, A=1)
218 | ```
219 |
220 | To reiterate, `parse` is able to compare shapes across multiple tensors by entering into a "dimension-binding scope".
221 | One can enter into this context explicitly:
222 |
223 | ```python
224 | >>> from phantom_tensors import dim_binding_scope
225 |
226 | >>> x = parse(np.zeros((2,)), NDArray[B]) # binds B=2
227 | >>> y = parse(np.zeros((3,)), NDArray[B]) # binds B=3
228 | >>> with dim_binding_scope:
229 | ... x = parse(np.zeros((2,)), NDArray[B]) # binds B=2
230 | ... y = parse(np.zeros((3,)), NDArray[B]) # raises!
231 | ---------------------------------------------------------------------------
232 | ParseError: shape-(3,) doesn't match shape-type (B=2,)
233 | ```
234 |
235 | #### Support for `Literal` dimensions:
236 |
237 | ```python
238 | from typing import Literal as L
239 |
240 | from phantom_tensors import parse
241 | from phantom_tensors.torch import Tensor
242 |
243 | import torch as tr
244 |
245 | parse(tr.zeros(1, 3), Tensor[L[1], L[3]]) # static + runtime: OK
246 | parse(tr.zeros(2, 3), Tensor[L[1], L[3]]) # # Runtime: ParseError - mismatch at dim 0
247 | ```
248 |
249 | #### Support for `Literal` dimensions and variadic shapes:
250 |
251 | In Python 3.11 you can write shape types like `Tensor[int, *Ts, int]`, where `*Ts` represents 0 or more optional entries between two required dimensions. phantom-tensor supports this "unpack" dimension. In this README we opt for `typing_extensions.Unpack[Ts]` instead of `*Ts` for the sake of backwards compatibility.
252 |
253 | ```python
254 | from phantom_tensors import parse
255 | from phantom_tensors.torch import Tensor
256 |
257 | import torch as tr
258 | from typing_extensions import Unpack as U, TypeVarTuple
259 |
260 | Ts = TypeVarTuple("Ts")
261 |
262 | # U[Ts] represents an arbitrary number of entries
263 | parse(tr.ones(1, 3), Tensor[int, U[Ts], int) # static + runtime: OK
264 | parse(tr.ones(1, 0, 0, 0, 3), Tensor[int, U[Ts], int]) # static + runtime: OK
265 |
266 | parse(tr.ones(1, ), Tensor[int, U[Ts], int]) # Runtime: Not enough dimensions
267 | ```
268 |
269 | #### Support for [phantom types](https://github.com/antonagestam/phantom-types):
270 |
271 | Supports phatom type dimensions (i.e. `int` subclasses that override `__isinstance__` checks):
272 |
273 | ```python
274 | from phantom_tensors import parse
275 | from phantom_tensors.torch import Tensor
276 |
277 | import torch as tr
278 | from phantom import Phantom
279 |
280 | class EvenOnly(int, Phantom, predicate=lambda x: x%2 == 0): ...
281 |
282 | parse(tr.ones(1, 0), Tensor[int, EvenOnly]) # static return type: Tensor[int, EvenOnly]
283 | parse(tr.ones(1, 2), Tensor[int, EvenOnly]) # static return type: Tensor[int, EvenOnly]
284 | parse(tr.ones(1, 4), Tensor[int, EvenOnly]) # static return type: Tensor[int, EvenOnly]
285 |
286 | parse(tr.ones(1, 3), Tensor[int, EvenOnly]) # runtime: ParseError (3 is not an even number)
287 | ```
288 |
289 |
290 |
291 | ## Compatibility with Runtime Type Checkers
292 |
293 | `parse` is not the only way to perform runtime validation using phantom tensors – they work out of the box with 3rd party runtime type checkers like [beartype](https://github.com/beartype/beartype)! How is this possible?
294 |
295 | ...We do something tricky here! At, runtime `Tensor[A, B]` actually returns a [phantom type](https://github.com/antonagestam/phantom-types). This means that `isinstance(arr, NDArray[A, B])` is, at runtime, *actually* performing `isinstance(arr, PhantomNDArrayAB)`, which dynamically generated and is able to perform the type and shape checks.
296 |
297 | Thanks to the ability to bind dimensions within a specified context, all `beartype` needs to do is faithfully call `isinstance(...)` within said context and we can have the inputs and ouputs of a phantom-tensor-annotated function get checked!
298 |
299 | ```python
300 | from typing import Any
301 |
302 | from beartype import beartype # type: ignore
303 | import pytest
304 | import torch as tr
305 |
306 | from phantom_tensors.alphabet import A, B, C
307 | from phantom_tensors.torch import Tensor
308 | from phantom_tensors import dim_binding_scope, parse
309 |
310 | # @dim_binding_scope:
311 | # ensures A, B, C consistent across all input/output tensor shapes
312 | # within scope of function
313 | @dim_binding_scope
314 | @beartype # <-- adds isinstance checks on inputs & outputs
315 | def matrix_multiply(x: Tensor[A, B], y: Tensor[B, C]) -> Tensor[A, C]:
316 | a, _ = x.shape
317 | _, c = y.shape
318 | return parse(tr.rand(a, c), Tensor[A, C])
319 |
320 | @beartype
321 | def needs_vector(x: Tensor[Any]): ...
322 |
323 | x, y = parse(
324 | (tr.rand(3, 4), Tensor[A, B]),
325 | (tr.rand(4, 5), Tensor[B, C]),
326 | )
327 |
328 | z = matrix_multiply(x, y)
329 | z # type revealed: Tensor[A, C]
330 |
331 | with pytest.raises(Exception):
332 | # beartype raises error: input Tensor[A, C] doesn't match Tensor[A]
333 | needs_vector(z) # <- pyright also raises an error!
334 |
335 | with pytest.raises(Exception):
336 | # beartype raises error: inputs Tensor[A, B], Tensor[A, B] don't match signature
337 | matrix_multiply(x, x) # <- pyright also raises an error!
338 | ```
339 |
340 |
--------------------------------------------------------------------------------
/deps/requirements-pyright.txt:
--------------------------------------------------------------------------------
1 | pyright==1.1.374
2 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools >= 35.0.2",
4 | "wheel >= 0.29.0",
5 | "setuptools_scm[toml]==7.0.5",
6 | ]
7 | build-backend = "setuptools.build_meta"
8 |
9 |
10 | [project]
11 | name = "phantom_tensors"
12 | dynamic = ["version"]
13 | description = "Tensor-like types – with variadic shapes – that support both static and runtime type checking, and convenient parsing."
14 | readme = "README.md"
15 | requires-python = ">=3.8"
16 | dependencies = ["typing-extensions >= 4.1.0"]
17 | license = { text = "MIT" }
18 | keywords = [
19 | "machine learning",
20 | "research",
21 | "configuration",
22 | "scalable",
23 | "reproducible",
24 | "yaml",
25 | "Hydra",
26 | "dataclass",
27 | ]
28 |
29 | authors = [
30 | { name = "Ryan Soklaski", email = "rsoklaski@gmail.com" },
31 | { name = "Justin Goodwin", email = "jgoodwin@ll.mit.edu" },
32 | ]
33 | maintainers = [{ name = "Justin Goodwin", email = "jgoodwin@ll.mit.edu" }]
34 |
35 | classifiers = [
36 | "Development Status :: 4 - Beta",
37 | "License :: OSI Approved :: MIT License",
38 | "Operating System :: OS Independent",
39 | "Intended Audience :: Science/Research",
40 | "Programming Language :: Python :: 3.8",
41 | "Programming Language :: Python :: 3.9",
42 | "Programming Language :: Python :: 3.10",
43 | "Programming Language :: Python :: 3.11",
44 | "Programming Language :: Python :: 3.12",
45 | "Topic :: Scientific/Engineering",
46 | "Programming Language :: Python :: 3 :: Only",
47 | ]
48 |
49 | [project.optional-dependencies]
50 | test = ["beartype >= 0.10.4", "pytest >= 3.8", "hypothesis >= 6.28.0"]
51 |
52 |
53 | [project.urls]
54 | "Homepage" = "https://github.com/rsokl/phantom-tensors/"
55 | "Bug Reports" = "https://github.com/rsokl/phantom-tensors/issues"
56 | "Source" = "https://github.com/rsokl/phantom-tensors"
57 |
58 |
59 | [tool.setuptools_scm]
60 | write_to = "src/phantom_tensors/_version.py"
61 | version_scheme = "no-guess-dev"
62 |
63 |
64 | [tool.setuptools]
65 | package-dir = { "" = "src" }
66 |
67 | [tool.setuptools.packages.find]
68 | where = ["src"]
69 | exclude = ["tests*", "tests.*"]
70 |
71 | [tool.setuptools.package-data]
72 | phantom_tensors = ["py.typed"]
73 |
74 |
75 | [tool.isort]
76 | known_first_party = ["phantom_tensors", "tests"]
77 | profile = "black"
78 | combine_as_imports = true
79 |
80 |
81 | [tool.coverage.run]
82 | branch = true
83 | omit = ["tests/test_docs_typecheck.py"]
84 |
85 | [tool.coverage.report]
86 | omit = ["src/phantom_tensors/_version.py"]
87 | exclude_lines = [
88 | 'pragma: no cover',
89 | 'def __repr__',
90 | 'raise NotImplementedError',
91 | 'class .*\bProtocol(\[.+\])?\):',
92 | '@(abc\.)?abstractmethod',
93 | '@(typing\.)?overload',
94 | 'except ImportError:',
95 | 'except ModuleNotFoundError:',
96 | 'if (typing\.)?TYPE_CHECKING:',
97 | 'if sys\.version_info',
98 | ]
99 |
100 | [tool.pytest.ini_options]
101 | xfail_strict = true
102 |
103 |
104 | [tool.pyright]
105 | include = ["src"]
106 | exclude = [
107 | "**/node_modules",
108 | "**/__pycache__",
109 | "src/phantom_tensors/_version.py",
110 | "**/third_party",
111 | ]
112 | reportUnnecessaryTypeIgnoreComment = true
113 | reportUnnecessaryIsInstance = false
114 |
115 |
116 | [tool.codespell]
117 | skip = 'docs/build/*'
118 |
119 |
120 | [tool.tox]
121 | legacy_tox_ini = """
122 | [tox]
123 | isolated_build = True
124 | envlist = py38, py39, py310, py311, py312
125 |
126 | [gh-actions]
127 | python =
128 | 3.8: py38
129 | 3.9: py39
130 | 3.10: py310
131 | 3.11: py311
132 | 3.12: py312
133 |
134 | [testenv]
135 | description = Runs test suite parallelized in the specified python enviornment and
136 | against number of available processes (up to 4).
137 | Run `tox -e py39 -- -n 0` to run tests in a python 3.9 with
138 | parallelization disabled.
139 | passenv = *
140 | extras = test
141 | deps = pytest-xdist
142 | commands = pytest tests/ {posargs: -n auto --maxprocesses=4}
143 |
144 |
145 | [testenv:coverage]
146 | description = Runs test suite and measures test-coverage. Fails if coverage is
147 | below 100 prcnt. Run `tox -e coverage -- -n 0` to disable parallelization.
148 | setenv = NUMBA_DISABLE_JIT=1
149 | usedevelop = true
150 | basepython = python3.10
151 | deps = {[testenv]deps}
152 | coverage[toml]
153 | pytest-cov
154 | beartype
155 | torch
156 | numpy
157 | phantom-types
158 | commands = pytest --cov-report term-missing --cov-config=pyproject.toml --cov-fail-under=100 --cov=phantom_tensors tests {posargs: -n auto --maxprocesses=4}
159 |
160 |
161 | [testenv:third-party]
162 | description = Runs test suite against optional 3rd party packages that phantom-tensors
163 | provides specialized support for.
164 | install_command = pip install --upgrade --upgrade-strategy eager {opts} {packages}
165 | basepython = python3.11
166 | deps = {[testenv]deps}
167 | beartype
168 | torch
169 | numpy
170 | phantom-types
171 |
172 |
173 | [testenv:pyright]
174 | description = Ensure that phantom-tensors's source code and test suite scan clean
175 | under pyright, and that phantom-tensors's public API has a 100 prcnt
176 | type-completeness score.
177 | usedevelop = true
178 | basepython = python3.11
179 | deps =
180 | --requirement deps/requirements-pyright.txt
181 | beartype
182 | torch
183 | numpy
184 | phantom-types
185 |
186 | commands = pyright tests/ src/ --pythonversion=3.8 -p pyrightconfig_py38.json
187 | pyright tests/ src/ --pythonversion=3.9
188 | pyright tests/ src/ --pythonversion=3.10
189 | pyright tests/ src/ --pythonversion=3.11
190 | pyright tests/ src/ --pythonversion=3.12
191 | pyright --ignoreexternal --verifytypes phantom_tensors
192 |
193 |
194 | [testenv:format]
195 | description = Applies auto-flake (e.g. remove unsused imports), black, and isort
196 | in-place on source files and test suite. Running this can help fix a
197 | failing `enforce-format` run.
198 | skip_install=true
199 | deps =
200 | autoflake
201 | black
202 | isort
203 | commands =
204 | autoflake --recursive --in-place --remove-duplicate-keys --remove-unused-variables src/ tests/
205 | isort src/ tests/
206 | black src/ tests/
207 |
208 |
209 | [testenv:enforce-format]
210 | description = Ensures that source materials code and docs and test suite adhere to
211 | formatting and code-quality standards.
212 | skip_install=true
213 | basepython=python3.11
214 | deps=black
215 | isort
216 | flake8
217 | pytest
218 | codespell
219 | commands=
220 | black src/ tests/ --diff --check
221 | isort src/ tests/ --diff --check
222 | flake8 src/ tests/
223 | codespell src/ docs/
224 | """
225 |
--------------------------------------------------------------------------------
/pyrightconfig_py38.json:
--------------------------------------------------------------------------------
1 | {}
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 |
2 | [flake8]
3 | extend-ignore = F811,D1,D205,D209,D213,D400,D401,D999,D202,E203,E501,W503,E721,F403,F405
4 | exclude = .git,__pycache__,docs/*,old,build,dis,tests/annotations/*
--------------------------------------------------------------------------------
/src/phantom_tensors/__init__.py:
--------------------------------------------------------------------------------
1 | # pyright: strict
2 |
3 | from typing import TYPE_CHECKING
4 |
5 | from phantom_tensors._internals.dim_binding import dim_binding_scope
6 |
7 | from .parse import parse
8 |
9 | __all__ = ["parse", "dim_binding_scope"]
10 |
11 |
12 | if not TYPE_CHECKING:
13 | try:
14 | from ._version import version as __version__
15 | except ImportError:
16 | __version__ = "unknown version"
17 | else: # pragma: no cover
18 | __version__: str
19 |
--------------------------------------------------------------------------------
/src/phantom_tensors/_internals/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rsokl/phantom-tensors/9569cb2fee80ae12b71f398c1bdcabdb15addf36/src/phantom_tensors/_internals/__init__.py
--------------------------------------------------------------------------------
/src/phantom_tensors/_internals/dim_binding.py:
--------------------------------------------------------------------------------
1 | # pyright: strict
2 | from __future__ import annotations
3 |
4 | from collections import defaultdict
5 | from contextvars import ContextVar, Token
6 | from functools import wraps
7 | from typing import Any, Callable, Iterable, Optional, Tuple, Type, TypeVar, Union, cast
8 |
9 | from typing_extensions import TypeAlias
10 |
11 | import phantom_tensors._internals.utils as _utils
12 | from phantom_tensors._internals.utils import LiteralLike, NewTypeLike, UnpackLike
13 |
14 | ShapeDimType: TypeAlias = Union[
15 | Type[int],
16 | Type[UnpackLike],
17 | # Literal[Type[Any]] -- can't actually express this
18 | # the following all bind as dimension symbols by-reference
19 | Type[TypeVar],
20 | NewTypeLike,
21 | LiteralLike,
22 | ]
23 |
24 | LiteralCheck: TypeAlias = Callable[[Any, Iterable[Any]], bool]
25 |
26 | F = TypeVar("F", bound=Callable[..., Any])
27 |
28 |
29 | bindings: ContextVar[Optional[dict[Any, int]]] = ContextVar("bindings", default=None)
30 |
31 |
32 | class DimBindContext:
33 | __slots__ = ("_tokens", "_depth")
34 |
35 | def __init__(self) -> None:
36 | self._tokens: dict[int, Token[Optional[dict[Any, int]]]] = {}
37 | self._depth: int = 0
38 |
39 | def __enter__(self) -> None:
40 | b = bindings.get()
41 | self._depth += 1
42 | b = {} if b is None else b.copy()
43 | self._tokens[self._depth] = bindings.set(b)
44 |
45 | def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
46 | if self._depth == 1:
47 | bindings.reset(self._tokens.pop(self._depth))
48 | self._depth -= 1
49 |
50 | def __call__(self, func: F) -> F:
51 | @wraps(func)
52 | def wrapper(*args: Any, **kwargs: Any):
53 | with self:
54 | return func(*args, **kwargs)
55 |
56 | return cast(F, wrapper)
57 |
58 |
59 | dim_binding_scope: DimBindContext = DimBindContext()
60 |
61 |
62 | def check(shape_type: Tuple[ShapeDimType, ...], shape: Tuple[int, ...]) -> bool:
63 | # Don't need to check types / values of `shape` -- assumed
64 | # to be donwstream of `ndarray.shape` call and thus already
65 | # validatated
66 | # Maybe enable extra strict mode where we do that checking
67 |
68 | # E.g. Tensor[A, B, B, C] -> matches == {A: [0], B: [1, 2], C: [3]}
69 | bound_symbols: defaultdict[ShapeDimType, list[int]] = defaultdict(list)
70 | validated_symbols: defaultdict[ShapeDimType, list[int]] = defaultdict(list)
71 |
72 | # These can be cached globally -- are independent of match pattern
73 | # E.g. Tensor[Literal[1]] -> validators {Literal[1]: lambda x: x == 1}
74 | validators: dict[ShapeDimType, Callable[[Any], bool]] = {}
75 |
76 | var_field_ind: Optional[int] = None # contains *Ts
77 |
78 | # TODO: Add caching to validation process.
79 | # - Should this use weakrefs?
80 | for n, dim_symbol in enumerate(shape_type):
81 | if dim_symbol is Any or dim_symbol is int:
82 | # E.g. Tensor[int, int] or Tensor[Any]
83 | # --> no constraints on shape
84 | continue
85 | if _utils.is_typevar_unpack(dim_symbol):
86 | if var_field_ind is not None:
87 | raise TypeError(
88 | f"Type-shape {shape_type} specifies more than one TypeVarTuple"
89 | )
90 | var_field_ind = n
91 | continue
92 |
93 | # if variadic tuple is present, need to use negative indexing to reference
94 | # location from the end of the tuple
95 | CURRENT_INDEX = n if var_field_ind is None else n - len(shape_type)
96 |
97 | # The following symbols bind to dimensions (by symbol-reference)
98 | # Some of them may also carry with them additional validation checks,
99 | # which need only be checked when the symbol is first bound
100 | _match_list = bound_symbols[dim_symbol]
101 |
102 | if _match_list:
103 | # We have already encountered symbol; do not need to validate or
104 | # extract validation function
105 | _match_list.append(CURRENT_INDEX)
106 | continue
107 |
108 | _validate_list = validated_symbols[dim_symbol]
109 |
110 | if _validate_list or dim_symbol in validators:
111 | del bound_symbols[dim_symbol]
112 | _validate_list.append(CURRENT_INDEX)
113 | continue
114 |
115 | if _utils.is_newtype(dim_symbol):
116 | _supertype = dim_symbol.__supertype__
117 | if _supertype is not int:
118 | if not issubclass(_supertype, int):
119 | raise TypeError(
120 | f"Dimensions expressed by NewTypes must be associated with an "
121 | f"int or subclass of int. shape-type {shape_type} contains a "
122 | f"NewType of supertype {_supertype}"
123 | )
124 | validators[dim_symbol] = lambda x, sp=_supertype: isinstance(x, sp)
125 | _match_list.append(CURRENT_INDEX)
126 | del _supertype
127 | elif isinstance(dim_symbol, TypeVar):
128 | _match_list.append(CURRENT_INDEX)
129 | elif _utils.is_literal(dim_symbol) and dim_symbol:
130 | _expected_literals = dim_symbol.__args__
131 | literal_check: LiteralCheck = lambda x, y=_expected_literals: any(
132 | x == val for val in y
133 | )
134 | validators[dim_symbol] = literal_check
135 | _validate_list.append(CURRENT_INDEX)
136 | del _expected_literals
137 |
138 | elif isinstance(dim_symbol, type) and issubclass(dim_symbol, int):
139 | validators[dim_symbol] = lambda x, type_=dim_symbol: isinstance(x, type_)
140 | _validate_list.append(CURRENT_INDEX)
141 | else:
142 | raise TypeError(
143 | f"Got shape-type {shape_type} with dim {dim_symbol}. Valid dimensions "
144 | f"are `type[int] | Unpack | TypeVar | NewType | Literal`"
145 | )
146 | if not _match_list:
147 | del bound_symbols[dim_symbol]
148 |
149 | if not _validate_list:
150 | del validated_symbols[dim_symbol]
151 |
152 | if var_field_ind is None and len(shape_type) != len(shape):
153 | # E.g. type: Tensor[A, B, C]
154 | # vs. shape: (1, 1) or (1, 1, 1, 1) # must have exactly 3 dim
155 | return False
156 |
157 | if var_field_ind is not None and len(shape) < len(shape_type) - 1:
158 | # E.g. type: Tensor[A, *Ts, C]
159 | # vs. shape: (1,) # should have at least 2 dim
160 | return False
161 |
162 | _bindings = bindings.get()
163 |
164 | for symbol, indices in bound_symbols.items():
165 | validation_fn = validators.get(symbol, None)
166 |
167 | if len(indices) == 1 and _bindings is None and validation_fn is None:
168 | continue
169 |
170 | actual_val = shape[indices[0]]
171 |
172 | if _bindings is None or symbol is Any or symbol is int:
173 | expected_val = actual_val
174 | else:
175 | if symbol in _bindings:
176 | expected_val = _bindings[symbol]
177 | else:
178 | if validation_fn is not None and not validation_fn(actual_val):
179 | return False
180 | _bindings[symbol] = actual_val
181 | expected_val = actual_val
182 | if not all(expected_val == shape[index] for index in indices):
183 | return False
184 |
185 | for symbol, indices in validated_symbols.items():
186 | validation_fn = validators[symbol]
187 | if not all(validation_fn(shape[index]) for index in indices):
188 | return False
189 |
190 | return True
191 |
--------------------------------------------------------------------------------
/src/phantom_tensors/_internals/parse.py:
--------------------------------------------------------------------------------
1 | # pyright: strict
2 | from __future__ import annotations
3 |
4 | from typing import Any, List, Tuple, Type, TypeVar, Union, cast, overload
5 |
6 | from typing_extensions import ClassVar, Protocol, TypeAlias, TypeVarTuple
7 |
8 | from phantom_tensors._internals import utils as _utils
9 | from phantom_tensors._internals.dim_binding import (
10 | ShapeDimType,
11 | bindings,
12 | check,
13 | dim_binding_scope,
14 | )
15 | from phantom_tensors.errors import ParseError
16 |
17 | __all__ = ["parse"]
18 |
19 | Ta = TypeVar("Ta", bound=Tuple[Any, Any])
20 | Tb = TypeVar("Tb")
21 | Ts = TypeVarTuple("Ts")
22 |
23 |
24 | class _Generic(Protocol):
25 | __origin__: Type[Any]
26 | __args__: Tuple[ShapeDimType, ...]
27 |
28 |
29 | class _Phantom(Protocol):
30 | __bound__: ClassVar[Union[Type[Any], Tuple[Type[Any], ...]]]
31 | __args__: Tuple[ShapeDimType, ...]
32 |
33 |
34 | class HasShape(Protocol):
35 | @property
36 | def shape(self) -> Any: ...
37 |
38 |
39 | TupleInt: TypeAlias = Tuple[int, ...]
40 |
41 | S1 = TypeVar("S1", bound=HasShape)
42 | S2 = TypeVar("S2", bound=HasShape)
43 | S3 = TypeVar("S3", bound=HasShape)
44 | S4 = TypeVar("S4", bound=HasShape)
45 | S5 = TypeVar("S5", bound=HasShape)
46 | S6 = TypeVar("S6", bound=HasShape)
47 |
48 | Ts1 = TypeVarTuple("Ts1")
49 | Ts2 = TypeVarTuple("Ts2")
50 | Ts3 = TypeVarTuple("Ts3")
51 | Ts4 = TypeVarTuple("Ts4")
52 | Ts5 = TypeVarTuple("Ts5")
53 | Ts6 = TypeVarTuple("Ts6")
54 |
55 | T1 = TypeVar("T1", bound=Tuple[Any, ...])
56 | T2 = TypeVar("T2", bound=Tuple[Any, ...])
57 | T3 = TypeVar("T3", bound=Tuple[Any, ...])
58 | T4 = TypeVar("T4", bound=Tuple[Any, ...])
59 | T5 = TypeVar("T5", bound=Tuple[Any, ...])
60 | T6 = TypeVar("T6", bound=Tuple[Any, ...])
61 |
62 |
63 | I1 = TypeVar("I1", bound=int)
64 | I2 = TypeVar("I2", bound=int)
65 | I3 = TypeVar("I3", bound=int)
66 | I4 = TypeVar("I4", bound=int)
67 | I5 = TypeVar("I5", bound=int)
68 | I6 = TypeVar("I6", bound=int)
69 |
70 |
71 | def _to_tuple(x: Ta | Tuple[Ta, ...]) -> Tuple[Ta, ...]:
72 | if len(x) == 2 and not isinstance(x[0], tuple):
73 | return (x,) # type: ignore
74 | return x
75 |
76 |
77 | class Parser:
78 | def get_shape_and_concrete_type(
79 | self, type_: Union[_Phantom, _Generic, HasShape]
80 | ) -> Tuple[Tuple[ShapeDimType, ...], Union[type, Tuple[type, ...]]]:
81 | """
82 | Extracts the concrete base type(s) and shape-type from a generic tensor type.
83 |
84 | Overwrite this method to add support for different varieties of
85 | generic tensor types.
86 |
87 | Parameters
88 | ----------
89 | tensor : Any
90 | The tensor whose type/shape is being checked.
91 |
92 | type_ : Any
93 | The tensor-type that contains the shape information to
94 | be extracted.
95 |
96 | Returns
97 | -------
98 | Tuple[type | Tuple[type, ...]], Tuple[ShapeDimType, ...]]
99 |
100 | Examples
101 | --------
102 | >>> from phantom_tensors import parse
103 | >>> from phantom_tensors.numpy import NDArray
104 | >>> parse.get_shape_and_concrete_type(NDArray[int, int])
105 | (numpy.ndarray, (int, int))
106 | """
107 | if hasattr(type_, "__origin__"):
108 | type_ = cast(_Generic, type_)
109 | type_shape = type_.__args__
110 | base_type = type_.__origin__
111 |
112 | elif hasattr(type_, "__bound__"):
113 | type_ = cast(_Phantom, type_)
114 | type_shape = type_.__args__
115 | base_type = type_.__bound__
116 | else:
117 | assert False
118 | return type_shape, base_type
119 |
120 | @overload
121 | def __call__(
122 | self,
123 | __a: Tuple[HasShape, Type[S1]],
124 | __b: Tuple[HasShape, Type[S2]],
125 | __c: Tuple[HasShape, Type[S3]],
126 | __d: Tuple[HasShape, Type[S4]],
127 | __e: Tuple[HasShape, Type[S5]],
128 | __f: Tuple[HasShape, Type[S6]],
129 | ) -> Tuple[S1, S2, S3, S4, S5, S6]: ...
130 |
131 | @overload
132 | def __call__(
133 | self,
134 | __a: Tuple[HasShape, Type[S1]],
135 | __b: Tuple[HasShape, Type[S2]],
136 | __c: Tuple[HasShape, Type[S3]],
137 | __d: Tuple[HasShape, Type[S4]],
138 | __e: Tuple[HasShape, Type[S5]],
139 | ) -> Tuple[S1, S2, S3, S4, S5]: ...
140 |
141 | @overload
142 | def __call__(
143 | self,
144 | __a: Tuple[HasShape, Type[S1]],
145 | __b: Tuple[HasShape, Type[S2]],
146 | __c: Tuple[HasShape, Type[S3]],
147 | __d: Tuple[HasShape, Type[S4]],
148 | ) -> Tuple[S1, S2, S3, S4]: ...
149 |
150 | @overload
151 | def __call__(
152 | self,
153 | __a: Tuple[HasShape, Type[S1]],
154 | __b: Tuple[HasShape, Type[S2]],
155 | __c: Tuple[HasShape, Type[S3]],
156 | ) -> Tuple[S1, S2, S3]: ...
157 |
158 | @overload
159 | def __call__(
160 | self,
161 | __a: HasShape,
162 | __b: Type[S1],
163 | ) -> S1: ...
164 |
165 | @overload
166 | def __call__(
167 | self,
168 | __a: Tuple[HasShape, Type[S1]],
169 | __b: Tuple[HasShape, Type[S2]],
170 | ) -> Tuple[S1, S2]: ...
171 |
172 | @overload
173 | def __call__(self, __a: Tuple[HasShape, Type[S1]]) -> S1: ...
174 |
175 | @overload
176 | def __call__(
177 | self,
178 | *tensor_type_pairs: Tuple[HasShape, Type[HasShape]] | HasShape | Type[HasShape],
179 | ) -> HasShape | Tuple[HasShape, ...]: ...
180 |
181 | @dim_binding_scope
182 | def __call__(
183 | self,
184 | *tensor_type_pairs: Tuple[HasShape, Type[HasShape]] | HasShape | Type[HasShape],
185 | ) -> HasShape | Tuple[HasShape, ...]:
186 | if len(tensor_type_pairs) == 0:
187 | raise ValueError("")
188 | if len(tensor_type_pairs) == 2 and not isinstance(tensor_type_pairs[0], tuple):
189 | tensor_type_pairs = (tensor_type_pairs,) # type: ignore
190 |
191 | pairs = cast(
192 | Tuple[Tuple[HasShape, Type[HasShape]], ...], _to_tuple(tensor_type_pairs) # type: ignore
193 | )
194 |
195 | out: List[HasShape] = []
196 |
197 | del tensor_type_pairs
198 |
199 | for tensor, type_ in pairs:
200 | type_shape, expected_type = self.get_shape_and_concrete_type(type_=type_)
201 |
202 | if not isinstance(tensor, expected_type):
203 | raise ParseError(f"Expected {expected_type}, got: {type(tensor)}")
204 |
205 | if not check(type_shape, tensor.shape):
206 | _bindings = bindings.get()
207 | assert _bindings is not None
208 | type_str = ", ".join(
209 | (
210 | f"{getattr(p, '__name__', repr(p))}={_bindings.get(p, '?')}"
211 | if not _utils.is_typevar_unpack(p)
212 | else "[...]"
213 | )
214 | for p in type_shape
215 | )
216 | if len(type_shape) == 1:
217 | # (A) -> (A,)
218 | type_str += ","
219 | raise ParseError(
220 | f"shape-{tuple(tensor.shape)} doesn't match shape-type ({type_str})"
221 | )
222 | out.append(tensor)
223 | if len(out) == 1:
224 | return out[0]
225 | return tuple(out)
226 |
227 |
228 | parse: Parser = Parser()
229 |
230 | # @overload
231 | # def __call___ints(
232 | # __a: Tuple[int, Type[I1]],
233 | # __b: Tuple[int, Type[I2]],
234 | # __c: Tuple[int, Type[I3]],
235 | # __d: Tuple[int, Type[I4]],
236 | # __e: Tuple[int, Type[I5]],
237 | # __f: Tuple[int, Type[I6]],
238 | # ) -> Tuple[I1, I2, I3, I4, I5, I6]:
239 | # ...
240 |
241 |
242 | # @overload
243 | # def __call___ints(
244 | # __a: Tuple[int, Type[I1]],
245 | # __b: Tuple[int, Type[I2]],
246 | # __c: Tuple[int, Type[I3]],
247 | # __d: Tuple[int, Type[I4]],
248 | # __e: Tuple[int, Type[I5]],
249 | # ) -> Tuple[I1, I2, I3, I4, I5]:
250 | # ...
251 |
252 |
253 | # @overload
254 | # def __call___ints(
255 | # __a: Tuple[int, Type[I1]],
256 | # __b: Tuple[int, Type[I2]],
257 | # __c: Tuple[int, Type[I3]],
258 | # __d: Tuple[int, Type[I4]],
259 | # ) -> Tuple[I1, I2, I3, I4]:
260 | # ...
261 |
262 |
263 | # @overload
264 | # def __call___ints(
265 | # __a: Tuple[int, Type[I1]],
266 | # __b: Tuple[int, Type[I2]],
267 | # __c: Tuple[int, Type[I3]],
268 | # ) -> Tuple[I1, I2, I3]:
269 | # ...
270 |
271 |
272 | # @overload
273 | # def __call___ints(
274 | # __a: int,
275 | # __b: Type[I1],
276 | # ) -> I1:
277 | # ...
278 |
279 |
280 | # @overload
281 | # def __call___ints(
282 | # __a: Tuple[int, Type[I1]],
283 | # __b: Tuple[int, Type[I2]],
284 | # ) -> Tuple[I1, I2]:
285 | # ...
286 |
287 |
288 | # @overload
289 | # def __call___ints(__a: Tuple[int, Type[I1]]) -> I1:
290 | # ...
291 |
292 |
293 | # @overload
294 | # def __call___ints(
295 | # *tensor_type_pairs: Tuple[int, Type[int]] | int | Type[int]
296 | # ) -> int | Tuple[int, ...]:
297 | # ...
298 |
299 |
300 | # @dim_binding_scope
301 | # def __call___ints(
302 | # *tensor_type_pairs: Tuple[int, Type[int]] | int | Type[int]
303 | # ) -> int | Tuple[int, ...]:
304 | # ...
305 |
306 | # @overload
307 | # def __call___tuples(
308 | # __a: Tuple[TupleInt, Type[Tuple[U[Ts1]]]],
309 | # __b: Tuple[TupleInt, Type[Tuple[U[Ts2]]]],
310 | # __c: Tuple[TupleInt, Type[Tuple[U[Ts3]]]],
311 | # __d: Tuple[TupleInt, Type[Tuple[U[Ts4]]]],
312 | # __e: Tuple[TupleInt, Type[Tuple[U[Ts5]]]],
313 | # __f: Tuple[TupleInt, Type[Tuple[U[Ts6]]]],
314 | # ) -> Tuple[
315 | # Tuple[U[Ts1]],
316 | # Tuple[U[Ts2]],
317 | # Tuple[U[Ts3]],
318 | # Tuple[U[Ts4]],
319 | # Tuple[U[Ts5]],
320 | # Tuple[U[Ts6]],
321 | # ]:
322 | # ...
323 |
324 |
325 | # @overload
326 | # def __call___tuples(
327 | # __a: Tuple[TupleInt, Type[Tuple[U[Ts1]]]],
328 | # __b: Tuple[TupleInt, Type[Tuple[U[Ts2]]]],
329 | # __c: Tuple[TupleInt, Type[Tuple[U[Ts3]]]],
330 | # __d: Tuple[TupleInt, Type[Tuple[U[Ts4]]]],
331 | # __e: Tuple[TupleInt, Type[Tuple[U[Ts5]]]],
332 | # ) -> Tuple[Tuple[U[Ts1]], Tuple[U[Ts2]], Tuple[U[Ts3]], Tuple[U[Ts4]], Tuple[U[Ts5]]]:
333 | # ...
334 |
335 |
336 | # @overload
337 | # def __call___tuples(
338 | # __a: Tuple[TupleInt, Type[Tuple[U[Ts1]]]],
339 | # __b: Tuple[TupleInt, Type[Tuple[U[Ts2]]]],
340 | # __c: Tuple[TupleInt, Type[Tuple[U[Ts3]]]],
341 | # __d: Tuple[TupleInt, Type[Tuple[U[Ts4]]]],
342 | # ) -> Tuple[Tuple[U[Ts1]], Tuple[U[Ts2]], Tuple[U[Ts3]], Tuple[U[Ts4]]]:
343 | # ...
344 |
345 |
346 | # @overload
347 | # def __call___tuples(
348 | # __a: Tuple[TupleInt, Type[Tuple[U[Ts1]]]],
349 | # __b: Tuple[TupleInt, Type[Tuple[U[Ts2]]]],
350 | # __c: Tuple[TupleInt, Type[Tuple[U[Ts3]]]],
351 | # ) -> Tuple[Tuple[U[Ts1]], Tuple[U[Ts2]], Tuple[U[Ts3]]]:
352 | # ...
353 |
354 |
355 | # @overload
356 | # def __call___tuples(
357 | # __a: TupleInt,
358 | # __b: T1,
359 | # ) -> T1:
360 | # ...
361 |
362 |
363 | # @overload
364 | # def __call___tuples(
365 | # __a: TupleInt,
366 | # __b: Type[Tuple[U[Ts1]]],
367 | # ) -> Tuple[U[Ts1]]:
368 | # ...
369 |
370 |
371 | # @overload
372 | # def __call___tuples(
373 | # __a: Tuple[TupleInt, Type[Tuple[U[Ts1]]]],
374 | # __b: Tuple[TupleInt, Type[Tuple[U[Ts2]]]],
375 | # ) -> Tuple[Tuple[U[Ts1]], Tuple[U[Ts2]]]:
376 | # ...
377 |
378 |
379 | # @overload
380 | # def __call___tuples(__a: Tuple[TupleInt, Type[Tuple[U[Ts1]]]]) -> Tuple[U[Ts1]]:
381 | # ...
382 |
383 |
384 | # @overload
385 | # def __call___tuples(
386 | # *tensor_type_pairs: Tuple[TupleInt, Type[Tuple[Any, ...]]]
387 | # | TupleInt
388 | # | Type[Tuple[Any, ...]]
389 | # ) -> TupleInt | Tuple[Any, ...]:
390 | # ...
391 |
392 |
393 | # @dim_binding_scope
394 | # def __call___tuples(
395 | # *tensor_type_pairs: Tuple[TupleInt, Type[Tuple[Any, ...]]]
396 | # | TupleInt
397 | # | Type[Tuple[Any, ...]]
398 | # ) -> TupleInt | Tuple[Any, ...]:
399 | # ...
400 |
--------------------------------------------------------------------------------
/src/phantom_tensors/_internals/utils.py:
--------------------------------------------------------------------------------
1 | # pyright: strict
2 |
3 | import abc
4 | from typing import Any, Literal, Tuple, Type
5 |
6 | from typing_extensions import Protocol, TypeGuard, TypeVarTuple, Unpack
7 |
8 | _Ts = TypeVarTuple("_Ts")
9 |
10 | UnpackType = type(Unpack[_Ts]) # type: ignore
11 | LiteralType = type(Literal[1])
12 |
13 |
14 | class CustomInstanceCheck(abc.ABCMeta):
15 | def __instancecheck__(self, instance: object) -> bool:
16 | return self.__instancecheck__(instance)
17 |
18 |
19 | class NewTypeLike(Protocol):
20 | __name__: str
21 | __supertype__: Type[Any]
22 |
23 | def __call__(self, x: Any) -> int: ...
24 |
25 |
26 | class NewTypeInt(Protocol):
27 | __name__: str
28 | __supertype__: Type[int]
29 |
30 | def __call__(self, x: Any) -> int: ...
31 |
32 |
33 | class UnpackLike(Protocol):
34 | _inst: Literal[True]
35 | _name: Literal[None]
36 | __origin__: Type[Any] = Unpack # type: ignore
37 | __args__: Tuple[TypeVarTuple]
38 | __parameters__: Tuple[TypeVarTuple]
39 | __module__: str
40 |
41 |
42 | class LiteralLike(Protocol):
43 | _inst: Literal[True]
44 | _name: Literal[None]
45 | __origin__: Type[Any] = Literal # type: ignore
46 | __args__: Tuple[Any, ...]
47 | __parameters__: Tuple[()]
48 | __module__: str
49 |
50 |
51 | class TupleGeneric(Protocol):
52 | __origin__: Type[Tuple[Any, ...]]
53 | __args__: Tuple[Type[Any], ...]
54 |
55 |
56 | def is_newtype(x: Any) -> TypeGuard[NewTypeLike]:
57 | return hasattr(x, "__supertype__")
58 |
59 |
60 | def is_newtype_int(x: Any) -> TypeGuard[NewTypeInt]:
61 | supertype = getattr(x, "__supertype__", None)
62 | if supertype is None:
63 | return False
64 | return issubclass(supertype, int)
65 |
66 |
67 | def is_typevar_unpack(x: Any) -> TypeGuard[UnpackLike]:
68 | return isinstance(x, UnpackType)
69 |
70 |
71 | def is_tuple_generic(x: Any) -> TypeGuard[TupleGeneric]:
72 | return getattr(x, "__origin__", None) is tuple
73 |
74 |
75 | def is_literal(x: Any) -> TypeGuard[LiteralLike]:
76 | return isinstance(x, LiteralType)
77 |
--------------------------------------------------------------------------------
/src/phantom_tensors/alphabet.py:
--------------------------------------------------------------------------------
1 | # pyright: strict
2 | """Convenient definitions of
3 |
4 | = NewType('', int)
5 |
6 | to be used as descriptive axis labels in shape types"""
7 |
8 | from typing import NewType as _NewType
9 |
10 | __all__ = [
11 | "A",
12 | "B",
13 | "C",
14 | "D",
15 | "E",
16 | "F",
17 | "G",
18 | "H",
19 | "I",
20 | "J",
21 | "K",
22 | "L",
23 | "M",
24 | "N",
25 | "O",
26 | "P",
27 | "Q",
28 | "R",
29 | "S",
30 | "T",
31 | "U",
32 | "V",
33 | "W",
34 | "X",
35 | "Y",
36 | "Z",
37 | ]
38 |
39 | A = _NewType("A", int)
40 | B = _NewType("B", int)
41 | C = _NewType("C", int)
42 | D = _NewType("D", int)
43 | E = _NewType("E", int)
44 | F = _NewType("F", int)
45 | G = _NewType("G", int)
46 | H = _NewType("H", int)
47 | I = _NewType("I", int) # noqa: E741
48 | J = _NewType("J", int)
49 | K = _NewType("K", int)
50 | L = _NewType("L", int)
51 | M = _NewType("M", int)
52 | N = _NewType("N", int)
53 | O = _NewType("O", int) # noqa: E741
54 | P = _NewType("P", int)
55 | Q = _NewType("Q", int)
56 | R = _NewType("R", int)
57 | S = _NewType("S", int)
58 | T = _NewType("T", int)
59 | U = _NewType("U", int)
60 | V = _NewType("V", int)
61 | W = _NewType("W", int)
62 | X = _NewType("X", int)
63 | Y = _NewType("Y", int)
64 | Z = _NewType("Z", int)
65 |
--------------------------------------------------------------------------------
/src/phantom_tensors/array.py:
--------------------------------------------------------------------------------
1 | # pyright: strict
2 | from typing import Any, Tuple
3 |
4 | import typing_extensions as _te
5 | from typing_extensions import Protocol, runtime_checkable
6 |
7 | __all__ = ["SupportsArray"]
8 |
9 |
10 | Shape = _te.TypeVarTuple("Shape")
11 |
12 |
13 | @runtime_checkable
14 | class SupportsArray(Protocol[_te.Unpack[Shape]]):
15 | def __array__(self) -> Any: ...
16 |
17 | @property
18 | def shape(self) -> Tuple[_te.Unpack[Shape]]: ...
19 |
--------------------------------------------------------------------------------
/src/phantom_tensors/errors.py:
--------------------------------------------------------------------------------
1 | # pyright: strict
2 | class ParseError(TypeError):
3 | pass
4 |
--------------------------------------------------------------------------------
/src/phantom_tensors/meta.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | __all__ = ["CustomInstanceCheck"]
4 |
5 |
6 | class CustomInstanceCheck(abc.ABCMeta):
7 | """Used to support custom runtime type checks."""
8 |
9 | def __instancecheck__(self, instance: object) -> bool:
10 | return self.__instancecheck__(instance)
11 |
--------------------------------------------------------------------------------
/src/phantom_tensors/numpy.py:
--------------------------------------------------------------------------------
1 | # pyright: strict
2 | try:
3 | from numpy import ndarray as _ndarray
4 | from numpy.typing import NDArray as _NDArray
5 | except ImportError:
6 | raise ImportError("You must install numpy in order to user `phantom_tensors.numpy`")
7 |
8 | from typing import TYPE_CHECKING, Any, Generic, Sequence, SupportsIndex, Tuple, Union
9 |
10 | import typing_extensions as _te
11 |
12 | from ._internals.dim_binding import check
13 | from .meta import CustomInstanceCheck
14 |
15 | __all__ = ["NDArray"]
16 |
17 |
18 | Shape = _te.TypeVarTuple("Shape")
19 |
20 |
21 | class NDArray(Generic[_te.Unpack[Shape]], _NDArray[Any]):
22 |
23 | if not TYPE_CHECKING:
24 | _cache = {}
25 |
26 | @classmethod
27 | def __class_getitem__(cls, key):
28 | if not isinstance(key, tuple):
29 | key = (key,)
30 |
31 | class PhantomNDArray(
32 | _ndarray,
33 | metaclass=CustomInstanceCheck,
34 | ):
35 | __origin__ = _ndarray
36 | # TODO: conform with ndarray[shape, dtype]
37 | __args__ = key
38 |
39 | @classmethod
40 | def __instancecheck__(cls, __instance: object) -> bool:
41 | if not isinstance(__instance, _ndarray):
42 | return False
43 | return check(key, __instance.shape)
44 |
45 | return PhantomNDArray
46 |
47 | @property
48 | def shape(self) -> Tuple[_te.Unpack[Shape]]: # type: ignore
49 | ...
50 |
51 | @shape.setter
52 | def shape(self, value: Union[SupportsIndex, Sequence[SupportsIndex]]) -> None: ...
53 |
--------------------------------------------------------------------------------
/src/phantom_tensors/parse.py:
--------------------------------------------------------------------------------
1 | from ._internals.parse import Parser, parse
2 |
3 | __all__ = ["parse", "Parser"]
4 |
--------------------------------------------------------------------------------
/src/phantom_tensors/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rsokl/phantom-tensors/9569cb2fee80ae12b71f398c1bdcabdb15addf36/src/phantom_tensors/py.typed
--------------------------------------------------------------------------------
/src/phantom_tensors/torch.py:
--------------------------------------------------------------------------------
1 | # pyright: strict
2 | try:
3 | from torch import Tensor as _Tensor
4 | except ImportError:
5 | raise ImportError(
6 | "You must install pytorch in order to user `phantom_tensors.torch`"
7 | )
8 |
9 | from typing import TYPE_CHECKING, Generic, Tuple
10 |
11 | import typing_extensions as _te
12 |
13 | from phantom_tensors._internals.dim_binding import check
14 | from phantom_tensors.meta import CustomInstanceCheck
15 |
16 | __all__ = ["Tensor"]
17 |
18 |
19 | Shape = _te.TypeVarTuple("Shape")
20 |
21 |
22 | class _NewMeta(CustomInstanceCheck, type(_Tensor)): ...
23 |
24 |
25 | class Tensor(Generic[_te.Unpack[Shape]], _Tensor):
26 | if not TYPE_CHECKING:
27 | # TODO: add caching
28 |
29 | @classmethod
30 | def __class_getitem__(cls, key):
31 | if not isinstance(key, tuple):
32 | key = (key,)
33 |
34 | class PhantomTensor(
35 | _Tensor,
36 | metaclass=_NewMeta,
37 | ):
38 | __origin__ = _Tensor
39 | __args__ = key
40 |
41 | @classmethod
42 | def __instancecheck__(cls, __instance: object) -> bool:
43 | if not isinstance(__instance, _Tensor):
44 | return False
45 | return check(key, __instance.shape)
46 |
47 | return PhantomTensor
48 |
49 | @property
50 | def shape(self) -> Tuple[_te.Unpack[Shape]]: # type: ignore
51 | ...
52 |
--------------------------------------------------------------------------------
/src/phantom_tensors/words.py:
--------------------------------------------------------------------------------
1 | # pyright: strict
2 | """Convenient definitions of
3 |
4 | = NewType('', int)
5 |
6 | to be used as descriptive axis names in shape types"""
7 | from typing import NewType as _NewType
8 |
9 | __all__ = [
10 | "Axis",
11 | "Batch",
12 | "Channel",
13 | "Embed",
14 | "Filter",
15 | "Height",
16 | "Kernel",
17 | "Logits",
18 | "Score",
19 | "Time",
20 | "Vocab",
21 | "Width",
22 | "Axis1",
23 | "Axis2",
24 | "Axis3",
25 | "Axis4",
26 | "Axis5",
27 | "Axis6",
28 | "Dim1",
29 | "Dim2",
30 | "Dim3",
31 | "Dim4",
32 | "Dim5",
33 | "Dim6",
34 | ]
35 |
36 | Axis = _NewType("Axis", int)
37 | Batch = _NewType("Batch", int)
38 | Channel = _NewType("Channel", int)
39 | Embed = _NewType("Embed", int)
40 | Filter = _NewType("Filter", int)
41 | Height = _NewType("Height", int)
42 | Kernel = _NewType("Kernel", int)
43 | Logits = _NewType("Logits", int)
44 | Score = _NewType("Score", int)
45 | Time = _NewType("Time", int)
46 | Vocab = _NewType("Vocab", int)
47 | Width = _NewType("Width", int)
48 |
49 | Axis1 = _NewType("Axis1", int)
50 | Axis2 = _NewType("Axis2", int)
51 | Axis3 = _NewType("Axis3", int)
52 | Axis4 = _NewType("Axis4", int)
53 | Axis5 = _NewType("Axis5", int)
54 | Axis6 = _NewType("Axis6", int)
55 |
56 | Dim1 = _NewType("Dim1", int)
57 | Dim2 = _NewType("Dim2", int)
58 | Dim3 = _NewType("Dim3", int)
59 | Dim4 = _NewType("Dim4", int)
60 | Dim5 = _NewType("Dim5", int)
61 | Dim6 = _NewType("Dim6", int)
62 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rsokl/phantom-tensors/9569cb2fee80ae12b71f398c1bdcabdb15addf36/tests/__init__.py
--------------------------------------------------------------------------------
/tests/annotations.py:
--------------------------------------------------------------------------------
1 | # pyright: strict
2 | from __future__ import annotations
3 |
4 | import sys
5 | from typing import Tuple as tuple
6 |
7 | import torch as tr
8 | from typing_extensions import assert_type
9 |
10 | from phantom_tensors import parse
11 | from phantom_tensors.alphabet import A, B
12 | from phantom_tensors.torch import Tensor
13 |
14 |
15 | def parse_tensors(x: tr.Tensor):
16 | assert_type(parse(x, Tensor[A]), Tensor[A])
17 | assert_type(parse((x, Tensor[A])), Tensor[A])
18 | assert_type(
19 | parse(
20 | (x, Tensor[A]),
21 | (x, Tensor[A, B]),
22 | ),
23 | tuple[Tensor[A], Tensor[A, B]],
24 | )
25 | assert_type(
26 | parse(
27 | (x, Tensor[A]),
28 | (x, Tensor[A, B]),
29 | (x, Tensor[B, A]),
30 | ),
31 | tuple[Tensor[A], Tensor[A, B], Tensor[B, A]],
32 | )
33 |
34 | assert_type(
35 | parse(
36 | (x, Tensor[A]),
37 | (x, Tensor[A, B]),
38 | (x, Tensor[B, A]),
39 | (x, Tensor[B, B]),
40 | ),
41 | tuple[Tensor[A], Tensor[A, B], Tensor[B, A], Tensor[B, B]],
42 | )
43 |
44 | assert_type(
45 | parse(
46 | (x, Tensor[A]),
47 | (x, Tensor[A, B]),
48 | (x, Tensor[B, A]),
49 | (x, Tensor[B, B]),
50 | (x, Tensor[B]),
51 | ),
52 | tuple[Tensor[A], Tensor[A, B], Tensor[B, A], Tensor[B, B], Tensor[B]],
53 | )
54 |
55 | assert_type(
56 | parse(
57 | (x, Tensor[A]),
58 | (x, Tensor[A, B]),
59 | (x, Tensor[B, A]),
60 | (x, Tensor[B, B]),
61 | (x, Tensor[B]),
62 | (x, Tensor[A]),
63 | ),
64 | tuple[
65 | Tensor[A], Tensor[A, B], Tensor[B, A], Tensor[B, B], Tensor[B], Tensor[A]
66 | ],
67 | )
68 |
69 |
70 | def check_bad_tensor_parse(x: tr.Tensor):
71 | parse(1, tuple[A]) # type: ignore
72 | parse((1,), A) # type: ignore
73 |
74 | parse(x, A) # type: ignore
75 | parse((x, A)) # type: ignore
76 | parse((x, A), (x, tuple[A])) # type: ignore
77 |
78 |
79 | def check_readme_blurb_one():
80 | import numpy as np
81 |
82 | from phantom_tensors import parse
83 | from phantom_tensors.numpy import NDArray
84 |
85 | # runtime: checks that shapes (2, 3) and (3, 2)
86 | # match (A, B) and (B, A) pattern
87 | if sys.version_info < (3, 8):
88 | return
89 |
90 | x, y = parse(
91 | (np.ones((2, 3)), NDArray[A, B]),
92 | (np.ones((3, 2)), NDArray[B, A]),
93 | )
94 |
95 | assert_type(x, NDArray[A, B])
96 | assert_type(y, NDArray[B, A])
97 |
98 |
99 | def check_readme2():
100 | from typing import Any
101 |
102 | import numpy as np
103 |
104 | from phantom_tensors import parse
105 | from phantom_tensors.alphabet import A, B # these are just NewType(..., int) types
106 | from phantom_tensors.numpy import NDArray
107 |
108 | def func_on_2d(x: NDArray[Any, Any]): ...
109 |
110 | def func_on_3d(x: NDArray[Any, Any, Any]): ...
111 |
112 | def func_on_any_arr(x: np.ndarray[Any, Any]): ...
113 |
114 | if sys.version_info < (3, 8):
115 | return
116 |
117 | # runtime: ensures shape of arr_3d matches (A, B, A) patterns
118 | arr_3d = parse(np.ones((3, 5, 3)), NDArray[A, B, A])
119 |
120 | func_on_2d(arr_3d) # type: ignore
121 | func_on_3d(arr_3d) # static type checker: OK
122 | func_on_any_arr(arr_3d) # static type checker: OK
123 |
124 |
125 | def check_readme3():
126 | from typing import TypeVar, cast
127 |
128 | import torch as tr
129 | from beartype import beartype
130 |
131 | from phantom_tensors import dim_binding_scope
132 | from phantom_tensors.alphabet import A, B, C
133 | from phantom_tensors.torch import Tensor
134 |
135 | T1 = TypeVar("T1")
136 | T2 = TypeVar("T2")
137 | T3 = TypeVar("T3")
138 |
139 | @dim_binding_scope
140 | @beartype # <- adds runtime type checking to function's interfaces
141 | def buggy_matmul(x: Tensor[T1, T2], y: Tensor[T2, T3]) -> Tensor[T1, T3]:
142 | out = x @ x.T # <- wrong operation!
143 | # Will return shape-(A, A) tensor, not (A, C)
144 | # (and we lie to the static type checker to try to get away with it)
145 | return cast(Tensor[T1, T3], out)
146 |
147 | x, y = parse(
148 | (tr.ones(3, 4), Tensor[A, B]),
149 | (tr.ones(4, 5), Tensor[B, C]),
150 | )
151 | assert_type(x, Tensor[A, B])
152 | assert_type(y, Tensor[B, C])
153 |
154 | # At runtime:
155 | # beartype raises and catches shape-mismatch of output.
156 | # Function should return shape-(A, C) but, at runtime, returns
157 | # shape-(A, A)
158 | z = buggy_matmul(x, y) # beartype roars!
159 |
160 | assert_type(z, Tensor[A, C])
161 |
162 |
163 | def check_readme4():
164 | from typing import Any
165 |
166 | import numpy as np
167 | import torch as tr
168 |
169 | from phantom_tensors import parse
170 | from phantom_tensors.alphabet import A, B
171 | from phantom_tensors.numpy import NDArray
172 | from phantom_tensors.torch import Tensor
173 |
174 | if sys.version_info < (3, 8):
175 | return
176 |
177 | t1, arr, t2 = parse(
178 | # <- Runtime: enter dimension-binding context
179 | (tr.rand(9, 2, 9), Tensor[B, A, B]), # <-binds A=2 & B=9
180 | (np.ones((2,)), NDArray[A]), # <- checks A==2
181 | (tr.rand(9), Tensor[B]), # <- checks B==9
182 | ) # <- Runtime: exit dimension-binding scope
183 | # Statically: casts t1, arr, t2 to shape-typed Tensors
184 |
185 | # static type checkers now see
186 | # t1: Tensor[B, A, B]
187 | # arr: NDArray[A]
188 | # t2: Tensor[B]
189 |
190 | w = parse(tr.rand(78), Tensor[A])
191 | # <- binds A=78 within this context
192 |
193 | assert_type(t1, Tensor[B, A, B])
194 | assert_type(arr, NDArray[A])
195 | assert_type(t2, Tensor[B])
196 | assert_type(w, Tensor[A])
197 |
198 | def vanilla_numpy(x: np.ndarray[Any, Any]): ...
199 |
200 | def vanilla_torch(x: tr.Tensor): ...
201 |
202 | vanilla_numpy(arr) # type checker: OK
203 | vanilla_torch(t1) # type checker: OK
204 | vanilla_torch(arr) # type: ignore [type checker: Error!]
205 |
206 |
207 | def check_phantom_example():
208 | from typing import Any
209 |
210 | import torch as tr
211 |
212 | from phantom import Phantom
213 | from phantom_tensors import parse
214 | from phantom_tensors.torch import Tensor
215 |
216 | class EvenOnly(int, Phantom[Any], predicate=lambda x: x % 2 == 0): ...
217 |
218 | assert_type(parse(tr.ones(1, 0), Tensor[int, EvenOnly]), Tensor[int, EvenOnly])
219 | assert_type(parse(tr.ones(1, 2), Tensor[int, EvenOnly]), Tensor[int, EvenOnly])
220 | assert_type(parse(tr.ones(1, 4), Tensor[int, EvenOnly]), Tensor[int, EvenOnly])
221 |
222 | parse(tr.ones(1, 3), Tensor[int, EvenOnly]) # runtime: ParseError
223 |
224 |
225 | def check_beartype_example():
226 | from typing import Any
227 |
228 | import pytest
229 | from beartype import beartype
230 | from typing_extensions import assert_type
231 |
232 | from phantom_tensors import dim_binding_scope, parse
233 | from phantom_tensors.alphabet import A, B, C
234 | from phantom_tensors.torch import Tensor
235 |
236 | # @dim_binding_scope:
237 | # ensures A, B, C consistent across all input/output tensor shapes
238 | # within scope of function
239 | @dim_binding_scope
240 | @beartype # <-- adds isinstance checks on inputs & outputs
241 | def matrix_multiply(x: Tensor[A, B], y: Tensor[B, C]) -> Tensor[A, C]:
242 | a, _ = x.shape
243 | _, c = y.shape
244 | return parse(tr.rand(a, c), Tensor[A, C])
245 |
246 | @beartype
247 | def needs_vector(x: Tensor[Any]): ...
248 |
249 | x, y = parse(
250 | (tr.rand(3, 4), Tensor[A, B]),
251 | (tr.rand(4, 5), Tensor[B, C]),
252 | )
253 |
254 | z = matrix_multiply(x, y)
255 | assert_type(z, Tensor[A, C])
256 |
257 | with pytest.raises(Exception):
258 | # beartype raises error: input Tensor[A, C] doesn't match Tensor[A]
259 | needs_vector(z) # type: ignore
260 |
261 | with pytest.raises(Exception):
262 | # beartype raises error: inputs Tensor[A, B], Tensor[A, B] don't match signature
263 | matrix_multiply(x, x) # type: ignore
264 |
265 |
266 | # def check_parse_ints(x: int | Literal[1]):
267 | # assert_type(
268 | # parse_ints(x, A),
269 | # A,
270 | # )
271 | # assert_type(
272 | # parse_ints((x, A)),
273 | # A,
274 | # )
275 | # assert_type(
276 | # parse_ints(
277 | # (1, A),
278 | # (2, B),
279 | # ),
280 | # tuple[A, B],
281 | # )
282 | # assert_type(
283 | # parse_ints(
284 | # (1, A),
285 | # (2, B),
286 | # (2, B),
287 | # ),
288 | # tuple[A, B, B],
289 | # )
290 | # assert_type(
291 | # parse_ints(
292 | # (1, A),
293 | # (2, B),
294 | # (2, B),
295 | # (1, A),
296 | # ),
297 | # tuple[A, B, B, A],
298 | # )
299 | # assert_type(
300 | # parse_ints(
301 | # (1, A),
302 | # (2, B),
303 | # (2, B),
304 | # (1, A),
305 | # (2, B),
306 | # ),
307 | # tuple[A, B, B, A, B],
308 | # )
309 | # assert_type(
310 | # parse_ints((1, A), (2, B), (2, B), (1, A), (2, B), (1, A)),
311 | # tuple[A, B, B, A, B, A],
312 | # )
313 |
314 |
315 | # def check_parse_tuples(x: int):
316 |
317 | # assert_type(
318 | # parse_tuples((x,), tuple[A]),
319 | # tuple[A],
320 | # )
321 |
322 | # assert_type(
323 | # parse_tuples(((x,), tuple[A])),
324 | # tuple[A],
325 | # )
326 | # assert_type(
327 | # parse_tuples(
328 | # ((x,), tuple[A]),
329 | # ((x, x), tuple[A, A]),
330 | # ),
331 | # tuple[
332 | # tuple[A],
333 | # tuple[A, A],
334 | # ],
335 | # )
336 | # assert_type(
337 | # parse_tuples(
338 | # ((x,), tuple[A]),
339 | # ((x, x), tuple[A, A]),
340 | # ((x, x), tuple[A, A, B]),
341 | # ),
342 | # tuple[
343 | # tuple[A],
344 | # tuple[A, A],
345 | # tuple[A, A, B],
346 | # ],
347 | # )
348 | # assert_type(
349 | # parse_tuples(
350 | # ((x,), tuple[A]),
351 | # ((x, x), tuple[A, A]),
352 | # ((x, x, x), tuple[A, A, B]),
353 | # ((x, x, x), tuple[A, A, B]),
354 | # ),
355 | # tuple[
356 | # tuple[A],
357 | # tuple[A, A],
358 | # tuple[A, A, B],
359 | # tuple[A, A, B],
360 | # ],
361 | # )
362 | # assert_type(
363 | # parse_tuples(
364 | # ((x,), tuple[A]),
365 | # ((x, x), tuple[A, A]),
366 | # ((x, x, x), tuple[A, A, B]),
367 | # ((x, x, x), tuple[A, A, B]),
368 | # ((x,), tuple[A]),
369 | # ),
370 | # tuple[tuple[A], tuple[A, A], tuple[A, A, B], tuple[A, A, B], tuple[A]],
371 | # )
372 | # assert_type(
373 | # parse_tuples(
374 | # ((x,), tuple[A]),
375 | # ((x, x), tuple[A, A]),
376 | # ((x, x, x), tuple[A, A, B]),
377 | # ((x, x, x), tuple[A, A, B]),
378 | # ((x,), tuple[A]),
379 | # ((x,), tuple[B]),
380 | # ),
381 | # tuple[
382 | # tuple[A], tuple[A, A], tuple[A, A, B], tuple[A, A, B], tuple[A], tuple[B]
383 | # ],
384 | # )
385 |
386 |
387 | # def check_bad_int_parse(x: tr.Tensor, y: tuple[int, ...], z: int):
388 | # parse_ints(x, A)
389 | # parse_ints(y, A)
390 | # parse_ints((x, A))
391 | # parse_ints((y, A))
392 |
393 | # parse_ints(z, tuple[A])
394 | # parse_ints(z, Tensor[A])
395 |
396 | # def check_bad_tuple_parse(x: tr.Tensor, y: tuple[int, ...], z: int):
397 | # parse_tuples(x, tuple[A])
398 | # parse_tuples(z, tuple[A])
399 |
400 | # parse_tuples((x, tuple[A]))
401 | # parse_tuples((z, tuple[A]))
402 |
403 | # parse_ints(y, A)
404 | # parse_ints(y, Tensor[A])
405 |
--------------------------------------------------------------------------------
/tests/arrlike.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Tuple
2 |
3 |
4 | class ImplementsArray:
5 | def __init__(self, shape: Tuple[int, ...]) -> None:
6 | assert all(i >= 0 and isinstance(i, int) for i in shape)
7 | self._shape = shape
8 |
9 | @property
10 | def shape(self) -> Tuple[int, ...]:
11 | return self._shape
12 |
13 | def __array__(self) -> Any: ...
14 |
15 |
16 | def arr(*shape: int) -> ImplementsArray:
17 | return ImplementsArray(shape)
18 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import importlib.metadata
2 |
3 | # Skip collection of tests that require additional dependencies
4 | collect_ignore_glob = []
5 |
6 | OPTIONAL_TEST_DEPENDENCIES = (
7 | "numpy",
8 | "torch",
9 | "beartype",
10 | "phantom-types",
11 | )
12 |
13 | _installed = {dist.metadata["Name"] for dist in importlib.metadata.distributions()}
14 |
15 | if any(_module_name not in _installed for _module_name in OPTIONAL_TEST_DEPENDENCIES):
16 | collect_ignore_glob.append("*third_party.py")
17 |
--------------------------------------------------------------------------------
/tests/test_letters_and_words.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import phantom_tensors.alphabet as alphabet
4 | import phantom_tensors.words as words
5 | from phantom_tensors.alphabet import __all__ as all_letters
6 | from phantom_tensors.words import __all__ as all_words
7 |
8 |
9 | @pytest.mark.parametrize("name", all_letters)
10 | def test_shipped_letters_are_named_ints(name: str):
11 | Type = getattr(alphabet, name)
12 | assert Type.__name__ == name
13 | assert Type.__supertype__ is int
14 |
15 |
16 | @pytest.mark.parametrize("name", all_words)
17 | def test_shipped_words_are_named_ints(name: str):
18 | Type = getattr(words, name)
19 | assert Type.__name__ == name
20 | assert Type.__supertype__ is int
21 |
--------------------------------------------------------------------------------
/tests/test_prototype.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Any, Literal as L, NewType, TypeVar
3 |
4 | import pytest
5 | from pytest import param
6 | from typing_extensions import TypeVarTuple, Unpack as U
7 |
8 | import phantom_tensors
9 | from phantom_tensors import dim_binding_scope, parse
10 | from phantom_tensors.array import SupportsArray as Array
11 | from phantom_tensors.errors import ParseError
12 | from tests.arrlike import arr
13 |
14 | T = TypeVar("T")
15 | Ts = TypeVarTuple("Ts")
16 |
17 | A = NewType("A", int)
18 | B = NewType("B", int)
19 | C = NewType("C", int)
20 |
21 | parse_xfail = pytest.mark.xfail(raises=ParseError)
22 |
23 |
24 | def test_version():
25 | assert phantom_tensors.__version__ != "unknown"
26 |
27 |
28 | def test_parse_error_msg():
29 | with pytest.raises(
30 | ParseError,
31 | match=re.escape("shape-(2, 1) doesn't match shape-type ([...], A=2, A=2)"),
32 | ):
33 | parse(arr(2, 1), Array[U[Ts], A, A]) # type: ignore
34 |
35 |
36 | @pytest.mark.parametrize(
37 | "tensor_type_pairs",
38 | [
39 | # (arr(), Array[()]),
40 | (arr(), Array[U[Ts]]), # type: ignore
41 | (arr(2), Array[A]),
42 | (arr(2), Array[int]),
43 | (arr(2), Array[Any]),
44 | (arr(2), Array[U[Ts]]), # type: ignore
45 | (arr(2), Array[U[Ts], A]), # type: ignore
46 | (arr(2), Array[A, U[Ts]]), # type: ignore
47 | (arr(2, 2), Array[A, A]),
48 | (arr(2, 2), Array[U[Ts], A, A]), # type: ignore
49 | (arr(2, 2), Array[A, U[Ts], A]), # type: ignore
50 | (arr(2, 2), Array[A, A, U[Ts]]), # type: ignore
51 | (arr(1, 3, 2, 2), Array[U[Ts], A, A]), # type: ignore
52 | (arr(2, 1, 3, 2), Array[A, U[Ts], A]), # type: ignore
53 | (arr(2, 2, 1, 3), Array[A, A, U[Ts]]), # type: ignore
54 | (arr(1, 2, 1, 3, 2), Array[A, B, U[Ts], B]), # type: ignore
55 | (arr(1, 2, 3), Array[Any, Any, Any]),
56 | (arr(1, 2, 3), Array[int, int, int]),
57 | (arr(2, 1, 2), Array[A, B, A]),
58 | (arr(2, 1, 3), Array[A, B, C]),
59 | (
60 | (arr(5), Array[A]),
61 | (arr(5, 2), Array[A, B]),
62 | ),
63 | (arr(1), Array[L[1]]),
64 | (arr(3), Array[L[1, 2, 3]]),
65 | (arr(1, 2), Array[L[1], L[2]]),
66 | (arr(1, 2, 1), Array[L[1], L[2], L[1]]),
67 | param((arr(), Array[int]), marks=parse_xfail),
68 | param((arr(), Array[int, U[Ts]]), marks=parse_xfail), # type: ignore
69 | param((arr(2), Array[int, int]), marks=parse_xfail),
70 | param((arr(2, 4), Array[A, A]), marks=parse_xfail),
71 | param((arr(2, 1, 1), Array[A, B, A]), marks=parse_xfail),
72 | param((arr(1, 1, 2), Array[A, A, A]), marks=parse_xfail),
73 | param((arr(2, 1, 1), Array[A, U[Ts], A]), marks=parse_xfail), # type: ignore
74 | param((arr(1), Array[A, B, C]), marks=parse_xfail),
75 | param(((arr(2, 4), Array[A, A]),), marks=parse_xfail),
76 | param(
77 | (
78 | (arr(4), Array[A]),
79 | (arr(5), Array[A]),
80 | ),
81 | marks=parse_xfail,
82 | ),
83 | param((arr(3), Array[L[1]]), marks=parse_xfail),
84 | param((arr(2, 2), Array[L[1], L[2]]), marks=parse_xfail),
85 | param((arr(1, 1), Array[L[1], L[2]]), marks=parse_xfail),
86 | param(
87 | (arr(1, 1, 1), Array[L[1], L[2], L[1]]),
88 | marks=parse_xfail,
89 | ),
90 | ],
91 | )
92 | def test_parse_consistent_types(tensor_type_pairs):
93 | parse(*tensor_type_pairs)
94 |
95 |
96 | def test_parse_in_and_out_of_binding_scope():
97 | with dim_binding_scope:
98 |
99 | parse(arr(2), Array[A]) # binds A=2
100 |
101 | with pytest.raises(ParseError):
102 | parse(arr(3), Array[A])
103 |
104 | parse(arr(2), Array[A])
105 |
106 | parse(arr(2, 4), Array[A, B]) # binds B=4
107 | parse(arr(2, 9), Array[A, int])
108 |
109 | with pytest.raises(ParseError):
110 | parse(arr(2), Array[B])
111 |
112 | # no dims bound
113 | parse(arr(1, 3, 3, 1), Array[B, A, A, B]) # no dims bound
114 | parse(arr(1, 4, 4, 1), Array[B, A, A, B])
115 |
116 | parse(
117 | (arr(9), Array[B]),
118 | (arr(9, 2, 2, 9), Array[B, A, A, B]),
119 | )
120 |
121 |
122 | def test_parse_bind_multiple():
123 | with dim_binding_scope: # enter dimension-binding scope
124 | parse(
125 | (arr(2), Array[A]), # <-binds A=2
126 | (arr(9), Array[B]), # <-binds B=9
127 | (arr(9, 2, 9), Array[B, A, B]), # <-checks A & B
128 | )
129 |
130 | with pytest.raises(
131 | ParseError,
132 | match=re.escape("shape-(78,) doesn't match shape-type (A=2,)"),
133 | ):
134 | # can't re-bind A within scope
135 | parse(arr(78), Array[A])
136 |
137 | with pytest.raises(
138 | ParseError,
139 | match=re.escape("shape-(22,) doesn't match shape-type (B=9,)"),
140 | ):
141 | # can't re-bind B within scope
142 | parse(arr(22), Array[B])
143 |
144 | parse(arr(2), Array[A])
145 | parse(arr(9), Array[B])
146 |
147 | # exit dimension-binding scope
148 |
149 | parse(arr(78, 22), Array[A, B]) # now ok
150 |
151 |
152 | AStr = NewType("AStr", str)
153 |
154 |
155 | @pytest.mark.parametrize(
156 | "bad_type",
157 | [
158 | Array[AStr],
159 | Array[str],
160 | Array[int, str],
161 | Array[U[Ts], str], # type: ignore
162 | ],
163 | )
164 | def test_bad_type_validation(bad_type):
165 | with pytest.raises(TypeError):
166 | parse(arr(1), bad_type)
167 |
--------------------------------------------------------------------------------
/tests/test_third_party.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import NewType, TypeVar, cast
3 |
4 | import numpy as np
5 | import pytest
6 | import torch as tr
7 | from beartype import beartype
8 | from beartype.roar import (
9 | BeartypeCallHintParamViolation,
10 | BeartypeCallHintReturnViolation,
11 | )
12 | from typing_extensions import TypeVarTuple, Unpack as U
13 |
14 | from phantom import Phantom
15 | from phantom_tensors import dim_binding_scope, parse
16 | from phantom_tensors.alphabet import A, B, C
17 | from phantom_tensors.array import SupportsArray as Array
18 | from phantom_tensors.errors import ParseError
19 | from phantom_tensors.numpy import NDArray
20 | from phantom_tensors.torch import Tensor
21 | from tests.arrlike import arr
22 |
23 | T = TypeVar("T")
24 | Ts = TypeVarTuple("Ts")
25 |
26 |
27 | class One_to_Three(int, Phantom, predicate=lambda x: 0 < x < 4): ...
28 |
29 |
30 | class Ten_or_Eleven(int, Phantom, predicate=lambda x: 10 <= x <= 11): ...
31 |
32 |
33 | class EvenOnly(int, Phantom, predicate=lambda x: x % 2 == 0): ...
34 |
35 |
36 | NewOneToThree = NewType("NewOneToThree", One_to_Three)
37 |
38 |
39 | def test_NDArray():
40 | assert issubclass(NDArray, np.ndarray)
41 | assert issubclass(NDArray[A], np.ndarray)
42 |
43 | parse(np.ones((2,)), NDArray[A])
44 | with pytest.raises(ParseError):
45 | parse(np.ones((2, 3)), NDArray[A, A])
46 |
47 |
48 | def test_Tensor():
49 | assert issubclass(Tensor, tr.Tensor)
50 | assert issubclass(Tensor[A], tr.Tensor)
51 |
52 | parse(tr.ones((2,)), Tensor[A])
53 | with pytest.raises(ParseError):
54 | parse(tr.ones((2, 3)), Tensor[A, A])
55 |
56 |
57 | @pytest.mark.parametrize(
58 | "tensor_type_pairs",
59 | [
60 | (tr.ones(2), NDArray[int]), # type mismatch
61 | (np.ones((2,)), Tensor[int]), # type mismatch
62 | (arr(10, 2, 10), Array[One_to_Three, int, Ten_or_Eleven]),
63 | (arr(10, 2, 10), Array[NewOneToThree, int, Ten_or_Eleven]),
64 | (arr(2, 2, 8), Array[NewOneToThree, int, Ten_or_Eleven]),
65 | (arr(0, 10), Array[One_to_Three, U[Ts], Ten_or_Eleven]), # type: ignore
66 | (arr(2, 2, 0), Array[One_to_Three, U[Ts], Ten_or_Eleven]), # type: ignore
67 | (arr(2, 0, 0, 0), Array[One_to_Three, U[Ts], Ten_or_Eleven]), # type: ignore
68 | (arr(0, 0, 2, 0), Array[U[Ts], One_to_Three, Ten_or_Eleven]), # type: ignore
69 | ],
70 | )
71 | def test_parse_inconsistent_types(tensor_type_pairs):
72 | with pytest.raises(ParseError):
73 | parse(*tensor_type_pairs)
74 |
75 |
76 | def test_phantom_checks():
77 | assert not isinstance(np.ones((2,)), Tensor[int]) # type: ignore
78 | assert not isinstance(tr.ones((2,)), NDArray[int]) # type: ignore
79 |
80 | assert not isinstance(tr.ones((2, 2)), Tensor[int]) # type: ignore
81 | assert not isinstance(np.ones((2, 2)), NDArray[int]) # type: ignore
82 |
83 |
84 | def test_type_var_with_beartype():
85 | @dim_binding_scope
86 | @beartype
87 | def diag(sqr: Tensor[T, T]) -> Tensor[T]:
88 | return cast(Tensor[T], tr.diag(sqr))
89 |
90 | non_sqr = parse(tr.ones(2, 3), Tensor[A, B])
91 | with pytest.raises(BeartypeCallHintParamViolation):
92 | diag(non_sqr) # type: ignore
93 |
94 |
95 | def test_matmul_example():
96 | @dim_binding_scope
97 | @beartype
98 | def matrix_multiply(x: Tensor[A, B], y: Tensor[B, C]) -> Tensor[A, C]:
99 | out = x @ x.T
100 | return cast(Tensor[A, C], out)
101 |
102 | x, y = parse(
103 | (tr.ones(3, 4), Tensor[A, B]),
104 | (tr.ones(4, 5), Tensor[B, C]),
105 | )
106 | # x # type revealed: Tensor[A, B]
107 | # y # type revealed: Tensor[B, C]
108 |
109 | with pytest.raises(BeartypeCallHintReturnViolation):
110 | matrix_multiply(x, y)
111 |
112 |
113 | def test_runtime_checking_with_beartype():
114 | @dim_binding_scope
115 | # ^ ensures A, B, C consistent across all input/output tensor shapes
116 | # within scope of function
117 | @beartype
118 | def matrix_multiply(x: Tensor[A, B], y: Tensor[B, C]) -> Tensor[A, C]:
119 | a, b = x.shape
120 | b, c = y.shape
121 | return cast(Tensor[A, C], tr.ones(a, c))
122 |
123 | @beartype
124 | def needs_vector(x: Tensor[int]): ...
125 |
126 | x, y = parse(
127 | (tr.ones(3, 4), Tensor[A, B]),
128 | (tr.ones(4, 5), Tensor[B, C]),
129 | )
130 | # x # type revealed: Tensor[A, B]
131 | # y # type revealed: Tensor[B, C]
132 |
133 | z = matrix_multiply(x, y)
134 | # z # type revealed: Tensor[A, C]
135 |
136 | with pytest.raises(Exception):
137 | needs_vector(z) # type: ignore
138 |
139 | with pytest.raises(Exception):
140 | matrix_multiply(x, x) # type: ignore
141 |
142 |
143 | def test_catches_wrong_instance():
144 | with pytest.raises(
145 | ParseError,
146 | match=re.escape(
147 | "Expected , got: "
148 | ),
149 | ):
150 | parse(tr.tensor(1), NDArray[A, B])
151 |
152 | with pytest.raises(
153 | ParseError,
154 | match=re.escape(
155 | "Expected , got: "
156 | ),
157 | ):
158 | parse(np.array(1), Tensor[A])
159 |
160 |
161 | def test_isinstance_works():
162 | with dim_binding_scope:
163 |
164 | assert isinstance(tr.ones(2), Tensor[A]) # type: ignore
165 | assert not isinstance(tr.ones(3), Tensor[A]) # type: ignore
166 | assert isinstance(tr.ones(2), Tensor[A]) # type: ignore
167 |
168 | assert isinstance(tr.ones(2, 4), Tensor[A, B]) # type: ignore
169 | assert not isinstance(tr.ones(2), Tensor[B]) # type: ignore
170 | assert isinstance(tr.ones(4), Tensor[B]) # type: ignore
171 | assert isinstance(tr.ones(4, 2, 2, 4), Tensor[B, A, A, B]) # type: ignore
172 |
173 | assert isinstance(tr.ones(1, 3, 3, 1), Tensor[B, A, A, B]) # type: ignore
174 | assert isinstance(tr.ones(1, 4, 4, 1), Tensor[B, A, A, B]) # type: ignore
175 |
176 |
177 | @pytest.mark.parametrize(
178 | "tensor_type_pairs",
179 | [
180 | (arr(2, 2, 10), Array[One_to_Three, int, Ten_or_Eleven]),
181 | (arr(2, 10), Array[One_to_Three, U[Ts], Ten_or_Eleven]), # type: ignore
182 | (arr(2, 2, 10), Array[One_to_Three, U[Ts], Ten_or_Eleven]), # type: ignore
183 | (arr(2, 0, 0, 10), Array[One_to_Three, U[Ts], Ten_or_Eleven]), # type: ignore
184 | (arr(2, 2, 10), Array[NewOneToThree, int, Ten_or_Eleven]),
185 | (arr(0, 0, 2, 11), Array[U[Ts], One_to_Three, Ten_or_Eleven]), # type: ignore
186 | ],
187 | )
188 | def test_parse_consistent_types(tensor_type_pairs):
189 | parse(*tensor_type_pairs)
190 |
191 |
192 | def test_non_binding_subint_dims_pass():
193 | parse(arr(2, 4, 6), Array[EvenOnly, EvenOnly, EvenOnly])
194 | parse(
195 | (arr(2, 4), Array[EvenOnly, EvenOnly]),
196 | (arr(6, 8), Array[EvenOnly, EvenOnly]),
197 | )
198 |
199 |
200 | def test_non_binding_subint_dims_validates():
201 |
202 | with pytest.raises(
203 | ParseError,
204 | match=re.escape(
205 | r"shape-(2, 1) doesn't match shape-type (EvenOnly=?, EvenOnly=?)"
206 | ),
207 | ):
208 | parse(arr(2, 1), Array[EvenOnly, EvenOnly])
209 |
210 | with pytest.raises(
211 | ParseError,
212 | match=re.escape(
213 | r"shape-(6, 3) doesn't match shape-type (EvenOnly=?, EvenOnly=?)"
214 | ),
215 | ):
216 | parse(
217 | (arr(2, 4), Array[EvenOnly, EvenOnly]),
218 | (arr(6, 3), Array[EvenOnly, EvenOnly]),
219 | )
220 |
221 |
222 | def test_binding_validated_dims_validates():
223 |
224 | with pytest.raises(
225 | ParseError,
226 | match=re.escape(
227 | r"shape-(1, 2, 3) doesn't match shape-type (NewOneToThree=1, NewOneToThree=1, NewOneToThree=1)"
228 | ),
229 | ):
230 | parse(arr(1, 2, 3), Array[NewOneToThree, NewOneToThree, NewOneToThree])
231 |
232 | with pytest.raises(
233 | ParseError,
234 | match=re.escape(
235 | r"shape-(1, 2) doesn't match shape-type (NewOneToThree=1, NewOneToThree=1)"
236 | ),
237 | ):
238 | parse(
239 | (arr(1, 1), Array[NewOneToThree, NewOneToThree]),
240 | (arr(1, 2), Array[NewOneToThree, NewOneToThree]),
241 | )
242 |
243 |
244 | @pytest.mark.parametrize("good_shape", [(1, 1), (2, 2), (3, 3)])
245 | def test_binding_validated_dims_passes(good_shape):
246 | parse(arr(*good_shape), Array[NewOneToThree, NewOneToThree])
247 |
--------------------------------------------------------------------------------
/tests/test_type_properties.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import sys
4 | from typing import Literal, Tuple
5 |
6 | import pytest
7 | from hypothesis import assume, given
8 |
9 |
10 | @pytest.mark.xfail(
11 | sys.version_info < (3, 9), reason="Tuple normalization introduced in 3.9"
12 | )
13 | def test_literal_singlet_tuple_same_as_scalar():
14 | assert Literal[(1,)] is Literal[1] # type: ignore
15 |
16 |
17 | @given(...)
18 | def test_literals_can_check_by_identity(a: Tuple[int, ...], b: Tuple[int, ...]):
19 | assume(len(a))
20 | assume(len(b))
21 | assume(a != b)
22 | assert Literal[a] is Literal[a] # type: ignore
23 | assert Literal[b] is Literal[b] # type: ignore
24 | assert Literal[a] is not Literal[b] # type: ignore
25 |
26 |
27 | def test_literal_hashes_consistently():
28 | assert {Literal[1]: 1}[Literal[1]] == 1
29 | assert hash(Literal[1, 2]) == hash(Literal[1, 2])
30 |
31 |
32 | @pytest.mark.parametrize("content", [1, 2, (1, 2)])
33 | def test_literal_args_always_tuples(content):
34 | assert Literal[content].__args__ == (content,) if not isinstance(content, tuple) else content # type: ignore
35 |
--------------------------------------------------------------------------------