├── .github ├── dependabot.yml └── workflows │ ├── pypi_publish.yml │ └── tox_run.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE.txt ├── README.md ├── deps └── requirements-pyright.txt ├── pyproject.toml ├── pyrightconfig_py38.json ├── setup.cfg ├── src └── phantom_tensors │ ├── __init__.py │ ├── _internals │ ├── __init__.py │ ├── dim_binding.py │ ├── parse.py │ └── utils.py │ ├── alphabet.py │ ├── array.py │ ├── errors.py │ ├── meta.py │ ├── numpy.py │ ├── parse.py │ ├── py.typed │ ├── torch.py │ └── words.py └── tests ├── __init__.py ├── annotations.py ├── arrlike.py ├── conftest.py ├── test_letters_and_words.py ├── test_prototype.py ├── test_third_party.py └── test_type_properties.py /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" # See documentation for possible values 4 | directory: "/" # Location of package manifests 5 | schedule: 6 | interval: "daily" 7 | - package-ecosystem: "pip" 8 | directory: "/deps/" 9 | schedule: 10 | interval: "daily" 11 | target-branch: "main" -------------------------------------------------------------------------------- /.github/workflows/pypi_publish.yml: -------------------------------------------------------------------------------- 1 | 2 | # This workflows will upload a Python Package using Twine when a release is created 3 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 4 | 5 | name: Upload Python Package 6 | 7 | on: 8 | release: 9 | types: [created] 10 | 11 | jobs: 12 | deploy: 13 | 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: '3.x' 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install build twine 26 | - name: Build and publish 27 | env: 28 | TWINE_USERNAME: __token__ 29 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} 30 | run: | 31 | python -m build 32 | twine upload dist/* -------------------------------------------------------------------------------- /.github/workflows/tox_run.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | # Trigger the workflow on push or pull request, 5 | # but only for the main branch 6 | push: 7 | branches: 8 | - main 9 | - develop 10 | pull_request: 11 | branches: 12 | - main 13 | - develop 14 | 15 | jobs: 16 | tests: 17 | runs-on: ubuntu-latest 18 | 19 | strategy: 20 | max-parallel: 5 21 | matrix: 22 | python-version: [3.8, 3.9, "3.10", 3.11, 3.12] 23 | fail-fast: false 24 | 25 | steps: 26 | - uses: actions/checkout@v4 27 | - name: Set up Python ${{ matrix.python-version }} 28 | uses: actions/setup-python@v5 29 | with: 30 | python-version: ${{ matrix.python-version }} 31 | - name: Install dependencies 32 | run: | 33 | python -m pip install --upgrade pip 34 | pip install tox tox-gh-actions 35 | - name: Test with tox 36 | run: tox -e py 37 | 38 | test-third-party: 39 | runs-on: ubuntu-latest 40 | steps: 41 | - uses: actions/checkout@v4 42 | - name: Set up Python 3.11 43 | uses: actions/setup-python@v5 44 | with: 45 | python-version: 3.11 46 | - name: Install dependencies 47 | run: | 48 | python -m pip install --upgrade pip 49 | pip install tox tox-gh-actions 50 | - name: Cache tox environments 51 | id: cache-third-party 52 | uses: actions/cache@v4 53 | with: 54 | path: .tox 55 | key: tox-third-party-${{ hashFiles('setup.cfg') }}-${{ hashFiles('setup.py') }}-${{ hashFiles('tests/conftest.py') }}-${{ hashFiles('.github/workflows/tox.yml') }} 56 | - name: Test with tox 57 | run: tox -e third-party 58 | 59 | run-pyright: 60 | runs-on: ubuntu-latest 61 | steps: 62 | - uses: actions/checkout@v4 63 | - name: Set up Python 3.11 64 | uses: actions/setup-python@v5 65 | with: 66 | python-version: 3.11 67 | - name: Install dependencies 68 | run: | 69 | python -m pip install --upgrade pip 70 | pip install tox tox-gh-actions 71 | - name: Cache tox environments 72 | id: cache-third-party 73 | uses: actions/cache@v4 74 | with: 75 | path: .tox 76 | key: tox-third-party-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('setup.py') }}-${{ hashFiles('.github/workflows/tox.yml') }} 77 | - name: Test with tox 78 | run: tox -e pyright 79 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # hydra-zen docs 132 | 133 | docs/source/_build 134 | docs/builds 135 | 136 | # pycharm 137 | .idea/ 138 | 139 | # pyright 140 | node_modules/ 141 | 142 | scratch/ 143 | 144 | # npm 145 | package-lock.json 146 | 147 | # vscode 148 | settings.json -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 24.3.0 4 | hooks: 5 | - id: black 6 | - repo: https://github.com/PyCQA/flake8 7 | rev: 7.0.0 8 | hooks: 9 | - id: flake8 10 | - repo: https://github.com/pycqa/isort 11 | rev: 5.13.2 12 | hooks: 13 | - id: isort -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Ryan Soklaski 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Phantom Tensors 2 | > Tensor types with variadic shapes, for any array-based library, that work with both static and runtime type checkers 3 | 4 |

5 | 6 | PyPI 7 | 8 | 9 | Python version support 10 | 11 |

12 | 13 | 14 | **This project is currently just a rough prototype! Inspired by: [phantom-types](https://github.com/antonagestam/phantom-types)** 15 | 16 | The goal of this project is to let users write tensor-like types with variadic shapes (via [PEP 646](https://peps.python.org/pep-0646/)) that are: 17 | - Amendable to **static type checking (without mypy plugins)**. 18 | > E.g., pyright can tell the difference between `Tensor[Batch, Channel]` and `Tensor[Batch, Feature]` 19 | - Useful for performing **runtime checks of tensor types and shapes**. 20 | > E.g., can validate -- at runtime -- that arrays of types `NDArray[A, B]` and `NDArray[B, A]` indeed have transposed shapes with respect with each other. 21 | - Compatible with *any* array-based library (numpy, pytorch, xarray, cupy, mygrad, etc.) 22 | > E.g. A function annotated with `x: torch.Tensor` can be passed `phantom_tensors.torch.Tensor[N, B, D]`. It is trivial to write custom phantom-tensor flavored types for any array-based library. 23 | 24 | `phantom_tensors.parse` makes it easy to declare shaped tensor types in a way that static type checkers understand, and that are validated at runtime: 25 | 26 | ```python 27 | from typing import NewType 28 | 29 | import numpy as np 30 | 31 | from phantom_tensors import parse 32 | from phantom_tensors.numpy import NDArray 33 | 34 | A = NewType("A", int) 35 | B = NewType("B", int) 36 | 37 | # static: declare that x is of type NDArray[A, B] 38 | # declare that y is of type NDArray[B, A] 39 | # runtime: check that shapes (2, 3) and (3, 2) 40 | # match (A, B) and (B, A) pattern across 41 | # tensors 42 | x, y = parse( 43 | (np.ones((2, 3)), NDArray[A, B]), 44 | (np.ones((3, 2)), NDArray[B, A]), 45 | ) 46 | 47 | x # static type checker sees: NDArray[A, B] 48 | y # static type checker sees: NDArray[B, A] 49 | 50 | ``` 51 | 52 | Passing inconsistent types to `parse` will result in a runtime validation error. 53 | ```python 54 | # Runtime: Raises `ParseError` A=10 and A=2 do not match 55 | z, w = parse( 56 | (np.ones((10, 3)), NDArray[A, B]), 57 | (np.ones((3, 2)), NDArray[B, A]), 58 | ) 59 | ``` 60 | 61 | These shaped tensor types are amenable to static type checking: 62 | 63 | ```python 64 | from typing import Any 65 | 66 | import numpy as np 67 | 68 | from phantom_tensors import parse 69 | from phantom_tensors.numpy import NDArray 70 | from phantom_tensors.alphabet import A, B # these are just NewType(..., int) types 71 | 72 | def func_on_2d(x: NDArray[Any, Any]): ... 73 | def func_on_3d(x: NDArray[Any, Any, Any]): ... 74 | def func_on_any_arr(x: np.ndarray): ... 75 | 76 | # runtime: ensures shape of arr_3d matches (A, B, A) patterns 77 | arr_3d = parse(np.ones((3, 5, 3)), NDArray[A, B, A]) 78 | 79 | func_on_2d(arr_3d) # static type checker: Error! # expects 2D arr, got 3D 80 | 81 | func_on_3d(arr_3d) # static type checker: OK 82 | func_on_any_arr(arr_3d) # static type checker: OK 83 | ``` 84 | 85 | 86 | Write easy-to-understand interfaces using common dimension names (or make up your own): 87 | 88 | ```python 89 | from phantom_tensors.torch import Tensor 90 | from phantom_tensors.words import Batch, Embed, Vocab 91 | 92 | def embedder(x: Tensor[Batch, Vocab]) -> Tensor[Batch, Embed]: 93 | ... 94 | ``` 95 | 96 | 97 | Using a runtime type checker, such as [beartype](https://github.com/beartype/beartype) or [typeguard](https://github.com/agronholm/typeguard), in conjunction with `phantom_tensors` means that the typed shape information will be validated at runtime across a function's inputs and outputs, whenever that function is called. 98 | 99 | ```python 100 | from typing import TypeVar, cast 101 | from typing_extensions import assert_type 102 | 103 | import torch as tr 104 | from beartype import beartype 105 | 106 | from phantom_tensors import dim_binding_scope, parse 107 | from phantom_tensors.torch import Tensor 108 | from phantom_tensors.alphabet import A, B, C 109 | 110 | T1 = TypeVar("T1") 111 | T2 = TypeVar("T2") 112 | T3 = TypeVar("T3") 113 | 114 | 115 | @dim_binding_scope 116 | @beartype # <- adds runtime type checking to function's interfaces 117 | def buggy_matmul(x: Tensor[T1, T2], y: Tensor[T2, T3]) -> Tensor[T1, T3]: 118 | # This is the wrong operation! 119 | # Will return shape-(T1, T1) tensor, not (T1, T3) 120 | out = x @ x.T 121 | 122 | # We lie to the static type checker to try to get away with it 123 | return cast(Tensor[T1, T3], out) 124 | 125 | x, y = parse( 126 | (tr.ones(3, 4), Tensor[A, B]), 127 | (tr.ones(4, 5), Tensor[B, C]), 128 | ) 129 | 130 | # At runtime beartype raises: 131 | # Function should return shape-(A, C) but returned shape-(A, A) 132 | z = buggy_matmul(x, y) # Runtime validation error! 133 | 134 | ``` 135 | 136 | ## Installation 137 | 138 | ```shell 139 | pip install phantom-tensors 140 | ``` 141 | 142 | `typing-extensions` is the only strict dependency. Using features from `phantom_tensors.torch(numpy)` requires that `torch`(`numpy`) is installed too. 143 | 144 | ## Some Lower-Level Details and Features 145 | 146 | Everything on display here is achieved using relatively minimal hacks (no mypy plugin necessary, no monkeypatching). Presently, `torch.Tensor` and `numpy.ndarray` are explicitly supported by phantom-tensors, but it is trivial to add support for other array-like classes. 147 | 148 | > Note that mypy does not support PEP 646 yet, but pyright does. You can run pyright on the following examples to see that they do, indeed type-check as expected! 149 | 150 | 151 | ### Dimension-Binding Contexts 152 | 153 | `phantom_tensors.parse` validates inputs against types-with-shapes and performs [type narrowing](https://mypy.readthedocs.io/en/latest/type_narrowing.html) so that static type checkers are privy to the newly proven type information about those inputs. It performs inter-tensor shape consistency checks within a "dimension-binding context". Tensor-likes that are parsed simultaneously are automatically checked within a common dimension-binding context. 154 | 155 | 156 | ```python 157 | import numpy as np 158 | import torch as tr 159 | 160 | from phantom_tensors import parse 161 | from phantom_tensors.alphabet import A, B, C 162 | from phantom_tensors.numpy import NDArray 163 | from phantom_tensors.torch import Tensor 164 | 165 | t1, arr, t2 = parse( 166 | # <- Runtime: enter dimension-binding context 167 | (tr.rand(9, 2, 9), Tensor[B, A, B]), # <-binds A=2 & B=9 168 | (np.ones((2,)), NDArray[A]), # <- checks A==2 169 | (tr.rand(9), Tensor[B]), # <- checks B==9 170 | ) # <- Runtime: exit dimension-binding scope 171 | # Statically: casts t1, arr, t2 to shape-typed Tensors 172 | 173 | # static type checkers now see 174 | # t1: Tensor[B, A, B] 175 | # arr: NDArray[A] 176 | # t2: Tensor[B] 177 | 178 | w = parse(tr.rand(78), Tensor[A]); # <- binds A=78 within this context 179 | ``` 180 | 181 | As indicated above, the type-checker sees the shaped-tensor/array types. Additionally, these are subclasses of their rightful parents, so we can pass these to functions typed with vanilla `torch.Tensor` and `numpy.ndarry` annotations, and type checkers will be a-ok with that. 182 | 183 | ```python 184 | def vanilla_numpy(x: np.ndarray): ... 185 | def vanilla_torch(x: tr.Tensor): ... 186 | 187 | vanilla_numpy(arr) # type checker: OK 188 | vanilla_torch(arr) # type checker: Error! 189 | vanilla_torch(t1) # type checker: OK 190 | ``` 191 | 192 | #### Basic forms of runtime validation performed by `parse` 193 | 194 | ```python 195 | # runtime type checking 196 | >>> parse(1, Tensor[A]) 197 | --------------------------------------------------------------------------- 198 | ParseError: Expected , got: 199 | 200 | # dimensionality mismatch 201 | >>> parse(tr.ones(3), Tensor[A, A, A]) 202 | --------------------------------------------------------------------------- 203 | ParseError: shape-(3,) doesn't match shape-type (A=?, A=?, A=?) 204 | 205 | # unsatisfied shape pattern 206 | >>> parse(tr.ones(1, 2), Tensor[A, A]) 207 | --------------------------------------------------------------------------- 208 | ParseError: shape-(1, 2) doesn't match shape-type (A=1, A=1) 209 | 210 | # inconsistent dimension sizes across tensors 211 | >>> x, y = parse( 212 | ... (tr.ones(1, 2), Tensor[A, B]), 213 | ... (tr.ones(4, 1), Tensor[B, A]), 214 | ... ) 215 | 216 | --------------------------------------------------------------------------- 217 | ParseError: shape-(4, 1) doesn't match shape-type (B=2, A=1) 218 | ``` 219 | 220 | To reiterate, `parse` is able to compare shapes across multiple tensors by entering into a "dimension-binding scope". 221 | One can enter into this context explicitly: 222 | 223 | ```python 224 | >>> from phantom_tensors import dim_binding_scope 225 | 226 | >>> x = parse(np.zeros((2,)), NDArray[B]) # binds B=2 227 | >>> y = parse(np.zeros((3,)), NDArray[B]) # binds B=3 228 | >>> with dim_binding_scope: 229 | ... x = parse(np.zeros((2,)), NDArray[B]) # binds B=2 230 | ... y = parse(np.zeros((3,)), NDArray[B]) # raises! 231 | --------------------------------------------------------------------------- 232 | ParseError: shape-(3,) doesn't match shape-type (B=2,) 233 | ``` 234 | 235 | #### Support for `Literal` dimensions: 236 | 237 | ```python 238 | from typing import Literal as L 239 | 240 | from phantom_tensors import parse 241 | from phantom_tensors.torch import Tensor 242 | 243 | import torch as tr 244 | 245 | parse(tr.zeros(1, 3), Tensor[L[1], L[3]]) # static + runtime: OK 246 | parse(tr.zeros(2, 3), Tensor[L[1], L[3]]) # # Runtime: ParseError - mismatch at dim 0 247 | ``` 248 | 249 | #### Support for `Literal` dimensions and variadic shapes: 250 | 251 | In Python 3.11 you can write shape types like `Tensor[int, *Ts, int]`, where `*Ts` represents 0 or more optional entries between two required dimensions. phantom-tensor supports this "unpack" dimension. In this README we opt for `typing_extensions.Unpack[Ts]` instead of `*Ts` for the sake of backwards compatibility. 252 | 253 | ```python 254 | from phantom_tensors import parse 255 | from phantom_tensors.torch import Tensor 256 | 257 | import torch as tr 258 | from typing_extensions import Unpack as U, TypeVarTuple 259 | 260 | Ts = TypeVarTuple("Ts") 261 | 262 | # U[Ts] represents an arbitrary number of entries 263 | parse(tr.ones(1, 3), Tensor[int, U[Ts], int) # static + runtime: OK 264 | parse(tr.ones(1, 0, 0, 0, 3), Tensor[int, U[Ts], int]) # static + runtime: OK 265 | 266 | parse(tr.ones(1, ), Tensor[int, U[Ts], int]) # Runtime: Not enough dimensions 267 | ``` 268 | 269 | #### Support for [phantom types](https://github.com/antonagestam/phantom-types): 270 | 271 | Supports phatom type dimensions (i.e. `int` subclasses that override `__isinstance__` checks): 272 | 273 | ```python 274 | from phantom_tensors import parse 275 | from phantom_tensors.torch import Tensor 276 | 277 | import torch as tr 278 | from phantom import Phantom 279 | 280 | class EvenOnly(int, Phantom, predicate=lambda x: x%2 == 0): ... 281 | 282 | parse(tr.ones(1, 0), Tensor[int, EvenOnly]) # static return type: Tensor[int, EvenOnly] 283 | parse(tr.ones(1, 2), Tensor[int, EvenOnly]) # static return type: Tensor[int, EvenOnly] 284 | parse(tr.ones(1, 4), Tensor[int, EvenOnly]) # static return type: Tensor[int, EvenOnly] 285 | 286 | parse(tr.ones(1, 3), Tensor[int, EvenOnly]) # runtime: ParseError (3 is not an even number) 287 | ``` 288 | 289 | 290 | 291 | ## Compatibility with Runtime Type Checkers 292 | 293 | `parse` is not the only way to perform runtime validation using phantom tensors – they work out of the box with 3rd party runtime type checkers like [beartype](https://github.com/beartype/beartype)! How is this possible? 294 | 295 | ...We do something tricky here! At, runtime `Tensor[A, B]` actually returns a [phantom type](https://github.com/antonagestam/phantom-types). This means that `isinstance(arr, NDArray[A, B])` is, at runtime, *actually* performing `isinstance(arr, PhantomNDArrayAB)`, which dynamically generated and is able to perform the type and shape checks. 296 | 297 | Thanks to the ability to bind dimensions within a specified context, all `beartype` needs to do is faithfully call `isinstance(...)` within said context and we can have the inputs and ouputs of a phantom-tensor-annotated function get checked! 298 | 299 | ```python 300 | from typing import Any 301 | 302 | from beartype import beartype # type: ignore 303 | import pytest 304 | import torch as tr 305 | 306 | from phantom_tensors.alphabet import A, B, C 307 | from phantom_tensors.torch import Tensor 308 | from phantom_tensors import dim_binding_scope, parse 309 | 310 | # @dim_binding_scope: 311 | # ensures A, B, C consistent across all input/output tensor shapes 312 | # within scope of function 313 | @dim_binding_scope 314 | @beartype # <-- adds isinstance checks on inputs & outputs 315 | def matrix_multiply(x: Tensor[A, B], y: Tensor[B, C]) -> Tensor[A, C]: 316 | a, _ = x.shape 317 | _, c = y.shape 318 | return parse(tr.rand(a, c), Tensor[A, C]) 319 | 320 | @beartype 321 | def needs_vector(x: Tensor[Any]): ... 322 | 323 | x, y = parse( 324 | (tr.rand(3, 4), Tensor[A, B]), 325 | (tr.rand(4, 5), Tensor[B, C]), 326 | ) 327 | 328 | z = matrix_multiply(x, y) 329 | z # type revealed: Tensor[A, C] 330 | 331 | with pytest.raises(Exception): 332 | # beartype raises error: input Tensor[A, C] doesn't match Tensor[A] 333 | needs_vector(z) # <- pyright also raises an error! 334 | 335 | with pytest.raises(Exception): 336 | # beartype raises error: inputs Tensor[A, B], Tensor[A, B] don't match signature 337 | matrix_multiply(x, x) # <- pyright also raises an error! 338 | ``` 339 | 340 | -------------------------------------------------------------------------------- /deps/requirements-pyright.txt: -------------------------------------------------------------------------------- 1 | pyright==1.1.374 2 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools >= 35.0.2", 4 | "wheel >= 0.29.0", 5 | "setuptools_scm[toml]==7.0.5", 6 | ] 7 | build-backend = "setuptools.build_meta" 8 | 9 | 10 | [project] 11 | name = "phantom_tensors" 12 | dynamic = ["version"] 13 | description = "Tensor-like types – with variadic shapes – that support both static and runtime type checking, and convenient parsing." 14 | readme = "README.md" 15 | requires-python = ">=3.8" 16 | dependencies = ["typing-extensions >= 4.1.0"] 17 | license = { text = "MIT" } 18 | keywords = [ 19 | "machine learning", 20 | "research", 21 | "configuration", 22 | "scalable", 23 | "reproducible", 24 | "yaml", 25 | "Hydra", 26 | "dataclass", 27 | ] 28 | 29 | authors = [ 30 | { name = "Ryan Soklaski", email = "rsoklaski@gmail.com" }, 31 | { name = "Justin Goodwin", email = "jgoodwin@ll.mit.edu" }, 32 | ] 33 | maintainers = [{ name = "Justin Goodwin", email = "jgoodwin@ll.mit.edu" }] 34 | 35 | classifiers = [ 36 | "Development Status :: 4 - Beta", 37 | "License :: OSI Approved :: MIT License", 38 | "Operating System :: OS Independent", 39 | "Intended Audience :: Science/Research", 40 | "Programming Language :: Python :: 3.8", 41 | "Programming Language :: Python :: 3.9", 42 | "Programming Language :: Python :: 3.10", 43 | "Programming Language :: Python :: 3.11", 44 | "Programming Language :: Python :: 3.12", 45 | "Topic :: Scientific/Engineering", 46 | "Programming Language :: Python :: 3 :: Only", 47 | ] 48 | 49 | [project.optional-dependencies] 50 | test = ["beartype >= 0.10.4", "pytest >= 3.8", "hypothesis >= 6.28.0"] 51 | 52 | 53 | [project.urls] 54 | "Homepage" = "https://github.com/rsokl/phantom-tensors/" 55 | "Bug Reports" = "https://github.com/rsokl/phantom-tensors/issues" 56 | "Source" = "https://github.com/rsokl/phantom-tensors" 57 | 58 | 59 | [tool.setuptools_scm] 60 | write_to = "src/phantom_tensors/_version.py" 61 | version_scheme = "no-guess-dev" 62 | 63 | 64 | [tool.setuptools] 65 | package-dir = { "" = "src" } 66 | 67 | [tool.setuptools.packages.find] 68 | where = ["src"] 69 | exclude = ["tests*", "tests.*"] 70 | 71 | [tool.setuptools.package-data] 72 | phantom_tensors = ["py.typed"] 73 | 74 | 75 | [tool.isort] 76 | known_first_party = ["phantom_tensors", "tests"] 77 | profile = "black" 78 | combine_as_imports = true 79 | 80 | 81 | [tool.coverage.run] 82 | branch = true 83 | omit = ["tests/test_docs_typecheck.py"] 84 | 85 | [tool.coverage.report] 86 | omit = ["src/phantom_tensors/_version.py"] 87 | exclude_lines = [ 88 | 'pragma: no cover', 89 | 'def __repr__', 90 | 'raise NotImplementedError', 91 | 'class .*\bProtocol(\[.+\])?\):', 92 | '@(abc\.)?abstractmethod', 93 | '@(typing\.)?overload', 94 | 'except ImportError:', 95 | 'except ModuleNotFoundError:', 96 | 'if (typing\.)?TYPE_CHECKING:', 97 | 'if sys\.version_info', 98 | ] 99 | 100 | [tool.pytest.ini_options] 101 | xfail_strict = true 102 | 103 | 104 | [tool.pyright] 105 | include = ["src"] 106 | exclude = [ 107 | "**/node_modules", 108 | "**/__pycache__", 109 | "src/phantom_tensors/_version.py", 110 | "**/third_party", 111 | ] 112 | reportUnnecessaryTypeIgnoreComment = true 113 | reportUnnecessaryIsInstance = false 114 | 115 | 116 | [tool.codespell] 117 | skip = 'docs/build/*' 118 | 119 | 120 | [tool.tox] 121 | legacy_tox_ini = """ 122 | [tox] 123 | isolated_build = True 124 | envlist = py38, py39, py310, py311, py312 125 | 126 | [gh-actions] 127 | python = 128 | 3.8: py38 129 | 3.9: py39 130 | 3.10: py310 131 | 3.11: py311 132 | 3.12: py312 133 | 134 | [testenv] 135 | description = Runs test suite parallelized in the specified python enviornment and 136 | against number of available processes (up to 4). 137 | Run `tox -e py39 -- -n 0` to run tests in a python 3.9 with 138 | parallelization disabled. 139 | passenv = * 140 | extras = test 141 | deps = pytest-xdist 142 | commands = pytest tests/ {posargs: -n auto --maxprocesses=4} 143 | 144 | 145 | [testenv:coverage] 146 | description = Runs test suite and measures test-coverage. Fails if coverage is 147 | below 100 prcnt. Run `tox -e coverage -- -n 0` to disable parallelization. 148 | setenv = NUMBA_DISABLE_JIT=1 149 | usedevelop = true 150 | basepython = python3.10 151 | deps = {[testenv]deps} 152 | coverage[toml] 153 | pytest-cov 154 | beartype 155 | torch 156 | numpy 157 | phantom-types 158 | commands = pytest --cov-report term-missing --cov-config=pyproject.toml --cov-fail-under=100 --cov=phantom_tensors tests {posargs: -n auto --maxprocesses=4} 159 | 160 | 161 | [testenv:third-party] 162 | description = Runs test suite against optional 3rd party packages that phantom-tensors 163 | provides specialized support for. 164 | install_command = pip install --upgrade --upgrade-strategy eager {opts} {packages} 165 | basepython = python3.11 166 | deps = {[testenv]deps} 167 | beartype 168 | torch 169 | numpy 170 | phantom-types 171 | 172 | 173 | [testenv:pyright] 174 | description = Ensure that phantom-tensors's source code and test suite scan clean 175 | under pyright, and that phantom-tensors's public API has a 100 prcnt 176 | type-completeness score. 177 | usedevelop = true 178 | basepython = python3.11 179 | deps = 180 | --requirement deps/requirements-pyright.txt 181 | beartype 182 | torch 183 | numpy 184 | phantom-types 185 | 186 | commands = pyright tests/ src/ --pythonversion=3.8 -p pyrightconfig_py38.json 187 | pyright tests/ src/ --pythonversion=3.9 188 | pyright tests/ src/ --pythonversion=3.10 189 | pyright tests/ src/ --pythonversion=3.11 190 | pyright tests/ src/ --pythonversion=3.12 191 | pyright --ignoreexternal --verifytypes phantom_tensors 192 | 193 | 194 | [testenv:format] 195 | description = Applies auto-flake (e.g. remove unsused imports), black, and isort 196 | in-place on source files and test suite. Running this can help fix a 197 | failing `enforce-format` run. 198 | skip_install=true 199 | deps = 200 | autoflake 201 | black 202 | isort 203 | commands = 204 | autoflake --recursive --in-place --remove-duplicate-keys --remove-unused-variables src/ tests/ 205 | isort src/ tests/ 206 | black src/ tests/ 207 | 208 | 209 | [testenv:enforce-format] 210 | description = Ensures that source materials code and docs and test suite adhere to 211 | formatting and code-quality standards. 212 | skip_install=true 213 | basepython=python3.11 214 | deps=black 215 | isort 216 | flake8 217 | pytest 218 | codespell 219 | commands= 220 | black src/ tests/ --diff --check 221 | isort src/ tests/ --diff --check 222 | flake8 src/ tests/ 223 | codespell src/ docs/ 224 | """ 225 | -------------------------------------------------------------------------------- /pyrightconfig_py38.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | 2 | [flake8] 3 | extend-ignore = F811,D1,D205,D209,D213,D400,D401,D999,D202,E203,E501,W503,E721,F403,F405 4 | exclude = .git,__pycache__,docs/*,old,build,dis,tests/annotations/* -------------------------------------------------------------------------------- /src/phantom_tensors/__init__.py: -------------------------------------------------------------------------------- 1 | # pyright: strict 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | from phantom_tensors._internals.dim_binding import dim_binding_scope 6 | 7 | from .parse import parse 8 | 9 | __all__ = ["parse", "dim_binding_scope"] 10 | 11 | 12 | if not TYPE_CHECKING: 13 | try: 14 | from ._version import version as __version__ 15 | except ImportError: 16 | __version__ = "unknown version" 17 | else: # pragma: no cover 18 | __version__: str 19 | -------------------------------------------------------------------------------- /src/phantom_tensors/_internals/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rsokl/phantom-tensors/9569cb2fee80ae12b71f398c1bdcabdb15addf36/src/phantom_tensors/_internals/__init__.py -------------------------------------------------------------------------------- /src/phantom_tensors/_internals/dim_binding.py: -------------------------------------------------------------------------------- 1 | # pyright: strict 2 | from __future__ import annotations 3 | 4 | from collections import defaultdict 5 | from contextvars import ContextVar, Token 6 | from functools import wraps 7 | from typing import Any, Callable, Iterable, Optional, Tuple, Type, TypeVar, Union, cast 8 | 9 | from typing_extensions import TypeAlias 10 | 11 | import phantom_tensors._internals.utils as _utils 12 | from phantom_tensors._internals.utils import LiteralLike, NewTypeLike, UnpackLike 13 | 14 | ShapeDimType: TypeAlias = Union[ 15 | Type[int], 16 | Type[UnpackLike], 17 | # Literal[Type[Any]] -- can't actually express this 18 | # the following all bind as dimension symbols by-reference 19 | Type[TypeVar], 20 | NewTypeLike, 21 | LiteralLike, 22 | ] 23 | 24 | LiteralCheck: TypeAlias = Callable[[Any, Iterable[Any]], bool] 25 | 26 | F = TypeVar("F", bound=Callable[..., Any]) 27 | 28 | 29 | bindings: ContextVar[Optional[dict[Any, int]]] = ContextVar("bindings", default=None) 30 | 31 | 32 | class DimBindContext: 33 | __slots__ = ("_tokens", "_depth") 34 | 35 | def __init__(self) -> None: 36 | self._tokens: dict[int, Token[Optional[dict[Any, int]]]] = {} 37 | self._depth: int = 0 38 | 39 | def __enter__(self) -> None: 40 | b = bindings.get() 41 | self._depth += 1 42 | b = {} if b is None else b.copy() 43 | self._tokens[self._depth] = bindings.set(b) 44 | 45 | def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: 46 | if self._depth == 1: 47 | bindings.reset(self._tokens.pop(self._depth)) 48 | self._depth -= 1 49 | 50 | def __call__(self, func: F) -> F: 51 | @wraps(func) 52 | def wrapper(*args: Any, **kwargs: Any): 53 | with self: 54 | return func(*args, **kwargs) 55 | 56 | return cast(F, wrapper) 57 | 58 | 59 | dim_binding_scope: DimBindContext = DimBindContext() 60 | 61 | 62 | def check(shape_type: Tuple[ShapeDimType, ...], shape: Tuple[int, ...]) -> bool: 63 | # Don't need to check types / values of `shape` -- assumed 64 | # to be donwstream of `ndarray.shape` call and thus already 65 | # validatated 66 | # Maybe enable extra strict mode where we do that checking 67 | 68 | # E.g. Tensor[A, B, B, C] -> matches == {A: [0], B: [1, 2], C: [3]} 69 | bound_symbols: defaultdict[ShapeDimType, list[int]] = defaultdict(list) 70 | validated_symbols: defaultdict[ShapeDimType, list[int]] = defaultdict(list) 71 | 72 | # These can be cached globally -- are independent of match pattern 73 | # E.g. Tensor[Literal[1]] -> validators {Literal[1]: lambda x: x == 1} 74 | validators: dict[ShapeDimType, Callable[[Any], bool]] = {} 75 | 76 | var_field_ind: Optional[int] = None # contains *Ts 77 | 78 | # TODO: Add caching to validation process. 79 | # - Should this use weakrefs? 80 | for n, dim_symbol in enumerate(shape_type): 81 | if dim_symbol is Any or dim_symbol is int: 82 | # E.g. Tensor[int, int] or Tensor[Any] 83 | # --> no constraints on shape 84 | continue 85 | if _utils.is_typevar_unpack(dim_symbol): 86 | if var_field_ind is not None: 87 | raise TypeError( 88 | f"Type-shape {shape_type} specifies more than one TypeVarTuple" 89 | ) 90 | var_field_ind = n 91 | continue 92 | 93 | # if variadic tuple is present, need to use negative indexing to reference 94 | # location from the end of the tuple 95 | CURRENT_INDEX = n if var_field_ind is None else n - len(shape_type) 96 | 97 | # The following symbols bind to dimensions (by symbol-reference) 98 | # Some of them may also carry with them additional validation checks, 99 | # which need only be checked when the symbol is first bound 100 | _match_list = bound_symbols[dim_symbol] 101 | 102 | if _match_list: 103 | # We have already encountered symbol; do not need to validate or 104 | # extract validation function 105 | _match_list.append(CURRENT_INDEX) 106 | continue 107 | 108 | _validate_list = validated_symbols[dim_symbol] 109 | 110 | if _validate_list or dim_symbol in validators: 111 | del bound_symbols[dim_symbol] 112 | _validate_list.append(CURRENT_INDEX) 113 | continue 114 | 115 | if _utils.is_newtype(dim_symbol): 116 | _supertype = dim_symbol.__supertype__ 117 | if _supertype is not int: 118 | if not issubclass(_supertype, int): 119 | raise TypeError( 120 | f"Dimensions expressed by NewTypes must be associated with an " 121 | f"int or subclass of int. shape-type {shape_type} contains a " 122 | f"NewType of supertype {_supertype}" 123 | ) 124 | validators[dim_symbol] = lambda x, sp=_supertype: isinstance(x, sp) 125 | _match_list.append(CURRENT_INDEX) 126 | del _supertype 127 | elif isinstance(dim_symbol, TypeVar): 128 | _match_list.append(CURRENT_INDEX) 129 | elif _utils.is_literal(dim_symbol) and dim_symbol: 130 | _expected_literals = dim_symbol.__args__ 131 | literal_check: LiteralCheck = lambda x, y=_expected_literals: any( 132 | x == val for val in y 133 | ) 134 | validators[dim_symbol] = literal_check 135 | _validate_list.append(CURRENT_INDEX) 136 | del _expected_literals 137 | 138 | elif isinstance(dim_symbol, type) and issubclass(dim_symbol, int): 139 | validators[dim_symbol] = lambda x, type_=dim_symbol: isinstance(x, type_) 140 | _validate_list.append(CURRENT_INDEX) 141 | else: 142 | raise TypeError( 143 | f"Got shape-type {shape_type} with dim {dim_symbol}. Valid dimensions " 144 | f"are `type[int] | Unpack | TypeVar | NewType | Literal`" 145 | ) 146 | if not _match_list: 147 | del bound_symbols[dim_symbol] 148 | 149 | if not _validate_list: 150 | del validated_symbols[dim_symbol] 151 | 152 | if var_field_ind is None and len(shape_type) != len(shape): 153 | # E.g. type: Tensor[A, B, C] 154 | # vs. shape: (1, 1) or (1, 1, 1, 1) # must have exactly 3 dim 155 | return False 156 | 157 | if var_field_ind is not None and len(shape) < len(shape_type) - 1: 158 | # E.g. type: Tensor[A, *Ts, C] 159 | # vs. shape: (1,) # should have at least 2 dim 160 | return False 161 | 162 | _bindings = bindings.get() 163 | 164 | for symbol, indices in bound_symbols.items(): 165 | validation_fn = validators.get(symbol, None) 166 | 167 | if len(indices) == 1 and _bindings is None and validation_fn is None: 168 | continue 169 | 170 | actual_val = shape[indices[0]] 171 | 172 | if _bindings is None or symbol is Any or symbol is int: 173 | expected_val = actual_val 174 | else: 175 | if symbol in _bindings: 176 | expected_val = _bindings[symbol] 177 | else: 178 | if validation_fn is not None and not validation_fn(actual_val): 179 | return False 180 | _bindings[symbol] = actual_val 181 | expected_val = actual_val 182 | if not all(expected_val == shape[index] for index in indices): 183 | return False 184 | 185 | for symbol, indices in validated_symbols.items(): 186 | validation_fn = validators[symbol] 187 | if not all(validation_fn(shape[index]) for index in indices): 188 | return False 189 | 190 | return True 191 | -------------------------------------------------------------------------------- /src/phantom_tensors/_internals/parse.py: -------------------------------------------------------------------------------- 1 | # pyright: strict 2 | from __future__ import annotations 3 | 4 | from typing import Any, List, Tuple, Type, TypeVar, Union, cast, overload 5 | 6 | from typing_extensions import ClassVar, Protocol, TypeAlias, TypeVarTuple 7 | 8 | from phantom_tensors._internals import utils as _utils 9 | from phantom_tensors._internals.dim_binding import ( 10 | ShapeDimType, 11 | bindings, 12 | check, 13 | dim_binding_scope, 14 | ) 15 | from phantom_tensors.errors import ParseError 16 | 17 | __all__ = ["parse"] 18 | 19 | Ta = TypeVar("Ta", bound=Tuple[Any, Any]) 20 | Tb = TypeVar("Tb") 21 | Ts = TypeVarTuple("Ts") 22 | 23 | 24 | class _Generic(Protocol): 25 | __origin__: Type[Any] 26 | __args__: Tuple[ShapeDimType, ...] 27 | 28 | 29 | class _Phantom(Protocol): 30 | __bound__: ClassVar[Union[Type[Any], Tuple[Type[Any], ...]]] 31 | __args__: Tuple[ShapeDimType, ...] 32 | 33 | 34 | class HasShape(Protocol): 35 | @property 36 | def shape(self) -> Any: ... 37 | 38 | 39 | TupleInt: TypeAlias = Tuple[int, ...] 40 | 41 | S1 = TypeVar("S1", bound=HasShape) 42 | S2 = TypeVar("S2", bound=HasShape) 43 | S3 = TypeVar("S3", bound=HasShape) 44 | S4 = TypeVar("S4", bound=HasShape) 45 | S5 = TypeVar("S5", bound=HasShape) 46 | S6 = TypeVar("S6", bound=HasShape) 47 | 48 | Ts1 = TypeVarTuple("Ts1") 49 | Ts2 = TypeVarTuple("Ts2") 50 | Ts3 = TypeVarTuple("Ts3") 51 | Ts4 = TypeVarTuple("Ts4") 52 | Ts5 = TypeVarTuple("Ts5") 53 | Ts6 = TypeVarTuple("Ts6") 54 | 55 | T1 = TypeVar("T1", bound=Tuple[Any, ...]) 56 | T2 = TypeVar("T2", bound=Tuple[Any, ...]) 57 | T3 = TypeVar("T3", bound=Tuple[Any, ...]) 58 | T4 = TypeVar("T4", bound=Tuple[Any, ...]) 59 | T5 = TypeVar("T5", bound=Tuple[Any, ...]) 60 | T6 = TypeVar("T6", bound=Tuple[Any, ...]) 61 | 62 | 63 | I1 = TypeVar("I1", bound=int) 64 | I2 = TypeVar("I2", bound=int) 65 | I3 = TypeVar("I3", bound=int) 66 | I4 = TypeVar("I4", bound=int) 67 | I5 = TypeVar("I5", bound=int) 68 | I6 = TypeVar("I6", bound=int) 69 | 70 | 71 | def _to_tuple(x: Ta | Tuple[Ta, ...]) -> Tuple[Ta, ...]: 72 | if len(x) == 2 and not isinstance(x[0], tuple): 73 | return (x,) # type: ignore 74 | return x 75 | 76 | 77 | class Parser: 78 | def get_shape_and_concrete_type( 79 | self, type_: Union[_Phantom, _Generic, HasShape] 80 | ) -> Tuple[Tuple[ShapeDimType, ...], Union[type, Tuple[type, ...]]]: 81 | """ 82 | Extracts the concrete base type(s) and shape-type from a generic tensor type. 83 | 84 | Overwrite this method to add support for different varieties of 85 | generic tensor types. 86 | 87 | Parameters 88 | ---------- 89 | tensor : Any 90 | The tensor whose type/shape is being checked. 91 | 92 | type_ : Any 93 | The tensor-type that contains the shape information to 94 | be extracted. 95 | 96 | Returns 97 | ------- 98 | Tuple[type | Tuple[type, ...]], Tuple[ShapeDimType, ...]] 99 | 100 | Examples 101 | -------- 102 | >>> from phantom_tensors import parse 103 | >>> from phantom_tensors.numpy import NDArray 104 | >>> parse.get_shape_and_concrete_type(NDArray[int, int]) 105 | (numpy.ndarray, (int, int)) 106 | """ 107 | if hasattr(type_, "__origin__"): 108 | type_ = cast(_Generic, type_) 109 | type_shape = type_.__args__ 110 | base_type = type_.__origin__ 111 | 112 | elif hasattr(type_, "__bound__"): 113 | type_ = cast(_Phantom, type_) 114 | type_shape = type_.__args__ 115 | base_type = type_.__bound__ 116 | else: 117 | assert False 118 | return type_shape, base_type 119 | 120 | @overload 121 | def __call__( 122 | self, 123 | __a: Tuple[HasShape, Type[S1]], 124 | __b: Tuple[HasShape, Type[S2]], 125 | __c: Tuple[HasShape, Type[S3]], 126 | __d: Tuple[HasShape, Type[S4]], 127 | __e: Tuple[HasShape, Type[S5]], 128 | __f: Tuple[HasShape, Type[S6]], 129 | ) -> Tuple[S1, S2, S3, S4, S5, S6]: ... 130 | 131 | @overload 132 | def __call__( 133 | self, 134 | __a: Tuple[HasShape, Type[S1]], 135 | __b: Tuple[HasShape, Type[S2]], 136 | __c: Tuple[HasShape, Type[S3]], 137 | __d: Tuple[HasShape, Type[S4]], 138 | __e: Tuple[HasShape, Type[S5]], 139 | ) -> Tuple[S1, S2, S3, S4, S5]: ... 140 | 141 | @overload 142 | def __call__( 143 | self, 144 | __a: Tuple[HasShape, Type[S1]], 145 | __b: Tuple[HasShape, Type[S2]], 146 | __c: Tuple[HasShape, Type[S3]], 147 | __d: Tuple[HasShape, Type[S4]], 148 | ) -> Tuple[S1, S2, S3, S4]: ... 149 | 150 | @overload 151 | def __call__( 152 | self, 153 | __a: Tuple[HasShape, Type[S1]], 154 | __b: Tuple[HasShape, Type[S2]], 155 | __c: Tuple[HasShape, Type[S3]], 156 | ) -> Tuple[S1, S2, S3]: ... 157 | 158 | @overload 159 | def __call__( 160 | self, 161 | __a: HasShape, 162 | __b: Type[S1], 163 | ) -> S1: ... 164 | 165 | @overload 166 | def __call__( 167 | self, 168 | __a: Tuple[HasShape, Type[S1]], 169 | __b: Tuple[HasShape, Type[S2]], 170 | ) -> Tuple[S1, S2]: ... 171 | 172 | @overload 173 | def __call__(self, __a: Tuple[HasShape, Type[S1]]) -> S1: ... 174 | 175 | @overload 176 | def __call__( 177 | self, 178 | *tensor_type_pairs: Tuple[HasShape, Type[HasShape]] | HasShape | Type[HasShape], 179 | ) -> HasShape | Tuple[HasShape, ...]: ... 180 | 181 | @dim_binding_scope 182 | def __call__( 183 | self, 184 | *tensor_type_pairs: Tuple[HasShape, Type[HasShape]] | HasShape | Type[HasShape], 185 | ) -> HasShape | Tuple[HasShape, ...]: 186 | if len(tensor_type_pairs) == 0: 187 | raise ValueError("") 188 | if len(tensor_type_pairs) == 2 and not isinstance(tensor_type_pairs[0], tuple): 189 | tensor_type_pairs = (tensor_type_pairs,) # type: ignore 190 | 191 | pairs = cast( 192 | Tuple[Tuple[HasShape, Type[HasShape]], ...], _to_tuple(tensor_type_pairs) # type: ignore 193 | ) 194 | 195 | out: List[HasShape] = [] 196 | 197 | del tensor_type_pairs 198 | 199 | for tensor, type_ in pairs: 200 | type_shape, expected_type = self.get_shape_and_concrete_type(type_=type_) 201 | 202 | if not isinstance(tensor, expected_type): 203 | raise ParseError(f"Expected {expected_type}, got: {type(tensor)}") 204 | 205 | if not check(type_shape, tensor.shape): 206 | _bindings = bindings.get() 207 | assert _bindings is not None 208 | type_str = ", ".join( 209 | ( 210 | f"{getattr(p, '__name__', repr(p))}={_bindings.get(p, '?')}" 211 | if not _utils.is_typevar_unpack(p) 212 | else "[...]" 213 | ) 214 | for p in type_shape 215 | ) 216 | if len(type_shape) == 1: 217 | # (A) -> (A,) 218 | type_str += "," 219 | raise ParseError( 220 | f"shape-{tuple(tensor.shape)} doesn't match shape-type ({type_str})" 221 | ) 222 | out.append(tensor) 223 | if len(out) == 1: 224 | return out[0] 225 | return tuple(out) 226 | 227 | 228 | parse: Parser = Parser() 229 | 230 | # @overload 231 | # def __call___ints( 232 | # __a: Tuple[int, Type[I1]], 233 | # __b: Tuple[int, Type[I2]], 234 | # __c: Tuple[int, Type[I3]], 235 | # __d: Tuple[int, Type[I4]], 236 | # __e: Tuple[int, Type[I5]], 237 | # __f: Tuple[int, Type[I6]], 238 | # ) -> Tuple[I1, I2, I3, I4, I5, I6]: 239 | # ... 240 | 241 | 242 | # @overload 243 | # def __call___ints( 244 | # __a: Tuple[int, Type[I1]], 245 | # __b: Tuple[int, Type[I2]], 246 | # __c: Tuple[int, Type[I3]], 247 | # __d: Tuple[int, Type[I4]], 248 | # __e: Tuple[int, Type[I5]], 249 | # ) -> Tuple[I1, I2, I3, I4, I5]: 250 | # ... 251 | 252 | 253 | # @overload 254 | # def __call___ints( 255 | # __a: Tuple[int, Type[I1]], 256 | # __b: Tuple[int, Type[I2]], 257 | # __c: Tuple[int, Type[I3]], 258 | # __d: Tuple[int, Type[I4]], 259 | # ) -> Tuple[I1, I2, I3, I4]: 260 | # ... 261 | 262 | 263 | # @overload 264 | # def __call___ints( 265 | # __a: Tuple[int, Type[I1]], 266 | # __b: Tuple[int, Type[I2]], 267 | # __c: Tuple[int, Type[I3]], 268 | # ) -> Tuple[I1, I2, I3]: 269 | # ... 270 | 271 | 272 | # @overload 273 | # def __call___ints( 274 | # __a: int, 275 | # __b: Type[I1], 276 | # ) -> I1: 277 | # ... 278 | 279 | 280 | # @overload 281 | # def __call___ints( 282 | # __a: Tuple[int, Type[I1]], 283 | # __b: Tuple[int, Type[I2]], 284 | # ) -> Tuple[I1, I2]: 285 | # ... 286 | 287 | 288 | # @overload 289 | # def __call___ints(__a: Tuple[int, Type[I1]]) -> I1: 290 | # ... 291 | 292 | 293 | # @overload 294 | # def __call___ints( 295 | # *tensor_type_pairs: Tuple[int, Type[int]] | int | Type[int] 296 | # ) -> int | Tuple[int, ...]: 297 | # ... 298 | 299 | 300 | # @dim_binding_scope 301 | # def __call___ints( 302 | # *tensor_type_pairs: Tuple[int, Type[int]] | int | Type[int] 303 | # ) -> int | Tuple[int, ...]: 304 | # ... 305 | 306 | # @overload 307 | # def __call___tuples( 308 | # __a: Tuple[TupleInt, Type[Tuple[U[Ts1]]]], 309 | # __b: Tuple[TupleInt, Type[Tuple[U[Ts2]]]], 310 | # __c: Tuple[TupleInt, Type[Tuple[U[Ts3]]]], 311 | # __d: Tuple[TupleInt, Type[Tuple[U[Ts4]]]], 312 | # __e: Tuple[TupleInt, Type[Tuple[U[Ts5]]]], 313 | # __f: Tuple[TupleInt, Type[Tuple[U[Ts6]]]], 314 | # ) -> Tuple[ 315 | # Tuple[U[Ts1]], 316 | # Tuple[U[Ts2]], 317 | # Tuple[U[Ts3]], 318 | # Tuple[U[Ts4]], 319 | # Tuple[U[Ts5]], 320 | # Tuple[U[Ts6]], 321 | # ]: 322 | # ... 323 | 324 | 325 | # @overload 326 | # def __call___tuples( 327 | # __a: Tuple[TupleInt, Type[Tuple[U[Ts1]]]], 328 | # __b: Tuple[TupleInt, Type[Tuple[U[Ts2]]]], 329 | # __c: Tuple[TupleInt, Type[Tuple[U[Ts3]]]], 330 | # __d: Tuple[TupleInt, Type[Tuple[U[Ts4]]]], 331 | # __e: Tuple[TupleInt, Type[Tuple[U[Ts5]]]], 332 | # ) -> Tuple[Tuple[U[Ts1]], Tuple[U[Ts2]], Tuple[U[Ts3]], Tuple[U[Ts4]], Tuple[U[Ts5]]]: 333 | # ... 334 | 335 | 336 | # @overload 337 | # def __call___tuples( 338 | # __a: Tuple[TupleInt, Type[Tuple[U[Ts1]]]], 339 | # __b: Tuple[TupleInt, Type[Tuple[U[Ts2]]]], 340 | # __c: Tuple[TupleInt, Type[Tuple[U[Ts3]]]], 341 | # __d: Tuple[TupleInt, Type[Tuple[U[Ts4]]]], 342 | # ) -> Tuple[Tuple[U[Ts1]], Tuple[U[Ts2]], Tuple[U[Ts3]], Tuple[U[Ts4]]]: 343 | # ... 344 | 345 | 346 | # @overload 347 | # def __call___tuples( 348 | # __a: Tuple[TupleInt, Type[Tuple[U[Ts1]]]], 349 | # __b: Tuple[TupleInt, Type[Tuple[U[Ts2]]]], 350 | # __c: Tuple[TupleInt, Type[Tuple[U[Ts3]]]], 351 | # ) -> Tuple[Tuple[U[Ts1]], Tuple[U[Ts2]], Tuple[U[Ts3]]]: 352 | # ... 353 | 354 | 355 | # @overload 356 | # def __call___tuples( 357 | # __a: TupleInt, 358 | # __b: T1, 359 | # ) -> T1: 360 | # ... 361 | 362 | 363 | # @overload 364 | # def __call___tuples( 365 | # __a: TupleInt, 366 | # __b: Type[Tuple[U[Ts1]]], 367 | # ) -> Tuple[U[Ts1]]: 368 | # ... 369 | 370 | 371 | # @overload 372 | # def __call___tuples( 373 | # __a: Tuple[TupleInt, Type[Tuple[U[Ts1]]]], 374 | # __b: Tuple[TupleInt, Type[Tuple[U[Ts2]]]], 375 | # ) -> Tuple[Tuple[U[Ts1]], Tuple[U[Ts2]]]: 376 | # ... 377 | 378 | 379 | # @overload 380 | # def __call___tuples(__a: Tuple[TupleInt, Type[Tuple[U[Ts1]]]]) -> Tuple[U[Ts1]]: 381 | # ... 382 | 383 | 384 | # @overload 385 | # def __call___tuples( 386 | # *tensor_type_pairs: Tuple[TupleInt, Type[Tuple[Any, ...]]] 387 | # | TupleInt 388 | # | Type[Tuple[Any, ...]] 389 | # ) -> TupleInt | Tuple[Any, ...]: 390 | # ... 391 | 392 | 393 | # @dim_binding_scope 394 | # def __call___tuples( 395 | # *tensor_type_pairs: Tuple[TupleInt, Type[Tuple[Any, ...]]] 396 | # | TupleInt 397 | # | Type[Tuple[Any, ...]] 398 | # ) -> TupleInt | Tuple[Any, ...]: 399 | # ... 400 | -------------------------------------------------------------------------------- /src/phantom_tensors/_internals/utils.py: -------------------------------------------------------------------------------- 1 | # pyright: strict 2 | 3 | import abc 4 | from typing import Any, Literal, Tuple, Type 5 | 6 | from typing_extensions import Protocol, TypeGuard, TypeVarTuple, Unpack 7 | 8 | _Ts = TypeVarTuple("_Ts") 9 | 10 | UnpackType = type(Unpack[_Ts]) # type: ignore 11 | LiteralType = type(Literal[1]) 12 | 13 | 14 | class CustomInstanceCheck(abc.ABCMeta): 15 | def __instancecheck__(self, instance: object) -> bool: 16 | return self.__instancecheck__(instance) 17 | 18 | 19 | class NewTypeLike(Protocol): 20 | __name__: str 21 | __supertype__: Type[Any] 22 | 23 | def __call__(self, x: Any) -> int: ... 24 | 25 | 26 | class NewTypeInt(Protocol): 27 | __name__: str 28 | __supertype__: Type[int] 29 | 30 | def __call__(self, x: Any) -> int: ... 31 | 32 | 33 | class UnpackLike(Protocol): 34 | _inst: Literal[True] 35 | _name: Literal[None] 36 | __origin__: Type[Any] = Unpack # type: ignore 37 | __args__: Tuple[TypeVarTuple] 38 | __parameters__: Tuple[TypeVarTuple] 39 | __module__: str 40 | 41 | 42 | class LiteralLike(Protocol): 43 | _inst: Literal[True] 44 | _name: Literal[None] 45 | __origin__: Type[Any] = Literal # type: ignore 46 | __args__: Tuple[Any, ...] 47 | __parameters__: Tuple[()] 48 | __module__: str 49 | 50 | 51 | class TupleGeneric(Protocol): 52 | __origin__: Type[Tuple[Any, ...]] 53 | __args__: Tuple[Type[Any], ...] 54 | 55 | 56 | def is_newtype(x: Any) -> TypeGuard[NewTypeLike]: 57 | return hasattr(x, "__supertype__") 58 | 59 | 60 | def is_newtype_int(x: Any) -> TypeGuard[NewTypeInt]: 61 | supertype = getattr(x, "__supertype__", None) 62 | if supertype is None: 63 | return False 64 | return issubclass(supertype, int) 65 | 66 | 67 | def is_typevar_unpack(x: Any) -> TypeGuard[UnpackLike]: 68 | return isinstance(x, UnpackType) 69 | 70 | 71 | def is_tuple_generic(x: Any) -> TypeGuard[TupleGeneric]: 72 | return getattr(x, "__origin__", None) is tuple 73 | 74 | 75 | def is_literal(x: Any) -> TypeGuard[LiteralLike]: 76 | return isinstance(x, LiteralType) 77 | -------------------------------------------------------------------------------- /src/phantom_tensors/alphabet.py: -------------------------------------------------------------------------------- 1 | # pyright: strict 2 | """Convenient definitions of 3 | 4 | = NewType('', int) 5 | 6 | to be used as descriptive axis labels in shape types""" 7 | 8 | from typing import NewType as _NewType 9 | 10 | __all__ = [ 11 | "A", 12 | "B", 13 | "C", 14 | "D", 15 | "E", 16 | "F", 17 | "G", 18 | "H", 19 | "I", 20 | "J", 21 | "K", 22 | "L", 23 | "M", 24 | "N", 25 | "O", 26 | "P", 27 | "Q", 28 | "R", 29 | "S", 30 | "T", 31 | "U", 32 | "V", 33 | "W", 34 | "X", 35 | "Y", 36 | "Z", 37 | ] 38 | 39 | A = _NewType("A", int) 40 | B = _NewType("B", int) 41 | C = _NewType("C", int) 42 | D = _NewType("D", int) 43 | E = _NewType("E", int) 44 | F = _NewType("F", int) 45 | G = _NewType("G", int) 46 | H = _NewType("H", int) 47 | I = _NewType("I", int) # noqa: E741 48 | J = _NewType("J", int) 49 | K = _NewType("K", int) 50 | L = _NewType("L", int) 51 | M = _NewType("M", int) 52 | N = _NewType("N", int) 53 | O = _NewType("O", int) # noqa: E741 54 | P = _NewType("P", int) 55 | Q = _NewType("Q", int) 56 | R = _NewType("R", int) 57 | S = _NewType("S", int) 58 | T = _NewType("T", int) 59 | U = _NewType("U", int) 60 | V = _NewType("V", int) 61 | W = _NewType("W", int) 62 | X = _NewType("X", int) 63 | Y = _NewType("Y", int) 64 | Z = _NewType("Z", int) 65 | -------------------------------------------------------------------------------- /src/phantom_tensors/array.py: -------------------------------------------------------------------------------- 1 | # pyright: strict 2 | from typing import Any, Tuple 3 | 4 | import typing_extensions as _te 5 | from typing_extensions import Protocol, runtime_checkable 6 | 7 | __all__ = ["SupportsArray"] 8 | 9 | 10 | Shape = _te.TypeVarTuple("Shape") 11 | 12 | 13 | @runtime_checkable 14 | class SupportsArray(Protocol[_te.Unpack[Shape]]): 15 | def __array__(self) -> Any: ... 16 | 17 | @property 18 | def shape(self) -> Tuple[_te.Unpack[Shape]]: ... 19 | -------------------------------------------------------------------------------- /src/phantom_tensors/errors.py: -------------------------------------------------------------------------------- 1 | # pyright: strict 2 | class ParseError(TypeError): 3 | pass 4 | -------------------------------------------------------------------------------- /src/phantom_tensors/meta.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | __all__ = ["CustomInstanceCheck"] 4 | 5 | 6 | class CustomInstanceCheck(abc.ABCMeta): 7 | """Used to support custom runtime type checks.""" 8 | 9 | def __instancecheck__(self, instance: object) -> bool: 10 | return self.__instancecheck__(instance) 11 | -------------------------------------------------------------------------------- /src/phantom_tensors/numpy.py: -------------------------------------------------------------------------------- 1 | # pyright: strict 2 | try: 3 | from numpy import ndarray as _ndarray 4 | from numpy.typing import NDArray as _NDArray 5 | except ImportError: 6 | raise ImportError("You must install numpy in order to user `phantom_tensors.numpy`") 7 | 8 | from typing import TYPE_CHECKING, Any, Generic, Sequence, SupportsIndex, Tuple, Union 9 | 10 | import typing_extensions as _te 11 | 12 | from ._internals.dim_binding import check 13 | from .meta import CustomInstanceCheck 14 | 15 | __all__ = ["NDArray"] 16 | 17 | 18 | Shape = _te.TypeVarTuple("Shape") 19 | 20 | 21 | class NDArray(Generic[_te.Unpack[Shape]], _NDArray[Any]): 22 | 23 | if not TYPE_CHECKING: 24 | _cache = {} 25 | 26 | @classmethod 27 | def __class_getitem__(cls, key): 28 | if not isinstance(key, tuple): 29 | key = (key,) 30 | 31 | class PhantomNDArray( 32 | _ndarray, 33 | metaclass=CustomInstanceCheck, 34 | ): 35 | __origin__ = _ndarray 36 | # TODO: conform with ndarray[shape, dtype] 37 | __args__ = key 38 | 39 | @classmethod 40 | def __instancecheck__(cls, __instance: object) -> bool: 41 | if not isinstance(__instance, _ndarray): 42 | return False 43 | return check(key, __instance.shape) 44 | 45 | return PhantomNDArray 46 | 47 | @property 48 | def shape(self) -> Tuple[_te.Unpack[Shape]]: # type: ignore 49 | ... 50 | 51 | @shape.setter 52 | def shape(self, value: Union[SupportsIndex, Sequence[SupportsIndex]]) -> None: ... 53 | -------------------------------------------------------------------------------- /src/phantom_tensors/parse.py: -------------------------------------------------------------------------------- 1 | from ._internals.parse import Parser, parse 2 | 3 | __all__ = ["parse", "Parser"] 4 | -------------------------------------------------------------------------------- /src/phantom_tensors/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rsokl/phantom-tensors/9569cb2fee80ae12b71f398c1bdcabdb15addf36/src/phantom_tensors/py.typed -------------------------------------------------------------------------------- /src/phantom_tensors/torch.py: -------------------------------------------------------------------------------- 1 | # pyright: strict 2 | try: 3 | from torch import Tensor as _Tensor 4 | except ImportError: 5 | raise ImportError( 6 | "You must install pytorch in order to user `phantom_tensors.torch`" 7 | ) 8 | 9 | from typing import TYPE_CHECKING, Generic, Tuple 10 | 11 | import typing_extensions as _te 12 | 13 | from phantom_tensors._internals.dim_binding import check 14 | from phantom_tensors.meta import CustomInstanceCheck 15 | 16 | __all__ = ["Tensor"] 17 | 18 | 19 | Shape = _te.TypeVarTuple("Shape") 20 | 21 | 22 | class _NewMeta(CustomInstanceCheck, type(_Tensor)): ... 23 | 24 | 25 | class Tensor(Generic[_te.Unpack[Shape]], _Tensor): 26 | if not TYPE_CHECKING: 27 | # TODO: add caching 28 | 29 | @classmethod 30 | def __class_getitem__(cls, key): 31 | if not isinstance(key, tuple): 32 | key = (key,) 33 | 34 | class PhantomTensor( 35 | _Tensor, 36 | metaclass=_NewMeta, 37 | ): 38 | __origin__ = _Tensor 39 | __args__ = key 40 | 41 | @classmethod 42 | def __instancecheck__(cls, __instance: object) -> bool: 43 | if not isinstance(__instance, _Tensor): 44 | return False 45 | return check(key, __instance.shape) 46 | 47 | return PhantomTensor 48 | 49 | @property 50 | def shape(self) -> Tuple[_te.Unpack[Shape]]: # type: ignore 51 | ... 52 | -------------------------------------------------------------------------------- /src/phantom_tensors/words.py: -------------------------------------------------------------------------------- 1 | # pyright: strict 2 | """Convenient definitions of 3 | 4 | = NewType('', int) 5 | 6 | to be used as descriptive axis names in shape types""" 7 | from typing import NewType as _NewType 8 | 9 | __all__ = [ 10 | "Axis", 11 | "Batch", 12 | "Channel", 13 | "Embed", 14 | "Filter", 15 | "Height", 16 | "Kernel", 17 | "Logits", 18 | "Score", 19 | "Time", 20 | "Vocab", 21 | "Width", 22 | "Axis1", 23 | "Axis2", 24 | "Axis3", 25 | "Axis4", 26 | "Axis5", 27 | "Axis6", 28 | "Dim1", 29 | "Dim2", 30 | "Dim3", 31 | "Dim4", 32 | "Dim5", 33 | "Dim6", 34 | ] 35 | 36 | Axis = _NewType("Axis", int) 37 | Batch = _NewType("Batch", int) 38 | Channel = _NewType("Channel", int) 39 | Embed = _NewType("Embed", int) 40 | Filter = _NewType("Filter", int) 41 | Height = _NewType("Height", int) 42 | Kernel = _NewType("Kernel", int) 43 | Logits = _NewType("Logits", int) 44 | Score = _NewType("Score", int) 45 | Time = _NewType("Time", int) 46 | Vocab = _NewType("Vocab", int) 47 | Width = _NewType("Width", int) 48 | 49 | Axis1 = _NewType("Axis1", int) 50 | Axis2 = _NewType("Axis2", int) 51 | Axis3 = _NewType("Axis3", int) 52 | Axis4 = _NewType("Axis4", int) 53 | Axis5 = _NewType("Axis5", int) 54 | Axis6 = _NewType("Axis6", int) 55 | 56 | Dim1 = _NewType("Dim1", int) 57 | Dim2 = _NewType("Dim2", int) 58 | Dim3 = _NewType("Dim3", int) 59 | Dim4 = _NewType("Dim4", int) 60 | Dim5 = _NewType("Dim5", int) 61 | Dim6 = _NewType("Dim6", int) 62 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rsokl/phantom-tensors/9569cb2fee80ae12b71f398c1bdcabdb15addf36/tests/__init__.py -------------------------------------------------------------------------------- /tests/annotations.py: -------------------------------------------------------------------------------- 1 | # pyright: strict 2 | from __future__ import annotations 3 | 4 | import sys 5 | from typing import Tuple as tuple 6 | 7 | import torch as tr 8 | from typing_extensions import assert_type 9 | 10 | from phantom_tensors import parse 11 | from phantom_tensors.alphabet import A, B 12 | from phantom_tensors.torch import Tensor 13 | 14 | 15 | def parse_tensors(x: tr.Tensor): 16 | assert_type(parse(x, Tensor[A]), Tensor[A]) 17 | assert_type(parse((x, Tensor[A])), Tensor[A]) 18 | assert_type( 19 | parse( 20 | (x, Tensor[A]), 21 | (x, Tensor[A, B]), 22 | ), 23 | tuple[Tensor[A], Tensor[A, B]], 24 | ) 25 | assert_type( 26 | parse( 27 | (x, Tensor[A]), 28 | (x, Tensor[A, B]), 29 | (x, Tensor[B, A]), 30 | ), 31 | tuple[Tensor[A], Tensor[A, B], Tensor[B, A]], 32 | ) 33 | 34 | assert_type( 35 | parse( 36 | (x, Tensor[A]), 37 | (x, Tensor[A, B]), 38 | (x, Tensor[B, A]), 39 | (x, Tensor[B, B]), 40 | ), 41 | tuple[Tensor[A], Tensor[A, B], Tensor[B, A], Tensor[B, B]], 42 | ) 43 | 44 | assert_type( 45 | parse( 46 | (x, Tensor[A]), 47 | (x, Tensor[A, B]), 48 | (x, Tensor[B, A]), 49 | (x, Tensor[B, B]), 50 | (x, Tensor[B]), 51 | ), 52 | tuple[Tensor[A], Tensor[A, B], Tensor[B, A], Tensor[B, B], Tensor[B]], 53 | ) 54 | 55 | assert_type( 56 | parse( 57 | (x, Tensor[A]), 58 | (x, Tensor[A, B]), 59 | (x, Tensor[B, A]), 60 | (x, Tensor[B, B]), 61 | (x, Tensor[B]), 62 | (x, Tensor[A]), 63 | ), 64 | tuple[ 65 | Tensor[A], Tensor[A, B], Tensor[B, A], Tensor[B, B], Tensor[B], Tensor[A] 66 | ], 67 | ) 68 | 69 | 70 | def check_bad_tensor_parse(x: tr.Tensor): 71 | parse(1, tuple[A]) # type: ignore 72 | parse((1,), A) # type: ignore 73 | 74 | parse(x, A) # type: ignore 75 | parse((x, A)) # type: ignore 76 | parse((x, A), (x, tuple[A])) # type: ignore 77 | 78 | 79 | def check_readme_blurb_one(): 80 | import numpy as np 81 | 82 | from phantom_tensors import parse 83 | from phantom_tensors.numpy import NDArray 84 | 85 | # runtime: checks that shapes (2, 3) and (3, 2) 86 | # match (A, B) and (B, A) pattern 87 | if sys.version_info < (3, 8): 88 | return 89 | 90 | x, y = parse( 91 | (np.ones((2, 3)), NDArray[A, B]), 92 | (np.ones((3, 2)), NDArray[B, A]), 93 | ) 94 | 95 | assert_type(x, NDArray[A, B]) 96 | assert_type(y, NDArray[B, A]) 97 | 98 | 99 | def check_readme2(): 100 | from typing import Any 101 | 102 | import numpy as np 103 | 104 | from phantom_tensors import parse 105 | from phantom_tensors.alphabet import A, B # these are just NewType(..., int) types 106 | from phantom_tensors.numpy import NDArray 107 | 108 | def func_on_2d(x: NDArray[Any, Any]): ... 109 | 110 | def func_on_3d(x: NDArray[Any, Any, Any]): ... 111 | 112 | def func_on_any_arr(x: np.ndarray[Any, Any]): ... 113 | 114 | if sys.version_info < (3, 8): 115 | return 116 | 117 | # runtime: ensures shape of arr_3d matches (A, B, A) patterns 118 | arr_3d = parse(np.ones((3, 5, 3)), NDArray[A, B, A]) 119 | 120 | func_on_2d(arr_3d) # type: ignore 121 | func_on_3d(arr_3d) # static type checker: OK 122 | func_on_any_arr(arr_3d) # static type checker: OK 123 | 124 | 125 | def check_readme3(): 126 | from typing import TypeVar, cast 127 | 128 | import torch as tr 129 | from beartype import beartype 130 | 131 | from phantom_tensors import dim_binding_scope 132 | from phantom_tensors.alphabet import A, B, C 133 | from phantom_tensors.torch import Tensor 134 | 135 | T1 = TypeVar("T1") 136 | T2 = TypeVar("T2") 137 | T3 = TypeVar("T3") 138 | 139 | @dim_binding_scope 140 | @beartype # <- adds runtime type checking to function's interfaces 141 | def buggy_matmul(x: Tensor[T1, T2], y: Tensor[T2, T3]) -> Tensor[T1, T3]: 142 | out = x @ x.T # <- wrong operation! 143 | # Will return shape-(A, A) tensor, not (A, C) 144 | # (and we lie to the static type checker to try to get away with it) 145 | return cast(Tensor[T1, T3], out) 146 | 147 | x, y = parse( 148 | (tr.ones(3, 4), Tensor[A, B]), 149 | (tr.ones(4, 5), Tensor[B, C]), 150 | ) 151 | assert_type(x, Tensor[A, B]) 152 | assert_type(y, Tensor[B, C]) 153 | 154 | # At runtime: 155 | # beartype raises and catches shape-mismatch of output. 156 | # Function should return shape-(A, C) but, at runtime, returns 157 | # shape-(A, A) 158 | z = buggy_matmul(x, y) # beartype roars! 159 | 160 | assert_type(z, Tensor[A, C]) 161 | 162 | 163 | def check_readme4(): 164 | from typing import Any 165 | 166 | import numpy as np 167 | import torch as tr 168 | 169 | from phantom_tensors import parse 170 | from phantom_tensors.alphabet import A, B 171 | from phantom_tensors.numpy import NDArray 172 | from phantom_tensors.torch import Tensor 173 | 174 | if sys.version_info < (3, 8): 175 | return 176 | 177 | t1, arr, t2 = parse( 178 | # <- Runtime: enter dimension-binding context 179 | (tr.rand(9, 2, 9), Tensor[B, A, B]), # <-binds A=2 & B=9 180 | (np.ones((2,)), NDArray[A]), # <- checks A==2 181 | (tr.rand(9), Tensor[B]), # <- checks B==9 182 | ) # <- Runtime: exit dimension-binding scope 183 | # Statically: casts t1, arr, t2 to shape-typed Tensors 184 | 185 | # static type checkers now see 186 | # t1: Tensor[B, A, B] 187 | # arr: NDArray[A] 188 | # t2: Tensor[B] 189 | 190 | w = parse(tr.rand(78), Tensor[A]) 191 | # <- binds A=78 within this context 192 | 193 | assert_type(t1, Tensor[B, A, B]) 194 | assert_type(arr, NDArray[A]) 195 | assert_type(t2, Tensor[B]) 196 | assert_type(w, Tensor[A]) 197 | 198 | def vanilla_numpy(x: np.ndarray[Any, Any]): ... 199 | 200 | def vanilla_torch(x: tr.Tensor): ... 201 | 202 | vanilla_numpy(arr) # type checker: OK 203 | vanilla_torch(t1) # type checker: OK 204 | vanilla_torch(arr) # type: ignore [type checker: Error!] 205 | 206 | 207 | def check_phantom_example(): 208 | from typing import Any 209 | 210 | import torch as tr 211 | 212 | from phantom import Phantom 213 | from phantom_tensors import parse 214 | from phantom_tensors.torch import Tensor 215 | 216 | class EvenOnly(int, Phantom[Any], predicate=lambda x: x % 2 == 0): ... 217 | 218 | assert_type(parse(tr.ones(1, 0), Tensor[int, EvenOnly]), Tensor[int, EvenOnly]) 219 | assert_type(parse(tr.ones(1, 2), Tensor[int, EvenOnly]), Tensor[int, EvenOnly]) 220 | assert_type(parse(tr.ones(1, 4), Tensor[int, EvenOnly]), Tensor[int, EvenOnly]) 221 | 222 | parse(tr.ones(1, 3), Tensor[int, EvenOnly]) # runtime: ParseError 223 | 224 | 225 | def check_beartype_example(): 226 | from typing import Any 227 | 228 | import pytest 229 | from beartype import beartype 230 | from typing_extensions import assert_type 231 | 232 | from phantom_tensors import dim_binding_scope, parse 233 | from phantom_tensors.alphabet import A, B, C 234 | from phantom_tensors.torch import Tensor 235 | 236 | # @dim_binding_scope: 237 | # ensures A, B, C consistent across all input/output tensor shapes 238 | # within scope of function 239 | @dim_binding_scope 240 | @beartype # <-- adds isinstance checks on inputs & outputs 241 | def matrix_multiply(x: Tensor[A, B], y: Tensor[B, C]) -> Tensor[A, C]: 242 | a, _ = x.shape 243 | _, c = y.shape 244 | return parse(tr.rand(a, c), Tensor[A, C]) 245 | 246 | @beartype 247 | def needs_vector(x: Tensor[Any]): ... 248 | 249 | x, y = parse( 250 | (tr.rand(3, 4), Tensor[A, B]), 251 | (tr.rand(4, 5), Tensor[B, C]), 252 | ) 253 | 254 | z = matrix_multiply(x, y) 255 | assert_type(z, Tensor[A, C]) 256 | 257 | with pytest.raises(Exception): 258 | # beartype raises error: input Tensor[A, C] doesn't match Tensor[A] 259 | needs_vector(z) # type: ignore 260 | 261 | with pytest.raises(Exception): 262 | # beartype raises error: inputs Tensor[A, B], Tensor[A, B] don't match signature 263 | matrix_multiply(x, x) # type: ignore 264 | 265 | 266 | # def check_parse_ints(x: int | Literal[1]): 267 | # assert_type( 268 | # parse_ints(x, A), 269 | # A, 270 | # ) 271 | # assert_type( 272 | # parse_ints((x, A)), 273 | # A, 274 | # ) 275 | # assert_type( 276 | # parse_ints( 277 | # (1, A), 278 | # (2, B), 279 | # ), 280 | # tuple[A, B], 281 | # ) 282 | # assert_type( 283 | # parse_ints( 284 | # (1, A), 285 | # (2, B), 286 | # (2, B), 287 | # ), 288 | # tuple[A, B, B], 289 | # ) 290 | # assert_type( 291 | # parse_ints( 292 | # (1, A), 293 | # (2, B), 294 | # (2, B), 295 | # (1, A), 296 | # ), 297 | # tuple[A, B, B, A], 298 | # ) 299 | # assert_type( 300 | # parse_ints( 301 | # (1, A), 302 | # (2, B), 303 | # (2, B), 304 | # (1, A), 305 | # (2, B), 306 | # ), 307 | # tuple[A, B, B, A, B], 308 | # ) 309 | # assert_type( 310 | # parse_ints((1, A), (2, B), (2, B), (1, A), (2, B), (1, A)), 311 | # tuple[A, B, B, A, B, A], 312 | # ) 313 | 314 | 315 | # def check_parse_tuples(x: int): 316 | 317 | # assert_type( 318 | # parse_tuples((x,), tuple[A]), 319 | # tuple[A], 320 | # ) 321 | 322 | # assert_type( 323 | # parse_tuples(((x,), tuple[A])), 324 | # tuple[A], 325 | # ) 326 | # assert_type( 327 | # parse_tuples( 328 | # ((x,), tuple[A]), 329 | # ((x, x), tuple[A, A]), 330 | # ), 331 | # tuple[ 332 | # tuple[A], 333 | # tuple[A, A], 334 | # ], 335 | # ) 336 | # assert_type( 337 | # parse_tuples( 338 | # ((x,), tuple[A]), 339 | # ((x, x), tuple[A, A]), 340 | # ((x, x), tuple[A, A, B]), 341 | # ), 342 | # tuple[ 343 | # tuple[A], 344 | # tuple[A, A], 345 | # tuple[A, A, B], 346 | # ], 347 | # ) 348 | # assert_type( 349 | # parse_tuples( 350 | # ((x,), tuple[A]), 351 | # ((x, x), tuple[A, A]), 352 | # ((x, x, x), tuple[A, A, B]), 353 | # ((x, x, x), tuple[A, A, B]), 354 | # ), 355 | # tuple[ 356 | # tuple[A], 357 | # tuple[A, A], 358 | # tuple[A, A, B], 359 | # tuple[A, A, B], 360 | # ], 361 | # ) 362 | # assert_type( 363 | # parse_tuples( 364 | # ((x,), tuple[A]), 365 | # ((x, x), tuple[A, A]), 366 | # ((x, x, x), tuple[A, A, B]), 367 | # ((x, x, x), tuple[A, A, B]), 368 | # ((x,), tuple[A]), 369 | # ), 370 | # tuple[tuple[A], tuple[A, A], tuple[A, A, B], tuple[A, A, B], tuple[A]], 371 | # ) 372 | # assert_type( 373 | # parse_tuples( 374 | # ((x,), tuple[A]), 375 | # ((x, x), tuple[A, A]), 376 | # ((x, x, x), tuple[A, A, B]), 377 | # ((x, x, x), tuple[A, A, B]), 378 | # ((x,), tuple[A]), 379 | # ((x,), tuple[B]), 380 | # ), 381 | # tuple[ 382 | # tuple[A], tuple[A, A], tuple[A, A, B], tuple[A, A, B], tuple[A], tuple[B] 383 | # ], 384 | # ) 385 | 386 | 387 | # def check_bad_int_parse(x: tr.Tensor, y: tuple[int, ...], z: int): 388 | # parse_ints(x, A) 389 | # parse_ints(y, A) 390 | # parse_ints((x, A)) 391 | # parse_ints((y, A)) 392 | 393 | # parse_ints(z, tuple[A]) 394 | # parse_ints(z, Tensor[A]) 395 | 396 | # def check_bad_tuple_parse(x: tr.Tensor, y: tuple[int, ...], z: int): 397 | # parse_tuples(x, tuple[A]) 398 | # parse_tuples(z, tuple[A]) 399 | 400 | # parse_tuples((x, tuple[A])) 401 | # parse_tuples((z, tuple[A])) 402 | 403 | # parse_ints(y, A) 404 | # parse_ints(y, Tensor[A]) 405 | -------------------------------------------------------------------------------- /tests/arrlike.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | 3 | 4 | class ImplementsArray: 5 | def __init__(self, shape: Tuple[int, ...]) -> None: 6 | assert all(i >= 0 and isinstance(i, int) for i in shape) 7 | self._shape = shape 8 | 9 | @property 10 | def shape(self) -> Tuple[int, ...]: 11 | return self._shape 12 | 13 | def __array__(self) -> Any: ... 14 | 15 | 16 | def arr(*shape: int) -> ImplementsArray: 17 | return ImplementsArray(shape) 18 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | 3 | # Skip collection of tests that require additional dependencies 4 | collect_ignore_glob = [] 5 | 6 | OPTIONAL_TEST_DEPENDENCIES = ( 7 | "numpy", 8 | "torch", 9 | "beartype", 10 | "phantom-types", 11 | ) 12 | 13 | _installed = {dist.metadata["Name"] for dist in importlib.metadata.distributions()} 14 | 15 | if any(_module_name not in _installed for _module_name in OPTIONAL_TEST_DEPENDENCIES): 16 | collect_ignore_glob.append("*third_party.py") 17 | -------------------------------------------------------------------------------- /tests/test_letters_and_words.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import phantom_tensors.alphabet as alphabet 4 | import phantom_tensors.words as words 5 | from phantom_tensors.alphabet import __all__ as all_letters 6 | from phantom_tensors.words import __all__ as all_words 7 | 8 | 9 | @pytest.mark.parametrize("name", all_letters) 10 | def test_shipped_letters_are_named_ints(name: str): 11 | Type = getattr(alphabet, name) 12 | assert Type.__name__ == name 13 | assert Type.__supertype__ is int 14 | 15 | 16 | @pytest.mark.parametrize("name", all_words) 17 | def test_shipped_words_are_named_ints(name: str): 18 | Type = getattr(words, name) 19 | assert Type.__name__ == name 20 | assert Type.__supertype__ is int 21 | -------------------------------------------------------------------------------- /tests/test_prototype.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Any, Literal as L, NewType, TypeVar 3 | 4 | import pytest 5 | from pytest import param 6 | from typing_extensions import TypeVarTuple, Unpack as U 7 | 8 | import phantom_tensors 9 | from phantom_tensors import dim_binding_scope, parse 10 | from phantom_tensors.array import SupportsArray as Array 11 | from phantom_tensors.errors import ParseError 12 | from tests.arrlike import arr 13 | 14 | T = TypeVar("T") 15 | Ts = TypeVarTuple("Ts") 16 | 17 | A = NewType("A", int) 18 | B = NewType("B", int) 19 | C = NewType("C", int) 20 | 21 | parse_xfail = pytest.mark.xfail(raises=ParseError) 22 | 23 | 24 | def test_version(): 25 | assert phantom_tensors.__version__ != "unknown" 26 | 27 | 28 | def test_parse_error_msg(): 29 | with pytest.raises( 30 | ParseError, 31 | match=re.escape("shape-(2, 1) doesn't match shape-type ([...], A=2, A=2)"), 32 | ): 33 | parse(arr(2, 1), Array[U[Ts], A, A]) # type: ignore 34 | 35 | 36 | @pytest.mark.parametrize( 37 | "tensor_type_pairs", 38 | [ 39 | # (arr(), Array[()]), 40 | (arr(), Array[U[Ts]]), # type: ignore 41 | (arr(2), Array[A]), 42 | (arr(2), Array[int]), 43 | (arr(2), Array[Any]), 44 | (arr(2), Array[U[Ts]]), # type: ignore 45 | (arr(2), Array[U[Ts], A]), # type: ignore 46 | (arr(2), Array[A, U[Ts]]), # type: ignore 47 | (arr(2, 2), Array[A, A]), 48 | (arr(2, 2), Array[U[Ts], A, A]), # type: ignore 49 | (arr(2, 2), Array[A, U[Ts], A]), # type: ignore 50 | (arr(2, 2), Array[A, A, U[Ts]]), # type: ignore 51 | (arr(1, 3, 2, 2), Array[U[Ts], A, A]), # type: ignore 52 | (arr(2, 1, 3, 2), Array[A, U[Ts], A]), # type: ignore 53 | (arr(2, 2, 1, 3), Array[A, A, U[Ts]]), # type: ignore 54 | (arr(1, 2, 1, 3, 2), Array[A, B, U[Ts], B]), # type: ignore 55 | (arr(1, 2, 3), Array[Any, Any, Any]), 56 | (arr(1, 2, 3), Array[int, int, int]), 57 | (arr(2, 1, 2), Array[A, B, A]), 58 | (arr(2, 1, 3), Array[A, B, C]), 59 | ( 60 | (arr(5), Array[A]), 61 | (arr(5, 2), Array[A, B]), 62 | ), 63 | (arr(1), Array[L[1]]), 64 | (arr(3), Array[L[1, 2, 3]]), 65 | (arr(1, 2), Array[L[1], L[2]]), 66 | (arr(1, 2, 1), Array[L[1], L[2], L[1]]), 67 | param((arr(), Array[int]), marks=parse_xfail), 68 | param((arr(), Array[int, U[Ts]]), marks=parse_xfail), # type: ignore 69 | param((arr(2), Array[int, int]), marks=parse_xfail), 70 | param((arr(2, 4), Array[A, A]), marks=parse_xfail), 71 | param((arr(2, 1, 1), Array[A, B, A]), marks=parse_xfail), 72 | param((arr(1, 1, 2), Array[A, A, A]), marks=parse_xfail), 73 | param((arr(2, 1, 1), Array[A, U[Ts], A]), marks=parse_xfail), # type: ignore 74 | param((arr(1), Array[A, B, C]), marks=parse_xfail), 75 | param(((arr(2, 4), Array[A, A]),), marks=parse_xfail), 76 | param( 77 | ( 78 | (arr(4), Array[A]), 79 | (arr(5), Array[A]), 80 | ), 81 | marks=parse_xfail, 82 | ), 83 | param((arr(3), Array[L[1]]), marks=parse_xfail), 84 | param((arr(2, 2), Array[L[1], L[2]]), marks=parse_xfail), 85 | param((arr(1, 1), Array[L[1], L[2]]), marks=parse_xfail), 86 | param( 87 | (arr(1, 1, 1), Array[L[1], L[2], L[1]]), 88 | marks=parse_xfail, 89 | ), 90 | ], 91 | ) 92 | def test_parse_consistent_types(tensor_type_pairs): 93 | parse(*tensor_type_pairs) 94 | 95 | 96 | def test_parse_in_and_out_of_binding_scope(): 97 | with dim_binding_scope: 98 | 99 | parse(arr(2), Array[A]) # binds A=2 100 | 101 | with pytest.raises(ParseError): 102 | parse(arr(3), Array[A]) 103 | 104 | parse(arr(2), Array[A]) 105 | 106 | parse(arr(2, 4), Array[A, B]) # binds B=4 107 | parse(arr(2, 9), Array[A, int]) 108 | 109 | with pytest.raises(ParseError): 110 | parse(arr(2), Array[B]) 111 | 112 | # no dims bound 113 | parse(arr(1, 3, 3, 1), Array[B, A, A, B]) # no dims bound 114 | parse(arr(1, 4, 4, 1), Array[B, A, A, B]) 115 | 116 | parse( 117 | (arr(9), Array[B]), 118 | (arr(9, 2, 2, 9), Array[B, A, A, B]), 119 | ) 120 | 121 | 122 | def test_parse_bind_multiple(): 123 | with dim_binding_scope: # enter dimension-binding scope 124 | parse( 125 | (arr(2), Array[A]), # <-binds A=2 126 | (arr(9), Array[B]), # <-binds B=9 127 | (arr(9, 2, 9), Array[B, A, B]), # <-checks A & B 128 | ) 129 | 130 | with pytest.raises( 131 | ParseError, 132 | match=re.escape("shape-(78,) doesn't match shape-type (A=2,)"), 133 | ): 134 | # can't re-bind A within scope 135 | parse(arr(78), Array[A]) 136 | 137 | with pytest.raises( 138 | ParseError, 139 | match=re.escape("shape-(22,) doesn't match shape-type (B=9,)"), 140 | ): 141 | # can't re-bind B within scope 142 | parse(arr(22), Array[B]) 143 | 144 | parse(arr(2), Array[A]) 145 | parse(arr(9), Array[B]) 146 | 147 | # exit dimension-binding scope 148 | 149 | parse(arr(78, 22), Array[A, B]) # now ok 150 | 151 | 152 | AStr = NewType("AStr", str) 153 | 154 | 155 | @pytest.mark.parametrize( 156 | "bad_type", 157 | [ 158 | Array[AStr], 159 | Array[str], 160 | Array[int, str], 161 | Array[U[Ts], str], # type: ignore 162 | ], 163 | ) 164 | def test_bad_type_validation(bad_type): 165 | with pytest.raises(TypeError): 166 | parse(arr(1), bad_type) 167 | -------------------------------------------------------------------------------- /tests/test_third_party.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import NewType, TypeVar, cast 3 | 4 | import numpy as np 5 | import pytest 6 | import torch as tr 7 | from beartype import beartype 8 | from beartype.roar import ( 9 | BeartypeCallHintParamViolation, 10 | BeartypeCallHintReturnViolation, 11 | ) 12 | from typing_extensions import TypeVarTuple, Unpack as U 13 | 14 | from phantom import Phantom 15 | from phantom_tensors import dim_binding_scope, parse 16 | from phantom_tensors.alphabet import A, B, C 17 | from phantom_tensors.array import SupportsArray as Array 18 | from phantom_tensors.errors import ParseError 19 | from phantom_tensors.numpy import NDArray 20 | from phantom_tensors.torch import Tensor 21 | from tests.arrlike import arr 22 | 23 | T = TypeVar("T") 24 | Ts = TypeVarTuple("Ts") 25 | 26 | 27 | class One_to_Three(int, Phantom, predicate=lambda x: 0 < x < 4): ... 28 | 29 | 30 | class Ten_or_Eleven(int, Phantom, predicate=lambda x: 10 <= x <= 11): ... 31 | 32 | 33 | class EvenOnly(int, Phantom, predicate=lambda x: x % 2 == 0): ... 34 | 35 | 36 | NewOneToThree = NewType("NewOneToThree", One_to_Three) 37 | 38 | 39 | def test_NDArray(): 40 | assert issubclass(NDArray, np.ndarray) 41 | assert issubclass(NDArray[A], np.ndarray) 42 | 43 | parse(np.ones((2,)), NDArray[A]) 44 | with pytest.raises(ParseError): 45 | parse(np.ones((2, 3)), NDArray[A, A]) 46 | 47 | 48 | def test_Tensor(): 49 | assert issubclass(Tensor, tr.Tensor) 50 | assert issubclass(Tensor[A], tr.Tensor) 51 | 52 | parse(tr.ones((2,)), Tensor[A]) 53 | with pytest.raises(ParseError): 54 | parse(tr.ones((2, 3)), Tensor[A, A]) 55 | 56 | 57 | @pytest.mark.parametrize( 58 | "tensor_type_pairs", 59 | [ 60 | (tr.ones(2), NDArray[int]), # type mismatch 61 | (np.ones((2,)), Tensor[int]), # type mismatch 62 | (arr(10, 2, 10), Array[One_to_Three, int, Ten_or_Eleven]), 63 | (arr(10, 2, 10), Array[NewOneToThree, int, Ten_or_Eleven]), 64 | (arr(2, 2, 8), Array[NewOneToThree, int, Ten_or_Eleven]), 65 | (arr(0, 10), Array[One_to_Three, U[Ts], Ten_or_Eleven]), # type: ignore 66 | (arr(2, 2, 0), Array[One_to_Three, U[Ts], Ten_or_Eleven]), # type: ignore 67 | (arr(2, 0, 0, 0), Array[One_to_Three, U[Ts], Ten_or_Eleven]), # type: ignore 68 | (arr(0, 0, 2, 0), Array[U[Ts], One_to_Three, Ten_or_Eleven]), # type: ignore 69 | ], 70 | ) 71 | def test_parse_inconsistent_types(tensor_type_pairs): 72 | with pytest.raises(ParseError): 73 | parse(*tensor_type_pairs) 74 | 75 | 76 | def test_phantom_checks(): 77 | assert not isinstance(np.ones((2,)), Tensor[int]) # type: ignore 78 | assert not isinstance(tr.ones((2,)), NDArray[int]) # type: ignore 79 | 80 | assert not isinstance(tr.ones((2, 2)), Tensor[int]) # type: ignore 81 | assert not isinstance(np.ones((2, 2)), NDArray[int]) # type: ignore 82 | 83 | 84 | def test_type_var_with_beartype(): 85 | @dim_binding_scope 86 | @beartype 87 | def diag(sqr: Tensor[T, T]) -> Tensor[T]: 88 | return cast(Tensor[T], tr.diag(sqr)) 89 | 90 | non_sqr = parse(tr.ones(2, 3), Tensor[A, B]) 91 | with pytest.raises(BeartypeCallHintParamViolation): 92 | diag(non_sqr) # type: ignore 93 | 94 | 95 | def test_matmul_example(): 96 | @dim_binding_scope 97 | @beartype 98 | def matrix_multiply(x: Tensor[A, B], y: Tensor[B, C]) -> Tensor[A, C]: 99 | out = x @ x.T 100 | return cast(Tensor[A, C], out) 101 | 102 | x, y = parse( 103 | (tr.ones(3, 4), Tensor[A, B]), 104 | (tr.ones(4, 5), Tensor[B, C]), 105 | ) 106 | # x # type revealed: Tensor[A, B] 107 | # y # type revealed: Tensor[B, C] 108 | 109 | with pytest.raises(BeartypeCallHintReturnViolation): 110 | matrix_multiply(x, y) 111 | 112 | 113 | def test_runtime_checking_with_beartype(): 114 | @dim_binding_scope 115 | # ^ ensures A, B, C consistent across all input/output tensor shapes 116 | # within scope of function 117 | @beartype 118 | def matrix_multiply(x: Tensor[A, B], y: Tensor[B, C]) -> Tensor[A, C]: 119 | a, b = x.shape 120 | b, c = y.shape 121 | return cast(Tensor[A, C], tr.ones(a, c)) 122 | 123 | @beartype 124 | def needs_vector(x: Tensor[int]): ... 125 | 126 | x, y = parse( 127 | (tr.ones(3, 4), Tensor[A, B]), 128 | (tr.ones(4, 5), Tensor[B, C]), 129 | ) 130 | # x # type revealed: Tensor[A, B] 131 | # y # type revealed: Tensor[B, C] 132 | 133 | z = matrix_multiply(x, y) 134 | # z # type revealed: Tensor[A, C] 135 | 136 | with pytest.raises(Exception): 137 | needs_vector(z) # type: ignore 138 | 139 | with pytest.raises(Exception): 140 | matrix_multiply(x, x) # type: ignore 141 | 142 | 143 | def test_catches_wrong_instance(): 144 | with pytest.raises( 145 | ParseError, 146 | match=re.escape( 147 | "Expected , got: " 148 | ), 149 | ): 150 | parse(tr.tensor(1), NDArray[A, B]) 151 | 152 | with pytest.raises( 153 | ParseError, 154 | match=re.escape( 155 | "Expected , got: " 156 | ), 157 | ): 158 | parse(np.array(1), Tensor[A]) 159 | 160 | 161 | def test_isinstance_works(): 162 | with dim_binding_scope: 163 | 164 | assert isinstance(tr.ones(2), Tensor[A]) # type: ignore 165 | assert not isinstance(tr.ones(3), Tensor[A]) # type: ignore 166 | assert isinstance(tr.ones(2), Tensor[A]) # type: ignore 167 | 168 | assert isinstance(tr.ones(2, 4), Tensor[A, B]) # type: ignore 169 | assert not isinstance(tr.ones(2), Tensor[B]) # type: ignore 170 | assert isinstance(tr.ones(4), Tensor[B]) # type: ignore 171 | assert isinstance(tr.ones(4, 2, 2, 4), Tensor[B, A, A, B]) # type: ignore 172 | 173 | assert isinstance(tr.ones(1, 3, 3, 1), Tensor[B, A, A, B]) # type: ignore 174 | assert isinstance(tr.ones(1, 4, 4, 1), Tensor[B, A, A, B]) # type: ignore 175 | 176 | 177 | @pytest.mark.parametrize( 178 | "tensor_type_pairs", 179 | [ 180 | (arr(2, 2, 10), Array[One_to_Three, int, Ten_or_Eleven]), 181 | (arr(2, 10), Array[One_to_Three, U[Ts], Ten_or_Eleven]), # type: ignore 182 | (arr(2, 2, 10), Array[One_to_Three, U[Ts], Ten_or_Eleven]), # type: ignore 183 | (arr(2, 0, 0, 10), Array[One_to_Three, U[Ts], Ten_or_Eleven]), # type: ignore 184 | (arr(2, 2, 10), Array[NewOneToThree, int, Ten_or_Eleven]), 185 | (arr(0, 0, 2, 11), Array[U[Ts], One_to_Three, Ten_or_Eleven]), # type: ignore 186 | ], 187 | ) 188 | def test_parse_consistent_types(tensor_type_pairs): 189 | parse(*tensor_type_pairs) 190 | 191 | 192 | def test_non_binding_subint_dims_pass(): 193 | parse(arr(2, 4, 6), Array[EvenOnly, EvenOnly, EvenOnly]) 194 | parse( 195 | (arr(2, 4), Array[EvenOnly, EvenOnly]), 196 | (arr(6, 8), Array[EvenOnly, EvenOnly]), 197 | ) 198 | 199 | 200 | def test_non_binding_subint_dims_validates(): 201 | 202 | with pytest.raises( 203 | ParseError, 204 | match=re.escape( 205 | r"shape-(2, 1) doesn't match shape-type (EvenOnly=?, EvenOnly=?)" 206 | ), 207 | ): 208 | parse(arr(2, 1), Array[EvenOnly, EvenOnly]) 209 | 210 | with pytest.raises( 211 | ParseError, 212 | match=re.escape( 213 | r"shape-(6, 3) doesn't match shape-type (EvenOnly=?, EvenOnly=?)" 214 | ), 215 | ): 216 | parse( 217 | (arr(2, 4), Array[EvenOnly, EvenOnly]), 218 | (arr(6, 3), Array[EvenOnly, EvenOnly]), 219 | ) 220 | 221 | 222 | def test_binding_validated_dims_validates(): 223 | 224 | with pytest.raises( 225 | ParseError, 226 | match=re.escape( 227 | r"shape-(1, 2, 3) doesn't match shape-type (NewOneToThree=1, NewOneToThree=1, NewOneToThree=1)" 228 | ), 229 | ): 230 | parse(arr(1, 2, 3), Array[NewOneToThree, NewOneToThree, NewOneToThree]) 231 | 232 | with pytest.raises( 233 | ParseError, 234 | match=re.escape( 235 | r"shape-(1, 2) doesn't match shape-type (NewOneToThree=1, NewOneToThree=1)" 236 | ), 237 | ): 238 | parse( 239 | (arr(1, 1), Array[NewOneToThree, NewOneToThree]), 240 | (arr(1, 2), Array[NewOneToThree, NewOneToThree]), 241 | ) 242 | 243 | 244 | @pytest.mark.parametrize("good_shape", [(1, 1), (2, 2), (3, 3)]) 245 | def test_binding_validated_dims_passes(good_shape): 246 | parse(arr(*good_shape), Array[NewOneToThree, NewOneToThree]) 247 | -------------------------------------------------------------------------------- /tests/test_type_properties.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import sys 4 | from typing import Literal, Tuple 5 | 6 | import pytest 7 | from hypothesis import assume, given 8 | 9 | 10 | @pytest.mark.xfail( 11 | sys.version_info < (3, 9), reason="Tuple normalization introduced in 3.9" 12 | ) 13 | def test_literal_singlet_tuple_same_as_scalar(): 14 | assert Literal[(1,)] is Literal[1] # type: ignore 15 | 16 | 17 | @given(...) 18 | def test_literals_can_check_by_identity(a: Tuple[int, ...], b: Tuple[int, ...]): 19 | assume(len(a)) 20 | assume(len(b)) 21 | assume(a != b) 22 | assert Literal[a] is Literal[a] # type: ignore 23 | assert Literal[b] is Literal[b] # type: ignore 24 | assert Literal[a] is not Literal[b] # type: ignore 25 | 26 | 27 | def test_literal_hashes_consistently(): 28 | assert {Literal[1]: 1}[Literal[1]] == 1 29 | assert hash(Literal[1, 2]) == hash(Literal[1, 2]) 30 | 31 | 32 | @pytest.mark.parametrize("content", [1, 2, (1, 2)]) 33 | def test_literal_args_always_tuples(content): 34 | assert Literal[content].__args__ == (content,) if not isinstance(content, tuple) else content # type: ignore 35 | --------------------------------------------------------------------------------