├── .coveragerc ├── .flake8 ├── .github └── workflows │ ├── build.yml │ ├── coverage.yml │ ├── lint.yml │ ├── mypy.yml │ └── publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── deprecated.md ├── jax_dataclasses ├── __init__.py ├── _copy_and_mutate.py ├── _dataclasses.py ├── _enforced_annotations.py ├── _get_type_hints.py ├── _jit.py └── py.typed ├── mypy.ini ├── setup.py └── tests ├── conftest.py ├── test_annotated_arrays.py ├── test_copy_and_mutate.py ├── test_cycle.py ├── test_dataclass.py ├── test_jit_ignore_py37.py ├── test_serialization.py ├── test_variadic_generic_py312.py └── test_vmap.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | exclude_lines = 3 | # Have to re-enable the standard pragma 4 | pragma: no cover 5 | 6 | # Don't compute coverage for abstract methods, properties 7 | @abstract 8 | @abc\.abstract 9 | 10 | # or warnings 11 | warnings 12 | 13 | # or empty function bodies 14 | pass 15 | \.\.\. 16 | 17 | # or typing imports 18 | TYPE_CHECKING 19 | 20 | # or assert statements 21 | assert 22 | 23 | # or anything that's not implemented 24 | NotImplementedError() 25 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # E203: whitespace before : 3 | # E501: line too long ( characters) 4 | # W503: line break before binary operator 5 | ; ignore = E203,E501,D100,D101,D102,D103,W503 6 | ignore = E203,E501,W503 7 | per-file-ignores = __init__.py:F401,F403 8 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-22.04 12 | strategy: 13 | matrix: 14 | python-version: ["3.9", "3.10", "3.11", "3.12"] 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v1 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install -e ".[testing]" 26 | - name: Test with pytest 27 | run: | 28 | pytest 29 | -------------------------------------------------------------------------------- /.github/workflows/coverage.yml: -------------------------------------------------------------------------------- 1 | name: coverage 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | coverage: 11 | runs-on: ubuntu-22.04 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Set up Python 3.12 15 | uses: actions/setup-python@v1 16 | with: 17 | python-version: 3.12 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install -e ".[testing]" 22 | - name: Generate coverage report 23 | run: | 24 | pytest --cov=jax_dataclasses --cov-report=xml 25 | - name: Upload to Codecov 26 | uses: codecov/codecov-action@v1 27 | with: 28 | token: ${{ secrets.CODECOV_TOKEN }} 29 | file: ./coverage.xml 30 | flags: unittests 31 | name: codecov-umbrella 32 | fail_ci_if_error: true 33 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | black-check: 11 | runs-on: ubuntu-22.04 12 | steps: 13 | - uses: actions/checkout@v1 14 | - name: Black Code Formatter 15 | uses: lgeiger/black-action@master 16 | with: 17 | args: ". --check" 18 | -------------------------------------------------------------------------------- /.github/workflows/mypy.yml: -------------------------------------------------------------------------------- 1 | name: mypy 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | mypy: 11 | runs-on: ubuntu-22.04 12 | strategy: 13 | matrix: 14 | python-version: ["3.12"] 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v1 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install -e . 26 | pip install mypy 27 | - name: Test with mypy 28 | run: | 29 | mypy . 30 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-22.04 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v1 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | *.egg-info 4 | __pycache__ 5 | .mypy_cache 6 | .dmypy.json 7 | .pytype 8 | .hypothesis 9 | .ipynb_checkpoints 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Brent Yi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## jax_dataclasses 2 | 3 | ![build](https://github.com/brentyi/jax_dataclasses/workflows/build/badge.svg) 4 | ![mypy](https://github.com/brentyi/jax_dataclasses/workflows/mypy/badge.svg?branch=main) 5 | ![lint](https://github.com/brentyi/jax_dataclasses/workflows/lint/badge.svg) 6 | [![codecov](https://codecov.io/gh/brentyi/jax_dataclasses/branch/main/graph/badge.svg?token=fFSx7CeKlW)](https://codecov.io/gh/brentyi/jax_dataclasses) 7 | 8 | 9 | 10 | - [Overview](#overview) 11 | - [Installation](#installation) 12 | - [Core interface](#core-interface) 13 | - [Static fields](#static-fields) 14 | - [Mutations](#mutations) 15 | - [Alternatives](#alternatives) 16 | - [Misc](#misc) 17 | 18 | 19 | 20 | ### Overview 21 | 22 | `jax_dataclasses` provides a simple wrapper around `dataclasses.dataclass` for use in 23 | JAX, which enables automatic support for: 24 | 25 | - [Pytree](https://jax.readthedocs.io/en/latest/pytrees.html) registration. This 26 | allows dataclasses to be used at API boundaries in JAX. 27 | - Serialization via `flax.serialization`. 28 | 29 | Distinguishing features include: 30 | 31 | - An annotation-based interface for marking static fields. 32 | - Improved ergonomics for "model surgery" in nested structures. 33 | 34 | ### Installation 35 | 36 | In Python >=3.7: 37 | 38 | ```bash 39 | pip install jax_dataclasses 40 | ``` 41 | 42 | We can then import: 43 | 44 | ```python 45 | import jax_dataclasses as jdc 46 | ``` 47 | 48 | ### Core interface 49 | 50 | `jax_dataclasses` is meant to provide a drop-in replacement for 51 | `dataclasses.dataclass`: jdc.pytree_dataclass has 52 | the same interface as `dataclasses.dataclass`, but also registers the target 53 | class as a pytree node. 54 | 55 | We also provide several aliases: 56 | `jdc.[field, asdict, astuples, is_dataclass, replace]` are identical to 57 | their counterparts in the standard dataclasses library. 58 | 59 | ### Static fields 60 | 61 | To mark a field as static (in this context: constant at compile-time), we can 62 | wrap its type with jdc.Static[]: 63 | 64 | ```python 65 | @jdc.pytree_dataclass 66 | class A: 67 | a: jax.Array 68 | b: jdc.Static[bool] 69 | ``` 70 | 71 | In a pytree node, static fields will be treated as part of the treedef instead 72 | of as a child of the node; all fields that are not explicitly marked static 73 | should contain arrays or child nodes. 74 | 75 | Bonus: if you like `jdc.Static[]`, we also introduce 76 | jdc.jit(). This enables use in function 77 | signatures, for example: 78 | 79 | ```python 80 | @jdc.jit 81 | def f(a: jax.Array, b: jdc.Static[bool]) -> jax.Array: 82 | ... 83 | ``` 84 | 85 | ### Mutations 86 | 87 | All dataclasses are automatically marked as frozen and thus immutable (even when 88 | no `frozen=` parameter is passed in). To make changes to nested structures 89 | easier, jdc.copy_and_mutate (a) makes a copy of a 90 | pytree and (b) returns a context in which any of that copy's contained 91 | dataclasses are temporarily mutable: 92 | 93 | ```python 94 | import jax 95 | from jax import numpy as jnp 96 | import jax_dataclasses as jdc 97 | 98 | @jdc.pytree_dataclass 99 | class Node: 100 | child: jax.Array 101 | 102 | obj = Node(child=jnp.zeros(3)) 103 | 104 | with jdc.copy_and_mutate(obj) as obj_updated: 105 | # Make mutations to the dataclass. This is primarily useful for nested 106 | # dataclasses. 107 | # 108 | # Does input validation by default: if the treedef, leaf shapes, or dtypes 109 | # of `obj` and `obj_updated` don't match, an AssertionError will be raised. 110 | # This can be disabled with a `validate=False` argument. 111 | obj_updated.child = jnp.ones(3) 112 | 113 | print(obj) 114 | print(obj_updated) 115 | ``` 116 | 117 | ### Alternatives 118 | 119 | A few other solutions exist for automatically integrating dataclass-style 120 | objects into pytree structures. Great ones include: 121 | [`chex.dataclass`](https://github.com/deepmind/chex), 122 | [`flax.struct`](https://github.com/google/flax), and 123 | [`tjax.dataclass`](https://github.com/NeilGirdhar/tjax). These all influenced 124 | this library. 125 | 126 | The main differentiators of `jax_dataclasses` are: 127 | 128 | - **Static analysis support.** `tjax` has a custom mypy plugin to enable type 129 | checking, but isn't supported by other tools. `flax.struct` implements the 130 | [`dataclass_transform`](https://github.com/microsoft/pyright/blob/main/specs/dataclass_transforms.md) 131 | spec proposed by pyright, but isn't supported by other tools. Because 132 | `@jdc.pytree_dataclass` has the same API as `@dataclasses.dataclass`, it can 133 | include pytree registration behavior at runtime while being treated as the 134 | standard decorator during static analysis. This means that all static 135 | checkers, language servers, and autocomplete engines that support the standard 136 | `dataclasses` library should work out of the box with `jax_dataclasses`. 137 | 138 | - **Nested dataclasses.** Making replacements/modifications in deeply nested 139 | dataclasses can be really frustrating. The three alternatives all introduce a 140 | `.replace(self, ...)` method to dataclasses that's a bit more convenient than 141 | the traditional `dataclasses.replace(obj, ...)` API for shallow changes, but 142 | still becomes really cumbersome to use when dataclasses are nested. 143 | `jdc.copy_and_mutate()` is introduced to address this. 144 | 145 | - **Static field support.** Parameters that should not be traced in JAX should 146 | be marked as static. This is supported in `flax`, `tjax`, and 147 | `jax_dataclasses`, but not `chex`. 148 | 149 | - **Serialization.** When working with `flax`, being able to serialize 150 | dataclasses is really handy. This is supported in `flax.struct` (naturally) 151 | and `jax_dataclasses`, but not `chex` or `tjax`. 152 | 153 | You can also eschew the dataclass-style interface entirely; 154 | [see how brax registers pytrees](https://github.com/google/brax/blob/730e05d4af58eada5b49a44e849107d76e386b9a/brax/pytree.py). 155 | This is a reasonable thing to prefer: it requires some floating strings and 156 | breaks things that I care about but you may not (like immutability and 157 | `__post_init__`), but gives more flexibility with custom `__init__` methods. 158 | 159 | ### Misc 160 | 161 | `jax_dataclasses` was originally written for and factored out of 162 | [jaxfg](http://github.com/brentyi/jaxfg), where 163 | [Nick Heppert](https://github.com/SuperN1ck) provided valuable feedback. 164 | -------------------------------------------------------------------------------- /deprecated.md: -------------------------------------------------------------------------------- 1 | # Deprecated features 2 | 3 | `jax_dataclasses` includes utilities for `__post_init__`-based runtime shape 4 | and datatype annotation. This works as-designed, but we no longer recommend 5 | using it. [jaxtyping](https://github.com/google/jaxtyping) is a reasonable 6 | alternative solution. 7 | 8 | ### Shape and data-type annotations 9 | 10 | Subclassing from jdc.EnforcedAnnotationsMixin 11 | enables automatic shape and data-type validation. Arrays contained within 12 | dataclasses are validated on instantiation and a **`.get_batch_axes()`** method 13 | is exposed for grabbing any common batch axes to the shapes of contained arrays. 14 | 15 | We can start by importing the standard `Annotated` type: 16 | 17 | ```python 18 | # Python >=3.9 19 | from typing import Annotated 20 | 21 | # Backport 22 | from typing_extensions import Annotated 23 | ``` 24 | 25 | We can then add shape annotations: 26 | 27 | ```python 28 | @jdc.pytree_dataclass 29 | class MnistStruct(jdc.EnforcedAnnotationsMixin): 30 | image: Annotated[ 31 | jnp.ndarray, 32 | # Note that we can move the expected location of the batch axes by 33 | # shifting the ellipsis around. 34 | # 35 | # If the ellipsis is excluded, we assume batch axes at the start of the 36 | # shape. 37 | (..., 28, 28), 38 | ] 39 | label: Annotated[ 40 | jnp.ndarray, 41 | (..., 10), 42 | ] 43 | ``` 44 | 45 | Or data-type annotations: 46 | 47 | ```python 48 | image: Annotated[ 49 | jnp.ndarray, 50 | jnp.float32, 51 | ] 52 | label: Annotated[ 53 | jnp.ndarray, 54 | jnp.integer, 55 | ] 56 | ``` 57 | 58 | Or both (note that annotations are order-invariant): 59 | 60 | ```python 61 | image: Annotated[ 62 | jnp.ndarray, 63 | (..., 28, 28), 64 | jnp.float32, 65 | ] 66 | label: Annotated[ 67 | jnp.ndarray, 68 | (..., 10), 69 | jnp.integer, 70 | ] 71 | ``` 72 | 73 | Then, assuming we've constrained both the shape and data-type: 74 | 75 | ```python 76 | # OK 77 | struct = MnistStruct( 78 | image=onp.zeros((28, 28), dtype=onp.float32), 79 | label=onp.zeros((10,), dtype=onp.uint8), 80 | ) 81 | print(struct.get_batch_axes()) # Prints () 82 | 83 | # OK 84 | struct = MnistStruct( 85 | image=onp.zeros((32, 28, 28), dtype=onp.float32), 86 | label=onp.zeros((32, 10), dtype=onp.uint8), 87 | ) 88 | print(struct.get_batch_axes()) # Prints (32,) 89 | 90 | # AssertionError on instantiation because of type mismatch 91 | MnistStruct( 92 | image=onp.zeros((28, 28), dtype=onp.float32), 93 | label=onp.zeros((10,), dtype=onp.float32), # Not an integer type! 94 | ) 95 | 96 | # AssertionError on instantiation because of shape mismatch 97 | MnistStruct( 98 | image=onp.zeros((28, 28), dtype=onp.float32), 99 | label=onp.zeros((5,), dtype=onp.uint8), 100 | ) 101 | 102 | # AssertionError on instantiation because of batch axis mismatch 103 | struct = MnistStruct( 104 | image=onp.zeros((64, 28, 28), dtype=onp.float32), 105 | label=onp.zeros((32, 10), dtype=onp.uint8), 106 | ) 107 | ``` 108 | -------------------------------------------------------------------------------- /jax_dataclasses/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, astuple, field, fields, is_dataclass, replace 2 | from typing import TYPE_CHECKING 3 | 4 | from ._copy_and_mutate import copy_and_mutate as copy_and_mutate 5 | 6 | if TYPE_CHECKING: 7 | # Treat our JAX field and dataclass functions as their counterparts from the 8 | # standard dataclasses library during static analysis 9 | # 10 | # Tools like via mypy, jedi, etc generally rely on a lot of special, hardcoded 11 | # behavior for the standard dataclasses library; this lets us take advantage of all 12 | # of it. 13 | # 14 | # Note that mypy will not follow aliases, so `from dataclasses import dataclass` is 15 | # preferred over `dataclass = dataclasses.dataclass`. 16 | # 17 | # Dataclass transforms serve a similar purpose, but are currently only supported in 18 | # pyright and pylance. 19 | # https://github.com/microsoft/pyright/blob/master/specs/dataclass_transforms.md 20 | # `static_field()` is deprecated, but not a lot of code to support, so leaving it 21 | # for now... 22 | from dataclasses import dataclass as pytree_dataclass 23 | else: 24 | from ._dataclasses import pytree_dataclass # noqa 25 | from ._dataclasses import deprecated_static_field as static_field # noqa 26 | 27 | from ._dataclasses import Static 28 | from ._enforced_annotations import EnforcedAnnotationsMixin 29 | from ._jit import jit 30 | 31 | __all__ = [ 32 | "asdict", 33 | "astuple", 34 | "field", 35 | "fields", 36 | "is_dataclass", 37 | "replace", 38 | "copy_and_mutate", 39 | "pytree_dataclass", 40 | "Static", 41 | "EnforcedAnnotationsMixin", 42 | "jit", 43 | ] 44 | -------------------------------------------------------------------------------- /jax_dataclasses/_copy_and_mutate.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import dataclasses 3 | import enum 4 | from typing import Any, ContextManager, Set, TypeVar 5 | 6 | from jax import numpy as jnp 7 | from jax import tree_util 8 | from jax.tree_util import default_registry 9 | 10 | T = TypeVar("T") 11 | 12 | 13 | class _Mutability(enum.Enum): 14 | FROZEN = enum.auto() 15 | MUTABLE = enum.auto() 16 | MUTABLE_NO_VALIDATION = enum.auto() 17 | 18 | 19 | def _mark_mutable(obj: Any, mutable: _Mutability, visited: Set[Any]) -> None: 20 | """Recursively freeze or unfreeze dataclasses in a structure. 21 | Currently only supports tuples, lists, dictionaries, dataclasses.""" 22 | 23 | # Skip objects we've already visited. This avoids redundancies when there are 24 | # identical branches in our pytree, but will also help prevent infinite looping from 25 | # cycles. 26 | if id(obj) in visited: 27 | return 28 | visited.add(id(obj)) 29 | 30 | if dataclasses.is_dataclass(obj): 31 | object.__setattr__(obj, "__mutability__", mutable) 32 | 33 | flattened = default_registry.flatten_one_level(obj) 34 | if flattened is None: 35 | return 36 | for child in flattened[0]: 37 | _mark_mutable(child, mutable, visited) 38 | 39 | 40 | def copy_and_mutate(obj: T, validate: bool = True) -> ContextManager[T]: 41 | """Context manager that copies a pytree and allows for temporary mutations to 42 | contained dataclasses. Optionally validates that treedefs, array shapes, and dtypes 43 | are not changed.""" 44 | 45 | # Inner function helps with static typing! 46 | def _replace_context(obj: T): 47 | # Make a copy of the input object. 48 | obj_copy = tree_util.tree_map(lambda leaf: leaf, obj) 49 | 50 | # Mark it as mutable. 51 | _mark_mutable( 52 | obj_copy, 53 | mutable=( 54 | _Mutability.MUTABLE if validate else _Mutability.MUTABLE_NO_VALIDATION 55 | ), 56 | visited=set(), 57 | ) 58 | 59 | # Yield. 60 | yield obj_copy 61 | 62 | # When done, mark as immutable again. 63 | _mark_mutable( 64 | obj_copy, 65 | mutable=_Mutability.FROZEN, 66 | visited=set(), 67 | ) 68 | 69 | return contextlib.contextmanager(_replace_context)(obj) 70 | 71 | 72 | def _unify_floats(dtype): 73 | if dtype == jnp.float64: 74 | return jnp.float32 75 | else: 76 | return dtype 77 | 78 | 79 | def _new_setattr(self, name: str, value: Any): 80 | if self.__mutability__ == _Mutability.MUTABLE: 81 | # Validate changes. 82 | current_value = getattr(self, name) 83 | 84 | # Make sure tree structure is unchanged. 85 | assert tree_util.tree_structure(value) == tree_util.tree_structure( 86 | current_value 87 | ), "Mismatched tree structure!" 88 | 89 | # Check leaf shapes. 90 | new_shapes = tuple( 91 | leaf.shape if hasattr(leaf, "shape") else tuple() 92 | for leaf in tree_util.tree_leaves(value) 93 | ) 94 | cur_shapes = tuple( 95 | leaf.shape if hasattr(leaf, "shape") else tuple() 96 | for leaf in tree_util.tree_leaves(current_value) 97 | ) 98 | assert ( 99 | new_shapes == cur_shapes 100 | ), f"Shape error: new shapes {new_shapes} do not match original {cur_shapes}!" 101 | 102 | # Check leaf dtypes. 103 | new_dtypes = tuple( 104 | _unify_floats(leaf.dtype) if hasattr(leaf, "dtype") else type(leaf) 105 | for leaf in tree_util.tree_leaves(value) 106 | ) 107 | cur_dtypes = tuple( 108 | _unify_floats(leaf.dtype) if hasattr(leaf, "dtype") else type(leaf) 109 | for leaf in tree_util.tree_leaves(current_value) 110 | ) 111 | for new, cur in zip(new_dtypes, cur_dtypes): 112 | assert ( 113 | new == cur or new in (int, float) or cur in (int, float) 114 | ), f"Type error: new dtypes {new_dtypes} do not match original {cur_dtypes}!" 115 | 116 | object.__setattr__(self, name, value) 117 | 118 | elif self.__mutability__ == _Mutability.MUTABLE_NO_VALIDATION: 119 | # Make changes without validation. 120 | object.__setattr__(self, name, value) 121 | 122 | elif self.__mutability__ == _Mutability.FROZEN: 123 | raise dataclasses.FrozenInstanceError( 124 | "Dataclass registered as pytree is immutable!" 125 | ) 126 | 127 | else: 128 | assert False 129 | -------------------------------------------------------------------------------- /jax_dataclasses/_dataclasses.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import dataclasses 4 | import functools 5 | from typing import Dict, List, Optional, Type, TypeVar 6 | 7 | from jax import tree_util 8 | from typing_extensions import Annotated, get_type_hints 9 | 10 | from ._get_type_hints import get_type_hints_partial 11 | 12 | try: 13 | # Attempt to import flax for serialization. The exception handling lets us drop 14 | # flax from our dependencies. 15 | from flax import serialization 16 | except ImportError: 17 | serialization = None # type: ignore 18 | 19 | from . import _copy_and_mutate 20 | 21 | T = TypeVar("T") 22 | 23 | 24 | JDC_STATIC_MARKER = "__jax_dataclasses_static_field__" 25 | 26 | 27 | # Stolen from here: https://github.com/google/jax/issues/10476 28 | InnerT = TypeVar("InnerT") 29 | Static = Annotated[InnerT, JDC_STATIC_MARKER] 30 | """Annotates a type as static in the sense of JAX; in a pytree, fields marked as such 31 | should be hashable and are treated as part of the treedef and not as a child node.""" 32 | 33 | 34 | def pytree_dataclass(cls: Optional[Type] = None, **kwargs): 35 | """Substitute for dataclasses.dataclass, which also registers dataclasses as 36 | PyTrees.""" 37 | 38 | def wrap(cls): 39 | return _register_pytree_dataclass(dataclasses.dataclass(cls, **kwargs)) 40 | 41 | if "frozen" in kwargs: 42 | assert kwargs["frozen"] is True, "Pytree dataclasses can only be frozen!" 43 | kwargs["frozen"] = True 44 | 45 | if cls is None: 46 | return wrap 47 | else: 48 | return wrap(cls) 49 | 50 | 51 | def deprecated_static_field(*args, **kwargs): 52 | """Deprecated, prefer `Static[]` on the type annotation instead.""" 53 | 54 | kwargs["metadata"] = kwargs.get("metadata", {}) 55 | kwargs["metadata"][JDC_STATIC_MARKER] = True 56 | 57 | return dataclasses.field(*args, **kwargs) 58 | 59 | 60 | @dataclasses.dataclass(frozen=True) 61 | class FieldInfo: 62 | child_node_field_names: List[str] 63 | static_field_names: List[str] 64 | 65 | 66 | def _register_pytree_dataclass(cls: Type[T]) -> Type[T]: 67 | """Register a dataclass as a flax-serializable pytree container.""" 68 | 69 | assert dataclasses.is_dataclass(cls) 70 | 71 | @functools.lru_cache(maxsize=1) 72 | def get_field_info() -> FieldInfo: 73 | # Determine which fields are static and part of the treedef, and which should be 74 | # registered as child nodes. 75 | child_node_field_names: List[str] = [] 76 | static_field_names: List[str] = [] 77 | 78 | # We don't directly use field.type for postponed evaluation; we want to make sure 79 | # that our types are interpreted as proper types and not as (string) forward 80 | # references. 81 | # 82 | # Note that there are ocassionally situations where the @jdc.pytree_dataclass 83 | # decorator is called before a referenced type is defined; to suppress this error, 84 | # we resolve missing names to our subscriptible placeholder object. 85 | 86 | try: 87 | type_from_name = get_type_hints(cls, include_extras=True) # type: ignore 88 | except Exception: 89 | # Try again, but suppress errors from unresolvable forward 90 | # references. This should be rare. 91 | type_from_name = get_type_hints_partial(cls, include_extras=True) # type: ignore 92 | 93 | for field in dataclasses.fields(cls): 94 | if not field.init: 95 | continue 96 | 97 | field_type = type_from_name[field.name] 98 | 99 | # Two ways to mark a field as static: either via the Static[] type or 100 | # jdc.static_field(). 101 | if ( 102 | hasattr(field_type, "__metadata__") 103 | and JDC_STATIC_MARKER in field_type.__metadata__ 104 | ): 105 | static_field_names.append(field.name) 106 | continue 107 | if field.metadata.get(JDC_STATIC_MARKER, False): 108 | static_field_names.append(field.name) 109 | continue 110 | 111 | child_node_field_names.append(field.name) 112 | return FieldInfo(child_node_field_names, static_field_names) 113 | 114 | # Define flatten, unflatten operations: this simple converts our dataclass to a list 115 | # of fields. 116 | def _flatten(obj): 117 | field_info = get_field_info() 118 | children = tuple(getattr(obj, key) for key in field_info.child_node_field_names) 119 | treedef = tuple(getattr(obj, key) for key in field_info.static_field_names) 120 | return children, treedef 121 | 122 | def _unflatten(treedef, children): 123 | field_info = get_field_info() 124 | return cls( 125 | **dict(zip(field_info.child_node_field_names, children)), 126 | **{key: tdef for key, tdef in zip(field_info.static_field_names, treedef)}, 127 | ) 128 | 129 | tree_util.register_pytree_node(cls, _flatten, _unflatten) 130 | 131 | # Serialization: this is mostly copied from `flax.struct.dataclass`. 132 | if serialization is not None: 133 | 134 | def _to_state_dict(x: T): 135 | field_info = get_field_info() 136 | state_dict = { 137 | name: serialization.to_state_dict(getattr(x, name)) 138 | for name in field_info.child_node_field_names 139 | } 140 | return state_dict 141 | 142 | def _from_state_dict(x: T, state: Dict): 143 | # Copy the state so we can pop the restored fields. 144 | field_info = get_field_info() 145 | state = state.copy() 146 | updates = {} 147 | for name in field_info.child_node_field_names: 148 | if name not in state: 149 | raise ValueError( 150 | f"Missing field {name} in state dict while restoring" 151 | f" an instance of {cls.__name__}" 152 | ) 153 | value = getattr(x, name) 154 | value_state = state.pop(name) 155 | updates[name] = serialization.from_state_dict(value, value_state) 156 | if state: 157 | names = ",".join(state.keys()) 158 | raise ValueError( 159 | f'Unknown field(s) "{names}" in state dict while' 160 | f" restoring an instance of {cls.__name__}" 161 | ) 162 | 163 | return dataclasses.replace(x, **updates) # type: ignore 164 | 165 | serialization.register_serialization_state( 166 | cls, _to_state_dict, _from_state_dict 167 | ) 168 | 169 | # Custom frozen dataclass implementation 170 | cls.__mutability__ = _copy_and_mutate._Mutability.FROZEN # type: ignore 171 | cls.__setattr__ = _copy_and_mutate._new_setattr # type: ignore 172 | 173 | return cls # type: ignore 174 | -------------------------------------------------------------------------------- /jax_dataclasses/_enforced_annotations.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Any, List, Optional, Tuple 3 | 4 | from jax import numpy as jnp 5 | from typing_extensions import TypeGuard 6 | 7 | from ._get_type_hints import get_type_hints_partial 8 | 9 | ExpectedShape = Tuple[Any, ...] 10 | 11 | 12 | def _is_expected_shape(shape: Any) -> TypeGuard[ExpectedShape]: 13 | """Returns `True` is a tuple of integers + potentially an Ellipsis.""" 14 | return isinstance(shape, tuple) and all( 15 | map(lambda x: isinstance(x, int) or x is ..., shape) 16 | ) 17 | 18 | 19 | # Some dtype superclasses that result in a warning when we attempt a jnp.dtype() on them. 20 | _dtype_set = {jnp.integer, jnp.signedinteger, jnp.floating, jnp.inexact} 21 | 22 | 23 | def _is_dtype(dtype: Any) -> bool: 24 | """Returns `True` is `dtype` is a valid datatype.""" 25 | 26 | if dtype in _dtype_set: 27 | return True 28 | try: 29 | jnp.dtype(dtype) 30 | return True 31 | except TypeError: 32 | return False 33 | 34 | 35 | class EnforcedAnnotationsMixin: 36 | """**Deprecated** mixin for dataclasses containing arrays that are 37 | annotated with expected shapes or types that can be checked at runtime. 38 | 39 | Runs input validation on instantiation and provides a single helper, 40 | `get_batch_axes()`, that returns common batch axes. 41 | 42 | First, we import `Annotated`: 43 | 44 | # Python <=3.8 45 | from typing_extensions import Annotated 46 | 47 | # Python 3.9 48 | from typing import Annotated 49 | 50 | Example of an annotated fields that must have shape (*, 50, 150, 3) and (*, 10), 51 | where the batch axes must be shared: 52 | 53 | image: Annotated[jnp.ndarray, (..., 50, 150, 3)] 54 | label: Annotated[jnp.ndarray, (..., 10,)] 55 | 56 | Fields that must be floats and integers respectively: 57 | 58 | image: Annotated[jnp.ndarray, jnp.floating] # or jnp.float32, jnp.float64, etc 59 | label: Annotated[jnp.ndarray, jnp.integer] # or jnp.uint8, jnp.uint32, etc 60 | 61 | Fields with both shape and type constraints: 62 | 63 | image: Annotated[jnp.ndarray, (..., 50, 150, 3), jnp.floating] 64 | label: Annotated[jnp.ndarray, (..., 10,), jnp.integer] 65 | 66 | Where the annotations are order-invariant and both optional. 67 | """ 68 | 69 | def __post_init__(self) -> None: 70 | """Validate after construction. 71 | 72 | We raise assertion errors in only two scenarios: 73 | - A field has a dtype that's not a subtype of the annotated dtype. 74 | - A field has a shape that's not consistent with the annotated shape.""" 75 | 76 | assert dataclasses.is_dataclass(self) 77 | 78 | hint_from_name = get_type_hints_partial(type(self), include_extras=True) # type: ignore 79 | batch_axes: Optional[Tuple[int, ...]] = None 80 | 81 | # Batch axes for child/nested elements. 82 | child_batch_axes_list: List[Tuple[int, ...]] = [] 83 | 84 | # For each field... 85 | for field in dataclasses.fields(self): 86 | type_hint = hint_from_name[field.name] 87 | value = getattr(self, field.name) 88 | 89 | if isinstance(value, EnforcedAnnotationsMixin): 90 | child_batch_axes = getattr(value, "__batch_axes__") 91 | if child_batch_axes is not None: 92 | child_batch_axes_list.append(child_batch_axes) 93 | continue 94 | 95 | # Check for metadata from `typing.Annotated` value! Skip if no annotation. 96 | if not hasattr(type_hint, "__metadata__"): 97 | continue 98 | metadata: Tuple[Any, ...] = type_hint.__metadata__ 99 | assert ( 100 | len(metadata) <= 2 101 | ), "We expect <= 2 metadata items; only shape and dtype are expected." 102 | 103 | # Check data type. 104 | metadata_dtype = tuple(filter(_is_dtype, metadata)) 105 | if len(metadata_dtype) > 0 and hasattr(value, "dtype"): 106 | (dtype,) = metadata_dtype 107 | assert jnp.issubdtype( 108 | value.dtype, dtype 109 | ), f"Mismatched dtype, expected {dtype} but got {value.dtype}." 110 | 111 | # Shape checks. 112 | metadata_shape = tuple(filter(_is_expected_shape, metadata)) 113 | shape: Optional[Tuple[int, ...]] = None 114 | if isinstance(value, (int, float)): 115 | shape = () 116 | elif hasattr(value, "shape"): 117 | shape = value.shape 118 | if len(metadata_shape) > 0 and shape is not None: 119 | # Get expected shape, sans batch axes. 120 | (expected_shape,) = metadata_shape 121 | field_batch_axes = _check_batch_axes(shape, expected_shape) 122 | if batch_axes is None: 123 | batch_axes = field_batch_axes 124 | else: 125 | assert ( 126 | batch_axes == field_batch_axes 127 | ), f"Batch axis mismatch: {batch_axes} and {field_batch_axes}." 128 | 129 | # Check child batch axes: any batch axes present in the parent should be present 130 | # in the children as well. 131 | if batch_axes is not None: 132 | for child_batch_axes in child_batch_axes_list: 133 | assert ( 134 | len(child_batch_axes) >= len(batch_axes) 135 | and child_batch_axes[: len(batch_axes)] == batch_axes 136 | ), ( 137 | f"Child batch axes {child_batch_axes} don't match parent axes" 138 | f" {batch_axes}." 139 | ) 140 | 141 | object.__setattr__(self, "__batch_axes__", batch_axes) 142 | 143 | def get_batch_axes(self) -> Tuple[int, ...]: 144 | """Return any leading batch axes (which should be shared across all contained 145 | arrays).""" 146 | 147 | batch_axes = getattr(self, "__batch_axes__") 148 | assert batch_axes is not None 149 | return batch_axes 150 | 151 | 152 | def _check_batch_axes( 153 | shape: Tuple[int, ...], 154 | expected_shape: ExpectedShape, 155 | ) -> Tuple[int, ...]: 156 | # By default, assume batch axes are at start of 157 | if ... not in expected_shape: 158 | expected_shape = (...,) + expected_shape 159 | 160 | assert expected_shape.count(...) == 1 161 | batch_index = expected_shape.index(...) 162 | 163 | # Actual shape should be expected shape prefixed by some batch axes. 164 | if len(expected_shape) > 1: 165 | expected_prefix = expected_shape[:batch_index] 166 | expected_suffix = expected_shape[batch_index + 1 :] 167 | 168 | prefix = shape[: len(expected_prefix)] 169 | suffix = shape[-len(expected_suffix) :] 170 | 171 | shape_error = ( 172 | "Shape did not match annotation: expected" 173 | f" ({','.join(map(str, expected_prefix + ('*',) + expected_suffix))}) but" 174 | f" got {shape}." 175 | ) 176 | assert suffix == expected_suffix, shape_error 177 | assert prefix == expected_prefix, shape_error 178 | batch_axes = shape[len(expected_prefix) : -len(expected_suffix)] 179 | else: 180 | batch_axes = shape 181 | 182 | return batch_axes 183 | -------------------------------------------------------------------------------- /jax_dataclasses/_get_type_hints.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import collections 4 | import functools 5 | import sys 6 | import types 7 | from types import MethodDescriptorType, MethodWrapperType, WrapperDescriptorType 8 | from typing import Any, Dict, Type 9 | 10 | 11 | class _UnresolvableForwardReference: 12 | def __class_getitem__(cls, item) -> Type[_UnresolvableForwardReference]: 13 | """__getitem__ passthrough, for supporting generics.""" 14 | return _UnresolvableForwardReference 15 | 16 | 17 | _allowed_types = ( 18 | types.FunctionType, 19 | types.BuiltinFunctionType, 20 | types.MethodType, 21 | types.ModuleType, 22 | WrapperDescriptorType, 23 | MethodWrapperType, 24 | MethodDescriptorType, 25 | ) 26 | 27 | 28 | @functools.lru_cache(maxsize=128) 29 | def get_type_hints_partial(obj, include_extras=False) -> Dict[str, Any]: 30 | """Adapted from typing.get_type_hints(), but aimed at suppressing errors from not 31 | (yet) resolvable forward references. 32 | 33 | This function should only be used to search for fields that are annotated with 34 | `jdc.Static[]`. 35 | 36 | For example: 37 | 38 | @jdc.pytree_dataclass 39 | class A: 40 | x: B 41 | y: jdc.Static[bool] 42 | 43 | @jdc.pytree_dataclass 44 | class B: 45 | x: jnp.ndarray 46 | 47 | Note that the type annotations of `A` need to be parsed by the `pytree_dataclass` 48 | decorator in order to register the static field, but `B` is not yet defined when the 49 | decorator is run. We don't actually care about the details of the `B` annotation, so 50 | we replace it in our annotation dictionary with a dummy value. 51 | 52 | Differences: 53 | 1. `include_extras` must be True. 54 | 2. Only supports types. 55 | 3. Doesn't throw an error when a name is not found. Instead, replaces the value 56 | with `_UnresolvableForwardReference`. 57 | """ 58 | assert include_extras 59 | 60 | # Replace any unresolvable names with _UnresolvableForwardReference. 61 | base_globals: Dict[str, Any] = collections.defaultdict( 62 | lambda: _UnresolvableForwardReference 63 | ) 64 | base_globals.update(__builtins__) # type: ignore 65 | 66 | # Classes require a special treatment. 67 | if isinstance(obj, type): 68 | hints = {} 69 | for base in reversed(obj.__mro__): 70 | ann = base.__dict__.get("__annotations__", {}) 71 | if len(ann) == 0: 72 | continue 73 | 74 | base_globals.update(sys.modules[base.__module__].__dict__) 75 | 76 | for name, value in ann.items(): 77 | if value is None: 78 | value = type(None) 79 | if isinstance(value, str): 80 | # The * replace is a hack for variadic generic support. 81 | value = value.replace("*", "") 82 | value = eval(value, base_globals) 83 | hints[name] = value 84 | return hints 85 | 86 | nsobj = obj 87 | 88 | # Find globalns for the unwrapped object. 89 | while hasattr(nsobj, "__wrapped__"): 90 | nsobj = nsobj.__wrapped__ 91 | base_globals.update(getattr(nsobj, "__globals__", {})) 92 | 93 | hints = getattr(obj, "__annotations__", None) # type: ignore 94 | if hints is None: 95 | # Return empty annotations for something that _could_ have them. 96 | if isinstance(obj, _allowed_types): 97 | return {} 98 | else: 99 | raise TypeError( 100 | "{!r} is not a module, class, method, or function.".format(obj) 101 | ) 102 | hints = dict(hints) 103 | for name, value in hints.items(): 104 | if value is None: 105 | value = type(None) 106 | if isinstance(value, str): 107 | # The * replace is a hack for variadic generic support. 108 | value = value.replace("*", "") 109 | value = eval(value, base_globals) 110 | hints[name] = value 111 | return hints 112 | -------------------------------------------------------------------------------- /jax_dataclasses/_jit.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Any, Callable, Optional, Sequence, TypeVar, Union, cast, overload 3 | 4 | import jax 5 | from jaxlib import xla_client as xc 6 | 7 | from ._dataclasses import JDC_STATIC_MARKER 8 | from ._get_type_hints import get_type_hints_partial 9 | 10 | CallableType = TypeVar("CallableType", bound=Callable) 11 | 12 | 13 | @overload 14 | def jit( 15 | fun: CallableType, 16 | *, 17 | device: Optional[xc.Device] = None, 18 | backend: Optional[str] = None, 19 | donate_argnums: Union[int, Sequence[int]] = (), 20 | inline: bool = False, 21 | keep_unused: bool = False, 22 | abstracted_axes: Optional[Any] = None, 23 | ) -> CallableType: ... 24 | 25 | 26 | @overload 27 | def jit( 28 | fun: None = None, 29 | *, 30 | device: Optional[xc.Device] = None, 31 | backend: Optional[str] = None, 32 | donate_argnums: Union[int, Sequence[int]] = (), 33 | inline: bool = False, 34 | keep_unused: bool = False, 35 | abstracted_axes: Optional[Any] = None, 36 | ) -> Callable[[CallableType], CallableType]: ... 37 | 38 | 39 | def jit( 40 | fun: Optional[CallableType] = None, 41 | *, 42 | device: Optional[xc.Device] = None, 43 | backend: Optional[str] = None, 44 | donate_argnums: Union[int, Sequence[int]] = (), 45 | inline: bool = False, 46 | keep_unused: bool = False, 47 | abstracted_axes: Optional[Any] = None, 48 | ) -> Union[CallableType, Callable[[CallableType], CallableType]]: 49 | """Light wrapper around `jax.jit`, with usability and type checking improvements. 50 | 51 | Three differences: 52 | - We remove the `static_argnums` and `static_argnames` parameters. Instead, 53 | static arguments can be specified in type annotations with 54 | `jax_dataclasses.Static[]`. 55 | - Instead of `jax.stages.Wrapped`, the return callable type is annotated to 56 | match the input callable type. This will improve autocomplete and type 57 | checking in most situations. 58 | - Similar to `@dataclasses.dataclass`, return a decorator if `fun` isn't passed 59 | in. This is convenient for avoiding `@functools.partial()`. 60 | """ 61 | 62 | def wrap(fun: CallableType) -> CallableType: 63 | signature = inspect.signature(fun) 64 | 65 | # Mark any inputs annotated with jax_dataclasses.Static[] as static. 66 | static_argnums = [] 67 | static_argnames = [] 68 | hint_from_name = get_type_hints_partial(fun, include_extras=True) 69 | for i, param in enumerate(signature.parameters.values()): 70 | name = param.name 71 | if name not in hint_from_name: 72 | continue 73 | hint = hint_from_name[name] 74 | if hasattr(hint, "__metadata__") and JDC_STATIC_MARKER in hint.__metadata__: 75 | if param.kind is param.POSITIONAL_ONLY: 76 | static_argnums.append(i) 77 | else: 78 | static_argnames.append(name) 79 | 80 | return cast( 81 | CallableType, 82 | jax.jit( 83 | fun, 84 | static_argnums=static_argnums if len(static_argnums) > 0 else None, 85 | static_argnames=static_argnames if len(static_argnames) > 0 else None, 86 | device=device, 87 | backend=backend, 88 | donate_argnums=donate_argnums, 89 | inline=inline, 90 | keep_unused=keep_unused, 91 | abstracted_axes=abstracted_axes, 92 | ), 93 | ) 94 | 95 | if fun is None: 96 | return wrap 97 | else: 98 | return wrap(fun) 99 | -------------------------------------------------------------------------------- /jax_dataclasses/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brentyi/jax_dataclasses/6817a9de26d6ae875f5bf70f24a983cad5d8cb11/jax_dataclasses/py.typed -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.12 3 | ignore_missing_imports = True 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | setup( 6 | name="jax_dataclasses", 7 | version="1.6.2", 8 | description="Dataclasses + JAX", 9 | long_description=long_description, 10 | long_description_content_type="text/markdown", 11 | url="http://github.com/brentyi/jax_dataclasses", 12 | author="brentyi", 13 | author_email="brentyi@berkeley.edu", 14 | license="MIT", 15 | packages=find_packages(), 16 | package_data={"jax_dataclasses": ["py.typed"]}, 17 | python_requires=">=3.9", 18 | install_requires=[ 19 | "jax>=0.4.25", # For `jax.tree_util.default_registry` 20 | "jaxlib", 21 | "typing_extensions", 22 | ], 23 | extras_require={ 24 | "testing": [ 25 | "flax", # Used for serialization tests. 26 | "pytest", 27 | "pytest-cov", 28 | ] 29 | }, 30 | classifiers=[ 31 | "Programming Language :: Python :: 3.9", 32 | "Programming Language :: Python :: 3.10", 33 | "Programming Language :: Python :: 3.11", 34 | "Programming Language :: Python :: 3.12", 35 | "License :: OSI Approved :: MIT License", 36 | "Operating System :: OS Independent", 37 | ], 38 | ) 39 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | collect_ignore_glob = [] 4 | if sys.version_info.major == 3 and sys.version_info.minor == 7: 5 | collect_ignore_glob.append("*_ignore_py37.py") 6 | 7 | if not (sys.version_info.major == 3 and sys.version_info.minor == 12): 8 | collect_ignore_glob.append("*_py312.py") 9 | -------------------------------------------------------------------------------- /tests/test_annotated_arrays.py: -------------------------------------------------------------------------------- 1 | """Tests for optional shape and type annotation features.""" 2 | 3 | import jax 4 | import numpy as onp 5 | import pytest 6 | from jax import numpy as jnp 7 | from typing_extensions import Annotated 8 | 9 | import jax_dataclasses as jdc 10 | 11 | 12 | @jdc.pytree_dataclass 13 | class MnistStruct(jdc.EnforcedAnnotationsMixin): 14 | image: Annotated[ 15 | onp.ndarray, 16 | (..., 28, 28), 17 | jnp.floating, 18 | ] 19 | label: Annotated[ 20 | onp.ndarray, 21 | (..., 10), 22 | jnp.integer, 23 | ] 24 | 25 | 26 | @jdc.pytree_dataclass 27 | class MnistStructPartial(jdc.EnforcedAnnotationsMixin): 28 | image_shape_only: Annotated[ 29 | onp.ndarray, 30 | (28, 28), # Ellipsis will be appended automatically. 31 | ] 32 | label_dtype_only: Annotated[ 33 | onp.ndarray, 34 | jnp.integer, 35 | ] 36 | 37 | 38 | def test_valid() -> None: 39 | data = MnistStruct( 40 | image=onp.zeros((28, 28), dtype=onp.float32), 41 | label=onp.zeros((10,), dtype=onp.uint8), 42 | ) 43 | assert data.get_batch_axes() == () 44 | 45 | data = MnistStruct( 46 | image=onp.zeros((5, 28, 28), dtype=onp.float32), 47 | label=onp.zeros((5, 10), dtype=onp.uint8), 48 | ) 49 | assert data.get_batch_axes() == (5,) 50 | 51 | data = MnistStruct( 52 | image=onp.zeros((5, 7, 28, 28), dtype=onp.float32), 53 | label=onp.zeros((5, 7, 10), dtype=onp.uint8), 54 | ) 55 | assert data.get_batch_axes() == (5, 7) 56 | 57 | data_partial = MnistStructPartial( 58 | image_shape_only=onp.zeros((7, 28, 28), dtype=onp.float32), 59 | label_dtype_only=onp.zeros((70), dtype=onp.int32), 60 | ) 61 | assert data_partial.get_batch_axes() == (7,) 62 | 63 | 64 | def test_shape_mismatch() -> None: 65 | with pytest.raises(AssertionError): 66 | MnistStruct( 67 | image=onp.zeros((7, 32, 32), dtype=onp.float32), 68 | label=onp.zeros((7, 10), dtype=onp.uint8), 69 | ) 70 | 71 | with pytest.raises(AssertionError): 72 | MnistStructPartial( 73 | image_shape_only=onp.zeros((7, 32, 32), dtype=onp.float32), 74 | label_dtype_only=onp.zeros((7, 10), dtype=onp.uint8), 75 | ) 76 | 77 | 78 | def test_batch_axis_mismatch() -> None: 79 | with pytest.raises(AssertionError): 80 | MnistStruct( 81 | image=onp.zeros((5, 7, 28, 28), dtype=onp.float32), 82 | label=onp.zeros((7, 10), dtype=onp.uint8), 83 | ) 84 | 85 | 86 | def test_dtype_mismatch() -> None: 87 | with pytest.raises(AssertionError): 88 | MnistStruct( 89 | image=onp.zeros((7, 28, 28), dtype=onp.uint8), 90 | label=onp.zeros((7, 10), dtype=onp.uint8), 91 | ) 92 | 93 | with pytest.raises(AssertionError): 94 | MnistStructPartial( 95 | image_shape_only=onp.zeros((7, 28, 28), dtype=onp.float32), 96 | label_dtype_only=onp.zeros((7, 10), dtype=onp.float32), 97 | ) 98 | 99 | 100 | def test_nested() -> None: 101 | @jdc.pytree_dataclass 102 | class Parent(jdc.EnforcedAnnotationsMixin): 103 | x: Annotated[onp.ndarray, jnp.floating, ()] 104 | child: MnistStruct 105 | 106 | # OK 107 | assert Parent( 108 | x=onp.zeros((7,), dtype=onp.float32), 109 | child=MnistStruct( 110 | image=onp.zeros((7, 28, 28), dtype=onp.float32), 111 | label=onp.zeros((7, 10), dtype=onp.uint8), 112 | ), 113 | ).get_batch_axes() == (7,) 114 | 115 | # Batch axis mismatch 116 | with pytest.raises(AssertionError): 117 | Parent( 118 | x=onp.zeros((5,), dtype=onp.float32), 119 | child=MnistStruct( 120 | image=onp.zeros((7, 28, 28), dtype=onp.float32), 121 | label=onp.zeros((7, 10), dtype=onp.uint8), 122 | ), 123 | ) 124 | 125 | # Type error 126 | with pytest.raises(AssertionError): 127 | Parent( 128 | x=onp.zeros((7,), dtype=onp.float32), 129 | child=MnistStruct( 130 | image=onp.zeros((7, 28, 28), dtype=onp.float32), 131 | label=onp.zeros((7, 10), dtype=onp.float32), 132 | ), 133 | ) 134 | 135 | 136 | def test_scalar() -> None: 137 | @jdc.pytree_dataclass 138 | class ScalarContainer(jdc.EnforcedAnnotationsMixin): 139 | scalar: Annotated[onp.ndarray, ()] # () => scalar shape 140 | 141 | assert ScalarContainer(scalar=5.0).get_batch_axes() == () # type: ignore 142 | assert ScalarContainer(scalar=onp.zeros((5,))).get_batch_axes() == (5,) 143 | 144 | 145 | def test_grad() -> None: 146 | @jdc.pytree_dataclass 147 | class Vector3(jdc.EnforcedAnnotationsMixin): 148 | parameters: Annotated[onp.ndarray, (3,)] 149 | 150 | # Make sure we can compute gradients wrt annotated dataclasses. 151 | grad = jax.grad(lambda x: jnp.sum(x.parameters))(Vector3(onp.zeros(3))) 152 | onp.testing.assert_allclose(grad.parameters, onp.ones((3,))) 153 | 154 | 155 | def test_unannotated() -> None: 156 | @jdc.pytree_dataclass 157 | class Test(jdc.EnforcedAnnotationsMixin): 158 | a: onp.ndarray 159 | 160 | with pytest.raises(AssertionError): 161 | Test(onp.zeros((2, 1, 2, 3, 5, 7, 9))).get_batch_axes() 162 | 163 | 164 | def test_middle_batch_axes() -> None: 165 | @jdc.pytree_dataclass 166 | class Test(jdc.EnforcedAnnotationsMixin): 167 | a: Annotated[onp.ndarray, (3, ..., 5, 7, 9)] 168 | 169 | test = Test(onp.zeros((3, 1, 2, 3, 5, 7, 9))) 170 | assert test.get_batch_axes() == (1, 2, 3) 171 | 172 | with pytest.raises(AssertionError): 173 | Test(onp.zeros((2, 1, 2, 3, 5, 7, 9))) 174 | with pytest.raises(AssertionError): 175 | Test(onp.zeros((3, 1, 2, 3, 5, 7))) 176 | 177 | 178 | # This test currently breaks -- shape assertions on instantiation makes it impossible to 179 | # compute some more complex Jacobians. 180 | # 181 | # Some options for fixing: adding some way to temporarily disable validation, or moving 182 | # away from validation on instantiation to validation only when `.get_batch_axes()` is 183 | # called. Either should be fairly straightforward, but this is fairly niche, produces (in my 184 | # opinion) unintuitive Pytree structures, and is easy to work around, so marking as a 185 | # no-fix for now. 186 | # 187 | # def test_jacobians() -> None: 188 | # @jdc.pytree_dataclass 189 | # class Vector3(ArrayAnnotationMixin): 190 | # parameters: Annotated[onp.ndarray, (3,)] 191 | # 192 | # @jdc.pytree_dataclass 193 | # class Vector4(ArrayAnnotationMixin): 194 | # parameters: Annotated[onp.ndarray, (4,)] 195 | # 196 | # def vec4_from_vec3(vec3: Vector3) -> Vector4: 197 | # return Vector4(onp.zeros((4,))) 198 | # 199 | # def vec3_from_vec4(vec4: Vector4) -> Vector3: 200 | # return Vector3(onp.zeros((3,))) 201 | # 202 | # jac = jax.jacfwd(vec4_from_vec3)(Vector3(onp.zeros((3,)))) 203 | # jac = jax.jacfwd(vec3_from_vec4)(Vector4(onp.zeros((4,)))) 204 | -------------------------------------------------------------------------------- /tests/test_copy_and_mutate.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Any, Dict, List 3 | 4 | import numpy as onp 5 | import pytest 6 | 7 | import jax_dataclasses as jdc 8 | 9 | 10 | def test_copy_and_mutate() -> None: 11 | # frozen=True should do nothing 12 | @jdc.pytree_dataclass(frozen=True) 13 | class Foo: 14 | array: Any 15 | 16 | @jdc.pytree_dataclass 17 | class Bar: 18 | children: List[Foo] 19 | array: Any 20 | array_unchanged: onp.ndarray 21 | 22 | obj = Bar( 23 | children=[Foo(array=onp.zeros(3))], 24 | array=onp.ones(3), 25 | array_unchanged=onp.ones(3), 26 | ) 27 | 28 | # Registered dataclasses are initially immutable 29 | with pytest.raises(dataclasses.FrozenInstanceError): 30 | obj.array = onp.zeros(3) 31 | 32 | # But we can use a context that copies a dataclass and temporarily makes the copy 33 | # mutable: 34 | with jdc.copy_and_mutate(obj) as obj: 35 | # Updates can then very easily be applied! 36 | obj.array = onp.zeros(3) 37 | obj.children[0].array = onp.ones(3) # type: ignore 38 | 39 | # Shapes can be validated... 40 | with pytest.raises(AssertionError): 41 | obj.children[0].array = onp.ones(1) # type: ignore 42 | 43 | # As well as dtypes 44 | with pytest.raises(AssertionError): 45 | obj.children[0].array = onp.ones(3, dtype=onp.int32) # type: ignore 46 | 47 | # Validation can also be disabled 48 | with jdc.copy_and_mutate(obj, validate=False) as obj: 49 | obj.children[0].array = onp.ones(1) # type: ignore 50 | obj.children[0].array = onp.ones(3) # type: ignore 51 | 52 | # Outside of the replace context, the copied object becomes immutable again: 53 | with pytest.raises(dataclasses.FrozenInstanceError): 54 | obj.array = onp.zeros(3) 55 | with pytest.raises(dataclasses.FrozenInstanceError): 56 | obj.children[0].array = onp.ones(3) # type: ignore 57 | 58 | onp.testing.assert_allclose(obj.array, onp.zeros(3)) 59 | onp.testing.assert_allclose(obj.array_unchanged, onp.ones(3)) 60 | onp.testing.assert_allclose(obj.children[0].array, onp.ones(3)) 61 | 62 | 63 | def test_copy_and_mutate_static() -> None: 64 | @dataclasses.dataclass 65 | class Inner: 66 | a: int 67 | b: int 68 | 69 | @jdc.pytree_dataclass 70 | class Foo: 71 | arrays: Dict[str, onp.ndarray] 72 | child: jdc.Static[Inner] 73 | 74 | obj = Foo(arrays={"x": onp.ones(3)}, child=Inner(1, 2)) 75 | 76 | # Registered dataclasses are initially immutable 77 | with pytest.raises(dataclasses.FrozenInstanceError): 78 | obj.child = Inner(5, 6) 79 | 80 | assert obj.child == Inner(1, 2) 81 | 82 | # But can be copied and mutated in a special context 83 | with jdc.copy_and_mutate(obj) as obj_updated: 84 | obj_updated.child = Inner(5, 6) 85 | 86 | assert obj.child == Inner(1, 2) 87 | assert obj_updated.child == Inner(5, 6) 88 | -------------------------------------------------------------------------------- /tests/test_cycle.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Tuple 4 | 5 | from jax import numpy as jnp 6 | 7 | import jax_dataclasses as jdc 8 | 9 | 10 | def test_cycle() -> None: 11 | @jdc.pytree_dataclass 12 | class TreeNode: 13 | content: jnp.ndarray 14 | children: Tuple[TreeNode, ...] 15 | 16 | a = TreeNode(content=jnp.zeros(3), children=()) 17 | b = TreeNode(content=jnp.ones(3), children=(a,)) 18 | 19 | # Not a cycle. OK! 20 | with jdc.copy_and_mutate(b, validate=False) as b: 21 | b.children = (a, a) 22 | 23 | # Cycle. Ideally this should raise an error, but the way duplicate nodes in pytrees 24 | # are possible makes robust cycle detection under linear space/time constraints a 25 | # hassle. So instead we currently do nothing. 26 | with jdc.copy_and_mutate(b, validate=False) as b: 27 | b.children[0].children = (b,) 28 | -------------------------------------------------------------------------------- /tests/test_dataclass.py: -------------------------------------------------------------------------------- 1 | """Tests for standard jdc.pytree_dataclass features. Initialization, flattening, unflattening, 2 | static fields, etc. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | from typing import Generic, TypeVar 8 | 9 | import jax 10 | import numpy as onp 11 | import pytest 12 | from jax import tree_util 13 | 14 | import jax_dataclasses as jdc 15 | 16 | 17 | def _assert_pytree_allclose(x, y): 18 | tree_util.tree_map( 19 | lambda *arrays: onp.testing.assert_allclose(arrays[0], arrays[1]), x, y 20 | ) 21 | 22 | 23 | def test_init() -> None: 24 | @jdc.pytree_dataclass 25 | class A: 26 | field1: int 27 | field2: int 28 | 29 | assert A(field1=5, field2=3) == A(5, 3) 30 | 31 | with pytest.raises(TypeError): 32 | # Not enough arguments 33 | A(field1=5) # type: ignore 34 | 35 | 36 | def test_default_arg() -> None: 37 | @jdc.pytree_dataclass 38 | class A: 39 | field1: int 40 | field2: int = 3 41 | 42 | assert A(field1=5, field2=3) == A(5, 3) == A(field1=5) == A(5) 43 | 44 | 45 | def test_flatten() -> None: 46 | @jdc.pytree_dataclass 47 | class A: 48 | field1: float 49 | field2: float 50 | 51 | @jax.jit 52 | def jitted_sum(obj: A) -> float: 53 | return obj.field1 + obj.field2 54 | 55 | _assert_pytree_allclose(jitted_sum(A(5.0, 3.0)), 8.0) 56 | 57 | 58 | def test_unflatten() -> None: 59 | @jdc.pytree_dataclass 60 | class A: 61 | field1: float 62 | field2: float 63 | 64 | @jax.jit 65 | def construct_A(a: float) -> A: 66 | return A(field1=a, field2=a * 2.0) 67 | 68 | _assert_pytree_allclose(A(1.0, 2.0), construct_A(1.0)) 69 | 70 | 71 | def test_static_field() -> None: 72 | @jdc.pytree_dataclass 73 | class A: 74 | field1: float 75 | field2: float 76 | field3: jdc.Static[bool] 77 | 78 | @jax.jit 79 | def jitted_op(obj: A) -> float: 80 | if obj.field3: 81 | return obj.field1 + obj.field2 82 | else: 83 | return obj.field1 - obj.field2 84 | 85 | with pytest.raises(ValueError): 86 | # Cannot map over pytrees with different treedefs 87 | _assert_pytree_allclose(A(1.0, 2.0, False), A(1.0, 2.0, True)) 88 | 89 | _assert_pytree_allclose(jitted_op(A(5.0, 3.0, True)), 8.0) 90 | _assert_pytree_allclose(jitted_op(A(5.0, 3.0, False)), 2.0) 91 | 92 | 93 | def test_static_field_deprecated() -> None: 94 | @jdc.pytree_dataclass 95 | class A: 96 | field1: float 97 | field2: float 98 | field3: bool = jdc.static_field() # type: ignore 99 | 100 | @jax.jit 101 | def jitted_op(obj: A) -> float: 102 | if obj.field3: 103 | return obj.field1 + obj.field2 104 | else: 105 | return obj.field1 - obj.field2 106 | 107 | with pytest.raises(ValueError): 108 | # Cannot map over pytrees with different treedefs 109 | _assert_pytree_allclose(A(1.0, 2.0, False), A(1.0, 2.0, True)) 110 | 111 | _assert_pytree_allclose(jitted_op(A(5.0, 3.0, True)), 8.0) 112 | _assert_pytree_allclose(jitted_op(A(5.0, 3.0, False)), 2.0) 113 | 114 | 115 | def test_no_init() -> None: 116 | @jdc.pytree_dataclass 117 | class A: 118 | field1: float 119 | field2: float = jdc.field() 120 | field3: jdc.Static[bool] = jdc.field(init=False) 121 | 122 | def __post_init__(self): 123 | object.__setattr__(self, "field3", False) 124 | 125 | @jax.jit 126 | def construct_A(a: float) -> A: 127 | return A(field1=a, field2=a * 2.0) 128 | 129 | assert construct_A(5.0).field3 is False 130 | 131 | 132 | def test_static_field_forward_ref() -> None: 133 | @jdc.pytree_dataclass 134 | class A: 135 | field1: float 136 | field2: float 137 | field3: jdc.Static[Container[bool]] 138 | 139 | T = TypeVar("T") 140 | 141 | @jdc.pytree_dataclass 142 | class Container(Generic[T]): 143 | x: T 144 | 145 | @jax.jit 146 | def jitted_op(obj: A) -> float: 147 | if obj.field3.x: 148 | return obj.field1 + obj.field2 149 | else: 150 | return obj.field1 - obj.field2 151 | 152 | with pytest.raises(ValueError): 153 | # Cannot map over pytrees with different treedefs 154 | _assert_pytree_allclose( 155 | A(1.0, 2.0, Container(False)), A(1.0, 2.0, Container(True)) 156 | ) 157 | 158 | _assert_pytree_allclose(jitted_op(A(5.0, 3.0, Container(True))), 8.0) 159 | _assert_pytree_allclose(jitted_op(A(5.0, 3.0, Container(False))), 2.0) 160 | -------------------------------------------------------------------------------- /tests/test_jit_ignore_py37.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import dataclasses 4 | 5 | import jax 6 | import pytest 7 | from jax import numpy as jnp 8 | 9 | import jax_dataclasses as jdc 10 | 11 | 12 | def test_jit_0(): 13 | def func(x: int, y: int) -> jax.Array: 14 | assert isinstance(x, int) 15 | assert isinstance(y, int) 16 | return jnp.full(shape=(x,), fill_value=y) 17 | 18 | assert func(3, 4).shape == (3,) 19 | 20 | 21 | def test_jit_1(): 22 | @jdc.jit 23 | def func(x: int, y: int) -> jax.Array: 24 | assert not isinstance(x, int) 25 | assert not isinstance(y, int) 26 | return jnp.full(shape=(x,), fill_value=y) 27 | 28 | with pytest.raises(TypeError): 29 | assert func(3, 4).shape == (3,) 30 | 31 | 32 | def test_jit_2(): 33 | @jdc.jit 34 | def func(x: jdc.Static[int], y: int) -> jax.Array: 35 | assert isinstance(x, int) 36 | assert not isinstance(y, int) 37 | return jnp.full(shape=(x,), fill_value=y) 38 | 39 | assert func(3, 4).shape == (3,) 40 | 41 | 42 | def test_jit_3(): 43 | @jdc.jit 44 | def func(x: jdc.Static[int], /, y: int) -> jax.Array: 45 | assert isinstance(x, int) 46 | assert not isinstance(y, int) 47 | return jnp.full(shape=(x,), fill_value=y) 48 | 49 | assert func(3, 4).shape == (3,) 50 | 51 | 52 | def test_jit_4(): 53 | @jdc.jit 54 | def func(*, x: jdc.Static[int], y: int) -> jax.Array: 55 | assert isinstance(x, int) 56 | assert not isinstance(y, int) 57 | return jnp.full(shape=(x,), fill_value=y) 58 | 59 | assert func(x=3, y=4).shape == (3,) 60 | 61 | 62 | def test_jit_5(): 63 | @jdc.jit 64 | def func(x: jdc.Static[int], y: int, /, *, z: jdc.Static[int]) -> jax.Array: 65 | assert isinstance(x, int) 66 | assert not isinstance(y, int) 67 | assert isinstance(z, int) 68 | return jnp.full(shape=(x + z,), fill_value=y) 69 | 70 | assert func(2, 4, z=1).shape == (3,) 71 | 72 | 73 | def test_jit_6(): 74 | @jdc.jit 75 | def func(x: jdc.Static[int], y: int, *, z: jdc.Static[int]) -> jax.Array: 76 | assert isinstance(x, int) 77 | assert not isinstance(y, int) 78 | assert isinstance(z, int) 79 | return jnp.full(shape=(x + z,), fill_value=y) 80 | 81 | assert func(2, 4, z=1).shape == (3,) 82 | 83 | 84 | def test_jit_7(): 85 | @jdc.jit 86 | def func(x: jdc.Static[int], y: int, z: jdc.Static[int], /) -> jax.Array: 87 | assert isinstance(x, int) 88 | assert not isinstance(y, int) 89 | assert isinstance(z, int) 90 | return jnp.full(shape=(x + z,), fill_value=y) 91 | 92 | assert func(2, 4, 1).shape == (3,) 93 | 94 | 95 | def test_jit_no_annotation(): 96 | @jdc.jit 97 | def func(x: jdc.Static[int], y, z: jdc.Static[int], /) -> jax.Array: 98 | assert isinstance(x, int) 99 | assert not isinstance(y, int) 100 | assert isinstance(z, int) 101 | return jnp.full(shape=(x + z,), fill_value=y) 102 | 103 | assert func(2, 4, 1).shape == (3,) 104 | 105 | 106 | def test_jit_donate_buffer(): 107 | @jdc.jit(donate_argnums=(1,)) 108 | def func(x: jdc.Static[int], y: int, z: jdc.Static[int], /) -> jax.Array: 109 | assert isinstance(x, int) 110 | assert not isinstance(y, int) 111 | assert isinstance(z, int) 112 | out = jnp.full(shape=(x + z,), fill_value=y) 113 | 114 | # Shape matches `y`, so we should be able to reuse the donated buffer. 115 | return jnp.sum(out) 116 | 117 | assert func(2, 4, 1).shape == () 118 | 119 | 120 | def test_jit_forward_ref(): 121 | @jdc.jit 122 | def func(xz: jdc.Static[SomeConfig], y: int, /) -> jax.Array: 123 | assert not isinstance(y, int) 124 | return jnp.full(shape=(xz.x + xz.z,), fill_value=y) 125 | 126 | assert func(SomeConfig(2, 1), 4).shape == (3,) 127 | 128 | 129 | def test_jit_lambda(): 130 | assert jdc.jit(lambda x, y: x + y)(jnp.zeros(3), jnp.ones(3)).shape == (3,) 131 | 132 | 133 | @dataclasses.dataclass(frozen=True) 134 | class SomeConfig: 135 | x: int 136 | z: int 137 | -------------------------------------------------------------------------------- /tests/test_serialization.py: -------------------------------------------------------------------------------- 1 | """Tests for serialization using `flax.serialization`.""" 2 | 3 | import flax 4 | import numpy as onp 5 | from jax import tree_util 6 | 7 | import jax_dataclasses as jdc 8 | 9 | 10 | def _assert_pytree_allclose(x, y) -> None: 11 | tree_util.tree_map( 12 | lambda *arrays: onp.testing.assert_allclose(arrays[0], arrays[1]), x, y 13 | ) 14 | 15 | 16 | def test_serialization() -> None: 17 | @jdc.pytree_dataclass 18 | class A: 19 | field1: int 20 | field2: int 21 | field3: jdc.Static[bool] 22 | 23 | obj = A(field1=5, field2=3, field3=True) 24 | 25 | _assert_pytree_allclose( 26 | obj, 27 | flax.serialization.from_bytes(obj, flax.serialization.to_bytes(obj)), 28 | ) 29 | -------------------------------------------------------------------------------- /tests/test_variadic_generic_py312.py: -------------------------------------------------------------------------------- 1 | # mypy: ignore-errors 2 | # 3 | # PEP 695 generics aren't yet supported in mypy. 4 | 5 | from __future__ import annotations 6 | 7 | from typing import Generic, TypeVarTuple 8 | 9 | import jax_dataclasses as jdc 10 | 11 | Ts = TypeVarTuple("Ts") 12 | 13 | 14 | @jdc.pytree_dataclass 15 | class Args(Generic[*Ts]): 16 | @staticmethod 17 | @jdc.jit 18 | def make(args: jdc.Static[tuple[*Ts]]) -> tuple[Args[*Ts], tuple[*Ts]]: 19 | return Args(), args 20 | 21 | 22 | def test0() -> None: 23 | assert Args.make((1, 2, 3))[1] == (1, 2, 3) 24 | 25 | 26 | @jdc.pytree_dataclass 27 | class Args2[*T]: 28 | @staticmethod 29 | @jdc.jit 30 | def make[*T_](args: jdc.Static[tuple[*T_]]) -> tuple[Args2[*T_], tuple[*T_]]: 31 | return Args2(), args 32 | 33 | 34 | def test1() -> None: 35 | assert Args2.make((1, 2, 3))[1] == (1, 2, 3) 36 | -------------------------------------------------------------------------------- /tests/test_vmap.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import pytest 3 | from jax import numpy as jnp 4 | from typing_extensions import Annotated 5 | 6 | import jax_dataclasses as jdc 7 | 8 | 9 | @jdc.pytree_dataclass 10 | class Node(jdc.EnforcedAnnotationsMixin): 11 | a: Annotated[jnp.ndarray, (5,), jnp.floating] 12 | 13 | 14 | def test_vmap(): 15 | with pytest.raises(AssertionError): 16 | jax.jit(jax.vmap(lambda *unused: None))(Node(jnp.zeros((5,)))) 17 | 18 | jax.jit(jax.vmap(lambda *unused: None))(Node(jnp.zeros((5, 5)))) 19 | jax.jit(jax.vmap(lambda *unused: None, in_axes=(None, None, 0, 0)))( 20 | Node(jnp.zeros(5)), 21 | Node(jnp.zeros(5)), 22 | jnp.zeros((5, 100)), 23 | jnp.zeros((5, 100)), 24 | ) 25 | --------------------------------------------------------------------------------