├── chex ├── py.typed ├── _src │ ├── __init__.py │ ├── warnings_test.py │ ├── fake_set_n_cpu_devices_test.py │ ├── pytypes.py │ ├── warnings.py │ ├── restrict_backends.py │ ├── restrict_backends_test.py │ ├── asserts_internal_test.py │ ├── dimensions_test.py │ ├── dimensions.py │ ├── asserts_chexify.py │ ├── dataclass.py │ ├── fake.py │ ├── fake_test.py │ ├── asserts_internal.py │ └── variants.py ├── chex_test.py └── __init__.py ├── requirements ├── requirements-test.txt ├── requirements-docs.txt └── requirements.txt ├── MANIFEST.in ├── .gitignore ├── .readthedocs.yaml ├── docs ├── Makefile ├── index.rst ├── ext │ └── coverage_check.py ├── api.rst └── conf.py ├── .github └── workflows │ ├── ci.yml │ └── pypi-publish.yml ├── CONTRIBUTING.md ├── setup.py ├── test.sh ├── LICENSE └── README.md /chex/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements/requirements-test.txt: -------------------------------------------------------------------------------- 1 | cloudpickle==2.2.0 2 | dm-tree>=0.1.5 3 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | include requirements/* 4 | include chex/py.typed 5 | -------------------------------------------------------------------------------- /requirements/requirements-docs.txt: -------------------------------------------------------------------------------- 1 | sphinx>=6.0.0 2 | sphinx-book-theme>=1.0.1 3 | sphinxcontrib-katex 4 | -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py>=0.9.0 2 | typing_extensions>=4.2.0 3 | jax>=0.4.16 4 | jaxlib>=0.1.37 5 | numpy>=1.24.1 6 | setuptools;python_version>="3.12" 7 | toolz>=0.9.0 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Building and releasing library: 2 | *.egg-info 3 | *.pyc 4 | *.so 5 | build/ 6 | dist/ 7 | venv/ 8 | docs/_build/ 9 | 10 | # Mac OS 11 | .DS_Store 12 | 13 | # Python tools 14 | .mypy_cache/ 15 | .pytype/ 16 | .ipynb_checkpoints 17 | 18 | # Editors 19 | .idea 20 | .vscode 21 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | version: 2 5 | 6 | build: 7 | os: ubuntu-22.04 8 | tools: 9 | python: "3.11" 10 | 11 | sphinx: 12 | builder: html 13 | configuration: docs/conf.py 14 | fail_on_warning: false 15 | 16 | python: 17 | install: 18 | - requirements: requirements/requirements-docs.txt 19 | - requirements: requirements/requirements.txt 20 | - method: setuptools 21 | path: . 22 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = -W --keep-going 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | -------------------------------------------------------------------------------- /chex/_src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: 4 | push: 5 | branches: ["master"] 6 | pull_request: 7 | branches: ["master"] 8 | schedule: 9 | - cron: '30 2 * * *' 10 | 11 | jobs: 12 | build-and-test: 13 | name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" 14 | runs-on: "${{ matrix.os }}" 15 | 16 | strategy: 17 | matrix: 18 | python-version: ["3.9", "3.10", "3.11"] 19 | os: [ubuntu-latest] 20 | 21 | steps: 22 | - uses: "actions/checkout@v2" 23 | - uses: "actions/setup-python@v4" 24 | with: 25 | python-version: "${{ matrix.python-version }}" 26 | cache: "pip" 27 | cache-dependency-path: '**/requirements*.txt' 28 | - name: Run CI tests 29 | run: bash test.sh 30 | shell: bash 31 | -------------------------------------------------------------------------------- /.github/workflows/pypi-publish.yml: -------------------------------------------------------------------------------- 1 | name: pypi 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Set up Python 13 | uses: actions/setup-python@v4 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install setuptools wheel twine 20 | - name: Check consistency between the package version and release tag 21 | run: | 22 | RELEASE_VER=${GITHUB_REF#refs/*/} 23 | PACKAGE_VER="v`python setup.py --version`" 24 | if [ $RELEASE_VER != $PACKAGE_VER ] 25 | then 26 | echo "package ver. ($PACKAGE_VER) != release ver. ($RELEASE_VER)"; exit 1 27 | fi 28 | - name: Build and publish 29 | env: 30 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 31 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 32 | run: | 33 | python setup.py sdist bdist_wheel 34 | twine upload dist/* 35 | -------------------------------------------------------------------------------- /chex/chex_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for chex.""" 16 | 17 | from absl.testing import absltest 18 | import chex 19 | 20 | 21 | class ChexTest(absltest.TestCase): 22 | """Test chex can be imported correctly.""" 23 | 24 | def test_import(self): 25 | self.assertTrue(hasattr(chex, 'assert_devices_available')) 26 | 27 | 28 | if __name__ == '__main__': 29 | absltest.main() 30 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Testing 26 | 27 | Please make sure that your PR passes all tests by running `bash test.sh` on your 28 | local machine. Also, you can run only tests that are affected by your code 29 | changes, but you will need to select them manually. 30 | 31 | ## Community Guidelines 32 | 33 | This project follows [Google's Open Source Community 34 | Guidelines](https://opensource.google.com/conduct/). 35 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/deepmind/chex/tree/master/docs 2 | 3 | Chex 4 | ----- 5 | 6 | Chex is a library of utilities for helping to write reliable JAX code. 7 | 8 | This includes utils to help: 9 | 10 | * Instrument your code (e.g. assertions) 11 | * Debug (e.g. transforming pmaps in vmaps within a context manager). 12 | * Test JAX code across many variants (e.g. jitted vs non-jitted). 13 | 14 | Modules overview can be found `on GitHub `_. 15 | 16 | Installation 17 | ------------ 18 | 19 | Chex can be installed with pip directly from github, with the following command: 20 | 21 | ``pip install git+git://github.com/deepmind/chex.git`` 22 | 23 | or from PyPI: 24 | 25 | ``pip install chex`` 26 | 27 | .. toctree:: 28 | :caption: API Documentation 29 | :maxdepth: 2 30 | 31 | api 32 | 33 | Citing Chex 34 | ----------- 35 | 36 | This repository is part of the `DeepMind JAX Ecosystem `_. 37 | 38 | To cite Chex please use the `DeepMind JAX Ecosystem citation `_. 39 | 40 | Contribute 41 | ---------- 42 | 43 | - `Issue tracker `_ 44 | - `Source code `_ 45 | 46 | Support 47 | ------- 48 | 49 | If you are having issues, please let us know by filing an issue on our 50 | `issue tracker `_. 51 | 52 | License 53 | ------- 54 | 55 | Chex is licensed under the Apache 2.0 License. 56 | 57 | 58 | Indices and Tables 59 | ================== 60 | 61 | * :ref:`genindex` 62 | -------------------------------------------------------------------------------- /chex/_src/warnings_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `warnings.py`.""" 16 | 17 | import functools 18 | 19 | from absl.testing import absltest 20 | 21 | from chex._src import warnings 22 | 23 | 24 | @functools.partial(warnings.warn_only_n_pos_args_in_future, n=1) 25 | def f(a, b, c): 26 | return a + b + c 27 | 28 | 29 | @functools.partial(warnings.warn_deprecated_function, replacement='h') 30 | def g(a, b, c): 31 | return a + b + c 32 | 33 | 34 | def h1(a, b, c): 35 | return a + b + c 36 | h2 = warnings.create_deprecated_function_alias(h1, 'path.h2', 'path.h1') 37 | 38 | 39 | class WarningsTest(absltest.TestCase): 40 | 41 | def test_warn_only_n_pos_args_in_future(self): 42 | with self.assertWarns(Warning): 43 | f(1, 2, 3) 44 | with self.assertWarns(Warning): 45 | f(1, 2, c=3) 46 | 47 | def test_warn_deprecated_function(self): 48 | with self.assertWarns(Warning): 49 | g(1, 2, 3) 50 | 51 | def test_create_deprecated_function_alias(self): 52 | with self.assertWarns(Warning): 53 | h2(1, 2, 3) 54 | 55 | 56 | if __name__ == '__main__': 57 | absltest.main() 58 | -------------------------------------------------------------------------------- /chex/_src/fake_set_n_cpu_devices_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Test for `set_n_cpu_devices` from `fake.py`. 16 | 17 | This test is isolated to ensure hermeticity because its execution changes 18 | XLA backend configuration. 19 | """ 20 | 21 | import unittest 22 | from absl.testing import absltest 23 | from chex._src import asserts 24 | from chex._src import fake 25 | 26 | 27 | class DevicesSetterTest(absltest.TestCase): 28 | 29 | def test_set_n_cpu_devices(self): 30 | try: 31 | # Should not initialize backends. 32 | fake.set_n_cpu_devices(4) 33 | except RuntimeError as set_cpu_exception: 34 | raise unittest.SkipTest( 35 | "set_n_cpu_devices: backend's already been initialized. " 36 | 'Run this test in isolation from others.') from set_cpu_exception 37 | 38 | # Hence, this one does not fail. 39 | fake.set_n_cpu_devices(6) 40 | 41 | # This assert initializes backends. 42 | asserts.assert_devices_available(6, 'cpu', backend='cpu') 43 | 44 | # Which means that next call must fail. 45 | with self.assertRaisesRegex(RuntimeError, 46 | 'Attempted to set 8 devices, but 6 CPUs.+'): 47 | fake.set_n_cpu_devices(8) 48 | 49 | 50 | if __name__ == '__main__': 51 | absltest.main() 52 | -------------------------------------------------------------------------------- /chex/_src/pytypes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Type definitions to use for type annotations.""" 16 | 17 | from typing import Any, Iterable, Mapping, Sequence, Union 18 | 19 | import jax 20 | import numpy as np 21 | 22 | # Special types of arrays. 23 | ArrayNumpy = np.ndarray 24 | 25 | # For instance checking, use `isinstance(x, jax.Array)`. 26 | ArrayDevice = jax.Array 27 | 28 | # Types for backward compatibility. 29 | ArraySharded = jax.Array 30 | ArrayBatched = jax.Array 31 | 32 | # Generic array type. 33 | # Similar to `jax.typing.ArrayLike` but does not accept python scalar types. 34 | Array = Union[ 35 | ArrayDevice, 36 | ArrayBatched, 37 | ArraySharded, # JAX array type 38 | ArrayNumpy, # NumPy array type 39 | np.bool_, 40 | np.number, # NumPy scalar types 41 | ] 42 | 43 | # A tree of generic arrays. 44 | ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']] 45 | ArrayDeviceTree = Union[ 46 | ArrayDevice, Iterable['ArrayDeviceTree'], Mapping[Any, 'ArrayDeviceTree'] 47 | ] 48 | ArrayNumpyTree = Union[ 49 | ArrayNumpy, Iterable['ArrayNumpyTree'], Mapping[Any, 'ArrayNumpyTree'] 50 | ] 51 | 52 | # Other types. 53 | Scalar = Union[float, int] 54 | Numeric = Union[Array, Scalar] 55 | Shape = Sequence[Union[int, Any]] 56 | PRNGKey = jax.Array 57 | PyTreeDef = jax.tree_util.PyTreeDef 58 | Device = jax.Device 59 | 60 | # TODO(iukemaev, jakevdp): upgrade minimum jax version & remove this condition. 61 | if hasattr(jax.typing, 'DTypeLike'): 62 | # jax version 0.4.19 or newer 63 | ArrayDType = jax.typing.DTypeLike # pylint:disable=invalid-name 64 | else: 65 | ArrayDType = Any # pylint:disable=invalid-name 66 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Install script for setuptools.""" 16 | 17 | import os 18 | from setuptools import find_packages 19 | from setuptools import setup 20 | 21 | _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 22 | 23 | 24 | def _get_version(): 25 | with open(os.path.join(_CURRENT_DIR, 'chex', '__init__.py')) as fp: 26 | for line in fp: 27 | if line.startswith('__version__') and '=' in line: 28 | version = line[line.find('=') + 1:].strip(' \'"\n') 29 | if version: 30 | return version 31 | raise ValueError('`__version__` not defined in `chex/__init__.py`') 32 | 33 | 34 | def _parse_requirements(path): 35 | with open(os.path.join(_CURRENT_DIR, path)) as f: 36 | return [ 37 | line.rstrip() 38 | for line in f 39 | if not (line.isspace() or line.startswith('#')) 40 | ] 41 | 42 | 43 | setup( 44 | name='chex', 45 | version=_get_version(), 46 | url='https://github.com/deepmind/chex', 47 | license='Apache 2.0', 48 | author='DeepMind', 49 | description=('Chex: Testing made fun, in JAX!'), 50 | long_description=open(os.path.join(_CURRENT_DIR, 'README.md')).read(), 51 | long_description_content_type='text/markdown', 52 | author_email='chex-dev@google.com', 53 | keywords='jax testing debugging python machine learning', 54 | packages=find_packages(exclude=['*_test.py']), 55 | install_requires=_parse_requirements( 56 | os.path.join(_CURRENT_DIR, 'requirements', 'requirements.txt')), 57 | tests_require=_parse_requirements( 58 | os.path.join(_CURRENT_DIR, 'requirements', 'requirements-test.txt')), 59 | zip_safe=False, # Required for full installation. 60 | include_package_data=True, 61 | python_requires='>=3.9', 62 | classifiers=[ 63 | 'Development Status :: 5 - Production/Stable', 64 | 'Environment :: Console', 65 | 'Intended Audience :: Science/Research', 66 | 'Intended Audience :: Developers', 67 | 'License :: OSI Approved :: Apache Software License', 68 | 'Operating System :: OS Independent', 69 | 'Programming Language :: Python', 70 | 'Programming Language :: Python :: 3', 71 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 72 | 'Topic :: Software Development :: Testing :: Mocking', 73 | 'Topic :: Software Development :: Testing :: Unit', 74 | 'Topic :: Software Development :: Libraries :: Python Modules', 75 | ], 76 | ) 77 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # Runs CI tests on a local machine. 17 | set -xeuo pipefail 18 | 19 | # Install deps in a virtual env. 20 | rm -rf _testing 21 | rm -rf .pytype 22 | mkdir -p _testing 23 | readonly VENV_DIR="$(mktemp -d -p `pwd`/_testing chex-env.XXXXXXXX)" 24 | # in the unlikely case in which there was something in that directory 25 | python3 -m venv "${VENV_DIR}" 26 | source "${VENV_DIR}/bin/activate" 27 | python --version 28 | 29 | # Install dependencies. 30 | pip install --upgrade pip setuptools wheel 31 | pip install flake8 pytest-xdist pylint pylint-exit 32 | pip install -r requirements/requirements.txt 33 | pip install -r requirements/requirements-test.txt 34 | 35 | # Lint with flake8. 36 | flake8 `find chex -name '*.py' | xargs` --count --select=E9,F63,F7,F82,E225,E251 --show-source --statistics 37 | 38 | # Lint with pylint. 39 | PYLINT_ARGS="-efail -wfail -cfail -rfail" 40 | # Download Google OSS config. 41 | wget -nd -v -t 3 -O .pylintrc https://google.github.io/styleguide/pylintrc 42 | # Append specific config lines. 43 | echo "signature-mutators=toolz.functoolz.curry" >> .pylintrc 44 | echo "disable=unnecessary-lambda-assignment,use-dict-literal" >> .pylintrc 45 | # Lint modules and tests separately. 46 | pylint --rcfile=.pylintrc `find chex -name '*.py' | grep -v 'test.py' | xargs` -d E1102|| pylint-exit $PYLINT_ARGS $? 47 | # Disable `protected-access` warnings for tests. 48 | pylint --rcfile=.pylintrc `find chex -name '*_test.py' | xargs` -d W0212,E1130,E1102 || pylint-exit $PYLINT_ARGS $? 49 | # Cleanup. 50 | rm .pylintrc 51 | 52 | # Build the package. 53 | python setup.py sdist 54 | pip wheel --verbose --no-deps --no-clean dist/chex*.tar.gz 55 | pip install chex*.whl 56 | 57 | # Check types with pytype. 58 | # Note: pytype does not support 3.11 as of 25.06.23 59 | # See https://github.com/google/pytype/issues/1308 60 | if [ `python -c 'import sys; print(sys.version_info.minor)'` -lt 11 ]; 61 | then 62 | pip install pytype 63 | pytype `find chex/_src -name "*py" | xargs` -k 64 | fi; 65 | 66 | # Run tests using pytest. 67 | # Change directory to avoid importing the package from repo root. 68 | pip install -r requirements/requirements-test.txt 69 | cd _testing 70 | 71 | # Main tests. 72 | pytest -n "$(grep -c ^processor /proc/cpuinfo)" --pyargs chex -k "not fake_set_n_cpu_devices_test" 73 | 74 | # Isolate tests that use `chex.set_n_cpu_device()`. 75 | pytest -n "$(grep -c ^processor /proc/cpuinfo)" --pyargs chex -k "fake_set_n_cpu_devices_test" 76 | cd .. 77 | 78 | # Build Sphinx docs. 79 | 80 | pip install -r requirements/requirements-docs.txt 81 | cd docs 82 | make coverage_check 83 | make html 84 | cd .. 85 | 86 | # cleanup 87 | rm -rf _testing 88 | 89 | set +u 90 | deactivate 91 | echo "All tests passed. Congrats!" 92 | -------------------------------------------------------------------------------- /chex/_src/warnings.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utilities to emit warnings.""" 16 | 17 | import functools 18 | import warnings 19 | 20 | 21 | def warn_only_n_pos_args_in_future(fun, n): 22 | """Warns if more than ``n`` positional arguments are passed to ``fun``. 23 | 24 | For instance: 25 | >>> @functools.partial(chex.warn_only_n_pos_args_in_future, n=1) 26 | ... def f(a, b, c=1): 27 | ... return a + b + c 28 | 29 | Will raise a DeprecationWarning if ``f`` is called with more than one 30 | positional argument (e.g. both f(1, 2, 3) and f(1, 2, c=3) raise a warning). 31 | 32 | Args: 33 | fun: the function to wrap. 34 | n: the number of positional arguments to allow. 35 | 36 | Returns: 37 | A wrapped function that emits a warning if more than `n` positional 38 | arguments are passed. 39 | """ 40 | 41 | @functools.wraps(fun) 42 | def wrapper(*args, **kwargs): 43 | if len(args) > n: 44 | warnings.warn( 45 | f'only the first {n} arguments can be passed positionally ' 46 | 'additional args will become keyword-only soon', 47 | DeprecationWarning, 48 | stacklevel=2 49 | ) 50 | return fun(*args, **kwargs) 51 | 52 | return wrapper 53 | 54 | 55 | warn_keyword_args_only_in_future = functools.partial( 56 | warn_only_n_pos_args_in_future, n=0 57 | ) 58 | 59 | 60 | def warn_deprecated_function(fun, replacement): 61 | """A decorator to mark a function definition as deprecated. 62 | 63 | Example usage: 64 | >>> @functools.partial(chex.warn_deprecated_function, replacement='g') 65 | ... def f(a, b): 66 | ... return a + b 67 | 68 | Args: 69 | fun: the deprecated function. 70 | replacement: the name of the function to be used instead. 71 | 72 | Returns: 73 | the wrapped function. 74 | """ 75 | 76 | @functools.wraps(fun) 77 | def new_fun(*args, **kwargs): 78 | warnings.warn( 79 | f'The function {fun.__name__} is deprecated, ' 80 | f'please use {replacement} instead.', 81 | category=DeprecationWarning, 82 | stacklevel=2) 83 | return fun(*args, **kwargs) 84 | return new_fun 85 | 86 | 87 | def create_deprecated_function_alias(fun, new_name, deprecated_alias): 88 | """Create a deprecated alias for a function. 89 | 90 | Example usage: 91 | >>> g = create_deprecated_function_alias(f, 'path.f', 'path.g') 92 | 93 | Args: 94 | fun: the deprecated function. 95 | new_name: the new name to use (you may include the path for clarity). 96 | deprecated_alias: the old name (you may include the path for clarity). 97 | 98 | Returns: 99 | the wrapped function. 100 | """ 101 | 102 | @functools.wraps(fun) 103 | def new_fun(*args, **kwargs): 104 | warnings.warn( 105 | f'The function {deprecated_alias} is deprecated, ' 106 | f'please use {new_name} instead.', 107 | category=DeprecationWarning, 108 | stacklevel=2) 109 | return fun(*args, **kwargs) 110 | return new_fun 111 | -------------------------------------------------------------------------------- /docs/ext/coverage_check.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Asserts all public symbols are covered in the docs.""" 16 | 17 | import inspect 18 | import types 19 | from typing import Any, Mapping, Sequence, Tuple 20 | 21 | import chex as _module 22 | from sphinx import application 23 | from sphinx import builders 24 | from sphinx import errors 25 | 26 | 27 | def find_internal_python_modules( 28 | root_module: types.ModuleType,) -> Sequence[Tuple[str, types.ModuleType]]: 29 | """Returns `(name, module)` for all submodules under `root_module`.""" 30 | modules = set([(root_module.__name__, root_module)]) 31 | visited = set() 32 | to_visit = [root_module] 33 | 34 | while to_visit: 35 | mod = to_visit.pop() 36 | visited.add(mod) 37 | 38 | for name in dir(mod): 39 | obj = getattr(mod, name) 40 | if inspect.ismodule(obj) and obj not in visited: 41 | if obj.__name__.startswith(_module.__name__): 42 | if '_src' not in obj.__name__: 43 | to_visit.append(obj) 44 | modules.add((obj.__name__, obj)) 45 | 46 | return sorted(modules) 47 | 48 | 49 | def get_public_symbols() -> Sequence[Tuple[str, types.ModuleType]]: 50 | names = set() 51 | for module_name, module in find_internal_python_modules(_module): 52 | for name in module.__all__: 53 | names.add(module_name + '.' + name) 54 | return tuple(names) 55 | 56 | 57 | class CoverageCheck(builders.Builder): 58 | """Builder that checks all public symbols are included.""" 59 | 60 | name = 'coverage_check' 61 | 62 | def get_outdated_docs(self) -> str: 63 | return 'coverage_check' 64 | 65 | def write(self, *ignored: Any) -> None: 66 | pass 67 | 68 | def finish(self) -> None: 69 | documented_objects = frozenset(self.env.domaindata['py']['objects']) 70 | undocumented_objects = set(get_public_symbols()) - documented_objects 71 | 72 | # Exclude deprecated API symbols. 73 | assertion_exceptions = ('assert_tree_all_close', 74 | 'assert_tree_all_equal_comparator', 75 | 'assert_tree_all_equal_shapes', 76 | 'assert_tree_all_equal_structs') 77 | undocumented_objects -= {'chex.' + s for s in assertion_exceptions} 78 | 79 | # Exclude pytypes. 80 | pytypes_exceptions = ( 81 | 'Array', 82 | 'ArrayBatched', 83 | 'Array', 84 | 'ArrayBatched', 85 | 'ArrayDevice', 86 | 'ArrayDeviceTree', 87 | 'ArrayDType', 88 | 'ArrayNumpy', 89 | 'ArrayNumpyTree', 90 | 'ArraySharded', 91 | 'ArrayTree', 92 | 'Device', 93 | 'Numeric', 94 | 'PRNGKey', 95 | 'PyTreeDef', 96 | 'Scalar', 97 | 'Shape', 98 | ) 99 | 100 | # Exclude public constants. 101 | pytypes_exceptions += ('ChexifyChecks',) 102 | 103 | undocumented_objects -= {'chex.' + s for s in pytypes_exceptions} 104 | 105 | if undocumented_objects: 106 | undocumented_objects = tuple(sorted(undocumented_objects)) 107 | raise errors.SphinxError( 108 | 'All public symbols must be included in our documentation, did you ' 109 | 'forget to add an entry to `api.rst`?\n' 110 | f'Undocumented symbols: {undocumented_objects}') 111 | 112 | 113 | def setup(app: application.Sphinx) -> Mapping[str, Any]: 114 | app.add_builder(CoverageCheck) 115 | return dict(version=_module.__version__, parallel_read_safe=True) 116 | -------------------------------------------------------------------------------- /chex/_src/restrict_backends.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A context manager that objects to JAX compilation for specified backends. 16 | 17 | This is useful, for example, when certain JAX code needs to run in an 18 | environment where an accelerator is present but reserved for other purposes. 19 | Typically one would use `jax.jit(..., backend='cpu')` to keep the code away 20 | from the accelerator, but it is hard to check by hand that this has been done 21 | without exception throughout an entire subsystem. Then, `restrict_backends()` 22 | can be used to detect any overlooked case and report it by raising an exception. 23 | 24 | Similarly, it can be useful for a system such as a learner to make sure that 25 | all required JAX programs have been assigned to their respective backends by 26 | the end of its first iteration; this helps to show that it will not later run 27 | into memory fragmentation problems. By entering a `restrict_backends()` context 28 | at the end of the first iteration, the system can detect any overlooked cases. 29 | """ 30 | import contextlib 31 | import functools 32 | from typing import Optional, Sequence 33 | 34 | # pylint: disable=g-import-not-at-top 35 | try: 36 | from jax._src import compiler 37 | except ImportError: 38 | # TODO(phawkins): remove this path after jax>=0.4.15 is the minimum version 39 | # required by chex. 40 | from jax._src import dispatch as compiler # type: ignore 41 | # pylint: enable=g-import-not-at-top 42 | 43 | 44 | class RestrictedBackendError(RuntimeError): 45 | pass 46 | 47 | 48 | @contextlib.contextmanager 49 | def restrict_backends( 50 | *, 51 | allowed: Optional[Sequence[str]] = None, 52 | forbidden: Optional[Sequence[str]] = None): 53 | """Disallows JAX compilation for certain backends. 54 | 55 | Args: 56 | allowed: Names of backend platforms (e.g. 'cpu' or 'tpu') for which 57 | compilation is still to be permitted. 58 | forbidden: Names of backend platforms for which compilation is to be 59 | forbidden. 60 | 61 | Yields: 62 | None, in a context where compilation for forbidden platforms will raise 63 | a `RestrictedBackendError`. 64 | 65 | Raises: 66 | ValueError: if neither `allowed` nor `forbidden` is specified (i.e. they 67 | are both `None`), or if anything is both allowed and forbidden. 68 | """ 69 | allowed = tuple(allowed) if allowed is not None else None 70 | forbidden = tuple(forbidden) if forbidden is not None else None 71 | 72 | if allowed is None and forbidden is None: 73 | raise ValueError('No restrictions specified.') 74 | contradictions = set(allowed or ()) & set(forbidden or ()) 75 | if contradictions: 76 | raise ValueError( 77 | f"Backends {contradictions} can't be both allowed and forbidden.") 78 | 79 | def is_allowed(backend_platform): 80 | return ((backend_platform in allowed) if allowed is not None else 81 | (backend_platform not in forbidden)) 82 | 83 | inner_backend_compile = compiler.backend_compile 84 | 85 | @functools.wraps(inner_backend_compile) 86 | def wrapper(backend, *args, **kwargs): 87 | if not is_allowed(backend.platform): 88 | raise RestrictedBackendError( 89 | f'Compiling a JAX program for {backend.platform} is forbidden by ' 90 | f'restrict_backends().') 91 | return inner_backend_compile(backend, *args, **kwargs) 92 | 93 | try: 94 | compiler.backend_compile = wrapper 95 | yield 96 | finally: 97 | backend_compile = compiler.backend_compile 98 | assert backend_compile is wrapper, backend_compile 99 | compiler.backend_compile = inner_backend_compile 100 | -------------------------------------------------------------------------------- /chex/_src/restrict_backends_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `restrict_backends.py`.""" 16 | from absl.testing import absltest 17 | from chex._src import restrict_backends 18 | import jax 19 | import jax.numpy as jnp 20 | import numpy as np 21 | 22 | 23 | def compute_cube(side): 24 | return jnp.sum(jnp.ones((side, side)) * side) 25 | 26 | 27 | class RestrictBackendsTest(absltest.TestCase): 28 | 29 | # These tests need an accelerator of some sort, so that JAX can try to use it. 30 | def setUp(self): 31 | super().setUp() 32 | 33 | try: 34 | jax.devices('gpu') 35 | gpu_backend_available = True 36 | except RuntimeError: 37 | gpu_backend_available = False 38 | 39 | try: 40 | jax.devices('tpu') 41 | tpu_backend_available = True 42 | except RuntimeError: 43 | tpu_backend_available = False 44 | 45 | if not tpu_backend_available or gpu_backend_available: 46 | self.skipTest('No known accelerator backends are available, so these ' 47 | 'tests will not test anything useful.') 48 | 49 | def test_detects_implicitly_forbidden_tpu_computation(self): 50 | with self.assertRaisesRegex(restrict_backends.RestrictedBackendError, 51 | r'forbidden by restrict_backends'): 52 | with restrict_backends.restrict_backends(allowed=['cpu']): 53 | compute_cube(3) 54 | # Make sure the restriction is no longer in place. 55 | np.testing.assert_array_equal(compute_cube(3), 27) 56 | 57 | def test_detects_explicitly_forbidden_tpu_computation(self): 58 | with self.assertRaisesRegex(restrict_backends.RestrictedBackendError, 59 | r'forbidden by restrict_backends'): 60 | with restrict_backends.restrict_backends(forbidden=['tpu', 'gpu']): 61 | compute_cube(2) 62 | # Make sure the restriction is no longer in place. 63 | np.testing.assert_array_equal(compute_cube(2), 8) 64 | 65 | def test_detects_implicitly_forbidden_cpu_computation(self): 66 | with self.assertRaisesRegex(restrict_backends.RestrictedBackendError, 67 | r'forbidden by restrict_backends'): 68 | with restrict_backends.restrict_backends(allowed=['tpu', 'gpu']): 69 | jax.jit(lambda: compute_cube(8), backend='cpu')() 70 | # Make sure the restriction is no longer in place. 71 | np.testing.assert_array_equal(compute_cube(8), 512) 72 | 73 | def test_detects_explicitly_forbidden_cpu_computation(self): 74 | with self.assertRaisesRegex(restrict_backends.RestrictedBackendError, 75 | r'forbidden by restrict_backends'): 76 | with restrict_backends.restrict_backends(forbidden=['cpu']): 77 | jax.jit(lambda: compute_cube(9), backend='cpu')() 78 | # Make sure the restriction is no longer in place. 79 | np.testing.assert_array_equal(compute_cube(9), 729) 80 | 81 | def test_ignores_explicitly_allowed_cpu_computation(self): 82 | with restrict_backends.restrict_backends(allowed=['cpu']): 83 | c = jax.jit(lambda: compute_cube(4), backend='cpu')() 84 | np.testing.assert_array_equal(c, 64) 85 | 86 | def test_ignores_implicitly_allowed_cpu_computation(self): 87 | with restrict_backends.restrict_backends(forbidden=['tpu', 'gpu']): 88 | c = jax.jit(lambda: compute_cube(5), backend='cpu')() 89 | np.testing.assert_array_equal(c, 125) 90 | 91 | def test_ignores_explicitly_allowed_tpu_computation(self): 92 | with restrict_backends.restrict_backends(allowed=['tpu', 'gpu']): 93 | c = jax.jit(lambda: compute_cube(6))() 94 | np.testing.assert_array_equal(c, 216) 95 | 96 | def test_ignores_implicitly_allowed_tpu_computation(self): 97 | with restrict_backends.restrict_backends(forbidden=['cpu']): 98 | c = jax.jit(lambda: compute_cube(7))() 99 | np.testing.assert_array_equal(c, 343) 100 | 101 | def test_raises_if_no_restrictions_specified(self): 102 | with self.assertRaisesRegex(ValueError, r'No restrictions specified'): 103 | with restrict_backends.restrict_backends(): 104 | pass 105 | 106 | def test_raises_if_contradictory_restrictions_specified(self): 107 | with self.assertRaisesRegex(ValueError, r"can't be both"): 108 | with restrict_backends.restrict_backends( 109 | allowed=['cpu'], forbidden=['cpu']): 110 | pass 111 | 112 | 113 | if __name__ == '__main__': 114 | absltest.main() 115 | -------------------------------------------------------------------------------- /chex/_src/asserts_internal_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `asserts_internal.py`.""" 16 | 17 | import functools 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | from chex._src import asserts_internal as ai 22 | from chex._src import variants 23 | import jax 24 | import jax.numpy as jnp 25 | 26 | 27 | class IsTraceableTest(variants.TestCase): 28 | 29 | @variants.variants(with_jit=True, with_pmap=True) 30 | def test_is_traceable(self): 31 | def dummy_wrapper(fn): 32 | 33 | @functools.wraps(fn) 34 | def fn_wrapped(fn, *args): 35 | return fn(args) 36 | 37 | return fn_wrapped 38 | 39 | fn = lambda x: x.sum() 40 | wrapped_fn = dummy_wrapper(fn) 41 | self.assertFalse(ai.is_traceable(fn)) 42 | self.assertFalse(ai.is_traceable(wrapped_fn)) 43 | 44 | var_fn = self.variant(fn) 45 | wrapped_var_f = dummy_wrapper(var_fn) 46 | var_wrapped_f = self.variant(wrapped_fn) 47 | self.assertTrue(ai.is_traceable(var_fn)) 48 | self.assertTrue(ai.is_traceable(wrapped_var_f)) 49 | self.assertTrue(ai.is_traceable(var_wrapped_f)) 50 | 51 | 52 | class ExceptionMessageFormatTest(variants.TestCase): 53 | 54 | @parameterized.product( 55 | include_default_msg=(False, True), 56 | include_custom_msg=(False, True), 57 | exc_type=(AssertionError, ValueError), 58 | ) 59 | def test_format(self, include_default_msg, include_custom_msg, exc_type): 60 | 61 | exc_msg = lambda x: f'{x} is non-positive.' 62 | 63 | @functools.partial(ai.chex_assertion, jittable_assert_fn=None) 64 | def assert_positive(x): 65 | if x <= 0: 66 | raise AssertionError(exc_msg(x)) 67 | 68 | @functools.partial(ai.chex_assertion, jittable_assert_fn=None) 69 | def assert_each_positive(*args): 70 | for x in args: 71 | assert_positive(x) 72 | 73 | # Pass. 74 | assert_positive(1) 75 | assert_each_positive(1, 2, 3) 76 | 77 | # Check the format of raised exceptions' messages. 78 | def expected_exc_msg(x, custom_msg): 79 | msg = exc_msg(x) if include_default_msg else '' 80 | msg = rf'{msg} \[{custom_msg}\]' if custom_msg else msg 81 | return msg 82 | 83 | # Run in a loop to generate different custom messages. 84 | for i in range(3): 85 | custom_msg = f'failed at iter {i}' if include_custom_msg else '' 86 | 87 | with self.assertRaisesRegex( 88 | exc_type, ai.get_err_regex(expected_exc_msg(-1, custom_msg))): 89 | assert_positive( # pylint:disable=unexpected-keyword-arg 90 | -1, 91 | custom_message=custom_msg, 92 | include_default_message=include_default_msg, 93 | exception_type=exc_type) 94 | 95 | with self.assertRaisesRegex( 96 | exc_type, ai.get_err_regex(expected_exc_msg(-3, custom_msg))): 97 | assert_each_positive( # pylint:disable=unexpected-keyword-arg 98 | 1, 99 | -3, 100 | 2, 101 | custom_message=custom_msg, 102 | include_default_message=include_default_msg, 103 | exception_type=exc_type) 104 | 105 | 106 | class JitCompatibleTest(variants.TestCase): 107 | 108 | def test_api(self): 109 | 110 | def assert_fn(x): 111 | if x.shape != (2,): 112 | raise AssertionError(f'shape != (2,) {x.shape}!') 113 | 114 | for transform_fn in (jax.jit, jax.grad, jax.vmap): 115 | x_ok = jnp.ones((2,)) 116 | x_wrong = jnp.ones((3,)) 117 | is_vmap = transform_fn is jax.vmap 118 | if is_vmap: 119 | x_ok, x_wrong = (jnp.expand_dims(x, 0) for x in (x_ok, x_wrong)) 120 | 121 | # Jax-compatible. 122 | assert_compat_fn = ai.chex_assertion(assert_fn, jittable_assert_fn=None) 123 | 124 | def compat_fn(x, assertion=assert_compat_fn): 125 | assertion(x) 126 | return x.sum() 127 | 128 | if not is_vmap: 129 | compat_fn(x_ok) 130 | transform_fn(compat_fn)(x_ok) 131 | with self.assertRaisesRegex(AssertionError, 'shape !='): 132 | transform_fn(compat_fn)(x_wrong) 133 | 134 | # JAX-incompatible. 135 | assert_incompat_fn = ai.chex_assertion( 136 | assert_fn, jittable_assert_fn=assert_fn) 137 | 138 | def incompat_fn(x, assertion=assert_incompat_fn): 139 | assertion(x) 140 | return x.sum() 141 | 142 | if not is_vmap: 143 | incompat_fn(x_ok) 144 | with self.assertRaisesRegex(RuntimeError, 145 | 'Value assertions can only be called from'): 146 | transform_fn(incompat_fn)(x_wrong) 147 | 148 | 149 | if __name__ == '__main__': 150 | jax.config.update('jax_numpy_rank_promotion', 'raise') 151 | absltest.main() 152 | -------------------------------------------------------------------------------- /chex/_src/dimensions_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for ``dimensions`` module.""" 16 | 17 | import doctest 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | from chex._src import asserts 22 | from chex._src import dimensions 23 | import jax 24 | import numpy as np 25 | 26 | 27 | class _ChexModule: 28 | """Mock module for providing minimal context to docstring tests.""" 29 | assert_shape = asserts.assert_shape 30 | assert_rank = asserts.assert_rank 31 | Dimensions = dimensions.Dimensions # pylint: disable=invalid-name 32 | 33 | 34 | class DimensionsTest(parameterized.TestCase): 35 | 36 | def test_docstring_examples(self): 37 | doctest.run_docstring_examples( 38 | dimensions.Dimensions, 39 | globs={'chex': _ChexModule, 'jax': jax, 'jnp': jax.numpy}) 40 | 41 | @parameterized.named_parameters([ 42 | ('scalar', '', (), ()), 43 | ('vector', 'a', (7,), (7,)), 44 | ('list', 'ab', [7, 11], (7, 11)), 45 | ('numpy_array', 'abc', np.array([7, 11, 13]), (7, 11, 13)), 46 | ('case_sensitive', 'aA', (7, 11), (7, 11)), 47 | ]) 48 | def test_set_ok(self, k, v, shape): 49 | dims = dimensions.Dimensions(x=23, y=29) 50 | dims[k] = v 51 | asserts.assert_shape(np.empty((23, *shape, 29)), dims['x' + k + 'y']) 52 | 53 | def test_set_wildcard(self): 54 | dims = dimensions.Dimensions(x=23, y=29) 55 | dims['a_b__'] = (7, 11, 13, 17, 19) 56 | self.assertEqual(dims['xayb'], (23, 7, 29, 13)) 57 | with self.assertRaisesRegex(KeyError, r'\*'): 58 | dims['ab*'] = (7, 11, 13) 59 | 60 | def test_get_wildcard(self): 61 | dims = dimensions.Dimensions(x=23, y=29) 62 | self.assertEqual(dims['x*y**'], (23, None, 29, None, None)) 63 | asserts.assert_shape(np.empty((23, 1, 29, 2, 3)), dims['x*y**']) 64 | with self.assertRaisesRegex(KeyError, r'\_'): 65 | _ = dims['xy_'] 66 | 67 | def test_get_literals(self): 68 | dims = dimensions.Dimensions(x=23, y=29) 69 | self.assertEqual(dims['x1y23'], (23, 1, 29, 2, 3)) 70 | 71 | @parameterized.named_parameters([ 72 | ('scalar', 'a', 7, TypeError, r'value must be sized'), 73 | ('iterator', 'a', (x for x in [7]), TypeError, r'value must be sized'), 74 | ('len_mismatch', 'ab', (7, 11, 13), ValueError, r'different length'), 75 | ('non_integer_size', 'a', (7.001,), 76 | TypeError, r'cannot be interpreted as a python int'), 77 | ('bad_key_type', 13, (7,), TypeError, r'key must be a string'), 78 | ('bad_key_string', '@%^#', (7, 11, 13, 17), KeyError, r'\@'), 79 | ]) 80 | def test_set_exception(self, k, v, e, m): 81 | dims = dimensions.Dimensions(x=23, y=29) 82 | with self.assertRaisesRegex(e, m): 83 | dims[k] = v 84 | 85 | @parameterized.named_parameters([ 86 | ('bad_key_type', 13, TypeError, r'key must be a string'), 87 | ('bad_key_string', '@%^#', KeyError, r'\@'), 88 | ]) 89 | def test_get_exception(self, k, e, m): 90 | dims = dimensions.Dimensions(x=23, y=29) 91 | with self.assertRaisesRegex(e, m): 92 | _ = dims[k] 93 | 94 | @parameterized.named_parameters([ 95 | ('scalar', '', (), 1), 96 | ('nonscalar', 'ab', (3, 5), 15), 97 | ]) 98 | def test_size_ok(self, names, shape, expected_size): 99 | dims = dimensions.Dimensions(**dict(zip(names, shape))) 100 | self.assertEqual(dims.size(names), expected_size) 101 | 102 | @parameterized.named_parameters([ 103 | ('named', 'ab'), 104 | ('asterisk', 'a*'), 105 | ('zero', 'a0'), 106 | ('negative', 'ac'), 107 | ]) 108 | def test_size_fail_wildcard(self, names): 109 | dims = dimensions.Dimensions(a=3, b=None, c=-1) 110 | with self.assertRaisesRegex(ValueError, r'cannot take product of shape'): 111 | dims.size(names) 112 | 113 | @parameterized.named_parameters([ 114 | ('trivial_start', '(a)bc', (3, 5, 7)), 115 | ('trivial_mid', 'a(b)c', (3, 5, 7)), 116 | ('trivial_end', 'ab(c)', (3, 5, 7)), 117 | ('start', '(ab)cd', (15, 7, 11)), 118 | ('mid', 'a(bc)d', (3, 35, 11)), 119 | ('end', 'ab(cd)', (3, 5, 77)), 120 | ('multiple', '(ab)(cd)', (15, 77)), 121 | ('all', '(abc)', (105,)), 122 | ]) 123 | def test_flatten_ok(self, named_shape, expected_shape): 124 | dims = dimensions.Dimensions(a=3, b=5, c=7, d=11) 125 | self.assertEqual(dims[named_shape], expected_shape) 126 | 127 | @parameterized.named_parameters([ 128 | ('unmatched_open', '(ab', r'unmatched parentheses in named shape'), 129 | ('unmatched_closed', 'a)b', r'unmatched parentheses in named shape'), 130 | ('nested', '(a(bc))', r'nested parentheses are unsupported'), 131 | ('wildcard_named', 'a(bx)', r'cannot take product of shape'), 132 | ('wildcard_asterisk', '(a*)b', r'cannot take product of shape'), 133 | ('zero_sized_dim', '(a0)b', r'cannot take product of shape'), 134 | ('neg_sized_dim', '(ay)b', r'cannot take product of shape'), 135 | ('empty_start', '()ab', r'found empty parentheses in named shape'), 136 | ('empty_mid', 'a()b', r'found empty parentheses in named shape'), 137 | ('empty_end', 'ab()', r'found empty parentheses in named shape'), 138 | ('empty_solo', '()', r'found empty parentheses in named shape'), 139 | ]) 140 | def test_flatten_fail(self, named_shape, error_message): 141 | dims = dimensions.Dimensions(a=3, b=5, x=None, y=-1) 142 | with self.assertRaisesRegex(ValueError, error_message): 143 | _ = dims[named_shape] 144 | 145 | 146 | if __name__ == '__main__': 147 | absltest.main() 148 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | Assertions 2 | ========== 3 | 4 | .. currentmodule:: chex 5 | 6 | .. autosummary:: 7 | 8 | assert_axis_dimension 9 | assert_axis_dimension_comparator 10 | assert_axis_dimension_gt 11 | assert_axis_dimension_gteq 12 | assert_axis_dimension_lt 13 | assert_axis_dimension_lteq 14 | assert_devices_available 15 | assert_equal 16 | assert_equal_rank 17 | assert_equal_size 18 | assert_equal_shape 19 | assert_equal_shape_prefix 20 | assert_equal_shape_suffix 21 | assert_exactly_one_is_none 22 | assert_gpu_available 23 | assert_is_broadcastable 24 | assert_is_divisible 25 | assert_max_traces 26 | assert_not_both_none 27 | assert_numerical_grads 28 | assert_rank 29 | assert_scalar 30 | assert_scalar_in 31 | assert_scalar_negative 32 | assert_scalar_non_negative 33 | assert_scalar_positive 34 | assert_size 35 | assert_shape 36 | assert_tpu_available 37 | assert_tree_all_finite 38 | assert_tree_has_only_ndarrays 39 | assert_tree_is_on_device 40 | assert_tree_is_on_host 41 | assert_tree_is_sharded 42 | assert_tree_no_nones 43 | assert_tree_shape_prefix 44 | assert_tree_shape_suffix 45 | assert_trees_all_close 46 | assert_trees_all_close_ulp 47 | assert_trees_all_equal 48 | assert_trees_all_equal_comparator 49 | assert_trees_all_equal_dtypes 50 | assert_trees_all_equal_sizes 51 | assert_trees_all_equal_shapes 52 | assert_trees_all_equal_shapes_and_dtypes 53 | assert_trees_all_equal_structs 54 | assert_type 55 | chexify 56 | ChexifyChecks 57 | with_jittable_assertions 58 | block_until_chexify_assertions_complete 59 | Dimensions 60 | disable_asserts 61 | enable_asserts 62 | clear_trace_counter 63 | if_args_not_none 64 | 65 | 66 | Jax Assertions 67 | ~~~~~~~~~~~~~~ 68 | 69 | .. autofunction:: assert_max_traces 70 | .. autofunction:: assert_devices_available 71 | .. autofunction:: assert_gpu_available 72 | .. autofunction:: assert_tpu_available 73 | 74 | 75 | Value (Runtime) Assertions 76 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 77 | 78 | .. autofunction:: chexify 79 | .. autosummary:: ChexifyChecks 80 | .. autofunction:: with_jittable_assertions 81 | .. autofunction:: block_until_chexify_assertions_complete 82 | 83 | 84 | Tree Assertions 85 | ~~~~~~~~~~~~~~~ 86 | 87 | .. autofunction:: assert_tree_all_finite 88 | .. autofunction:: assert_tree_has_only_ndarrays 89 | .. autofunction:: assert_tree_is_on_device 90 | .. autofunction:: assert_tree_is_on_host 91 | .. autofunction:: assert_tree_is_sharded 92 | .. autofunction:: assert_tree_no_nones 93 | .. autofunction:: assert_tree_shape_prefix 94 | .. autofunction:: assert_tree_shape_suffix 95 | .. autofunction:: assert_trees_all_close 96 | .. autofunction:: assert_trees_all_close_ulp 97 | .. autofunction:: assert_trees_all_equal 98 | .. autofunction:: assert_trees_all_equal_comparator 99 | .. autofunction:: assert_trees_all_equal_dtypes 100 | .. autofunction:: assert_trees_all_equal_sizes 101 | .. autofunction:: assert_trees_all_equal_shapes 102 | .. autofunction:: assert_trees_all_equal_shapes_and_dtypes 103 | .. autofunction:: assert_trees_all_equal_structs 104 | 105 | 106 | Generic Assertions 107 | ~~~~~~~~~~~~~~~~~~ 108 | 109 | .. autofunction:: assert_axis_dimension 110 | .. autofunction:: assert_axis_dimension_comparator 111 | .. autofunction:: assert_axis_dimension_gt 112 | .. autofunction:: assert_axis_dimension_gteq 113 | .. autofunction:: assert_axis_dimension_lt 114 | .. autofunction:: assert_axis_dimension_lteq 115 | .. autofunction:: assert_equal 116 | .. autofunction:: assert_equal_rank 117 | .. autofunction:: assert_equal_size 118 | .. autofunction:: assert_equal_shape 119 | .. autofunction:: assert_equal_shape_prefix 120 | .. autofunction:: assert_equal_shape_suffix 121 | .. autofunction:: assert_exactly_one_is_none 122 | .. autofunction:: assert_is_broadcastable 123 | .. autofunction:: assert_is_divisible 124 | .. autofunction:: assert_not_both_none 125 | .. autofunction:: assert_numerical_grads 126 | .. autofunction:: assert_rank 127 | .. autofunction:: assert_scalar 128 | .. autofunction:: assert_scalar_in 129 | .. autofunction:: assert_scalar_negative 130 | .. autofunction:: assert_scalar_non_negative 131 | .. autofunction:: assert_scalar_positive 132 | .. autofunction:: assert_size 133 | .. autofunction:: assert_shape 134 | .. autofunction:: assert_type 135 | 136 | 137 | Shapes and Named Dimensions 138 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 139 | 140 | .. autoclass:: Dimensions 141 | 142 | 143 | Utils 144 | ~~~~~ 145 | 146 | .. autofunction:: disable_asserts 147 | .. autofunction:: enable_asserts 148 | .. autofunction:: clear_trace_counter 149 | .. autofunction:: if_args_not_none 150 | 151 | 152 | Warnings 153 | ========== 154 | 155 | .. currentmodule:: chex 156 | 157 | .. autofunction:: create_deprecated_function_alias 158 | .. autofunction:: warn_deprecated_function 159 | .. autofunction:: warn_keyword_args_only_in_future 160 | .. autofunction:: warn_only_n_pos_args_in_future 161 | 162 | 163 | Backend restriction 164 | =================== 165 | 166 | .. currentmodule:: chex 167 | 168 | .. autofunction:: restrict_backends 169 | 170 | 171 | Dataclasses 172 | =========== 173 | 174 | .. currentmodule:: chex 175 | 176 | .. autofunction:: dataclass 177 | .. autofunction:: mappable_dataclass 178 | .. autofunction:: register_dataclass_type_with_jax_tree_util 179 | 180 | 181 | Fakes 182 | ===== 183 | 184 | .. currentmodule:: chex 185 | 186 | .. autosummary:: 187 | 188 | fake_jit 189 | fake_pmap 190 | fake_pmap_and_jit 191 | set_n_cpu_devices 192 | 193 | Transformations 194 | ~~~~~~~~~~~~~~~ 195 | 196 | .. autofunction:: fake_jit 197 | .. autofunction:: fake_pmap 198 | .. autofunction:: fake_pmap_and_jit 199 | 200 | 201 | Devices 202 | ~~~~~~~ 203 | 204 | .. autofunction:: set_n_cpu_devices 205 | 206 | 207 | Pytypes 208 | ======= 209 | 210 | .. currentmodule:: chex 211 | 212 | .. autosummary:: 213 | 214 | Array 215 | ArrayBatched 216 | ArrayDevice 217 | ArrayDeviceTree 218 | ArrayDType 219 | ArrayNumpy 220 | ArrayNumpyTree 221 | ArraySharded 222 | ArrayTree 223 | Device 224 | Numeric 225 | PRNGKey 226 | PyTreeDef 227 | Scalar 228 | Shape 229 | 230 | 231 | 232 | Variants 233 | ======== 234 | .. currentmodule:: chex 235 | 236 | .. autoclass:: ChexVariantType 237 | .. autoclass:: TestCase 238 | .. autofunction:: variants 239 | .. autofunction:: all_variants 240 | .. autofunction:: params_product 241 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Configuration file for the Sphinx documentation builder.""" 16 | 17 | # This file only contains a selection of the most common options. For a full 18 | # list see the documentation: 19 | # http://www.sphinx-doc.org/en/master/config 20 | 21 | # -- Path setup -------------------------------------------------------------- 22 | 23 | # If extensions (or modules to document with autodoc) are in another directory, 24 | # add these directories to sys.path here. If the directory is relative to the 25 | # documentation root, use os.path.abspath to make it absolute, like shown here. 26 | 27 | # pylint: disable=g-bad-import-order 28 | # pylint: disable=g-import-not-at-top 29 | import inspect 30 | import os 31 | import sys 32 | 33 | 34 | def _add_annotations_import(path): 35 | """Appends a future annotations import to the file at the given path.""" 36 | with open(path) as f: 37 | contents = f.read() 38 | if contents.startswith('from __future__ import annotations'): 39 | # If we run sphinx multiple times then we will append the future import 40 | # multiple times too. 41 | return 42 | 43 | assert contents.startswith('#'), (path, contents.split('\n')[0]) 44 | with open(path, 'w') as f: 45 | # NOTE: This is subtle and not unit tested, we're prefixing the first line 46 | # in each Python file with this future import. It is important to prefix 47 | # not insert a newline such that source code locations are accurate (we link 48 | # to GitHub). The assertion above ensures that the first line in the file is 49 | # a comment so it is safe to prefix it. 50 | f.write('from __future__ import annotations ') 51 | f.write(contents) 52 | 53 | 54 | def _recursive_add_annotations_import(): 55 | for path, _, files in os.walk('../chex/'): 56 | for file in files: 57 | if file.endswith('.py'): 58 | _add_annotations_import(os.path.abspath(os.path.join(path, file))) 59 | 60 | if 'READTHEDOCS' in os.environ: 61 | _recursive_add_annotations_import() 62 | 63 | sys.path.insert(0, os.path.abspath('../')) 64 | sys.path.append(os.path.abspath('ext')) 65 | 66 | import chex 67 | from sphinxcontrib import katex 68 | 69 | # -- Project information ----------------------------------------------------- 70 | 71 | project = 'Chex' 72 | copyright = '2021, DeepMind' # pylint: disable=redefined-builtin 73 | author = 'Chex Contributors' 74 | 75 | # -- General configuration --------------------------------------------------- 76 | 77 | master_doc = 'index' 78 | 79 | # Add any Sphinx extension module names here, as strings. They can be 80 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 81 | # ones. 82 | extensions = [ 83 | 'sphinx.ext.autodoc', 84 | 'sphinx.ext.autosummary', 85 | 'sphinx.ext.doctest', 86 | 'sphinx.ext.inheritance_diagram', 87 | 'sphinx.ext.intersphinx', 88 | 'sphinx.ext.linkcode', 89 | 'sphinx.ext.napoleon', 90 | 'sphinxcontrib.katex', 91 | 'coverage_check', 92 | ] 93 | 94 | # Add any paths that contain templates here, relative to this directory. 95 | templates_path = ['_templates'] 96 | 97 | # List of patterns, relative to source directory, that match files and 98 | # directories to ignore when looking for source files. 99 | # This pattern also affects html_static_path and html_extra_path. 100 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 101 | 102 | # -- Options for autodoc ----------------------------------------------------- 103 | 104 | autodoc_default_options = { 105 | 'member-order': 'bysource', 106 | 'special-members': True, 107 | 'exclude-members': '__repr__, __str__, __weakref__', 108 | } 109 | 110 | # -- Options for HTML output ------------------------------------------------- 111 | 112 | # The theme to use for HTML and HTML Help pages. See the documentation for 113 | # a list of builtin themes. 114 | html_theme = 'sphinx_book_theme' 115 | 116 | html_theme_options = { 117 | 'show_toc_level': 2, 118 | 'repository_url': 'https://github.com/google-deepmind/chex', 119 | 'use_repository_button': True, # add a "link to repository" button 120 | 'navigation_with_keys': False, 121 | } 122 | 123 | # Add any paths that contain custom static files (such as style sheets) here, 124 | # relative to this directory. They are copied after the builtin static files, 125 | # so a file named "default.css" will overwrite the builtin "default.css". 126 | 127 | html_static_path = [] 128 | 129 | # -- Options for katex ------------------------------------------------------ 130 | 131 | # See: https://sphinxcontrib-katex.readthedocs.io/en/0.4.1/macros.html 132 | latex_macros = r""" 133 | \def \d #1{\operatorname{#1}} 134 | """ 135 | 136 | # Translate LaTeX macros to KaTeX and add to options for HTML builder 137 | katex_macros = katex.latex_defs_to_katex_macros(latex_macros) 138 | katex_options = ( 139 | '{displayMode: true, fleqn: true, macros: {' + katex_macros + '}}' 140 | ) 141 | 142 | # Add LaTeX macros for LATEX builder 143 | latex_elements = {'preamble': latex_macros} 144 | 145 | # -- Source code links ------------------------------------------------------- 146 | 147 | 148 | def linkcode_resolve(domain, info): 149 | """Resolve a GitHub URL corresponding to Python object.""" 150 | if domain != 'py': 151 | return None 152 | 153 | try: 154 | mod = sys.modules[info['module']] 155 | except ImportError: 156 | return None 157 | 158 | obj = mod 159 | try: 160 | for attr in info['fullname'].split('.'): 161 | obj = getattr(obj, attr) 162 | except AttributeError: 163 | return None 164 | else: 165 | obj = inspect.unwrap(obj) 166 | 167 | try: 168 | filename = inspect.getsourcefile(obj) 169 | except TypeError: 170 | return None 171 | 172 | try: 173 | source, lineno = inspect.getsourcelines(obj) 174 | except OSError: 175 | return None 176 | 177 | # TODO(slebedev): support tags after we release an initial version. 178 | return ( 179 | 'https://github.com/google-deepmind/chex/tree/main/chex/%s#L%d#L%d' 180 | % ( 181 | os.path.relpath(filename, start=os.path.dirname(chex.__file__)), 182 | lineno, 183 | lineno + len(source) - 1, 184 | ) 185 | ) 186 | 187 | 188 | # -- Intersphinx configuration ----------------------------------------------- 189 | 190 | intersphinx_mapping = { 191 | 'jax': ('https://jax.readthedocs.io/en/latest/', None), 192 | } 193 | 194 | source_suffix = ['.rst', '.md', '.ipynb'] 195 | -------------------------------------------------------------------------------- /chex/_src/dimensions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utilities to hold expected dimension sizes.""" 16 | 17 | import math 18 | import re 19 | from typing import Any, Collection, Dict, Optional, Sized, Tuple 20 | 21 | 22 | Shape = Tuple[Optional[int], ...] 23 | 24 | 25 | class Dimensions: 26 | """A lightweight utility that maps strings to shape tuples. 27 | 28 | The most basic usage is: 29 | 30 | .. code:: 31 | 32 | >>> dims = chex.Dimensions(B=3, T=5, N=7) # You can specify any letters. 33 | >>> dims['NBT'] 34 | (7, 3, 5) 35 | 36 | This is useful when dealing with many differently shaped arrays. For instance, 37 | let's check the shape of this array: 38 | 39 | .. code:: 40 | 41 | >>> x = jnp.array([[2, 0, 5, 6, 3], 42 | ... [5, 4, 4, 3, 3], 43 | ... [0, 0, 5, 2, 0]]) 44 | >>> chex.assert_shape(x, dims['BT']) 45 | 46 | The dimension sizes can be gotten directly, e.g. :code:`dims.N == 7`. This can 47 | be useful in many applications. For instance, let's one-hot encode our array. 48 | 49 | .. code:: 50 | 51 | >>> y = jax.nn.one_hot(x, dims.N) 52 | >>> chex.assert_shape(y, dims['BTN']) 53 | 54 | You can also store the shape of a given array in :code:`dims`, e.g. 55 | 56 | .. code:: 57 | 58 | >>> z = jnp.array([[0, 6, 0, 2], 59 | ... [4, 2, 2, 4]]) 60 | >>> dims['XY'] = z.shape 61 | >>> dims 62 | Dimensions(B=3, N=7, T=5, X=2, Y=4) 63 | 64 | You can access the flat size of a shape as 65 | 66 | .. code:: 67 | 68 | >>> dims.size('BT') # Same as prod(dims['BT']). 69 | 15 70 | 71 | Similarly, you can flatten axes together by wrapping them in parentheses: 72 | 73 | .. code:: 74 | 75 | >>> dims['(BT)N'] 76 | (15, 7) 77 | 78 | You can set a wildcard dimension, cf. :func:`chex.assert_shape`: 79 | 80 | .. code:: 81 | 82 | >>> dims.W = None 83 | >>> dims['BTW'] 84 | (3, 5, None) 85 | 86 | Or you can use the wildcard character `'*'` directly: 87 | 88 | .. code:: 89 | 90 | >>> dims['BT*'] 91 | (3, 5, None) 92 | 93 | Single digits are interpreted as literal integers. Note that this notation 94 | is limited to single-digit literals. 95 | 96 | .. code:: 97 | 98 | >>> dims['BT123'] 99 | (3, 5, 1, 2, 3) 100 | 101 | Support for single digits was mainly included to accommodate dummy axes 102 | introduced for consistent broadcasting. For instance, instead of using 103 | :func:`jnp.expand_dims ` you could do the following: 104 | 105 | .. code:: 106 | 107 | >>> w = y * x # Cannot broadcast (3, 5, 7) with (3, 5) 108 | Traceback (most recent call last): 109 | ... 110 | ValueError: Incompatible shapes for broadcasting: ((3, 5, 7), (1, 3, 5)) 111 | >>> w = y * x.reshape(dims['BT1']) 112 | >>> chex.assert_shape(w, dims['BTN']) 113 | 114 | Sometimes you only care about some array dimensions but not all. You can use 115 | an underscore to ignore an axis, e.g. 116 | 117 | .. code:: 118 | 119 | >>> chex.assert_rank(y, 3) 120 | >>> dims['__M'] = y.shape # Skip the first two axes. 121 | 122 | Finally note that a single-character key returns a tuple of length one. 123 | 124 | .. code:: 125 | 126 | >>> dims['M'] 127 | (7,) 128 | """ 129 | # Tell static type checker not to worry about attribute errors. 130 | _HAS_DYNAMIC_ATTRIBUTES = True 131 | 132 | def __init__(self, **dim_sizes) -> None: 133 | for dim, size in dim_sizes.items(): 134 | self._setdim(dim, size) 135 | 136 | def size(self, key: str) -> int: 137 | """Returns the flat size of a given named shape, i.e. prod(shape).""" 138 | shape = self[key] 139 | if any(size is None or size <= 0 for size in shape): 140 | raise ValueError( 141 | f"cannot take product of shape '{key}' = {shape}, " 142 | 'because it contains non-positive sized dimensions' 143 | ) 144 | return math.prod(shape) 145 | 146 | def __getitem__(self, key: str) -> Shape: 147 | self._validate_key(key) 148 | shape = [] 149 | open_parentheses = False 150 | dims_to_flatten = '' 151 | for dim in key: 152 | # Signal to start accumulating `dims_to_flatten`. 153 | if dim == '(': 154 | if open_parentheses: 155 | raise ValueError(f"nested parentheses are unsupported; got: '{key}'") 156 | open_parentheses = True 157 | 158 | # Signal to collect accumulated `dims_to_flatten`. 159 | elif dim == ')': 160 | if not open_parentheses: 161 | raise ValueError(f"unmatched parentheses in named shape: '{key}'") 162 | if not dims_to_flatten: 163 | raise ValueError(f"found empty parentheses in named shape: '{key}'") 164 | shape.append(self.size(dims_to_flatten)) 165 | # Reset. 166 | open_parentheses = False 167 | dims_to_flatten = '' 168 | 169 | # Accumulate `dims_to_flatten`. 170 | elif open_parentheses: 171 | dims_to_flatten += dim 172 | 173 | # The typical (non-flattening) case. 174 | else: 175 | shape.append(self._getdim(dim)) 176 | 177 | if open_parentheses: 178 | raise ValueError(f"unmatched parentheses in named shape: '{key}'") 179 | return tuple(shape) 180 | 181 | def __setitem__(self, key: str, value: Collection[Optional[int]]) -> None: 182 | self._validate_key(key) 183 | self._validate_value(value) 184 | if len(key) != len(value): 185 | raise ValueError( 186 | f'key string {repr(key)} and shape {tuple(value)} ' 187 | 'have different lengths') 188 | for dim, size in zip(key, value): 189 | self._setdim(dim, size) 190 | 191 | def __delitem__(self, key: str) -> None: 192 | self._validate_key(key) 193 | for dim in key: 194 | self._deldim(dim) 195 | 196 | def __repr__(self) -> str: 197 | args = ', '.join(f'{k}={v}' for k, v in sorted(self._asdict().items())) 198 | return f'{type(self).__name__}({args})' 199 | 200 | def _asdict(self) -> Dict[str, Optional[int]]: 201 | return {k: v for k, v in self.__dict__.items() 202 | if re.fullmatch(r'[a-zA-Z]', k)} 203 | 204 | def _getdim(self, dim: str) -> Optional[int]: 205 | if dim == '*': 206 | return None 207 | if re.fullmatch(r'[0-9]', dim): 208 | return int(dim) 209 | try: 210 | return getattr(self, dim) 211 | except AttributeError as e: 212 | raise KeyError(dim) from e 213 | 214 | def _setdim(self, dim: str, size: Optional[int]) -> None: 215 | if dim == '_': # Skip. 216 | return 217 | self._validate_dim(dim) 218 | setattr(self, dim, _optional_int(size)) 219 | 220 | def _deldim(self, dim: str) -> None: 221 | if dim == '_': # Skip. 222 | return 223 | self._validate_dim(dim) 224 | try: 225 | return delattr(self, dim) 226 | except AttributeError as e: 227 | raise KeyError(dim) from e 228 | 229 | def _validate_key(self, key: Any) -> None: 230 | if not isinstance(key, str): 231 | raise TypeError(f'key must be a string; got: {type(key).__name__}') 232 | 233 | def _validate_value(self, value: Any) -> None: 234 | if not isinstance(value, Sized): 235 | raise TypeError( 236 | 'value must be sized, i.e. an object with a well-defined len(value); ' 237 | f'got object of type: {type(value).__name__}') 238 | 239 | def _validate_dim(self, dim: Any) -> None: 240 | if not isinstance(dim, str): 241 | raise TypeError( 242 | f'dimension name must be a string; got: {type(dim).__name__}') 243 | if not re.fullmatch(r'[a-zA-Z]', dim): 244 | raise KeyError( 245 | 'dimension names may only be contain letters (or \'_\' to skip); ' 246 | f'got dimension name: {repr(dim)}') 247 | 248 | 249 | def _optional_int(x: Any) -> Optional[int]: 250 | if x is None: 251 | return None 252 | try: 253 | i = int(x) 254 | if x == i: 255 | return i 256 | except ValueError: 257 | pass 258 | raise TypeError(f'object cannot be interpreted as a python int: {repr(x)}') 259 | -------------------------------------------------------------------------------- /chex/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Chex: Testing made fun, in JAX!""" 16 | 17 | from chex._src.asserts import assert_axis_dimension 18 | from chex._src.asserts import assert_axis_dimension_comparator 19 | from chex._src.asserts import assert_axis_dimension_gt 20 | from chex._src.asserts import assert_axis_dimension_gteq 21 | from chex._src.asserts import assert_axis_dimension_lt 22 | from chex._src.asserts import assert_axis_dimension_lteq 23 | from chex._src.asserts import assert_devices_available 24 | from chex._src.asserts import assert_equal 25 | from chex._src.asserts import assert_equal_rank 26 | from chex._src.asserts import assert_equal_shape 27 | from chex._src.asserts import assert_equal_shape_prefix 28 | from chex._src.asserts import assert_equal_shape_suffix 29 | from chex._src.asserts import assert_equal_size 30 | from chex._src.asserts import assert_exactly_one_is_none 31 | from chex._src.asserts import assert_gpu_available 32 | from chex._src.asserts import assert_is_broadcastable 33 | from chex._src.asserts import assert_is_divisible 34 | from chex._src.asserts import assert_max_traces 35 | from chex._src.asserts import assert_not_both_none 36 | from chex._src.asserts import assert_numerical_grads 37 | from chex._src.asserts import assert_rank 38 | from chex._src.asserts import assert_scalar 39 | from chex._src.asserts import assert_scalar_in 40 | from chex._src.asserts import assert_scalar_negative 41 | from chex._src.asserts import assert_scalar_non_negative 42 | from chex._src.asserts import assert_scalar_positive 43 | from chex._src.asserts import assert_shape 44 | from chex._src.asserts import assert_size 45 | from chex._src.asserts import assert_tpu_available 46 | from chex._src.asserts import assert_tree_all_finite 47 | from chex._src.asserts import assert_tree_has_only_ndarrays 48 | from chex._src.asserts import assert_tree_is_on_device 49 | from chex._src.asserts import assert_tree_is_on_host 50 | from chex._src.asserts import assert_tree_is_sharded 51 | from chex._src.asserts import assert_tree_no_nones 52 | from chex._src.asserts import assert_tree_shape_prefix 53 | from chex._src.asserts import assert_tree_shape_suffix 54 | from chex._src.asserts import assert_trees_all_close 55 | from chex._src.asserts import assert_trees_all_close_ulp 56 | from chex._src.asserts import assert_trees_all_equal 57 | from chex._src.asserts import assert_trees_all_equal_comparator 58 | from chex._src.asserts import assert_trees_all_equal_dtypes 59 | from chex._src.asserts import assert_trees_all_equal_shapes 60 | from chex._src.asserts import assert_trees_all_equal_shapes_and_dtypes 61 | from chex._src.asserts import assert_trees_all_equal_sizes 62 | from chex._src.asserts import assert_trees_all_equal_structs 63 | from chex._src.asserts import assert_type 64 | from chex._src.asserts import clear_trace_counter 65 | from chex._src.asserts import disable_asserts 66 | from chex._src.asserts import enable_asserts 67 | from chex._src.asserts import if_args_not_none 68 | from chex._src.asserts_chexify import block_until_chexify_assertions_complete 69 | from chex._src.asserts_chexify import chexify 70 | from chex._src.asserts_chexify import ChexifyChecks 71 | from chex._src.asserts_chexify import with_jittable_assertions 72 | from chex._src.dataclass import dataclass 73 | from chex._src.dataclass import mappable_dataclass 74 | from chex._src.dataclass import register_dataclass_type_with_jax_tree_util 75 | from chex._src.dimensions import Dimensions 76 | from chex._src.fake import fake_jit 77 | from chex._src.fake import fake_pmap 78 | from chex._src.fake import fake_pmap_and_jit 79 | from chex._src.fake import set_n_cpu_devices 80 | from chex._src.pytypes import Array 81 | from chex._src.pytypes import ArrayBatched 82 | from chex._src.pytypes import ArrayDevice 83 | from chex._src.pytypes import ArrayDeviceTree 84 | from chex._src.pytypes import ArrayDType 85 | from chex._src.pytypes import ArrayNumpy 86 | from chex._src.pytypes import ArrayNumpyTree 87 | from chex._src.pytypes import ArraySharded 88 | from chex._src.pytypes import ArrayTree 89 | from chex._src.pytypes import Device 90 | from chex._src.pytypes import Numeric 91 | from chex._src.pytypes import PRNGKey 92 | from chex._src.pytypes import PyTreeDef 93 | from chex._src.pytypes import Scalar 94 | from chex._src.pytypes import Shape 95 | from chex._src.restrict_backends import restrict_backends 96 | from chex._src.variants import all_variants 97 | from chex._src.variants import ChexVariantType 98 | from chex._src.variants import params_product 99 | from chex._src.variants import TestCase 100 | from chex._src.variants import variants 101 | from chex._src.warnings import create_deprecated_function_alias 102 | from chex._src.warnings import warn_deprecated_function 103 | from chex._src.warnings import warn_keyword_args_only_in_future 104 | from chex._src.warnings import warn_only_n_pos_args_in_future 105 | 106 | 107 | __version__ = "0.1.86" 108 | 109 | __all__ = ( 110 | "all_variants", 111 | "Array", 112 | "ArrayBatched", 113 | "ArrayDevice", 114 | "ArrayDeviceTree", 115 | "ArrayDType", 116 | "ArrayNumpy", 117 | "ArrayNumpyTree", 118 | "ArraySharded", 119 | "ArrayTree", 120 | "ChexifyChecks", 121 | "assert_axis_dimension", 122 | "assert_axis_dimension_comparator", 123 | "assert_axis_dimension_gt", 124 | "assert_axis_dimension_gteq", 125 | "assert_axis_dimension_lt", 126 | "assert_axis_dimension_lteq", 127 | "assert_devices_available", 128 | "assert_equal", 129 | "assert_equal_rank", 130 | "assert_equal_shape", 131 | "assert_equal_shape_prefix", 132 | "assert_equal_shape_suffix", 133 | "assert_equal_size", 134 | "assert_exactly_one_is_none", 135 | "assert_gpu_available", 136 | "assert_is_broadcastable", 137 | "assert_is_divisible", 138 | "assert_max_traces", 139 | "assert_not_both_none", 140 | "assert_numerical_grads", 141 | "assert_rank", 142 | "assert_scalar", 143 | "assert_scalar_in", 144 | "assert_scalar_negative", 145 | "assert_scalar_non_negative", 146 | "assert_scalar_positive", 147 | "assert_shape", 148 | "assert_size", 149 | "assert_tpu_available", 150 | "assert_tree_all_finite", 151 | "assert_tree_has_only_ndarrays", 152 | "assert_tree_is_on_device", 153 | "assert_tree_is_on_host", 154 | "assert_tree_is_sharded", 155 | "assert_tree_no_nones", 156 | "assert_tree_shape_prefix", 157 | "assert_tree_shape_suffix", 158 | "assert_trees_all_close", 159 | "assert_trees_all_close_ulp", 160 | "assert_trees_all_equal", 161 | "assert_trees_all_equal_comparator", 162 | "assert_trees_all_equal_dtypes", 163 | "assert_trees_all_equal_shapes", 164 | "assert_trees_all_equal_shapes_and_dtypes", 165 | "assert_trees_all_equal_sizes", 166 | "assert_trees_all_equal_structs", 167 | "assert_type", 168 | "block_until_chexify_assertions_complete", 169 | "chexify", 170 | "ChexVariantType", 171 | "clear_trace_counter", 172 | "create_deprecated_function_alias", 173 | "dataclass", 174 | "Device", 175 | "Dimensions", 176 | "disable_asserts", 177 | "enable_asserts", 178 | "fake_jit", 179 | "fake_pmap", 180 | "fake_pmap_and_jit", 181 | "if_args_not_none", 182 | "mappable_dataclass", 183 | "Numeric", 184 | "params_product", 185 | "PRNGKey", 186 | "PyTreeDef", 187 | "register_dataclass_type_with_jax_tree_util", 188 | "restrict_backends", 189 | "Scalar", 190 | "set_n_cpu_devices", 191 | "Shape", 192 | "TestCase", 193 | "variants", 194 | "warn_deprecated_function", 195 | "warn_keyword_args_only_in_future", 196 | "warn_only_n_pos_args_in_future", 197 | "with_jittable_assertions", 198 | ) 199 | 200 | 201 | # _________________________________________ 202 | # / Please don't use symbols in `_src` they \ 203 | # \ are not part of the Chex public API. / 204 | # ----------------------------------------- 205 | # \ ^__^ 206 | # \ (oo)\_______ 207 | # (__)\ )\/\ 208 | # ||----w | 209 | # || || 210 | # 211 | -------------------------------------------------------------------------------- /chex/_src/asserts_chexify.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Chexification utilities.""" 16 | 17 | import atexit 18 | import collections 19 | from concurrent import futures 20 | import dataclasses 21 | import functools 22 | import re 23 | from typing import Any, Callable, FrozenSet 24 | 25 | from absl import logging 26 | from chex._src import asserts_internal as _ai 27 | import jax 28 | from jax.experimental import checkify 29 | 30 | 31 | @dataclasses.dataclass(frozen=True) 32 | class _ChexifyChecks: 33 | """A set of checks imported from checkify.""" 34 | 35 | user: FrozenSet[checkify.ErrorCategory] = checkify.user_checks 36 | nan: FrozenSet[checkify.ErrorCategory] = checkify.nan_checks 37 | index: FrozenSet[checkify.ErrorCategory] = checkify.index_checks 38 | div: FrozenSet[checkify.ErrorCategory] = checkify.div_checks 39 | float: FrozenSet[checkify.ErrorCategory] = checkify.float_checks 40 | automatic: FrozenSet[checkify.ErrorCategory] = checkify.automatic_checks 41 | all: FrozenSet[checkify.ErrorCategory] = checkify.all_checks 42 | 43 | 44 | _chexify_error_pattern = re.compile( 45 | re.escape(_ai.get_chexify_err_message('ANY', 'ANY')).replace('ANY', '.*') 46 | ) 47 | 48 | 49 | def _check_error(err: checkify.Error) -> None: 50 | """Checks the error and converts it to chex format.""" 51 | try: 52 | checkify.check_error(err) 53 | except ValueError as exc: 54 | msg = str(exc) 55 | if _chexify_error_pattern.match(msg): 56 | # Remove internal code pointers. 57 | internal_info_pos = msg.rfind('(check failed at') 58 | if internal_info_pos != -1: 59 | msg = msg[:internal_info_pos] 60 | raise AssertionError(msg) # pylint:disable=raise-missing-from 61 | else: 62 | raise 63 | 64 | 65 | def block_until_chexify_assertions_complete() -> None: 66 | """Waits until all asynchronous checks complete. 67 | 68 | See `chexify` for more detail. 69 | """ 70 | for wait_fn in _ai.CHEXIFY_STORAGE.wait_fns: 71 | wait_fn() 72 | 73 | 74 | @atexit.register # to catch uninspected error stats 75 | def _check_if_hanging_assertions(): 76 | if _ai.CHEXIFY_STORAGE.wait_fns: 77 | logging.warning( 78 | '[Chex] Some of chexify assetion statuses were not inspected due to ' 79 | 'async exec (https://jax.readthedocs.io/en/latest/async_dispatch.html).' 80 | ' Consider calling `chex.block_until_chexify_assertions_complete()` at ' 81 | 'the end of computations that rely on jitted chex assetions.') 82 | block_until_chexify_assertions_complete() 83 | 84 | 85 | # Public API. 86 | ChexifyChecks = _ChexifyChecks() 87 | 88 | 89 | def chexify( 90 | fn: Callable[..., Any], 91 | async_check: bool = True, 92 | errors: FrozenSet[checkify.ErrorCategory] = ChexifyChecks.user, 93 | ) -> Callable[..., Any]: 94 | """Wraps a transformed function `fn` to enable Chex value assertions. 95 | 96 | Chex value/runtime assertions access concrete values of tensors (e.g. 97 | `assert_tree_all_finite`) which are not available during JAX tracing, see 98 | https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html 99 | and 100 | https://jax.readthedocs.io/en/latest/_modules/jax/_src/errors.html#ConcretizationTypeError. 101 | 102 | This wrapper enables them in jitted/pmapped functions by performing a 103 | specifically designed JAX transformation 104 | https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html#the-checkify-transformation 105 | and calling functionalised checks 106 | https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.checkify.check.html 107 | 108 | Example: 109 | 110 | .. code:: 111 | 112 | @chex.chexify 113 | @jax.jit 114 | def logp1_abs_safe(x: chex.Array) -> chex.Array: 115 | chex.assert_tree_all_finite(x) 116 | return jnp.log(jnp.abs(x) + 1) 117 | 118 | logp1_abs_safe(jnp.ones(2)) # OK 119 | logp1_abs_safe(jnp.array([jnp.nan, 3])) # FAILS 120 | logp1_abs_safe.wait_checks() 121 | 122 | Note 1: This wrapper allows identifying the first failed assertion in a jitted 123 | code by printing a pointer to the line where the failed assertion was invoked. 124 | For getting verbose messages (including concrete tensor values), an unjitted 125 | version of the code will need to be executed with the same input values. Chex 126 | does not currently provide tools to help with this. 127 | 128 | Note 2: This wrapper fully supports asynchronous executions 129 | (see https://jax.readthedocs.io/en/latest/async_dispatch.html). 130 | To block program execution until asynchronous checks for a _chexified_ 131 | function `fn` complete, call `fn.wait_checks()`. Similarly, 132 | `chex.block_until_chexify_assertions_complete()` will block program execution 133 | until _all_ asyncronous checks complete. 134 | 135 | Note 3: Chex automatically selects the backend for executing its assertions 136 | (i.e. CPU or device accelerator) depending on the program context. 137 | 138 | Note 4: Value assertions can have impact on the performance of a function, see 139 | https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html#limitations 140 | 141 | Note 5: static assertions, such as `assert_shape` or 142 | `assert_trees_all_equal_dtypes`, can be called from a jitted function without 143 | `chexify` wrapper (since they do not access concrete values, only 144 | shapes and/or dtypes which are available during JAX tracing). 145 | 146 | More examples can be found at 147 | https://github.com/deepmind/chex/blob/master/chex/_src/asserts_chexify_test.py 148 | 149 | Args: 150 | fn: A transformed function to wrap. 151 | async_check: Whether to check errors in the async dispatch mode. See 152 | https://jax.readthedocs.io/en/latest/async_dispatch.html. 153 | errors: A set of `checkify.ErrorCategory` values which defines the set of 154 | enabled checks. By default only explicit ``checks`` are enabled (`user`). 155 | You can also for example enable NaN and Div-by-0 errors by passing the 156 | `float` set, or for example combine multiple sets through set 157 | operations (`float | user`). 158 | 159 | Returns: 160 | A _chexified_ function, i.e. the one with enabled value assertions. 161 | The returned function has `wait_checks()` method that blocks the caller 162 | until all pending async checks complete. 163 | """ 164 | # Hardware/XLA failures can only happen on the C++ side. They are expected to 165 | # issue critical errors that will immediately crash the whole program. 166 | # Nevertheless, Chex sets its own timeout for every chexified XLA comp. to 167 | # ensure that a program never blocks on Chex side when running in async mode. 168 | async_timeout = 1800 # 30 minutes 169 | 170 | # Get function name. 171 | if isinstance(fn, functools.partial): 172 | func_name = fn.func.__name__ 173 | else: 174 | func_name = fn.__name__ 175 | 176 | if async_check: 177 | # Spawn a thread for processing blocking calls. 178 | thread_pool = futures.ThreadPoolExecutor(1, f'async_chex_{func_name}') 179 | # A deque for futures. 180 | async_check_futures = collections.deque() 181 | 182 | # Checkification. 183 | checkified_fn = checkify.checkify(fn, errors=errors) 184 | 185 | @functools.wraps(fn) 186 | def _chexified_fn(*args, **kwargs): 187 | if _ai.CHEXIFY_STORAGE.level: 188 | raise RuntimeError( 189 | 'Nested @chexify wrapping is disallowed. ' 190 | 'Make sure that you only wrap the function at the outermost level.') 191 | 192 | if _ai.has_tracers((args, kwargs)): 193 | raise RuntimeError( 194 | '@chexify must be applied on top of all (p)jit/pmap transformations' 195 | ' (otherwise it will result in `UnexpectedTracerError`). If you have' 196 | ' functions that use value assertions, do not wrap them' 197 | ' individually -- just wrap the outermost function after' 198 | ' applying all your JAX transformations. See the example at ' 199 | 'https://github.com/google-deepmind/chex#static-and-value-aka-runtime-assertions' # pylint:disable=line-too-long 200 | ) 201 | 202 | if async_check: 203 | # Check completed calls. 204 | while async_check_futures and async_check_futures[0].done(): 205 | _check_error(async_check_futures.popleft().result(async_timeout)) 206 | 207 | # Run the checkified function. 208 | _ai.CHEXIFY_STORAGE.level += 1 209 | try: 210 | err, out = checkified_fn(*args, **kwargs) 211 | finally: 212 | _ai.CHEXIFY_STORAGE.level -= 1 213 | 214 | # Check errors. 215 | if async_check: 216 | # Blocking call is deferred to the thread. 217 | async_check_futures.append( 218 | thread_pool.submit(lambda: jax.device_get(err))) 219 | else: 220 | # Blocks until `fn`'s outputs are ready. 221 | _check_error(err) 222 | 223 | return out 224 | 225 | def _wait_checks(): 226 | if async_check: 227 | while async_check_futures: 228 | _check_error(async_check_futures.popleft().result(async_timeout)) 229 | 230 | # Add a barrier callback to the global storage. 231 | _ai.CHEXIFY_STORAGE.wait_fns.append(_wait_checks) 232 | 233 | # Add the callback to the chexified funtion's properties. 234 | if not hasattr(_chexified_fn, 'wait_checks'): 235 | _chexified_fn.wait_checks = _wait_checks 236 | else: 237 | logging.warning( 238 | "Function %s already defines 'wait_checks' method; " 239 | 'Chex will not redefine it.', func_name) 240 | 241 | return _chexified_fn 242 | 243 | 244 | def with_jittable_assertions(fn: Callable[..., Any], 245 | async_check: bool = True) -> Callable[..., Any]: 246 | """An alias for `chexify` (see the docs).""" 247 | return chexify(fn, async_check) 248 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /chex/_src/dataclass.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """JAX/dm-tree friendly dataclass implementation reusing Python dataclasses.""" 16 | 17 | import collections 18 | import dataclasses 19 | import functools 20 | import sys 21 | 22 | from absl import logging 23 | import jax 24 | from typing_extensions import dataclass_transform # pytype: disable=not-supported-yet 25 | 26 | 27 | FrozenInstanceError = dataclasses.FrozenInstanceError 28 | _RESERVED_DCLS_FIELD_NAMES = frozenset(("from_tuple", "replace", "to_tuple")) 29 | 30 | 31 | def mappable_dataclass(cls): 32 | """Exposes dataclass as ``collections.abc.Mapping`` descendent. 33 | 34 | Allows to traverse dataclasses in methods from `dm-tree` library. 35 | 36 | NOTE: changes dataclasses constructor to dict-type 37 | (i.e. positional args aren't supported; however can use generators/iterables). 38 | 39 | Args: 40 | cls: A dataclass to mutate. 41 | 42 | Returns: 43 | Mutated dataclass implementing ``collections.abc.Mapping`` interface. 44 | """ 45 | if not dataclasses.is_dataclass(cls): 46 | raise ValueError(f"Expected dataclass, got {cls} (change wrappers order?).") 47 | 48 | # Define methods for compatibility with `collections.abc.Mapping`. 49 | setattr(cls, "__getitem__", lambda self, x: self.__dict__[x]) 50 | setattr(cls, "__len__", lambda self: len(self.__dict__)) 51 | setattr(cls, "__iter__", lambda self: iter(self.__dict__)) 52 | # Override the default `collections.abc.Mapping` method implementation for 53 | # cleaner visualization. Without this change x.keys() shows the full repr(x) 54 | # instead of only the dict_keys present. The same goes for values and items. 55 | setattr(cls, "keys", lambda self: self.__dict__.keys()) 56 | setattr(cls, "values", lambda self: self.__dict__.values()) 57 | setattr(cls, "items", lambda self: self.__dict__.items()) 58 | 59 | # Update constructor. 60 | orig_init = cls.__init__ 61 | all_fields = set(f.name for f in cls.__dataclass_fields__.values()) 62 | init_fields = [f.name for f in cls.__dataclass_fields__.values() if f.init] 63 | 64 | @functools.wraps(orig_init) 65 | def new_init(self, *orig_args, **orig_kwargs): 66 | if (orig_args and orig_kwargs) or len(orig_args) > 1: 67 | raise ValueError( 68 | "Mappable dataclass constructor doesn't support positional args." 69 | "(it has the same constructor as python dict)") 70 | all_kwargs = dict(*orig_args, **orig_kwargs) 71 | unknown_kwargs = set(all_kwargs.keys()) - all_fields 72 | if unknown_kwargs: 73 | raise ValueError(f"__init__() got unexpected kwargs: {unknown_kwargs}.") 74 | 75 | # Pass only arguments corresponding to fields with `init=True`. 76 | valid_kwargs = {k: v for k, v in all_kwargs.items() if k in init_fields} 77 | orig_init(self, **valid_kwargs) 78 | 79 | cls.__init__ = new_init 80 | 81 | # Update base class to derive from Mapping 82 | dct = dict(cls.__dict__) 83 | if "__dict__" in dct: 84 | dct.pop("__dict__") # Avoid self-references. 85 | 86 | # Remove object from the sequence of base classes. Deriving from both Mapping 87 | # and object will cause a failure to create a MRO for the updated class 88 | bases = tuple(b for b in cls.__bases__ if b != object) 89 | cls = type(cls.__name__, bases + (collections.abc.Mapping,), dct) 90 | return cls 91 | 92 | 93 | @dataclass_transform() 94 | def dataclass( 95 | cls=None, 96 | *, 97 | init=True, 98 | repr=True, # pylint: disable=redefined-builtin 99 | eq=True, 100 | order=False, 101 | unsafe_hash=False, 102 | frozen=False, 103 | kw_only: bool = False, 104 | mappable_dataclass=True, # pylint: disable=redefined-outer-name 105 | ): 106 | """JAX-friendly wrapper for :py:func:`dataclasses.dataclass`. 107 | 108 | This wrapper class registers new dataclasses with JAX so that tree utils 109 | operate correctly. Additionally a replace method is provided making it easy 110 | to operate on the class when made immutable (frozen=True). 111 | 112 | Args: 113 | cls: A class to decorate. 114 | init: See :py:func:`dataclasses.dataclass`. 115 | repr: See :py:func:`dataclasses.dataclass`. 116 | eq: See :py:func:`dataclasses.dataclass`. 117 | order: See :py:func:`dataclasses.dataclass`. 118 | unsafe_hash: See :py:func:`dataclasses.dataclass`. 119 | frozen: See :py:func:`dataclasses.dataclass`. 120 | kw_only: See :py:func:`dataclasses.dataclass`. 121 | mappable_dataclass: If True (the default), methods to make the class 122 | implement the :py:class:`collections.abc.Mapping` interface will be 123 | generated and the class will include :py:class:`collections.abc.Mapping` 124 | in its base classes. 125 | `True` is the default, because being an instance of `Mapping` makes 126 | `chex.dataclass` compatible with e.g. `jax.tree_util.tree_*` methods, the 127 | `tree` library, or methods related to tensorflow/python/utils/nest.py. 128 | As a side-effect, e.g. `np.testing.assert_array_equal` will only check 129 | the field names are equal and not the content. Use `chex.assert_tree_*` 130 | instead. 131 | 132 | Returns: 133 | A JAX-friendly dataclass. 134 | """ 135 | def dcls(cls): 136 | # Make sure to create a separate _Dataclass instance for each `cls`. 137 | return _Dataclass( 138 | init, repr, eq, order, unsafe_hash, frozen, kw_only, mappable_dataclass 139 | )(cls) 140 | 141 | if cls is None: 142 | return dcls 143 | return dcls(cls) 144 | 145 | 146 | class _Dataclass(): 147 | """JAX-friendly wrapper for `dataclasses.dataclass`.""" 148 | 149 | def __init__( 150 | self, 151 | init=True, 152 | repr=True, # pylint: disable=redefined-builtin 153 | eq=True, 154 | order=False, 155 | unsafe_hash=False, 156 | frozen=False, 157 | kw_only=False, 158 | mappable_dataclass=True, # pylint: disable=redefined-outer-name 159 | ): 160 | self.init = init 161 | self.repr = repr # pylint: disable=redefined-builtin 162 | self.eq = eq 163 | self.order = order 164 | self.unsafe_hash = unsafe_hash 165 | self.frozen = frozen 166 | self.kw_only = kw_only 167 | self.mappable_dataclass = mappable_dataclass 168 | 169 | def __call__(self, cls): 170 | """Forwards class to dataclasses's wrapper and registers it with JAX.""" 171 | 172 | # Remove once https://github.com/python/cpython/pull/24484 is merged. 173 | for base in cls.__bases__: 174 | if (dataclasses.is_dataclass(base) and 175 | getattr(base, "__dataclass_params__").frozen and not self.frozen): 176 | raise TypeError("cannot inherit non-frozen dataclass from a frozen one") 177 | 178 | # `kw_only` is only available starting from 3.10. 179 | version_dependent_args = {} 180 | version = sys.version_info 181 | if version.major == 3 and version.minor >= 10: 182 | version_dependent_args = {"kw_only": self.kw_only} 183 | # pytype: disable=wrong-keyword-args 184 | dcls = dataclasses.dataclass( 185 | cls, 186 | init=self.init, 187 | repr=self.repr, 188 | eq=self.eq, 189 | order=self.order, 190 | unsafe_hash=self.unsafe_hash, 191 | frozen=self.frozen, 192 | **version_dependent_args, 193 | ) 194 | # pytype: enable=wrong-keyword-args 195 | 196 | fields_names = set(f.name for f in dataclasses.fields(dcls)) 197 | invalid_fields = fields_names.intersection(_RESERVED_DCLS_FIELD_NAMES) 198 | if invalid_fields: 199 | raise ValueError(f"The following dataclass fields are disallowed: " 200 | f"{invalid_fields} ({dcls}).") 201 | 202 | if self.mappable_dataclass: 203 | dcls = mappable_dataclass(dcls) 204 | 205 | def _from_tuple(args): 206 | return dcls(zip(dcls.__dataclass_fields__.keys(), args)) 207 | 208 | def _to_tuple(self): 209 | return tuple(getattr(self, k) for k in self.__dataclass_fields__.keys()) 210 | 211 | def _replace(self, **kwargs): 212 | return dataclasses.replace(self, **kwargs) 213 | 214 | def _getstate(self): 215 | return self.__dict__ 216 | 217 | # Register the dataclass at definition. As long as the dataclass is defined 218 | # outside __main__, this is sufficient to make JAX's PyTree registry 219 | # recognize the dataclass and the dataclass' custom PyTreeDef, especially 220 | # when unpickling either the dataclass object, its type, or its PyTreeDef, 221 | # in a different process, because the defining module will be imported. 222 | # 223 | # However, if the dataclass is defined in __main__, unpickling in a 224 | # subprocess does not trigger re-registration. Therefore we also need to 225 | # register when deserializing the object, or construction (e.g. when the 226 | # dataclass type is being unpickled). Unfortunately, there is not yet a way 227 | # to trigger re-registration when the treedef is unpickled as that's handled 228 | # by JAX. 229 | # 230 | # See internal dataclass_test for unit tests demonstrating the problems. 231 | # The registration below may result in pickling failures of the sort 232 | # _pickle.PicklingError: Can't pickle : 233 | # it's not the same object as register_dataclass_type_with_jax_tree_util 234 | # for modules defined in __main__ so we disable registration in this case. 235 | if dcls.__module__ != "__main__": 236 | register_dataclass_type_with_jax_tree_util(dcls) 237 | 238 | # Patch __setstate__ to register the dataclass on deserialization. 239 | def _setstate(self, state): 240 | register_dataclass_type_with_jax_tree_util(dcls) 241 | self.__dict__.update(state) 242 | 243 | orig_init = dcls.__init__ 244 | 245 | # Patch __init__ such that the dataclass is registered on creation if it is 246 | # not registered on deserialization. 247 | @functools.wraps(orig_init) 248 | def _init(self, *args, **kwargs): 249 | register_dataclass_type_with_jax_tree_util(dcls) 250 | return orig_init(self, *args, **kwargs) 251 | 252 | setattr(dcls, "from_tuple", _from_tuple) 253 | setattr(dcls, "to_tuple", _to_tuple) 254 | setattr(dcls, "replace", _replace) 255 | setattr(dcls, "__getstate__", _getstate) 256 | setattr(dcls, "__setstate__", _setstate) 257 | setattr(dcls, "__init__", _init) 258 | 259 | return dcls 260 | 261 | 262 | def _dataclass_unflatten(dcls, keys, values): 263 | """Creates a chex dataclass from a flatten jax.tree_util representation.""" 264 | dcls_object = dcls.__new__(dcls) 265 | attribute_dict = dict(zip(keys, values)) 266 | # Looping over fields instead of keys & values preserves the field order. 267 | # Using dataclasses.fields fails because dataclass uids change after 268 | # serialisation (eg, with cloudpickle). 269 | for field in dcls.__dataclass_fields__.values(): 270 | if field.name in attribute_dict: # Filter pseudo-fields. 271 | object.__setattr__(dcls_object, field.name, attribute_dict[field.name]) 272 | # Need to manual call post_init here as we have avoided calling __init__ 273 | if getattr(dcls_object, "__post_init__", None): 274 | dcls_object.__post_init__() 275 | return dcls_object 276 | 277 | 278 | def _flatten_with_path(dcls): 279 | path = [] 280 | keys = [] 281 | for k, v in sorted(dcls.__dict__.items()): 282 | k = jax.tree_util.GetAttrKey(k) 283 | path.append((k, v)) 284 | keys.append(k) 285 | return path, keys 286 | 287 | 288 | @functools.cache 289 | def register_dataclass_type_with_jax_tree_util(data_class): 290 | """Register an existing dataclass so JAX knows how to handle it. 291 | 292 | This means that functions in jax.tree_util operate over the fields 293 | of the dataclass. See 294 | https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees 295 | for further information. 296 | 297 | Args: 298 | data_class: A class created using dataclasses.dataclass. It must be 299 | constructable from keyword arguments corresponding to the members exposed 300 | in instance.__dict__. 301 | """ 302 | flatten = lambda d: jax.util.unzip2(sorted(d.__dict__.items()))[::-1] 303 | unflatten = functools.partial(_dataclass_unflatten, data_class) 304 | try: 305 | jax.tree_util.register_pytree_with_keys( 306 | nodetype=data_class, flatten_with_keys=_flatten_with_path, 307 | flatten_func=flatten, unflatten_func=unflatten) 308 | except ValueError: 309 | logging.info("%s is already registered as JAX PyTree node.", data_class) 310 | -------------------------------------------------------------------------------- /chex/_src/fake.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utilities to patch JAX functions with faked implementations. 16 | 17 | This module provides fake implementations of jax.jit and jax.pmap, which can be 18 | patched over existing implementations for easier debugging. 19 | 20 | See https://www.martinfowler.com/articles/mocksArentStubs.html 21 | """ 22 | 23 | import contextlib 24 | import functools 25 | import inspect 26 | import os 27 | import re 28 | from typing import Any, Callable, Iterable, Optional, Union 29 | from unittest import mock 30 | from absl import flags 31 | import jax 32 | import jax.numpy as jnp 33 | 34 | 35 | FLAGS = flags.FLAGS 36 | flags.DEFINE_integer('chex_n_cpu_devices', 1, 37 | 'Number of CPU threads to use as devices in tests.') 38 | flags.DEFINE_bool('chex_assert_multiple_cpu_devices', False, 39 | 'Whether to fail if a number of CPU devices is less than 2.') 40 | 41 | _xla_device_count_flag_regexp = ( 42 | r'[-]{0,2}xla_force_host_platform_device_count=(\d+)?(\s|$)') 43 | 44 | 45 | def get_n_cpu_devices_from_xla_flags() -> int: 46 | """Parses number of CPUs from the XLA environment flags.""" 47 | m = re.match(_xla_device_count_flag_regexp, os.getenv('XLA_FLAGS', '')) 48 | 49 | # At least one CPU device must be available. 50 | n_devices = int(m.group(1)) if m else 1 51 | return n_devices 52 | 53 | 54 | def set_n_cpu_devices(n: Optional[int] = None) -> None: 55 | """Forces XLA to use `n` CPU threads as host devices. 56 | 57 | This allows `jax.pmap` to be tested on a single-CPU platform. 58 | This utility only takes effect before XLA backends are initialized, i.e. 59 | before any JAX operation is executed (including `jax.devices()` etc.). 60 | See https://github.com/google/jax/issues/1408. 61 | 62 | Args: 63 | n: A required number of CPU devices (``FLAGS.chex_n_cpu_devices`` is used by 64 | default). 65 | 66 | Raises: 67 | RuntimeError: If XLA backends were already initialized. 68 | """ 69 | n = n or FLAGS['chex_n_cpu_devices'].value 70 | 71 | n_devices = get_n_cpu_devices_from_xla_flags() 72 | cpu_backend = (jax.lib.xla_bridge._backends or {}).get('cpu', None) # pylint: disable=protected-access 73 | if cpu_backend is not None and n_devices != n: 74 | raise RuntimeError( 75 | f'Attempted to set {n} devices, but {n_devices} CPUs already available:' 76 | ' ensure that `set_n_cpu_devices` is executed before any JAX operation.' 77 | ) 78 | 79 | xla_flags = os.getenv('XLA_FLAGS', '') 80 | xla_flags = re.sub(_xla_device_count_flag_regexp, '', xla_flags) 81 | os.environ['XLA_FLAGS'] = ' '.join( 82 | [f'--xla_force_host_platform_device_count={n}'] + xla_flags.split()) 83 | 84 | 85 | def convert_to_varargs(sig, *args, **kwargs): 86 | """Converts varargs+kwargs function arguments into varargs only.""" 87 | bound_args = sig.bind(*args, **kwargs) 88 | return bound_args.args 89 | 90 | 91 | def _ignore_axis_index_groups(fn): 92 | """Wrapper that forces axis_index_groups to be None. 93 | 94 | This is to avoid problems within fake_pmap where parallel operations are 95 | performed with vmap, rather than pmap. Parallel operations where 96 | `axis_index_groups` is not `None` are not currently supported under vmap. 97 | 98 | Args: 99 | fn: the function to wrap 100 | 101 | Returns: 102 | a wrapped function that forces any keyword argument named 103 | `axis_index_groups` to be None 104 | """ 105 | @functools.wraps(fn) 106 | def _fake(*args, axis_index_groups=None, **kwargs): 107 | del axis_index_groups 108 | return fn(*args, axis_index_groups=None, **kwargs) 109 | return _fake 110 | 111 | 112 | _fake_all_gather = _ignore_axis_index_groups(jax.lax.all_gather) 113 | _fake_all_to_all = _ignore_axis_index_groups(jax.lax.all_to_all) 114 | _fake_psum = _ignore_axis_index_groups(jax.lax.psum) 115 | _fake_pmean = _ignore_axis_index_groups(jax.lax.pmean) 116 | _fake_pmax = _ignore_axis_index_groups(jax.lax.pmax) 117 | _fake_pmin = _ignore_axis_index_groups(jax.lax.pmin) 118 | _fake_pswapaxes = _ignore_axis_index_groups(jax.lax.pswapaxes) 119 | 120 | 121 | @functools.wraps(jax.pmap) 122 | def _fake_pmap(fn, 123 | axis_name: Optional[Any] = None, 124 | *, 125 | in_axes=0, 126 | static_broadcasted_argnums: Union[int, Iterable[int]] = (), 127 | jit_result: bool = False, 128 | fake_parallel_axis: bool = False, 129 | **unused_kwargs): 130 | """Fake implementation of pmap using vmap.""" 131 | 132 | if isinstance(static_broadcasted_argnums, int): 133 | static_broadcasted_argnums = (static_broadcasted_argnums,) 134 | if static_broadcasted_argnums and isinstance(in_axes, dict): 135 | raise NotImplementedError( 136 | 'static_broadcasted_argnums with dict in_axes not supported.') 137 | 138 | fn_signature = inspect.signature( 139 | fn, 140 | # Disable 'follow wrapped' because we want the exact signature of fn, 141 | # not the signature of any function it might wrap. 142 | follow_wrapped=False) 143 | 144 | @functools.wraps(fn) 145 | def wrapped_fn(*args, **kwargs): 146 | # Convert kwargs to varargs 147 | # This is a workaround for vmapped functions not working with kwargs 148 | call_args = convert_to_varargs(fn_signature, *args, **kwargs) 149 | 150 | if static_broadcasted_argnums: 151 | # Make sure vmap does not try to map over `static_broadcasted_argnums`. 152 | if isinstance(in_axes, int): 153 | vmap_in_axes = [in_axes] * len(call_args) 154 | else: 155 | vmap_in_axes = list(in_axes) 156 | for argnum in static_broadcasted_argnums: 157 | vmap_in_axes[argnum] = None 158 | 159 | # To protect the arguments from `static_broadcasted_argnums`, 160 | # from turning into tracers (because of vmap), we capture the original 161 | # `call_args` and replace the passed in tracers with original values. 162 | original_call_args = call_args 163 | 164 | # A function passed to vmap, that will simply replace the static args 165 | # with their original values. 166 | def fn_without_statics(*args): 167 | args_with_original_statics = [ 168 | orig_arg if i in static_broadcasted_argnums else arg 169 | for i, (arg, orig_arg) in enumerate(zip(args, original_call_args)) 170 | ] 171 | return fn(*args_with_original_statics) 172 | 173 | # Make sure to avoid turning static args into tracers: Some python objects 174 | # might not survive vmap. Just replace with an unused constant. 175 | call_args = [ 176 | 1 if i in static_broadcasted_argnums else arg 177 | for i, arg in enumerate(call_args) 178 | ] 179 | 180 | else: 181 | vmap_in_axes = in_axes 182 | fn_without_statics = fn 183 | 184 | vmapped_fn = jax.vmap( 185 | fn_without_statics, in_axes=vmap_in_axes, axis_name=axis_name 186 | ) 187 | if jit_result: 188 | vmapped_fn = jax.jit(vmapped_fn) 189 | 190 | if fake_parallel_axis: 191 | call_args = jax.tree_util.tree_map( 192 | lambda x: jnp.expand_dims(x, axis=0), call_args) 193 | 194 | output = vmapped_fn(*call_args) 195 | 196 | if fake_parallel_axis: 197 | output = jax.tree_util.tree_map(lambda x: jnp.squeeze(x, axis=0), output) 198 | 199 | return output 200 | 201 | return wrapped_fn 202 | 203 | 204 | # pylint:disable=unnecessary-dunder-call 205 | class FakeContext(contextlib.ExitStack): 206 | 207 | def start(self): 208 | self.__enter__() 209 | 210 | def stop(self): 211 | self.__exit__(None, None, None) 212 | 213 | 214 | # pylint:enable=unnecessary-dunder-call 215 | 216 | 217 | def fake_jit(enable_patching: bool = True) -> FakeContext: 218 | """Context manager for patching `jax.jit` with the identity function. 219 | 220 | This is intended to be used as a debugging tool to programmatically enable or 221 | disable JIT compilation. 222 | 223 | Can be used either as a context managed scope: 224 | 225 | .. code-block:: python 226 | 227 | with chex.fake_jit(): 228 | @jax.jit 229 | def foo(x): 230 | ... 231 | 232 | or by calling `start` and `stop`: 233 | 234 | .. code-block:: python 235 | 236 | fake_jit_context = chex.fake_jit() 237 | fake_jit_context.start() 238 | 239 | @jax.jit 240 | def foo(x): 241 | ... 242 | 243 | fake_jit_context.stop() 244 | 245 | Args: 246 | enable_patching: Whether to patch `jax.jit`. 247 | 248 | Returns: 249 | Context where `jax.jit` is patched with the identity function jax is 250 | configured to avoid jitting internally whenever possible in functions 251 | such as `jax.lax.scan`, etc. 252 | """ 253 | stack = FakeContext() 254 | stack.enter_context(jax.disable_jit(disable=enable_patching)) 255 | return stack 256 | 257 | 258 | def fake_pmap( 259 | enable_patching: bool = True, 260 | jit_result: bool = False, 261 | ignore_axis_index_groups: bool = False, 262 | fake_parallel_axis: bool = False, 263 | ) -> FakeContext: 264 | """Context manager for patching `jax.pmap` with `jax.vmap`. 265 | 266 | This is intended to be used as a debugging tool to programmatically replace 267 | pmap transformations with a non-parallel vmap transformation. 268 | 269 | Can be used either as a context managed scope: 270 | 271 | .. code-block:: python 272 | 273 | with chex.fake_pmap(): 274 | @jax.pmap 275 | def foo(x): 276 | ... 277 | 278 | or by calling `start` and `stop`: 279 | 280 | .. code-block:: python 281 | 282 | fake_pmap_context = chex.fake_pmap() 283 | fake_pmap_context.start() 284 | @jax.pmap 285 | def foo(x): 286 | ... 287 | fake_pmap_context.stop() 288 | 289 | Args: 290 | enable_patching: Whether to patch `jax.pmap`. 291 | jit_result: Whether the transformed function should be jitted despite not 292 | being pmapped. 293 | ignore_axis_index_groups: Whether to force any parallel operation within the 294 | context to set `axis_index_groups` to be None. This is a compatibility 295 | option to allow users of the axis_index_groups parameter to run under the 296 | fake_pmap context. This feature is not currently supported in vmap, and 297 | will fail, so we force the parameter to be `None`. 298 | *Warning*: This will produce different results to running under `jax.pmap` 299 | fake_parallel_axis: Fake a parallel axis 300 | 301 | Returns: 302 | Context where `jax.pmap` is patched with `jax.vmap`. 303 | """ 304 | stack = FakeContext() 305 | if enable_patching: 306 | patched_pmap = functools.partial( 307 | _fake_pmap, 308 | jit_result=jit_result, 309 | fake_parallel_axis=fake_parallel_axis) 310 | 311 | stack.enter_context(mock.patch('jax.pmap', patched_pmap)) 312 | 313 | if ignore_axis_index_groups: 314 | stack.enter_context(mock.patch('jax.lax.all_gather', _fake_all_gather)) 315 | stack.enter_context(mock.patch('jax.lax.all_to_all', _fake_all_to_all)) 316 | stack.enter_context(mock.patch('jax.lax.psum', _fake_psum)) 317 | stack.enter_context(mock.patch('jax.lax.pmean', _fake_pmean)) 318 | stack.enter_context(mock.patch('jax.lax.pmax', _fake_pmax)) 319 | stack.enter_context(mock.patch('jax.lax.pmin', _fake_pmin)) 320 | stack.enter_context(mock.patch('jax.lax.pswapaxes', _fake_pswapaxes)) 321 | else: 322 | # Use default implementations 323 | pass 324 | 325 | return stack 326 | 327 | 328 | def fake_pmap_and_jit(enable_pmap_patching: bool = True, 329 | enable_jit_patching: bool = True) -> FakeContext: 330 | """Context manager for patching `jax.jit` and `jax.pmap`. 331 | 332 | This is a convenience function, equivalent to nested `chex.fake_pmap` and 333 | `chex.fake_jit` contexts. 334 | 335 | Note that calling (the true implementation of) `jax.pmap` will compile the 336 | function, so faking `jax.jit` in this case will not stop the function from 337 | being compiled. 338 | 339 | Args: 340 | enable_pmap_patching: Whether to patch `jax.pmap`. 341 | enable_jit_patching: Whether to patch `jax.jit`. 342 | 343 | Returns: 344 | Context where jax.pmap and jax.jit are patched with jax.vmap and the 345 | identity function 346 | """ 347 | stack = FakeContext() 348 | stack.enter_context(fake_pmap(enable_pmap_patching)) 349 | stack.enter_context(fake_jit(enable_jit_patching)) 350 | return stack 351 | 352 | 353 | class OnCallOfTransformedFunction(): 354 | """Injects a callback into any transformed function. 355 | 356 | A typical use-case is jax.jit or jax.pmap which is often hidden deep inside 357 | the code. This context manager allows to inject a callback function into 358 | functions which are transformed by the user-specified transformation. 359 | The callback will receive the transformed function and its arguments. 360 | 361 | The function can be useful to debug, profile and check the calls of any 362 | transformed function in a program 363 | 364 | For instance: 365 | 366 | with chex.OnCallOfTransformedFunction('jax.jit', print): 367 | [...] 368 | 369 | would print all calls to any function which was jit-compiled within this 370 | context. 371 | 372 | We can also automatically create profiles on the first call of all the 373 | jit compiled functions in the program: 374 | 375 | class profile_once(): 376 | def __init__(self): 377 | self._first_call = True 378 | 379 | def __call__(self, fn, *args, **kwargs): 380 | if self._first_call: 381 | self._first_call = False 382 | print(profile_from_HLO(fn.lower(*args, **kwargs)) 383 | 384 | with chex.OnCallOfTransformedFunction('jax.jit', profile_once()): 385 | [...] 386 | """ 387 | 388 | def __init__(self, fn_transformation: str, callback_fn: Callable[..., Any]): 389 | """Creates a new OnCallOfTransformedFunction context manager. 390 | 391 | Args: 392 | fn_transformation: identifier of the function transformation e.g. 393 | 'jax.jit', 'jax.pmap', ... 394 | callback_fn: A callback function which receives the transformed function 395 | and its arguments on every call. 396 | """ 397 | self._fn_transformation = fn_transformation 398 | self._callback_fn = callback_fn 399 | self._patch: mock._patch[Callable[[Any], Any]] = None # pylint: disable=unsubscriptable-object 400 | self._original_fn_transformation = None 401 | 402 | def __enter__(self): 403 | 404 | def _new_fn_transformation(fn, *args, **kwargs): 405 | """Returns a transformed version of the given function.""" 406 | transformed_fn = self._original_fn_transformation(fn, *args, **kwargs) 407 | 408 | @functools.wraps(transformed_fn) 409 | def _new_transformed_fn(*args, **kwargs): 410 | """Returns result of the returned function and calls the callback.""" 411 | self._callback_fn(transformed_fn, *args, **kwargs) 412 | return transformed_fn(*args, **kwargs) 413 | 414 | return _new_transformed_fn 415 | 416 | self._patch = mock.patch(self._fn_transformation, _new_fn_transformation) 417 | self._original_fn_transformation, unused_local = self._patch.get_original() 418 | self._patch.start() 419 | 420 | def __exit__(self, *unused_args): 421 | self._patch.stop() 422 | -------------------------------------------------------------------------------- /chex/_src/fake_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `fake.py`.""" 16 | 17 | import dataclasses 18 | import functools 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | from chex._src import asserts 23 | from chex._src import fake 24 | from chex._src import pytypes 25 | import jax 26 | import jax.numpy as jnp 27 | 28 | ArrayBatched = pytypes.ArrayBatched 29 | ArraySharded = pytypes.ArraySharded 30 | 31 | 32 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. 33 | def setUpModule(): 34 | fake.set_n_cpu_devices() 35 | 36 | 37 | def _assert_jitted(fn, fn_input, is_jitted): 38 | """Asserts that a function can be jitted or not. 39 | 40 | Args: 41 | fn: The function to be tested 42 | fn_input: Input to pass to the function 43 | is_jitted: Assert that the function can be jitted with jax.jit (True) or 44 | cannot be jitted (False), i.e. the fake jit is working correctly. 45 | """ 46 | asserts.clear_trace_counter() 47 | max_traces = 1 if is_jitted else 0 48 | wrapped_fn = jax.jit(asserts.assert_max_traces(fn, max_traces)) 49 | wrapped_fn(fn_input) 50 | 51 | 52 | def _assert_pmapped(fn, fn_input, is_pmapped, should_jit=False): 53 | """Asserts whether a function can be pmapped or not. 54 | 55 | Args: 56 | fn: The function to be tested 57 | fn_input: Input to pass to the function 58 | is_pmapped: Assert that the function can be pmapped with jax.pmap (True) or 59 | cannot be pmapped (False), i.e. the fake pmap is working correctly. 60 | should_jit: if True, asserts that the function is jitted, regardless of it 61 | being pmapped or not. 62 | """ 63 | num_devices = len(jax.devices()) 64 | if should_jit: 65 | asserts.clear_trace_counter() 66 | fn = asserts.assert_max_traces(fn, n=1) 67 | wrapped_fn = jax.pmap(fn, axis_size=num_devices) 68 | 69 | fn_input = jnp.broadcast_to(fn_input, (num_devices,) + fn_input.shape) 70 | output = wrapped_fn(fn_input) 71 | 72 | # We test whether the function has been pmapped by inspecting the type of 73 | # the function output, if it is a sharded array type then the function has 74 | # been pmapped 75 | if is_pmapped: 76 | expected_type = jax.Array 77 | assert_message = f'Output is type {type(output)}, expected {expected_type}' 78 | assert isinstance(output, expected_type), assert_message 79 | else: 80 | expected_type = 'DeviceArray' 81 | assert_message = f'Output is type {type(output)}, expected {expected_type}' 82 | # ShardedDeviceArray is a subclass of DeviceArray. So, to enforce we have 83 | # a DeviceArray, we also check it's not a sharded one. 84 | assert (isinstance(output, jax.Array) and 85 | len(output.sharding.device_set) == 1), assert_message 86 | 87 | 88 | class PmapFakeTest(parameterized.TestCase): 89 | 90 | def test_assert_pmapped(self): 91 | def foo(x): 92 | return x * 2 93 | fn_input = jnp.ones((4,)) 94 | 95 | _assert_pmapped(foo, fn_input, True) 96 | # Since this test runs only on 1 device, having a test to check if the 97 | # output is sharded or not is not correct. With jax.Array, you can check 98 | # the `len(output.sharding.device_set)` to see if its sharded or not, but 99 | # here because of a single device it fails. 100 | 101 | def test_assert_jitted(self): 102 | fn_input = jnp.ones((4,)) 103 | def foo(x): 104 | return x * 2 105 | 106 | _assert_jitted(foo, fn_input, True) 107 | with self.assertRaises(AssertionError): 108 | _assert_jitted(foo, fn_input, False) 109 | 110 | @parameterized.named_parameters([ 111 | ('plain_jit', {'enable_patching': True}, False), 112 | ('faked_jit', {'enable_patching': False}, True), 113 | ]) 114 | def test_fake_jit(self, fake_kwargs, is_jitted): 115 | fn_input = jnp.ones((4,)) 116 | def foo(x): 117 | return x * 2 118 | 119 | # Call with context manager 120 | with fake.fake_jit(**fake_kwargs): 121 | _assert_jitted(foo, fn_input, is_jitted) 122 | 123 | # Call with start/stop 124 | ctx = fake.fake_jit(**fake_kwargs) 125 | ctx.start() 126 | _assert_jitted(foo, fn_input, is_jitted) 127 | ctx.stop() 128 | 129 | @parameterized.named_parameters([ 130 | ('plain_pmap_but_jit', True, True), 131 | ('plain_pmap', True, False), 132 | ('faked_pmap_but_jit', False, True), 133 | ('faked_pmap', False, False), 134 | ]) 135 | def test_fake_pmap_(self, is_pmapped, jit_result): 136 | enable_patching = not is_pmapped 137 | 138 | fn_input = jnp.ones((4,)) 139 | def foo(x): 140 | return x * 2 141 | 142 | # Call with context manager 143 | with fake.fake_pmap(enable_patching=enable_patching, jit_result=jit_result): 144 | _assert_pmapped(foo, fn_input, is_pmapped, jit_result) 145 | 146 | # Call with start/stop 147 | ctx = fake.fake_pmap(enable_patching=enable_patching, jit_result=jit_result) 148 | ctx.start() 149 | _assert_pmapped(foo, fn_input, is_pmapped, jit_result) 150 | ctx.stop() 151 | 152 | def test_fake_pmap_axis_name(self): 153 | 154 | with fake.fake_pmap(): 155 | 156 | @functools.partial(jax.pmap, axis_name='i') 157 | @functools.partial(jax.pmap, axis_name='j') 158 | def f(_): 159 | return jax.lax.axis_index('i'), jax.lax.axis_index('j') 160 | x, y = f(jnp.zeros((4, 2))) 161 | 162 | self.assertEqual(x.tolist(), [[0, 0], [1, 1], [2, 2], [3, 3]]) 163 | self.assertEqual(y.tolist(), [[0, 1], [0, 1], [0, 1], [0, 1]]) 164 | 165 | @parameterized.named_parameters([ 166 | ('fake_nothing', { 167 | 'enable_pmap_patching': False, 168 | 'enable_jit_patching': False 169 | }, True, True), 170 | ('fake_pmap', { 171 | 'enable_pmap_patching': True, 172 | 'enable_jit_patching': False 173 | }, False, True), 174 | # Default pmap will implicitly compile the function 175 | ('fake_jit', { 176 | 'enable_pmap_patching': False, 177 | 'enable_jit_patching': True 178 | }, True, False), 179 | ('fake_both', { 180 | 'enable_pmap_patching': True, 181 | 'enable_jit_patching': True 182 | }, False, False), 183 | ]) 184 | def test_pmap_and_jit(self, fake_kwargs, is_pmapped, is_jitted): 185 | fn_input = jnp.ones((4,)) 186 | def foo(x): 187 | return x * 2 188 | 189 | # Call with context manager 190 | with fake.fake_pmap_and_jit(**fake_kwargs): 191 | _assert_pmapped(foo, fn_input, is_pmapped) 192 | _assert_jitted(foo, fn_input, is_jitted) 193 | 194 | # Call with start/stop 195 | ctx = fake.fake_pmap_and_jit(**fake_kwargs) 196 | ctx.start() 197 | _assert_pmapped(foo, fn_input, is_pmapped) 198 | _assert_jitted(foo, fn_input, is_jitted) 199 | ctx.stop() 200 | 201 | @parameterized.named_parameters([ 202 | ('fake_nothing', False, False), 203 | ('fake_pmap', True, False), 204 | ('fake_jit', False, True), 205 | ('fake_both', True, True), 206 | ]) 207 | def test_with_kwargs(self, fake_pmap, fake_jit): 208 | with fake.fake_pmap_and_jit(fake_pmap, fake_jit): 209 | num_devices = len(jax.devices()) 210 | 211 | @functools.partial(jax.pmap, axis_size=num_devices) 212 | @jax.jit 213 | def foo(x, y): 214 | return (x * 2) + y 215 | 216 | # pmap over all available devices 217 | inputs = jnp.array([1, 2]) 218 | inputs = jnp.broadcast_to(inputs, (num_devices,) + inputs.shape) 219 | expected = jnp.broadcast_to(jnp.array([3, 6]), (num_devices, 2)) 220 | 221 | asserts.assert_trees_all_close(foo(x=inputs, y=inputs), expected) 222 | 223 | @parameterized.named_parameters([ 224 | ('fake_nothing', False, 1), 225 | ('fake_pmap', True, 1), 226 | ('fake_nothing_no_static_args', False, ()), 227 | ('fake_pmap_no_static_args', True, ()), 228 | ]) 229 | def test_with_static_broadcasted_argnums(self, fake_pmap, static_argnums): 230 | with fake.fake_pmap_and_jit(fake_pmap, enable_jit_patching=False): 231 | num_devices = len(jax.devices()) 232 | 233 | # Note: mode='bar' is intended to test that we correctly handle kwargs 234 | # with defaults for which we don't pass a value at call time. 235 | @functools.partial( 236 | jax.pmap, 237 | axis_size=num_devices, 238 | static_broadcasted_argnums=static_argnums, 239 | ) 240 | @functools.partial( 241 | jax.jit, 242 | static_argnums=static_argnums, 243 | ) 244 | def foo(x, multiplier, y, mode='bar'): 245 | if static_argnums == 1 or 1 in static_argnums: 246 | # Verify that the static arguments are not replaced with tracers. 247 | self.assertIsInstance(multiplier, int) 248 | 249 | if mode == 'bar': 250 | return (x * multiplier) + y 251 | else: 252 | return x 253 | 254 | # pmap over all available devices 255 | inputs = jnp.array([1, 2]) 256 | inputs = jnp.broadcast_to(inputs, (num_devices,) + inputs.shape) 257 | func = lambda: foo(inputs, 100, inputs) # Pass multiplier=100. 258 | 259 | if static_argnums == 1: # Should work. 260 | expected = jnp.broadcast_to(jnp.array([101, 202]), (num_devices, 2)) 261 | result = func() 262 | asserts.assert_trees_all_close(result, expected) 263 | else: # Should error. 264 | with self.assertRaises(ValueError): 265 | result = func() 266 | 267 | @parameterized.parameters(1, [1]) 268 | def test_pmap_with_complex_static_broadcasted_object(self, static_argnums): 269 | 270 | @dataclasses.dataclass 271 | class Multiplier: 272 | x: int 273 | y: int 274 | 275 | def foo(x, multiplier, y): 276 | if static_argnums == 1 or 1 in static_argnums: 277 | # Verify that the static arguments are not replaced with tracers. 278 | self.assertIsInstance(multiplier, Multiplier) 279 | 280 | return x * multiplier.x + y * multiplier.y 281 | 282 | with fake.fake_pmap_and_jit(): 283 | num_devices = jax.device_count() 284 | 285 | # pmap over all available devices 286 | transformed_foo = jax.pmap( 287 | foo, 288 | axis_size=num_devices, 289 | static_broadcasted_argnums=static_argnums, 290 | ) 291 | x, y = jax.random.randint( 292 | jax.random.PRNGKey(27), (2, num_devices, 3, 5), 0, 10 293 | ) 294 | 295 | # Test 1. 296 | mult = Multiplier(x=2, y=7) 297 | asserts.assert_trees_all_equal( 298 | transformed_foo(x, mult, y), 299 | foo(x, mult, y), 300 | x * mult.x + y * mult.y, 301 | ) 302 | 303 | # Test 2. 304 | mult = Multiplier(x=72, y=21) 305 | asserts.assert_trees_all_equal( 306 | transformed_foo(x, mult, y), 307 | foo(x, mult, y), 308 | x * mult.x + y * mult.y, 309 | ) 310 | 311 | @parameterized.named_parameters([ 312 | ('fake_nothing', False, False), 313 | ('fake_pmap', True, False), 314 | ('fake_jit', False, True), 315 | ('fake_both', True, True), 316 | ]) 317 | def test_with_partial(self, fake_pmap, fake_jit): 318 | with fake.fake_pmap_and_jit(fake_pmap, fake_jit): 319 | num_devices = len(jax.devices()) 320 | 321 | # Testing a common use-case where non-parallel arguments are partially 322 | # applied before pmapping 323 | def foo(x, y, flag): 324 | return (x * 2) + y if flag else (x + y) 325 | foo = functools.partial(foo, flag=True) 326 | 327 | foo = jax.pmap(foo, axis_size=num_devices) 328 | foo = jax.jit(foo) 329 | 330 | # pmap over all available devices 331 | inputs = jnp.array([1, 2]) 332 | inputs = jnp.broadcast_to(inputs, (num_devices,) + inputs.shape) 333 | expected = jnp.broadcast_to(jnp.array([3, 6]), (num_devices, 2)) 334 | 335 | asserts.assert_trees_all_close(foo(inputs, inputs), expected) 336 | asserts.assert_trees_all_close(foo(x=inputs, y=inputs), expected) 337 | 338 | @parameterized.named_parameters([ 339 | ('fake_nothing', False, False), 340 | ('fake_pmap', True, False), 341 | ('fake_jit', False, True), 342 | ('fake_both', True, True), 343 | ]) 344 | def test_with_default_params(self, fake_pmap, fake_jit): 345 | with fake.fake_pmap_and_jit(fake_pmap, fake_jit): 346 | num_devices = len(jax.devices()) 347 | 348 | # Default flag specified at definition time 349 | def foo(x, y, flag=True): 350 | return (x * 2) + y if flag else (x + y) 351 | 352 | default_foo = jax.pmap(foo, axis_size=num_devices) 353 | default_foo = jax.jit(default_foo) 354 | 355 | inputs = jnp.array([1, 2]) 356 | inputs = jnp.broadcast_to(inputs, (num_devices,) + inputs.shape) 357 | expected = jnp.broadcast_to(jnp.array([3, 6]), (num_devices, 2)) 358 | asserts.assert_trees_all_close(default_foo(inputs, inputs), expected) 359 | asserts.assert_trees_all_close(default_foo(x=inputs, y=inputs), expected) 360 | 361 | # Default overriden by partial to execute other branch 362 | overidden_foo = functools.partial(foo, flag=False) 363 | overidden_foo = jax.pmap(overidden_foo, axis_size=num_devices) 364 | overidden_foo = jax.jit(overidden_foo) 365 | 366 | expected = jnp.broadcast_to(jnp.array([2, 4]), (num_devices, 2)) 367 | asserts.assert_trees_all_close(overidden_foo(inputs, inputs), expected) 368 | asserts.assert_trees_all_close( 369 | overidden_foo(x=inputs, y=inputs), expected) 370 | 371 | def test_parallel_ops_equivalence(self): 372 | """Test equivalence between parallel operations using pmap and vmap.""" 373 | num_devices = len(jax.devices()) 374 | inputs = jax.random.uniform(shape=(num_devices, num_devices, 2), 375 | key=jax.random.PRNGKey(1)) 376 | 377 | def test_equivalence(fn): 378 | with fake.fake_pmap(enable_patching=False): 379 | outputs1 = jax.pmap(fn, axis_name='i', axis_size=num_devices)(inputs) 380 | with fake.fake_pmap(enable_patching=True): 381 | outputs2 = jax.pmap(fn, axis_name='i', axis_size=num_devices)(inputs) 382 | with fake.fake_pmap(enable_patching=True, jit_result=True): 383 | outputs3 = jax.pmap(fn, axis_name='i', axis_size=num_devices)(inputs) 384 | asserts.assert_trees_all_close(outputs1, outputs2, outputs3) 385 | 386 | parallel_ops_and_kwargs = [ 387 | (jax.lax.psum, {}), 388 | (jax.lax.pmax, {}), 389 | (jax.lax.pmin, {}), 390 | (jax.lax.pmean, {}), 391 | (jax.lax.all_gather, {}), 392 | (jax.lax.all_to_all, { 393 | 'split_axis': 0, 394 | 'concat_axis': 1 395 | }), 396 | (jax.lax.ppermute, { 397 | 'perm': [(x, (x + 1) % num_devices) for x in range(num_devices)] 398 | }), 399 | ] 400 | 401 | def fn(op, kwargs, x, y=2.0): 402 | return op(x * y, axis_name='i', **kwargs) 403 | partial_fn = functools.partial(fn, y=4.0) 404 | lambda_fn = lambda op, kwargs, x: fn(op, kwargs, x, y=5.0) 405 | 406 | for op, kwargs in parallel_ops_and_kwargs: 407 | test_equivalence(functools.partial(fn, op, kwargs)) 408 | test_equivalence(functools.partial(fn, op, kwargs, y=3.0)) 409 | test_equivalence(functools.partial(partial_fn, op, kwargs)) 410 | test_equivalence(functools.partial(lambda_fn, op, kwargs)) 411 | 412 | def test_fake_parallel_axis(self): 413 | inputs = jnp.ones(shape=(2, 2)) 414 | with fake.fake_pmap(fake_parallel_axis=False): 415 | @jax.pmap 416 | def no_fake_parallel_axis_fn(x): 417 | asserts.assert_shape(x, (2,)) 418 | return 2.0 * x 419 | 420 | outputs = no_fake_parallel_axis_fn(inputs) 421 | asserts.assert_trees_all_close(outputs, 2.0) 422 | 423 | with fake.fake_pmap(fake_parallel_axis=True): 424 | @jax.pmap 425 | def fake_parallel_axis_fn(x): 426 | asserts.assert_shape(x, (2, 2,)) 427 | return 2.0 * x 428 | 429 | outputs = fake_parallel_axis_fn(inputs) 430 | asserts.assert_trees_all_close(outputs, 2.0) 431 | 432 | 433 | class _Counter(): 434 | """Counts how often an instance is called.""" 435 | 436 | def __init__(self): 437 | self.count = 0 438 | 439 | def __call__(self, *unused_args, **unused_kwargs): 440 | self.count += 1 441 | 442 | 443 | class OnCallOfTransformedFunctionTest(parameterized.TestCase): 444 | 445 | def test_on_call_of_transformed_function(self): 446 | counter = _Counter() 447 | with fake.OnCallOfTransformedFunction('jax.jit', counter): 448 | jax.jit(jnp.sum)(jnp.zeros((10,))) 449 | jax.jit(jnp.max)(jnp.zeros((10,))) 450 | self.assertEqual(counter.count, 2) 451 | 452 | 453 | if __name__ == '__main__': 454 | absltest.main() 455 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chex 2 | 3 | ![CI status](https://github.com/deepmind/chex/workflows/ci/badge.svg) 4 | ![docs](https://readthedocs.org/projects/chex/badge/?version=latest) 5 | ![pypi](https://img.shields.io/pypi/v/chex) 6 | 7 | Chex is a library of utilities for helping to write reliable JAX code. 8 | 9 | This includes utils to help: 10 | 11 | * Instrument your code (e.g. assertions, warnings) 12 | * Debug (e.g. transforming `pmaps` in `vmaps` within a context manager). 13 | * Test JAX code across many `variants` (e.g. jitted vs non-jitted). 14 | 15 | ## Installation 16 | 17 | You can install the latest released version of Chex from PyPI via: 18 | 19 | ```sh 20 | pip install chex 21 | ``` 22 | 23 | or you can install the latest development version from GitHub: 24 | 25 | ```sh 26 | pip install git+https://github.com/deepmind/chex.git 27 | ``` 28 | 29 | ## Modules Overview 30 | 31 | ### Dataclass ([dataclass.py](https://github.com/deepmind/chex/blob/master/chex/_src/dataclass.py)) 32 | 33 | Dataclasses are a popular construct introduced by Python 3.7 to allow to 34 | easily specify typed data structures with minimal boilerplate code. They are 35 | not, however, compatible with JAX and 36 | [dm-tree](https://github.com/deepmind/tree) out of the box. 37 | 38 | In Chex we provide a JAX-friendly dataclass implementation reusing python [dataclasses](https://docs.python.org/3/library/dataclasses.html#module-dataclasses). 39 | 40 | Chex implementation of `dataclass` registers dataclasses as internal [_PyTree_ 41 | nodes](https://jax.readthedocs.io/en/latest/pytrees.html) to ensure 42 | compatibility with JAX data structures. 43 | 44 | In addition, we provide a class wrapper that exposes dataclasses as 45 | `collections.Mapping` descendants which allows to process them 46 | (e.g. (un-)flatten) in `dm-tree` methods as usual Python dictionaries. 47 | See [`@mappable_dataclass`](https://github.com/deepmind/chex/blob/master/chex/_src/dataclass.py#L27) 48 | docstring for more details. 49 | 50 | Example: 51 | 52 | ```python 53 | @chex.dataclass 54 | class Parameters: 55 | x: chex.ArrayDevice 56 | y: chex.ArrayDevice 57 | 58 | parameters = Parameters( 59 | x=jnp.ones((2, 2)), 60 | y=jnp.ones((1, 2)), 61 | ) 62 | 63 | # Dataclasses can be treated as JAX pytrees 64 | jax.tree_util.tree_map(lambda x: 2.0 * x, parameters) 65 | 66 | # and as mappings by dm-tree 67 | tree.flatten(parameters) 68 | ``` 69 | 70 | **NOTE**: Unlike standard Python 3.7 dataclasses, Chex 71 | dataclasses cannot be constructed using positional arguments. They support 72 | construction arguments provided in the same format as the Python dict 73 | constructor. Dataclasses can be converted to tuples with the `from_tuple` and 74 | `to_tuple` methods if necessary. 75 | 76 | ```python 77 | parameters = Parameters( 78 | jnp.ones((2, 2)), 79 | jnp.ones((1, 2)), 80 | ) 81 | # ValueError: Mappable dataclass constructor doesn't support positional args. 82 | ``` 83 | 84 | ### Assertions ([asserts.py](https://github.com/deepmind/chex/blob/master/chex/_src/asserts.py)) 85 | 86 | One limitation of PyType annotations for JAX is that they do not support the 87 | specification of `DeviceArray` ranks, shapes or dtypes. Chex includes a number 88 | of functions that allow flexible and concise specification of these properties. 89 | 90 | E.g. suppose you want to ensure that all tensors `t1`, `t2`, `t3` have the same 91 | shape, and that tensors `t4`, `t5` have rank `2` and (`3` or `4`), respectively. 92 | 93 | ```python 94 | chex.assert_equal_shape([t1, t2, t3]) 95 | chex.assert_rank([t4, t5], [2, {3, 4}]) 96 | ``` 97 | 98 | More examples: 99 | 100 | ```python 101 | from chex import assert_shape, assert_rank, ... 102 | 103 | assert_shape(x, (2, 3)) # x has shape (2, 3) 104 | assert_shape([x, y], [(), (2,3)]) # x is scalar and y has shape (2, 3) 105 | 106 | assert_rank(x, 0) # x is scalar 107 | assert_rank([x, y], [0, 2]) # x is scalar and y is a rank-2 array 108 | assert_rank([x, y], {0, 2}) # x and y are scalar OR rank-2 arrays 109 | 110 | assert_type(x, int) # x has type `int` (x can be an array) 111 | assert_type([x, y], [int, float]) # x has type `int` and y has type `float` 112 | 113 | assert_equal_shape([x, y, z]) # x, y, and z have equal shapes 114 | 115 | assert_trees_all_close(tree_x, tree_y) # values and structure of trees match 116 | assert_tree_all_finite(tree_x) # all tree_x leaves are finite 117 | 118 | assert_devices_available(2, 'gpu') # 2 GPUs available 119 | assert_tpu_available() # at least 1 TPU available 120 | 121 | assert_numerical_grads(f, (x, y), j) # f^{(j)}(x, y) matches numerical grads 122 | ``` 123 | 124 | See `asserts.py` 125 | [documentation](https://chex.readthedocs.io/en/latest/api.html#assertions) to 126 | find all supported assertions. 127 | 128 | If you cannot find a specific assertion, please consider making a pull request 129 | or openning an issue on 130 | [the bug tracker](https://github.com/deepmind/chex/issues). 131 | 132 | #### Optional Arguments 133 | 134 | All chex assertions support the following optional kwargs for manipulating the 135 | emitted exception messages: 136 | 137 | * `custom_message`: A string to include into the emitted exception messages. 138 | * `include_default_message`: Whether to include the default Chex message into 139 | the emitted exception messages. 140 | * `exception_type`: An exception type to use. `AssertionError` by default. 141 | 142 | For example, the following code: 143 | 144 | ```python 145 | dataset = load_dataset() 146 | params = init_params() 147 | for i in range(num_steps): 148 | params = update_params(params, dataset.sample()) 149 | chex.assert_tree_all_finite(params, 150 | custom_message=f'Failed at iteration {i}.', 151 | exception_type=ValueError) 152 | ``` 153 | 154 | will raise a `ValueError` that includes a step number when `params` get polluted 155 | with `NaNs` or `None`s. 156 | 157 | #### Static and Value (aka *Runtime*) Assertions 158 | 159 | Chex divides all assertions into 2 classes: ***static*** and ***value*** 160 | assertions. 161 | 162 | 1. ***static*** assertions use anything except concrete values of tensors. 163 | Examples: `assert_shape`, `assert_trees_all_equal_dtypes`, 164 | `assert_max_traces`. 165 | 166 | 2. ***value*** assertions require access to tensor values, which are not 167 | available during JAX tracing (see 168 | [HowJAX primitives work](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html)), 169 | thus such assertion need special treatment in a *jitted* code. 170 | 171 | To enable value assertions in a jitted function, it can be decorated with 172 | `chex.chexify()` wrapper. Example: 173 | 174 | ```python 175 | @chex.chexify 176 | @jax.jit 177 | def logp1_abs_safe(x: chex.Array) -> chex.Array: 178 | chex.assert_tree_all_finite(x) 179 | return jnp.log(jnp.abs(x) + 1) 180 | 181 | logp1_abs_safe(jnp.ones(2)) # OK 182 | logp1_abs_safe(jnp.array([jnp.nan, 3])) # FAILS (in async mode) 183 | 184 | # The error will be raised either at the next line OR at the next 185 | # `logp1_abs_safe` call. See the docs for more detain on async mode. 186 | logp1_abs_safe.wait_checks() # Wait for the (async) computation to complete. 187 | ``` 188 | 189 | See 190 | [this docstring](https://chex.readthedocs.io/en/latest/api.html#chex.chexify) 191 | for more detail on `chex.chexify()`. 192 | 193 | #### JAX Tracing Assertions 194 | 195 | JAX re-traces JIT'ted function every time the structure of passed arguments 196 | changes. Often this behavior is inadvertent and leads to a significant 197 | performance drop which is hard to debug. [@chex.assert_max_traces](https://github.com/deepmind/chex/blob/master/chex/_src/asserts.py#L44) 198 | decorator asserts that the function is not re-traced more than `n` times during 199 | program execution. 200 | 201 | Global trace counter can be cleared by calling 202 | `chex.clear_trace_counter()`. This function be used to isolate unittests relying 203 | on `@chex.assert_max_traces`. 204 | 205 | Examples: 206 | 207 | ```python 208 | @jax.jit 209 | @chex.assert_max_traces(n=1) 210 | def fn_sum_jitted(x, y): 211 | return x + y 212 | 213 | fn_sum_jitted(jnp.zeros(3), jnp.zeros(3)) # tracing for the 1st time - OK 214 | fn_sum_jitted(jnp.zeros([6, 7]), jnp.zeros([6, 7])) # AssertionError! 215 | ``` 216 | 217 | Can be used with `jax.pmap()` as well: 218 | 219 | ```python 220 | def fn_sub(x, y): 221 | return x - y 222 | 223 | fn_sub_pmapped = jax.pmap(chex.assert_max_traces(fn_sub, n=10)) 224 | ``` 225 | 226 | See 227 | [HowJAX primitives work](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html) 228 | section for more information about tracing. 229 | 230 | ### Warnings ([warnigns.py](https://github.com/deepmind/chex/blob/master/chex/_src/warnings.py)) 231 | 232 | In addition to hard assertions Chex also offers utilities to add common 233 | warnings, such as specific types of deprecation warnings. 234 | 235 | ### Test variants ([variants.py](https://github.com/deepmind/chex/blob/master/chex/_src/variants.py)) 236 | 237 | JAX relies extensively on code transformation and compilation, meaning that it 238 | can be hard to ensure that code is properly tested. For instance, just testing a 239 | python function using JAX code will not cover the actual code path that is 240 | executed when jitted, and that path will also differ whether the code is jitted 241 | for CPU, GPU, or TPU. This has been a source of obscure and hard to catch bugs 242 | where XLA changes would lead to undesirable behaviours that however only 243 | manifest in one specific code transformation. 244 | 245 | Variants make it easy to ensure that unit tests cover different ‘variations’ of 246 | a function, by providing a simple decorator that can be used to repeat any test 247 | under all (or a subset) of the relevant code transformations. 248 | 249 | E.g. suppose you want to test the output of a function `fn` with or without jit. 250 | You can use `chex.variants` to run the test with both the jitted and non-jitted 251 | version of the function by simply decorating a test method with 252 | `@chex.variants`, and then using `self.variant(fn)` in place of `fn` in the body 253 | of the test. 254 | 255 | ```python 256 | def fn(x, y): 257 | return x + y 258 | ... 259 | 260 | class ExampleTest(chex.TestCase): 261 | 262 | @chex.variants(with_jit=True, without_jit=True) 263 | def test(self): 264 | var_fn = self.variant(fn) 265 | self.assertEqual(fn(1, 2), 3) 266 | self.assertEqual(var_fn(1, 2), fn(1, 2)) 267 | ``` 268 | 269 | If you define the function in the test method, you may also use `self.variant` 270 | as a decorator in the function definition. For example: 271 | 272 | ```python 273 | class ExampleTest(chex.TestCase): 274 | 275 | @chex.variants(with_jit=True, without_jit=True) 276 | def test(self): 277 | @self.variant 278 | def var_fn(x, y): 279 | return x + y 280 | 281 | self.assertEqual(var_fn(1, 2), 3) 282 | ``` 283 | 284 | Example of parameterized test: 285 | 286 | ```python 287 | from absl.testing import parameterized 288 | 289 | # Could also be: 290 | # `class ExampleParameterizedTest(chex.TestCase, parameterized.TestCase):` 291 | # `class ExampleParameterizedTest(chex.TestCase):` 292 | class ExampleParameterizedTest(parameterized.TestCase): 293 | 294 | @chex.variants(with_jit=True, without_jit=True) 295 | @parameterized.named_parameters( 296 | ('case_positive', 1, 2, 3), 297 | ('case_negative', -1, -2, -3), 298 | ) 299 | def test(self, arg_1, arg_2, expected): 300 | @self.variant 301 | def var_fn(x, y): 302 | return x + y 303 | 304 | self.assertEqual(var_fn(arg_1, arg_2), expected) 305 | ``` 306 | 307 | Chex currently supports the following variants: 308 | 309 | * `with_jit` -- applies `jax.jit()` transformation to the function. 310 | * `without_jit` -- uses the function as is, i.e. identity transformation. 311 | * `with_device` -- places all arguments (except specified in `ignore_argnums` 312 | argument) into device memory before applying the function. 313 | * `without_device` -- places all arguments in RAM before applying the function. 314 | * `with_pmap` -- applies `jax.pmap()` transformation to the function (see notes below). 315 | 316 | See documentation in [variants.py](https://github.com/deepmind/chex/blob/master/chex/_src/variants.py) for more details on the supported variants. 317 | More examples can be found in [variants_test.py](https://github.com/deepmind/chex/blob/master/chex/_src/variants_test.py). 318 | 319 | ### Variants notes 320 | 321 | * Test classes that use `@chex.variants` must inherit from 322 | `chex.TestCase` (or any other base class that unrolls tests generators 323 | within `TestCase`, e.g. `absl.testing.parameterized.TestCase`). 324 | 325 | * **[`jax.vmap`]** All variants can be applied to a vmapped function; 326 | please see an example in [variants_test.py](https://github.com/deepmind/chex/blob/master/chex/_src/variants_test.py) (`test_vmapped_fn_named_params` and 327 | `test_pmap_vmapped_fn`). 328 | 329 | * **[`@chex.all_variants`]** You can get all supported variants 330 | by using the decorator `@chex.all_variants`. 331 | 332 | * **[`with_pmap` variant]** `jax.pmap(fn)` 333 | ([doc](https://jax.readthedocs.io/en/latest/jax.html#jax.pmap)) performs 334 | parallel map of `fn` onto multiple devices. Since most tests run in a 335 | single-device environment (i.e. having access to a single CPU or GPU), in which 336 | case `jax.pmap` is a functional equivalent to `jax.jit`, ` with_pmap` variant is 337 | skipped by default (although it works fine with a single device). Below we 338 | describe a way to properly test `fn` if it is supposed to be used in 339 | multi-device environments (TPUs or multiple CPUs/GPUs). To disable skipping 340 | `with_pmap` variants in case of a single device, add 341 | `--chex_skip_pmap_variant_if_single_device=false` to your test command. 342 | 343 | ### Fakes ([fake.py](https://github.com/deepmind/chex/blob/master/chex/_src/fake.py)) 344 | 345 | Debugging in JAX is made more difficult by code transformations such as `jit` 346 | and `pmap`, which introduce optimizations that make code hard to inspect and 347 | trace. It can also be difficult to disable those transformations during 348 | debugging as they can be called at several places in the underlying 349 | code. Chex provides tools to globally replace `jax.jit` with a no-op 350 | transformation and `jax.pmap` with a (non-parallel) `jax.vmap`, in order to more 351 | easily debug code in a single-device context. 352 | 353 | For example, you can use Chex to fake `pmap` and have it replaced with a `vmap`. 354 | This can be achieved by wrapping your code with a context manager: 355 | 356 | ```python 357 | with chex.fake_pmap(): 358 | @jax.pmap 359 | def fn(inputs): 360 | ... 361 | 362 | # Function will be vmapped over inputs 363 | fn(inputs) 364 | ``` 365 | 366 | The same functionality can also be invoked with `start` and `stop`: 367 | 368 | ```python 369 | fake_pmap = chex.fake_pmap() 370 | fake_pmap.start() 371 | ... your jax code ... 372 | fake_pmap.stop() 373 | ``` 374 | 375 | In addition, you can fake a real multi-device test environment with a 376 | multi-threaded CPU. See section **Faking multi-device test environments** for 377 | more details. 378 | 379 | See documentation in [fake.py](https://github.com/deepmind/chex/blob/master/chex/_src/fake.py) and examples in [fake_test.py](https://github.com/deepmind/chex/blob/master/chex/_src/fake_test.py) for more details. 380 | 381 | ## Faking multi-device test environments 382 | 383 | In situations where you do not have easy access to multiple devices, you can 384 | still test parallel computation using single-device multi-threading. 385 | 386 | In particular, one can force XLA to use a single CPU's threads as separate 387 | devices, i.e. to fake a real multi-device environment with a multi-threaded one. 388 | These two options are theoretically equivalent from XLA perspective because they 389 | expose the same interface and use identical abstractions. 390 | 391 | Chex has a flag `chex_n_cpu_devices` that specifies a number of CPU threads to 392 | use as XLA devices. 393 | 394 | To set up a multi-threaded XLA environment for `absl` tests, define 395 | `setUpModule` function in your test module: 396 | 397 | ```python 398 | def setUpModule(): 399 | chex.set_n_cpu_devices() 400 | ``` 401 | 402 | Now you can launch your test with `python test.py --chex_n_cpu_devices=N` to run 403 | it in multi-device regime. Note that **all** tests within a module will have an 404 | access to `N` devices. 405 | 406 | More examples can be found in [variants_test.py](https://github.com/deepmind/chex/blob/master/chex/_src/variants_test.py), [fake_test.py](https://github.com/deepmind/chex/blob/master/chex/_src/fake_test.py) and [fake_set_n_cpu_devices_test.py](https://github.com/deepmind/chex/blob/master/chex/_src/fake_set_n_cpu_devices_test.py). 407 | 408 | ### Using named dimension sizes. 409 | 410 | Chex comes with a small utility that allows you to package a collection of 411 | dimension sizes into a single object. The basic idea is: 412 | 413 | ```python 414 | dims = chex.Dimensions(B=batch_size, T=sequence_len, E=embedding_dim) 415 | ... 416 | chex.assert_shape(arr, dims['BTE']) 417 | ``` 418 | 419 | String lookups are translated integer tuples. For instance, let's say 420 | `batch_size == 3`, `sequence_len = 5` and `embedding_dim = 7`, then 421 | 422 | ```python 423 | dims['BTE'] == (3, 5, 7) 424 | dims['B'] == (3,) 425 | dims['TTBEE'] == (5, 5, 3, 7, 7) 426 | ... 427 | ``` 428 | 429 | You can also assign dimension sizes dynamically as follows: 430 | 431 | ```python 432 | dims['XY'] = some_matrix.shape 433 | dims.Z = 13 434 | ``` 435 | 436 | For more examples, see [chex.Dimensions](https://chex.readthedocs.io/en/latest/api.html#chex.Dimensions) 437 | documentation. 438 | 439 | ## Citing Chex 440 | 441 | This repository is part of the [DeepMind JAX Ecosystem], to cite Chex please use 442 | the [DeepMind JAX Ecosystem citation]. 443 | 444 | [DeepMind JAX Ecosystem]: https://deepmind.com/blog/article/using-jax-to-accelerate-our-research "DeepMind JAX Ecosystem" 445 | [DeepMind JAX Ecosystem citation]: https://github.com/deepmind/jax/blob/main/deepmind2020jax.txt "Citation" 446 | -------------------------------------------------------------------------------- /chex/_src/asserts_internal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Chex assertion internal utilities and symbols. 16 | 17 | [README!] 18 | 19 | We reserve the right to change the code in this module at any time without 20 | providing any guarantees of backward compatibility. For this reason, 21 | we strongly recommend that you avoid using this module directly at all costs! 22 | Instead, consider opening an issue on GitHub and describing your use case. 23 | """ 24 | 25 | import collections 26 | import collections.abc 27 | import functools 28 | import re 29 | import threading 30 | import traceback 31 | from typing import Any, Sequence, Union, Callable, Hashable, List, Optional, Set, Tuple, Type 32 | 33 | from absl import logging 34 | from chex._src import pytypes 35 | import jax 36 | from jax.experimental import checkify 37 | import jax.numpy as jnp 38 | import numpy as np 39 | 40 | # Custom pytypes. 41 | TLeaf = Any 42 | TLeavesEqCmpFn = Callable[[TLeaf, TLeaf], bool] 43 | TLeavesEqCmpErrorFn = Callable[[TLeaf, TLeaf], str] 44 | 45 | # TODO(iukemaev): define a typing protocol for TChexAssertion. 46 | # Chex assertion signature: 47 | # (*args, 48 | # custom_message: Optional[str] = None, 49 | # custom_message_format_vars: Sequence[Any] = (), 50 | # include_default_message: bool = True, 51 | # exception_type: Type[Exception] = AssertionError, 52 | # **kwargs) 53 | TChexAssertion = Callable[..., None] 54 | TAssertFn = Callable[..., None] 55 | TJittableAssertFn = Callable[..., pytypes.Array] # a predicate function 56 | 57 | # Matchers. 58 | TDimMatcher = Optional[Union[int, Set[int], type(Ellipsis)]] 59 | TShapeMatcher = Sequence[TDimMatcher] 60 | 61 | 62 | class _ChexifyStorage(threading.local): 63 | """Thread-safe storage for internal variables used in @chexify.""" 64 | wait_fns = [] 65 | level = 0 66 | 67 | 68 | # Chex namespace variables. 69 | ERR_PREFIX = "[Chex] " 70 | TRACE_COUNTER = collections.Counter() 71 | DISABLE_ASSERTIONS = False 72 | 73 | # This variable is used for _chexify_ transformations, see `asserts_chexify.py`. 74 | CHEXIFY_STORAGE = _ChexifyStorage() 75 | 76 | 77 | def assert_collection_of_arrays(inputs: Sequence[pytypes.Array]): 78 | """Checks if ``inputs`` is a collection of arrays.""" 79 | if not isinstance(inputs, collections.abc.Collection): 80 | raise ValueError(f"`inputs` is not a collection of arrays: {inputs}.") 81 | 82 | 83 | def jnp_to_np_array(arr: pytypes.Array) -> np.ndarray: 84 | """Converts `jnp.ndarray` to `np.ndarray`.""" 85 | if getattr(arr, "dtype", None) == jnp.bfloat16: 86 | # Numpy does not support `bfloat16`. 87 | arr = arr.astype(jnp.float32) 88 | return jax.device_get(arr) 89 | 90 | 91 | def deprecation_wrapper(new_fn, old_name, new_name): 92 | """Allows deprecated functions to continue running, with a warning logged.""" 93 | 94 | def inner_fn(*args, **kwargs): 95 | logging.warning( 96 | "chex.%s has been renamed to chex.%s, please update your code.", 97 | old_name, new_name) 98 | return new_fn(*args, **kwargs) 99 | 100 | return inner_fn 101 | 102 | 103 | def get_stacktrace_without_chex_internals() -> List[traceback.FrameSummary]: 104 | """Returns the latest non-chex frame from the call stack.""" 105 | stacktrace = list(traceback.extract_stack()) 106 | for i in reversed(range(len(stacktrace))): 107 | fname = stacktrace[i].filename 108 | if fname.find("/chex/") == -1 or fname.endswith("_test.py"): 109 | return stacktrace[:i+1] 110 | 111 | debug_info = "\n-----\n".join(traceback.format_stack()) 112 | raise RuntimeError( 113 | "get_stacktrace_without_chex_internals() failed. " 114 | "Please file a bug at https://github.com/deepmind/chex/issues and " 115 | "include the following debug info in it. " 116 | "Please make sure it does not include any private information! " 117 | f"Debug: '{debug_info}'.") 118 | 119 | 120 | def get_err_regex(message: str) -> str: 121 | """Constructs a regexp for the exception message. 122 | 123 | Args: 124 | message: an exception message. 125 | 126 | Returns: 127 | Regexp that ensures the message follows the standard chex formatting. 128 | """ 129 | # (ERR_PREFIX + any symbols (incl. \n) + message) 130 | return f"{re.escape(ERR_PREFIX)}[\\s\\S]*{message}" 131 | 132 | 133 | def get_chexify_err_message(name: str, msg: str = "") -> str: 134 | """Constructs an error message for the chexify exception.""" 135 | return f"{ERR_PREFIX}chexify assertion '{name}' failed: {msg}" 136 | 137 | 138 | def _make_host_assertion(assert_fn: TAssertFn, 139 | name: Optional[str] = None) -> TChexAssertion: 140 | """Constructs a host assertion given `assert_fn`. 141 | 142 | This wrapper should only be applied to the assertions that are either 143 | a) never used in jitted code, or 144 | b) when used in jitted code they do not check/access tensor values (i.e. 145 | they do not introduce value-dependent python control flow, see 146 | https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError). 147 | 148 | Args: 149 | assert_fn: A function implementing the check. 150 | name: A name for assertion. 151 | 152 | Returns: 153 | A chex assertion. 154 | """ 155 | if name is None: 156 | name = assert_fn.__name__ 157 | 158 | def _assert_on_host(*args, 159 | custom_message: Optional[str] = None, 160 | custom_message_format_vars: Sequence[Any] = (), 161 | include_default_message: bool = True, 162 | exception_type: Type[Exception] = AssertionError, 163 | **kwargs) -> None: 164 | # Format error's stack trace to remove Chex' internal frames. 165 | assertion_exc = None 166 | value_exc = None 167 | try: 168 | assert_fn(*args, **kwargs) 169 | except AssertionError as e: 170 | assertion_exc = e 171 | except ValueError as e: 172 | value_exc = e 173 | finally: 174 | if value_exc is not None: 175 | raise ValueError(str(value_exc)) 176 | 177 | if assertion_exc is not None: 178 | # Format the exception message. 179 | error_msg = str(assertion_exc) 180 | 181 | # Include only the name of the outermost chex assertion. 182 | if error_msg.startswith(ERR_PREFIX): 183 | error_msg = error_msg[error_msg.find("failed:") + len("failed:"):] 184 | 185 | # Whether to include the default error message. 186 | default_msg = (f"Assertion {name} failed: " 187 | if include_default_message else "") 188 | error_msg = f"{ERR_PREFIX}{default_msg}{error_msg}" 189 | 190 | # Whether to include a custom error message. 191 | if custom_message: 192 | if custom_message_format_vars: 193 | custom_message = custom_message.format(*custom_message_format_vars) 194 | error_msg = f"{error_msg} [{custom_message}]" 195 | 196 | raise exception_type(error_msg) 197 | 198 | return _assert_on_host 199 | 200 | 201 | def chex_assertion( 202 | assert_fn: TAssertFn, 203 | jittable_assert_fn: Optional[TJittableAssertFn], 204 | name: Optional[str] = None) -> TChexAssertion: 205 | """Wraps Chex assert functions to control their common behaviour. 206 | 207 | Extends the assertion to support the following optional auxiliary kwargs: 208 | custom_message: A string to include into the emitted exception messages. 209 | custom_message_format_vars: A list of variables to pass as arguments to 210 | `custom_message.format()`. 211 | include_default_message: Whether to include the default Chex message into 212 | the emitted exception messages. 213 | exception_type: An exception type to use. `AssertionError` by default. 214 | 215 | Args: 216 | assert_fn: A function implementing the check. 217 | jittable_assert_fn: An optional jittable version of `assert_fn` implementing 218 | a predicate (returning `True` only if assertion passes). 219 | Required for value assertions. 220 | name: A name for assertion. If not provided, use `assert_fn.__name__`. 221 | 222 | Returns: 223 | A Chex assertion (with auxiliary kwargs). 224 | """ 225 | if name is None: 226 | name = assert_fn.__name__ 227 | 228 | host_assertion_fn = _make_host_assertion(assert_fn, name) 229 | 230 | @functools.wraps(assert_fn) 231 | def _chex_assert_fn(*args, 232 | custom_message: Optional[str] = None, 233 | custom_message_format_vars: Sequence[Any] = (), 234 | include_default_message: bool = True, 235 | exception_type: Type[Exception] = AssertionError, 236 | **kwargs) -> None: 237 | if DISABLE_ASSERTIONS: 238 | return 239 | if (jittable_assert_fn is not None and has_tracers((args, kwargs))): 240 | if not CHEXIFY_STORAGE.level: 241 | raise RuntimeError( 242 | "Value assertions can only be called from functions wrapped " 243 | "with `@chex.chexify`. See the docs.") 244 | 245 | # A wrapped to inject auxiliary debug info and `custom_message`. 246 | original_check = checkify.check 247 | 248 | def _check(pred, msg, *fmt_args, **fmt_kwargs): 249 | # Add chex info. 250 | msg = get_chexify_err_message(name, msg) 251 | 252 | # Add a custom message. 253 | if custom_message: 254 | msg += f" Custom message: {custom_message}." 255 | fmt_args = list(fmt_args) + list(custom_message_format_vars) 256 | 257 | # Add a traceback and a pointer to the callsite. 258 | stacktrace = get_stacktrace_without_chex_internals() 259 | msg += ( 260 | f" [failed at: {stacktrace[-1].filename}:{stacktrace[-1].lineno}]" 261 | ) 262 | 263 | # Call original `checkify.check()`. 264 | original_check(pred, msg, *fmt_args, **fmt_kwargs) 265 | 266 | # Mock during the assertion's execution time. 267 | checkify.check = _check 268 | pred = jittable_assert_fn(*args, **kwargs) # execute the assertion 269 | checkify.check = original_check # return the original implementation 270 | 271 | # A safeguard to ensure that the results of a check are not ignored. 272 | # In particular, this check fails when `pred` is False and no 273 | # `checkify.check` calls took place in `jittable_assert_fn`, which would 274 | # be a bug in the assertion's implementation. 275 | checkify.check(pred, "assertion failed!") 276 | else: 277 | try: 278 | host_assertion_fn( 279 | *args, 280 | custom_message=custom_message, 281 | custom_message_format_vars=custom_message_format_vars, 282 | include_default_message=include_default_message, 283 | exception_type=exception_type, 284 | **kwargs) 285 | except jax.errors.ConcretizationTypeError as exc: 286 | msg = ("Chex assertion detected `ConcretizationTypeError`: it is very " 287 | "likely that it tried to access tensors' values during tracing. " 288 | "Make sure that you defined a jittable version of this chex " 289 | "assertion; if that does not help, please file a bug.") 290 | raise exc from RuntimeError(msg) 291 | 292 | # Override name. 293 | setattr(_chex_assert_fn, "__name__", name) 294 | return _chex_assert_fn 295 | 296 | 297 | def format_tree_path(path: Sequence[Any]) -> str: 298 | return "/".join(str(p) for p in path) 299 | 300 | 301 | def format_shape_matcher(shape: TShapeMatcher) -> str: 302 | return f"({', '.join('...' if d is Ellipsis else str(d) for d in shape)})" 303 | 304 | 305 | def num_devices_available(devtype: str, backend: Optional[str] = None) -> int: 306 | """Returns the number of available device of the given type.""" 307 | devtype = devtype.lower() 308 | supported_types = ("cpu", "gpu", "tpu") 309 | if devtype not in supported_types: 310 | raise ValueError( 311 | f"Unknown device type '{devtype}' (expected one of {supported_types}).") 312 | 313 | return sum(d.platform == devtype for d in jax.devices(backend)) 314 | 315 | 316 | def get_tracers(tree: pytypes.ArrayTree) -> Tuple[jax.core.Tracer]: 317 | """Returns a tuple with tracers from a tree.""" 318 | return tuple( 319 | x for x in jax.tree_util.tree_leaves(tree) 320 | if isinstance(x, jax.core.Tracer)) 321 | 322 | 323 | def has_tracers(tree: pytypes.ArrayTree) -> bool: 324 | """Checks whether a tree contains any tracers.""" 325 | return any( 326 | isinstance(x, jax.core.Tracer) for x in jax.tree_util.tree_leaves(tree)) 327 | 328 | 329 | def is_traceable(fn) -> bool: 330 | """Checks if function is traceable. 331 | 332 | JAX traces a function when it is wrapped with @jit, @pmap, or @vmap. 333 | In other words, this function checks whether `fn` is wrapped with any of 334 | the aforementioned JAX transformations. 335 | 336 | Args: 337 | fn: function to assert. 338 | 339 | Returns: 340 | Bool indicating whether fn is traceable. 341 | """ 342 | 343 | fn_string_tokens = ( 344 | ".reraise_with_filtered_traceback", # JIT in Python ver. >= 3.7 345 | "CompiledFunction", # C++ JIT in jaxlib 0.1.66 or newer. 346 | "pmap.", # Python pmap 347 | "PmapFunction", # C++ pmap in jaxlib 0.1.72 or newer. 348 | "vmap.", # vmap 349 | "_python_pjit", 350 | "_cpp_pjit", 351 | ) 352 | 353 | fn_type_tokens = ( 354 | "PmapFunction", 355 | "PjitFunction", 356 | ) 357 | 358 | # Un-wrap `fn` and check if any internal fn is jitted by pattern matching. 359 | fn_ = fn 360 | while True: 361 | if any(t in str(fn_) for t in fn_string_tokens): 362 | return True 363 | 364 | if any(t in str(type(fn_)) for t in fn_type_tokens): 365 | return True 366 | 367 | if hasattr(fn_, "__wrapped__"): 368 | # Wrapper. 369 | fn_globals = getattr(fn_, "__globals__", {}) 370 | 371 | if fn_globals.get("__name__", None) == "jax.api": 372 | # Wrapper from `jax.api`. 373 | return True 374 | 375 | if "api_boundary" in fn_globals: 376 | # api_boundary is a JAX wrapper for traced functions. 377 | return True 378 | 379 | try: 380 | if isinstance(fn_, jax.lib.xla_extension.PjitFunction): 381 | return True 382 | except AttributeError: 383 | pass 384 | else: 385 | break 386 | 387 | fn_ = fn_.__wrapped__ 388 | return False 389 | 390 | 391 | def assert_leaves_all_eq_comparator( 392 | equality_comparator: TLeavesEqCmpFn, 393 | error_msg_fn: Callable[[TLeaf, TLeaf, str, int, int], 394 | str], path: Sequence[Any], *leaves: Sequence[TLeaf]): 395 | """Asserts all leaves are equal using custom comparator. Not jittable.""" 396 | path_str = format_tree_path(path) 397 | for i in range(1, len(leaves)): 398 | if not equality_comparator(leaves[0], leaves[i]): 399 | raise AssertionError(error_msg_fn(leaves[0], leaves[i], path_str, 0, i)) 400 | 401 | 402 | def assert_trees_all_eq_comparator_jittable( 403 | equality_comparator: TLeavesEqCmpFn, 404 | error_msg_template: str, 405 | *trees: Sequence[pytypes.ArrayTree]) -> pytypes.Array: 406 | """Asserts all trees are equal using custom comparator. JIT-friendly.""" 407 | 408 | if len(trees) < 2: 409 | raise ValueError( 410 | "Assertions over only one tree does not make sense. Maybe you wrote " 411 | "`assert_trees_xxx([a, b])` instead of `assert_trees_xxx(a, b)`, or " 412 | "forgot the `error_msg_fn` arg to `assert_trees_xxx`?") 413 | 414 | def _tree_error_msg_fn( 415 | path: Tuple[Union[int, str, Hashable]], i_1: int, i_2: int): 416 | if path: 417 | return ( 418 | f"Trees {i_1} and {i_2} differ in leaves '{path}':" 419 | f" {error_msg_template}" 420 | ) 421 | else: 422 | return f"Trees (arrays) {i_1} and {i_2} differ: {error_msg_template}." 423 | 424 | def _cmp_leaves(path, *leaves): 425 | verdict = jnp.array(True) 426 | for i in range(1, len(leaves)): 427 | check_res = equality_comparator(leaves[0], leaves[i]) 428 | checkify.check( 429 | pred=check_res, 430 | msg=_tree_error_msg_fn(path, 0, i), 431 | arr_1=leaves[0], 432 | arr_2=leaves[i], 433 | ) 434 | verdict = jnp.logical_and(verdict, check_res) 435 | return verdict 436 | 437 | # Trees are guaranteed to have the same structure. 438 | paths = [ 439 | convert_jax_path_to_dm_path(path) 440 | for path, _ in jax.tree_util.tree_flatten_with_path(trees[0])[0]] 441 | trees_leaves = [jax.tree_util.tree_leaves(tree) for tree in trees] 442 | 443 | verdict = jnp.array(True) 444 | for leaf_i, path in enumerate(paths): 445 | verdict = jnp.logical_and( 446 | verdict, _cmp_leaves(path, *[leaves[leaf_i] for leaves in trees_leaves]) 447 | ) 448 | 449 | return verdict 450 | 451 | 452 | JaxKeyType = Union[ 453 | int, 454 | str, 455 | Hashable, 456 | jax.tree_util.SequenceKey, 457 | jax.tree_util.DictKey, 458 | jax.tree_util.FlattenedIndexKey, 459 | jax.tree_util.GetAttrKey, 460 | ] 461 | 462 | 463 | def convert_jax_path_to_dm_path( 464 | jax_tree_path: Sequence[JaxKeyType], 465 | ) -> Tuple[Union[int, str, Hashable]]: 466 | """Converts a path from jax.tree_util to one from dm-tree.""" 467 | 468 | # pytype:disable=attribute-error 469 | def _convert_key_fn(key: JaxKeyType) -> Union[int, str, Hashable]: 470 | if isinstance(key, (str, int)): 471 | return key # int | str. 472 | if isinstance(key, jax.tree_util.SequenceKey): 473 | return key.idx # int. 474 | if isinstance(key, jax.tree_util.DictKey): 475 | return key.key # Hashable 476 | if isinstance(key, jax.tree_util.FlattenedIndexKey): 477 | return key.key # int. 478 | if isinstance(key, jax.tree_util.GetAttrKey): 479 | return key.name # str. 480 | raise ValueError(f"Jax tree key '{key}' of type '{type(key)}' not valid.") 481 | # pytype:enable=attribute-error 482 | 483 | return tuple(_convert_key_fn(key) for key in jax_tree_path) 484 | -------------------------------------------------------------------------------- /chex/_src/variants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Chex variants utilities.""" 16 | 17 | import enum 18 | import functools 19 | import inspect 20 | import itertools 21 | from typing import Any, Sequence 22 | import unittest 23 | 24 | from absl import flags 25 | from absl.testing import parameterized 26 | from chex._src import fake 27 | from chex._src import pytypes 28 | import jax 29 | from jax import tree_util 30 | import jax.numpy as jnp 31 | import toolz 32 | 33 | FLAGS = flags.FLAGS 34 | flags.DEFINE_bool( 35 | "chex_skip_pmap_variant_if_single_device", True, 36 | "Whether to skip pmap variant if only one device is available.") 37 | 38 | 39 | # We choose to subclass instead of a simple alias, as Python doesn't allow 40 | # multiple inheritance from the same class, and users may want to subclass their 41 | # tests from both `chex.TestCase` and `parameterized.TestCase`. 42 | # 43 | # User is free to use any base class that supports generators unrolling 44 | # instead of `variants.TestCase` or `parameterized.TestCase`. If a base class 45 | # doesn't support this feature variant test fails with a corresponding error. 46 | class TestCase(parameterized.TestCase): 47 | """A class for Chex tests that use variants. 48 | 49 | See the docstring for ``chex.variants`` for more information. 50 | 51 | Note: ``chex.variants`` returns a generator producing one test per variant. 52 | Therefore, the used test class must support dynamic unrolling of these 53 | generators during module import. It is implemented (and battle-tested) in 54 | ``absl.parameterized.TestCase``, and here we subclass from it. 55 | """ 56 | 57 | def variant(self, *args, **kwargs): 58 | """Raises a RuntimeError if not overriden or redefined.""" 59 | raise RuntimeError( 60 | "self.variant is not defined: forgot to wrap a test in @chex.variants?") 61 | 62 | 63 | class ChexVariantType(enum.Enum): 64 | """An enumeration of available Chex variants. 65 | 66 | Use ``self.variant.type`` to get type of the current test variant. 67 | See the docstring of ``chex.variants`` for more information. 68 | """ 69 | 70 | WITH_JIT = 1 71 | WITHOUT_JIT = 2 72 | WITH_DEVICE = 3 73 | WITHOUT_DEVICE = 4 74 | WITH_PMAP = 5 75 | 76 | def __str__(self) -> str: 77 | return "_" + self.name.lower() 78 | 79 | 80 | tree_map = tree_util.tree_map 81 | 82 | 83 | def params_product(*params_lists: Sequence[Sequence[Any]], 84 | named: bool = False) -> Sequence[Sequence[Any]]: 85 | """Generates a cartesian product of `params_lists`. 86 | 87 | See tests from ``variants_test.py`` for examples of usage. 88 | 89 | Args: 90 | *params_lists: A list of params combinations. 91 | named: Whether to generate test names (for 92 | `absl.parameterized.named_parameters(...)`). 93 | 94 | Returns: 95 | A cartesian product of `params_lists` combinations. 96 | """ 97 | 98 | def generate(): 99 | for combination in itertools.product(*params_lists): 100 | if named: 101 | name = "_".join(t[0] for t in combination) 102 | args_tuples = (t[1:] for t in combination) 103 | args = sum(args_tuples, ()) 104 | yield (name, *args) 105 | else: 106 | yield sum(combination, ()) 107 | 108 | return list(generate()) 109 | 110 | 111 | def count_num_calls(fn): 112 | """Counts the number of times the function was called.""" 113 | num_calls = 0 114 | 115 | @functools.wraps(fn) 116 | def fn_wrapped(*args, **kwargs): 117 | nonlocal num_calls 118 | num_calls += 1 119 | return fn(*args, **kwargs) 120 | 121 | return fn_wrapped, lambda: num_calls 122 | 123 | 124 | class VariantsTestCaseGenerator: 125 | """TestCase generator for chex variants. Supports sharding.""" 126 | 127 | def __init__(self, test_object, which_variants): 128 | self._which_variants = which_variants 129 | self._generated_names_freq = {} 130 | if hasattr(test_object, "__iter__"): 131 | # `test_object` is a generator (e.g. parameterised test). 132 | self._test_methods = list(test_object) 133 | else: 134 | # `test_object` is a single test method. 135 | self._test_methods = [test_object] 136 | 137 | def add_variants(self, which_variants): 138 | """Merge variants.""" 139 | for var, incl in which_variants.items(): 140 | self._which_variants[var] = self._which_variants.get(var, False) or incl 141 | 142 | @property 143 | def __name__(self): 144 | msg = ("A test wrapper attempts to access __name__ of " 145 | "VariantsTestCaseGenerator. Usually, this happens when " 146 | "@parameterized wraps @variants.variants. Make sure that the " 147 | "@variants.variants wrapper is an outer one, i.e. nothing wraps it.") 148 | raise RuntimeError(msg) 149 | 150 | def __call__(self): 151 | msg = ("A test wrapper attempts to invoke __call__ of " 152 | "VariantsTestCaseGenerator: make sure that all `TestCase` instances " 153 | "that use variants inherit from `chex.TestCase`.") 154 | raise RuntimeError(msg) 155 | 156 | def _set_test_name(self, test_method, variant): 157 | """Set a name for the generated test.""" 158 | name = getattr(test_method, "__name__", "") 159 | params_repr = getattr(test_method, "__x_params_repr__", "") 160 | chex_suffix = f"{variant}" 161 | 162 | candidate_name = "_".join(filter(None, [name, params_repr, chex_suffix])) 163 | name_freq = self._generated_names_freq.get(candidate_name, 0) 164 | if name_freq: 165 | # Ensure that test names are unique. 166 | new_name = name + "_" + str(name_freq) 167 | unique_name = "_".join(filter(None, [new_name, params_repr, chex_suffix])) 168 | else: 169 | unique_name = candidate_name 170 | self._generated_names_freq[candidate_name] = name_freq + 1 171 | 172 | # Always use name for compatibility with `absl.testing.parameterized`. 173 | setattr(test_method, "__name__", unique_name) 174 | setattr(test_method, "__x_params_repr__", "") 175 | setattr(test_method, "__x_use_name__", True) 176 | return test_method 177 | 178 | def _inner_iter(self, test_method): 179 | """Generate chex variants for a single test.""" 180 | 181 | def make_test(variant: ChexVariantType): 182 | 183 | @functools.wraps(test_method) 184 | def test(self, *args, **kwargs): 185 | # Skip pmap variant if only one device is available. 186 | 187 | if (variant is ChexVariantType.WITH_PMAP and 188 | FLAGS["chex_skip_pmap_variant_if_single_device"].value and 189 | jax.device_count() < 2): 190 | raise unittest.SkipTest( 191 | f"Only 1 device is available ({jax.devices()}).") 192 | 193 | # n_cpu_devices assert. 194 | if FLAGS["chex_assert_multiple_cpu_devices"].value: 195 | required_n_cpus = fake.get_n_cpu_devices_from_xla_flags() 196 | if required_n_cpus < 2: 197 | raise RuntimeError( 198 | f"Required number of CPU devices is {required_n_cpus} < 2." 199 | "Consider setting up your test module to use multiple CPU " 200 | " devices (see README.md) or disabling " 201 | "`chex_assert_multiple_cpu_devices` flag.") 202 | available_n_cpus = jax.device_count("cpu") 203 | if required_n_cpus != available_n_cpus: 204 | raise RuntimeError( 205 | "Number of available CPU devices is not equal to the required: " 206 | f"{available_n_cpus} != {required_n_cpus}") 207 | 208 | # Set up the variant. 209 | self.variant, num_calls = count_num_calls(_variant_decorators[variant]) 210 | self.variant.type = variant 211 | res = test_method(self, *args, **kwargs) 212 | if num_calls() == 0: 213 | raise RuntimeError( 214 | "Test is wrapped in @chex.variants, but never calls self.variant." 215 | " Consider debugging the test or removing @chex.variants wrapper." 216 | f" (variant: {variant})") 217 | return res 218 | 219 | self._set_test_name(test, variant) 220 | return test 221 | 222 | selected_variants = [ 223 | var_name for var_name, is_included in self._which_variants.items() 224 | if is_included 225 | ] 226 | if not selected_variants: 227 | raise ValueError(f"No variants selected for test: {test_method}.") 228 | 229 | return (make_test(var_name) for var_name in selected_variants) 230 | 231 | def __iter__(self): 232 | """Generate chex variants for each test case.""" 233 | return itertools.chain(*(self._inner_iter(m) for m in self._test_methods)) 234 | 235 | 236 | @toolz.curry 237 | def _variants_fn(test_object, **which_variants) -> VariantsTestCaseGenerator: 238 | """Implements `variants` and `all_variants`.""" 239 | 240 | # Convert keys to enum entries. 241 | which_variants = { 242 | ChexVariantType[name.upper()]: var 243 | for name, var in which_variants.items() 244 | } 245 | if isinstance(test_object, VariantsTestCaseGenerator): 246 | # Merge variants for nested wrappers. 247 | test_object.add_variants(which_variants) 248 | else: 249 | test_object = VariantsTestCaseGenerator(test_object, which_variants) 250 | 251 | return test_object 252 | 253 | 254 | @toolz.curry 255 | # pylint: disable=redefined-outer-name 256 | def variants(test_method, 257 | with_jit: bool = False, 258 | without_jit: bool = False, 259 | with_device: bool = False, 260 | without_device: bool = False, 261 | with_pmap: bool = False) -> VariantsTestCaseGenerator: 262 | # pylint: enable=redefined-outer-name 263 | """Decorates a test to expose Chex variants. 264 | 265 | The decorated test has access to a decorator called ``self.variant``, which 266 | may be applied to functions to test different JAX behaviors. Consider: 267 | 268 | .. code-block:: python 269 | 270 | @chex.variants(with_jit=True, without_jit=True) 271 | def test(self): 272 | @self.variant 273 | def f(x, y): 274 | return x + y 275 | 276 | self.assertEqual(f(1, 2), 3) 277 | 278 | In this example, the function ``test`` will be called twice: once with `f` 279 | jitted (i.e. using `jax.jit`) and another where `f` is not jitted. 280 | 281 | Variants `with_jit=True` and `with_pmap=True` accept additional specific to 282 | them arguments. Example: 283 | 284 | .. code-block:: python 285 | 286 | @chex.variants(with_jit=True) 287 | def test(self): 288 | @self.variant(static_argnums=(1,)) 289 | def f(x, y): 290 | # `y` is not traced. 291 | return x + y 292 | 293 | self.assertEqual(f(1, 2), 3) 294 | 295 | Variant `with_pmap=True` also accepts `broadcast_args_to_devices` 296 | (whether to broadcast each input argument to all participating devices), 297 | `reduce_fn` (a function to apply to results of pmapped `fn`), and 298 | `n_devices` (number of devices to use in the `pmap` computation). 299 | See the docstring of `_with_pmap` for more details (including default values). 300 | 301 | If used with ``absl.testing.parameterized``, `@chex.variants` must wrap it: 302 | 303 | .. code-block:: python 304 | 305 | @chex.variants(with_jit=True, without_jit=True) 306 | @parameterized.named_parameters('test', *args) 307 | def test(self, *args): 308 | ... 309 | 310 | Tests that use this wrapper must be inherited from ``parameterized.TestCase``. 311 | For more examples see ``variants_test.py``. 312 | 313 | Args: 314 | test_method: A test method to decorate. 315 | with_jit: Whether to test with `jax.jit`. 316 | without_jit: Whether to test without `jax.jit`. Any jit compilation done 317 | within the test method will not be affected. 318 | with_device: Whether to test with args placed on device, using 319 | `jax.device_put`. 320 | without_device: Whether to test with args (explicitly) not placed on device, 321 | using `jax.device_get`. 322 | with_pmap: Whether to test with `jax.pmap`, with computation duplicated 323 | across devices. 324 | 325 | Returns: 326 | A decorated ``test_method``. 327 | """ 328 | return _variants_fn( 329 | test_method, 330 | with_jit=with_jit, 331 | without_jit=without_jit, 332 | with_device=with_device, 333 | without_device=without_device, 334 | with_pmap=with_pmap) 335 | 336 | 337 | @toolz.curry 338 | # pylint: disable=redefined-outer-name 339 | def all_variants(test_method, 340 | with_jit: bool = True, 341 | without_jit: bool = True, 342 | with_device: bool = True, 343 | without_device: bool = True, 344 | with_pmap: bool = True) -> VariantsTestCaseGenerator: 345 | # pylint: enable=redefined-outer-name 346 | """Equivalent to ``chex.variants`` but with flipped defaults.""" 347 | return _variants_fn( 348 | test_method, 349 | with_jit=with_jit, 350 | without_jit=without_jit, 351 | with_device=with_device, 352 | without_device=without_device, 353 | with_pmap=with_pmap) 354 | 355 | 356 | def check_variant_arguments(variant_fn): 357 | """Raises `ValueError` if `variant_fn` got an unknown argument.""" 358 | 359 | @functools.wraps(variant_fn) 360 | def wrapper(*args, **kwargs): 361 | unknown_args = set(kwargs.keys()) - _valid_kwargs_keys 362 | if unknown_args: 363 | raise ValueError(f"Unknown arguments in `self.variant`: {unknown_args}.") 364 | return variant_fn(*args, **kwargs) 365 | 366 | return wrapper 367 | 368 | 369 | @toolz.curry 370 | @check_variant_arguments 371 | def _with_jit(fn, 372 | static_argnums=None, 373 | static_argnames=None, 374 | device=None, 375 | backend=None, 376 | **unused_kwargs): 377 | """Variant that applies `jax.jit` to fn.""" 378 | 379 | return jax.jit( 380 | fn, 381 | static_argnums=static_argnums, 382 | static_argnames=static_argnames, 383 | device=device, 384 | backend=backend) 385 | 386 | 387 | @toolz.curry 388 | @check_variant_arguments 389 | def _without_jit(fn, **unused_kwargs): 390 | """Variant that does not apply `jax.jit` to a fn (identity).""" 391 | 392 | @functools.wraps(fn) 393 | def wrapper(*args, **kwargs): 394 | return fn(*args, **kwargs) 395 | 396 | return wrapper 397 | 398 | 399 | @toolz.curry 400 | @check_variant_arguments 401 | def _with_device(fn, ignore_argnums=(), static_argnums=(), **unused_kwargs): 402 | """Variant that applies `jax.device_put` to the args of fn.""" 403 | 404 | if isinstance(ignore_argnums, int): 405 | ignore_argnums = (ignore_argnums,) 406 | if isinstance(static_argnums, int): 407 | static_argnums = (static_argnums,) 408 | 409 | @functools.wraps(fn) 410 | def wrapper(*args, **kwargs): 411 | 412 | def put(x): 413 | try: 414 | return jax.device_put(x) 415 | except TypeError: # not a valid JAX type 416 | return x 417 | 418 | device_args = [ 419 | arg if (idx in ignore_argnums or idx in static_argnums) else tree_map( 420 | put, arg) for idx, arg in enumerate(args) 421 | ] 422 | device_kwargs = tree_map(put, kwargs) 423 | return fn(*device_args, **device_kwargs) 424 | 425 | return wrapper 426 | 427 | 428 | @toolz.curry 429 | @check_variant_arguments 430 | def _without_device(fn, **unused_kwargs): 431 | """Variant that applies `jax.device_get` to the args of fn.""" 432 | 433 | @functools.wraps(fn) 434 | def wrapper(*args, **kwargs): 435 | 436 | def get(x): 437 | if isinstance(x, jax.Array): 438 | return jax.device_get(x) 439 | return x 440 | 441 | no_device_args = tree_map(get, args) 442 | no_device_kwargs = tree_map(get, kwargs) 443 | return fn(*no_device_args, **no_device_kwargs) 444 | 445 | return wrapper 446 | 447 | 448 | @toolz.curry 449 | @check_variant_arguments 450 | def _with_pmap(fn, 451 | broadcast_args_to_devices=True, 452 | reduce_fn="first_device_output", 453 | n_devices=None, 454 | axis_name="i", 455 | devices=None, 456 | in_axes=0, 457 | static_broadcasted_argnums=(), 458 | static_argnums=(), 459 | backend=None, 460 | **unused_kwargs): 461 | """Variant that applies `jax.pmap` to fn. 462 | 463 | Args: 464 | fn: A function to wrap. 465 | broadcast_args_to_devices: Whether to broadcast `fn` args to pmap format 466 | (i.e. pmapped axes' sizes == a number of devices). 467 | reduce_fn: A function to apply to outputs of `fn`. 468 | n_devices: A number of devices to use (can specify a `backend` if required). 469 | axis_name: An argument for `pmap`. 470 | devices: An argument for `pmap`. 471 | in_axes: An argument for `pmap`. 472 | static_broadcasted_argnums: An argument for `pmap`. 473 | static_argnums: An alias of ``static_broadcasted_argnums``. 474 | backend: An argument for `pmap`. 475 | **unused_kwargs: Unused kwargs (e.g. related to other variants). 476 | 477 | Returns: 478 | Wrapped `fn` that accepts `args` and `kwargs` and returns a superposition of 479 | `reduce_fn` and `fn` applied to them. 480 | 481 | Raises: 482 | ValueError: If `broadcast_args_to_devices` used with `in_axes` or 483 | `static_broadcasted_argnums`; if number of available devices is less than 484 | required; if pmappable arg axes' sizes are not equal to the number of 485 | devices. 486 | SkipTest: If the flag ``chex_skip_pmap_variant_if_single_device`` is set and 487 | there is only one device available. 488 | """ 489 | if (FLAGS["chex_skip_pmap_variant_if_single_device"].value and 490 | jax.device_count() < 2): 491 | raise unittest.SkipTest(f"Only 1 device is available ({jax.devices()}).") 492 | 493 | if broadcast_args_to_devices and in_axes != 0: 494 | raise ValueError( 495 | "Do not use `broadcast_args_to_devices` when specifying `in_axes`.") 496 | 497 | # Set up a reduce function. 498 | if reduce_fn == "first_device_output": 499 | reduce_fn = lambda t: tree_map(lambda x: x[0], t) 500 | elif reduce_fn == "identity" or reduce_fn is None: # Identity. 501 | reduce_fn = lambda t: t 502 | 503 | if not static_argnums and static_argnums != 0: 504 | static_argnums = static_broadcasted_argnums 505 | if isinstance(static_argnums, int): 506 | static_argnums = (static_argnums,) 507 | 508 | pmap_kwargs = dict( 509 | axis_name=axis_name, 510 | devices=devices, 511 | in_axes=in_axes, 512 | static_broadcasted_argnums=static_argnums, 513 | backend=backend) 514 | pmapped_fn = jax.pmap(fn, **pmap_kwargs) 515 | 516 | @functools.wraps(pmapped_fn) 517 | def wrapper(*args: pytypes.ArrayTree, **kwargs: pytypes.ArrayTree): 518 | if kwargs and (in_axes != 0 or static_argnums): 519 | raise ValueError("Do not use kwargs with `in_axes` or `static_argnums` " 520 | "in pmapped function.") 521 | devices_ = list(devices or jax.devices(backend)) 522 | n_devices_ = n_devices or len(devices_) 523 | devices_ = devices_[:n_devices_] 524 | if len(devices_) != n_devices_: 525 | raise ValueError("Number of available devices is less than required for " 526 | f"test ({len(devices_)} < {n_devices_})") 527 | 528 | bcast_fn = lambda x: jnp.broadcast_to(x, (n_devices_,) + jnp.array(x).shape) 529 | if broadcast_args_to_devices: 530 | args = [ 531 | tree_map(bcast_fn, arg) if idx not in static_argnums else arg 532 | for idx, arg in enumerate(args) 533 | ] 534 | kwargs = tree_map(bcast_fn, kwargs) 535 | else: 536 | # Pmappable axes size must be equal to number of devices. 537 | in_axes_ = in_axes if isinstance(in_axes, 538 | (tuple, list)) else [in_axes] * len(args) 539 | is_pmappable_arg = [ 540 | idx not in static_argnums and in_axes_[idx] is not None 541 | for idx in range(len(args)) 542 | ] 543 | for is_pmappable_arg, arg in zip(is_pmappable_arg, args): 544 | if not is_pmappable_arg: 545 | continue 546 | if not all( 547 | x.shape[0] == n_devices_ for x in jax.tree_util.tree_leaves(arg)): 548 | shapes = tree_map(jnp.shape, arg) 549 | raise ValueError( 550 | f"Pmappable arg axes size must be equal to number of devices, " 551 | f"got: {shapes} (expected the first dim to be {n_devices_}). " 552 | "Consider setting `broadcast_args_to_devices=True`.") 553 | 554 | new_kwargs = dict( 555 | axis_name=axis_name, 556 | devices=devices_, 557 | in_axes=in_axes, 558 | static_broadcasted_argnums=static_argnums, 559 | backend=backend) 560 | 561 | # Re-compile fn if kwargs changed. 562 | nonlocal pmap_kwargs 563 | nonlocal pmapped_fn 564 | if new_kwargs != pmap_kwargs: 565 | pmap_kwargs = new_kwargs 566 | pmapped_fn = jax.pmap(fn, **pmap_kwargs) 567 | 568 | res = pmapped_fn(*args, **kwargs) 569 | return reduce_fn(res) 570 | 571 | return wrapper 572 | 573 | 574 | _variant_decorators = dict({ 575 | ChexVariantType.WITH_JIT: _with_jit, 576 | ChexVariantType.WITHOUT_JIT: _without_jit, 577 | ChexVariantType.WITH_DEVICE: _with_device, 578 | ChexVariantType.WITHOUT_DEVICE: _without_device, 579 | ChexVariantType.WITH_PMAP: _with_pmap, 580 | }) 581 | 582 | 583 | class Variant: 584 | """Variant class for typing and string representation.""" 585 | 586 | def __init__(self, name, fn): 587 | self._fn = fn 588 | self._name = name 589 | 590 | def __repr__(self): 591 | return self._name 592 | 593 | def __call__(self, *args, **kwargs): 594 | # Could apply decorators (currying, arg-checking) here 595 | return self._fn(*args, **kwargs) 596 | 597 | 598 | # Expose variant objects. 599 | without_device = Variant("chex_without_device", _without_device) 600 | without_jit = Variant("chex_without_jit", _without_jit) 601 | with_device = Variant("chex_with_device", _with_device) 602 | with_jit = Variant("chex_with_jit", _with_jit) 603 | with_pmap = Variant("chex_with_pmap", _with_pmap) 604 | ALL_VARIANTS = (without_device, without_jit, with_device, with_jit, with_pmap) 605 | 606 | # Collect valid argument names from all variant decorators. 607 | _valid_kwargs_keys = set() 608 | for fn_ in _variant_decorators.values(): 609 | original_fn = fn_.func.__wrapped__ 610 | _valid_kwargs_keys.update(inspect.getfullargspec(original_fn).args) 611 | --------------------------------------------------------------------------------