├── dltype ├── _lib │ ├── __init__.py │ ├── _constants.py │ ├── _dtypes.py │ ├── _dependency_utilities.py │ ├── _torch_tensors.py │ ├── _numpy_tensors.py │ ├── _errors.py │ ├── _symbolic_expressions.py │ ├── _tensor_type_base.py │ ├── _dltype_context.py │ ├── _parser.py │ ├── _universal_tensors.py │ └── _core.py ├── tests │ ├── __init__.py │ ├── interop_test.py │ ├── parser_test.py │ └── dltype_test.py └── __init__.py ├── .python-version ├── CODEOWNERS ├── .github ├── CODEOWNERS ├── pull_request_template.md └── workflows │ ├── publish.yaml │ └── pr.yaml ├── .vscode ├── extensions.json └── settings.json ├── setup.sh ├── CONTRIBUTING.md ├── .pre-commit-config.yaml ├── pyproject.toml ├── .gitignore ├── LICENSE ├── benchmark.py └── README.md /dltype/_lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.10 2 | -------------------------------------------------------------------------------- /dltype/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @dlangerm-stackav 2 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @dlangerm-stackav 2 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "charliermarsh.ruff", 4 | "ms-python.vscode-pylance" 5 | ] 6 | } 7 | -------------------------------------------------------------------------------- /dltype/_lib/_constants.py: -------------------------------------------------------------------------------- 1 | """Constants related to the dltype library.""" 2 | 3 | import typing 4 | 5 | # Constants 6 | PYDANTIC_INFO_KEY: typing.Final = "__dltype__" 7 | DEBUG_MODE: typing.Final = False 8 | MAX_ACCEPTABLE_EVALUATION_TIME_NS: typing.Final = int(5e9) # 5ms 9 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | - What is this PR doing and why? 4 | 5 | ## Testing 6 | 7 | Please select all that apply. 8 | 9 | - [ ] Existing unit tests 10 | - [ ] Unit tests added by this PR 11 | - [ ] Other (please explain) 12 | - [ ] This PR is not tested 13 | 14 | ## Test instructions 15 | 16 | - Please add relevant commands required to reproduce the testing described above 17 | - Please add relevant test outputs (screenshots, logs, etc.) 18 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script sets up the environment for the project 4 | 5 | curl -LsSf https://astral.sh/uv/install.sh | sh 6 | 7 | uv python install 8 | uv sync 9 | 10 | if ! command -v pre-commit >/dev/null 2>&1 11 | then 12 | echo "WARNING: pre-commit not found, please install it for a better dev experience" 13 | echo "pip install pre-commit" 14 | echo "pre-commit install --install-hooks" 15 | else 16 | pre-commit install --install-hooks 17 | fi 18 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | Contributions are welcome & instructions on how to set up the package for local development are below. 3 | 4 | ## Development Setup 5 | The development environment is managed via `uv`. 6 | Please run the setup script to install it and sync the project dependencies before developing. 7 | 8 | ```bash 9 | bash ./setup.sh 10 | ``` 11 | 12 | 13 | The unit tests can be run with: 14 | 15 | ```bash 16 | uv run pytest 17 | ``` 18 | 19 | and the benchmark with: 20 | 21 | ```bash 22 | uv run benchmark.py 23 | ``` 24 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "[json]": { 3 | "editor.formatOnSave": false 4 | }, 5 | "[yaml]": { 6 | "editor.formatOnSave": false 7 | }, 8 | "cSpell.words": [ 9 | "dltype", 10 | "dltyped", 11 | "tracebackhide", 12 | "numpy", 13 | "onnx", 14 | "pytest", 15 | "setuptools", 16 | "pydantic", 17 | "pyright", 18 | "ndarray", 19 | "ndims", 20 | "multiaxis", 21 | "allclose", 22 | "popleft" 23 | ], 24 | "editor.formatOnSave": true, 25 | "ruff.interpreter": [ 26 | "${workspaceFolder}/.venv/bin/python" 27 | ], 28 | "ruff.path": [ 29 | "${workspaceFolder}/.venv/bin/ruff" 30 | ] 31 | } 32 | -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | on: 2 | release: 3 | types: 4 | - published 5 | 6 | name: release 7 | 8 | jobs: 9 | pypi-publish: 10 | name: upload release to PyPI 11 | runs-on: ubuntu-latest 12 | # Specifying a GitHub environment is optional, but strongly encouraged 13 | environment: pypi 14 | permissions: 15 | # IMPORTANT: this permission is mandatory for Trusted Publishing 16 | id-token: write 17 | steps: 18 | - name: Checkout repository 19 | uses: actions/checkout@v5 20 | 21 | - name: Install uv 22 | uses: astral-sh/setup-uv@v5 23 | with: 24 | version: 0.7.17 25 | 26 | - name: Install the project 27 | run: uv sync --locked --all-extras --dev 28 | 29 | - name: Run tests 30 | run: | 31 | uv run pytest 32 | 33 | - name: Run benchmarks 34 | run: | 35 | uv run benchmark.py 36 | 37 | - name: Build package distributions 38 | run: uv build 39 | 40 | - name: Publish package distributions to PyPI 41 | run: uv publish 42 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: no-commit-to-branch 6 | args: [--branch, main] 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | - id: check-toml 11 | - id: mixed-line-ending 12 | args: [--fix, lf] 13 | - id: check-json 14 | - id: check-added-large-files 15 | - id: check-merge-conflict 16 | - id: debug-statements 17 | - id: name-tests-test 18 | - id: pretty-format-json 19 | args: 20 | - --autofix 21 | - --indent=2 22 | - repo: https://github.com/astral-sh/ruff-pre-commit 23 | # Ruff version. 24 | rev: v0.12.0 25 | hooks: 26 | - id: ruff-check 27 | types_or: [python, pyi] 28 | args: [--fix] 29 | - id: ruff-format 30 | types_or: [python, pyi] 31 | - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks 32 | rev: v2.14.0 33 | hooks: 34 | - id: pretty-format-yaml 35 | args: [--autofix, --indent, '2', --offset, '2', --line-width, '80'] 36 | - id: pretty-format-toml 37 | args: [--autofix, --indent, '2'] 38 | - repo: https://github.com/astral-sh/uv-pre-commit 39 | # uv version. 40 | rev: 0.7.15 41 | hooks: 42 | - id: uv-lock 43 | -------------------------------------------------------------------------------- /dltype/_lib/_dtypes.py: -------------------------------------------------------------------------------- 1 | """Supported datatypes.""" 2 | 3 | import typing 4 | 5 | from dltype._lib import ( 6 | _dependency_utilities as _deps, 7 | ) 8 | 9 | # NOTE: the order of these is important, pyright assumes the last branch is taken 10 | # so we get proper union type hint checking 11 | if _deps.is_numpy_available() and not _deps.is_torch_available(): 12 | import numpy as np 13 | import numpy.typing as npt 14 | 15 | DLtypeTensorT: typing.TypeAlias = npt.NDArray[typing.Any] # pyright: ignore[reportRedeclaration] 16 | DLtypeDtypeT: typing.TypeAlias = npt.DTypeLike # pyright: ignore[reportRedeclaration] 17 | SUPPORTED_TENSOR_TYPES: typing.Final = {np.ndarray} 18 | elif _deps.is_torch_available() and not _deps.is_numpy_available(): 19 | import torch 20 | 21 | DLtypeTensorT: typing.TypeAlias = torch.Tensor # pyright: ignore[reportRedeclaration] 22 | DLtypeDtypeT: typing.TypeAlias = torch.dtype # pyright: ignore[reportRedeclaration] 23 | SUPPORTED_TENSOR_TYPES: typing.Final = {torch.Tensor} # pyright: ignore[reportGeneralTypeIssues, reportConstantRedefinition] 24 | elif _deps.is_numpy_available() and _deps.is_torch_available(): 25 | import numpy as np 26 | import numpy.typing as npt 27 | import torch 28 | 29 | DLtypeTensorT: typing.TypeAlias = torch.Tensor | npt.NDArray[typing.Any] # pyright: ignore[reportRedeclaration] 30 | DLtypeDtypeT: typing.TypeAlias = torch.dtype | npt.DTypeLike # pyright: ignore[reportRedeclaration] 31 | SUPPORTED_TENSOR_TYPES: typing.Final = {torch.Tensor, np.ndarray} # pyright: ignore[reportGeneralTypeIssues, reportConstantRedefinition] 32 | else: 33 | _deps.raise_for_missing_dependency() 34 | -------------------------------------------------------------------------------- /.github/workflows/pr.yaml: -------------------------------------------------------------------------------- 1 | name: Enforce Unit Tests 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main # or master 7 | push: 8 | branches: 9 | - main # or master 10 | 11 | jobs: 12 | run-unit-tests: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Checkout repository 16 | uses: actions/checkout@v5 17 | 18 | - name: Install uv 19 | uses: astral-sh/setup-uv@v5 20 | with: 21 | version: 0.7.17 22 | enable-cache: true 23 | 24 | - name: Run ruff check 25 | run: | 26 | uv run ruff format --check 27 | uv run ruff check 28 | 29 | - name: Run default unit tests 30 | run: uv run pytest 31 | 32 | - name: Run type checking 33 | run: uv run pyright --stats 34 | 35 | check-backwards-compatibility: 36 | strategy: 37 | matrix: 38 | numpy_version: [==1.22.0, ''] 39 | torch_version: [==1.11.0, ''] 40 | python_version: ['3.10', '3.11', '3.12', '3.13'] 41 | exclude: 42 | - numpy_version: ==1.22.0 43 | python_version: '3.12' 44 | - numpy_version: ==1.22.0 45 | python_version: '3.13' 46 | 47 | runs-on: ubuntu-latest 48 | steps: 49 | - name: Checkout repository 50 | uses: actions/checkout@v5 51 | 52 | - name: Install uv 53 | uses: astral-sh/setup-uv@v5 54 | with: 55 | version: 0.7.17 56 | python-version: ${{ matrix.python_version }} 57 | enable-cache: true 58 | 59 | - name: Run tests 60 | run: | 61 | uv add --frozen 'torch${{ matrix.torch_version }}' 'numpy${{ matrix.numpy_version }}' 62 | uv run --frozen pytest 63 | 64 | - name: Run benchmarks 65 | run: | 66 | uv run --frozen benchmark.py 67 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [dependency-groups] 2 | dev = [ 3 | "numpy>=2.2.6", 4 | "onnx>=1.18.0", 5 | "pytest>=8.4.1", 6 | "ruff>=0.12.0", 7 | "torch>=1.4.0", 8 | "setuptools>=60.0.0", 9 | "pyright>=1.1.407" 10 | ] 11 | 12 | [project] 13 | dependencies = ["pydantic>=2.0"] 14 | description = "An extremely lightweight typing library for torch tensors or numpy arrays. Supports runtime shape checking and data type validation." 15 | keywords = ["pytorch", "numpy", "shape check", "type check"] 16 | license = "Apache-2.0" 17 | license-files = ["LICENSE"] 18 | name = "dltype" 19 | readme = "README.md" 20 | requires-python = ">=3.10" 21 | version = "0.7.0" 22 | 23 | [project.optional-dependencies] 24 | numpy = ["numpy"] 25 | torch = ["torch>=1.11.0"] 26 | 27 | [project.urls] 28 | homepage = "https://github.com/stackav-oss/dltype" 29 | repository = "https://github.com/stackav-oss/dltype.git" 30 | 31 | [tool.pyright] 32 | exclude = [] 33 | ignore = [] 34 | include = ["dltype"] 35 | reportUnnecessaryTypeIgnoreComment = "error" 36 | typeCheckingMode = "strict" 37 | 38 | [tool.ruff] 39 | indent-width = 4 40 | line-length = 110 41 | target-version = "py310" 42 | 43 | [tool.ruff.format] 44 | docstring-code-format = true 45 | docstring-code-line-length = "dynamic" 46 | line-ending = "lf" 47 | 48 | [tool.ruff.lint] 49 | extend-select = [ 50 | "ANN", 51 | "ARG", 52 | "B", 53 | "BLE", 54 | "C4", 55 | "C90", 56 | "COM", 57 | "D", 58 | "E", 59 | "EM", 60 | "ERA", 61 | "F", 62 | "FA", 63 | "FBT", 64 | "G", 65 | "I", 66 | "ICN", 67 | "INP", 68 | "ISC", 69 | "LOG", 70 | "N", 71 | "NPY", 72 | "PERF", 73 | "PIE", 74 | "PLE", 75 | "PLR", 76 | "PT", 77 | "PTH", 78 | "Q", 79 | "RET", 80 | "RSE", 81 | "RUF", 82 | "SIM", 83 | "SLF", 84 | "T20", 85 | "TC", 86 | "TID", 87 | "TRY", 88 | "W" 89 | ] 90 | fixable = ["ALL"] 91 | ignore = ["D203", "D212", "D401", "E501", "PLR6301", "PLR0917"] 92 | select = ["E4", "E7", "E9", "F"] 93 | 94 | [tool.ruff.lint.per-file-ignores] 95 | "dltype/tests/*" = ["D", "ARG", "FBT", "PLR2004", "T201"] 96 | -------------------------------------------------------------------------------- /dltype/tests/interop_test.py: -------------------------------------------------------------------------------- 1 | """Test that dltype can operate with either numpy or torch installed.""" 2 | 3 | import sys 4 | from collections.abc import Iterator 5 | from importlib import reload 6 | from unittest.mock import patch 7 | 8 | import numpy as np 9 | import pytest 10 | import torch 11 | 12 | import dltype 13 | 14 | 15 | @pytest.fixture(autouse=True) 16 | def clear_cached_available_fns() -> None: 17 | """Clear cached functions to ensure fresh imports.""" 18 | # Clear the cache for the dependency utilities 19 | from dltype._lib._dependency_utilities import is_numpy_available, is_torch_available 20 | 21 | is_torch_available.cache_clear() 22 | is_numpy_available.cache_clear() 23 | 24 | 25 | @pytest.fixture(autouse=True) 26 | def reset_modules() -> Iterator[None]: 27 | # Store a copy of the initial sys.modules state 28 | initial_modules = sys.modules.copy() 29 | yield 30 | # Restore sys.modules to its initial state after the test 31 | sys.modules.clear() 32 | sys.modules.update(initial_modules) 33 | 34 | 35 | @pytest.fixture 36 | def mock_missing_numpy() -> Iterator[None]: 37 | """Mock numpy as missing.""" 38 | with patch("dltype._lib._dependency_utilities.np", None): 39 | yield 40 | 41 | 42 | @pytest.fixture 43 | def mock_missing_torch() -> Iterator[None]: 44 | """Mock torch as missing.""" 45 | with patch("dltype._lib._dependency_utilities.torch", None): 46 | yield 47 | 48 | 49 | def test_dltype_imports_without_torch(mock_missing_torch: None) -> None: 50 | """Test that dltype can be imported without torch.""" 51 | del sys.modules["torch"] 52 | 53 | with pytest.raises(ImportError): 54 | reload(torch) 55 | 56 | reloaded_dltype = reload(dltype) 57 | 58 | assert (np.bool_,) == reloaded_dltype.BoolTensor.DTYPES 59 | 60 | 61 | def test_dltype_imports_without_numpy(mock_missing_numpy: None) -> None: 62 | """Test that dltype can be imported without numpy.""" 63 | del sys.modules["numpy"] 64 | 65 | with pytest.raises(ImportError): 66 | reload(np) 67 | 68 | reloaded_dltype = reload(dltype) 69 | 70 | assert (torch.bool,) == reloaded_dltype.BoolTensor.DTYPES 71 | 72 | 73 | def test_dltype_imports_with_both() -> None: 74 | """Test that dltype can be imported with both torch and numpy.""" 75 | reloaded_dltype = reload(dltype) 76 | assert ( 77 | torch.bool, 78 | np.bool_, 79 | ) == reloaded_dltype.BoolTensor.DTYPES 80 | 81 | 82 | def test_dltype_asserts_import_error_with_neither( 83 | mock_missing_numpy: None, 84 | mock_missing_torch: None, 85 | ) -> None: 86 | """Test that dltype raises ImportError if neither torch nor numpy is available.""" 87 | 88 | with pytest.raises(ImportError, match="Neither torch nor numpy is available"): 89 | reload(dltype) 90 | -------------------------------------------------------------------------------- /dltype/_lib/_dependency_utilities.py: -------------------------------------------------------------------------------- 1 | """Utilities to handle optional dependencies in dltype.""" 2 | 3 | import typing 4 | from collections.abc import Callable 5 | from functools import cache 6 | 7 | Ret = typing.TypeVar("Ret") 8 | P = typing.ParamSpec("P") 9 | 10 | 11 | def _empty_wrapper(fn: Callable[P, Ret]) -> Callable[P, Ret]: 12 | """A no-op function used as a placeholder for optional dependencies.""" 13 | return fn 14 | 15 | 16 | # import these first and avoid runtime penalties if they are not available 17 | try: 18 | import torch 19 | 20 | # re-export for compatibility 21 | torch_jit_unused = torch.jit.unused # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] 22 | except ImportError: 23 | torch_jit_unused = _empty_wrapper 24 | torch = None 25 | 26 | 27 | try: 28 | import numpy as np 29 | except ImportError: 30 | np = None 31 | 32 | 33 | @cache 34 | def is_torch_available() -> bool: 35 | """Check if the torch library is available.""" 36 | return torch is not None 37 | 38 | 39 | @cache 40 | def is_numpy_available() -> bool: 41 | """Check if the numpy library is available.""" 42 | return np is not None 43 | 44 | 45 | @cache 46 | def is_np_float128_available() -> bool: 47 | float_128_available = False 48 | if is_numpy_available(): 49 | try: 50 | # Check if float128 is available (may not be supported on all platforms) 51 | _ = np.float128 # pyright: ignore[reportOptionalMemberAccess] 52 | float_128_available = True 53 | except AttributeError: 54 | pass 55 | return float_128_available 56 | 57 | 58 | @cache 59 | def is_np_longdouble_available() -> bool: 60 | longdouble_available = False 61 | if is_numpy_available(): 62 | try: 63 | # Check if longdouble is available (may not be supported on all platforms) 64 | _ = np.longdouble # pyright: ignore[reportOptionalMemberAccess] 65 | longdouble_available = True 66 | except AttributeError: 67 | pass 68 | return longdouble_available 69 | 70 | 71 | def raise_for_missing_dependency() -> typing.NoReturn: 72 | """Raise an ImportError if neither torch nor numpy is available.""" 73 | if not is_torch_available() and not is_numpy_available(): 74 | msg = "Neither torch nor numpy is available. Please install one of them to use dltype." 75 | raise ImportError(msg) 76 | 77 | msg = "Improper use of raise_for_missing_dependency, should only be called when both dependencies are missing." 78 | raise AssertionError(msg) 79 | 80 | 81 | def is_torch_scripting() -> bool: 82 | """Check if the torch library is in scripting mode.""" 83 | if not is_torch_available(): 84 | return False 85 | 86 | return torch.jit.is_scripting() # pyright: ignore[reportOptionalMemberAccess, reportPrivateImportUsage] 87 | -------------------------------------------------------------------------------- /dltype/_lib/_torch_tensors.py: -------------------------------------------------------------------------------- 1 | """Torch tensors for DLType.""" 2 | 3 | import torch 4 | 5 | from dltype._lib._tensor_type_base import TensorTypeBase 6 | 7 | 8 | class UInt8Tensor(TensorTypeBase): 9 | """A class to represent an 8-bit unsigned integer tensor type.""" 10 | 11 | DTYPES = (torch.uint8,) 12 | 13 | 14 | class UInt16Tensor(TensorTypeBase): 15 | """A class to represent an 16-bit unsigned integer tensor type.""" 16 | 17 | DTYPES = (torch.uint16,) 18 | 19 | 20 | class UInt32Tensor(TensorTypeBase): 21 | """A class to represent an 32-bit unsigned integer tensor type.""" 22 | 23 | DTYPES = (torch.uint32,) 24 | 25 | 26 | class UInt64Tensor(TensorTypeBase): 27 | """A class to represent an 64-bit unsigned integer tensor type.""" 28 | 29 | DTYPES = (torch.uint64,) 30 | 31 | 32 | class Int8Tensor(TensorTypeBase): 33 | """A class to represent an 8-bit integer tensor type.""" 34 | 35 | DTYPES = (torch.int8,) 36 | 37 | 38 | class Int16Tensor(TensorTypeBase): 39 | """A class to represent a 16-bit integer tensor type.""" 40 | 41 | DTYPES = (torch.int16,) 42 | 43 | 44 | class Int32Tensor(TensorTypeBase): 45 | """A class to represent a 32-bit integer tensor type.""" 46 | 47 | DTYPES = (torch.int32,) 48 | 49 | 50 | class Int64Tensor(TensorTypeBase): 51 | """A class to represent a 64-bit integer tensor type.""" 52 | 53 | DTYPES = (torch.int64,) 54 | 55 | 56 | class SignedIntTensor(TensorTypeBase): 57 | """A class to represent any signed integer tensor type of any size (8 bit, 16 bit, 32 bit, and 64 bit).""" 58 | 59 | DTYPES = ( 60 | *Int8Tensor.DTYPES, 61 | *Int16Tensor.DTYPES, 62 | *Int32Tensor.DTYPES, 63 | *Int64Tensor.DTYPES, 64 | ) 65 | 66 | 67 | class UnsignedIntTensor(TensorTypeBase): 68 | """A class to represent any unsigned integer tensor type of any size (8 bit, 16 bit, 32 bit, and 64 bit).""" 69 | 70 | DTYPES = ( 71 | *UInt8Tensor.DTYPES, 72 | *UInt16Tensor.DTYPES, 73 | *UInt32Tensor.DTYPES, 74 | *UInt64Tensor.DTYPES, 75 | ) 76 | 77 | 78 | class IntTensor(TensorTypeBase): 79 | """A class to represent any integer (signed or unsigned) tensor type of any size (8, 16, 32, and 64).""" 80 | 81 | DTYPES = ( 82 | *UnsignedIntTensor.DTYPES, 83 | *SignedIntTensor.DTYPES, 84 | ) 85 | 86 | 87 | class BFloat16Tensor(TensorTypeBase): 88 | """A class to represent bfloat16.""" 89 | 90 | DTYPES = (torch.bfloat16,) 91 | 92 | 93 | class IEEE754HalfFloatTensor(TensorTypeBase): 94 | """A dtype for 16 bit half-precision floats that comply with the IEE 754 specification (excludes bf16).""" 95 | 96 | DTYPES = ( 97 | torch.half, 98 | torch.float16, 99 | ) 100 | 101 | 102 | class Float16Tensor(TensorTypeBase): 103 | """A class to represent any 16 bit floating point type (includes regular float16 as well as bfloat16).""" 104 | 105 | DTYPES = (*IEEE754HalfFloatTensor.DTYPES, *BFloat16Tensor.DTYPES) 106 | 107 | 108 | class Float32Tensor(TensorTypeBase): 109 | """A class to represent any 32 bit floating point type.""" 110 | 111 | DTYPES = ( 112 | torch.float, 113 | torch.float32, 114 | ) 115 | 116 | 117 | class DoubleTensor(TensorTypeBase): 118 | """A class to represent a double tensor type.""" 119 | 120 | DTYPES = (torch.double, torch.float64) 121 | 122 | 123 | Float64Tensor = DoubleTensor 124 | 125 | 126 | class FloatTensor(TensorTypeBase): 127 | """A class to represent any floating point tensor type of any size (16 bit, 32 bit, and 64 bit).""" 128 | 129 | DTYPES = ( 130 | *Float16Tensor.DTYPES, 131 | *Float32Tensor.DTYPES, 132 | *Float64Tensor.DTYPES, 133 | ) 134 | 135 | 136 | class BoolTensor(TensorTypeBase): 137 | """A class to represent a boolean tensor type.""" 138 | 139 | DTYPES = (torch.bool,) 140 | -------------------------------------------------------------------------------- /dltype/__init__.py: -------------------------------------------------------------------------------- 1 | """A fast, lightweight runtime type checker for torch tensors and numpy arrays.""" 2 | 3 | from dltype._lib._constants import ( 4 | DEBUG_MODE, 5 | MAX_ACCEPTABLE_EVALUATION_TIME_NS, 6 | ) 7 | from dltype._lib._core import ( 8 | DLTypeScopeProvider, 9 | dltyped, 10 | dltyped_dataclass, 11 | dltyped_namedtuple, 12 | ) 13 | from dltype._lib._dependency_utilities import ( 14 | is_numpy_available, 15 | is_torch_available, 16 | raise_for_missing_dependency, 17 | ) 18 | from dltype._lib._dtypes import SUPPORTED_TENSOR_TYPES 19 | from dltype._lib._errors import ( 20 | DLTypeDtypeError, 21 | DLTypeDuplicateError, 22 | DLTypeError, 23 | DLTypeInvalidReferenceError, 24 | DLTypeNDimsError, 25 | DLTypeScopeProviderError, 26 | DLTypeShapeError, 27 | DLTypeUnsupportedTensorTypeError, 28 | ) 29 | from dltype._lib._symbolic_expressions import ( 30 | AnonymousAxis, 31 | ConstantAxis, 32 | LiteralAxis, 33 | Max, 34 | Min, 35 | Shape, 36 | VariableAxis, 37 | ) 38 | from dltype._lib._tensor_type_base import ( 39 | TensorTypeBase, 40 | ) 41 | 42 | if is_torch_available() and is_numpy_available(): 43 | from dltype._lib._universal_tensors import ( 44 | BFloat16Tensor, 45 | BoolTensor, 46 | DoubleTensor, 47 | Float16Tensor, 48 | Float32Tensor, 49 | Float64Tensor, 50 | FloatTensor, 51 | IEEE754HalfFloatTensor, 52 | Int8Tensor, 53 | Int16Tensor, 54 | Int32Tensor, 55 | Int64Tensor, 56 | IntTensor, 57 | SignedIntTensor, 58 | UInt8Tensor, 59 | UInt16Tensor, 60 | UInt32Tensor, 61 | UInt64Tensor, 62 | UnsignedIntTensor, 63 | ) 64 | elif is_torch_available(): 65 | from dltype._lib._torch_tensors import ( 66 | BFloat16Tensor, 67 | BoolTensor, 68 | DoubleTensor, 69 | Float16Tensor, 70 | Float32Tensor, 71 | Float64Tensor, 72 | FloatTensor, 73 | IEEE754HalfFloatTensor, 74 | Int8Tensor, 75 | Int16Tensor, 76 | Int32Tensor, 77 | Int64Tensor, 78 | IntTensor, 79 | SignedIntTensor, 80 | UInt8Tensor, 81 | UInt16Tensor, 82 | UInt32Tensor, 83 | UInt64Tensor, 84 | UnsignedIntTensor, 85 | ) 86 | elif is_numpy_available(): 87 | from dltype._lib._numpy_tensors import ( 88 | BoolTensor, 89 | DoubleTensor, 90 | Float16Tensor, 91 | Float32Tensor, 92 | Float64Tensor, 93 | FloatTensor, 94 | IEEE754HalfFloatTensor, 95 | Int8Tensor, 96 | Int16Tensor, 97 | Int32Tensor, 98 | Int64Tensor, 99 | IntTensor, 100 | SignedIntTensor, 101 | UInt8Tensor, 102 | UInt16Tensor, 103 | UInt32Tensor, 104 | UInt64Tensor, 105 | UnsignedIntTensor, 106 | ) 107 | 108 | BFloat16Tensor = None 109 | else: 110 | raise_for_missing_dependency() 111 | 112 | 113 | __all__ = [ 114 | "DEBUG_MODE", 115 | "MAX_ACCEPTABLE_EVALUATION_TIME_NS", 116 | "SUPPORTED_TENSOR_TYPES", 117 | "AnonymousAxis", 118 | "BFloat16Tensor", 119 | "BFloat16Tensor", 120 | "BoolTensor", 121 | "ConstantAxis", 122 | "DLTypeDtypeError", 123 | "DLTypeDuplicateError", 124 | "DLTypeError", 125 | "DLTypeInvalidReferenceError", 126 | "DLTypeNDimsError", 127 | "DLTypeScopeProvider", 128 | "DLTypeScopeProviderError", 129 | "DLTypeShapeError", 130 | "DLTypeUnsupportedTensorTypeError", 131 | "DoubleTensor", 132 | "Float16Tensor", 133 | "Float32Tensor", 134 | "Float64Tensor", 135 | "FloatTensor", 136 | "IEEE754HalfFloatTensor", 137 | "Int8Tensor", 138 | "Int16Tensor", 139 | "Int32Tensor", 140 | "Int64Tensor", 141 | "IntTensor", 142 | "LiteralAxis", 143 | "Max", 144 | "Min", 145 | "Shape", 146 | "SignedIntTensor", 147 | "TensorTypeBase", 148 | "UInt8Tensor", 149 | "UInt16Tensor", 150 | "UInt32Tensor", 151 | "UInt64Tensor", 152 | "UnsignedIntTensor", 153 | "VariableAxis", 154 | "dltyped", 155 | "dltyped_dataclass", 156 | "dltyped_namedtuple", 157 | ] 158 | -------------------------------------------------------------------------------- /dltype/_lib/_numpy_tensors.py: -------------------------------------------------------------------------------- 1 | """Numpy-only tensor types for DLType.""" 2 | 3 | import numpy as np 4 | 5 | from dltype._lib._dependency_utilities import ( 6 | is_np_float128_available, 7 | is_np_longdouble_available, 8 | ) 9 | from dltype._lib._tensor_type_base import TensorTypeBase 10 | 11 | 12 | class UInt8Tensor(TensorTypeBase): 13 | """A class to represent an 8-bit unsigned integer tensor type.""" 14 | 15 | DTYPES = (np.uint8,) 16 | 17 | 18 | class UInt16Tensor(TensorTypeBase): 19 | """A class to represent an 16-bit unsigned integer tensor type.""" 20 | 21 | DTYPES = (np.uint16,) 22 | 23 | 24 | class UInt32Tensor(TensorTypeBase): 25 | """A class to represent an 32-bit unsigned integer tensor type.""" 26 | 27 | DTYPES = (np.uint32,) 28 | 29 | 30 | class UInt64Tensor(TensorTypeBase): 31 | """A class to represent an 64-bit unsigned integer tensor type.""" 32 | 33 | DTYPES = (np.uint64,) 34 | 35 | 36 | class Int8Tensor(TensorTypeBase): 37 | """A class to represent an 8-bit integer tensor type.""" 38 | 39 | DTYPES = (np.int8,) 40 | 41 | 42 | class Int16Tensor(TensorTypeBase): 43 | """A class to represent a 16-bit integer tensor type.""" 44 | 45 | DTYPES = (np.int16,) 46 | 47 | 48 | class Int32Tensor(TensorTypeBase): 49 | """A class to represent a 32-bit integer tensor type.""" 50 | 51 | DTYPES = (np.int32,) 52 | 53 | 54 | class Int64Tensor(TensorTypeBase): 55 | """A class to represent a 64-bit integer tensor type.""" 56 | 57 | DTYPES = (np.int64,) 58 | 59 | 60 | class SignedIntTensor(TensorTypeBase): 61 | """A class to represent any signed integer tensor type of any size (8 bit, 16 bit, 32 bit, and 64 bit).""" 62 | 63 | DTYPES = ( 64 | *Int8Tensor.DTYPES, 65 | *Int16Tensor.DTYPES, 66 | *Int32Tensor.DTYPES, 67 | *Int64Tensor.DTYPES, 68 | ) 69 | 70 | 71 | class UnsignedIntTensor(TensorTypeBase): 72 | """A class to represent any unsigned integer tensor type of any size (8 bit, 16 bit, 32 bit, and 64 bit).""" 73 | 74 | DTYPES = ( 75 | *UInt8Tensor.DTYPES, 76 | *UInt16Tensor.DTYPES, 77 | *UInt32Tensor.DTYPES, 78 | *UInt64Tensor.DTYPES, 79 | ) 80 | 81 | 82 | class IntTensor(TensorTypeBase): 83 | """A class to represent any integer tensor type of any size (8 bit, 16 bit, 32 bit, and 64 bit).""" 84 | 85 | DTYPES = ( 86 | *SignedIntTensor.DTYPES, 87 | *UnsignedIntTensor.DTYPES, 88 | ) 89 | 90 | 91 | class IEEE754HalfFloatTensor(TensorTypeBase): 92 | """A dtype for 16 bit half-precision floats that comply with the IEE 754 specification.""" 93 | 94 | DTYPES = (np.float16,) 95 | 96 | 97 | # Note that numpy does not support non IEEE754 compliant 16 bit floating types such as bfloat16 98 | Float16Tensor = IEEE754HalfFloatTensor 99 | 100 | 101 | class Float32Tensor(TensorTypeBase): 102 | """A class to represent any 32 bit floating point type.""" 103 | 104 | DTYPES = (np.float32,) 105 | 106 | 107 | class Float64Tensor(TensorTypeBase): 108 | """A class to represent a double tensor type.""" 109 | 110 | DTYPES = (np.float64,) 111 | 112 | 113 | DoubleTensor = Float64Tensor 114 | 115 | 116 | class Float128Tensor(TensorTypeBase): 117 | """A class to represent 128 bit floating point type.""" 118 | 119 | DTYPES = (np.float128,) if is_np_float128_available() else () 120 | 121 | 122 | class LongDoubleTensor(TensorTypeBase): 123 | """A class to represent long double floating point type.""" 124 | 125 | DTYPES = (np.longdouble,) if is_np_longdouble_available() else () 126 | 127 | 128 | class FloatTensor(TensorTypeBase): 129 | """A class to represent any floating point tensor type of any size (16 bit, 32 bit, 64 bit, and optionally 128 bit).""" 130 | 131 | # Build DTYPES list based on available types 132 | DTYPES = ( 133 | # Add standard floating point types (16, 32, 64 bits) 134 | *Float16Tensor.DTYPES, 135 | *Float32Tensor.DTYPES, 136 | *DoubleTensor.DTYPES, 137 | # Add optional 128-bit and longdouble types if available 138 | *Float128Tensor.DTYPES, 139 | *LongDoubleTensor.DTYPES, 140 | ) 141 | 142 | 143 | class BoolTensor(TensorTypeBase): 144 | """A class to represent a boolean tensor type.""" 145 | 146 | DTYPES = (np.bool_,) 147 | -------------------------------------------------------------------------------- /dltype/tests/parser_test.py: -------------------------------------------------------------------------------- 1 | # pyright: reportPrivateUsage=false 2 | """Tests for expression parsing.""" 3 | 4 | import pytest 5 | 6 | from dltype import AnonymousAxis, ConstantAxis, LiteralAxis, Max, Min, Shape, VariableAxis 7 | from dltype._lib import _parser 8 | 9 | 10 | @pytest.mark.parametrize( 11 | ("expression", "scope", "expected"), 12 | [ 13 | ("a", {"a": 1}, 1), 14 | ("a=1", {}, 1), 15 | ("1+2", {}, 3), 16 | ("1+2*3", {}, 7), 17 | ("3*3", {}, 9), 18 | ("3*3", {"x": 1}, 9), 19 | ("3*x", {"x": 3}, 9), 20 | ("x+y*x", {"x": 3, "y": 4}, 15), 21 | ("min(55,1-2)", {}, -1), 22 | ("max(55,1-2)", {}, 55), 23 | ("min(max(0,x),y)", {"x": 3, "y": 4}, 3), 24 | ("min(max(0,x),y)", {"x": -3, "y": 4}, 0), 25 | ("max(x-y,0)", {"x": -3, "y": -4}, 1), 26 | ("max(x-y,0)", {"x": 3, "y": 4}, 0), 27 | ("max(x-y,0)", {"x": 3, "y": 2}, 1), 28 | ("3^2", {}, 9), 29 | ("3^2^2", {}, 81), 30 | ("min(3^x,max(3-y,99))", {"x": 2, "y": 4}, 9), 31 | ("max(3^x,min(3-y,99))", {"x": 2, "y": 4}, 9), 32 | ("min(3^x,3-y)", {"x": 2, "y": 4}, -1), 33 | ("variable_name_with_underscores", {"variable_name_with_underscores": 1}, 1), 34 | ("isqrt(5)", {}, 2), 35 | ("isqrt(16)", {}, 4), 36 | ("isqrt(x-y)", {"x": 20, "y": 5}, 3), 37 | ("min(isqrt(20),isqrt(16))", {}, 4), 38 | ("max(isqrt(20),isqrt(16))", {}, 4), 39 | ], 40 | ) 41 | def test_parse_expression( 42 | expression: str, 43 | scope: dict[str, int], 44 | expected: int, 45 | ) -> None: 46 | assert _parser.expression_from_string(expression).evaluate(scope) == expected 47 | 48 | 49 | @pytest.mark.parametrize( 50 | ("expression", "scope"), 51 | [ 52 | ("1 + 2", {}), 53 | ("a+2", {}), 54 | ("a=a-1", {}), 55 | ("a", {"b": 1}), 56 | ("min(1,2", {}), 57 | ("*batch", {}), 58 | ("3**2", {}), 59 | ("^", {}), 60 | ("isqrt(4, 5)", {}), 61 | ("isqrt()", {}), 62 | ("max(1)", {}), 63 | ("min()", {}), 64 | ], 65 | ) 66 | def test_parse_invalid_expression(expression: str, scope: dict[str, int]) -> None: 67 | with pytest.raises((SyntaxError, KeyError)): 68 | _parser.expression_from_string(expression).evaluate(scope) 69 | 70 | 71 | @pytest.mark.parametrize( 72 | ("expression", "expected"), 73 | [ 74 | (Shape[1 + 2], "3"), 75 | (Shape[Max(1, VariableAxis("imageh"))], "max(1,imageh)"), 76 | (Shape[Min(1, 2)], "1"), 77 | (Shape[ConstantAxis("RGB", 4)], "RGB=4"), 78 | (Shape[..., VariableAxis("c"), VariableAxis("h"), VariableAxis("w")], "... c h w"), 79 | ( 80 | Shape[AnonymousAxis("batch"), VariableAxis("c"), VariableAxis("h"), VariableAxis("w")], 81 | "*batch c h w", 82 | ), 83 | (Shape[LiteralAxis(4), VariableAxis("r")], "4 r"), 84 | (Shape[Min(4 + VariableAxis("image_w"), VariableAxis("imageh"))], "min(4+image_w,imageh)"), 85 | ], 86 | ) 87 | def test_parse_symbolic(expression: Shape, expected: str) -> None: 88 | assert str(expression) == expected 89 | 90 | 91 | def test_raises() -> None: 92 | rgb = ConstantAxis("RGB", 3) 93 | 94 | # NOTE: we disallow adding anything to constants because it isn't clear what the intent 95 | # would be. 96 | 97 | # For example, if we have rgb=3 we clearly have an axis that is meant for 98 | # rgb channels and would have a shape of 3. 99 | 100 | # however, what does rgb+1 mean? Do we want the axis of rgb to be 4 instead? in which case, is it even 101 | # referring to the same axis anymore? Or do we want a new axis for rgba, but how do we change the name in an addition operation? 102 | 103 | with pytest.raises(TypeError): 104 | _ = rgb + 4 # pyright: ignore[reportUnknownVariableType, reportOperatorIssue] 105 | 106 | with pytest.raises(TypeError): 107 | _ = 4 + rgb # pyright: ignore[reportUnknownVariableType, reportOperatorIssue] 108 | 109 | # Similarly operating on anonymous axes doesn't make sense in general as it isn't clear what the intent would be 110 | with pytest.raises(TypeError): 111 | _ = AnonymousAxis("*batch") + 4 # pyright: ignore[reportUnknownVariableType, reportOperatorIssue] 112 | 113 | with pytest.raises(TypeError): 114 | _ = 4 + AnonymousAxis("*batch") # pyright: ignore[reportUnknownVariableType, reportOperatorIssue] 115 | -------------------------------------------------------------------------------- /dltype/_lib/_errors.py: -------------------------------------------------------------------------------- 1 | """Errors for the dltype library.""" 2 | 3 | import typing 4 | from abc import ABC, abstractmethod 5 | from collections import abc 6 | 7 | from dltype._lib._dtypes import SUPPORTED_TENSOR_TYPES, DLtypeDtypeT 8 | 9 | 10 | class DLTypeError(TypeError, ABC): 11 | """An error raised when a type assertion is hit.""" 12 | 13 | def __init__(self, error_ctx: str | None) -> None: 14 | self._ctx = error_ctx 15 | super().__init__() 16 | 17 | def set_context(self, error_ctx: str) -> None: 18 | self._ctx = error_ctx 19 | 20 | @abstractmethod 21 | def __str__(self) -> str: 22 | if self._ctx is not None: 23 | return f"[{self._ctx}] {self!s}" 24 | return super().__str__() 25 | 26 | 27 | class DLTypeUnsupportedTensorTypeError(DLTypeError): 28 | """An error raised when dltype is attempted to be used on an unsupported tensor type.""" 29 | 30 | def __init__(self, actual_type: type[typing.Any]) -> None: 31 | self._actual = actual_type 32 | 33 | def __str__(self) -> str: 34 | return f"Invalid tensor type, expected one of {SUPPORTED_TENSOR_TYPES}, actual={self._actual}" 35 | 36 | 37 | class DLTypeShapeError(DLTypeError): 38 | """An error raised when a shape assertion is hit.""" 39 | 40 | def __init__( 41 | self, 42 | index: int, 43 | expected_shape: int, 44 | actual: int, 45 | tensor_name: str, 46 | error_ctx: str | None = None, 47 | ) -> None: 48 | self._tensor_name = tensor_name or "anonymous" 49 | self._index = index 50 | self._expected = expected_shape 51 | self._actual = actual 52 | super().__init__(error_ctx=error_ctx) 53 | 54 | def __str__(self) -> str: 55 | return f"Invalid tensor shape, tensor={self._tensor_name} dim={self._index} expected={self._expected} actual={self._actual}" 56 | 57 | 58 | class DLTypeNDimsError(DLTypeError): 59 | """An error raised when a tensor does not have the expected number of dimensions.""" 60 | 61 | def __init__( 62 | self, 63 | expected: int, 64 | actual: int, 65 | tensor_name: str, 66 | error_ctx: str | None = None, 67 | ) -> None: 68 | self._tensor_name = tensor_name or "anonymous" 69 | self._expected = expected 70 | self._actual = actual 71 | super().__init__(error_ctx=error_ctx) 72 | 73 | def __str__(self) -> str: 74 | return f"Invalid number of dimensions, tensor={self._tensor_name} expected ndims={self._expected} actual={self._actual}" 75 | 76 | 77 | class DLTypeDtypeError(DLTypeError): 78 | """An error raised when a dtype assertion is hit.""" 79 | 80 | def __init__( 81 | self, 82 | tensor_name: str | None, 83 | expected: abc.Iterable[DLtypeDtypeT] | None, 84 | received: abc.Iterable[DLtypeDtypeT] | None, 85 | error_ctx: str | None = None, 86 | ) -> None: 87 | """Raise an error regarding an invalid shape.""" 88 | self._tensor_name = tensor_name or "anonymous" 89 | self._expected = expected or set() 90 | self._received = received or set() 91 | super().__init__(error_ctx=error_ctx) 92 | 93 | def __str__(self) -> str: 94 | expected = ",".join(sorted(map(str, self._expected))) 95 | received = ",".join(sorted(map(str, self._received))) 96 | return f"Invalid dtype, tensor={self._tensor_name} expected one of ({expected}) got={received}" 97 | 98 | 99 | class DLTypeDuplicateError(DLTypeError): 100 | """An error raised when a duplicate tensor name is hit.""" 101 | 102 | def __init__( 103 | self, 104 | tensor_name: str | None, 105 | error_ctx: str | None = None, 106 | ) -> None: 107 | self._tensor_name = tensor_name 108 | super().__init__(error_ctx=error_ctx) 109 | 110 | def __str__(self) -> str: 111 | return f"Invalid duplicate tensor, tensor={self._tensor_name}" 112 | 113 | 114 | class DLTypeInvalidReferenceError(DLTypeError): 115 | """An error raised when an invalid reference is hit.""" 116 | 117 | def __init__( 118 | self, 119 | tensor_name: str | None, 120 | missing_ref: str | None, 121 | current_context: dict[str, int] | None, 122 | error_ctx: str | None = None, 123 | ) -> None: 124 | self._tensor_name = tensor_name or "?" 125 | self._missing_ref = missing_ref or "?" 126 | self._context = current_context or {} 127 | super().__init__(error_ctx=error_ctx) 128 | 129 | def __str__(self) -> str: 130 | context = ", ".join(self._context.keys()) 131 | return f"Invalid axis referenced before assignment tensor={self._tensor_name} missing_ref={self._missing_ref} valid_refs={context}" 132 | 133 | 134 | class DLTypeScopeProviderError(DLTypeError): 135 | """An error raised when an invalid scope provider is hit.""" 136 | 137 | def __init__( 138 | self, 139 | bad_scope_provider: str, 140 | error_ctx: str | None = None, 141 | ) -> None: 142 | self._bad_scope_provider = bad_scope_provider 143 | super().__init__(error_ctx=error_ctx) 144 | 145 | def __str__(self) -> str: 146 | return f"Invalid scope provider {self._bad_scope_provider}, expected 'self' or a DLTypeScopeProvider" 147 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[codz] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | #poetry.toml 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. 114 | # https://pdm-project.org/en/latest/usage/project/#working-with-version-control 115 | #pdm.lock 116 | #pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # pixi 121 | # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. 122 | #pixi.lock 123 | # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one 124 | # in the .venv directory. It is recommended not to include this directory in version control. 125 | .pixi 126 | 127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 128 | __pypackages__/ 129 | 130 | # Celery stuff 131 | celerybeat-schedule 132 | celerybeat.pid 133 | 134 | # SageMath parsed files 135 | *.sage.py 136 | 137 | # Environments 138 | .env 139 | .envrc 140 | .venv 141 | env/ 142 | venv/ 143 | ENV/ 144 | env.bak/ 145 | venv.bak/ 146 | 147 | # Spyder project settings 148 | .spyderproject 149 | .spyproject 150 | 151 | # Rope project settings 152 | .ropeproject 153 | 154 | # mkdocs documentation 155 | /site 156 | 157 | # mypy 158 | .mypy_cache/ 159 | .dmypy.json 160 | dmypy.json 161 | 162 | # Pyre type checker 163 | .pyre/ 164 | 165 | # pytype static type analyzer 166 | .pytype/ 167 | 168 | # Cython debug symbols 169 | cython_debug/ 170 | 171 | # PyCharm 172 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 173 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 174 | # and can be added to the global gitignore or merged into this file. For a more nuclear 175 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 176 | #.idea/ 177 | 178 | # Abstra 179 | # Abstra is an AI-powered process automation framework. 180 | # Ignore directories containing user credentials, local state, and settings. 181 | # Learn more at https://abstra.io/docs 182 | .abstra/ 183 | 184 | # Visual Studio Code 185 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 186 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 187 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 188 | # you could uncomment the following to ignore the entire vscode folder 189 | # .vscode/ 190 | 191 | # Ruff stuff: 192 | .ruff_cache/ 193 | 194 | # PyPI configuration file 195 | .pypirc 196 | 197 | # Cursor 198 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to 199 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data 200 | # refer to https://docs.cursor.com/context/ignore-files 201 | .cursorignore 202 | .cursorindexingignore 203 | 204 | # Marimo 205 | marimo/_static/ 206 | marimo/_lsp/ 207 | __marimo__/ 208 | -------------------------------------------------------------------------------- /dltype/_lib/_symbolic_expressions.py: -------------------------------------------------------------------------------- 1 | """Allows specifying expressions as symbolic types rather than strings.""" 2 | 3 | from __future__ import annotations 4 | 5 | import typing 6 | from abc import ABC, abstractmethod 7 | from types import EllipsisType 8 | 9 | from typing_extensions import override 10 | 11 | 12 | class AxisOperationBase(ABC): 13 | def __init__( 14 | self, 15 | lhs: OperableAxis | ComputedAxis | int, 16 | rhs: OperableAxis | ComputedAxis | int, 17 | ) -> None: 18 | self._lhs = lhs if isinstance(lhs, OperableAxis | ComputedAxis) else LiteralAxis(lhs) 19 | self._rhs = rhs if isinstance(rhs, OperableAxis | ComputedAxis) else LiteralAxis(rhs) 20 | 21 | @abstractmethod 22 | def __str__(self) -> str: 23 | pass 24 | 25 | def __repr__(self) -> str: 26 | return self.__str__() 27 | 28 | 29 | class Add(AxisOperationBase): 30 | def __str__(self) -> str: 31 | if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis): 32 | return f"{self._lhs.value + self._rhs.value}" 33 | return f"{self._lhs}+{self._rhs}" 34 | 35 | 36 | class Subtract(AxisOperationBase): 37 | def __str__(self) -> str: 38 | if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis): 39 | return f"{self._lhs.value - self._rhs.value}" 40 | return f"{self._lhs}-{self._rhs}" 41 | 42 | 43 | class Divide(AxisOperationBase): 44 | def __str__(self) -> str: 45 | if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis): 46 | return f"{self._lhs.value // self._rhs.value}" 47 | return f"{self._lhs}/{self._rhs}" 48 | 49 | 50 | class Multiply(AxisOperationBase): 51 | def __str__(self) -> str: 52 | if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis): 53 | return f"{self._lhs.value * self._rhs.value}" 54 | return f"{self._lhs}*{self._rhs}" 55 | 56 | 57 | class Exp(AxisOperationBase): 58 | def __str__(self) -> str: 59 | if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis): 60 | return f"{self._lhs.value**self._rhs.value}" 61 | return f"{self._lhs}^{self._rhs}" 62 | 63 | 64 | class Max(AxisOperationBase): 65 | def __str__(self) -> str: 66 | if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis): 67 | return f"{max(self._lhs.value, self._rhs.value)}" 68 | return f"max({self._lhs},{self._rhs})" 69 | 70 | 71 | class Min(AxisOperationBase): 72 | def __str__(self) -> str: 73 | if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis): 74 | return f"{min(self._lhs.value, self._rhs.value)}" 75 | return f"min({self._lhs},{self._rhs})" 76 | 77 | 78 | class OperableAxis(ABC): 79 | @abstractmethod 80 | def __str__(self) -> str: ... 81 | 82 | def __repr__(self) -> str: 83 | return self.__str__() 84 | 85 | def __resolve_expr_sides( 86 | self, 87 | other: OperableAxisT, 88 | *, 89 | reverse: bool = False, 90 | ) -> tuple[OperableAxisT | OperableAxis, OperableAxisT | OperableAxis]: 91 | lhs = other if reverse else self 92 | rhs = self if reverse else other 93 | 94 | return lhs, rhs 95 | 96 | def __add__(self, other: OperableAxisT) -> ComputedAxis: 97 | return ComputedAxis(Add(*self.__resolve_expr_sides(other))) 98 | 99 | def __radd__(self, other: OperableAxisT) -> ComputedAxis: 100 | return ComputedAxis(Add(*self.__resolve_expr_sides(other, reverse=True))) 101 | 102 | def __sub__(self, other: OperableAxisT) -> ComputedAxis: 103 | return ComputedAxis(Subtract(*self.__resolve_expr_sides(other))) 104 | 105 | def __rsub__(self, other: OperableAxisT) -> ComputedAxis: 106 | return ComputedAxis(Subtract(*self.__resolve_expr_sides(other, reverse=True))) 107 | 108 | def __mul__(self, other: OperableAxisT) -> ComputedAxis: 109 | return ComputedAxis(Multiply(*self.__resolve_expr_sides(other))) 110 | 111 | def __rmul__(self, other: OperableAxisT) -> ComputedAxis: 112 | return ComputedAxis(Multiply(*self.__resolve_expr_sides(other, reverse=True))) 113 | 114 | def __floordiv__(self, other: OperableAxisT) -> ComputedAxis: 115 | return ComputedAxis(Divide(*self.__resolve_expr_sides(other))) 116 | 117 | def __rfloordiv__(self, other: OperableAxisT) -> ComputedAxis: 118 | return ComputedAxis(Divide(*self.__resolve_expr_sides(other, reverse=True))) 119 | 120 | def __pow__(self, other: OperableAxisT) -> ComputedAxis: 121 | return ComputedAxis(Exp(*self.__resolve_expr_sides(other))) 122 | 123 | def __rpow__(self, other: OperableAxisT) -> ComputedAxis: 124 | return ComputedAxis(Exp(*self.__resolve_expr_sides(other, reverse=True))) 125 | 126 | 127 | class LiteralAxis(OperableAxis): 128 | def __init__(self, value: int) -> None: 129 | """Initialize an axis with a literal integer value.""" 130 | self._value = value 131 | 132 | @property 133 | def value(self) -> int: 134 | return self._value 135 | 136 | def __str__(self) -> str: 137 | return str(self._value) 138 | 139 | 140 | class VariableAxis(OperableAxis): 141 | def __init__(self, identifier: str) -> None: 142 | """Initialize an axis with an identifier string.""" 143 | self._identifier = identifier 144 | 145 | def __str__(self) -> str: 146 | return str(self._identifier) 147 | 148 | 149 | class ComputedAxis(OperableAxis): 150 | def __init__(self, computation: AxisOperationBase) -> None: 151 | self._computation = computation 152 | 153 | def __str__(self) -> str: 154 | return f"{self._computation}" 155 | 156 | def __repr__(self) -> str: 157 | return self.__str__() 158 | 159 | 160 | class NamedComputedAxis(ComputedAxis): 161 | def __init__(self, identifier: str, computation: AxisOperationBase) -> None: 162 | super().__init__(computation) 163 | self._identifier = identifier 164 | 165 | @property 166 | def identifier(self) -> str: 167 | return self._identifier 168 | 169 | @override 170 | def __str__(self) -> str: 171 | return f"{self._identifier}={self._computation}" 172 | 173 | 174 | OperableAxisT: typing.TypeAlias = LiteralAxis | VariableAxis | ComputedAxis | int 175 | 176 | 177 | class ConstantAxis: 178 | def __init__(self, identifier: str, value: int) -> None: 179 | """Initialize a symbol with an identifier string equal to a constant value i.e. batch=3.""" 180 | self._identifier = str(identifier) 181 | self._value = value 182 | 183 | @property 184 | def value(self) -> int: 185 | return self._value 186 | 187 | @property 188 | def identifier(self) -> str: 189 | return self._identifier 190 | 191 | def __str__(self) -> str: 192 | return f"{self._identifier}={self._value}" 193 | 194 | 195 | class AnonymousAxis: 196 | def __init__(self, maybe_name: str | EllipsisType) -> None: 197 | """Initialize an axis or set of axes with zero or more values, optionally give a name.""" 198 | self._identifier = maybe_name 199 | 200 | def __str__(self) -> str: 201 | return ("*" + str(self._identifier)) if isinstance(self._identifier, str) else "..." 202 | 203 | def __repr__(self) -> str: 204 | return self.__str__() 205 | 206 | 207 | AxisT: typing.TypeAlias = OperableAxisT | AnonymousAxis | ConstantAxis 208 | ExpressionComponentT = AxisT | EllipsisType 209 | 210 | 211 | class Shape: 212 | """The expression of tensor shape as a sequence of expressions.""" 213 | 214 | def __init__(self, symbols: tuple[ExpressionComponentT, ...] | ExpressionComponentT) -> None: 215 | _symbols = symbols if isinstance(symbols, tuple) else (symbols,) 216 | self._raveled_expressions = [ 217 | (AnonymousAxis(...) if isinstance(symbol, EllipsisType) else symbol) for symbol in _symbols 218 | ] 219 | 220 | def __str__(self) -> str: 221 | return " ".join(map(str, self._raveled_expressions)) 222 | 223 | def __repr__(self) -> str: 224 | return self.__str__() 225 | 226 | @classmethod 227 | def __class_getitem__(cls, args: tuple[ExpressionComponentT, ...]) -> Shape: 228 | return cls(args) 229 | -------------------------------------------------------------------------------- /dltype/_lib/_tensor_type_base.py: -------------------------------------------------------------------------------- 1 | """The base class for all dltype supported tensor annotations.""" 2 | 3 | from __future__ import annotations 4 | 5 | import typing 6 | 7 | from pydantic_core import core_schema 8 | from typing_extensions import override 9 | 10 | from dltype._lib import ( 11 | _constants, 12 | _dltype_context, 13 | _dtypes, 14 | _errors, 15 | _parser, 16 | _symbolic_expressions, 17 | ) 18 | from dltype._lib import ( 19 | _dependency_utilities as _deps, 20 | ) 21 | 22 | if typing.TYPE_CHECKING: 23 | from pydantic import GetCoreSchemaHandler, ValidationInfo 24 | 25 | if _deps.is_numpy_available(): 26 | import numpy as np 27 | import numpy.typing as npt 28 | 29 | 30 | def _resolve_numpy_dtype( 31 | np_array_t: type[npt.NDArray[typing.Any]], 32 | ) -> list[npt.DTypeLike]: 33 | """Resolve the numpy dtype of a numpy array.""" 34 | maybe_dtype_arg = typing.get_args(np_array_t)[1] 35 | maybe_dtype = typing.get_args(maybe_dtype_arg) 36 | 37 | # if the dtype is a union of types, we need to resolve it 38 | return [ 39 | typing.cast("npt.DTypeLike", dtype) 40 | for maybe_union in maybe_dtype 41 | for dtype in typing.get_args(maybe_union) or [maybe_union] 42 | ] 43 | 44 | 45 | class TensorTypeBase: 46 | """ 47 | A class to represent a tensor type. 48 | 49 | A tensor type is expected to validate the shape of any literal integers present in the type hint. 50 | It may also choose to validate the datatype of the tensor. 51 | """ 52 | 53 | DTYPES: typing.ClassVar[tuple[_dtypes.DLtypeDtypeT, ...]] = () 54 | """The torch dtypes that this tensor type asserts to contain. (empty for any dtype).""" 55 | 56 | def __init__(self, shape: str | None, *, optional: bool = False) -> None: 57 | """Create a new tensor type object.""" 58 | self.multiaxis_index: int | None = None 59 | self.anonymous_multiaxis: bool = False 60 | self.multiaxis_name: str | None = None 61 | self.optional = optional 62 | self.expected_shape = self._parse_shape_string(shape) 63 | # only include literal dimensions that aren't multiaxis 64 | self._literal_dims = tuple( 65 | (idx, dim.evaluate({})) 66 | for idx, dim in enumerate(self.expected_shape) 67 | if dim.is_literal and idx != self.multiaxis_index # pyright: ignore[reportUnnecessaryComparison] 68 | ) 69 | 70 | @override 71 | def __repr__(self) -> str: 72 | """Get the string representation of the tensor type.""" 73 | return f"{self.__class__.__name__}[{self.expected_shape}]" 74 | 75 | def _parse_shape_string( 76 | self, 77 | shape_string: str | None, 78 | ) -> tuple[_parser.DLTypeDimensionExpression, ...]: 79 | """Parse the shape string into a list of dimension expressions.""" 80 | if shape_string is None: 81 | return () 82 | 83 | split_shape = shape_string.split() 84 | 85 | if not split_shape: 86 | msg = f"Invalid shape {shape_string=}" 87 | raise SyntaxError(msg) 88 | 89 | # Process shape specification, looking for multiaxis modifiers 90 | processed_shapes: list[_parser.DLTypeDimensionExpression] = [] 91 | modifiers: dict[int, _parser.DLTypeModifier | None] = {} 92 | 93 | for i, dim_str in enumerate(split_shape): 94 | modifiers[i] = None 95 | for modifier in _parser.DLTypeModifier: 96 | if dim_str.startswith(modifier.value): 97 | modifiers[i] = modifier 98 | break 99 | 100 | this_dimension_modifier = modifiers[i] 101 | if this_dimension_modifier in { 102 | _parser.DLTypeModifier.NAMED_MULTIAXIS, 103 | _parser.DLTypeModifier.ANONYMOUS_MULTIAXIS, 104 | }: 105 | if self.multiaxis_index is not None: 106 | msg = f"Multiple multiaxis modifiers not allowed in {shape_string=}" 107 | raise SyntaxError(msg) 108 | 109 | self.multiaxis_index = i 110 | self.multiaxis_name = dim_str[len(this_dimension_modifier.value) :] 111 | self.anonymous_multiaxis = ( 112 | this_dimension_modifier == _parser.DLTypeModifier.ANONYMOUS_MULTIAXIS 113 | ) 114 | 115 | processed_shapes.append(_parser.expression_from_string(dim_str)) 116 | 117 | return tuple(processed_shapes) 118 | 119 | @classmethod 120 | def __class_getitem__(cls, shape_string: str | None | _symbolic_expressions.Shape) -> TensorTypeBase: 121 | """Get the type of the tensor.""" 122 | return cls(shape_string if isinstance(shape_string, str | None) else str(shape_string)) 123 | 124 | def __get_pydantic_core_schema__( 125 | self, 126 | source_type: type, 127 | handler: GetCoreSchemaHandler, 128 | ) -> core_schema.CoreSchema: 129 | """Get the Pydantic core schema for this type.""" 130 | 131 | def validate_tensor( 132 | tensor: _dtypes.DLtypeTensorT, 133 | info: ValidationInfo, 134 | ) -> _dtypes.DLtypeTensorT: 135 | """Validate the tensor.""" 136 | __tracebackhide__ = not _constants.DEBUG_MODE 137 | self.check(tensor, info.field_name or "anonymous") 138 | 139 | if _constants.PYDANTIC_INFO_KEY not in info.data: 140 | info.data[_constants.PYDANTIC_INFO_KEY] = _dltype_context.DLTypeContext() 141 | 142 | dl_context = typing.cast( 143 | "_dltype_context.DLTypeContext", 144 | info.data[_constants.PYDANTIC_INFO_KEY], 145 | ) 146 | dl_context.add(info.field_name or "_unknown_", (tensor,), (self,)) 147 | dl_context.assert_context() 148 | 149 | return tensor 150 | 151 | if _deps.is_numpy_available() and typing.get_origin(source_type) is np.ndarray: # pyright: ignore[reportPossiblyUnboundVariable] 152 | dtypes = _resolve_numpy_dtype(source_type) 153 | if self.DTYPES and any(dtype not in self.DTYPES for dtype in dtypes): 154 | raise _errors.DLTypeDtypeError( 155 | tensor_name=handler.field_name, 156 | expected=self.DTYPES, 157 | received=dtypes, 158 | ) 159 | # numpy arrays don't implement isinstance() because the type is actually a 160 | # parameterized generic alias and not a concrete type. We need to check the origin instead. 161 | # This is a bit of a hack, but we still get the correct type hint in 162 | # the end because we check against the dtype of the tensor first. 163 | source_type = np.ndarray # pyright: ignore[reportPossiblyUnboundVariable] 164 | 165 | return core_schema.with_info_after_validator_function( 166 | validate_tensor, 167 | schema=core_schema.is_instance_schema(source_type), 168 | field_name=handler.field_name, 169 | ) 170 | 171 | def check( 172 | self, 173 | tensor: _dtypes.DLtypeTensorT, 174 | tensor_name: str = "anonymous", 175 | ) -> None: 176 | """Check if the tensor matches this type.""" 177 | # Basic validation for multi-axis dimensions 178 | __tracebackhide__ = not _constants.DEBUG_MODE 179 | if self.multiaxis_index is not None: 180 | # Min required dimensions = expected shape length + extra dimensions - 1 (the multi-axis placeholder) 181 | min_required_dims = len(self.expected_shape) - 1 182 | if len(tensor.shape) < min_required_dims: 183 | raise _errors.DLTypeNDimsError( 184 | expected=min_required_dims, 185 | actual=tensor.ndim, 186 | tensor_name=tensor_name, 187 | ) 188 | 189 | # Standard case: exact dimension count match 190 | elif len(tensor.shape) != len(self.expected_shape): 191 | raise _errors.DLTypeNDimsError( 192 | expected=len(self.expected_shape), 193 | actual=tensor.ndim, 194 | tensor_name=tensor_name, 195 | ) 196 | 197 | if self.DTYPES and tensor.dtype not in self.DTYPES: 198 | raise _errors.DLTypeDtypeError( 199 | expected=self.DTYPES, 200 | received={tensor.dtype}, 201 | tensor_name=tensor_name, 202 | ) 203 | 204 | for idx, dim in self._literal_dims: 205 | # Adjust index if multiaxis exists and is before this dimension 206 | adjusted_idx = idx 207 | if self.multiaxis_index is not None and idx > self.multiaxis_index: 208 | # Adjust by the difference between actual and expected dimensions 209 | adjusted_idx += len(tensor.shape) - len(self.expected_shape) 210 | 211 | if tensor.shape[adjusted_idx] != dim: 212 | raise _errors.DLTypeShapeError( 213 | tensor_name=tensor_name, 214 | index=adjusted_idx, 215 | expected_shape=dim, 216 | actual=tensor.shape[adjusted_idx], 217 | ) 218 | -------------------------------------------------------------------------------- /dltype/_lib/_dltype_context.py: -------------------------------------------------------------------------------- 1 | """A module to assist with using Annotated[torch.Tensor] in type hints.""" 2 | 3 | from __future__ import annotations 4 | 5 | import logging 6 | import time 7 | import warnings 8 | from collections import deque 9 | from typing import Any, Final, NamedTuple, TypeAlias, cast 10 | 11 | from dltype._lib import _constants, _dtypes, _errors, _parser, _tensor_type_base 12 | 13 | _logger: Final = logging.getLogger(__name__) 14 | 15 | EvaluatedDimensionT: TypeAlias = dict[str, int] 16 | 17 | 18 | def _maybe_warn_runtime(runtime_ns: int) -> bool: 19 | return runtime_ns > _constants.MAX_ACCEPTABLE_EVALUATION_TIME_NS 20 | 21 | 22 | class _ConcreteType(NamedTuple): 23 | """A class containing a tensor name, a tensor value, and its type.""" 24 | 25 | arg_index: int 26 | tensor_arg_name_orig: str 27 | tensor: _dtypes.DLtypeTensorT 28 | dltype_annotation: _tensor_type_base.TensorTypeBase 29 | 30 | @property 31 | def tensor_arg_name(self) -> str: 32 | return ( 33 | f"{self.tensor_arg_name_orig}[{self.arg_index}]" 34 | if self.arg_index > 0 35 | else self.tensor_arg_name_orig 36 | ) 37 | 38 | def get_expected_shape( 39 | self, 40 | tensor: _dtypes.DLtypeTensorT, 41 | ) -> tuple[_parser.DLTypeDimensionExpression, ...]: 42 | """ 43 | Get the expected shape of the tensor. 44 | 45 | We handle multi-axis dimensions by replacing the multi-axis placeholder with the actual shape. 46 | """ 47 | expected_shape = list(self.dltype_annotation.expected_shape) 48 | 49 | if self.dltype_annotation.multiaxis_index is not None: 50 | _logger.debug( 51 | "Replacing multiaxis dimension %r with actual shape %r", 52 | self.dltype_annotation.multiaxis_index, 53 | tensor.shape, 54 | ) 55 | # for multiaxis dimensions, we replace the values in the 56 | # expected shape with the actual shape for every 57 | # dimension that is not the multiaxis dimension 58 | actual_shape = tensor.shape 59 | 60 | multi_axis_offset = len(actual_shape) - len(expected_shape) + 1 61 | 62 | expected_shape.pop(self.dltype_annotation.multiaxis_index) 63 | for i in range(multi_axis_offset): 64 | expected_shape.insert( 65 | self.dltype_annotation.multiaxis_index + i, 66 | _parser.DLTypeDimensionExpression.from_multiaxis_literal( 67 | f"{self.dltype_annotation.multiaxis_name}[{i}]", 68 | actual_shape[self.dltype_annotation.multiaxis_index + i], 69 | is_anonymous=self.dltype_annotation.anonymous_multiaxis, 70 | ), 71 | ) 72 | 73 | return tuple(expected_shape) 74 | 75 | 76 | class DLTypeContext: 77 | """ 78 | A class representing the current context for type hints. 79 | 80 | Keeps track of a simple mapping of names to expected shapes and types. 81 | 82 | This context can be evaluated at any time to check if the current actual shapes and types match expected. 83 | 84 | We evaluate in first-come-first-correct manner where the first tensor for a name is considered correct. 85 | """ 86 | 87 | def __init__(self) -> None: 88 | """Create a new DLTypeContext.""" 89 | self._hinted_tensors: deque[_ConcreteType] = deque() 90 | # mapping of dimension -> shape 91 | self.tensor_shape_map: EvaluatedDimensionT = {} 92 | # mapping of tensor name -> tensor type, used to check for duplicates 93 | self.registered_tensor_dtypes: dict[str, _dtypes.DLtypeDtypeT] = {} 94 | 95 | def add( 96 | self, 97 | name: str, 98 | tensor_values: tuple[Any, ...], 99 | dltype_annotation_tup: tuple[_tensor_type_base.TensorTypeBase | None, ...] | None, 100 | ) -> None: 101 | """Add a tensor to the context.""" 102 | if dltype_annotation_tup is None: 103 | return 104 | 105 | for idx, (dltype_annotation, tensor) in enumerate( 106 | zip( 107 | dltype_annotation_tup, 108 | tensor_values, 109 | strict=True, 110 | ), 111 | ): 112 | if dltype_annotation is None: 113 | continue 114 | if dltype_annotation.optional and tensor is None: 115 | # skip optional tensors 116 | return 117 | if not any(isinstance(tensor, T) for T in _dtypes.SUPPORTED_TENSOR_TYPES): 118 | raise _errors.DLTypeUnsupportedTensorTypeError( 119 | actual_type=cast("type[Any]", type(tensor)), 120 | ) 121 | self._hinted_tensors.append(_ConcreteType(idx, name, tensor, dltype_annotation)) 122 | 123 | def assert_context(self) -> None: 124 | """Considering the current context, check if all tensors match their expected types.""" 125 | __tracebackhide__ = not _constants.DEBUG_MODE 126 | 127 | start_t = time.perf_counter_ns() 128 | 129 | try: 130 | while self._hinted_tensors: 131 | tensor_context = self._hinted_tensors.popleft() 132 | # first check if the tensor could possibly have the right shape 133 | tensor_context.dltype_annotation.check( 134 | tensor_context.tensor, 135 | tensor_name=tensor_context.tensor_arg_name, 136 | ) 137 | 138 | if tensor_context.tensor_arg_name in self.registered_tensor_dtypes: 139 | raise _errors.DLTypeDuplicateError( 140 | tensor_name=tensor_context.tensor_arg_name, 141 | ) 142 | 143 | self.registered_tensor_dtypes[tensor_context.tensor_arg_name] = tensor_context.tensor.dtype 144 | expected_shape = tensor_context.get_expected_shape( 145 | tensor_context.tensor, 146 | ) 147 | self._assert_tensor_shape( 148 | tensor_context.tensor_arg_name, 149 | expected_shape, 150 | tensor_context.tensor, 151 | ) 152 | 153 | finally: 154 | end_t = time.perf_counter_ns() 155 | runtime_ns = end_t - start_t 156 | _logger.debug("Context evaluation took %d ns", runtime_ns) 157 | if _maybe_warn_runtime(runtime_ns): 158 | max_ms = _constants.MAX_ACCEPTABLE_EVALUATION_TIME_NS / 1e6 159 | warnings.warn( 160 | f"Type checking took longer than expected {(runtime_ns) / 1e6:.2f}ms > {max_ms:.2f}ms", 161 | UserWarning, 162 | stacklevel=2, 163 | ) 164 | 165 | def _assert_tensor_shape( 166 | self, 167 | tensor_arg_name: str, 168 | expected_shape: tuple[_parser.DLTypeDimensionExpression, ...], 169 | tensor: _dtypes.DLtypeTensorT, 170 | ) -> None: 171 | """Check if the tensor shape matches the expected shape.""" 172 | __tracebackhide__ = not _constants.DEBUG_MODE 173 | actual_shape = tuple(tensor.shape) 174 | 175 | for dim_idx, dimension_expression in enumerate(expected_shape): 176 | if dimension_expression.is_anonymous: 177 | # we don't need to check anonymous dimensions 178 | continue 179 | 180 | if ( 181 | dimension_expression.is_literal 182 | and dimension_expression.identifier not in self.tensor_shape_map 183 | ): 184 | # handled by the check method above 185 | _logger.debug( 186 | "Skipping literal dimension %r (%s)", 187 | dimension_expression, 188 | self.tensor_shape_map, 189 | ) 190 | continue 191 | 192 | if ( 193 | dimension_expression.is_identifier 194 | and dimension_expression.identifier not in self.tensor_shape_map 195 | ): 196 | _logger.debug( 197 | "establishing %r with %r", 198 | dimension_expression.identifier, 199 | actual_shape[dim_idx], 200 | ) 201 | self.tensor_shape_map[dimension_expression.identifier] = actual_shape[dim_idx] 202 | continue 203 | 204 | _logger.debug( 205 | "Checking dimension %r with scope=%s", 206 | dimension_expression, 207 | self.tensor_shape_map, 208 | ) 209 | 210 | try: 211 | expected_result = dimension_expression.evaluate(self.tensor_shape_map) 212 | except KeyError as e: 213 | missing_ref = e.args[0] 214 | raise _errors.DLTypeInvalidReferenceError( 215 | tensor_name=tensor_arg_name, 216 | missing_ref=missing_ref, 217 | current_context=self.tensor_shape_map, 218 | ) from e 219 | 220 | if expected_result != actual_shape[dim_idx]: 221 | raise _errors.DLTypeShapeError( 222 | tensor_name=tensor_arg_name, 223 | index=dim_idx, 224 | expected_shape=expected_result, 225 | actual=actual_shape[dim_idx], 226 | ) 227 | 228 | if dimension_expression.identifier not in self.tensor_shape_map: 229 | self.tensor_shape_map[dimension_expression.identifier] = actual_shape[dim_idx] 230 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2025 Stack AV Co. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | """Benchmark dltype vs. beartype vs. manual checking vs. baseline.""" 2 | 3 | from contextlib import suppress 4 | from enum import Enum, auto 5 | from inspect import signature 6 | from typing import Annotated, Final, NamedTuple 7 | 8 | import torch 9 | from torch.utils.benchmark import Measurement, Timer 10 | 11 | import dltype 12 | 13 | 14 | class BenchmarkMode(str, Enum): 15 | """What conditions to apply to the benchmark arguments.""" 16 | 17 | correct = auto() 18 | incorrect_shape = auto() 19 | incorrect_datatype = auto() 20 | incorrect_shape_and_datatype = auto() 21 | 22 | 23 | class SetupTensors(NamedTuple): 24 | """Collection of tensors that will become the function arguments for benchmark code.""" 25 | 26 | tensor_a: torch.Tensor | None 27 | tensor_b: torch.Tensor | None 28 | tensor_c: torch.Tensor | None 29 | 30 | 31 | def setup_code(mode: BenchmarkMode) -> SetupTensors: 32 | """Set up tensors for the benchmark code.""" 33 | match mode: 34 | case BenchmarkMode.correct: 35 | return SetupTensors( 36 | tensor_a=torch.rand(8, 2, 3, 4), 37 | tensor_b=torch.rand(8, 2, 3, 4), 38 | tensor_c=torch.rand(8, 2, 3, 4), 39 | ) 40 | case BenchmarkMode.incorrect_shape: 41 | return SetupTensors( 42 | tensor_a=torch.rand(8, 2, 3, 4), 43 | tensor_b=torch.rand(7, 2, 3, 4), 44 | tensor_c=torch.rand(8, 2, 3), 45 | ) 46 | case BenchmarkMode.incorrect_datatype: 47 | return SetupTensors( 48 | tensor_a=torch.rand(8, 2, 3, 4).int(), 49 | tensor_b=torch.rand(8, 2, 3, 4).int(), 50 | tensor_c=torch.rand(8, 2, 3, 4).int(), 51 | ) 52 | case BenchmarkMode.incorrect_shape_and_datatype: 53 | return SetupTensors( 54 | tensor_a=torch.rand(8, 2, 3, 4).int(), 55 | tensor_b=None, 56 | tensor_c=torch.rand(8, 2, 3).int(), 57 | ) 58 | 59 | 60 | class BenchmarkParams(NamedTuple): 61 | """Parameters for a benchmark run.""" 62 | 63 | mode: BenchmarkMode 64 | function_name: str 65 | function_args: tuple[str, ...] | None 66 | add_decorator: bool 67 | expected_error: type[Exception] | None 68 | 69 | 70 | class BenchmarkResult(NamedTuple): 71 | """Result of a benchmark run.""" 72 | 73 | params: BenchmarkParams 74 | measurement: Measurement 75 | 76 | 77 | class BenchmarkFunc: 78 | """A dltype benchmark function taking params and returning a result of that benchmark when called.""" 79 | 80 | def __init__(self, params: BenchmarkParams) -> None: 81 | """Create a new benchmark function.""" 82 | suppressed_prefix = ( 83 | f"with {suppress.__name__}({params.expected_error.__name__}): " if params.expected_error else "" 84 | ) 85 | tensor_args = ", ".join(SetupTensors._fields) 86 | maybe_decorated_function = ( 87 | f"dltype.dltyped()({params.function_name})" if params.add_decorator else f"{params.function_name}" 88 | ) 89 | bench = ( 90 | f"{maybe_decorated_function}({', '.join(params.function_args) if params.function_args else ''})" 91 | ) 92 | 93 | self._timer = Timer( 94 | setup=f"{tensor_args} = setup_code(BenchmarkMode.{params.mode.name})", 95 | stmt=f"{suppressed_prefix}{bench}", 96 | globals=globals() 97 | | ({params.expected_error.__name__: params.expected_error} if params.expected_error else {}), 98 | ) 99 | self._params = params 100 | 101 | def __call__(self) -> BenchmarkResult: 102 | """Run the benchmark and return the result.""" 103 | print(f"running bench={self._params.mode=} {self._params.function_name=}") # noqa: T201 104 | return BenchmarkResult(self._params, self._timer.adaptive_autorange()) 105 | 106 | 107 | def baseline( 108 | tensor_a: torch.Tensor, 109 | tensor_b: torch.Tensor, 110 | tensor_c: torch.Tensor, 111 | ) -> torch.Tensor: 112 | """A function that takes a tensor and returns a tensor.""" 113 | return (tensor_a * tensor_b + tensor_c).permute(2, 3, 0, 1) 114 | 115 | 116 | def dltype_function( 117 | tensor_a: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]], 118 | tensor_b: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]], 119 | tensor_c: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]], 120 | ) -> Annotated[torch.Tensor, dltype.FloatTensor["h w b c"]]: 121 | """A function that takes a tensor and returns a tensor.""" 122 | return (tensor_a * tensor_b + tensor_c).permute(2, 3, 0, 1) 123 | 124 | 125 | @dltype.dltyped() 126 | def dltype_decorated( 127 | tensor_a: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]], 128 | tensor_b: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]], 129 | tensor_c: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]], 130 | ) -> Annotated[torch.Tensor, dltype.FloatTensor["h w b c"]]: 131 | """A function that takes a tensor and returns a tensor.""" 132 | return (tensor_a * tensor_b + tensor_c).permute(2, 3, 0, 1) 133 | 134 | 135 | def manual_shape_check( 136 | tensor_a: torch.Tensor, 137 | tensor_b: torch.Tensor, 138 | tensor_c: torch.Tensor, 139 | ) -> torch.Tensor: 140 | """A function that takes a tensor and returns a tensor.""" 141 | if not all( 142 | isinstance(tensor, torch.Tensor) # pyright: ignore[reportUnnecessaryIsInstance] 143 | for tensor in (tensor_a, tensor_b, tensor_c) 144 | ): 145 | msg = "Tensors must have type=torch.Tensor." 146 | raise TypeError(msg) 147 | shapes = (tensor_a.shape, tensor_b.shape, tensor_c.shape) 148 | if not all(tensor.dtype == torch.float32 for tensor in (tensor_a, tensor_b, tensor_c)): 149 | msg = "Tensors must have dtype=torch.float32." 150 | raise TypeError(msg) 151 | if {len(shape) for shape in shapes} != {4}: 152 | msg = "Shapes must have the same number of dimensions=4." 153 | raise TypeError(msg) 154 | if all(shape == shapes[0] for shape in shapes): 155 | return (tensor_a * tensor_b + tensor_c).permute(2, 3, 0, 1) 156 | msg = "Shapes must be equal." 157 | raise TypeError(msg) 158 | 159 | 160 | # a dltyped function with an expression equivalent to the jaxtyped_with_expression function 161 | @dltype.dltyped() 162 | def dltyped_with_expression( 163 | tensor_a: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]], 164 | tensor_b: Annotated[torch.Tensor, dltype.FloatTensor["b1 c1 h1 w1"]], 165 | ) -> Annotated[torch.Tensor, dltype.FloatTensor["b max(c-1,0) w+h1"]]: 166 | """A function that takes a tensor and returns a tensor.""" 167 | # Use c-1 dimension as specified in return type 168 | reduced_c = max(tensor_a.shape[1] - 1, 0) 169 | 170 | w_plus_h = tensor_a.shape[3] + tensor_b.shape[2] 171 | 172 | return torch.zeros( 173 | tensor_a.shape[0], # b 174 | reduced_c, # c-1 175 | w_plus_h, # w+h1 176 | ) 177 | 178 | 179 | def expression_baseline( 180 | tensor_a: torch.Tensor, 181 | tensor_b: torch.Tensor, 182 | ) -> torch.Tensor: 183 | """A function that takes a tensor and returns a tensor.""" 184 | # Use c-1 dimension as specified in return type 185 | reduced_c = max(tensor_a.shape[1] - 1, 0) 186 | 187 | w_plus_h = tensor_a.shape[3] + tensor_b.shape[2] 188 | 189 | return torch.zeros( 190 | tensor_a.shape[0], # b 191 | reduced_c, # c-1 192 | w_plus_h, # w+h1 193 | ) 194 | 195 | 196 | def anonymous_axis_baseline( 197 | tensor_a: torch.Tensor, 198 | ) -> torch.Tensor: 199 | """A function that takes a tensor and cats it.""" 200 | return torch.stack([tensor_a, tensor_a], dim=1) 201 | 202 | 203 | @dltype.dltyped() 204 | def dltyped_anonymous_axis( 205 | tensor_a: Annotated[torch.Tensor, dltype.FloatTensor["*batch h w"]], 206 | ) -> Annotated[torch.Tensor, dltype.FloatTensor["*batch 2 h w"]]: 207 | """A function that takes a tensor and adds a second dimension to it.""" 208 | return torch.stack([tensor_a, tensor_a], dim=1) 209 | 210 | 211 | if __name__ == "__main__": 212 | all_functions: Final = [ 213 | baseline, 214 | manual_shape_check, 215 | dltype_function, 216 | dltype_decorated, 217 | expression_baseline, 218 | dltyped_with_expression, 219 | anonymous_axis_baseline, 220 | dltyped_anonymous_axis, 221 | ] 222 | needs_decorator = frozenset({dltype_function}) 223 | error_override = { 224 | manual_shape_check.__name__: { 225 | BenchmarkMode.incorrect_shape: TypeError, 226 | BenchmarkMode.incorrect_datatype: TypeError, 227 | BenchmarkMode.incorrect_shape_and_datatype: TypeError, 228 | }, 229 | baseline.__name__: { 230 | BenchmarkMode.incorrect_shape: RuntimeError, 231 | BenchmarkMode.incorrect_datatype: RuntimeError, 232 | BenchmarkMode.incorrect_shape_and_datatype: TypeError, 233 | }, 234 | expression_baseline.__name__: { 235 | BenchmarkMode.incorrect_shape_and_datatype: AttributeError, 236 | }, 237 | anonymous_axis_baseline.__name__: { 238 | BenchmarkMode.incorrect_shape_and_datatype: TypeError, 239 | }, 240 | } 241 | 242 | all_benchmarks: dict[BenchmarkMode, dict[str, BenchmarkFunc]] = {} 243 | 244 | for mode in BenchmarkMode: 245 | for func in all_functions: 246 | expected_error = None 247 | match mode: 248 | case BenchmarkMode.correct: 249 | expected_error = None 250 | case BenchmarkMode.incorrect_shape: 251 | expected_error = dltype.DLTypeShapeError 252 | case BenchmarkMode.incorrect_datatype: 253 | expected_error = dltype.DLTypeDtypeError 254 | case BenchmarkMode.incorrect_shape_and_datatype: 255 | expected_error = dltype.DLTypeError 256 | 257 | if func.__name__ in error_override: 258 | expected_error = error_override[func.__name__].get(mode, expected_error) 259 | 260 | all_benchmarks.setdefault(mode, {})[func.__name__] = BenchmarkFunc( 261 | BenchmarkParams( 262 | mode=mode, 263 | function_name=func.__name__, 264 | function_args=tuple(map(str, signature(func).parameters.keys())), 265 | add_decorator=func in needs_decorator, 266 | expected_error=expected_error, 267 | ), 268 | ) 269 | 270 | summary_results: dict[BenchmarkMode, dict[str, BenchmarkResult]] = {} 271 | 272 | for mode, mode_runs in all_benchmarks.items(): 273 | for func_name, benchmark in mode_runs.items(): 274 | summary_results.setdefault(mode, {})[func_name] = benchmark() 275 | 276 | for results in summary_results.values(): 277 | for result in results.values(): 278 | print("-" * 10) # noqa: T201 279 | print(f"Function: {result.params.function_name} Setup: {result.params.mode}") # noqa: T201 280 | print(result.measurement) # noqa: T201 281 | print("-" * 10) # noqa: T201 282 | 283 | max_func_length = max(len(func.__name__) + 3 for func in all_functions) 284 | max_mode_length = max(len(mode.name) + 3 for mode in BenchmarkMode) 285 | 286 | print(f"{'Benchmark':<{max_func_length}}", end="") # noqa: T201 287 | for mode in BenchmarkMode: 288 | print(f"{mode.name:>{max_mode_length}} ", end="") # noqa: T201 289 | 290 | for func in all_functions: 291 | print() # noqa: T201 292 | print(f"{func.__name__:<{max_func_length}}", end="") # noqa: T201 293 | for mode in BenchmarkMode: 294 | result = summary_results[mode][func.__name__] 295 | value = f"{result.measurement.mean * 1e6:.2f} uS" 296 | print(f"{value:>{max_mode_length}}", end="") # noqa: T201 297 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DL Type (Deep Learning Type Library) 2 | 3 | This typing library is intended to replace jaxtyping for runtime type checking of torch tensors and numpy arrays. 4 | 5 | In particular, we support two functions that beartype/jaxtype do not: 6 | 7 | 1. Support for torch.jit.script/torch.compile/torch.jit.trace 8 | 2. Pydantic model type annotations for torch tensors. 9 | 10 | ## Features 11 | 12 | - Shape and Type Validation: Validate tensor shapes and types at runtime with symbolic dimension support. 13 | - Pydantic Integration: First-class support for tensor validation in Pydantic models. 14 | - Context-Aware Validation: Ensures consistency across multiple tensors in the same context. 15 | - ONNX/torch.compile Compatible: Works seamlessly with model export and compilation workflows. 16 | - Symbolic Dimensions: Support for named dimensions that enforce consistency. 17 | 18 | ## Installation 19 | 20 | Install dltype through pip 21 | ```bash 22 | pip3 install dltype 23 | ``` 24 | 25 | > [!NOTE] 26 | > dltype does not depend explicitly on torch or numpy, but you must have at least one of them installed at import time otherwise the import will fail. 27 | 28 | ## Usage 29 | 30 | Type hints are evaluated in a context in source-code order, so any references to dimension symbols must exist before an expression is evaluated. 31 | 32 | ## Supported syntax 33 | 34 | DL Type supports four types of dimension specifications: 35 | 36 | ### Scalars 37 | 38 | Single element tensors with no shape 39 | 40 | ```python 41 | IntTensor[None] # An integer tensor with a single value and no axes 42 | ``` 43 | 44 | ### Literal Dimensions 45 | 46 | Simple integer dimensions with fixed sizes: 47 | 48 | ```python 49 | FloatTensor["3 5"] # A tensor with shape (3, 5) 50 | FloatTensor["batch channels=3 height width"] # identifiers set to dimensions for documentation 51 | ``` 52 | 53 | ### Expressions 54 | 55 | Mathematical expressions combining literals and symbols. 56 | 57 | ```python 58 | FloatTensor["batch channels*2"] # If channels=64, shape would be (batch, 128) 59 | FloatTensor["batch-1"] # One less than the batch dimension 60 | FloatTensor["features/2"] # Half the features dimension 61 | ``` 62 | 63 | #### Supported Operators and Functions 64 | 65 | > [!NOTE] 66 | > Expressions _must_ never have spaces. 67 | 68 | ##### Operators 69 | 70 | - `+` Addition 71 | - `-` Subtraction 72 | - `*` Multiplication 73 | - `/` Integer division 74 | - `^` Exponentiation 75 | 76 | ##### Functions 77 | 78 | - `min(a,b)` Minimum of two expressions 79 | - `max(a,b)` Maximum of two expressions 80 | 81 | > [!WARNING] 82 | > While nested function calls like `min(max(a,b),c)` are supported, 83 | > combining function calls with other operators in the same expression 84 | > (e.g., `min(1,batch)+max(2,channels)`) is not supported to simplify parsing. 85 | 86 | ### Symbolic Dimensions 87 | 88 | Symbolic Dimensions 89 | Named dimensions that ensure consistency across tensors: 90 | 91 | ```python 92 | FloatTensor["batch channels"] # A tensor with two dimensions 93 | ``` 94 | 95 | ### Multi Dimensions 96 | 97 | Named or anonymous dimension identifiers that may cover zero or more dimensions in the actual tensors. 98 | Only one multi-dimension identifier is allowed per type hint. 99 | 100 | ```python 101 | FloatTensor["... channels h w"] # anonymous dimension will not be matched across tensors 102 | DoubleTensor["batch *channels features"] # named dimension which can be matched across tensors 103 | ``` 104 | 105 | ### Statically Defined Shapes 106 | 107 | ```python 108 | Shape[AnonymouousAxis("batch"), ConstantAxis("rgb", 3), VariableAxis("imgh"), VariableAxis("imgw")] 109 | Shape[..., VariableAxis("N"), 4] 110 | ``` 111 | 112 | ## Argument and return typing 113 | 114 | ```python 115 | from typing import Annotated 116 | import torch 117 | from dltype import FloatTensor, dltyped 118 | 119 | @dltyped() 120 | def add_tensors( 121 | x: Annotated[torch.Tensor, FloatTensor["batch features"]], 122 | y: Annotated[torch.Tensor, FloatTensor["batch features"]] 123 | ) -> Annotated[torch.Tensor, FloatTensor["batch features"]]: 124 | return x + y 125 | ``` 126 | 127 | ## Pydantic model typing 128 | 129 | ```python 130 | from typing import Annotated 131 | from pydantic import BaseModel 132 | import torch 133 | from dltype import FloatTensor, IntTensor 134 | 135 | class ImageBatch(BaseModel): 136 | # note the parenthesis instead of brackets for pydantic models 137 | images: Annotated[torch.Tensor, FloatTensor("batch 3 height width")] 138 | labels: Annotated[torch.Tensor, IntTensor("batch")] 139 | 140 | # All tensor validations happen automatically 141 | # Shape consistency is enforced across fields 142 | ``` 143 | 144 | ## NamedTuple typing 145 | 146 | We expose `@dltyped_namedtuple()` for NamedTuples. 147 | `NamedTuples` are validated upon construction, beware that assignments or manipulations after construction are unchecked. 148 | 149 | ```python 150 | @dltype.dltyped_namedtuple() 151 | class MyNamedTuple(NamedTuple): 152 | tensor: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]] 153 | mask: Annotated[torch.Tensor, dltype.IntTensor["b h w"]] 154 | other: int 155 | ``` 156 | 157 | ## @dataclass support 158 | 159 | Similar to `NamedTuples` and pydantic `BaseModels`, `@dataclasses` may be decorated and validated. 160 | The normal caveats apply in that we only validate at construction and not on assignment. 161 | Therefore, we recommend using frozen `@dataclasses` when possible. 162 | 163 | ```python 164 | from typing import Annotated 165 | import torch 166 | from dltype import FloatTensor, IntTensor, dltyped_dataclass 167 | 168 | # order is important, we raise an error if dltyped_dataclass is applied below dataclass 169 | # this is because the @dataclass decorator applies a bunch of source code modification that we don't want to have to hack around 170 | @dltyped_dataclass() 171 | @dataclass(frozen=True, slots=True) 172 | class MyDataclass: 173 | images: Annotated[torch.Tensor, FloatTensor["batch 3 height width"]] 174 | labels: Annotated[torch.Tensor, IntTensor["batch"]] 175 | ``` 176 | 177 | ## Optionals 178 | 179 | We have no support for general unions of types to prevent confusing behavior when using runtime shape checking. 180 | DLType only supports optional types (i.e. `Type | None`). 181 | To annotate a tensor as being optional, see the example below. 182 | 183 | ```python 184 | @dltype.dltyped() 185 | def optional_tensor_func(tensor: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]] | None) -> torch.Tensor: 186 | if tensor is None: 187 | return torch.zeros(1, 3, 5, 5) 188 | return tensor 189 | ``` 190 | 191 | ## Tuple returns 192 | 193 | Tuples are a common way of passing multiple values and returns from functions. 194 | DLType supports annotating tuples passed to and returned from functions. 195 | Tuples can be mixes of annotated tensors as well as other types of objects. 196 | 197 | ```python 198 | def returned_tuple_func() -> tuple[Annotated[torch.Tensor, dltype.UInt8Tensor["b rgb=3 h w"]], int]: 199 | return torch.zeros(1, 3, 1080, 1920, dtype=torch.uint8), 8 200 | ``` 201 | 202 | ## Numpy and Tensor Mixing 203 | 204 | ```python 205 | from typing import Annotated 206 | import torch 207 | import numpy as np 208 | from dltype import FloatTensor, dltyped 209 | 210 | @dltyped() 211 | def transform_tensors( 212 | points: Annotated[np.ndarray, FloatTensor["N 3"]] 213 | transform: Annotated[torch.Tensor, FloatTensor["3 3"]] 214 | ) -> Annotated[torch.Tensor, FloatTensor["N 3"]]: 215 | return torch.from_numpy(points) @ transform 216 | ``` 217 | 218 | ## Providing External Scope 219 | 220 | There are situations that a runtime variable may influence the expected shape of a tensor. 221 | To provide external scope to be used by dltype, you may implement the `DLTypeScopeProvider` protocol. 222 | There are two flavors of this, one for methods, the other for free functions, both are shown below. 223 | Using external scope providers for free functions is not an encouraged use case as it encourages keeping global state. 224 | Additionally, free functions are generally stateless but this makes the type checking logic stateful and thus 225 | makes the execution of the function impure. 226 | We support this because there are certain scenarios where loading a configuration from a file and providing it as an expected dimension for some typed function may be useful and necessary. 227 | 228 | ```python 229 | 230 | # Using `self` as the DLTypeScopeProvider in an object (this is the primary use case) 231 | class MyModule(nn.Module): 232 | # ... some implementation details 233 | def __init__(self, config: MyConfig) -> None: 234 | self.cfg = config 235 | 236 | # the DLTypeScopeProvider protocol requires this function to be specified. 237 | def get_dltype_scope(self) -> dict[str, int]: 238 | """Return the DLType scope which is simply a dictionary of 'axis-name' -> dimension size.""" 239 | return {"in_channel": self.cfg.in_channel} 240 | 241 | # "self" is a literal, not a string -- pyright will yell at you if this is wrong. 242 | # The first argument of the decorated function will be checked to obey the protocol before calling `get_dltype_scope`. 243 | @dltyped("self") 244 | def forward( 245 | self, 246 | tensor_1: Annotated[torch.Tensor, FloatTensor["batch num_voxel_features z y x"]], 247 | # NOTE: in_channel comes from the external scope and is used in the expression below to evaluate the 'channels' expected dimension 248 | tensor_2: Annotated[torch.Tensor, FloatTensor["batch channels=in_channel-num_voxel_features z y x"]] 249 | ) -> torch.Tensor: 250 | 251 | ## Using a scope provider for a free function 252 | 253 | class MyProvider: 254 | def get_dltype_scope(self) -> dict[str, int]: 255 | # load some_value from a config file in the constructor 256 | # or fetch it from a singleton 257 | return { 258 | "dim1": self.some_value 259 | } 260 | 261 | @dltyped(provider=MyProvider()) 262 | def free_function(tensor: FloatTensor["batch dim1"]) -> None: 263 | # ... implementation details, dim1 provided by the external scope 264 | ``` 265 | 266 | ## Supported Types 267 | 268 | - `FloatTensor`: For any precision floating point tensor. Is a superset of the following: 269 | - `Float16Tensor`: For any 16 bit floating point type. Is a superset of the following: 270 | - `IEEE754HalfFloatTensor`: For 16 bit floating point types that comply with the IEE 754 half-precision specification (notably, does not include `bfloat16`). For numpy tensors `Float16Tensor` is equal to `IEEE754HalfFloatTensor`. Use if you need to forbid usage of `bfloat16` for some reason. Otherwise prefer the `Float16Tensor` type for usage with mixed precision codebases. 271 | - `BFloat16Tensor`: For 16 bit floating point tensors following the [`bfloat16` format](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format). Is not IEEE 754 compliant and is not supported by NumPy. Use if you need to write code that is `bfloat16` specific, otherwise prefer `Float16Tensor` for usage with a mixed precision instruction scope (such as `torch.amp`). 272 | - `Float32Tensor`: For single precision 32 bit floats. 273 | - `Float64Tensor`: For double precision 64 bit floats. Aliases to `DoubleTensor`. 274 | - Note that `np.float128` and `np.longdouble` will be considered as `FloatTensors` BUT do not exist as standalone types to be used by `dltype` ie. there is no `Float128Tensor` type. These types are not supported by torch, and only supported by numpy on certain platforms, thus we only "support" them insofar as they are considered floating point types. 275 | - `IntTensor`: For integer tensors of any precision. Is a superset of the following: 276 | - `Int8Tensor` 277 | - `Int16Tensor` 278 | - `Int32Tensor` 279 | - `Int64Tensor` 280 | - `BoolTensor`: For boolean tensors 281 | - `TensorTypeBase`: Base class for any tensor which does not enforce any specific datatype, feel free to add custom validation logic by overriding the `check` method. 282 | 283 | ## Limitations 284 | 285 | - In the current implementation, _every_ call will be checked, which may or may not be slow depending on how big the context is (it shouldn't be that slow). 286 | - Pydantic default values are not checked. 287 | - Only symbolic, literal, and expressions are allowed for dimension specifiers, f-string syntax from `jaxtyping` is not supported. 288 | - Only torch tensors and numpy arrays are supported for now. 289 | - Static checking is not supported, only runtime checks, though some errors will be caught statically by construction. 290 | - DLType does not support checkking inside unbounded container types (i.e. `list[TensorTypeBase]`) for performance reasons. 291 | - DLType does not support unions, but does support optionals. 292 | -------------------------------------------------------------------------------- /dltype/_lib/_parser.py: -------------------------------------------------------------------------------- 1 | """A module for parsing dimension shape expressions for dltype.""" 2 | 3 | from __future__ import annotations 4 | 5 | import enum 6 | import itertools 7 | import logging 8 | import math 9 | import re 10 | from typing import Final 11 | 12 | from typing_extensions import override 13 | 14 | _logger: Final = logging.getLogger(__name__) 15 | 16 | 17 | class DLTypeSpecifier(enum.Enum): 18 | """An enum representing a way to specify a name for a dimension expression or literal.""" 19 | 20 | EQUALS = "=" 21 | 22 | 23 | class DLTypeModifier(enum.Enum): 24 | """An enum representing a modifier that can be applied to a dimension expression.""" 25 | 26 | ANONYMOUS_MULTIAXIS = "..." 27 | NAMED_MULTIAXIS = "*" 28 | 29 | 30 | class _DLTypeOperator(enum.Enum): 31 | """An enum representing a mathematical operator for a dimension expression.""" 32 | 33 | ADD = "+" 34 | SUB = "-" 35 | MUL = "*" 36 | EXP = "^" 37 | DIV = "/" 38 | MIN = "min" 39 | MAX = "max" 40 | ISQRT = "isqrt" 41 | 42 | def evaluate_unary(self, a: int) -> int: 43 | """Evaluate the unary operator.""" 44 | if self is _DLTypeOperator.ISQRT: 45 | return math.isqrt(a) 46 | raise NotImplementedError(self) 47 | 48 | def evaluate(self, a: int, b: int) -> int: # noqa: PLR0911 49 | """Evaluate the operator.""" 50 | if self is _DLTypeOperator.ADD: 51 | return a + b 52 | if self is _DLTypeOperator.SUB: 53 | return a - b 54 | if self is _DLTypeOperator.MUL: 55 | return a * b 56 | if self is _DLTypeOperator.EXP: 57 | return int(a**b) 58 | if self is _DLTypeOperator.DIV: 59 | return a // b 60 | if self is _DLTypeOperator.MIN: 61 | return min(a, b) 62 | if self is _DLTypeOperator.MAX: 63 | return max(a, b) 64 | raise NotImplementedError(self) 65 | 66 | 67 | _op_precedence: Final = { 68 | _DLTypeOperator.ADD: 1, 69 | _DLTypeOperator.SUB: 1, 70 | _DLTypeOperator.MUL: 2, 71 | _DLTypeOperator.DIV: 2, 72 | _DLTypeOperator.EXP: 3, 73 | _DLTypeOperator.MIN: 0, 74 | _DLTypeOperator.MAX: 0, 75 | _DLTypeOperator.ISQRT: 0, 76 | } 77 | 78 | _unary_operators: Final = frozenset({_DLTypeOperator.ISQRT}) 79 | _binary_operators: Final = frozenset({_DLTypeOperator.MIN, _DLTypeOperator.MAX}) 80 | _functional_operators: Final = frozenset(_unary_operators.union(_binary_operators)) 81 | _valid_operators: frozenset[str] = frozenset( 82 | {op.value for op in _DLTypeOperator if op not in _functional_operators}, 83 | ) 84 | _valid_modifiers: frozenset[str] = frozenset({mod.value for mod in DLTypeModifier}) 85 | 86 | INFIX_EXPRESSION_SPLIT_RX: Final = re.compile( 87 | f"({'|'.join(map(re.escape, _valid_operators))})", 88 | ) 89 | VALID_EXPRESSION_RX: Final = re.compile( 90 | f"^[a-zA-Z0-9_{''.join(map(re.escape, _valid_operators.union(_valid_modifiers)))}]+$", 91 | ) 92 | VALID_IDENTIFIER_RX: Final = re.compile(r"^[a-zA-Z][a-zA-Z0-9\_]*$") 93 | 94 | 95 | class DLTypeDimensionExpression: 96 | """A class representing a dimension that depends on other dimensions.""" 97 | 98 | def __init__( 99 | self, 100 | identifier: str, 101 | postfix_expression: list[str | int | _DLTypeOperator], 102 | *, 103 | is_multiaxis_literal: bool = False, 104 | is_anonymous: bool = False, 105 | ) -> None: 106 | """Create a new dimension expression.""" 107 | self.identifier = identifier 108 | self.parsed_expression = postfix_expression 109 | # multiaxis literals cannot be evaluated until 110 | # the actual shape is known, so we don't consider this to be a true literal 111 | # for the purposes of evaluating the expression 112 | self.is_literal = not is_multiaxis_literal and all( 113 | isinstance(token, int) for token in postfix_expression 114 | ) 115 | self.is_identifier = is_multiaxis_literal or (postfix_expression == [identifier]) 116 | # this is an expression if it's not a literal value, if it's 117 | # an identifier that points to another dimension, or if it's an 118 | # identifier that doesn't just point to itself 119 | self.is_expression = not self.is_literal and ( 120 | len(postfix_expression) > 1 or self.identifier not in postfix_expression 121 | ) 122 | self.is_multiaxis_literal = is_multiaxis_literal 123 | self.is_anonymous = is_anonymous 124 | _logger.debug( 125 | "Created new %s dimension expression %r", 126 | "multiaxis" if self.is_multiaxis_literal else "", 127 | self, 128 | ) 129 | 130 | # ensure we don't have any self-referential expressions 131 | if self.is_expression and self.identifier in self.parsed_expression: 132 | msg = f"Self-referential expression {self=}" 133 | raise SyntaxError(msg) 134 | 135 | @override 136 | def __repr__(self) -> str: 137 | """Get a string representation of the dimension expression.""" 138 | if self.is_anonymous: 139 | return f"Anonymous<{self.identifier}>" 140 | if self.is_literal: 141 | return f"Literal<{self.identifier}={self.parsed_expression}>" 142 | return f"Identifier<{self.identifier}={self.parsed_expression}>" 143 | 144 | @classmethod 145 | def from_multiaxis_literal( 146 | cls, 147 | identifier: str, 148 | literal: int, 149 | *, 150 | is_anonymous: bool = False, 151 | ) -> DLTypeDimensionExpression: 152 | """ 153 | Create a new dimension expression from a multi-axis literal. 154 | 155 | This is a special case where the expression is a single literal that is repeated across all axes. 156 | Anonymous axes are a special case where the actual value of the literal is irrelevant. 157 | """ 158 | return cls( 159 | identifier, 160 | [literal], 161 | is_multiaxis_literal=True, 162 | is_anonymous=is_anonymous, 163 | ) 164 | 165 | def evaluate(self, scope: dict[str, int]) -> int: 166 | """Evaluate the expression.""" 167 | _logger.debug("Evaluating expression %s with scope %s", self, scope) 168 | stack: list[int] = [] 169 | 170 | if self.is_anonymous: 171 | msg = "Cannot evaluate an anonymous axis" 172 | raise ValueError(msg) 173 | 174 | if self.identifier in scope: 175 | # if the identifier is in the scope, we return the value directly 176 | # however if we're an anonymous axis, we don't want to 177 | # return the value directly as the prior scoped value is irrelevant 178 | return scope[self.identifier] 179 | 180 | for token in self.parsed_expression: 181 | if isinstance(token, int): 182 | # literal integer 183 | stack.append(token) 184 | elif isinstance(token, str): 185 | # intentionally allow KeyError to be raised if the identifier is not in the scope 186 | stack.append(scope[token]) 187 | elif isinstance(token, _DLTypeOperator): # pyright: ignore[reportUnnecessaryIsInstance] 188 | b = stack.pop() 189 | if token in _unary_operators: 190 | stack.append(token.evaluate_unary(b)) 191 | continue 192 | a = stack.pop() 193 | stack.append(token.evaluate(a, b)) 194 | else: 195 | msg = f"Invalid token {token=}" 196 | raise TypeError(msg) 197 | 198 | if len(stack) != 1: 199 | msg = f"Invalid stack {stack=}" 200 | raise ValueError(msg) 201 | 202 | _logger.debug("Evaluated expression %r to %r", self, stack[0]) 203 | return stack[0] 204 | 205 | 206 | def _postfix_from_infix(identifier: str, expression: str) -> DLTypeDimensionExpression: 207 | """ 208 | Extract a postfix expression from an infix expression. 209 | 210 | Notably, this function does not handle function calls, which are handled separately. 211 | """ 212 | _logger.debug("Parsing infix expression %r", expression) 213 | 214 | # this is a modified expression, so we need to handle it differently 215 | if expression == DLTypeModifier.ANONYMOUS_MULTIAXIS.value: 216 | return DLTypeDimensionExpression(expression, [], is_anonymous=True) 217 | if expression.startswith(DLTypeModifier.NAMED_MULTIAXIS.value): 218 | stripped_expression = expression[len(DLTypeModifier.NAMED_MULTIAXIS.value) :] 219 | if not VALID_IDENTIFIER_RX.match(stripped_expression): 220 | msg = f"Invalid identifier {stripped_expression=}" 221 | raise SyntaxError(msg) 222 | return DLTypeDimensionExpression(stripped_expression, [stripped_expression]) 223 | 224 | split_expression = INFIX_EXPRESSION_SPLIT_RX.split(expression) 225 | 226 | # Convert infix to postfix using shunting yard algorithm 227 | stack: list[str | _DLTypeOperator] = [] 228 | postfix: list[str | int | _DLTypeOperator] = [] 229 | for token in split_expression: 230 | if token.isdigit(): 231 | postfix.append(int(token)) 232 | elif token in _valid_operators: 233 | current_op = _DLTypeOperator(token) 234 | 235 | # Pop operators with higher or equal precedence 236 | while ( 237 | stack 238 | and isinstance(stack[-1], _DLTypeOperator) 239 | and _op_precedence.get(stack[-1], 0) >= _op_precedence.get(current_op, 0) 240 | ): 241 | postfix.append(stack.pop()) 242 | 243 | stack.append(current_op) 244 | elif VALID_IDENTIFIER_RX.match(token): 245 | # It's a variable name 246 | postfix.append(token) 247 | else: 248 | msg = f"Invalid expression {expression=}" 249 | raise SyntaxError(msg) 250 | 251 | # Pop any remaining operators 252 | while stack: 253 | postfix.append(stack.pop()) 254 | 255 | _logger.debug("Parsed infix expression %r to postfix %r", expression, postfix) 256 | return DLTypeDimensionExpression(identifier, postfix) 257 | 258 | 259 | def _maybe_parse_functional_expression( 260 | identifier: str, 261 | expression: str, 262 | function: _DLTypeOperator, 263 | ) -> DLTypeDimensionExpression | None: 264 | """ 265 | Parse a function-like expression such as min(a,b) or max(x,y). 266 | 267 | Args: 268 | identifier: The identifier for the expression (e.g. the name of the dimension) 269 | expression: The expression to parse 270 | function: The function operator (_DLTypeOperator.MIN or _DLTypeOperator.MAX) 271 | 272 | Returns: 273 | A parsed dimension expression if the expression is a valid function call, None otherwise 274 | 275 | """ 276 | if not expression.startswith(f"{function.value}("): 277 | return None 278 | 279 | # Find balanced closing parenthesis 280 | # Strip function name and opening parenthesis 281 | content = expression[len(function.value) + 1 :] 282 | 283 | # Remove closing parenthesis 284 | content = content[:-1] 285 | 286 | # Find the comma that separates arguments (accounting for nesting) 287 | depth = 0 288 | balanced_content: list[str] = [] 289 | current_span = "" 290 | 291 | for char in content: 292 | current_span += char 293 | if char == "(": 294 | depth += 1 295 | elif char == ")": 296 | depth -= 1 297 | if depth < 0: 298 | msg = f"Unbalanced parentheses in function expression: {expression}" 299 | raise SyntaxError(msg) 300 | elif char == "," and depth == 0: 301 | balanced_content.append(current_span[:-1]) 302 | current_span = "" 303 | balanced_content.append(current_span) 304 | 305 | if function in _binary_operators and len(balanced_content) != 2: # noqa: PLR2004 306 | msg = f"Function {function.value} requires 2 arguments, got {len(balanced_content)} in {expression=}" 307 | raise SyntaxError(msg) 308 | if function in _unary_operators and len(balanced_content) != 1: 309 | msg = f"Function {function.value} requires 1 argument, got {len(balanced_content)} in {expression=}" 310 | raise SyntaxError(msg) 311 | 312 | expressions = [expression_from_string(exp) for exp in balanced_content] 313 | 314 | # Build postfix expression: [arg1 tokens, arg2 tokens, function] 315 | return DLTypeDimensionExpression( 316 | identifier, 317 | [*list(itertools.chain(*[exp.parsed_expression for exp in expressions])), function], 318 | ) 319 | 320 | 321 | def expression_from_string(expression: str) -> DLTypeDimensionExpression: 322 | """ 323 | Parse a dimension expression from a string and return a parsed expression. 324 | 325 | Examples: 326 | >>> expression_from_string("a+b") 327 | Identifier 328 | 329 | >>> expression_from_string("min(a,b)") 330 | Identifier 331 | 332 | # literals 333 | >>> expression_from_string("10") 334 | Literal<10=[10]> 335 | 336 | >>> expression_from_string("...") 337 | Anonymous<...> 338 | 339 | >>> expression_from_string("a*10") 340 | Identifier 341 | 342 | >>> expression_from_string("a*10+b") 343 | Identifier 344 | 345 | Args: 346 | identifier: The identifier for the expression (e.g. the name of the dimension) 347 | expression: The expression to parse 348 | 349 | Returns: 350 | A parsed dimension expression. 351 | 352 | """ 353 | if not expression: 354 | msg = f"Empty expression {expression=}" 355 | raise SyntaxError(msg) 356 | 357 | # split the expression into the identifier and the expression if it has a specifier 358 | identifier = expression 359 | if DLTypeSpecifier.EQUALS.value in expression: 360 | identifier, expression = expression.split(DLTypeSpecifier.EQUALS.value) 361 | 362 | for function in _functional_operators: 363 | if result := _maybe_parse_functional_expression( 364 | identifier, 365 | expression, 366 | function, 367 | ): 368 | _logger.debug("Parsed function expression %r", result) 369 | return result 370 | 371 | if not VALID_EXPRESSION_RX.match(expression): 372 | msg = f"Invalid {expression=} {VALID_EXPRESSION_RX=}" 373 | raise SyntaxError(msg) 374 | 375 | # split the expression into tokens using the operators from the enum as delimiters 376 | return _postfix_from_infix(identifier, expression) 377 | -------------------------------------------------------------------------------- /dltype/_lib/_universal_tensors.py: -------------------------------------------------------------------------------- 1 | # pyright: reportPossiblyUnboundVariable=false 2 | """Tensor types work with either torch or numpy (maybe extended later).""" 3 | 4 | from dltype._lib import _tensor_type_base 5 | from dltype._lib._dependency_utilities import ( 6 | is_numpy_available, 7 | is_torch_available, 8 | raise_for_missing_dependency, 9 | ) 10 | 11 | if is_numpy_available(): 12 | from dltype._lib._numpy_tensors import ( 13 | BoolTensor as NumPyBoolTensor, 14 | ) 15 | from dltype._lib._numpy_tensors import ( 16 | Float16Tensor as NumPyFloat16Tensor, 17 | ) 18 | from dltype._lib._numpy_tensors import ( 19 | Float32Tensor as NumPyFloat32Tensor, 20 | ) 21 | from dltype._lib._numpy_tensors import ( 22 | Float64Tensor as NumPyFloat64Tensor, 23 | ) 24 | from dltype._lib._numpy_tensors import ( 25 | FloatTensor as NumpyFloatTensor, 26 | ) 27 | from dltype._lib._numpy_tensors import ( 28 | IEEE754HalfFloatTensor as NumPyIEEE754HalfFloatTensor, 29 | ) 30 | from dltype._lib._numpy_tensors import ( 31 | Int8Tensor as NumPyInt8Tensor, 32 | ) 33 | from dltype._lib._numpy_tensors import ( 34 | Int16Tensor as NumPyInt16Tensor, 35 | ) 36 | from dltype._lib._numpy_tensors import ( 37 | Int32Tensor as NumPyInt32Tensor, 38 | ) 39 | from dltype._lib._numpy_tensors import ( 40 | Int64Tensor as NumPyInt64Tensor, 41 | ) 42 | from dltype._lib._numpy_tensors import ( 43 | IntTensor as NumPyIntTensor, 44 | ) 45 | from dltype._lib._numpy_tensors import ( 46 | SignedIntTensor as NumPySignedIntTensor, 47 | ) 48 | from dltype._lib._numpy_tensors import ( 49 | UInt8Tensor as NumPyUInt8Tensor, 50 | ) 51 | from dltype._lib._numpy_tensors import ( 52 | UInt16Tensor as NumPyUInt16Tensor, 53 | ) 54 | from dltype._lib._numpy_tensors import ( 55 | UInt32Tensor as NumPyUInt32Tensor, 56 | ) 57 | from dltype._lib._numpy_tensors import ( 58 | UInt64Tensor as NumPyUInt64Tensor, 59 | ) 60 | from dltype._lib._numpy_tensors import ( 61 | UnsignedIntTensor as NumPyUnsignedIntTensor, 62 | ) 63 | 64 | if is_torch_available(): 65 | from dltype._lib._torch_tensors import ( 66 | BFloat16Tensor as TorchBFloat16Tensor, 67 | ) 68 | from dltype._lib._torch_tensors import ( 69 | BoolTensor as TorchBoolTensor, 70 | ) 71 | from dltype._lib._torch_tensors import ( 72 | Float16Tensor as TorchFloat16Tensor, 73 | ) 74 | from dltype._lib._torch_tensors import ( 75 | Float32Tensor as TorchFloat32Tensor, 76 | ) 77 | from dltype._lib._torch_tensors import ( 78 | Float64Tensor as TorchFloat64Tensor, 79 | ) 80 | from dltype._lib._torch_tensors import ( 81 | FloatTensor as TorchFloatTensor, 82 | ) 83 | from dltype._lib._torch_tensors import ( 84 | IEEE754HalfFloatTensor as TorchIEEE754HalfFloatTensor, 85 | ) 86 | from dltype._lib._torch_tensors import ( 87 | Int8Tensor as TorchInt8Tensor, 88 | ) 89 | from dltype._lib._torch_tensors import ( 90 | Int16Tensor as TorchInt16Tensor, 91 | ) 92 | from dltype._lib._torch_tensors import ( 93 | Int32Tensor as TorchInt32Tensor, 94 | ) 95 | from dltype._lib._torch_tensors import ( 96 | Int64Tensor as TorchInt64Tensor, 97 | ) 98 | from dltype._lib._torch_tensors import ( 99 | IntTensor as TorchIntTensor, 100 | ) 101 | from dltype._lib._torch_tensors import ( 102 | SignedIntTensor as TorchSignedIntTensor, 103 | ) 104 | from dltype._lib._torch_tensors import ( 105 | UInt8Tensor as TorchUInt8Tensor, 106 | ) 107 | from dltype._lib._torch_tensors import ( 108 | UInt16Tensor as TorchUInt16Tensor, 109 | ) 110 | from dltype._lib._torch_tensors import ( 111 | UInt32Tensor as TorchUInt32Tensor, 112 | ) 113 | from dltype._lib._torch_tensors import ( 114 | UInt64Tensor as TorchUInt64Tensor, 115 | ) 116 | from dltype._lib._torch_tensors import ( 117 | UnsignedIntTensor as TorchUnsignedIntTensor, 118 | ) 119 | 120 | 121 | class Int8Tensor(_tensor_type_base.TensorTypeBase): 122 | """A class to represent an 8-bit integer tensor type.""" 123 | 124 | DTYPES = ( 125 | (*TorchInt8Tensor.DTYPES, *NumPyInt8Tensor.DTYPES) 126 | if is_torch_available() and is_numpy_available() 127 | else ( 128 | TorchInt8Tensor.DTYPES 129 | if is_torch_available() 130 | else NumPyInt8Tensor.DTYPES 131 | if is_numpy_available() 132 | else raise_for_missing_dependency() 133 | ) 134 | ) 135 | 136 | 137 | class UInt8Tensor(_tensor_type_base.TensorTypeBase): 138 | """A class to represent an unsigned 8-bit integer tensor type.""" 139 | 140 | DTYPES = ( 141 | (*TorchUInt8Tensor.DTYPES, *NumPyUInt8Tensor.DTYPES) 142 | if is_torch_available() and is_numpy_available() 143 | else ( 144 | TorchUInt8Tensor.DTYPES 145 | if is_torch_available() 146 | else NumPyUInt8Tensor.DTYPES 147 | if is_numpy_available() 148 | else raise_for_missing_dependency() 149 | ) 150 | ) 151 | 152 | 153 | class Int16Tensor(_tensor_type_base.TensorTypeBase): 154 | """A class to represent a 16-bit integer tensor type.""" 155 | 156 | DTYPES = ( 157 | (*TorchInt16Tensor.DTYPES, *NumPyInt16Tensor.DTYPES) 158 | if is_torch_available() and is_numpy_available() 159 | else ( 160 | TorchInt16Tensor.DTYPES 161 | if is_torch_available() 162 | else NumPyInt16Tensor.DTYPES 163 | if is_numpy_available() 164 | else raise_for_missing_dependency() 165 | ) 166 | ) 167 | 168 | 169 | class UInt16Tensor(_tensor_type_base.TensorTypeBase): 170 | """A class to represent an unsigned 16-bit integer tensor type.""" 171 | 172 | DTYPES = ( 173 | (*TorchUInt16Tensor.DTYPES, *NumPyUInt16Tensor.DTYPES) 174 | if is_torch_available() and is_numpy_available() 175 | else ( 176 | TorchUInt16Tensor.DTYPES 177 | if is_torch_available() 178 | else NumPyUInt16Tensor.DTYPES 179 | if is_numpy_available() 180 | else raise_for_missing_dependency() 181 | ) 182 | ) 183 | 184 | 185 | class Int32Tensor(_tensor_type_base.TensorTypeBase): 186 | """A class to represent a 32-bit integer tensor type.""" 187 | 188 | DTYPES = ( 189 | (*TorchInt32Tensor.DTYPES, *NumPyInt32Tensor.DTYPES) 190 | if is_torch_available() and is_numpy_available() 191 | else ( 192 | TorchInt32Tensor.DTYPES 193 | if is_torch_available() 194 | else NumPyInt32Tensor.DTYPES 195 | if is_numpy_available() 196 | else raise_for_missing_dependency() 197 | ) 198 | ) 199 | 200 | 201 | class UInt32Tensor(_tensor_type_base.TensorTypeBase): 202 | """A class to represent an unsigned 32-bit integer tensor type.""" 203 | 204 | DTYPES = ( 205 | (*TorchUInt32Tensor.DTYPES, *NumPyUInt32Tensor.DTYPES) 206 | if is_torch_available() and is_numpy_available() 207 | else ( 208 | TorchUInt32Tensor.DTYPES 209 | if is_torch_available() 210 | else NumPyUInt32Tensor.DTYPES 211 | if is_numpy_available() 212 | else raise_for_missing_dependency() 213 | ) 214 | ) 215 | 216 | 217 | class Int64Tensor(_tensor_type_base.TensorTypeBase): 218 | """A class to represent a 64-bit integer tensor type.""" 219 | 220 | DTYPES = ( 221 | (*TorchInt64Tensor.DTYPES, *NumPyInt64Tensor.DTYPES) 222 | if is_torch_available() and is_numpy_available() 223 | else ( 224 | TorchInt64Tensor.DTYPES 225 | if is_torch_available() 226 | else NumPyInt64Tensor.DTYPES 227 | if is_numpy_available() 228 | else raise_for_missing_dependency() 229 | ) 230 | ) 231 | 232 | 233 | class UInt64Tensor(_tensor_type_base.TensorTypeBase): 234 | """A class to represent an unsigned 64-bit integer tensor type.""" 235 | 236 | DTYPES = ( 237 | (*TorchUInt64Tensor.DTYPES, *NumPyUInt64Tensor.DTYPES) 238 | if is_torch_available() and is_numpy_available() 239 | else ( 240 | TorchUInt64Tensor.DTYPES 241 | if is_torch_available() 242 | else NumPyUInt64Tensor.DTYPES 243 | if is_numpy_available() 244 | else raise_for_missing_dependency() 245 | ) 246 | ) 247 | 248 | 249 | class SignedIntTensor(_tensor_type_base.TensorTypeBase): 250 | """A class to represent an integer tensor of any precision (8 bit, 16 bit, 32 bit and 64 bit).""" 251 | 252 | DTYPES = ( 253 | (*TorchSignedIntTensor.DTYPES, *NumPySignedIntTensor.DTYPES) 254 | if is_torch_available() and is_numpy_available() 255 | else ( 256 | TorchIntTensor.DTYPES 257 | if is_torch_available() 258 | else NumPySignedIntTensor.DTYPES 259 | if is_numpy_available() 260 | else raise_for_missing_dependency() 261 | ) 262 | ) 263 | 264 | 265 | class UnsignedIntTensor(_tensor_type_base.TensorTypeBase): 266 | """A class to represent an unsigned integer tensor of any precision (8 bit, 16 bit, 32 bit and 64 bit).""" 267 | 268 | DTYPES = ( 269 | (*TorchUnsignedIntTensor.DTYPES, *NumPyUnsignedIntTensor.DTYPES) 270 | if is_torch_available() and is_numpy_available() 271 | else ( 272 | TorchUnsignedIntTensor.DTYPES 273 | if is_torch_available() 274 | else NumPyUnsignedIntTensor.DTYPES 275 | if is_numpy_available() 276 | else raise_for_missing_dependency() 277 | ) 278 | ) 279 | 280 | 281 | class IntTensor(_tensor_type_base.TensorTypeBase): 282 | """A class to represent an integer tensor (signed or unsigned) of any precision (8, 16, 32, 64 bit).""" 283 | 284 | DTYPES = ( 285 | (*TorchIntTensor.DTYPES, *NumPyIntTensor.DTYPES) 286 | if is_torch_available() and is_numpy_available() 287 | else ( 288 | TorchIntTensor.DTYPES 289 | if is_torch_available() 290 | else NumPyIntTensor.DTYPES 291 | if is_numpy_available() 292 | else raise_for_missing_dependency() 293 | ) 294 | ) 295 | 296 | 297 | class IEEE754HalfFloatTensor(_tensor_type_base.TensorTypeBase): 298 | """ 299 | A class to represent half precision tensor types. Does not include special types such as bfloat16. 300 | 301 | Use this type if bfloat16 causes issues for some reason and you need to prohibit its use. 302 | """ 303 | 304 | DTYPES = ( 305 | (*TorchIEEE754HalfFloatTensor.DTYPES, *NumPyIEEE754HalfFloatTensor.DTYPES) 306 | if is_torch_available() and is_numpy_available() 307 | else ( 308 | TorchIEEE754HalfFloatTensor.DTYPES 309 | if is_torch_available() 310 | else NumPyIEEE754HalfFloatTensor.DTYPES 311 | if is_numpy_available() 312 | else raise_for_missing_dependency() 313 | ) 314 | ) 315 | 316 | 317 | class BFloat16Tensor(_tensor_type_base.TensorTypeBase): 318 | """A tensor that can only be bfloat16.""" 319 | 320 | DTYPES = TorchBFloat16Tensor.DTYPES if is_torch_available() else () 321 | 322 | 323 | class Float16Tensor(_tensor_type_base.TensorTypeBase): 324 | """A class to represent any 16-bit float tensor types (includes bfloat16).""" 325 | 326 | DTYPES = ( 327 | (*TorchFloat16Tensor.DTYPES, *NumPyFloat16Tensor.DTYPES) 328 | if is_torch_available() and is_numpy_available() 329 | else ( 330 | TorchFloat16Tensor.DTYPES 331 | if is_torch_available() 332 | else NumPyFloat16Tensor.DTYPES 333 | if is_numpy_available() 334 | else raise_for_missing_dependency() 335 | ) 336 | ) 337 | 338 | 339 | class Float32Tensor(_tensor_type_base.TensorTypeBase): 340 | """A class to represent a 32-bit float tensor type.""" 341 | 342 | DTYPES = ( 343 | (*TorchFloat32Tensor.DTYPES, *NumPyFloat32Tensor.DTYPES) 344 | if is_torch_available() and is_numpy_available() 345 | else ( 346 | TorchFloat32Tensor.DTYPES 347 | if is_torch_available() 348 | else NumPyFloat32Tensor.DTYPES 349 | if is_numpy_available() 350 | else raise_for_missing_dependency() 351 | ) 352 | ) 353 | 354 | 355 | class Float64Tensor(_tensor_type_base.TensorTypeBase): 356 | """A class to represent a double tensor type.""" 357 | 358 | DTYPES = ( 359 | (*TorchFloat64Tensor.DTYPES, *NumPyFloat64Tensor.DTYPES) 360 | if is_torch_available() and is_numpy_available() 361 | else ( 362 | TorchFloat64Tensor.DTYPES 363 | if is_torch_available() 364 | else NumPyFloat64Tensor.DTYPES 365 | if is_numpy_available() 366 | else raise_for_missing_dependency() 367 | ) 368 | ) 369 | 370 | 371 | DoubleTensor = Float64Tensor 372 | 373 | 374 | class FloatTensor(_tensor_type_base.TensorTypeBase): 375 | """ 376 | A class to represent the superset of any floating point type of any precision. 377 | 378 | This includes 16 bit, 32 bit, 64 bit, and optionally numpy's 128 bit if it is supported. 379 | """ 380 | 381 | DTYPES = ( 382 | (*TorchFloatTensor.DTYPES, *NumpyFloatTensor.DTYPES) 383 | if is_torch_available() and is_numpy_available() 384 | else ( 385 | TorchFloatTensor.DTYPES 386 | if is_torch_available() 387 | else NumpyFloatTensor.DTYPES 388 | if is_numpy_available() 389 | else raise_for_missing_dependency() 390 | ) 391 | ) 392 | 393 | 394 | class BoolTensor(_tensor_type_base.TensorTypeBase): 395 | """A class to represent a boolean tensor type.""" 396 | 397 | DTYPES = ( 398 | (*TorchBoolTensor.DTYPES, *NumPyBoolTensor.DTYPES) 399 | if is_torch_available() and is_numpy_available() 400 | else ( 401 | TorchBoolTensor.DTYPES 402 | if is_torch_available() 403 | else NumPyBoolTensor.DTYPES 404 | if is_numpy_available() 405 | else raise_for_missing_dependency() 406 | ) 407 | ) 408 | 409 | 410 | __all__ = [ 411 | "BFloat16Tensor", 412 | "BoolTensor", 413 | "DoubleTensor", 414 | "Float16Tensor", 415 | "Float32Tensor", 416 | "Float64Tensor", 417 | "FloatTensor", 418 | "IEEE754HalfFloatTensor", 419 | "Int8Tensor", 420 | "Int16Tensor", 421 | "Int32Tensor", 422 | "Int64Tensor", 423 | "IntTensor", 424 | "SignedIntTensor", 425 | "UInt8Tensor", 426 | "UInt16Tensor", 427 | "UInt32Tensor", 428 | "UInt64Tensor", 429 | "UnsignedIntTensor", 430 | ] 431 | -------------------------------------------------------------------------------- /dltype/_lib/_core.py: -------------------------------------------------------------------------------- 1 | """A module to assist with using Annotated[torch.Tensor] in type hints.""" 2 | 3 | from __future__ import annotations 4 | 5 | import inspect 6 | import itertools 7 | import logging 8 | import warnings 9 | from functools import lru_cache, wraps 10 | from typing import ( 11 | TYPE_CHECKING, 12 | Annotated, 13 | Any, 14 | Final, 15 | Literal, 16 | NamedTuple, 17 | ParamSpec, 18 | Protocol, 19 | TypeVar, 20 | Union, 21 | cast, 22 | get_args, 23 | get_origin, 24 | get_type_hints, 25 | runtime_checkable, 26 | ) 27 | 28 | from dltype._lib import ( 29 | _constants, 30 | _dependency_utilities, 31 | _dltype_context, 32 | _dtypes, 33 | _errors, 34 | _tensor_type_base, 35 | ) 36 | 37 | if TYPE_CHECKING: 38 | from collections.abc import Callable 39 | 40 | 41 | _logger: Final = logging.getLogger(__name__) 42 | 43 | P = ParamSpec("P") 44 | R = TypeVar("R") 45 | 46 | 47 | class DLTypeAnnotation(NamedTuple): 48 | """A class representing a type annotation for a tensor.""" 49 | 50 | tensor_type_hint: type[_dtypes.DLtypeTensorT] | None 51 | dltype_annotation: _tensor_type_base.TensorTypeBase | None 52 | 53 | @classmethod 54 | def from_hint( 55 | cls, 56 | hint: type | None, 57 | *, 58 | optional: bool = False, 59 | ) -> tuple[DLTypeAnnotation | None, ...]: 60 | """Create a new _DLTypeAnnotation from a type hint.""" 61 | if hint is None: 62 | return (None,) 63 | 64 | _logger.debug("Creating DLType from hint %r", hint) 65 | n_expected_args = len(cls._fields) 66 | origin = get_origin(hint) 67 | args = get_args(hint) 68 | 69 | # Handle Optional[T] types (Union[T, None] or Union[T, NoneType]) 70 | if origin is Union: 71 | # Get the non-None type from the Union 72 | non_none_types = [t for t in args if t not in {type(None), None}] 73 | 74 | # Only support Optional[T], not general Union types 75 | if len(non_none_types) != 1: 76 | msg = f"Only Optional tensor types are supported, not general Union types. Got: {hint}" 77 | raise TypeError(msg) 78 | 79 | # Recursively process the non-None type with optional=True 80 | return cls.from_hint(non_none_types[0], optional=True) 81 | 82 | # tuple handling special case 83 | if origin is tuple: 84 | return tuple(itertools.chain(*[cls.from_hint(inner_hint) for inner_hint in args])) 85 | 86 | # Only process Annotated types 87 | if origin is not Annotated: 88 | return (None,) 89 | 90 | # Ensure the annotation is a TensorTypeBase 91 | if len(args) < n_expected_args or not isinstance( 92 | args[1], 93 | _tensor_type_base.TensorTypeBase, 94 | ): 95 | _logger.warning( 96 | "Invalid annotated dltype hint: %r", 97 | args[1:] if len(args) >= n_expected_args else None, 98 | ) 99 | return (None,) 100 | 101 | # Ensure the base type is a supported tensor type 102 | tensor_type, dltype_hint = args[0], args[1] 103 | if not any(T in tensor_type.mro() for T in _dtypes.SUPPORTED_TENSOR_TYPES): 104 | msg = f"Invalid base type=<{tensor_type}> in DLType hint, expected a subtype of {_dtypes.SUPPORTED_TENSOR_TYPES}" 105 | raise TypeError(msg) 106 | 107 | dltype_hint.optional = optional 108 | return (cls(tensor_type_hint=tensor_type, dltype_annotation=dltype_hint),) 109 | 110 | 111 | @lru_cache() 112 | def _resolve_types( 113 | annotations: tuple[DLTypeAnnotation | None, ...] | None, 114 | ) -> tuple[_tensor_type_base.TensorTypeBase | None, ...] | None: 115 | if annotations is None or all(ann is None for ann in annotations): 116 | return None 117 | return tuple((ann.dltype_annotation if ann is not None else None for ann in annotations)) 118 | 119 | 120 | @runtime_checkable 121 | class DLTypeScopeProvider(Protocol): 122 | """A protocol for classes that provide a scope for DLTypeDimensionExpression evaluation.""" 123 | 124 | def get_dltype_scope(self) -> _dltype_context.EvaluatedDimensionT: 125 | """Get the current scope of variables for the DLTypeDimensionExpression evaluation.""" 126 | ... 127 | 128 | 129 | def _maybe_get_type_hints( 130 | existing_hints: dict[str, tuple[DLTypeAnnotation | None, ...]] | None, 131 | func: Callable[P, R], 132 | ) -> dict[str, tuple[DLTypeAnnotation | None, ...]] | None: 133 | """Get the type hints for a function, or return an empty dict if not available.""" 134 | if existing_hints is not None: 135 | return existing_hints 136 | try: 137 | return { 138 | name: DLTypeAnnotation.from_hint(hint) 139 | for name, hint in get_type_hints(func, include_extras=True).items() 140 | } 141 | except NameError: 142 | return None 143 | 144 | 145 | @lru_cache() 146 | def _maybe_get_signature( 147 | existing: inspect.Signature | None, 148 | func: Callable[P, R], 149 | ) -> inspect.Signature | None: 150 | """Get the signature of a function, or return an empty signature if not available.""" 151 | if existing is not None: 152 | return existing 153 | try: 154 | return inspect.signature(func) 155 | except TypeError: 156 | return None 157 | 158 | 159 | def _resolve_value( 160 | value: Any, # noqa: ANN401 161 | type_hint: tuple[_tensor_type_base.TensorTypeBase | DLTypeAnnotation | None, ...], 162 | ) -> tuple[Any]: 163 | return cast("tuple[Any]", value) if len(type_hint) > 1 else (value,) 164 | 165 | 166 | def dltyped( # noqa: C901, PLR0915 167 | scope_provider: DLTypeScopeProvider | Literal["self"] | None = None, 168 | ) -> Callable[[Callable[P, R]], Callable[P, R]]: 169 | """ 170 | Apply type checking to the decorated function. 171 | 172 | Args: 173 | scope_provider: An optional scope provider to use for type checking, if None, no scope provider is used, if 'self' 174 | is used, the first argument of the function is expected to be a DLTypeScopeProvider and the function must be a method. 175 | 176 | Returns: 177 | A wrapper function with type checking 178 | 179 | """ 180 | 181 | def _inner_dltyped(func: Callable[P, R]) -> Callable[P, R]: # noqa: C901, PLR0915 182 | if _dependency_utilities.is_torch_scripting(): 183 | # jit script doesn't support annotated type hints at all, we have no choice but to skip the type checking 184 | return func 185 | 186 | # Handle regular functions 187 | signature = _maybe_get_signature(None, func) 188 | # assume that if signature is None, we are dealing with a function with a forward reference, which is almost certainly a classmethod or staticmethod 189 | # we can't check the signature in this case, so we just assume it's a method for now to avoid raising a false positive error 190 | # if it _isn't_ a method but we specified "self", later on when we check if the scope provider is a DLTypeScopeProvider, we'll raise an error 191 | is_method = ( 192 | bool("self" in signature.parameters or "cls" in signature.parameters) if signature else True 193 | ) 194 | if scope_provider == "self" and not is_method: 195 | msg = "Scope provider types can only be used with methods." 196 | raise TypeError(msg) 197 | return_key = "return" 198 | dltype_hints = _maybe_get_type_hints(None, func) 199 | 200 | # if we added dltype to a method where it will have no effect, warn the user 201 | if dltype_hints is not None and all(all(vv is None for vv in v) for v in dltype_hints.values()): 202 | _logger.warning("dltype_hints=%r", dltype_hints) 203 | warnings.warn( 204 | "No DLType hints found, skipping type checking", 205 | UserWarning, 206 | stacklevel=2, 207 | ) 208 | return func 209 | 210 | @wraps(func) 211 | @_dependency_utilities.torch_jit_unused # pyright: ignore[reportUnknownMemberType] 212 | def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # noqa: C901, PLR0912 213 | __tracebackhide__ = not _constants.DEBUG_MODE 214 | nonlocal signature 215 | nonlocal dltype_hints 216 | 217 | dltype_hints = _maybe_get_type_hints(dltype_hints, func) 218 | signature = _maybe_get_signature(signature, func) 219 | if signature is None or dltype_hints is None: 220 | warnings.warn( 221 | "Unable to determine signature of dltyped function, type checking will be skipped. (Inner classes with forward references are not supported.)", 222 | UserWarning, 223 | stacklevel=2, 224 | ) 225 | return func(*args, **kwargs) 226 | 227 | bound_args = signature.bind(*args, **kwargs) 228 | bound_args.apply_defaults() 229 | 230 | actual_args = bound_args.arguments 231 | 232 | ctx = _dltype_context.DLTypeContext() 233 | 234 | if scope_provider == "self" and isinstance( 235 | actual_args[str(scope_provider)], 236 | DLTypeScopeProvider, 237 | ): 238 | ctx.tensor_shape_map = actual_args[str(scope_provider)].get_dltype_scope() 239 | _logger.debug("Using self as scope provider %s", ctx.tensor_shape_map) 240 | elif scope_provider is not None and isinstance( 241 | scope_provider, 242 | DLTypeScopeProvider, 243 | ): 244 | ctx.tensor_shape_map = scope_provider.get_dltype_scope() 245 | _logger.debug("Using unbound scope provider %s", ctx.tensor_shape_map) 246 | elif scope_provider is not None: 247 | raise _errors.DLTypeScopeProviderError( 248 | bad_scope_provider=scope_provider, 249 | ) 250 | 251 | for name in dltype_hints: 252 | if name == return_key: 253 | # special handling of the return value, we don't want to evaluate the function before the arguments are checked 254 | continue 255 | 256 | if name in {"self", "cls"}: 257 | # if we have an argument called self or cls with a type hint, we would skip a parameter incorrectly 258 | # just disallow the behavior 259 | msg = f"Invalid argument {name=} is not a supported argument name for dltype." 260 | raise TypeError(msg) 261 | 262 | if maybe_annotation := dltype_hints.get(name): 263 | tensor = actual_args[name] 264 | ctx.add( 265 | name, 266 | _resolve_value(tensor, maybe_annotation), 267 | _resolve_types(maybe_annotation), 268 | ) 269 | elif any(isinstance(actual_args[name], T) for T in _dtypes.SUPPORTED_TENSOR_TYPES): 270 | warnings.warn( 271 | f"[argument={name}] is missing a DLType hint", 272 | UserWarning, 273 | stacklevel=2, 274 | ) 275 | else: 276 | _logger.debug("No DLType hint for %r", name) 277 | 278 | try: 279 | ctx.assert_context() 280 | retval = func(*args, **kwargs) 281 | if maybe_return_annotation := _resolve_types(dltype_hints.get(return_key)): 282 | ctx.add( 283 | return_key, 284 | _resolve_value(retval, maybe_return_annotation), 285 | maybe_return_annotation, 286 | ) 287 | ctx.assert_context() 288 | elif any(isinstance(retval, T) for T in _dtypes.SUPPORTED_TENSOR_TYPES): 289 | warnings.warn( 290 | f"[{return_key}] is missing a DLType hint", 291 | UserWarning, 292 | stacklevel=2, 293 | ) 294 | except _errors.DLTypeError as e: 295 | # include the full function signature in the error message 296 | e.set_context(f"{func.__name__}{signature}") 297 | raise 298 | return retval 299 | 300 | return wrapper 301 | 302 | return _inner_dltyped 303 | 304 | 305 | # Add this with your other TypeVar definitions 306 | NT = TypeVar("NT", bound=NamedTuple) 307 | 308 | 309 | def dltyped_namedtuple() -> Callable[[type[NT]], type[NT]]: 310 | """ 311 | Apply type checking to a NamedTuple class. 312 | 313 | Returns: 314 | A modified NamedTuple class with type checking on construction 315 | 316 | """ 317 | 318 | def _inner_dltyped_namedtuple(cls: type[NT]) -> type[NT]: 319 | # NOTE: NamedTuple isn't actually a class, it's a factory function that returns a new class so we can't use issubclass here 320 | if not ( 321 | isinstance(cls, type) and hasattr(cls, "_fields") and issubclass(cls, tuple) # pyright: ignore[reportUnnecessaryIsInstance] 322 | ): 323 | msg = f"Expected a NamedTuple class, got {cls}" 324 | raise TypeError(msg) 325 | 326 | # Get the field annotations from the NamedTuple 327 | field_hints = get_type_hints(cls, include_extras=True) 328 | 329 | # Check for fields with DLType annotations 330 | dltype_fields: dict[str, tuple[DLTypeAnnotation | None, ...]] = {} 331 | for field_name in cls._fields: 332 | if field_name in field_hints: 333 | hint = field_hints[field_name] 334 | dltype_fields[field_name] = DLTypeAnnotation.from_hint(hint) 335 | 336 | # If no fields need validation, return the original class 337 | if not dltype_fields: 338 | return cls 339 | 340 | # Create a new __new__ method that validates on construction 341 | original_new = cls.__new__ 342 | 343 | def validated_new(cls_inner: type[NT], *args: Any, **kwargs: Any) -> NT: # noqa: ANN401 (these actually can be any type) 344 | """A new __new__ method that validates the fields upon construction.""" 345 | # First create the instance using the original __new__ 346 | instance = original_new(cls_inner, *args, **kwargs) 347 | 348 | # Then validate all fields with DLType annotations 349 | ctx = _dltype_context.DLTypeContext() 350 | for field_name, annotation in dltype_fields.items(): 351 | field_index = cls._fields.index(field_name) 352 | value = instance[field_index] 353 | ctx.add(field_name, _resolve_value(value, annotation), _resolve_types(annotation)) 354 | 355 | # Assert that all fields are valid 356 | try: 357 | ctx.assert_context() 358 | except _errors.DLTypeError as e: 359 | e.set_context(cls.__name__) 360 | raise 361 | 362 | return instance 363 | 364 | # Create the new class with our modified __new__ method 365 | return cast("type[NT]", type(cls.__name__, (cls,), {"__new__": validated_new})) 366 | 367 | return _inner_dltyped_namedtuple 368 | 369 | 370 | DataclassT = TypeVar("DataclassT") 371 | 372 | 373 | def dltyped_dataclass() -> Callable[[type[DataclassT]], type[DataclassT]]: 374 | """ 375 | Apply type checking to a dataclass. 376 | 377 | This will validate all fields with DLType annotations during object construction. 378 | Works with both regular and frozen dataclasses. 379 | 380 | Returns: 381 | A modified dataclass with type checking on initialization 382 | 383 | """ 384 | 385 | def _inner_dltyped_dataclass(cls: type[DataclassT]) -> type[DataclassT]: 386 | if _dependency_utilities.is_torch_scripting(): 387 | return cls 388 | 389 | # check that we are a dataclass, raise an error if not 390 | if not hasattr(cls, "__dataclass_fields__"): 391 | msg = f"Class {cls.__name__} is not a dataclass, apply @dataclass below dltyped_dataclass." 392 | raise TypeError(msg) 393 | 394 | # Store original __init__ to ensure we run after the dataclass initialization 395 | original_init = cls.__init__ 396 | # Get field annotations 397 | field_hints = get_type_hints(cls, include_extras=True) 398 | dltype_hints = {name: DLTypeAnnotation.from_hint(hint) for name, hint in field_hints.items()} 399 | 400 | def new_init(self: DataclassT, *args: Any, **kwargs: Any) -> None: # noqa: ANN401 401 | """A new __init__ method that validates the fields after initialization.""" 402 | # First call the original __init__ 403 | original_init(self, *args, **kwargs) 404 | 405 | # Validate fields with DLType annotations 406 | ctx = _dltype_context.DLTypeContext() 407 | 408 | for field_name in field_hints: 409 | annotation = dltype_hints.get(field_name) 410 | if annotation is not None: 411 | # Get the field value 412 | value = getattr(self, field_name, None) 413 | 414 | # Add to validation context 415 | ctx.add(field_name, _resolve_value(value, annotation), _resolve_types(annotation)) 416 | 417 | # Assert that all fields are valid 418 | try: 419 | ctx.assert_context() 420 | except _errors.DLTypeError as e: 421 | e.set_context(cls.__name__) 422 | raise 423 | 424 | # Replace the __init__ method 425 | cls.__init__ = new_init 426 | 427 | return cls 428 | 429 | return _inner_dltyped_dataclass 430 | -------------------------------------------------------------------------------- /dltype/tests/dltype_test.py: -------------------------------------------------------------------------------- 1 | # pyright: reportPrivateUsage=false, reportUnknownMemberType=false 2 | """Tests for common types used in deep learning.""" 3 | 4 | import re 5 | from collections.abc import Callable 6 | from dataclasses import dataclass 7 | from pathlib import Path 8 | from tempfile import NamedTemporaryFile 9 | from typing import Annotated, Final, NamedTuple, TypeAlias 10 | from unittest.mock import patch 11 | 12 | import numpy as np 13 | import numpy.typing as npt 14 | import pytest 15 | import torch 16 | from pydantic import BaseModel 17 | 18 | import dltype 19 | 20 | np_rand = np.random.RandomState(42).rand 21 | NPFloatArrayT: TypeAlias = npt.NDArray[np.float32 | np.float64] 22 | NPIntArrayT: TypeAlias = npt.NDArray[np.int32 | np.uint16 | np.uint32 | np.uint8] 23 | 24 | 25 | class _RaisesInfo(NamedTuple): 26 | exception_type: type[Exception] | None = None 27 | regex: str | None = None 28 | value: torch.Tensor | None = None 29 | 30 | 31 | @dltype.dltyped() 32 | def bad_function( 33 | tensor: Annotated[torch.Tensor, dltype.TensorTypeBase["b c h w"]], 34 | ) -> Annotated[torch.Tensor, dltype.TensorTypeBase["h w b c"]]: 35 | """A function that takes a tensor and returns a tensor.""" 36 | return tensor 37 | 38 | 39 | @dltype.dltyped() 40 | def good_function( 41 | tensor: Annotated[torch.Tensor, dltype.TensorTypeBase["b c h w"]], 42 | ) -> Annotated[torch.Tensor, dltype.TensorTypeBase["h w b c"]]: 43 | """A function that takes a tensor and returns a tensor.""" 44 | return tensor.permute(2, 3, 0, 1) 45 | 46 | 47 | @dltype.dltyped() 48 | def incomplete_annotated_function( 49 | tensor: Annotated[torch.Tensor, dltype.TensorTypeBase["b c h w"]], 50 | ) -> torch.Tensor: 51 | """A function that takes a tensor and returns a tensor.""" 52 | return tensor 53 | 54 | 55 | @dltype.dltyped() 56 | def incomplete_return_function( 57 | tensor: torch.Tensor, 58 | ) -> Annotated[torch.Tensor, dltype.TensorTypeBase["h w b c"]]: 59 | """A function that takes a tensor and returns a tensor.""" 60 | return tensor.permute(2, 3, 0, 1) 61 | 62 | 63 | @dltype.dltyped() 64 | def inconsistent_shape_function( 65 | tensor: Annotated[torch.Tensor, dltype.TensorTypeBase["b c h w"]], 66 | ) -> Annotated[torch.Tensor, dltype.TensorTypeBase["h w b"]]: 67 | """A function that takes a tensor and returns a tensor.""" 68 | if tensor.shape[0] == 1: 69 | return tensor.permute(2, 3, 0) 70 | return tensor.permute(2, 3, 0, 1) 71 | 72 | 73 | BAD_DIMENSION_RX: Final = r"Invalid number of dimensions" 74 | 75 | 76 | def bad_dimension_error( 77 | tensor_name: str, 78 | *, 79 | idx: int, 80 | expected: int, 81 | actual: int, 82 | ) -> str: 83 | return re.escape( 84 | f"Invalid tensor shape, tensor={tensor_name} dim={idx} expected={expected} actual={actual}", 85 | ) 86 | 87 | 88 | def bad_ndim_error(tensor_name: str, *, expected: int, actual: int) -> str: 89 | return re.escape( 90 | f"Invalid number of dimensions, tensor={tensor_name} expected ndims={expected} actual={actual}", 91 | ) 92 | 93 | 94 | @pytest.mark.parametrize( 95 | ("input_tensor", "func", "expected"), 96 | [ 97 | pytest.param( 98 | torch.ones(1, 1, 1, 1), 99 | bad_function, 100 | _RaisesInfo(value=torch.ones(1, 1, 1, 1)), 101 | id="bad_func trivial", 102 | ), 103 | pytest.param( 104 | torch.rand(1, 2, 3, 4), 105 | bad_function, 106 | _RaisesInfo( 107 | exception_type=dltype.DLTypeShapeError, 108 | regex=bad_dimension_error("return", expected=3, idx=0, actual=1), 109 | ), 110 | id="bad_func_4D", 111 | ), 112 | pytest.param( 113 | torch.rand(1, 2, 3), 114 | bad_function, 115 | _RaisesInfo( 116 | exception_type=dltype.DLTypeNDimsError, 117 | regex=bad_ndim_error("tensor", expected=4, actual=3), 118 | ), 119 | id="bad_func_3D", 120 | ), 121 | pytest.param( 122 | torch.rand(1, 2, 3, 4, 5), 123 | bad_function, 124 | _RaisesInfo( 125 | exception_type=dltype.DLTypeNDimsError, 126 | regex=bad_ndim_error("tensor", expected=4, actual=5), 127 | ), 128 | id="bad_func_5D", 129 | ), 130 | pytest.param( 131 | torch.ones(1, 2, 3, 4), 132 | good_function, 133 | _RaisesInfo(value=torch.ones(3, 4, 1, 2)), 134 | id="good_func_4D", 135 | ), 136 | pytest.param( 137 | torch.rand(1, 2, 3), 138 | good_function, 139 | _RaisesInfo( 140 | exception_type=dltype.DLTypeNDimsError, 141 | regex=bad_ndim_error("tensor", expected=4, actual=3), 142 | ), 143 | id="good_func_3D", 144 | ), 145 | pytest.param( 146 | torch.ones(1, 2, 3, 4), 147 | incomplete_annotated_function, 148 | _RaisesInfo(value=torch.ones(1, 2, 3, 4)), 149 | id="incomplete_annotated_4D", 150 | ), 151 | pytest.param( 152 | torch.ones(1, 2, 3, 4), 153 | incomplete_return_function, 154 | _RaisesInfo(value=torch.ones(3, 4, 1, 2)), 155 | id="incomplete_return_4D", 156 | ), 157 | pytest.param( 158 | torch.ones(1, 2, 3), 159 | incomplete_return_function, 160 | _RaisesInfo( 161 | exception_type=RuntimeError, 162 | regex=r"number of dimensions in the tensor input does not match*", 163 | ), 164 | id="invalid arg no type hint", 165 | ), 166 | ], 167 | ) 168 | def test_single_in_single_out( 169 | input_tensor: torch.Tensor, 170 | func: Callable[[torch.Tensor], torch.Tensor], 171 | expected: _RaisesInfo, 172 | ) -> None: 173 | # test both positional and keyword arguments 174 | if expected.exception_type is not None: 175 | with pytest.raises(expected.exception_type, match=expected.regex): 176 | func(input_tensor) 177 | with pytest.raises(expected.exception_type, match=expected.regex): 178 | func(tensor=input_tensor) # pyright: ignore[reportCallIssue] 179 | else: 180 | torch.testing.assert_close(func(input_tensor), expected.value) 181 | torch.testing.assert_close(func(tensor=input_tensor), expected.value) # pyright: ignore[reportCallIssue] 182 | 183 | 184 | class _TestBaseModel(BaseModel, frozen=True): 185 | tensor: Annotated[torch.Tensor, dltype.TensorTypeBase("b c h w")] 186 | tensor_2: Annotated[torch.Tensor, dltype.TensorTypeBase("b c h w")] 187 | 188 | 189 | class _TestBaseModel2(BaseModel, frozen=True): 190 | tensor: Annotated[torch.Tensor, dltype.TensorTypeBase("b h c w")] 191 | tensor_2: Annotated[torch.Tensor, dltype.TensorTypeBase("b a c d")] 192 | 193 | 194 | class _TestBaseModelWithNumpy(BaseModel, frozen=True): 195 | tensor: Annotated[NPFloatArrayT, dltype.TensorTypeBase("b c h w")] 196 | tensor_2: Annotated[NPFloatArrayT, dltype.FloatTensor("b c h w")] 197 | 198 | 199 | @pytest.mark.parametrize( 200 | ("tensor", "tensor_2", "model", "expected"), 201 | [ 202 | pytest.param( 203 | torch.rand(1, 2, 3, 4), 204 | torch.rand(1, 2, 3, 4), 205 | _TestBaseModel, 206 | _RaisesInfo(), 207 | id="good_tensors", 208 | ), 209 | pytest.param( 210 | torch.rand(1, 2, 3, 4), 211 | torch.rand(1, 2, 3), 212 | _TestBaseModel, 213 | _RaisesInfo( 214 | exception_type=dltype.DLTypeNDimsError, 215 | regex=bad_ndim_error("tensor_2", expected=4, actual=3), 216 | ), 217 | id="bad_tensors", 218 | ), 219 | pytest.param( 220 | torch.rand(1, 2, 3, 4), 221 | torch.rand(1, 2, 3, 5), 222 | _TestBaseModel, 223 | _RaisesInfo( 224 | exception_type=dltype.DLTypeShapeError, 225 | regex=bad_dimension_error("tensor_2", idx=3, expected=4, actual=5), 226 | ), 227 | id="bad_tensors_2", 228 | ), 229 | pytest.param( 230 | torch.rand(2, 2, 2, 2), 231 | torch.rand(2, 2, 2, 2), 232 | _TestBaseModel, 233 | _RaisesInfo(), 234 | id="good_tensors_2", 235 | ), 236 | pytest.param( 237 | torch.rand(1, 2, 3, 4), 238 | torch.rand(1, 2, 3, 4), 239 | _TestBaseModel2, 240 | _RaisesInfo(), 241 | id="good_tensors_3", 242 | ), 243 | pytest.param( 244 | torch.rand(1, 2, 3, 4), 245 | torch.rand(1, 5, 3, 7), 246 | _TestBaseModel2, 247 | _RaisesInfo(), 248 | id="bad_tensors_3", 249 | ), 250 | pytest.param( 251 | torch.rand(4, 3, 2, 1), 252 | torch.rand(1, 2, 3, 4), 253 | _TestBaseModel2, 254 | _RaisesInfo( 255 | exception_type=dltype.DLTypeShapeError, 256 | regex=bad_dimension_error("tensor_2", idx=0, expected=4, actual=1), 257 | ), 258 | id="bad_tensors_4", 259 | ), 260 | pytest.param( 261 | torch.rand(2, 2, 2, 2), 262 | torch.rand(2, 2, 1, 2), 263 | _TestBaseModel2, 264 | _RaisesInfo( 265 | exception_type=dltype.DLTypeShapeError, 266 | regex=bad_dimension_error("tensor_2", idx=2, expected=2, actual=1), 267 | ), 268 | id="good_tensors_4", 269 | ), 270 | pytest.param( 271 | np_rand(1, 2, 3, 4), 272 | np_rand(1, 2, 3, 4), 273 | _TestBaseModelWithNumpy, 274 | _RaisesInfo(), 275 | id="good_tensors_numpy", 276 | ), 277 | pytest.param( 278 | np_rand(1, 2, 3, 4), 279 | np_rand(1, 2, 3), 280 | _TestBaseModelWithNumpy, 281 | _RaisesInfo( 282 | exception_type=dltype.DLTypeNDimsError, 283 | regex=bad_ndim_error("tensor_2", expected=4, actual=3), 284 | ), 285 | id="bad_tensors_numpy", 286 | ), 287 | pytest.param( 288 | np_rand(1, 2, 3, 4), 289 | np_rand(1, 2, 3, 5), 290 | _TestBaseModelWithNumpy, 291 | _RaisesInfo( 292 | exception_type=dltype.DLTypeShapeError, 293 | regex=bad_dimension_error("tensor_2", idx=3, expected=4, actual=5), 294 | ), 295 | id="bad_tensors_numpy_2", 296 | ), 297 | pytest.param( 298 | np_rand(2, 2, 2, 2), 299 | np_rand(2, 2, 2, 2).astype(np.int32), 300 | _TestBaseModelWithNumpy, 301 | _RaisesInfo( 302 | exception_type=dltype.DLTypeDtypeError, 303 | regex=r"Invalid dtype.*got=int32", 304 | ), 305 | id="bad_dtype", 306 | ), 307 | ], 308 | ) 309 | def test_dltype_pydantic( 310 | tensor: torch.Tensor, 311 | tensor_2: torch.Tensor, 312 | model: type[BaseModel], 313 | expected: _RaisesInfo, 314 | ) -> None: 315 | if expected.exception_type: 316 | with pytest.raises(expected.exception_type, match=expected.regex): 317 | model(tensor=tensor, tensor_2=tensor_2) 318 | else: 319 | model(tensor=tensor, tensor_2=tensor) 320 | 321 | 322 | @dltype.dltyped() 323 | def numpy_function( 324 | tensor: Annotated[NPFloatArrayT, dltype.TensorTypeBase["b c h w"]], 325 | ) -> Annotated[torch.Tensor, dltype.TensorTypeBase["h w b c"]]: 326 | """A function that takes a tensor and returns a tensor.""" 327 | return torch.from_numpy(tensor).permute(2, 3, 0, 1) 328 | 329 | 330 | @pytest.mark.parametrize( 331 | ("tensor", "expected"), 332 | [ 333 | pytest.param( 334 | np.ones((1, 2, 3, 4), dtype=np.float32), 335 | _RaisesInfo(value=torch.ones(3, 4, 1, 2)), 336 | id="good_tensors", 337 | ), 338 | pytest.param( 339 | np.zeros((1, 2, 3)), 340 | _RaisesInfo( 341 | exception_type=dltype.DLTypeNDimsError, 342 | regex=bad_ndim_error("tensor", expected=4, actual=3), 343 | ), 344 | id="bad_tensors", 345 | ), 346 | pytest.param( 347 | np.zeros((1, 2, 3, 5, 6)), 348 | _RaisesInfo( 349 | exception_type=dltype.DLTypeNDimsError, 350 | regex=bad_ndim_error("tensor", expected=4, actual=5), 351 | ), 352 | id="bad_tensors_2", 353 | ), 354 | ], 355 | ) 356 | def test_numpy_mixed(tensor: NPFloatArrayT, expected: _RaisesInfo) -> None: 357 | if expected.exception_type: 358 | with pytest.raises(expected.exception_type, match=expected.regex): 359 | numpy_function(tensor) 360 | else: 361 | torch.testing.assert_close(numpy_function(tensor), expected.value) 362 | 363 | 364 | @pytest.mark.parametrize( 365 | ("tensor_type", "tensor", "expected"), 366 | [ 367 | pytest.param( 368 | dltype.IntTensor("b c h w"), 369 | torch.rand(1, 2, 3, 4), 370 | _RaisesInfo(exception_type=dltype.DLTypeDtypeError), 371 | id="int_tensor", 372 | ), 373 | pytest.param( 374 | dltype.IntTensor("b c h w"), 375 | torch.rand(1, 2, 3, 4).int(), 376 | _RaisesInfo(), 377 | id="int_tensor_2", 378 | ), 379 | pytest.param( 380 | dltype.FloatTensor("b c h w"), 381 | torch.rand(1, 2, 3, 4).int(), 382 | _RaisesInfo(exception_type=dltype.DLTypeDtypeError), 383 | id="float_tensor", 384 | ), 385 | pytest.param( 386 | dltype.FloatTensor("b c h w"), 387 | torch.rand(1, 2, 3, 4), 388 | _RaisesInfo(), 389 | id="float_tensor_2", 390 | ), 391 | pytest.param( 392 | dltype.FloatTensor("b c h w"), 393 | np_rand(1, 2, 3, 4).astype(np.double), 394 | _RaisesInfo(), 395 | id="float_tensor_3", 396 | ), 397 | pytest.param( 398 | dltype.IntTensor("b c h w"), 399 | np_rand(1, 2, 3, 4), 400 | _RaisesInfo(exception_type=dltype.DLTypeDtypeError), 401 | id="int_tensor_2", 402 | ), 403 | pytest.param( 404 | dltype.DoubleTensor("b c h w"), 405 | np_rand(1, 2, 3, 4).astype(np.float32), 406 | _RaisesInfo(exception_type=dltype.DLTypeDtypeError), 407 | id="double_tensor", 408 | ), 409 | pytest.param( 410 | dltype.DoubleTensor("b c h w"), 411 | np_rand(1, 2, 3, 4).astype(np.double), 412 | _RaisesInfo(), 413 | id="double_tensor_2", 414 | ), 415 | pytest.param( 416 | dltype.DoubleTensor("b c h w"), 417 | np_rand(1, 2, 3, 4).astype(np.float64), 418 | _RaisesInfo(), 419 | id="double_tensor_3", 420 | ), 421 | ], 422 | ) 423 | def test_types( 424 | tensor_type: dltype.TensorTypeBase, 425 | tensor: torch.Tensor | NPFloatArrayT, 426 | expected: _RaisesInfo, 427 | ) -> None: 428 | if expected.exception_type: 429 | with pytest.raises(expected.exception_type, match=expected.regex): 430 | tensor_type.check(tensor) 431 | else: 432 | tensor_type.check(tensor) 433 | 434 | 435 | @pytest.mark.parametrize( 436 | ("tensor_type", "tensor", "expected"), 437 | [ 438 | pytest.param( 439 | dltype.FloatTensor("1 2 3 4"), 440 | torch.rand(1, 2, 3, 4), 441 | _RaisesInfo(), 442 | id="int_tensor", 443 | ), 444 | pytest.param( 445 | dltype.FloatTensor("1 2 3 4"), 446 | np_rand(1, 2, 3, 4), 447 | _RaisesInfo(), 448 | id="int_tensor_2", 449 | ), 450 | pytest.param( 451 | dltype.FloatTensor("b c 3 4"), 452 | torch.rand(1, 2, 3, 4), 453 | _RaisesInfo(), 454 | id="mixed literal dims", 455 | ), 456 | pytest.param( 457 | dltype.FloatTensor("1 c h 4"), 458 | torch.rand(1, 2, 3, 4), 459 | _RaisesInfo(), 460 | id="dims before and after", 461 | ), 462 | pytest.param( 463 | dltype.FloatTensor("1 2 3 w"), 464 | torch.rand(2, 2, 3, 9), 465 | _RaisesInfo( 466 | exception_type=dltype.DLTypeShapeError, 467 | regex=bad_dimension_error("anonymous", idx=0, expected=1, actual=2), 468 | ), 469 | id="bad literal dim", 470 | ), 471 | pytest.param( 472 | dltype.FloatTensor("*batch c 2 3"), 473 | torch.rand(1, 2, 2, 3), 474 | _RaisesInfo(), 475 | id="wildcard dim", 476 | ), 477 | pytest.param( 478 | dltype.FloatTensor("*batch c 2 3"), 479 | torch.rand(3, 2, 3, 2), 480 | _RaisesInfo( 481 | exception_type=dltype.DLTypeShapeError, 482 | regex=bad_dimension_error("anonymous", idx=2, expected=2, actual=3), 483 | ), 484 | id="wildcard dim_2", 485 | ), 486 | ], 487 | ) 488 | def test_literal_shapes( 489 | tensor_type: dltype.TensorTypeBase, 490 | tensor: torch.Tensor | NPFloatArrayT, 491 | expected: _RaisesInfo, 492 | ) -> None: 493 | if expected.exception_type is not None: 494 | with pytest.raises(expected.exception_type, match=expected.regex): 495 | tensor_type.check(tensor) 496 | else: 497 | tensor_type.check(tensor) 498 | 499 | 500 | def test_onnx_export() -> None: 501 | class _DummyModule(torch.nn.Module): 502 | @dltype.dltyped() 503 | def forward( 504 | self, 505 | x: Annotated[torch.Tensor, dltype.FloatTensor("b c h w")], 506 | ) -> Annotated[torch.Tensor, dltype.FloatTensor("b c h w")]: 507 | return torch.multiply(x, 2) 508 | 509 | with NamedTemporaryFile() as f: 510 | torch.onnx.export( 511 | _DummyModule(), 512 | (torch.rand(1, 2, 3, 4),), 513 | f.name, 514 | input_names=["input"], 515 | output_names=["output"], 516 | ) 517 | 518 | assert Path(f.name).exists() 519 | assert Path(f.name).stat().st_size > 0 520 | 521 | with pytest.raises(TypeError): 522 | torch.onnx.export( 523 | _DummyModule(), 524 | (torch.rand(1, 2, 3),), 525 | f.name, 526 | input_names=["input"], 527 | output_names=["output"], 528 | ) 529 | 530 | 531 | def test_torch_compile() -> None: 532 | class _DummyModule(torch.nn.Module): 533 | @dltype.dltyped() 534 | def forward( 535 | self, 536 | x: Annotated[torch.Tensor, dltype.FloatTensor("b c h w")], 537 | ) -> Annotated[torch.Tensor, dltype.FloatTensor("b c h w")]: 538 | return torch.multiply(x, 2) 539 | 540 | _DummyModule().forward(torch.rand(1, 2, 3, 4)) 541 | 542 | with pytest.raises(TypeError): 543 | _DummyModule().forward(torch.rand(1, 2, 3)) 544 | 545 | module = torch.compile(_DummyModule()) 546 | 547 | module(torch.rand(1, 2, 3, 4)) 548 | 549 | with pytest.raises(TypeError): 550 | module(torch.rand(1, 2, 3)) 551 | 552 | torch.jit.trace(_DummyModule(), torch.rand(1, 2, 3, 4)) 553 | 554 | scripted_module = torch.jit.script(_DummyModule()) 555 | 556 | scripted_module(torch.rand(1, 2, 3, 4)) 557 | 558 | with pytest.raises(TypeError): 559 | scripted_module(torch.rand(1, 2, 3)) 560 | 561 | 562 | @dltype.dltyped() 563 | def mixed_func( # noqa: PLR0913 564 | tensor: Annotated[torch.Tensor, dltype.TensorTypeBase["b c h w"]], 565 | array: Annotated[NPFloatArrayT, dltype.TensorTypeBase["b c h w"]], 566 | number: int, 567 | other_tensor: torch.Tensor, 568 | other_array: NPFloatArrayT, 569 | other_number: float, 570 | other_annotated_tensor: Annotated[torch.Tensor, dltype.TensorTypeBase["c c c"]], 571 | ) -> Annotated[torch.Tensor, dltype.TensorTypeBase["h w b c"]]: 572 | return tensor.permute(2, 3, 0, 1) 573 | 574 | 575 | def test_mixed_typing() -> None: 576 | mixed_func( 577 | torch.rand(1, 2, 3, 4), 578 | np_rand(1, 2, 3, 4), 579 | 1, 580 | torch.rand(1, 1, 1, 1), 581 | np_rand(1, 2, 3, 4), 582 | 1.0, 583 | torch.rand(2, 2, 2), 584 | ) 585 | 586 | with pytest.raises(TypeError): 587 | mixed_func( 588 | torch.rand(1, 2, 3, 4), 589 | np_rand(1, 2, 3, 4), 590 | 1, 591 | torch.rand(1, 1, 1, 1), 592 | np_rand(1, 2, 3, 4), 593 | 1.0, 594 | torch.rand(1, 2, 2), 595 | ) 596 | 597 | 598 | def test_bad_argument_name() -> None: 599 | @dltype.dltyped() 600 | def bad_function( 601 | self: Annotated[torch.Tensor, dltype.TensorTypeBase["b c h w"]], 602 | ) -> Annotated[torch.Tensor, dltype.TensorTypeBase["h w b c"]]: 603 | return self 604 | 605 | with pytest.raises(TypeError): 606 | bad_function(torch.rand(1, 2, 3, 4)) 607 | 608 | @dltype.dltyped() 609 | def other_bad_function( 610 | self: int, 611 | tensor: Annotated[torch.Tensor, dltype.TensorTypeBase["b c h w"]], 612 | ) -> Annotated[torch.Tensor, dltype.TensorTypeBase["h w b c"]]: 613 | return tensor 614 | 615 | with pytest.raises(TypeError): 616 | other_bad_function(1, torch.rand(1, 2, 3, 4)) 617 | 618 | 619 | def test_bad_dimension_name() -> None: 620 | with pytest.raises(SyntaxError): 621 | 622 | def bad_function( # pyright: ignore[reportUnusedFunction] 623 | tensor: Annotated[torch.Tensor, dltype.TensorTypeBase["b?"]], 624 | ) -> None: 625 | print(tensor) 626 | 627 | 628 | @dltype.dltyped() 629 | def func_with_expression( 630 | input_tensor: Annotated[torch.Tensor, dltype.FloatTensor["batch channels dim"]], 631 | ) -> Annotated[torch.Tensor, dltype.FloatTensor["batch channels dim-1"]]: 632 | return input_tensor[..., :-1] 633 | 634 | 635 | @dltype.dltyped() 636 | def func_with_min( 637 | input_tensor: Annotated[torch.Tensor, dltype.FloatTensor["batch channels dim"]], 638 | ) -> Annotated[torch.Tensor, dltype.FloatTensor["batch channels min(10,max(1,dim-1))"]]: 639 | if input_tensor.shape[2] == 1: 640 | return input_tensor[...] 641 | if input_tensor.shape[2] > 10: 642 | return input_tensor[..., :10] 643 | return input_tensor[..., :-1] 644 | 645 | 646 | @pytest.mark.parametrize( 647 | ("tensor", "function", "expected"), 648 | [ 649 | pytest.param( 650 | torch.rand(1, 2, 3), 651 | func_with_expression, 652 | _RaisesInfo(), 653 | id="basic_expression", 654 | ), 655 | pytest.param( 656 | torch.rand(1, 2, 3, 4), 657 | func_with_expression, 658 | _RaisesInfo( 659 | exception_type=dltype.DLTypeNDimsError, 660 | regex=bad_ndim_error("input_tensor", expected=3, actual=4), 661 | ), 662 | id="bad_expression", 663 | ), 664 | pytest.param( 665 | torch.rand(1, 2, 3), 666 | func_with_min, 667 | _RaisesInfo(), 668 | id="min_expression", 669 | ), 670 | pytest.param( 671 | torch.rand(1, 2, 1), 672 | func_with_min, 673 | _RaisesInfo(), 674 | id="min_expression_2", 675 | ), 676 | pytest.param( 677 | torch.rand(1, 2, 11), 678 | func_with_min, 679 | _RaisesInfo(), 680 | id="min_expression_3", 681 | ), 682 | pytest.param( 683 | torch.rand(1, 2), 684 | func_with_min, 685 | _RaisesInfo( 686 | exception_type=dltype.DLTypeNDimsError, 687 | regex=bad_ndim_error("input_tensor", expected=3, actual=2), 688 | ), 689 | id="min_expression_4", 690 | ), 691 | ], 692 | ) 693 | def test_typing_expressions( 694 | tensor: torch.Tensor, 695 | function: Callable[[torch.Tensor], torch.Tensor], 696 | expected: _RaisesInfo, 697 | ) -> None: 698 | if expected.exception_type: 699 | with pytest.raises(expected.exception_type, match=expected.regex): 700 | function(tensor) 701 | else: 702 | function(tensor) 703 | 704 | 705 | def test_expression_syntax_errors() -> None: 706 | with pytest.raises(SyntaxError): 707 | 708 | @dltype.dltyped() 709 | def func_with_bad_expression( # pyright: ignore[reportUnusedFunction] 710 | _: Annotated[torch.Tensor, dltype.FloatTensor["batch channels dim+"]], 711 | ) -> None: 712 | return None 713 | 714 | with pytest.raises(SyntaxError): 715 | 716 | @dltype.dltyped() 717 | def func_with_bad_expression( # pyright: ignore[reportUnusedFunction] 718 | _: Annotated[torch.Tensor, dltype.FloatTensor["+ - * dim"]], 719 | ) -> None: 720 | return None 721 | 722 | with pytest.raises(SyntaxError): 723 | # don't allow multiple min/max calls 724 | @dltype.dltyped() 725 | def func_with_bad_expression( # pyright: ignore[reportUnusedFunction] 726 | _: Annotated[ 727 | torch.Tensor, 728 | dltype.FloatTensor["batch channels min(1,channels-1)+max(channels,dim)"], 729 | ], 730 | ) -> None: 731 | return None 732 | 733 | with pytest.raises(SyntaxError): 734 | # don't allow multiple operators in a row 735 | @dltype.dltyped() 736 | def func_with_bad_expression( # pyright: ignore[reportUnusedFunction] 737 | _: Annotated[torch.Tensor, dltype.FloatTensor["dim dim-*1"]], 738 | ) -> None: 739 | return None 740 | 741 | 742 | @dltype.dltyped() 743 | def func_with_axis_wildcard( 744 | input_tensor: Annotated[torch.Tensor, dltype.FloatTensor["*batch channels h w"]], 745 | _: Annotated[torch.Tensor, dltype.FloatTensor["*batch channels w h"]], 746 | ) -> Annotated[torch.Tensor, dltype.FloatTensor["*batch channels*h*w"]]: 747 | # reshape along the last two dimensions but preserve the first N dimensions and the channel dimension 748 | return input_tensor.view(*input_tensor.shape[:-3], -1) 749 | 750 | 751 | # function with a wildcard in the middle of the dimensions 752 | @dltype.dltyped() 753 | def func_with_mid_tensor_wildcard( 754 | input_tensor: Annotated[torch.Tensor, dltype.FloatTensor["batch *channels h w"]], 755 | fail: bool = False, 756 | ) -> Annotated[torch.Tensor, dltype.FloatTensor["batch *channels h*w"]]: 757 | # reshape the tensor to preserve the batch dimension, retain the 758 | # channel dimensions, and flatten the spatial dimensions 759 | if fail: 760 | return input_tensor.view(input_tensor.shape[0] + 1, -1) 761 | return input_tensor.view(*input_tensor.shape[:-2], -1) 762 | 763 | 764 | # function with a wildcard in the middle of the dimensions 765 | @dltype.dltyped() 766 | def func_with_anon_wildcard( 767 | input_tensor: Annotated[torch.Tensor, dltype.FloatTensor["... h w"]], 768 | fail: bool = False, 769 | ) -> Annotated[torch.Tensor, dltype.FloatTensor["N h*w"]]: 770 | # take any number of dimensions > 2 and preserve the last two dimensions but flatten the rest 771 | if fail: 772 | return input_tensor.view(-1, input_tensor.shape[-1], input_tensor.shape[-2]) 773 | return input_tensor.view(-1, input_tensor.shape[-2] * input_tensor.shape[-1]) 774 | 775 | 776 | @pytest.mark.parametrize( 777 | ("tensor", "maybe_tensor_b", "func", "expected"), 778 | [ 779 | pytest.param( 780 | torch.rand(1, 2, 3, 4, 5), 781 | torch.rand(1, 2, 3, 5, 4), 782 | func_with_axis_wildcard, 783 | _RaisesInfo(), 784 | id="wildcard_axis", 785 | ), 786 | pytest.param( 787 | torch.rand(1, 2, 3, 4), 788 | torch.rand(1, 2, 4, 3), 789 | func_with_axis_wildcard, 790 | _RaisesInfo(), 791 | id="wildcard_axis_2", 792 | ), 793 | pytest.param( 794 | torch.rand(1, 2, 3), 795 | torch.rand(1, 2, 3), 796 | func_with_axis_wildcard, 797 | _RaisesInfo( 798 | exception_type=dltype.DLTypeShapeError, 799 | regex=bad_dimension_error("_", idx=1, expected=3, actual=2), 800 | ), 801 | id="wildcard_axis_3", 802 | ), 803 | pytest.param( 804 | torch.rand(1, 2, 3), 805 | None, 806 | func_with_mid_tensor_wildcard, 807 | _RaisesInfo(), 808 | id="mid_tensor_wildcard", 809 | ), 810 | pytest.param( 811 | torch.rand(1, 2, 3, 4), 812 | None, 813 | func_with_mid_tensor_wildcard, 814 | _RaisesInfo(), 815 | id="mid_tensor_wildcard_2", 816 | ), 817 | pytest.param( 818 | torch.rand(1, 2, 3, 4, 5), 819 | None, 820 | func_with_mid_tensor_wildcard, 821 | _RaisesInfo(), 822 | id="mid_tensor_wildcard_2", 823 | ), 824 | pytest.param( 825 | torch.rand(1, 2), 826 | None, 827 | func_with_mid_tensor_wildcard, 828 | _RaisesInfo( 829 | exception_type=dltype.DLTypeNDimsError, 830 | regex=bad_ndim_error("input_tensor", expected=3, actual=2), 831 | ), 832 | id="mid-tensor not enough dims", 833 | ), 834 | pytest.param( 835 | torch.rand(1, 2, 3, 4), 836 | True, 837 | func_with_mid_tensor_wildcard, 838 | _RaisesInfo( 839 | exception_type=dltype.DLTypeShapeError, 840 | regex=bad_dimension_error("return", idx=0, actual=2, expected=1), 841 | ), 842 | id="mid-tensor bad function impl", 843 | ), 844 | pytest.param( 845 | torch.rand(1, 2, 3), 846 | None, 847 | func_with_anon_wildcard, 848 | _RaisesInfo(), 849 | id="anon_wildcard", 850 | ), 851 | pytest.param( 852 | torch.rand(1, 2, 3, 4, 5, 6), 853 | None, 854 | func_with_anon_wildcard, 855 | _RaisesInfo(), 856 | id="anon_wildcard_2", 857 | ), 858 | pytest.param( 859 | torch.rand(1, 2), 860 | None, 861 | func_with_anon_wildcard, 862 | _RaisesInfo(), 863 | id="anon_wildcard_3", 864 | ), 865 | pytest.param( 866 | torch.rand(1, 2, 3, 4), 867 | True, 868 | func_with_anon_wildcard, 869 | _RaisesInfo( 870 | exception_type=dltype.DLTypeNDimsError, 871 | regex=bad_ndim_error("return", expected=2, actual=3), 872 | ), 873 | id="anon wildcard fail", 874 | ), 875 | ], 876 | ) 877 | def test_multiaxis_support( 878 | tensor: torch.Tensor, 879 | maybe_tensor_b: torch.Tensor | bool | None, 880 | func: Callable[[torch.Tensor, torch.Tensor | bool | None], torch.Tensor], 881 | expected: _RaisesInfo, 882 | ) -> None: 883 | if expected.exception_type: 884 | with pytest.raises(expected.exception_type, match=expected.regex): 885 | func(tensor, maybe_tensor_b) 886 | else: 887 | func(tensor, maybe_tensor_b) 888 | 889 | 890 | # function with a wildcard in the middle of the dimensions 891 | @dltype.dltyped() 892 | def func_with_two_anon_wildcards( 893 | input_tensor: Annotated[torch.Tensor, dltype.FloatTensor["... h w"]], 894 | _: None = None, 895 | ) -> Annotated[torch.Tensor, dltype.FloatTensor["... h w"]]: 896 | # concatenate the tensor with itself along the batch dimension 897 | return torch.cat([input_tensor, input_tensor], dim=0) 898 | 899 | 900 | @dltype.dltyped() 901 | def bad_func_with_two_named_wildcards( 902 | input_tensor: Annotated[torch.Tensor, dltype.FloatTensor["*batch h w"]], 903 | _: None = None, 904 | ) -> Annotated[torch.Tensor, dltype.FloatTensor["*batch h w"]]: 905 | # concatenate the tensor with itself along the batch dimension, this should fail every time because the 906 | # batch dimension is different 907 | return torch.cat([input_tensor, input_tensor], dim=0) 908 | 909 | 910 | @dltype.dltyped() 911 | def func_with_named_wildcard_followed_by_literal( 912 | input_tensor: Annotated[torch.Tensor, dltype.FloatTensor["*batch 1 h w"]], 913 | _: None = None, 914 | ) -> None: 915 | # this should work because the batch dimension is the same 916 | return None 917 | 918 | 919 | def test_anonymous_wildcard_arg_and_return() -> None: 920 | func_with_two_anon_wildcards(torch.rand(1, 2, 3)) 921 | input_t = torch.rand(1, 2, 3, 4, 5, 6) 922 | input_shape = input_t.shape 923 | result = func_with_two_anon_wildcards(input_t) 924 | # test that anonymous dimensions aren't matched 925 | assert result.shape[0] != input_shape[0] 926 | 927 | with pytest.raises(TypeError): 928 | bad_func_with_two_named_wildcards(torch.rand(1, 2, 3)) 929 | 930 | with pytest.raises(TypeError): 931 | bad_func_with_two_named_wildcards(torch.rand(1, 2, 3, 4, 5, 6)) 932 | 933 | func_with_named_wildcard_followed_by_literal(torch.rand(1, 1, 1, 3, 4)) 934 | func_with_named_wildcard_followed_by_literal(torch.rand(4, 3, 2, 1, 4, 5)) 935 | func_with_named_wildcard_followed_by_literal(torch.rand(1, 2, 3)) 936 | 937 | with pytest.raises(TypeError): 938 | func_with_named_wildcard_followed_by_literal(torch.rand(4, 3, 2, 2, 4, 5)) 939 | 940 | with pytest.raises(TypeError): 941 | func_with_named_wildcard_followed_by_literal(torch.rand(3, 2, 1)) 942 | 943 | 944 | def test_multiaxis_syntax() -> None: 945 | # fail if we have multiple wildcards 946 | with pytest.raises(SyntaxError): 947 | dltype.FloatTensor("batch *channels *h w") 948 | 949 | with pytest.raises(SyntaxError): 950 | dltype.FloatTensor("... *channels h w") 951 | 952 | with pytest.raises(SyntaxError): 953 | dltype.FloatTensor("...... h w") 954 | 955 | with pytest.raises(SyntaxError): 956 | dltype.FloatTensor("batch *channels h w *channels") 957 | 958 | with pytest.raises(SyntaxError): 959 | dltype.FloatTensor("*... h w") 960 | 961 | with pytest.raises(SyntaxError): 962 | dltype.FloatTensor("...batch h w") 963 | 964 | 965 | def test_named_axis() -> None: 966 | dltype.FloatTensor("batch channels dim=1") 967 | dltype.FloatTensor("batch channels dim=min(batch,channels)") 968 | dltype.FloatTensor("batch channels dim=max(batch-1,channels)") 969 | 970 | dltype.FloatTensor("batch channels dim=channels-1") 971 | with pytest.raises(SyntaxError): 972 | dltype.FloatTensor("batch channels=channels-1") 973 | 974 | with pytest.raises(SyntaxError): 975 | dltype.FloatTensor("batch channels dim=min(dim,channels)") 976 | 977 | @dltype.dltyped() 978 | def func_with_misordered_identifier( 979 | tensor: Annotated[ 980 | torch.Tensor, 981 | dltype.FloatTensor["batch dim=channels channels=4"], 982 | ], 983 | ) -> None: 984 | return None 985 | 986 | with pytest.raises(dltype.DLTypeNDimsError): 987 | func_with_misordered_identifier(torch.rand(1, 2, 3, 4)) 988 | 989 | with pytest.raises(dltype.DLTypeInvalidReferenceError): 990 | func_with_misordered_identifier(torch.rand(1, 3, 4)) 991 | 992 | with pytest.raises(dltype.DLTypeInvalidReferenceError): 993 | func_with_misordered_identifier(torch.rand(1, 4, 4)) 994 | 995 | 996 | # mock max_acceptable_eval_time to zero to ensure we issue a warning if the context takes too long 997 | def test_warn_on_function_evaluation() -> None: 998 | with patch("dltype._lib._dltype_context._maybe_warn_runtime", return_value=True): 999 | 1000 | @dltype.dltyped() 1001 | def dummy_function( 1002 | tensor: Annotated[torch.Tensor, dltype.FloatTensor["batch channels h w"]], 1003 | ) -> Annotated[torch.Tensor, dltype.FloatTensor["batch channels h w"]]: 1004 | return tensor 1005 | 1006 | with pytest.warns(UserWarning, match="Type checking took longer than expected"): 1007 | dummy_function(torch.rand(1, 2, 3, 4)) 1008 | 1009 | 1010 | def test_debug_mode_not_enabled() -> None: 1011 | if dltype.DEBUG_MODE: 1012 | pytest.fail("DEBUG_MODE should not be enabled by default") 1013 | 1014 | 1015 | def test_incompatible_tensor_type() -> None: 1016 | with pytest.raises(TypeError): 1017 | 1018 | @dltype.dltyped() 1019 | def bad_function( # pyright: ignore[reportUnusedFunction] 1020 | tensor: Annotated[list[int], dltype.IntTensor["b c h w"]], 1021 | ) -> list[int]: 1022 | return tensor 1023 | 1024 | 1025 | def test_dimension_name_with_underscores() -> None: 1026 | @dltype.dltyped() 1027 | def good_function( # pyright: ignore[reportUnusedFunction] 1028 | tensor: Annotated[ 1029 | torch.Tensor, 1030 | dltype.IntTensor["batch channels_in channels_out"], 1031 | ], 1032 | ) -> torch.Tensor: 1033 | return tensor 1034 | 1035 | 1036 | def test_dimension_with_external_scope() -> None: 1037 | class Provider: 1038 | def get_dltype_scope(self) -> dict[str, int]: 1039 | return {"channels_in": 3, "channels_out": 4} 1040 | 1041 | @dltype.dltyped(scope_provider="self") 1042 | def forward( 1043 | self, 1044 | tensor: Annotated[ 1045 | torch.Tensor, 1046 | dltype.FloatTensor["batch channels_in channels_out"], 1047 | ], 1048 | ) -> torch.Tensor: 1049 | return tensor 1050 | 1051 | @dltype.dltyped(scope_provider=Provider()) 1052 | def good_function( 1053 | tensor: Annotated[ 1054 | torch.Tensor, 1055 | dltype.IntTensor["batch channels_in channels_out"], 1056 | ], 1057 | ) -> torch.Tensor: 1058 | return tensor 1059 | 1060 | good_function(torch.ones(1, 3, 4).int()) 1061 | good_function(torch.ones(4, 3, 4).int()) 1062 | 1063 | with pytest.raises(dltype.DLTypeShapeError): 1064 | good_function(torch.ones(1, 3, 5).int()) 1065 | with pytest.raises(dltype.DLTypeShapeError): 1066 | good_function(torch.ones(1, 2, 4).int()) 1067 | 1068 | provider = Provider() 1069 | 1070 | provider.forward(torch.ones(1, 3, 4)) 1071 | provider.forward(torch.ones(4, 3, 4)) 1072 | 1073 | with pytest.raises(dltype.DLTypeShapeError): 1074 | provider.forward(torch.ones(1, 3, 5)) 1075 | with pytest.raises(dltype.DLTypeShapeError): 1076 | provider.forward(torch.ones(1, 2, 4)) 1077 | 1078 | 1079 | def test_optional_type_handling() -> None: 1080 | """Test that dltyped correctly handles Optional tensor types.""" 1081 | 1082 | # Test with a function with optional parameter 1083 | @dltype.dltyped() 1084 | def optional_tensor_func( 1085 | tensor: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]] | None, 1086 | ) -> torch.Tensor: 1087 | if tensor is None: 1088 | return torch.zeros(1, 3, 5, 5) 1089 | return tensor 1090 | 1091 | # Should work with None 1092 | result = optional_tensor_func(None) 1093 | assert result.shape == (1, 3, 5, 5) 1094 | 1095 | # Should work with correct tensor 1096 | input_tensor = torch.rand(2, 3, 4, 4) 1097 | torch.testing.assert_close(optional_tensor_func(input_tensor), input_tensor) 1098 | 1099 | # Should fail with incorrect shape 1100 | with pytest.raises(dltype.DLTypeNDimsError): 1101 | optional_tensor_func(torch.rand(2, 3, 4)) 1102 | 1103 | # Test with a function that returns an optional tensor 1104 | @dltype.dltyped() 1105 | def return_optional_tensor( 1106 | *, 1107 | return_none: bool, 1108 | ) -> Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]] | None: 1109 | if return_none: 1110 | return None 1111 | return torch.rand(2, 3, 4, 4) 1112 | 1113 | # Should work with either None or tensor 1114 | assert return_optional_tensor(return_none=True) is None 1115 | assert return_optional_tensor(return_none=False) is not None 1116 | 1117 | # Test rejection of non-Optional unions 1118 | with pytest.raises(TypeError, match="Only Optional tensor types are supported"): 1119 | 1120 | @dltype.dltyped() 1121 | def union_tensor_func( # pyright: ignore[reportUnusedFunction] 1122 | tensor: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]] 1123 | | Annotated[torch.Tensor, dltype.IntTensor["b c"]], 1124 | ) -> None: 1125 | pass 1126 | 1127 | # Test optional with classes 1128 | class ModelWithOptional(torch.nn.Module): 1129 | @dltype.dltyped() 1130 | def forward( 1131 | self, 1132 | x: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]] | None, 1133 | mask: Annotated[torch.Tensor, dltype.BoolTensor["b h w"]] | None = None, 1134 | ) -> Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]] | None: 1135 | if x is None: 1136 | return None 1137 | if mask is not None: 1138 | return x * mask.unsqueeze(1) 1139 | return x 1140 | 1141 | model = ModelWithOptional() 1142 | 1143 | # Test with various input combinations 1144 | x = torch.rand(2, 3, 4, 4) 1145 | mask = torch.randint(0, 2, (2, 4, 4)).bool() 1146 | 1147 | torch.testing.assert_close(model(x), x) 1148 | torch.testing.assert_close(model(x, mask), x * mask.unsqueeze(1)) 1149 | assert model(None) is None 1150 | assert model(None, mask) is None 1151 | 1152 | # Should still validate shapes when tensors are provided 1153 | with pytest.raises(dltype.DLTypeNDimsError): 1154 | model(torch.rand(2, 3, 4), None) 1155 | 1156 | with pytest.raises(dltype.DLTypeShapeError): 1157 | model(x, torch.randint(0, 2, (2, 5, 5)).bool()) 1158 | 1159 | 1160 | def test_named_tuple_handling() -> None: 1161 | @dltype.dltyped_namedtuple() 1162 | class MyNamedTuple(NamedTuple): 1163 | tensor: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]] 1164 | mask: Annotated[torch.Tensor, dltype.IntTensor["b h w"]] 1165 | other: int 1166 | 1167 | MyNamedTuple(torch.rand(2, 3, 4, 4), torch.randint(0, 2, (2, 4, 4)), 1) 1168 | 1169 | with pytest.raises(dltype.DLTypeNDimsError): 1170 | MyNamedTuple(torch.rand(2, 3, 4), torch.randint(0, 2, (2, 4, 4)), 1) 1171 | 1172 | with pytest.raises(dltype.DLTypeDtypeError): 1173 | MyNamedTuple(torch.rand(2, 3, 4, 4), torch.randint(0, 2, (2, 4, 4)).bool(), 1) 1174 | 1175 | # test a named tuple with an optional field 1176 | 1177 | @dltype.dltyped_namedtuple() 1178 | class MyOptionalNamedTuple(NamedTuple): 1179 | tensor: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]] 1180 | mask: Annotated[torch.Tensor, dltype.IntTensor["b h w"]] | None 1181 | 1182 | MyOptionalNamedTuple(torch.rand(2, 3, 4, 4), None) 1183 | MyOptionalNamedTuple(torch.rand(2, 3, 4, 4), torch.randint(0, 2, (2, 4, 4))) 1184 | 1185 | with pytest.raises(dltype.DLTypeNDimsError): 1186 | MyOptionalNamedTuple(torch.rand(2, 3, 4), None) 1187 | 1188 | with pytest.raises(dltype.DLTypeDtypeError): 1189 | MyOptionalNamedTuple( 1190 | torch.rand(2, 3, 4, 4), 1191 | torch.randint(0, 2, (2, 4, 4)).bool(), 1192 | ) 1193 | 1194 | 1195 | def test_annotated_dataclass() -> None: 1196 | """Test that dltyped correctly handles Annotated dataclasses.""" 1197 | 1198 | # Test with a function with annotated dataclass 1199 | @dltype.dltyped_dataclass() 1200 | @dataclass(frozen=True, kw_only=True, slots=True) 1201 | class AnnotatedDataclass: 1202 | tensor: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]] 1203 | tensor_2: Annotated[torch.Tensor, dltype.IntTensor["b c h w"]] 1204 | other_thing: int 1205 | un_annotated_tensor: torch.Tensor 1206 | 1207 | AnnotatedDataclass( 1208 | tensor=torch.rand(1, 2, 3, 4), 1209 | tensor_2=torch.randint(0, 10, (1, 2, 3, 4)), 1210 | other_thing=5, 1211 | un_annotated_tensor=torch.rand(1, 2, 3, 4), 1212 | ) 1213 | 1214 | with pytest.raises(dltype.DLTypeShapeError): 1215 | AnnotatedDataclass( 1216 | tensor=torch.rand(1, 2, 3, 4), 1217 | tensor_2=torch.randint(0, 10, (1, 2, 3, 5)), 1218 | other_thing=5, 1219 | un_annotated_tensor=torch.rand(1, 2, 3, 4), 1220 | ) 1221 | 1222 | # test that the order of the decorator matters 1223 | 1224 | with pytest.raises( 1225 | TypeError, 1226 | match=r"Class AnnotatedDataclass2 is not a dataclass, apply @dataclass below dltyped_dataclass.", 1227 | ): 1228 | 1229 | @dataclass(frozen=True, kw_only=True, slots=True) 1230 | @dltype.dltyped_dataclass() 1231 | class AnnotatedDataclass2: # pyright: ignore[reportUnusedClass] 1232 | pass 1233 | 1234 | # test with optional fields 1235 | 1236 | @dltype.dltyped_dataclass() 1237 | @dataclass(frozen=True, kw_only=True, slots=True) 1238 | class OptionalAnnotatedDataclass: 1239 | tensor: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]] | None 1240 | tensor_2: Annotated[torch.Tensor, dltype.IntTensor["b c h w"]] | None = None 1241 | other_thing: int = 0 1242 | 1243 | OptionalAnnotatedDataclass(tensor=None) 1244 | OptionalAnnotatedDataclass(tensor=None, tensor_2=None) 1245 | OptionalAnnotatedDataclass(tensor=torch.rand(1, 2, 3, 4)) 1246 | OptionalAnnotatedDataclass( 1247 | tensor=torch.rand(1, 2, 3, 4), 1248 | tensor_2=torch.randint(0, 10, (1, 2, 3, 4)), 1249 | ) 1250 | 1251 | with pytest.raises(dltype.DLTypeNDimsError): 1252 | OptionalAnnotatedDataclass(tensor=torch.rand(1, 2, 3)) 1253 | 1254 | with pytest.raises(dltype.DLTypeDtypeError): 1255 | OptionalAnnotatedDataclass( 1256 | tensor=torch.rand(1, 2, 3, 4), 1257 | tensor_2=torch.rand(1, 2, 3, 4).bool(), 1258 | ) 1259 | 1260 | 1261 | def test_improper_base_model_construction() -> None: 1262 | """Test that improper construction of BaseModel raises an error.""" 1263 | with pytest.raises(dltype.DLTypeDtypeError, match=r"Invalid dtype"): 1264 | 1265 | class _BadModel(BaseModel): # pyright: ignore[reportUnusedClass] 1266 | tensor: Annotated[npt.NDArray[np.float32], dltype.IntTensor["b c h w"]] 1267 | 1268 | with pytest.raises(dltype.DLTypeDtypeError, match=r"Invalid dtype"): 1269 | 1270 | class _BadModel2(BaseModel): # pyright: ignore[reportUnusedClass] 1271 | tensor: Annotated[ 1272 | npt.NDArray[np.float32 | np.float64], 1273 | dltype.IntTensor["b c h w"], 1274 | ] 1275 | 1276 | class _GoodModel(BaseModel): # pyright: ignore[reportUnusedClass] 1277 | tensor: Annotated[npt.NDArray[np.int32], dltype.IntTensor["b c h w"]] 1278 | 1279 | class _GoodModel2(BaseModel): # pyright: ignore[reportUnusedClass] 1280 | tensor: Annotated[npt.NDArray[np.int8 | np.int32], dltype.IntTensor["b c h w"]] 1281 | 1282 | 1283 | class _MyClass: 1284 | def __init__(self, tensor: torch.Tensor) -> None: 1285 | self.tensor = tensor 1286 | 1287 | @classmethod 1288 | @dltype.dltyped() 1289 | def create( 1290 | cls, 1291 | tensor: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]], 1292 | ) -> "_MyClass": 1293 | return cls(tensor) 1294 | 1295 | 1296 | def test_class_with_forward_reference() -> None: 1297 | """Test that a class with a forward reference to itself raises an error.""" 1298 | _MyClass.create(torch.rand(1, 2, 3, 4)) 1299 | 1300 | with pytest.raises(dltype.DLTypeNDimsError): 1301 | _MyClass.create(torch.rand(1, 2, 3)) 1302 | 1303 | # inner classes do not evaluate the forward reference correctly, warn the user 1304 | 1305 | class _InnerClass: 1306 | def __init__(self, tensor: torch.Tensor) -> None: 1307 | self.tensor = tensor 1308 | 1309 | @classmethod 1310 | @dltype.dltyped() 1311 | def create( 1312 | cls, 1313 | tensor: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]], 1314 | ) -> "_InnerClass": 1315 | return cls(tensor) 1316 | 1317 | with pytest.warns( 1318 | UserWarning, 1319 | match="Unable to determine signature of dltyped function, type checking will be skipped.", 1320 | ): 1321 | _InnerClass.create(torch.rand(1, 2, 3, 4)) 1322 | 1323 | 1324 | def test_warning_if_decorator_has_no_annotations_to_check() -> None: 1325 | with pytest.warns( 1326 | UserWarning, 1327 | match="No DLType hints found, skipping type checking", 1328 | ): 1329 | 1330 | @dltype.dltyped() 1331 | def no_annotations(tensor: torch.Tensor) -> torch.Tensor: # pyright: ignore[reportUnusedFunction] 1332 | return tensor 1333 | 1334 | # should warn if some tensors are untyped 1335 | @dltype.dltyped() 1336 | def some_annotations( 1337 | tensor: Annotated[torch.Tensor, dltype.FloatTensor["1 2 3"]], 1338 | ) -> torch.Tensor: 1339 | return tensor 1340 | 1341 | with pytest.warns( 1342 | UserWarning, 1343 | match=re.escape("[return] is missing a DLType hint"), 1344 | ): 1345 | some_annotations(torch.rand(1, 2, 3)) 1346 | 1347 | 1348 | def test_scalar() -> None: 1349 | """Test that dltyped correctly handles scalar types.""" 1350 | 1351 | @dltype.dltyped() 1352 | def scalar_func( 1353 | x: Annotated[torch.Tensor, dltype.FloatTensor[None]], 1354 | ) -> Annotated[torch.Tensor, dltype.FloatTensor[None]]: 1355 | return x 1356 | 1357 | # Should work with a scalar tensor 1358 | scalar_func(torch.tensor(3.14)) 1359 | 1360 | # Should fail with a non-scalar tensor 1361 | with pytest.raises(dltype.DLTypeNDimsError): 1362 | scalar_func(torch.tensor([3.14])) 1363 | 1364 | with pytest.raises(SyntaxError, match="Invalid shape shape_string=''"): 1365 | Annotated[torch.Tensor, dltype.FloatTensor[""]] 1366 | 1367 | 1368 | def test_signed_vs_unsigned() -> None: 1369 | """Test signed vs unsigned errors are handled correctly.""" 1370 | 1371 | @dltype.dltyped() 1372 | def signed_vs_unsigned( 1373 | x: Annotated[NPIntArrayT, dltype.SignedIntTensor["x"]], 1374 | y: Annotated[NPIntArrayT, dltype.UnsignedIntTensor["x"]], 1375 | ) -> Annotated[torch.Tensor, dltype.IntTensor["x"]]: 1376 | return torch.from_numpy((x * y).astype(np.uint8)) 1377 | 1378 | # should work nominally 1379 | 1380 | np.testing.assert_allclose( 1381 | signed_vs_unsigned( 1382 | np.array([6], dtype=np.int32), # pyright: ignore[reportUnknownArgumentType] 1383 | np.array([8], dtype=np.uint32), 1384 | ).numpy(), 1385 | np.array([48], dtype=np.uint8), 1386 | ) 1387 | 1388 | # Should fail with a bad signed tensor 1389 | with pytest.raises(dltype.DLTypeDtypeError): 1390 | signed_vs_unsigned( 1391 | np.array([6], dtype=np.uint32), 1392 | np.array([8], dtype=np.uint32), 1393 | ) 1394 | 1395 | with pytest.raises(dltype.DLTypeDtypeError): 1396 | signed_vs_unsigned(np.array([6], dtype=np.int32), np.array([8], dtype=np.int32)) 1397 | 1398 | 1399 | def test_bit_widths() -> None: 1400 | """Test bit width errors are handled correctly.""" 1401 | 1402 | @dltype.dltyped() 1403 | def various_bit_widths( 1404 | x: Annotated[NPIntArrayT, dltype.UInt16Tensor["x"]], 1405 | y: Annotated[torch.Tensor, dltype.Int64Tensor["x"]], 1406 | ) -> Annotated[NPIntArrayT, dltype.UInt8Tensor["x"]]: 1407 | return (x + y.numpy()).astype(np.uint8) 1408 | 1409 | # should work nominally 1410 | 1411 | np.testing.assert_allclose( 1412 | various_bit_widths( 1413 | np.array([6], dtype=np.uint16), 1414 | torch.tensor([8], dtype=torch.int64), 1415 | ), 1416 | np.array([14], dtype=np.uint8), 1417 | ) 1418 | 1419 | # Should fail with a bad width on a numpy tensor 1420 | with pytest.raises(dltype.DLTypeDtypeError): 1421 | various_bit_widths( 1422 | np.array([6], dtype=np.uint32), 1423 | torch.tensor([8], dtype=torch.int64), 1424 | ) 1425 | 1426 | # Should fail with a bad width on a torch tensor 1427 | with pytest.raises(dltype.DLTypeDtypeError): 1428 | various_bit_widths( 1429 | np.array([6], dtype=np.uint16), 1430 | torch.tensor([8], dtype=torch.int32), 1431 | ) 1432 | 1433 | 1434 | def test_invalid_tensor_type_handling() -> None: 1435 | with pytest.raises(dltype.DLTypeUnsupportedTensorTypeError): 1436 | good_function([1, 2, 3]) # type: ignore (intentionally bypass static type checking) 1437 | 1438 | 1439 | ShapedTensorT: TypeAlias = Annotated[torch.Tensor, dltype.Float16Tensor["1 2 3"]] 1440 | 1441 | 1442 | def test_type_alias() -> None: 1443 | @dltype.dltyped() 1444 | def function(tensor: ShapedTensorT) -> None: 1445 | print(tensor) 1446 | 1447 | function(torch.empty((1, 2, 3), dtype=torch.float16)) 1448 | 1449 | with pytest.raises(dltype.DLTypeDtypeError): 1450 | function(torch.empty(1, 2, 3, dtype=torch.float32)) 1451 | 1452 | 1453 | ShapeT: TypeAlias = dltype.Shape[1, 2, ..., dltype.VariableAxis("last")] 1454 | RGB: Final = dltype.ConstantAxis("RGB", 3) 1455 | Batch: Final = dltype.AnonymousAxis("batch") 1456 | ImgH: Final = dltype.VariableAxis("ImgH") 1457 | ImgW: Final = dltype.VariableAxis("ImgW") 1458 | 1459 | ImgBatch: TypeAlias = dltype.Shape[Batch, RGB, ImgH, ImgW] 1460 | 1461 | 1462 | def test_shaped_tensor() -> None: 1463 | @dltype.dltyped() 1464 | def func( 1465 | tensor: Annotated[torch.Tensor, dltype.FloatTensor[ShapeT]], 1466 | fail: bool = False, 1467 | ) -> Annotated[torch.Tensor, dltype.FloatTensor["1 2 ... last"]]: 1468 | return tensor if not fail else torch.empty(1, 2, 99) 1469 | 1470 | assert str(ImgBatch) == "*batch RGB=3 ImgH ImgW" 1471 | 1472 | @dltype.dltyped() 1473 | def func2( 1474 | arg: Annotated[torch.Tensor, dltype.UInt8Tensor[ImgBatch]], 1475 | ) -> Annotated[ 1476 | torch.Tensor, 1477 | dltype.UInt8Tensor[ 1478 | dltype.Shape[ 1479 | Batch, 1480 | ImgH * ImgW, 1481 | RGB, 1482 | ] 1483 | ], 1484 | ]: 1485 | if arg.ndim > 3: 1486 | return arg.view(*arg.shape[:-3], arg.shape[-2] * arg.shape[-1], 3) 1487 | 1488 | return arg.view(arg.shape[-2] * arg.shape[-1], 3) 1489 | 1490 | func(torch.empty(1, 2, 9)) 1491 | func(torch.empty(1, 2, 3, 9)) 1492 | 1493 | func2(torch.empty(3, 4, 5, dtype=torch.uint8)) 1494 | func2(torch.empty(1, 3, 4, 5, dtype=torch.uint8)) 1495 | func2(torch.empty(1, 2, 3, 4, 5, dtype=torch.uint8)) 1496 | 1497 | with pytest.raises(dltype.DLTypeShapeError): 1498 | func(torch.empty(0, 2, 8)) 1499 | 1500 | with pytest.raises(dltype.DLTypeShapeError): 1501 | func(torch.empty(1, 2, 4), True) 1502 | 1503 | class PydanticObj(BaseModel): 1504 | tensor_a: Annotated[torch.Tensor, dltype.FloatTensor[dltype.Shape[4, 4]]] 1505 | 1506 | PydanticObj(tensor_a=torch.zeros((4, 4))) 1507 | 1508 | with pytest.raises(dltype.DLTypeShapeError): 1509 | PydanticObj(tensor_a=torch.ones((3, 3))) 1510 | 1511 | 1512 | def test_return_tuple() -> None: 1513 | @dltype.dltyped() 1514 | def func( 1515 | arg: Annotated[torch.Tensor, dltype.FloatTensor["batch channels"]], 1516 | *, 1517 | fail: bool = False, 1518 | ) -> tuple[ 1519 | Annotated[torch.Tensor, dltype.FloatTensor["channels batch"]], 1520 | Annotated[torch.Tensor, dltype.FloatTensor["1 channels batch"]], 1521 | int, 1522 | ]: 1523 | return ( 1524 | (arg.permute(1, 0), arg.permute(1, 0).unsqueeze(0), 2) 1525 | if not fail 1526 | else (arg.permute(1, 0), arg.unsqueeze(0), 2) 1527 | ) 1528 | 1529 | func(torch.zeros(1, 3)) 1530 | 1531 | with pytest.raises(dltype.DLTypeShapeError): 1532 | func(torch.zeros(1, 3), fail=True) 1533 | 1534 | 1535 | def test_pass_tuple() -> None: 1536 | @dltype.dltyped() 1537 | def func( 1538 | arg: tuple[ 1539 | Annotated[torch.Tensor, dltype.FloatTensor["x x y"]], 1540 | Annotated[torch.Tensor, dltype.FloatTensor["y y x"]], 1541 | int, 1542 | ], 1543 | ) -> None: 1544 | pass 1545 | 1546 | func((torch.zeros(1, 1, 3), torch.zeros(3, 3, 1), 1)) 1547 | 1548 | with pytest.raises(dltype.DLTypeShapeError): 1549 | func((torch.zeros(1, 1, 3), torch.zeros(3, 2, 1), 1)) 1550 | --------------------------------------------------------------------------------