├── chex
├── py.typed
├── _src
│ ├── __init__.py
│ ├── warnings_test.py
│ ├── fake_set_n_cpu_devices_test.py
│ ├── pytypes.py
│ ├── warnings.py
│ ├── restrict_backends.py
│ ├── restrict_backends_test.py
│ ├── asserts_internal_test.py
│ ├── dimensions_test.py
│ ├── dimensions.py
│ ├── asserts_chexify.py
│ ├── dataclass.py
│ ├── fake.py
│ ├── fake_test.py
│ ├── asserts_internal.py
│ └── variants.py
├── chex_test.py
└── __init__.py
├── requirements
├── requirements-test.txt
├── requirements-docs.txt
└── requirements.txt
├── MANIFEST.in
├── .gitignore
├── .readthedocs.yaml
├── docs
├── Makefile
├── index.rst
├── ext
│ └── coverage_check.py
├── api.rst
└── conf.py
├── .github
└── workflows
│ ├── ci.yml
│ └── pypi-publish.yml
├── CONTRIBUTING.md
├── setup.py
├── test.sh
├── LICENSE
└── README.md
/chex/py.typed:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements/requirements-test.txt:
--------------------------------------------------------------------------------
1 | cloudpickle==2.2.0
2 | dm-tree>=0.1.5
3 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md
2 | include LICENSE
3 | include requirements/*
4 | include chex/py.typed
5 |
--------------------------------------------------------------------------------
/requirements/requirements-docs.txt:
--------------------------------------------------------------------------------
1 | sphinx>=6.0.0
2 | sphinx-book-theme>=1.0.1
3 | sphinxcontrib-katex
4 |
--------------------------------------------------------------------------------
/requirements/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py>=0.9.0
2 | typing_extensions>=4.2.0
3 | jax>=0.4.16
4 | jaxlib>=0.1.37
5 | numpy>=1.24.1
6 | setuptools;python_version>="3.12"
7 | toolz>=0.9.0
8 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Building and releasing library:
2 | *.egg-info
3 | *.pyc
4 | *.so
5 | build/
6 | dist/
7 | venv/
8 | docs/_build/
9 |
10 | # Mac OS
11 | .DS_Store
12 |
13 | # Python tools
14 | .mypy_cache/
15 | .pytype/
16 | .ipynb_checkpoints
17 |
18 | # Editors
19 | .idea
20 | .vscode
21 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # Read the Docs configuration file
2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
3 |
4 | version: 2
5 |
6 | build:
7 | os: ubuntu-22.04
8 | tools:
9 | python: "3.11"
10 |
11 | sphinx:
12 | builder: html
13 | configuration: docs/conf.py
14 | fail_on_warning: false
15 |
16 | python:
17 | install:
18 | - requirements: requirements/requirements-docs.txt
19 | - requirements: requirements/requirements.txt
20 | - method: setuptools
21 | path: .
22 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS = -W --keep-going
6 | SPHINXBUILD = sphinx-build
7 | SOURCEDIR = .
8 | BUILDDIR = _build
9 |
10 | # Put it first so that "make" without argument is like "make help".
11 | help:
12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13 |
14 | .PHONY: help Makefile
15 |
16 | # Catch-all target: route all unknown targets to Sphinx using the new
17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18 | %: Makefile
19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
20 |
--------------------------------------------------------------------------------
/chex/_src/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: ci
2 |
3 | on:
4 | push:
5 | branches: ["master"]
6 | pull_request:
7 | branches: ["master"]
8 | schedule:
9 | - cron: '30 2 * * *'
10 |
11 | jobs:
12 | build-and-test:
13 | name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}"
14 | runs-on: "${{ matrix.os }}"
15 |
16 | strategy:
17 | matrix:
18 | python-version: ["3.9", "3.10", "3.11"]
19 | os: [ubuntu-latest]
20 |
21 | steps:
22 | - uses: "actions/checkout@v2"
23 | - uses: "actions/setup-python@v4"
24 | with:
25 | python-version: "${{ matrix.python-version }}"
26 | cache: "pip"
27 | cache-dependency-path: '**/requirements*.txt'
28 | - name: Run CI tests
29 | run: bash test.sh
30 | shell: bash
31 |
--------------------------------------------------------------------------------
/.github/workflows/pypi-publish.yml:
--------------------------------------------------------------------------------
1 | name: pypi
2 |
3 | on:
4 | release:
5 | types: [created]
6 |
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v4
12 | - name: Set up Python
13 | uses: actions/setup-python@v4
14 | with:
15 | python-version: '3.x'
16 | - name: Install dependencies
17 | run: |
18 | python -m pip install --upgrade pip
19 | pip install setuptools wheel twine
20 | - name: Check consistency between the package version and release tag
21 | run: |
22 | RELEASE_VER=${GITHUB_REF#refs/*/}
23 | PACKAGE_VER="v`python setup.py --version`"
24 | if [ $RELEASE_VER != $PACKAGE_VER ]
25 | then
26 | echo "package ver. ($PACKAGE_VER) != release ver. ($RELEASE_VER)"; exit 1
27 | fi
28 | - name: Build and publish
29 | env:
30 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
31 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
32 | run: |
33 | python setup.py sdist bdist_wheel
34 | twine upload dist/*
35 |
--------------------------------------------------------------------------------
/chex/chex_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for chex."""
16 |
17 | from absl.testing import absltest
18 | import chex
19 |
20 |
21 | class ChexTest(absltest.TestCase):
22 | """Test chex can be imported correctly."""
23 |
24 | def test_import(self):
25 | self.assertTrue(hasattr(chex, 'assert_devices_available'))
26 |
27 |
28 | if __name__ == '__main__':
29 | absltest.main()
30 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Testing
26 |
27 | Please make sure that your PR passes all tests by running `bash test.sh` on your
28 | local machine. Also, you can run only tests that are affected by your code
29 | changes, but you will need to select them manually.
30 |
31 | ## Community Guidelines
32 |
33 | This project follows [Google's Open Source Community
34 | Guidelines](https://opensource.google.com/conduct/).
35 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/deepmind/chex/tree/master/docs
2 |
3 | Chex
4 | -----
5 |
6 | Chex is a library of utilities for helping to write reliable JAX code.
7 |
8 | This includes utils to help:
9 |
10 | * Instrument your code (e.g. assertions)
11 | * Debug (e.g. transforming pmaps in vmaps within a context manager).
12 | * Test JAX code across many variants (e.g. jitted vs non-jitted).
13 |
14 | Modules overview can be found `on GitHub `_.
15 |
16 | Installation
17 | ------------
18 |
19 | Chex can be installed with pip directly from github, with the following command:
20 |
21 | ``pip install git+git://github.com/deepmind/chex.git``
22 |
23 | or from PyPI:
24 |
25 | ``pip install chex``
26 |
27 | .. toctree::
28 | :caption: API Documentation
29 | :maxdepth: 2
30 |
31 | api
32 |
33 | Citing Chex
34 | -----------
35 |
36 | This repository is part of the `DeepMind JAX Ecosystem `_.
37 |
38 | To cite Chex please use the `DeepMind JAX Ecosystem citation `_.
39 |
40 | Contribute
41 | ----------
42 |
43 | - `Issue tracker `_
44 | - `Source code `_
45 |
46 | Support
47 | -------
48 |
49 | If you are having issues, please let us know by filing an issue on our
50 | `issue tracker `_.
51 |
52 | License
53 | -------
54 |
55 | Chex is licensed under the Apache 2.0 License.
56 |
57 |
58 | Indices and Tables
59 | ==================
60 |
61 | * :ref:`genindex`
62 |
--------------------------------------------------------------------------------
/chex/_src/warnings_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for `warnings.py`."""
16 |
17 | import functools
18 |
19 | from absl.testing import absltest
20 |
21 | from chex._src import warnings
22 |
23 |
24 | @functools.partial(warnings.warn_only_n_pos_args_in_future, n=1)
25 | def f(a, b, c):
26 | return a + b + c
27 |
28 |
29 | @functools.partial(warnings.warn_deprecated_function, replacement='h')
30 | def g(a, b, c):
31 | return a + b + c
32 |
33 |
34 | def h1(a, b, c):
35 | return a + b + c
36 | h2 = warnings.create_deprecated_function_alias(h1, 'path.h2', 'path.h1')
37 |
38 |
39 | class WarningsTest(absltest.TestCase):
40 |
41 | def test_warn_only_n_pos_args_in_future(self):
42 | with self.assertWarns(Warning):
43 | f(1, 2, 3)
44 | with self.assertWarns(Warning):
45 | f(1, 2, c=3)
46 |
47 | def test_warn_deprecated_function(self):
48 | with self.assertWarns(Warning):
49 | g(1, 2, 3)
50 |
51 | def test_create_deprecated_function_alias(self):
52 | with self.assertWarns(Warning):
53 | h2(1, 2, 3)
54 |
55 |
56 | if __name__ == '__main__':
57 | absltest.main()
58 |
--------------------------------------------------------------------------------
/chex/_src/fake_set_n_cpu_devices_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Test for `set_n_cpu_devices` from `fake.py`.
16 |
17 | This test is isolated to ensure hermeticity because its execution changes
18 | XLA backend configuration.
19 | """
20 |
21 | import unittest
22 | from absl.testing import absltest
23 | from chex._src import asserts
24 | from chex._src import fake
25 |
26 |
27 | class DevicesSetterTest(absltest.TestCase):
28 |
29 | def test_set_n_cpu_devices(self):
30 | try:
31 | # Should not initialize backends.
32 | fake.set_n_cpu_devices(4)
33 | except RuntimeError as set_cpu_exception:
34 | raise unittest.SkipTest(
35 | "set_n_cpu_devices: backend's already been initialized. "
36 | 'Run this test in isolation from others.') from set_cpu_exception
37 |
38 | # Hence, this one does not fail.
39 | fake.set_n_cpu_devices(6)
40 |
41 | # This assert initializes backends.
42 | asserts.assert_devices_available(6, 'cpu', backend='cpu')
43 |
44 | # Which means that next call must fail.
45 | with self.assertRaisesRegex(RuntimeError,
46 | 'Attempted to set 8 devices, but 6 CPUs.+'):
47 | fake.set_n_cpu_devices(8)
48 |
49 |
50 | if __name__ == '__main__':
51 | absltest.main()
52 |
--------------------------------------------------------------------------------
/chex/_src/pytypes.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Type definitions to use for type annotations."""
16 |
17 | from typing import Any, Iterable, Mapping, Sequence, Union
18 |
19 | import jax
20 | import numpy as np
21 |
22 | # Special types of arrays.
23 | ArrayNumpy = np.ndarray
24 |
25 | # For instance checking, use `isinstance(x, jax.Array)`.
26 | ArrayDevice = jax.Array
27 |
28 | # Types for backward compatibility.
29 | ArraySharded = jax.Array
30 | ArrayBatched = jax.Array
31 |
32 | # Generic array type.
33 | # Similar to `jax.typing.ArrayLike` but does not accept python scalar types.
34 | Array = Union[
35 | ArrayDevice,
36 | ArrayBatched,
37 | ArraySharded, # JAX array type
38 | ArrayNumpy, # NumPy array type
39 | np.bool_,
40 | np.number, # NumPy scalar types
41 | ]
42 |
43 | # A tree of generic arrays.
44 | ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]
45 | ArrayDeviceTree = Union[
46 | ArrayDevice, Iterable['ArrayDeviceTree'], Mapping[Any, 'ArrayDeviceTree']
47 | ]
48 | ArrayNumpyTree = Union[
49 | ArrayNumpy, Iterable['ArrayNumpyTree'], Mapping[Any, 'ArrayNumpyTree']
50 | ]
51 |
52 | # Other types.
53 | Scalar = Union[float, int]
54 | Numeric = Union[Array, Scalar]
55 | Shape = Sequence[Union[int, Any]]
56 | PRNGKey = jax.Array
57 | PyTreeDef = jax.tree_util.PyTreeDef
58 | Device = jax.Device
59 |
60 | # TODO(iukemaev, jakevdp): upgrade minimum jax version & remove this condition.
61 | if hasattr(jax.typing, 'DTypeLike'):
62 | # jax version 0.4.19 or newer
63 | ArrayDType = jax.typing.DTypeLike # pylint:disable=invalid-name
64 | else:
65 | ArrayDType = Any # pylint:disable=invalid-name
66 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Install script for setuptools."""
16 |
17 | import os
18 | from setuptools import find_packages
19 | from setuptools import setup
20 |
21 | _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
22 |
23 |
24 | def _get_version():
25 | with open(os.path.join(_CURRENT_DIR, 'chex', '__init__.py')) as fp:
26 | for line in fp:
27 | if line.startswith('__version__') and '=' in line:
28 | version = line[line.find('=') + 1:].strip(' \'"\n')
29 | if version:
30 | return version
31 | raise ValueError('`__version__` not defined in `chex/__init__.py`')
32 |
33 |
34 | def _parse_requirements(path):
35 | with open(os.path.join(_CURRENT_DIR, path)) as f:
36 | return [
37 | line.rstrip()
38 | for line in f
39 | if not (line.isspace() or line.startswith('#'))
40 | ]
41 |
42 |
43 | setup(
44 | name='chex',
45 | version=_get_version(),
46 | url='https://github.com/deepmind/chex',
47 | license='Apache 2.0',
48 | author='DeepMind',
49 | description=('Chex: Testing made fun, in JAX!'),
50 | long_description=open(os.path.join(_CURRENT_DIR, 'README.md')).read(),
51 | long_description_content_type='text/markdown',
52 | author_email='chex-dev@google.com',
53 | keywords='jax testing debugging python machine learning',
54 | packages=find_packages(exclude=['*_test.py']),
55 | install_requires=_parse_requirements(
56 | os.path.join(_CURRENT_DIR, 'requirements', 'requirements.txt')),
57 | tests_require=_parse_requirements(
58 | os.path.join(_CURRENT_DIR, 'requirements', 'requirements-test.txt')),
59 | zip_safe=False, # Required for full installation.
60 | include_package_data=True,
61 | python_requires='>=3.9',
62 | classifiers=[
63 | 'Development Status :: 5 - Production/Stable',
64 | 'Environment :: Console',
65 | 'Intended Audience :: Science/Research',
66 | 'Intended Audience :: Developers',
67 | 'License :: OSI Approved :: Apache Software License',
68 | 'Operating System :: OS Independent',
69 | 'Programming Language :: Python',
70 | 'Programming Language :: Python :: 3',
71 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
72 | 'Topic :: Software Development :: Testing :: Mocking',
73 | 'Topic :: Software Development :: Testing :: Unit',
74 | 'Topic :: Software Development :: Libraries :: Python Modules',
75 | ],
76 | )
77 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | # Runs CI tests on a local machine.
17 | set -xeuo pipefail
18 |
19 | # Install deps in a virtual env.
20 | rm -rf _testing
21 | rm -rf .pytype
22 | mkdir -p _testing
23 | readonly VENV_DIR="$(mktemp -d -p `pwd`/_testing chex-env.XXXXXXXX)"
24 | # in the unlikely case in which there was something in that directory
25 | python3 -m venv "${VENV_DIR}"
26 | source "${VENV_DIR}/bin/activate"
27 | python --version
28 |
29 | # Install dependencies.
30 | pip install --upgrade pip setuptools wheel
31 | pip install flake8 pytest-xdist pylint pylint-exit
32 | pip install -r requirements/requirements.txt
33 | pip install -r requirements/requirements-test.txt
34 |
35 | # Lint with flake8.
36 | flake8 `find chex -name '*.py' | xargs` --count --select=E9,F63,F7,F82,E225,E251 --show-source --statistics
37 |
38 | # Lint with pylint.
39 | PYLINT_ARGS="-efail -wfail -cfail -rfail"
40 | # Download Google OSS config.
41 | wget -nd -v -t 3 -O .pylintrc https://google.github.io/styleguide/pylintrc
42 | # Append specific config lines.
43 | echo "signature-mutators=toolz.functoolz.curry" >> .pylintrc
44 | echo "disable=unnecessary-lambda-assignment,use-dict-literal" >> .pylintrc
45 | # Lint modules and tests separately.
46 | pylint --rcfile=.pylintrc `find chex -name '*.py' | grep -v 'test.py' | xargs` -d E1102|| pylint-exit $PYLINT_ARGS $?
47 | # Disable `protected-access` warnings for tests.
48 | pylint --rcfile=.pylintrc `find chex -name '*_test.py' | xargs` -d W0212,E1130,E1102 || pylint-exit $PYLINT_ARGS $?
49 | # Cleanup.
50 | rm .pylintrc
51 |
52 | # Build the package.
53 | python setup.py sdist
54 | pip wheel --verbose --no-deps --no-clean dist/chex*.tar.gz
55 | pip install chex*.whl
56 |
57 | # Check types with pytype.
58 | # Note: pytype does not support 3.11 as of 25.06.23
59 | # See https://github.com/google/pytype/issues/1308
60 | if [ `python -c 'import sys; print(sys.version_info.minor)'` -lt 11 ];
61 | then
62 | pip install pytype
63 | pytype `find chex/_src -name "*py" | xargs` -k
64 | fi;
65 |
66 | # Run tests using pytest.
67 | # Change directory to avoid importing the package from repo root.
68 | pip install -r requirements/requirements-test.txt
69 | cd _testing
70 |
71 | # Main tests.
72 | pytest -n "$(grep -c ^processor /proc/cpuinfo)" --pyargs chex -k "not fake_set_n_cpu_devices_test"
73 |
74 | # Isolate tests that use `chex.set_n_cpu_device()`.
75 | pytest -n "$(grep -c ^processor /proc/cpuinfo)" --pyargs chex -k "fake_set_n_cpu_devices_test"
76 | cd ..
77 |
78 | # Build Sphinx docs.
79 |
80 | pip install -r requirements/requirements-docs.txt
81 | cd docs
82 | make coverage_check
83 | make html
84 | cd ..
85 |
86 | # cleanup
87 | rm -rf _testing
88 |
89 | set +u
90 | deactivate
91 | echo "All tests passed. Congrats!"
92 |
--------------------------------------------------------------------------------
/chex/_src/warnings.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Utilities to emit warnings."""
16 |
17 | import functools
18 | import warnings
19 |
20 |
21 | def warn_only_n_pos_args_in_future(fun, n):
22 | """Warns if more than ``n`` positional arguments are passed to ``fun``.
23 |
24 | For instance:
25 | >>> @functools.partial(chex.warn_only_n_pos_args_in_future, n=1)
26 | ... def f(a, b, c=1):
27 | ... return a + b + c
28 |
29 | Will raise a DeprecationWarning if ``f`` is called with more than one
30 | positional argument (e.g. both f(1, 2, 3) and f(1, 2, c=3) raise a warning).
31 |
32 | Args:
33 | fun: the function to wrap.
34 | n: the number of positional arguments to allow.
35 |
36 | Returns:
37 | A wrapped function that emits a warning if more than `n` positional
38 | arguments are passed.
39 | """
40 |
41 | @functools.wraps(fun)
42 | def wrapper(*args, **kwargs):
43 | if len(args) > n:
44 | warnings.warn(
45 | f'only the first {n} arguments can be passed positionally '
46 | 'additional args will become keyword-only soon',
47 | DeprecationWarning,
48 | stacklevel=2
49 | )
50 | return fun(*args, **kwargs)
51 |
52 | return wrapper
53 |
54 |
55 | warn_keyword_args_only_in_future = functools.partial(
56 | warn_only_n_pos_args_in_future, n=0
57 | )
58 |
59 |
60 | def warn_deprecated_function(fun, replacement):
61 | """A decorator to mark a function definition as deprecated.
62 |
63 | Example usage:
64 | >>> @functools.partial(chex.warn_deprecated_function, replacement='g')
65 | ... def f(a, b):
66 | ... return a + b
67 |
68 | Args:
69 | fun: the deprecated function.
70 | replacement: the name of the function to be used instead.
71 |
72 | Returns:
73 | the wrapped function.
74 | """
75 |
76 | @functools.wraps(fun)
77 | def new_fun(*args, **kwargs):
78 | warnings.warn(
79 | f'The function {fun.__name__} is deprecated, '
80 | f'please use {replacement} instead.',
81 | category=DeprecationWarning,
82 | stacklevel=2)
83 | return fun(*args, **kwargs)
84 | return new_fun
85 |
86 |
87 | def create_deprecated_function_alias(fun, new_name, deprecated_alias):
88 | """Create a deprecated alias for a function.
89 |
90 | Example usage:
91 | >>> g = create_deprecated_function_alias(f, 'path.f', 'path.g')
92 |
93 | Args:
94 | fun: the deprecated function.
95 | new_name: the new name to use (you may include the path for clarity).
96 | deprecated_alias: the old name (you may include the path for clarity).
97 |
98 | Returns:
99 | the wrapped function.
100 | """
101 |
102 | @functools.wraps(fun)
103 | def new_fun(*args, **kwargs):
104 | warnings.warn(
105 | f'The function {deprecated_alias} is deprecated, '
106 | f'please use {new_name} instead.',
107 | category=DeprecationWarning,
108 | stacklevel=2)
109 | return fun(*args, **kwargs)
110 | return new_fun
111 |
--------------------------------------------------------------------------------
/docs/ext/coverage_check.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Asserts all public symbols are covered in the docs."""
16 |
17 | import inspect
18 | import types
19 | from typing import Any, Mapping, Sequence, Tuple
20 |
21 | import chex as _module
22 | from sphinx import application
23 | from sphinx import builders
24 | from sphinx import errors
25 |
26 |
27 | def find_internal_python_modules(
28 | root_module: types.ModuleType,) -> Sequence[Tuple[str, types.ModuleType]]:
29 | """Returns `(name, module)` for all submodules under `root_module`."""
30 | modules = set([(root_module.__name__, root_module)])
31 | visited = set()
32 | to_visit = [root_module]
33 |
34 | while to_visit:
35 | mod = to_visit.pop()
36 | visited.add(mod)
37 |
38 | for name in dir(mod):
39 | obj = getattr(mod, name)
40 | if inspect.ismodule(obj) and obj not in visited:
41 | if obj.__name__.startswith(_module.__name__):
42 | if '_src' not in obj.__name__:
43 | to_visit.append(obj)
44 | modules.add((obj.__name__, obj))
45 |
46 | return sorted(modules)
47 |
48 |
49 | def get_public_symbols() -> Sequence[Tuple[str, types.ModuleType]]:
50 | names = set()
51 | for module_name, module in find_internal_python_modules(_module):
52 | for name in module.__all__:
53 | names.add(module_name + '.' + name)
54 | return tuple(names)
55 |
56 |
57 | class CoverageCheck(builders.Builder):
58 | """Builder that checks all public symbols are included."""
59 |
60 | name = 'coverage_check'
61 |
62 | def get_outdated_docs(self) -> str:
63 | return 'coverage_check'
64 |
65 | def write(self, *ignored: Any) -> None:
66 | pass
67 |
68 | def finish(self) -> None:
69 | documented_objects = frozenset(self.env.domaindata['py']['objects'])
70 | undocumented_objects = set(get_public_symbols()) - documented_objects
71 |
72 | # Exclude deprecated API symbols.
73 | assertion_exceptions = ('assert_tree_all_close',
74 | 'assert_tree_all_equal_comparator',
75 | 'assert_tree_all_equal_shapes',
76 | 'assert_tree_all_equal_structs')
77 | undocumented_objects -= {'chex.' + s for s in assertion_exceptions}
78 |
79 | # Exclude pytypes.
80 | pytypes_exceptions = (
81 | 'Array',
82 | 'ArrayBatched',
83 | 'Array',
84 | 'ArrayBatched',
85 | 'ArrayDevice',
86 | 'ArrayDeviceTree',
87 | 'ArrayDType',
88 | 'ArrayNumpy',
89 | 'ArrayNumpyTree',
90 | 'ArraySharded',
91 | 'ArrayTree',
92 | 'Device',
93 | 'Numeric',
94 | 'PRNGKey',
95 | 'PyTreeDef',
96 | 'Scalar',
97 | 'Shape',
98 | )
99 |
100 | # Exclude public constants.
101 | pytypes_exceptions += ('ChexifyChecks',)
102 |
103 | undocumented_objects -= {'chex.' + s for s in pytypes_exceptions}
104 |
105 | if undocumented_objects:
106 | undocumented_objects = tuple(sorted(undocumented_objects))
107 | raise errors.SphinxError(
108 | 'All public symbols must be included in our documentation, did you '
109 | 'forget to add an entry to `api.rst`?\n'
110 | f'Undocumented symbols: {undocumented_objects}')
111 |
112 |
113 | def setup(app: application.Sphinx) -> Mapping[str, Any]:
114 | app.add_builder(CoverageCheck)
115 | return dict(version=_module.__version__, parallel_read_safe=True)
116 |
--------------------------------------------------------------------------------
/chex/_src/restrict_backends.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A context manager that objects to JAX compilation for specified backends.
16 |
17 | This is useful, for example, when certain JAX code needs to run in an
18 | environment where an accelerator is present but reserved for other purposes.
19 | Typically one would use `jax.jit(..., backend='cpu')` to keep the code away
20 | from the accelerator, but it is hard to check by hand that this has been done
21 | without exception throughout an entire subsystem. Then, `restrict_backends()`
22 | can be used to detect any overlooked case and report it by raising an exception.
23 |
24 | Similarly, it can be useful for a system such as a learner to make sure that
25 | all required JAX programs have been assigned to their respective backends by
26 | the end of its first iteration; this helps to show that it will not later run
27 | into memory fragmentation problems. By entering a `restrict_backends()` context
28 | at the end of the first iteration, the system can detect any overlooked cases.
29 | """
30 | import contextlib
31 | import functools
32 | from typing import Optional, Sequence
33 |
34 | # pylint: disable=g-import-not-at-top
35 | try:
36 | from jax._src import compiler
37 | except ImportError:
38 | # TODO(phawkins): remove this path after jax>=0.4.15 is the minimum version
39 | # required by chex.
40 | from jax._src import dispatch as compiler # type: ignore
41 | # pylint: enable=g-import-not-at-top
42 |
43 |
44 | class RestrictedBackendError(RuntimeError):
45 | pass
46 |
47 |
48 | @contextlib.contextmanager
49 | def restrict_backends(
50 | *,
51 | allowed: Optional[Sequence[str]] = None,
52 | forbidden: Optional[Sequence[str]] = None):
53 | """Disallows JAX compilation for certain backends.
54 |
55 | Args:
56 | allowed: Names of backend platforms (e.g. 'cpu' or 'tpu') for which
57 | compilation is still to be permitted.
58 | forbidden: Names of backend platforms for which compilation is to be
59 | forbidden.
60 |
61 | Yields:
62 | None, in a context where compilation for forbidden platforms will raise
63 | a `RestrictedBackendError`.
64 |
65 | Raises:
66 | ValueError: if neither `allowed` nor `forbidden` is specified (i.e. they
67 | are both `None`), or if anything is both allowed and forbidden.
68 | """
69 | allowed = tuple(allowed) if allowed is not None else None
70 | forbidden = tuple(forbidden) if forbidden is not None else None
71 |
72 | if allowed is None and forbidden is None:
73 | raise ValueError('No restrictions specified.')
74 | contradictions = set(allowed or ()) & set(forbidden or ())
75 | if contradictions:
76 | raise ValueError(
77 | f"Backends {contradictions} can't be both allowed and forbidden.")
78 |
79 | def is_allowed(backend_platform):
80 | return ((backend_platform in allowed) if allowed is not None else
81 | (backend_platform not in forbidden))
82 |
83 | inner_backend_compile = compiler.backend_compile
84 |
85 | @functools.wraps(inner_backend_compile)
86 | def wrapper(backend, *args, **kwargs):
87 | if not is_allowed(backend.platform):
88 | raise RestrictedBackendError(
89 | f'Compiling a JAX program for {backend.platform} is forbidden by '
90 | f'restrict_backends().')
91 | return inner_backend_compile(backend, *args, **kwargs)
92 |
93 | try:
94 | compiler.backend_compile = wrapper
95 | yield
96 | finally:
97 | backend_compile = compiler.backend_compile
98 | assert backend_compile is wrapper, backend_compile
99 | compiler.backend_compile = inner_backend_compile
100 |
--------------------------------------------------------------------------------
/chex/_src/restrict_backends_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for `restrict_backends.py`."""
16 | from absl.testing import absltest
17 | from chex._src import restrict_backends
18 | import jax
19 | import jax.numpy as jnp
20 | import numpy as np
21 |
22 |
23 | def compute_cube(side):
24 | return jnp.sum(jnp.ones((side, side)) * side)
25 |
26 |
27 | class RestrictBackendsTest(absltest.TestCase):
28 |
29 | # These tests need an accelerator of some sort, so that JAX can try to use it.
30 | def setUp(self):
31 | super().setUp()
32 |
33 | try:
34 | jax.devices('gpu')
35 | gpu_backend_available = True
36 | except RuntimeError:
37 | gpu_backend_available = False
38 |
39 | try:
40 | jax.devices('tpu')
41 | tpu_backend_available = True
42 | except RuntimeError:
43 | tpu_backend_available = False
44 |
45 | if not tpu_backend_available or gpu_backend_available:
46 | self.skipTest('No known accelerator backends are available, so these '
47 | 'tests will not test anything useful.')
48 |
49 | def test_detects_implicitly_forbidden_tpu_computation(self):
50 | with self.assertRaisesRegex(restrict_backends.RestrictedBackendError,
51 | r'forbidden by restrict_backends'):
52 | with restrict_backends.restrict_backends(allowed=['cpu']):
53 | compute_cube(3)
54 | # Make sure the restriction is no longer in place.
55 | np.testing.assert_array_equal(compute_cube(3), 27)
56 |
57 | def test_detects_explicitly_forbidden_tpu_computation(self):
58 | with self.assertRaisesRegex(restrict_backends.RestrictedBackendError,
59 | r'forbidden by restrict_backends'):
60 | with restrict_backends.restrict_backends(forbidden=['tpu', 'gpu']):
61 | compute_cube(2)
62 | # Make sure the restriction is no longer in place.
63 | np.testing.assert_array_equal(compute_cube(2), 8)
64 |
65 | def test_detects_implicitly_forbidden_cpu_computation(self):
66 | with self.assertRaisesRegex(restrict_backends.RestrictedBackendError,
67 | r'forbidden by restrict_backends'):
68 | with restrict_backends.restrict_backends(allowed=['tpu', 'gpu']):
69 | jax.jit(lambda: compute_cube(8), backend='cpu')()
70 | # Make sure the restriction is no longer in place.
71 | np.testing.assert_array_equal(compute_cube(8), 512)
72 |
73 | def test_detects_explicitly_forbidden_cpu_computation(self):
74 | with self.assertRaisesRegex(restrict_backends.RestrictedBackendError,
75 | r'forbidden by restrict_backends'):
76 | with restrict_backends.restrict_backends(forbidden=['cpu']):
77 | jax.jit(lambda: compute_cube(9), backend='cpu')()
78 | # Make sure the restriction is no longer in place.
79 | np.testing.assert_array_equal(compute_cube(9), 729)
80 |
81 | def test_ignores_explicitly_allowed_cpu_computation(self):
82 | with restrict_backends.restrict_backends(allowed=['cpu']):
83 | c = jax.jit(lambda: compute_cube(4), backend='cpu')()
84 | np.testing.assert_array_equal(c, 64)
85 |
86 | def test_ignores_implicitly_allowed_cpu_computation(self):
87 | with restrict_backends.restrict_backends(forbidden=['tpu', 'gpu']):
88 | c = jax.jit(lambda: compute_cube(5), backend='cpu')()
89 | np.testing.assert_array_equal(c, 125)
90 |
91 | def test_ignores_explicitly_allowed_tpu_computation(self):
92 | with restrict_backends.restrict_backends(allowed=['tpu', 'gpu']):
93 | c = jax.jit(lambda: compute_cube(6))()
94 | np.testing.assert_array_equal(c, 216)
95 |
96 | def test_ignores_implicitly_allowed_tpu_computation(self):
97 | with restrict_backends.restrict_backends(forbidden=['cpu']):
98 | c = jax.jit(lambda: compute_cube(7))()
99 | np.testing.assert_array_equal(c, 343)
100 |
101 | def test_raises_if_no_restrictions_specified(self):
102 | with self.assertRaisesRegex(ValueError, r'No restrictions specified'):
103 | with restrict_backends.restrict_backends():
104 | pass
105 |
106 | def test_raises_if_contradictory_restrictions_specified(self):
107 | with self.assertRaisesRegex(ValueError, r"can't be both"):
108 | with restrict_backends.restrict_backends(
109 | allowed=['cpu'], forbidden=['cpu']):
110 | pass
111 |
112 |
113 | if __name__ == '__main__':
114 | absltest.main()
115 |
--------------------------------------------------------------------------------
/chex/_src/asserts_internal_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for `asserts_internal.py`."""
16 |
17 | import functools
18 |
19 | from absl.testing import absltest
20 | from absl.testing import parameterized
21 | from chex._src import asserts_internal as ai
22 | from chex._src import variants
23 | import jax
24 | import jax.numpy as jnp
25 |
26 |
27 | class IsTraceableTest(variants.TestCase):
28 |
29 | @variants.variants(with_jit=True, with_pmap=True)
30 | def test_is_traceable(self):
31 | def dummy_wrapper(fn):
32 |
33 | @functools.wraps(fn)
34 | def fn_wrapped(fn, *args):
35 | return fn(args)
36 |
37 | return fn_wrapped
38 |
39 | fn = lambda x: x.sum()
40 | wrapped_fn = dummy_wrapper(fn)
41 | self.assertFalse(ai.is_traceable(fn))
42 | self.assertFalse(ai.is_traceable(wrapped_fn))
43 |
44 | var_fn = self.variant(fn)
45 | wrapped_var_f = dummy_wrapper(var_fn)
46 | var_wrapped_f = self.variant(wrapped_fn)
47 | self.assertTrue(ai.is_traceable(var_fn))
48 | self.assertTrue(ai.is_traceable(wrapped_var_f))
49 | self.assertTrue(ai.is_traceable(var_wrapped_f))
50 |
51 |
52 | class ExceptionMessageFormatTest(variants.TestCase):
53 |
54 | @parameterized.product(
55 | include_default_msg=(False, True),
56 | include_custom_msg=(False, True),
57 | exc_type=(AssertionError, ValueError),
58 | )
59 | def test_format(self, include_default_msg, include_custom_msg, exc_type):
60 |
61 | exc_msg = lambda x: f'{x} is non-positive.'
62 |
63 | @functools.partial(ai.chex_assertion, jittable_assert_fn=None)
64 | def assert_positive(x):
65 | if x <= 0:
66 | raise AssertionError(exc_msg(x))
67 |
68 | @functools.partial(ai.chex_assertion, jittable_assert_fn=None)
69 | def assert_each_positive(*args):
70 | for x in args:
71 | assert_positive(x)
72 |
73 | # Pass.
74 | assert_positive(1)
75 | assert_each_positive(1, 2, 3)
76 |
77 | # Check the format of raised exceptions' messages.
78 | def expected_exc_msg(x, custom_msg):
79 | msg = exc_msg(x) if include_default_msg else ''
80 | msg = rf'{msg} \[{custom_msg}\]' if custom_msg else msg
81 | return msg
82 |
83 | # Run in a loop to generate different custom messages.
84 | for i in range(3):
85 | custom_msg = f'failed at iter {i}' if include_custom_msg else ''
86 |
87 | with self.assertRaisesRegex(
88 | exc_type, ai.get_err_regex(expected_exc_msg(-1, custom_msg))):
89 | assert_positive( # pylint:disable=unexpected-keyword-arg
90 | -1,
91 | custom_message=custom_msg,
92 | include_default_message=include_default_msg,
93 | exception_type=exc_type)
94 |
95 | with self.assertRaisesRegex(
96 | exc_type, ai.get_err_regex(expected_exc_msg(-3, custom_msg))):
97 | assert_each_positive( # pylint:disable=unexpected-keyword-arg
98 | 1,
99 | -3,
100 | 2,
101 | custom_message=custom_msg,
102 | include_default_message=include_default_msg,
103 | exception_type=exc_type)
104 |
105 |
106 | class JitCompatibleTest(variants.TestCase):
107 |
108 | def test_api(self):
109 |
110 | def assert_fn(x):
111 | if x.shape != (2,):
112 | raise AssertionError(f'shape != (2,) {x.shape}!')
113 |
114 | for transform_fn in (jax.jit, jax.grad, jax.vmap):
115 | x_ok = jnp.ones((2,))
116 | x_wrong = jnp.ones((3,))
117 | is_vmap = transform_fn is jax.vmap
118 | if is_vmap:
119 | x_ok, x_wrong = (jnp.expand_dims(x, 0) for x in (x_ok, x_wrong))
120 |
121 | # Jax-compatible.
122 | assert_compat_fn = ai.chex_assertion(assert_fn, jittable_assert_fn=None)
123 |
124 | def compat_fn(x, assertion=assert_compat_fn):
125 | assertion(x)
126 | return x.sum()
127 |
128 | if not is_vmap:
129 | compat_fn(x_ok)
130 | transform_fn(compat_fn)(x_ok)
131 | with self.assertRaisesRegex(AssertionError, 'shape !='):
132 | transform_fn(compat_fn)(x_wrong)
133 |
134 | # JAX-incompatible.
135 | assert_incompat_fn = ai.chex_assertion(
136 | assert_fn, jittable_assert_fn=assert_fn)
137 |
138 | def incompat_fn(x, assertion=assert_incompat_fn):
139 | assertion(x)
140 | return x.sum()
141 |
142 | if not is_vmap:
143 | incompat_fn(x_ok)
144 | with self.assertRaisesRegex(RuntimeError,
145 | 'Value assertions can only be called from'):
146 | transform_fn(incompat_fn)(x_wrong)
147 |
148 |
149 | if __name__ == '__main__':
150 | jax.config.update('jax_numpy_rank_promotion', 'raise')
151 | absltest.main()
152 |
--------------------------------------------------------------------------------
/chex/_src/dimensions_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for ``dimensions`` module."""
16 |
17 | import doctest
18 |
19 | from absl.testing import absltest
20 | from absl.testing import parameterized
21 | from chex._src import asserts
22 | from chex._src import dimensions
23 | import jax
24 | import numpy as np
25 |
26 |
27 | class _ChexModule:
28 | """Mock module for providing minimal context to docstring tests."""
29 | assert_shape = asserts.assert_shape
30 | assert_rank = asserts.assert_rank
31 | Dimensions = dimensions.Dimensions # pylint: disable=invalid-name
32 |
33 |
34 | class DimensionsTest(parameterized.TestCase):
35 |
36 | def test_docstring_examples(self):
37 | doctest.run_docstring_examples(
38 | dimensions.Dimensions,
39 | globs={'chex': _ChexModule, 'jax': jax, 'jnp': jax.numpy})
40 |
41 | @parameterized.named_parameters([
42 | ('scalar', '', (), ()),
43 | ('vector', 'a', (7,), (7,)),
44 | ('list', 'ab', [7, 11], (7, 11)),
45 | ('numpy_array', 'abc', np.array([7, 11, 13]), (7, 11, 13)),
46 | ('case_sensitive', 'aA', (7, 11), (7, 11)),
47 | ])
48 | def test_set_ok(self, k, v, shape):
49 | dims = dimensions.Dimensions(x=23, y=29)
50 | dims[k] = v
51 | asserts.assert_shape(np.empty((23, *shape, 29)), dims['x' + k + 'y'])
52 |
53 | def test_set_wildcard(self):
54 | dims = dimensions.Dimensions(x=23, y=29)
55 | dims['a_b__'] = (7, 11, 13, 17, 19)
56 | self.assertEqual(dims['xayb'], (23, 7, 29, 13))
57 | with self.assertRaisesRegex(KeyError, r'\*'):
58 | dims['ab*'] = (7, 11, 13)
59 |
60 | def test_get_wildcard(self):
61 | dims = dimensions.Dimensions(x=23, y=29)
62 | self.assertEqual(dims['x*y**'], (23, None, 29, None, None))
63 | asserts.assert_shape(np.empty((23, 1, 29, 2, 3)), dims['x*y**'])
64 | with self.assertRaisesRegex(KeyError, r'\_'):
65 | _ = dims['xy_']
66 |
67 | def test_get_literals(self):
68 | dims = dimensions.Dimensions(x=23, y=29)
69 | self.assertEqual(dims['x1y23'], (23, 1, 29, 2, 3))
70 |
71 | @parameterized.named_parameters([
72 | ('scalar', 'a', 7, TypeError, r'value must be sized'),
73 | ('iterator', 'a', (x for x in [7]), TypeError, r'value must be sized'),
74 | ('len_mismatch', 'ab', (7, 11, 13), ValueError, r'different length'),
75 | ('non_integer_size', 'a', (7.001,),
76 | TypeError, r'cannot be interpreted as a python int'),
77 | ('bad_key_type', 13, (7,), TypeError, r'key must be a string'),
78 | ('bad_key_string', '@%^#', (7, 11, 13, 17), KeyError, r'\@'),
79 | ])
80 | def test_set_exception(self, k, v, e, m):
81 | dims = dimensions.Dimensions(x=23, y=29)
82 | with self.assertRaisesRegex(e, m):
83 | dims[k] = v
84 |
85 | @parameterized.named_parameters([
86 | ('bad_key_type', 13, TypeError, r'key must be a string'),
87 | ('bad_key_string', '@%^#', KeyError, r'\@'),
88 | ])
89 | def test_get_exception(self, k, e, m):
90 | dims = dimensions.Dimensions(x=23, y=29)
91 | with self.assertRaisesRegex(e, m):
92 | _ = dims[k]
93 |
94 | @parameterized.named_parameters([
95 | ('scalar', '', (), 1),
96 | ('nonscalar', 'ab', (3, 5), 15),
97 | ])
98 | def test_size_ok(self, names, shape, expected_size):
99 | dims = dimensions.Dimensions(**dict(zip(names, shape)))
100 | self.assertEqual(dims.size(names), expected_size)
101 |
102 | @parameterized.named_parameters([
103 | ('named', 'ab'),
104 | ('asterisk', 'a*'),
105 | ('zero', 'a0'),
106 | ('negative', 'ac'),
107 | ])
108 | def test_size_fail_wildcard(self, names):
109 | dims = dimensions.Dimensions(a=3, b=None, c=-1)
110 | with self.assertRaisesRegex(ValueError, r'cannot take product of shape'):
111 | dims.size(names)
112 |
113 | @parameterized.named_parameters([
114 | ('trivial_start', '(a)bc', (3, 5, 7)),
115 | ('trivial_mid', 'a(b)c', (3, 5, 7)),
116 | ('trivial_end', 'ab(c)', (3, 5, 7)),
117 | ('start', '(ab)cd', (15, 7, 11)),
118 | ('mid', 'a(bc)d', (3, 35, 11)),
119 | ('end', 'ab(cd)', (3, 5, 77)),
120 | ('multiple', '(ab)(cd)', (15, 77)),
121 | ('all', '(abc)', (105,)),
122 | ])
123 | def test_flatten_ok(self, named_shape, expected_shape):
124 | dims = dimensions.Dimensions(a=3, b=5, c=7, d=11)
125 | self.assertEqual(dims[named_shape], expected_shape)
126 |
127 | @parameterized.named_parameters([
128 | ('unmatched_open', '(ab', r'unmatched parentheses in named shape'),
129 | ('unmatched_closed', 'a)b', r'unmatched parentheses in named shape'),
130 | ('nested', '(a(bc))', r'nested parentheses are unsupported'),
131 | ('wildcard_named', 'a(bx)', r'cannot take product of shape'),
132 | ('wildcard_asterisk', '(a*)b', r'cannot take product of shape'),
133 | ('zero_sized_dim', '(a0)b', r'cannot take product of shape'),
134 | ('neg_sized_dim', '(ay)b', r'cannot take product of shape'),
135 | ('empty_start', '()ab', r'found empty parentheses in named shape'),
136 | ('empty_mid', 'a()b', r'found empty parentheses in named shape'),
137 | ('empty_end', 'ab()', r'found empty parentheses in named shape'),
138 | ('empty_solo', '()', r'found empty parentheses in named shape'),
139 | ])
140 | def test_flatten_fail(self, named_shape, error_message):
141 | dims = dimensions.Dimensions(a=3, b=5, x=None, y=-1)
142 | with self.assertRaisesRegex(ValueError, error_message):
143 | _ = dims[named_shape]
144 |
145 |
146 | if __name__ == '__main__':
147 | absltest.main()
148 |
--------------------------------------------------------------------------------
/docs/api.rst:
--------------------------------------------------------------------------------
1 | Assertions
2 | ==========
3 |
4 | .. currentmodule:: chex
5 |
6 | .. autosummary::
7 |
8 | assert_axis_dimension
9 | assert_axis_dimension_comparator
10 | assert_axis_dimension_gt
11 | assert_axis_dimension_gteq
12 | assert_axis_dimension_lt
13 | assert_axis_dimension_lteq
14 | assert_devices_available
15 | assert_equal
16 | assert_equal_rank
17 | assert_equal_size
18 | assert_equal_shape
19 | assert_equal_shape_prefix
20 | assert_equal_shape_suffix
21 | assert_exactly_one_is_none
22 | assert_gpu_available
23 | assert_is_broadcastable
24 | assert_is_divisible
25 | assert_max_traces
26 | assert_not_both_none
27 | assert_numerical_grads
28 | assert_rank
29 | assert_scalar
30 | assert_scalar_in
31 | assert_scalar_negative
32 | assert_scalar_non_negative
33 | assert_scalar_positive
34 | assert_size
35 | assert_shape
36 | assert_tpu_available
37 | assert_tree_all_finite
38 | assert_tree_has_only_ndarrays
39 | assert_tree_is_on_device
40 | assert_tree_is_on_host
41 | assert_tree_is_sharded
42 | assert_tree_no_nones
43 | assert_tree_shape_prefix
44 | assert_tree_shape_suffix
45 | assert_trees_all_close
46 | assert_trees_all_close_ulp
47 | assert_trees_all_equal
48 | assert_trees_all_equal_comparator
49 | assert_trees_all_equal_dtypes
50 | assert_trees_all_equal_sizes
51 | assert_trees_all_equal_shapes
52 | assert_trees_all_equal_shapes_and_dtypes
53 | assert_trees_all_equal_structs
54 | assert_type
55 | chexify
56 | ChexifyChecks
57 | with_jittable_assertions
58 | block_until_chexify_assertions_complete
59 | Dimensions
60 | disable_asserts
61 | enable_asserts
62 | clear_trace_counter
63 | if_args_not_none
64 |
65 |
66 | Jax Assertions
67 | ~~~~~~~~~~~~~~
68 |
69 | .. autofunction:: assert_max_traces
70 | .. autofunction:: assert_devices_available
71 | .. autofunction:: assert_gpu_available
72 | .. autofunction:: assert_tpu_available
73 |
74 |
75 | Value (Runtime) Assertions
76 | ~~~~~~~~~~~~~~~~~~~~~~~~~~
77 |
78 | .. autofunction:: chexify
79 | .. autosummary:: ChexifyChecks
80 | .. autofunction:: with_jittable_assertions
81 | .. autofunction:: block_until_chexify_assertions_complete
82 |
83 |
84 | Tree Assertions
85 | ~~~~~~~~~~~~~~~
86 |
87 | .. autofunction:: assert_tree_all_finite
88 | .. autofunction:: assert_tree_has_only_ndarrays
89 | .. autofunction:: assert_tree_is_on_device
90 | .. autofunction:: assert_tree_is_on_host
91 | .. autofunction:: assert_tree_is_sharded
92 | .. autofunction:: assert_tree_no_nones
93 | .. autofunction:: assert_tree_shape_prefix
94 | .. autofunction:: assert_tree_shape_suffix
95 | .. autofunction:: assert_trees_all_close
96 | .. autofunction:: assert_trees_all_close_ulp
97 | .. autofunction:: assert_trees_all_equal
98 | .. autofunction:: assert_trees_all_equal_comparator
99 | .. autofunction:: assert_trees_all_equal_dtypes
100 | .. autofunction:: assert_trees_all_equal_sizes
101 | .. autofunction:: assert_trees_all_equal_shapes
102 | .. autofunction:: assert_trees_all_equal_shapes_and_dtypes
103 | .. autofunction:: assert_trees_all_equal_structs
104 |
105 |
106 | Generic Assertions
107 | ~~~~~~~~~~~~~~~~~~
108 |
109 | .. autofunction:: assert_axis_dimension
110 | .. autofunction:: assert_axis_dimension_comparator
111 | .. autofunction:: assert_axis_dimension_gt
112 | .. autofunction:: assert_axis_dimension_gteq
113 | .. autofunction:: assert_axis_dimension_lt
114 | .. autofunction:: assert_axis_dimension_lteq
115 | .. autofunction:: assert_equal
116 | .. autofunction:: assert_equal_rank
117 | .. autofunction:: assert_equal_size
118 | .. autofunction:: assert_equal_shape
119 | .. autofunction:: assert_equal_shape_prefix
120 | .. autofunction:: assert_equal_shape_suffix
121 | .. autofunction:: assert_exactly_one_is_none
122 | .. autofunction:: assert_is_broadcastable
123 | .. autofunction:: assert_is_divisible
124 | .. autofunction:: assert_not_both_none
125 | .. autofunction:: assert_numerical_grads
126 | .. autofunction:: assert_rank
127 | .. autofunction:: assert_scalar
128 | .. autofunction:: assert_scalar_in
129 | .. autofunction:: assert_scalar_negative
130 | .. autofunction:: assert_scalar_non_negative
131 | .. autofunction:: assert_scalar_positive
132 | .. autofunction:: assert_size
133 | .. autofunction:: assert_shape
134 | .. autofunction:: assert_type
135 |
136 |
137 | Shapes and Named Dimensions
138 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~
139 |
140 | .. autoclass:: Dimensions
141 |
142 |
143 | Utils
144 | ~~~~~
145 |
146 | .. autofunction:: disable_asserts
147 | .. autofunction:: enable_asserts
148 | .. autofunction:: clear_trace_counter
149 | .. autofunction:: if_args_not_none
150 |
151 |
152 | Warnings
153 | ==========
154 |
155 | .. currentmodule:: chex
156 |
157 | .. autofunction:: create_deprecated_function_alias
158 | .. autofunction:: warn_deprecated_function
159 | .. autofunction:: warn_keyword_args_only_in_future
160 | .. autofunction:: warn_only_n_pos_args_in_future
161 |
162 |
163 | Backend restriction
164 | ===================
165 |
166 | .. currentmodule:: chex
167 |
168 | .. autofunction:: restrict_backends
169 |
170 |
171 | Dataclasses
172 | ===========
173 |
174 | .. currentmodule:: chex
175 |
176 | .. autofunction:: dataclass
177 | .. autofunction:: mappable_dataclass
178 | .. autofunction:: register_dataclass_type_with_jax_tree_util
179 |
180 |
181 | Fakes
182 | =====
183 |
184 | .. currentmodule:: chex
185 |
186 | .. autosummary::
187 |
188 | fake_jit
189 | fake_pmap
190 | fake_pmap_and_jit
191 | set_n_cpu_devices
192 |
193 | Transformations
194 | ~~~~~~~~~~~~~~~
195 |
196 | .. autofunction:: fake_jit
197 | .. autofunction:: fake_pmap
198 | .. autofunction:: fake_pmap_and_jit
199 |
200 |
201 | Devices
202 | ~~~~~~~
203 |
204 | .. autofunction:: set_n_cpu_devices
205 |
206 |
207 | Pytypes
208 | =======
209 |
210 | .. currentmodule:: chex
211 |
212 | .. autosummary::
213 |
214 | Array
215 | ArrayBatched
216 | ArrayDevice
217 | ArrayDeviceTree
218 | ArrayDType
219 | ArrayNumpy
220 | ArrayNumpyTree
221 | ArraySharded
222 | ArrayTree
223 | Device
224 | Numeric
225 | PRNGKey
226 | PyTreeDef
227 | Scalar
228 | Shape
229 |
230 |
231 |
232 | Variants
233 | ========
234 | .. currentmodule:: chex
235 |
236 | .. autoclass:: ChexVariantType
237 | .. autoclass:: TestCase
238 | .. autofunction:: variants
239 | .. autofunction:: all_variants
240 | .. autofunction:: params_product
241 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Configuration file for the Sphinx documentation builder."""
16 |
17 | # This file only contains a selection of the most common options. For a full
18 | # list see the documentation:
19 | # http://www.sphinx-doc.org/en/master/config
20 |
21 | # -- Path setup --------------------------------------------------------------
22 |
23 | # If extensions (or modules to document with autodoc) are in another directory,
24 | # add these directories to sys.path here. If the directory is relative to the
25 | # documentation root, use os.path.abspath to make it absolute, like shown here.
26 |
27 | # pylint: disable=g-bad-import-order
28 | # pylint: disable=g-import-not-at-top
29 | import inspect
30 | import os
31 | import sys
32 |
33 |
34 | def _add_annotations_import(path):
35 | """Appends a future annotations import to the file at the given path."""
36 | with open(path) as f:
37 | contents = f.read()
38 | if contents.startswith('from __future__ import annotations'):
39 | # If we run sphinx multiple times then we will append the future import
40 | # multiple times too.
41 | return
42 |
43 | assert contents.startswith('#'), (path, contents.split('\n')[0])
44 | with open(path, 'w') as f:
45 | # NOTE: This is subtle and not unit tested, we're prefixing the first line
46 | # in each Python file with this future import. It is important to prefix
47 | # not insert a newline such that source code locations are accurate (we link
48 | # to GitHub). The assertion above ensures that the first line in the file is
49 | # a comment so it is safe to prefix it.
50 | f.write('from __future__ import annotations ')
51 | f.write(contents)
52 |
53 |
54 | def _recursive_add_annotations_import():
55 | for path, _, files in os.walk('../chex/'):
56 | for file in files:
57 | if file.endswith('.py'):
58 | _add_annotations_import(os.path.abspath(os.path.join(path, file)))
59 |
60 | if 'READTHEDOCS' in os.environ:
61 | _recursive_add_annotations_import()
62 |
63 | sys.path.insert(0, os.path.abspath('../'))
64 | sys.path.append(os.path.abspath('ext'))
65 |
66 | import chex
67 | from sphinxcontrib import katex
68 |
69 | # -- Project information -----------------------------------------------------
70 |
71 | project = 'Chex'
72 | copyright = '2021, DeepMind' # pylint: disable=redefined-builtin
73 | author = 'Chex Contributors'
74 |
75 | # -- General configuration ---------------------------------------------------
76 |
77 | master_doc = 'index'
78 |
79 | # Add any Sphinx extension module names here, as strings. They can be
80 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
81 | # ones.
82 | extensions = [
83 | 'sphinx.ext.autodoc',
84 | 'sphinx.ext.autosummary',
85 | 'sphinx.ext.doctest',
86 | 'sphinx.ext.inheritance_diagram',
87 | 'sphinx.ext.intersphinx',
88 | 'sphinx.ext.linkcode',
89 | 'sphinx.ext.napoleon',
90 | 'sphinxcontrib.katex',
91 | 'coverage_check',
92 | ]
93 |
94 | # Add any paths that contain templates here, relative to this directory.
95 | templates_path = ['_templates']
96 |
97 | # List of patterns, relative to source directory, that match files and
98 | # directories to ignore when looking for source files.
99 | # This pattern also affects html_static_path and html_extra_path.
100 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
101 |
102 | # -- Options for autodoc -----------------------------------------------------
103 |
104 | autodoc_default_options = {
105 | 'member-order': 'bysource',
106 | 'special-members': True,
107 | 'exclude-members': '__repr__, __str__, __weakref__',
108 | }
109 |
110 | # -- Options for HTML output -------------------------------------------------
111 |
112 | # The theme to use for HTML and HTML Help pages. See the documentation for
113 | # a list of builtin themes.
114 | html_theme = 'sphinx_book_theme'
115 |
116 | html_theme_options = {
117 | 'show_toc_level': 2,
118 | 'repository_url': 'https://github.com/google-deepmind/chex',
119 | 'use_repository_button': True, # add a "link to repository" button
120 | 'navigation_with_keys': False,
121 | }
122 |
123 | # Add any paths that contain custom static files (such as style sheets) here,
124 | # relative to this directory. They are copied after the builtin static files,
125 | # so a file named "default.css" will overwrite the builtin "default.css".
126 |
127 | html_static_path = []
128 |
129 | # -- Options for katex ------------------------------------------------------
130 |
131 | # See: https://sphinxcontrib-katex.readthedocs.io/en/0.4.1/macros.html
132 | latex_macros = r"""
133 | \def \d #1{\operatorname{#1}}
134 | """
135 |
136 | # Translate LaTeX macros to KaTeX and add to options for HTML builder
137 | katex_macros = katex.latex_defs_to_katex_macros(latex_macros)
138 | katex_options = (
139 | '{displayMode: true, fleqn: true, macros: {' + katex_macros + '}}'
140 | )
141 |
142 | # Add LaTeX macros for LATEX builder
143 | latex_elements = {'preamble': latex_macros}
144 |
145 | # -- Source code links -------------------------------------------------------
146 |
147 |
148 | def linkcode_resolve(domain, info):
149 | """Resolve a GitHub URL corresponding to Python object."""
150 | if domain != 'py':
151 | return None
152 |
153 | try:
154 | mod = sys.modules[info['module']]
155 | except ImportError:
156 | return None
157 |
158 | obj = mod
159 | try:
160 | for attr in info['fullname'].split('.'):
161 | obj = getattr(obj, attr)
162 | except AttributeError:
163 | return None
164 | else:
165 | obj = inspect.unwrap(obj)
166 |
167 | try:
168 | filename = inspect.getsourcefile(obj)
169 | except TypeError:
170 | return None
171 |
172 | try:
173 | source, lineno = inspect.getsourcelines(obj)
174 | except OSError:
175 | return None
176 |
177 | # TODO(slebedev): support tags after we release an initial version.
178 | return (
179 | 'https://github.com/google-deepmind/chex/tree/main/chex/%s#L%d#L%d'
180 | % (
181 | os.path.relpath(filename, start=os.path.dirname(chex.__file__)),
182 | lineno,
183 | lineno + len(source) - 1,
184 | )
185 | )
186 |
187 |
188 | # -- Intersphinx configuration -----------------------------------------------
189 |
190 | intersphinx_mapping = {
191 | 'jax': ('https://jax.readthedocs.io/en/latest/', None),
192 | }
193 |
194 | source_suffix = ['.rst', '.md', '.ipynb']
195 |
--------------------------------------------------------------------------------
/chex/_src/dimensions.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Utilities to hold expected dimension sizes."""
16 |
17 | import math
18 | import re
19 | from typing import Any, Collection, Dict, Optional, Sized, Tuple
20 |
21 |
22 | Shape = Tuple[Optional[int], ...]
23 |
24 |
25 | class Dimensions:
26 | """A lightweight utility that maps strings to shape tuples.
27 |
28 | The most basic usage is:
29 |
30 | .. code::
31 |
32 | >>> dims = chex.Dimensions(B=3, T=5, N=7) # You can specify any letters.
33 | >>> dims['NBT']
34 | (7, 3, 5)
35 |
36 | This is useful when dealing with many differently shaped arrays. For instance,
37 | let's check the shape of this array:
38 |
39 | .. code::
40 |
41 | >>> x = jnp.array([[2, 0, 5, 6, 3],
42 | ... [5, 4, 4, 3, 3],
43 | ... [0, 0, 5, 2, 0]])
44 | >>> chex.assert_shape(x, dims['BT'])
45 |
46 | The dimension sizes can be gotten directly, e.g. :code:`dims.N == 7`. This can
47 | be useful in many applications. For instance, let's one-hot encode our array.
48 |
49 | .. code::
50 |
51 | >>> y = jax.nn.one_hot(x, dims.N)
52 | >>> chex.assert_shape(y, dims['BTN'])
53 |
54 | You can also store the shape of a given array in :code:`dims`, e.g.
55 |
56 | .. code::
57 |
58 | >>> z = jnp.array([[0, 6, 0, 2],
59 | ... [4, 2, 2, 4]])
60 | >>> dims['XY'] = z.shape
61 | >>> dims
62 | Dimensions(B=3, N=7, T=5, X=2, Y=4)
63 |
64 | You can access the flat size of a shape as
65 |
66 | .. code::
67 |
68 | >>> dims.size('BT') # Same as prod(dims['BT']).
69 | 15
70 |
71 | Similarly, you can flatten axes together by wrapping them in parentheses:
72 |
73 | .. code::
74 |
75 | >>> dims['(BT)N']
76 | (15, 7)
77 |
78 | You can set a wildcard dimension, cf. :func:`chex.assert_shape`:
79 |
80 | .. code::
81 |
82 | >>> dims.W = None
83 | >>> dims['BTW']
84 | (3, 5, None)
85 |
86 | Or you can use the wildcard character `'*'` directly:
87 |
88 | .. code::
89 |
90 | >>> dims['BT*']
91 | (3, 5, None)
92 |
93 | Single digits are interpreted as literal integers. Note that this notation
94 | is limited to single-digit literals.
95 |
96 | .. code::
97 |
98 | >>> dims['BT123']
99 | (3, 5, 1, 2, 3)
100 |
101 | Support for single digits was mainly included to accommodate dummy axes
102 | introduced for consistent broadcasting. For instance, instead of using
103 | :func:`jnp.expand_dims ` you could do the following:
104 |
105 | .. code::
106 |
107 | >>> w = y * x # Cannot broadcast (3, 5, 7) with (3, 5)
108 | Traceback (most recent call last):
109 | ...
110 | ValueError: Incompatible shapes for broadcasting: ((3, 5, 7), (1, 3, 5))
111 | >>> w = y * x.reshape(dims['BT1'])
112 | >>> chex.assert_shape(w, dims['BTN'])
113 |
114 | Sometimes you only care about some array dimensions but not all. You can use
115 | an underscore to ignore an axis, e.g.
116 |
117 | .. code::
118 |
119 | >>> chex.assert_rank(y, 3)
120 | >>> dims['__M'] = y.shape # Skip the first two axes.
121 |
122 | Finally note that a single-character key returns a tuple of length one.
123 |
124 | .. code::
125 |
126 | >>> dims['M']
127 | (7,)
128 | """
129 | # Tell static type checker not to worry about attribute errors.
130 | _HAS_DYNAMIC_ATTRIBUTES = True
131 |
132 | def __init__(self, **dim_sizes) -> None:
133 | for dim, size in dim_sizes.items():
134 | self._setdim(dim, size)
135 |
136 | def size(self, key: str) -> int:
137 | """Returns the flat size of a given named shape, i.e. prod(shape)."""
138 | shape = self[key]
139 | if any(size is None or size <= 0 for size in shape):
140 | raise ValueError(
141 | f"cannot take product of shape '{key}' = {shape}, "
142 | 'because it contains non-positive sized dimensions'
143 | )
144 | return math.prod(shape)
145 |
146 | def __getitem__(self, key: str) -> Shape:
147 | self._validate_key(key)
148 | shape = []
149 | open_parentheses = False
150 | dims_to_flatten = ''
151 | for dim in key:
152 | # Signal to start accumulating `dims_to_flatten`.
153 | if dim == '(':
154 | if open_parentheses:
155 | raise ValueError(f"nested parentheses are unsupported; got: '{key}'")
156 | open_parentheses = True
157 |
158 | # Signal to collect accumulated `dims_to_flatten`.
159 | elif dim == ')':
160 | if not open_parentheses:
161 | raise ValueError(f"unmatched parentheses in named shape: '{key}'")
162 | if not dims_to_flatten:
163 | raise ValueError(f"found empty parentheses in named shape: '{key}'")
164 | shape.append(self.size(dims_to_flatten))
165 | # Reset.
166 | open_parentheses = False
167 | dims_to_flatten = ''
168 |
169 | # Accumulate `dims_to_flatten`.
170 | elif open_parentheses:
171 | dims_to_flatten += dim
172 |
173 | # The typical (non-flattening) case.
174 | else:
175 | shape.append(self._getdim(dim))
176 |
177 | if open_parentheses:
178 | raise ValueError(f"unmatched parentheses in named shape: '{key}'")
179 | return tuple(shape)
180 |
181 | def __setitem__(self, key: str, value: Collection[Optional[int]]) -> None:
182 | self._validate_key(key)
183 | self._validate_value(value)
184 | if len(key) != len(value):
185 | raise ValueError(
186 | f'key string {repr(key)} and shape {tuple(value)} '
187 | 'have different lengths')
188 | for dim, size in zip(key, value):
189 | self._setdim(dim, size)
190 |
191 | def __delitem__(self, key: str) -> None:
192 | self._validate_key(key)
193 | for dim in key:
194 | self._deldim(dim)
195 |
196 | def __repr__(self) -> str:
197 | args = ', '.join(f'{k}={v}' for k, v in sorted(self._asdict().items()))
198 | return f'{type(self).__name__}({args})'
199 |
200 | def _asdict(self) -> Dict[str, Optional[int]]:
201 | return {k: v for k, v in self.__dict__.items()
202 | if re.fullmatch(r'[a-zA-Z]', k)}
203 |
204 | def _getdim(self, dim: str) -> Optional[int]:
205 | if dim == '*':
206 | return None
207 | if re.fullmatch(r'[0-9]', dim):
208 | return int(dim)
209 | try:
210 | return getattr(self, dim)
211 | except AttributeError as e:
212 | raise KeyError(dim) from e
213 |
214 | def _setdim(self, dim: str, size: Optional[int]) -> None:
215 | if dim == '_': # Skip.
216 | return
217 | self._validate_dim(dim)
218 | setattr(self, dim, _optional_int(size))
219 |
220 | def _deldim(self, dim: str) -> None:
221 | if dim == '_': # Skip.
222 | return
223 | self._validate_dim(dim)
224 | try:
225 | return delattr(self, dim)
226 | except AttributeError as e:
227 | raise KeyError(dim) from e
228 |
229 | def _validate_key(self, key: Any) -> None:
230 | if not isinstance(key, str):
231 | raise TypeError(f'key must be a string; got: {type(key).__name__}')
232 |
233 | def _validate_value(self, value: Any) -> None:
234 | if not isinstance(value, Sized):
235 | raise TypeError(
236 | 'value must be sized, i.e. an object with a well-defined len(value); '
237 | f'got object of type: {type(value).__name__}')
238 |
239 | def _validate_dim(self, dim: Any) -> None:
240 | if not isinstance(dim, str):
241 | raise TypeError(
242 | f'dimension name must be a string; got: {type(dim).__name__}')
243 | if not re.fullmatch(r'[a-zA-Z]', dim):
244 | raise KeyError(
245 | 'dimension names may only be contain letters (or \'_\' to skip); '
246 | f'got dimension name: {repr(dim)}')
247 |
248 |
249 | def _optional_int(x: Any) -> Optional[int]:
250 | if x is None:
251 | return None
252 | try:
253 | i = int(x)
254 | if x == i:
255 | return i
256 | except ValueError:
257 | pass
258 | raise TypeError(f'object cannot be interpreted as a python int: {repr(x)}')
259 |
--------------------------------------------------------------------------------
/chex/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Chex: Testing made fun, in JAX!"""
16 |
17 | from chex._src.asserts import assert_axis_dimension
18 | from chex._src.asserts import assert_axis_dimension_comparator
19 | from chex._src.asserts import assert_axis_dimension_gt
20 | from chex._src.asserts import assert_axis_dimension_gteq
21 | from chex._src.asserts import assert_axis_dimension_lt
22 | from chex._src.asserts import assert_axis_dimension_lteq
23 | from chex._src.asserts import assert_devices_available
24 | from chex._src.asserts import assert_equal
25 | from chex._src.asserts import assert_equal_rank
26 | from chex._src.asserts import assert_equal_shape
27 | from chex._src.asserts import assert_equal_shape_prefix
28 | from chex._src.asserts import assert_equal_shape_suffix
29 | from chex._src.asserts import assert_equal_size
30 | from chex._src.asserts import assert_exactly_one_is_none
31 | from chex._src.asserts import assert_gpu_available
32 | from chex._src.asserts import assert_is_broadcastable
33 | from chex._src.asserts import assert_is_divisible
34 | from chex._src.asserts import assert_max_traces
35 | from chex._src.asserts import assert_not_both_none
36 | from chex._src.asserts import assert_numerical_grads
37 | from chex._src.asserts import assert_rank
38 | from chex._src.asserts import assert_scalar
39 | from chex._src.asserts import assert_scalar_in
40 | from chex._src.asserts import assert_scalar_negative
41 | from chex._src.asserts import assert_scalar_non_negative
42 | from chex._src.asserts import assert_scalar_positive
43 | from chex._src.asserts import assert_shape
44 | from chex._src.asserts import assert_size
45 | from chex._src.asserts import assert_tpu_available
46 | from chex._src.asserts import assert_tree_all_finite
47 | from chex._src.asserts import assert_tree_has_only_ndarrays
48 | from chex._src.asserts import assert_tree_is_on_device
49 | from chex._src.asserts import assert_tree_is_on_host
50 | from chex._src.asserts import assert_tree_is_sharded
51 | from chex._src.asserts import assert_tree_no_nones
52 | from chex._src.asserts import assert_tree_shape_prefix
53 | from chex._src.asserts import assert_tree_shape_suffix
54 | from chex._src.asserts import assert_trees_all_close
55 | from chex._src.asserts import assert_trees_all_close_ulp
56 | from chex._src.asserts import assert_trees_all_equal
57 | from chex._src.asserts import assert_trees_all_equal_comparator
58 | from chex._src.asserts import assert_trees_all_equal_dtypes
59 | from chex._src.asserts import assert_trees_all_equal_shapes
60 | from chex._src.asserts import assert_trees_all_equal_shapes_and_dtypes
61 | from chex._src.asserts import assert_trees_all_equal_sizes
62 | from chex._src.asserts import assert_trees_all_equal_structs
63 | from chex._src.asserts import assert_type
64 | from chex._src.asserts import clear_trace_counter
65 | from chex._src.asserts import disable_asserts
66 | from chex._src.asserts import enable_asserts
67 | from chex._src.asserts import if_args_not_none
68 | from chex._src.asserts_chexify import block_until_chexify_assertions_complete
69 | from chex._src.asserts_chexify import chexify
70 | from chex._src.asserts_chexify import ChexifyChecks
71 | from chex._src.asserts_chexify import with_jittable_assertions
72 | from chex._src.dataclass import dataclass
73 | from chex._src.dataclass import mappable_dataclass
74 | from chex._src.dataclass import register_dataclass_type_with_jax_tree_util
75 | from chex._src.dimensions import Dimensions
76 | from chex._src.fake import fake_jit
77 | from chex._src.fake import fake_pmap
78 | from chex._src.fake import fake_pmap_and_jit
79 | from chex._src.fake import set_n_cpu_devices
80 | from chex._src.pytypes import Array
81 | from chex._src.pytypes import ArrayBatched
82 | from chex._src.pytypes import ArrayDevice
83 | from chex._src.pytypes import ArrayDeviceTree
84 | from chex._src.pytypes import ArrayDType
85 | from chex._src.pytypes import ArrayNumpy
86 | from chex._src.pytypes import ArrayNumpyTree
87 | from chex._src.pytypes import ArraySharded
88 | from chex._src.pytypes import ArrayTree
89 | from chex._src.pytypes import Device
90 | from chex._src.pytypes import Numeric
91 | from chex._src.pytypes import PRNGKey
92 | from chex._src.pytypes import PyTreeDef
93 | from chex._src.pytypes import Scalar
94 | from chex._src.pytypes import Shape
95 | from chex._src.restrict_backends import restrict_backends
96 | from chex._src.variants import all_variants
97 | from chex._src.variants import ChexVariantType
98 | from chex._src.variants import params_product
99 | from chex._src.variants import TestCase
100 | from chex._src.variants import variants
101 | from chex._src.warnings import create_deprecated_function_alias
102 | from chex._src.warnings import warn_deprecated_function
103 | from chex._src.warnings import warn_keyword_args_only_in_future
104 | from chex._src.warnings import warn_only_n_pos_args_in_future
105 |
106 |
107 | __version__ = "0.1.86"
108 |
109 | __all__ = (
110 | "all_variants",
111 | "Array",
112 | "ArrayBatched",
113 | "ArrayDevice",
114 | "ArrayDeviceTree",
115 | "ArrayDType",
116 | "ArrayNumpy",
117 | "ArrayNumpyTree",
118 | "ArraySharded",
119 | "ArrayTree",
120 | "ChexifyChecks",
121 | "assert_axis_dimension",
122 | "assert_axis_dimension_comparator",
123 | "assert_axis_dimension_gt",
124 | "assert_axis_dimension_gteq",
125 | "assert_axis_dimension_lt",
126 | "assert_axis_dimension_lteq",
127 | "assert_devices_available",
128 | "assert_equal",
129 | "assert_equal_rank",
130 | "assert_equal_shape",
131 | "assert_equal_shape_prefix",
132 | "assert_equal_shape_suffix",
133 | "assert_equal_size",
134 | "assert_exactly_one_is_none",
135 | "assert_gpu_available",
136 | "assert_is_broadcastable",
137 | "assert_is_divisible",
138 | "assert_max_traces",
139 | "assert_not_both_none",
140 | "assert_numerical_grads",
141 | "assert_rank",
142 | "assert_scalar",
143 | "assert_scalar_in",
144 | "assert_scalar_negative",
145 | "assert_scalar_non_negative",
146 | "assert_scalar_positive",
147 | "assert_shape",
148 | "assert_size",
149 | "assert_tpu_available",
150 | "assert_tree_all_finite",
151 | "assert_tree_has_only_ndarrays",
152 | "assert_tree_is_on_device",
153 | "assert_tree_is_on_host",
154 | "assert_tree_is_sharded",
155 | "assert_tree_no_nones",
156 | "assert_tree_shape_prefix",
157 | "assert_tree_shape_suffix",
158 | "assert_trees_all_close",
159 | "assert_trees_all_close_ulp",
160 | "assert_trees_all_equal",
161 | "assert_trees_all_equal_comparator",
162 | "assert_trees_all_equal_dtypes",
163 | "assert_trees_all_equal_shapes",
164 | "assert_trees_all_equal_shapes_and_dtypes",
165 | "assert_trees_all_equal_sizes",
166 | "assert_trees_all_equal_structs",
167 | "assert_type",
168 | "block_until_chexify_assertions_complete",
169 | "chexify",
170 | "ChexVariantType",
171 | "clear_trace_counter",
172 | "create_deprecated_function_alias",
173 | "dataclass",
174 | "Device",
175 | "Dimensions",
176 | "disable_asserts",
177 | "enable_asserts",
178 | "fake_jit",
179 | "fake_pmap",
180 | "fake_pmap_and_jit",
181 | "if_args_not_none",
182 | "mappable_dataclass",
183 | "Numeric",
184 | "params_product",
185 | "PRNGKey",
186 | "PyTreeDef",
187 | "register_dataclass_type_with_jax_tree_util",
188 | "restrict_backends",
189 | "Scalar",
190 | "set_n_cpu_devices",
191 | "Shape",
192 | "TestCase",
193 | "variants",
194 | "warn_deprecated_function",
195 | "warn_keyword_args_only_in_future",
196 | "warn_only_n_pos_args_in_future",
197 | "with_jittable_assertions",
198 | )
199 |
200 |
201 | # _________________________________________
202 | # / Please don't use symbols in `_src` they \
203 | # \ are not part of the Chex public API. /
204 | # -----------------------------------------
205 | # \ ^__^
206 | # \ (oo)\_______
207 | # (__)\ )\/\
208 | # ||----w |
209 | # || ||
210 | #
211 |
--------------------------------------------------------------------------------
/chex/_src/asserts_chexify.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Chexification utilities."""
16 |
17 | import atexit
18 | import collections
19 | from concurrent import futures
20 | import dataclasses
21 | import functools
22 | import re
23 | from typing import Any, Callable, FrozenSet
24 |
25 | from absl import logging
26 | from chex._src import asserts_internal as _ai
27 | import jax
28 | from jax.experimental import checkify
29 |
30 |
31 | @dataclasses.dataclass(frozen=True)
32 | class _ChexifyChecks:
33 | """A set of checks imported from checkify."""
34 |
35 | user: FrozenSet[checkify.ErrorCategory] = checkify.user_checks
36 | nan: FrozenSet[checkify.ErrorCategory] = checkify.nan_checks
37 | index: FrozenSet[checkify.ErrorCategory] = checkify.index_checks
38 | div: FrozenSet[checkify.ErrorCategory] = checkify.div_checks
39 | float: FrozenSet[checkify.ErrorCategory] = checkify.float_checks
40 | automatic: FrozenSet[checkify.ErrorCategory] = checkify.automatic_checks
41 | all: FrozenSet[checkify.ErrorCategory] = checkify.all_checks
42 |
43 |
44 | _chexify_error_pattern = re.compile(
45 | re.escape(_ai.get_chexify_err_message('ANY', 'ANY')).replace('ANY', '.*')
46 | )
47 |
48 |
49 | def _check_error(err: checkify.Error) -> None:
50 | """Checks the error and converts it to chex format."""
51 | try:
52 | checkify.check_error(err)
53 | except ValueError as exc:
54 | msg = str(exc)
55 | if _chexify_error_pattern.match(msg):
56 | # Remove internal code pointers.
57 | internal_info_pos = msg.rfind('(check failed at')
58 | if internal_info_pos != -1:
59 | msg = msg[:internal_info_pos]
60 | raise AssertionError(msg) # pylint:disable=raise-missing-from
61 | else:
62 | raise
63 |
64 |
65 | def block_until_chexify_assertions_complete() -> None:
66 | """Waits until all asynchronous checks complete.
67 |
68 | See `chexify` for more detail.
69 | """
70 | for wait_fn in _ai.CHEXIFY_STORAGE.wait_fns:
71 | wait_fn()
72 |
73 |
74 | @atexit.register # to catch uninspected error stats
75 | def _check_if_hanging_assertions():
76 | if _ai.CHEXIFY_STORAGE.wait_fns:
77 | logging.warning(
78 | '[Chex] Some of chexify assetion statuses were not inspected due to '
79 | 'async exec (https://jax.readthedocs.io/en/latest/async_dispatch.html).'
80 | ' Consider calling `chex.block_until_chexify_assertions_complete()` at '
81 | 'the end of computations that rely on jitted chex assetions.')
82 | block_until_chexify_assertions_complete()
83 |
84 |
85 | # Public API.
86 | ChexifyChecks = _ChexifyChecks()
87 |
88 |
89 | def chexify(
90 | fn: Callable[..., Any],
91 | async_check: bool = True,
92 | errors: FrozenSet[checkify.ErrorCategory] = ChexifyChecks.user,
93 | ) -> Callable[..., Any]:
94 | """Wraps a transformed function `fn` to enable Chex value assertions.
95 |
96 | Chex value/runtime assertions access concrete values of tensors (e.g.
97 | `assert_tree_all_finite`) which are not available during JAX tracing, see
98 | https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html
99 | and
100 | https://jax.readthedocs.io/en/latest/_modules/jax/_src/errors.html#ConcretizationTypeError.
101 |
102 | This wrapper enables them in jitted/pmapped functions by performing a
103 | specifically designed JAX transformation
104 | https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html#the-checkify-transformation
105 | and calling functionalised checks
106 | https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.checkify.check.html
107 |
108 | Example:
109 |
110 | .. code::
111 |
112 | @chex.chexify
113 | @jax.jit
114 | def logp1_abs_safe(x: chex.Array) -> chex.Array:
115 | chex.assert_tree_all_finite(x)
116 | return jnp.log(jnp.abs(x) + 1)
117 |
118 | logp1_abs_safe(jnp.ones(2)) # OK
119 | logp1_abs_safe(jnp.array([jnp.nan, 3])) # FAILS
120 | logp1_abs_safe.wait_checks()
121 |
122 | Note 1: This wrapper allows identifying the first failed assertion in a jitted
123 | code by printing a pointer to the line where the failed assertion was invoked.
124 | For getting verbose messages (including concrete tensor values), an unjitted
125 | version of the code will need to be executed with the same input values. Chex
126 | does not currently provide tools to help with this.
127 |
128 | Note 2: This wrapper fully supports asynchronous executions
129 | (see https://jax.readthedocs.io/en/latest/async_dispatch.html).
130 | To block program execution until asynchronous checks for a _chexified_
131 | function `fn` complete, call `fn.wait_checks()`. Similarly,
132 | `chex.block_until_chexify_assertions_complete()` will block program execution
133 | until _all_ asyncronous checks complete.
134 |
135 | Note 3: Chex automatically selects the backend for executing its assertions
136 | (i.e. CPU or device accelerator) depending on the program context.
137 |
138 | Note 4: Value assertions can have impact on the performance of a function, see
139 | https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html#limitations
140 |
141 | Note 5: static assertions, such as `assert_shape` or
142 | `assert_trees_all_equal_dtypes`, can be called from a jitted function without
143 | `chexify` wrapper (since they do not access concrete values, only
144 | shapes and/or dtypes which are available during JAX tracing).
145 |
146 | More examples can be found at
147 | https://github.com/deepmind/chex/blob/master/chex/_src/asserts_chexify_test.py
148 |
149 | Args:
150 | fn: A transformed function to wrap.
151 | async_check: Whether to check errors in the async dispatch mode. See
152 | https://jax.readthedocs.io/en/latest/async_dispatch.html.
153 | errors: A set of `checkify.ErrorCategory` values which defines the set of
154 | enabled checks. By default only explicit ``checks`` are enabled (`user`).
155 | You can also for example enable NaN and Div-by-0 errors by passing the
156 | `float` set, or for example combine multiple sets through set
157 | operations (`float | user`).
158 |
159 | Returns:
160 | A _chexified_ function, i.e. the one with enabled value assertions.
161 | The returned function has `wait_checks()` method that blocks the caller
162 | until all pending async checks complete.
163 | """
164 | # Hardware/XLA failures can only happen on the C++ side. They are expected to
165 | # issue critical errors that will immediately crash the whole program.
166 | # Nevertheless, Chex sets its own timeout for every chexified XLA comp. to
167 | # ensure that a program never blocks on Chex side when running in async mode.
168 | async_timeout = 1800 # 30 minutes
169 |
170 | # Get function name.
171 | if isinstance(fn, functools.partial):
172 | func_name = fn.func.__name__
173 | else:
174 | func_name = fn.__name__
175 |
176 | if async_check:
177 | # Spawn a thread for processing blocking calls.
178 | thread_pool = futures.ThreadPoolExecutor(1, f'async_chex_{func_name}')
179 | # A deque for futures.
180 | async_check_futures = collections.deque()
181 |
182 | # Checkification.
183 | checkified_fn = checkify.checkify(fn, errors=errors)
184 |
185 | @functools.wraps(fn)
186 | def _chexified_fn(*args, **kwargs):
187 | if _ai.CHEXIFY_STORAGE.level:
188 | raise RuntimeError(
189 | 'Nested @chexify wrapping is disallowed. '
190 | 'Make sure that you only wrap the function at the outermost level.')
191 |
192 | if _ai.has_tracers((args, kwargs)):
193 | raise RuntimeError(
194 | '@chexify must be applied on top of all (p)jit/pmap transformations'
195 | ' (otherwise it will result in `UnexpectedTracerError`). If you have'
196 | ' functions that use value assertions, do not wrap them'
197 | ' individually -- just wrap the outermost function after'
198 | ' applying all your JAX transformations. See the example at '
199 | 'https://github.com/google-deepmind/chex#static-and-value-aka-runtime-assertions' # pylint:disable=line-too-long
200 | )
201 |
202 | if async_check:
203 | # Check completed calls.
204 | while async_check_futures and async_check_futures[0].done():
205 | _check_error(async_check_futures.popleft().result(async_timeout))
206 |
207 | # Run the checkified function.
208 | _ai.CHEXIFY_STORAGE.level += 1
209 | try:
210 | err, out = checkified_fn(*args, **kwargs)
211 | finally:
212 | _ai.CHEXIFY_STORAGE.level -= 1
213 |
214 | # Check errors.
215 | if async_check:
216 | # Blocking call is deferred to the thread.
217 | async_check_futures.append(
218 | thread_pool.submit(lambda: jax.device_get(err)))
219 | else:
220 | # Blocks until `fn`'s outputs are ready.
221 | _check_error(err)
222 |
223 | return out
224 |
225 | def _wait_checks():
226 | if async_check:
227 | while async_check_futures:
228 | _check_error(async_check_futures.popleft().result(async_timeout))
229 |
230 | # Add a barrier callback to the global storage.
231 | _ai.CHEXIFY_STORAGE.wait_fns.append(_wait_checks)
232 |
233 | # Add the callback to the chexified funtion's properties.
234 | if not hasattr(_chexified_fn, 'wait_checks'):
235 | _chexified_fn.wait_checks = _wait_checks
236 | else:
237 | logging.warning(
238 | "Function %s already defines 'wait_checks' method; "
239 | 'Chex will not redefine it.', func_name)
240 |
241 | return _chexified_fn
242 |
243 |
244 | def with_jittable_assertions(fn: Callable[..., Any],
245 | async_check: bool = True) -> Callable[..., Any]:
246 | """An alias for `chexify` (see the docs)."""
247 | return chexify(fn, async_check)
248 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/chex/_src/dataclass.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """JAX/dm-tree friendly dataclass implementation reusing Python dataclasses."""
16 |
17 | import collections
18 | import dataclasses
19 | import functools
20 | import sys
21 |
22 | from absl import logging
23 | import jax
24 | from typing_extensions import dataclass_transform # pytype: disable=not-supported-yet
25 |
26 |
27 | FrozenInstanceError = dataclasses.FrozenInstanceError
28 | _RESERVED_DCLS_FIELD_NAMES = frozenset(("from_tuple", "replace", "to_tuple"))
29 |
30 |
31 | def mappable_dataclass(cls):
32 | """Exposes dataclass as ``collections.abc.Mapping`` descendent.
33 |
34 | Allows to traverse dataclasses in methods from `dm-tree` library.
35 |
36 | NOTE: changes dataclasses constructor to dict-type
37 | (i.e. positional args aren't supported; however can use generators/iterables).
38 |
39 | Args:
40 | cls: A dataclass to mutate.
41 |
42 | Returns:
43 | Mutated dataclass implementing ``collections.abc.Mapping`` interface.
44 | """
45 | if not dataclasses.is_dataclass(cls):
46 | raise ValueError(f"Expected dataclass, got {cls} (change wrappers order?).")
47 |
48 | # Define methods for compatibility with `collections.abc.Mapping`.
49 | setattr(cls, "__getitem__", lambda self, x: self.__dict__[x])
50 | setattr(cls, "__len__", lambda self: len(self.__dict__))
51 | setattr(cls, "__iter__", lambda self: iter(self.__dict__))
52 | # Override the default `collections.abc.Mapping` method implementation for
53 | # cleaner visualization. Without this change x.keys() shows the full repr(x)
54 | # instead of only the dict_keys present. The same goes for values and items.
55 | setattr(cls, "keys", lambda self: self.__dict__.keys())
56 | setattr(cls, "values", lambda self: self.__dict__.values())
57 | setattr(cls, "items", lambda self: self.__dict__.items())
58 |
59 | # Update constructor.
60 | orig_init = cls.__init__
61 | all_fields = set(f.name for f in cls.__dataclass_fields__.values())
62 | init_fields = [f.name for f in cls.__dataclass_fields__.values() if f.init]
63 |
64 | @functools.wraps(orig_init)
65 | def new_init(self, *orig_args, **orig_kwargs):
66 | if (orig_args and orig_kwargs) or len(orig_args) > 1:
67 | raise ValueError(
68 | "Mappable dataclass constructor doesn't support positional args."
69 | "(it has the same constructor as python dict)")
70 | all_kwargs = dict(*orig_args, **orig_kwargs)
71 | unknown_kwargs = set(all_kwargs.keys()) - all_fields
72 | if unknown_kwargs:
73 | raise ValueError(f"__init__() got unexpected kwargs: {unknown_kwargs}.")
74 |
75 | # Pass only arguments corresponding to fields with `init=True`.
76 | valid_kwargs = {k: v for k, v in all_kwargs.items() if k in init_fields}
77 | orig_init(self, **valid_kwargs)
78 |
79 | cls.__init__ = new_init
80 |
81 | # Update base class to derive from Mapping
82 | dct = dict(cls.__dict__)
83 | if "__dict__" in dct:
84 | dct.pop("__dict__") # Avoid self-references.
85 |
86 | # Remove object from the sequence of base classes. Deriving from both Mapping
87 | # and object will cause a failure to create a MRO for the updated class
88 | bases = tuple(b for b in cls.__bases__ if b != object)
89 | cls = type(cls.__name__, bases + (collections.abc.Mapping,), dct)
90 | return cls
91 |
92 |
93 | @dataclass_transform()
94 | def dataclass(
95 | cls=None,
96 | *,
97 | init=True,
98 | repr=True, # pylint: disable=redefined-builtin
99 | eq=True,
100 | order=False,
101 | unsafe_hash=False,
102 | frozen=False,
103 | kw_only: bool = False,
104 | mappable_dataclass=True, # pylint: disable=redefined-outer-name
105 | ):
106 | """JAX-friendly wrapper for :py:func:`dataclasses.dataclass`.
107 |
108 | This wrapper class registers new dataclasses with JAX so that tree utils
109 | operate correctly. Additionally a replace method is provided making it easy
110 | to operate on the class when made immutable (frozen=True).
111 |
112 | Args:
113 | cls: A class to decorate.
114 | init: See :py:func:`dataclasses.dataclass`.
115 | repr: See :py:func:`dataclasses.dataclass`.
116 | eq: See :py:func:`dataclasses.dataclass`.
117 | order: See :py:func:`dataclasses.dataclass`.
118 | unsafe_hash: See :py:func:`dataclasses.dataclass`.
119 | frozen: See :py:func:`dataclasses.dataclass`.
120 | kw_only: See :py:func:`dataclasses.dataclass`.
121 | mappable_dataclass: If True (the default), methods to make the class
122 | implement the :py:class:`collections.abc.Mapping` interface will be
123 | generated and the class will include :py:class:`collections.abc.Mapping`
124 | in its base classes.
125 | `True` is the default, because being an instance of `Mapping` makes
126 | `chex.dataclass` compatible with e.g. `jax.tree_util.tree_*` methods, the
127 | `tree` library, or methods related to tensorflow/python/utils/nest.py.
128 | As a side-effect, e.g. `np.testing.assert_array_equal` will only check
129 | the field names are equal and not the content. Use `chex.assert_tree_*`
130 | instead.
131 |
132 | Returns:
133 | A JAX-friendly dataclass.
134 | """
135 | def dcls(cls):
136 | # Make sure to create a separate _Dataclass instance for each `cls`.
137 | return _Dataclass(
138 | init, repr, eq, order, unsafe_hash, frozen, kw_only, mappable_dataclass
139 | )(cls)
140 |
141 | if cls is None:
142 | return dcls
143 | return dcls(cls)
144 |
145 |
146 | class _Dataclass():
147 | """JAX-friendly wrapper for `dataclasses.dataclass`."""
148 |
149 | def __init__(
150 | self,
151 | init=True,
152 | repr=True, # pylint: disable=redefined-builtin
153 | eq=True,
154 | order=False,
155 | unsafe_hash=False,
156 | frozen=False,
157 | kw_only=False,
158 | mappable_dataclass=True, # pylint: disable=redefined-outer-name
159 | ):
160 | self.init = init
161 | self.repr = repr # pylint: disable=redefined-builtin
162 | self.eq = eq
163 | self.order = order
164 | self.unsafe_hash = unsafe_hash
165 | self.frozen = frozen
166 | self.kw_only = kw_only
167 | self.mappable_dataclass = mappable_dataclass
168 |
169 | def __call__(self, cls):
170 | """Forwards class to dataclasses's wrapper and registers it with JAX."""
171 |
172 | # Remove once https://github.com/python/cpython/pull/24484 is merged.
173 | for base in cls.__bases__:
174 | if (dataclasses.is_dataclass(base) and
175 | getattr(base, "__dataclass_params__").frozen and not self.frozen):
176 | raise TypeError("cannot inherit non-frozen dataclass from a frozen one")
177 |
178 | # `kw_only` is only available starting from 3.10.
179 | version_dependent_args = {}
180 | version = sys.version_info
181 | if version.major == 3 and version.minor >= 10:
182 | version_dependent_args = {"kw_only": self.kw_only}
183 | # pytype: disable=wrong-keyword-args
184 | dcls = dataclasses.dataclass(
185 | cls,
186 | init=self.init,
187 | repr=self.repr,
188 | eq=self.eq,
189 | order=self.order,
190 | unsafe_hash=self.unsafe_hash,
191 | frozen=self.frozen,
192 | **version_dependent_args,
193 | )
194 | # pytype: enable=wrong-keyword-args
195 |
196 | fields_names = set(f.name for f in dataclasses.fields(dcls))
197 | invalid_fields = fields_names.intersection(_RESERVED_DCLS_FIELD_NAMES)
198 | if invalid_fields:
199 | raise ValueError(f"The following dataclass fields are disallowed: "
200 | f"{invalid_fields} ({dcls}).")
201 |
202 | if self.mappable_dataclass:
203 | dcls = mappable_dataclass(dcls)
204 |
205 | def _from_tuple(args):
206 | return dcls(zip(dcls.__dataclass_fields__.keys(), args))
207 |
208 | def _to_tuple(self):
209 | return tuple(getattr(self, k) for k in self.__dataclass_fields__.keys())
210 |
211 | def _replace(self, **kwargs):
212 | return dataclasses.replace(self, **kwargs)
213 |
214 | def _getstate(self):
215 | return self.__dict__
216 |
217 | # Register the dataclass at definition. As long as the dataclass is defined
218 | # outside __main__, this is sufficient to make JAX's PyTree registry
219 | # recognize the dataclass and the dataclass' custom PyTreeDef, especially
220 | # when unpickling either the dataclass object, its type, or its PyTreeDef,
221 | # in a different process, because the defining module will be imported.
222 | #
223 | # However, if the dataclass is defined in __main__, unpickling in a
224 | # subprocess does not trigger re-registration. Therefore we also need to
225 | # register when deserializing the object, or construction (e.g. when the
226 | # dataclass type is being unpickled). Unfortunately, there is not yet a way
227 | # to trigger re-registration when the treedef is unpickled as that's handled
228 | # by JAX.
229 | #
230 | # See internal dataclass_test for unit tests demonstrating the problems.
231 | # The registration below may result in pickling failures of the sort
232 | # _pickle.PicklingError: Can't pickle :
233 | # it's not the same object as register_dataclass_type_with_jax_tree_util
234 | # for modules defined in __main__ so we disable registration in this case.
235 | if dcls.__module__ != "__main__":
236 | register_dataclass_type_with_jax_tree_util(dcls)
237 |
238 | # Patch __setstate__ to register the dataclass on deserialization.
239 | def _setstate(self, state):
240 | register_dataclass_type_with_jax_tree_util(dcls)
241 | self.__dict__.update(state)
242 |
243 | orig_init = dcls.__init__
244 |
245 | # Patch __init__ such that the dataclass is registered on creation if it is
246 | # not registered on deserialization.
247 | @functools.wraps(orig_init)
248 | def _init(self, *args, **kwargs):
249 | register_dataclass_type_with_jax_tree_util(dcls)
250 | return orig_init(self, *args, **kwargs)
251 |
252 | setattr(dcls, "from_tuple", _from_tuple)
253 | setattr(dcls, "to_tuple", _to_tuple)
254 | setattr(dcls, "replace", _replace)
255 | setattr(dcls, "__getstate__", _getstate)
256 | setattr(dcls, "__setstate__", _setstate)
257 | setattr(dcls, "__init__", _init)
258 |
259 | return dcls
260 |
261 |
262 | def _dataclass_unflatten(dcls, keys, values):
263 | """Creates a chex dataclass from a flatten jax.tree_util representation."""
264 | dcls_object = dcls.__new__(dcls)
265 | attribute_dict = dict(zip(keys, values))
266 | # Looping over fields instead of keys & values preserves the field order.
267 | # Using dataclasses.fields fails because dataclass uids change after
268 | # serialisation (eg, with cloudpickle).
269 | for field in dcls.__dataclass_fields__.values():
270 | if field.name in attribute_dict: # Filter pseudo-fields.
271 | object.__setattr__(dcls_object, field.name, attribute_dict[field.name])
272 | # Need to manual call post_init here as we have avoided calling __init__
273 | if getattr(dcls_object, "__post_init__", None):
274 | dcls_object.__post_init__()
275 | return dcls_object
276 |
277 |
278 | def _flatten_with_path(dcls):
279 | path = []
280 | keys = []
281 | for k, v in sorted(dcls.__dict__.items()):
282 | k = jax.tree_util.GetAttrKey(k)
283 | path.append((k, v))
284 | keys.append(k)
285 | return path, keys
286 |
287 |
288 | @functools.cache
289 | def register_dataclass_type_with_jax_tree_util(data_class):
290 | """Register an existing dataclass so JAX knows how to handle it.
291 |
292 | This means that functions in jax.tree_util operate over the fields
293 | of the dataclass. See
294 | https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees
295 | for further information.
296 |
297 | Args:
298 | data_class: A class created using dataclasses.dataclass. It must be
299 | constructable from keyword arguments corresponding to the members exposed
300 | in instance.__dict__.
301 | """
302 | flatten = lambda d: jax.util.unzip2(sorted(d.__dict__.items()))[::-1]
303 | unflatten = functools.partial(_dataclass_unflatten, data_class)
304 | try:
305 | jax.tree_util.register_pytree_with_keys(
306 | nodetype=data_class, flatten_with_keys=_flatten_with_path,
307 | flatten_func=flatten, unflatten_func=unflatten)
308 | except ValueError:
309 | logging.info("%s is already registered as JAX PyTree node.", data_class)
310 |
--------------------------------------------------------------------------------
/chex/_src/fake.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Utilities to patch JAX functions with faked implementations.
16 |
17 | This module provides fake implementations of jax.jit and jax.pmap, which can be
18 | patched over existing implementations for easier debugging.
19 |
20 | See https://www.martinfowler.com/articles/mocksArentStubs.html
21 | """
22 |
23 | import contextlib
24 | import functools
25 | import inspect
26 | import os
27 | import re
28 | from typing import Any, Callable, Iterable, Optional, Union
29 | from unittest import mock
30 | from absl import flags
31 | import jax
32 | import jax.numpy as jnp
33 |
34 |
35 | FLAGS = flags.FLAGS
36 | flags.DEFINE_integer('chex_n_cpu_devices', 1,
37 | 'Number of CPU threads to use as devices in tests.')
38 | flags.DEFINE_bool('chex_assert_multiple_cpu_devices', False,
39 | 'Whether to fail if a number of CPU devices is less than 2.')
40 |
41 | _xla_device_count_flag_regexp = (
42 | r'[-]{0,2}xla_force_host_platform_device_count=(\d+)?(\s|$)')
43 |
44 |
45 | def get_n_cpu_devices_from_xla_flags() -> int:
46 | """Parses number of CPUs from the XLA environment flags."""
47 | m = re.match(_xla_device_count_flag_regexp, os.getenv('XLA_FLAGS', ''))
48 |
49 | # At least one CPU device must be available.
50 | n_devices = int(m.group(1)) if m else 1
51 | return n_devices
52 |
53 |
54 | def set_n_cpu_devices(n: Optional[int] = None) -> None:
55 | """Forces XLA to use `n` CPU threads as host devices.
56 |
57 | This allows `jax.pmap` to be tested on a single-CPU platform.
58 | This utility only takes effect before XLA backends are initialized, i.e.
59 | before any JAX operation is executed (including `jax.devices()` etc.).
60 | See https://github.com/google/jax/issues/1408.
61 |
62 | Args:
63 | n: A required number of CPU devices (``FLAGS.chex_n_cpu_devices`` is used by
64 | default).
65 |
66 | Raises:
67 | RuntimeError: If XLA backends were already initialized.
68 | """
69 | n = n or FLAGS['chex_n_cpu_devices'].value
70 |
71 | n_devices = get_n_cpu_devices_from_xla_flags()
72 | cpu_backend = (jax.lib.xla_bridge._backends or {}).get('cpu', None) # pylint: disable=protected-access
73 | if cpu_backend is not None and n_devices != n:
74 | raise RuntimeError(
75 | f'Attempted to set {n} devices, but {n_devices} CPUs already available:'
76 | ' ensure that `set_n_cpu_devices` is executed before any JAX operation.'
77 | )
78 |
79 | xla_flags = os.getenv('XLA_FLAGS', '')
80 | xla_flags = re.sub(_xla_device_count_flag_regexp, '', xla_flags)
81 | os.environ['XLA_FLAGS'] = ' '.join(
82 | [f'--xla_force_host_platform_device_count={n}'] + xla_flags.split())
83 |
84 |
85 | def convert_to_varargs(sig, *args, **kwargs):
86 | """Converts varargs+kwargs function arguments into varargs only."""
87 | bound_args = sig.bind(*args, **kwargs)
88 | return bound_args.args
89 |
90 |
91 | def _ignore_axis_index_groups(fn):
92 | """Wrapper that forces axis_index_groups to be None.
93 |
94 | This is to avoid problems within fake_pmap where parallel operations are
95 | performed with vmap, rather than pmap. Parallel operations where
96 | `axis_index_groups` is not `None` are not currently supported under vmap.
97 |
98 | Args:
99 | fn: the function to wrap
100 |
101 | Returns:
102 | a wrapped function that forces any keyword argument named
103 | `axis_index_groups` to be None
104 | """
105 | @functools.wraps(fn)
106 | def _fake(*args, axis_index_groups=None, **kwargs):
107 | del axis_index_groups
108 | return fn(*args, axis_index_groups=None, **kwargs)
109 | return _fake
110 |
111 |
112 | _fake_all_gather = _ignore_axis_index_groups(jax.lax.all_gather)
113 | _fake_all_to_all = _ignore_axis_index_groups(jax.lax.all_to_all)
114 | _fake_psum = _ignore_axis_index_groups(jax.lax.psum)
115 | _fake_pmean = _ignore_axis_index_groups(jax.lax.pmean)
116 | _fake_pmax = _ignore_axis_index_groups(jax.lax.pmax)
117 | _fake_pmin = _ignore_axis_index_groups(jax.lax.pmin)
118 | _fake_pswapaxes = _ignore_axis_index_groups(jax.lax.pswapaxes)
119 |
120 |
121 | @functools.wraps(jax.pmap)
122 | def _fake_pmap(fn,
123 | axis_name: Optional[Any] = None,
124 | *,
125 | in_axes=0,
126 | static_broadcasted_argnums: Union[int, Iterable[int]] = (),
127 | jit_result: bool = False,
128 | fake_parallel_axis: bool = False,
129 | **unused_kwargs):
130 | """Fake implementation of pmap using vmap."""
131 |
132 | if isinstance(static_broadcasted_argnums, int):
133 | static_broadcasted_argnums = (static_broadcasted_argnums,)
134 | if static_broadcasted_argnums and isinstance(in_axes, dict):
135 | raise NotImplementedError(
136 | 'static_broadcasted_argnums with dict in_axes not supported.')
137 |
138 | fn_signature = inspect.signature(
139 | fn,
140 | # Disable 'follow wrapped' because we want the exact signature of fn,
141 | # not the signature of any function it might wrap.
142 | follow_wrapped=False)
143 |
144 | @functools.wraps(fn)
145 | def wrapped_fn(*args, **kwargs):
146 | # Convert kwargs to varargs
147 | # This is a workaround for vmapped functions not working with kwargs
148 | call_args = convert_to_varargs(fn_signature, *args, **kwargs)
149 |
150 | if static_broadcasted_argnums:
151 | # Make sure vmap does not try to map over `static_broadcasted_argnums`.
152 | if isinstance(in_axes, int):
153 | vmap_in_axes = [in_axes] * len(call_args)
154 | else:
155 | vmap_in_axes = list(in_axes)
156 | for argnum in static_broadcasted_argnums:
157 | vmap_in_axes[argnum] = None
158 |
159 | # To protect the arguments from `static_broadcasted_argnums`,
160 | # from turning into tracers (because of vmap), we capture the original
161 | # `call_args` and replace the passed in tracers with original values.
162 | original_call_args = call_args
163 |
164 | # A function passed to vmap, that will simply replace the static args
165 | # with their original values.
166 | def fn_without_statics(*args):
167 | args_with_original_statics = [
168 | orig_arg if i in static_broadcasted_argnums else arg
169 | for i, (arg, orig_arg) in enumerate(zip(args, original_call_args))
170 | ]
171 | return fn(*args_with_original_statics)
172 |
173 | # Make sure to avoid turning static args into tracers: Some python objects
174 | # might not survive vmap. Just replace with an unused constant.
175 | call_args = [
176 | 1 if i in static_broadcasted_argnums else arg
177 | for i, arg in enumerate(call_args)
178 | ]
179 |
180 | else:
181 | vmap_in_axes = in_axes
182 | fn_without_statics = fn
183 |
184 | vmapped_fn = jax.vmap(
185 | fn_without_statics, in_axes=vmap_in_axes, axis_name=axis_name
186 | )
187 | if jit_result:
188 | vmapped_fn = jax.jit(vmapped_fn)
189 |
190 | if fake_parallel_axis:
191 | call_args = jax.tree_util.tree_map(
192 | lambda x: jnp.expand_dims(x, axis=0), call_args)
193 |
194 | output = vmapped_fn(*call_args)
195 |
196 | if fake_parallel_axis:
197 | output = jax.tree_util.tree_map(lambda x: jnp.squeeze(x, axis=0), output)
198 |
199 | return output
200 |
201 | return wrapped_fn
202 |
203 |
204 | # pylint:disable=unnecessary-dunder-call
205 | class FakeContext(contextlib.ExitStack):
206 |
207 | def start(self):
208 | self.__enter__()
209 |
210 | def stop(self):
211 | self.__exit__(None, None, None)
212 |
213 |
214 | # pylint:enable=unnecessary-dunder-call
215 |
216 |
217 | def fake_jit(enable_patching: bool = True) -> FakeContext:
218 | """Context manager for patching `jax.jit` with the identity function.
219 |
220 | This is intended to be used as a debugging tool to programmatically enable or
221 | disable JIT compilation.
222 |
223 | Can be used either as a context managed scope:
224 |
225 | .. code-block:: python
226 |
227 | with chex.fake_jit():
228 | @jax.jit
229 | def foo(x):
230 | ...
231 |
232 | or by calling `start` and `stop`:
233 |
234 | .. code-block:: python
235 |
236 | fake_jit_context = chex.fake_jit()
237 | fake_jit_context.start()
238 |
239 | @jax.jit
240 | def foo(x):
241 | ...
242 |
243 | fake_jit_context.stop()
244 |
245 | Args:
246 | enable_patching: Whether to patch `jax.jit`.
247 |
248 | Returns:
249 | Context where `jax.jit` is patched with the identity function jax is
250 | configured to avoid jitting internally whenever possible in functions
251 | such as `jax.lax.scan`, etc.
252 | """
253 | stack = FakeContext()
254 | stack.enter_context(jax.disable_jit(disable=enable_patching))
255 | return stack
256 |
257 |
258 | def fake_pmap(
259 | enable_patching: bool = True,
260 | jit_result: bool = False,
261 | ignore_axis_index_groups: bool = False,
262 | fake_parallel_axis: bool = False,
263 | ) -> FakeContext:
264 | """Context manager for patching `jax.pmap` with `jax.vmap`.
265 |
266 | This is intended to be used as a debugging tool to programmatically replace
267 | pmap transformations with a non-parallel vmap transformation.
268 |
269 | Can be used either as a context managed scope:
270 |
271 | .. code-block:: python
272 |
273 | with chex.fake_pmap():
274 | @jax.pmap
275 | def foo(x):
276 | ...
277 |
278 | or by calling `start` and `stop`:
279 |
280 | .. code-block:: python
281 |
282 | fake_pmap_context = chex.fake_pmap()
283 | fake_pmap_context.start()
284 | @jax.pmap
285 | def foo(x):
286 | ...
287 | fake_pmap_context.stop()
288 |
289 | Args:
290 | enable_patching: Whether to patch `jax.pmap`.
291 | jit_result: Whether the transformed function should be jitted despite not
292 | being pmapped.
293 | ignore_axis_index_groups: Whether to force any parallel operation within the
294 | context to set `axis_index_groups` to be None. This is a compatibility
295 | option to allow users of the axis_index_groups parameter to run under the
296 | fake_pmap context. This feature is not currently supported in vmap, and
297 | will fail, so we force the parameter to be `None`.
298 | *Warning*: This will produce different results to running under `jax.pmap`
299 | fake_parallel_axis: Fake a parallel axis
300 |
301 | Returns:
302 | Context where `jax.pmap` is patched with `jax.vmap`.
303 | """
304 | stack = FakeContext()
305 | if enable_patching:
306 | patched_pmap = functools.partial(
307 | _fake_pmap,
308 | jit_result=jit_result,
309 | fake_parallel_axis=fake_parallel_axis)
310 |
311 | stack.enter_context(mock.patch('jax.pmap', patched_pmap))
312 |
313 | if ignore_axis_index_groups:
314 | stack.enter_context(mock.patch('jax.lax.all_gather', _fake_all_gather))
315 | stack.enter_context(mock.patch('jax.lax.all_to_all', _fake_all_to_all))
316 | stack.enter_context(mock.patch('jax.lax.psum', _fake_psum))
317 | stack.enter_context(mock.patch('jax.lax.pmean', _fake_pmean))
318 | stack.enter_context(mock.patch('jax.lax.pmax', _fake_pmax))
319 | stack.enter_context(mock.patch('jax.lax.pmin', _fake_pmin))
320 | stack.enter_context(mock.patch('jax.lax.pswapaxes', _fake_pswapaxes))
321 | else:
322 | # Use default implementations
323 | pass
324 |
325 | return stack
326 |
327 |
328 | def fake_pmap_and_jit(enable_pmap_patching: bool = True,
329 | enable_jit_patching: bool = True) -> FakeContext:
330 | """Context manager for patching `jax.jit` and `jax.pmap`.
331 |
332 | This is a convenience function, equivalent to nested `chex.fake_pmap` and
333 | `chex.fake_jit` contexts.
334 |
335 | Note that calling (the true implementation of) `jax.pmap` will compile the
336 | function, so faking `jax.jit` in this case will not stop the function from
337 | being compiled.
338 |
339 | Args:
340 | enable_pmap_patching: Whether to patch `jax.pmap`.
341 | enable_jit_patching: Whether to patch `jax.jit`.
342 |
343 | Returns:
344 | Context where jax.pmap and jax.jit are patched with jax.vmap and the
345 | identity function
346 | """
347 | stack = FakeContext()
348 | stack.enter_context(fake_pmap(enable_pmap_patching))
349 | stack.enter_context(fake_jit(enable_jit_patching))
350 | return stack
351 |
352 |
353 | class OnCallOfTransformedFunction():
354 | """Injects a callback into any transformed function.
355 |
356 | A typical use-case is jax.jit or jax.pmap which is often hidden deep inside
357 | the code. This context manager allows to inject a callback function into
358 | functions which are transformed by the user-specified transformation.
359 | The callback will receive the transformed function and its arguments.
360 |
361 | The function can be useful to debug, profile and check the calls of any
362 | transformed function in a program
363 |
364 | For instance:
365 |
366 | with chex.OnCallOfTransformedFunction('jax.jit', print):
367 | [...]
368 |
369 | would print all calls to any function which was jit-compiled within this
370 | context.
371 |
372 | We can also automatically create profiles on the first call of all the
373 | jit compiled functions in the program:
374 |
375 | class profile_once():
376 | def __init__(self):
377 | self._first_call = True
378 |
379 | def __call__(self, fn, *args, **kwargs):
380 | if self._first_call:
381 | self._first_call = False
382 | print(profile_from_HLO(fn.lower(*args, **kwargs))
383 |
384 | with chex.OnCallOfTransformedFunction('jax.jit', profile_once()):
385 | [...]
386 | """
387 |
388 | def __init__(self, fn_transformation: str, callback_fn: Callable[..., Any]):
389 | """Creates a new OnCallOfTransformedFunction context manager.
390 |
391 | Args:
392 | fn_transformation: identifier of the function transformation e.g.
393 | 'jax.jit', 'jax.pmap', ...
394 | callback_fn: A callback function which receives the transformed function
395 | and its arguments on every call.
396 | """
397 | self._fn_transformation = fn_transformation
398 | self._callback_fn = callback_fn
399 | self._patch: mock._patch[Callable[[Any], Any]] = None # pylint: disable=unsubscriptable-object
400 | self._original_fn_transformation = None
401 |
402 | def __enter__(self):
403 |
404 | def _new_fn_transformation(fn, *args, **kwargs):
405 | """Returns a transformed version of the given function."""
406 | transformed_fn = self._original_fn_transformation(fn, *args, **kwargs)
407 |
408 | @functools.wraps(transformed_fn)
409 | def _new_transformed_fn(*args, **kwargs):
410 | """Returns result of the returned function and calls the callback."""
411 | self._callback_fn(transformed_fn, *args, **kwargs)
412 | return transformed_fn(*args, **kwargs)
413 |
414 | return _new_transformed_fn
415 |
416 | self._patch = mock.patch(self._fn_transformation, _new_fn_transformation)
417 | self._original_fn_transformation, unused_local = self._patch.get_original()
418 | self._patch.start()
419 |
420 | def __exit__(self, *unused_args):
421 | self._patch.stop()
422 |
--------------------------------------------------------------------------------
/chex/_src/fake_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for `fake.py`."""
16 |
17 | import dataclasses
18 | import functools
19 |
20 | from absl.testing import absltest
21 | from absl.testing import parameterized
22 | from chex._src import asserts
23 | from chex._src import fake
24 | from chex._src import pytypes
25 | import jax
26 | import jax.numpy as jnp
27 |
28 | ArrayBatched = pytypes.ArrayBatched
29 | ArraySharded = pytypes.ArraySharded
30 |
31 |
32 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests.
33 | def setUpModule():
34 | fake.set_n_cpu_devices()
35 |
36 |
37 | def _assert_jitted(fn, fn_input, is_jitted):
38 | """Asserts that a function can be jitted or not.
39 |
40 | Args:
41 | fn: The function to be tested
42 | fn_input: Input to pass to the function
43 | is_jitted: Assert that the function can be jitted with jax.jit (True) or
44 | cannot be jitted (False), i.e. the fake jit is working correctly.
45 | """
46 | asserts.clear_trace_counter()
47 | max_traces = 1 if is_jitted else 0
48 | wrapped_fn = jax.jit(asserts.assert_max_traces(fn, max_traces))
49 | wrapped_fn(fn_input)
50 |
51 |
52 | def _assert_pmapped(fn, fn_input, is_pmapped, should_jit=False):
53 | """Asserts whether a function can be pmapped or not.
54 |
55 | Args:
56 | fn: The function to be tested
57 | fn_input: Input to pass to the function
58 | is_pmapped: Assert that the function can be pmapped with jax.pmap (True) or
59 | cannot be pmapped (False), i.e. the fake pmap is working correctly.
60 | should_jit: if True, asserts that the function is jitted, regardless of it
61 | being pmapped or not.
62 | """
63 | num_devices = len(jax.devices())
64 | if should_jit:
65 | asserts.clear_trace_counter()
66 | fn = asserts.assert_max_traces(fn, n=1)
67 | wrapped_fn = jax.pmap(fn, axis_size=num_devices)
68 |
69 | fn_input = jnp.broadcast_to(fn_input, (num_devices,) + fn_input.shape)
70 | output = wrapped_fn(fn_input)
71 |
72 | # We test whether the function has been pmapped by inspecting the type of
73 | # the function output, if it is a sharded array type then the function has
74 | # been pmapped
75 | if is_pmapped:
76 | expected_type = jax.Array
77 | assert_message = f'Output is type {type(output)}, expected {expected_type}'
78 | assert isinstance(output, expected_type), assert_message
79 | else:
80 | expected_type = 'DeviceArray'
81 | assert_message = f'Output is type {type(output)}, expected {expected_type}'
82 | # ShardedDeviceArray is a subclass of DeviceArray. So, to enforce we have
83 | # a DeviceArray, we also check it's not a sharded one.
84 | assert (isinstance(output, jax.Array) and
85 | len(output.sharding.device_set) == 1), assert_message
86 |
87 |
88 | class PmapFakeTest(parameterized.TestCase):
89 |
90 | def test_assert_pmapped(self):
91 | def foo(x):
92 | return x * 2
93 | fn_input = jnp.ones((4,))
94 |
95 | _assert_pmapped(foo, fn_input, True)
96 | # Since this test runs only on 1 device, having a test to check if the
97 | # output is sharded or not is not correct. With jax.Array, you can check
98 | # the `len(output.sharding.device_set)` to see if its sharded or not, but
99 | # here because of a single device it fails.
100 |
101 | def test_assert_jitted(self):
102 | fn_input = jnp.ones((4,))
103 | def foo(x):
104 | return x * 2
105 |
106 | _assert_jitted(foo, fn_input, True)
107 | with self.assertRaises(AssertionError):
108 | _assert_jitted(foo, fn_input, False)
109 |
110 | @parameterized.named_parameters([
111 | ('plain_jit', {'enable_patching': True}, False),
112 | ('faked_jit', {'enable_patching': False}, True),
113 | ])
114 | def test_fake_jit(self, fake_kwargs, is_jitted):
115 | fn_input = jnp.ones((4,))
116 | def foo(x):
117 | return x * 2
118 |
119 | # Call with context manager
120 | with fake.fake_jit(**fake_kwargs):
121 | _assert_jitted(foo, fn_input, is_jitted)
122 |
123 | # Call with start/stop
124 | ctx = fake.fake_jit(**fake_kwargs)
125 | ctx.start()
126 | _assert_jitted(foo, fn_input, is_jitted)
127 | ctx.stop()
128 |
129 | @parameterized.named_parameters([
130 | ('plain_pmap_but_jit', True, True),
131 | ('plain_pmap', True, False),
132 | ('faked_pmap_but_jit', False, True),
133 | ('faked_pmap', False, False),
134 | ])
135 | def test_fake_pmap_(self, is_pmapped, jit_result):
136 | enable_patching = not is_pmapped
137 |
138 | fn_input = jnp.ones((4,))
139 | def foo(x):
140 | return x * 2
141 |
142 | # Call with context manager
143 | with fake.fake_pmap(enable_patching=enable_patching, jit_result=jit_result):
144 | _assert_pmapped(foo, fn_input, is_pmapped, jit_result)
145 |
146 | # Call with start/stop
147 | ctx = fake.fake_pmap(enable_patching=enable_patching, jit_result=jit_result)
148 | ctx.start()
149 | _assert_pmapped(foo, fn_input, is_pmapped, jit_result)
150 | ctx.stop()
151 |
152 | def test_fake_pmap_axis_name(self):
153 |
154 | with fake.fake_pmap():
155 |
156 | @functools.partial(jax.pmap, axis_name='i')
157 | @functools.partial(jax.pmap, axis_name='j')
158 | def f(_):
159 | return jax.lax.axis_index('i'), jax.lax.axis_index('j')
160 | x, y = f(jnp.zeros((4, 2)))
161 |
162 | self.assertEqual(x.tolist(), [[0, 0], [1, 1], [2, 2], [3, 3]])
163 | self.assertEqual(y.tolist(), [[0, 1], [0, 1], [0, 1], [0, 1]])
164 |
165 | @parameterized.named_parameters([
166 | ('fake_nothing', {
167 | 'enable_pmap_patching': False,
168 | 'enable_jit_patching': False
169 | }, True, True),
170 | ('fake_pmap', {
171 | 'enable_pmap_patching': True,
172 | 'enable_jit_patching': False
173 | }, False, True),
174 | # Default pmap will implicitly compile the function
175 | ('fake_jit', {
176 | 'enable_pmap_patching': False,
177 | 'enable_jit_patching': True
178 | }, True, False),
179 | ('fake_both', {
180 | 'enable_pmap_patching': True,
181 | 'enable_jit_patching': True
182 | }, False, False),
183 | ])
184 | def test_pmap_and_jit(self, fake_kwargs, is_pmapped, is_jitted):
185 | fn_input = jnp.ones((4,))
186 | def foo(x):
187 | return x * 2
188 |
189 | # Call with context manager
190 | with fake.fake_pmap_and_jit(**fake_kwargs):
191 | _assert_pmapped(foo, fn_input, is_pmapped)
192 | _assert_jitted(foo, fn_input, is_jitted)
193 |
194 | # Call with start/stop
195 | ctx = fake.fake_pmap_and_jit(**fake_kwargs)
196 | ctx.start()
197 | _assert_pmapped(foo, fn_input, is_pmapped)
198 | _assert_jitted(foo, fn_input, is_jitted)
199 | ctx.stop()
200 |
201 | @parameterized.named_parameters([
202 | ('fake_nothing', False, False),
203 | ('fake_pmap', True, False),
204 | ('fake_jit', False, True),
205 | ('fake_both', True, True),
206 | ])
207 | def test_with_kwargs(self, fake_pmap, fake_jit):
208 | with fake.fake_pmap_and_jit(fake_pmap, fake_jit):
209 | num_devices = len(jax.devices())
210 |
211 | @functools.partial(jax.pmap, axis_size=num_devices)
212 | @jax.jit
213 | def foo(x, y):
214 | return (x * 2) + y
215 |
216 | # pmap over all available devices
217 | inputs = jnp.array([1, 2])
218 | inputs = jnp.broadcast_to(inputs, (num_devices,) + inputs.shape)
219 | expected = jnp.broadcast_to(jnp.array([3, 6]), (num_devices, 2))
220 |
221 | asserts.assert_trees_all_close(foo(x=inputs, y=inputs), expected)
222 |
223 | @parameterized.named_parameters([
224 | ('fake_nothing', False, 1),
225 | ('fake_pmap', True, 1),
226 | ('fake_nothing_no_static_args', False, ()),
227 | ('fake_pmap_no_static_args', True, ()),
228 | ])
229 | def test_with_static_broadcasted_argnums(self, fake_pmap, static_argnums):
230 | with fake.fake_pmap_and_jit(fake_pmap, enable_jit_patching=False):
231 | num_devices = len(jax.devices())
232 |
233 | # Note: mode='bar' is intended to test that we correctly handle kwargs
234 | # with defaults for which we don't pass a value at call time.
235 | @functools.partial(
236 | jax.pmap,
237 | axis_size=num_devices,
238 | static_broadcasted_argnums=static_argnums,
239 | )
240 | @functools.partial(
241 | jax.jit,
242 | static_argnums=static_argnums,
243 | )
244 | def foo(x, multiplier, y, mode='bar'):
245 | if static_argnums == 1 or 1 in static_argnums:
246 | # Verify that the static arguments are not replaced with tracers.
247 | self.assertIsInstance(multiplier, int)
248 |
249 | if mode == 'bar':
250 | return (x * multiplier) + y
251 | else:
252 | return x
253 |
254 | # pmap over all available devices
255 | inputs = jnp.array([1, 2])
256 | inputs = jnp.broadcast_to(inputs, (num_devices,) + inputs.shape)
257 | func = lambda: foo(inputs, 100, inputs) # Pass multiplier=100.
258 |
259 | if static_argnums == 1: # Should work.
260 | expected = jnp.broadcast_to(jnp.array([101, 202]), (num_devices, 2))
261 | result = func()
262 | asserts.assert_trees_all_close(result, expected)
263 | else: # Should error.
264 | with self.assertRaises(ValueError):
265 | result = func()
266 |
267 | @parameterized.parameters(1, [1])
268 | def test_pmap_with_complex_static_broadcasted_object(self, static_argnums):
269 |
270 | @dataclasses.dataclass
271 | class Multiplier:
272 | x: int
273 | y: int
274 |
275 | def foo(x, multiplier, y):
276 | if static_argnums == 1 or 1 in static_argnums:
277 | # Verify that the static arguments are not replaced with tracers.
278 | self.assertIsInstance(multiplier, Multiplier)
279 |
280 | return x * multiplier.x + y * multiplier.y
281 |
282 | with fake.fake_pmap_and_jit():
283 | num_devices = jax.device_count()
284 |
285 | # pmap over all available devices
286 | transformed_foo = jax.pmap(
287 | foo,
288 | axis_size=num_devices,
289 | static_broadcasted_argnums=static_argnums,
290 | )
291 | x, y = jax.random.randint(
292 | jax.random.PRNGKey(27), (2, num_devices, 3, 5), 0, 10
293 | )
294 |
295 | # Test 1.
296 | mult = Multiplier(x=2, y=7)
297 | asserts.assert_trees_all_equal(
298 | transformed_foo(x, mult, y),
299 | foo(x, mult, y),
300 | x * mult.x + y * mult.y,
301 | )
302 |
303 | # Test 2.
304 | mult = Multiplier(x=72, y=21)
305 | asserts.assert_trees_all_equal(
306 | transformed_foo(x, mult, y),
307 | foo(x, mult, y),
308 | x * mult.x + y * mult.y,
309 | )
310 |
311 | @parameterized.named_parameters([
312 | ('fake_nothing', False, False),
313 | ('fake_pmap', True, False),
314 | ('fake_jit', False, True),
315 | ('fake_both', True, True),
316 | ])
317 | def test_with_partial(self, fake_pmap, fake_jit):
318 | with fake.fake_pmap_and_jit(fake_pmap, fake_jit):
319 | num_devices = len(jax.devices())
320 |
321 | # Testing a common use-case where non-parallel arguments are partially
322 | # applied before pmapping
323 | def foo(x, y, flag):
324 | return (x * 2) + y if flag else (x + y)
325 | foo = functools.partial(foo, flag=True)
326 |
327 | foo = jax.pmap(foo, axis_size=num_devices)
328 | foo = jax.jit(foo)
329 |
330 | # pmap over all available devices
331 | inputs = jnp.array([1, 2])
332 | inputs = jnp.broadcast_to(inputs, (num_devices,) + inputs.shape)
333 | expected = jnp.broadcast_to(jnp.array([3, 6]), (num_devices, 2))
334 |
335 | asserts.assert_trees_all_close(foo(inputs, inputs), expected)
336 | asserts.assert_trees_all_close(foo(x=inputs, y=inputs), expected)
337 |
338 | @parameterized.named_parameters([
339 | ('fake_nothing', False, False),
340 | ('fake_pmap', True, False),
341 | ('fake_jit', False, True),
342 | ('fake_both', True, True),
343 | ])
344 | def test_with_default_params(self, fake_pmap, fake_jit):
345 | with fake.fake_pmap_and_jit(fake_pmap, fake_jit):
346 | num_devices = len(jax.devices())
347 |
348 | # Default flag specified at definition time
349 | def foo(x, y, flag=True):
350 | return (x * 2) + y if flag else (x + y)
351 |
352 | default_foo = jax.pmap(foo, axis_size=num_devices)
353 | default_foo = jax.jit(default_foo)
354 |
355 | inputs = jnp.array([1, 2])
356 | inputs = jnp.broadcast_to(inputs, (num_devices,) + inputs.shape)
357 | expected = jnp.broadcast_to(jnp.array([3, 6]), (num_devices, 2))
358 | asserts.assert_trees_all_close(default_foo(inputs, inputs), expected)
359 | asserts.assert_trees_all_close(default_foo(x=inputs, y=inputs), expected)
360 |
361 | # Default overriden by partial to execute other branch
362 | overidden_foo = functools.partial(foo, flag=False)
363 | overidden_foo = jax.pmap(overidden_foo, axis_size=num_devices)
364 | overidden_foo = jax.jit(overidden_foo)
365 |
366 | expected = jnp.broadcast_to(jnp.array([2, 4]), (num_devices, 2))
367 | asserts.assert_trees_all_close(overidden_foo(inputs, inputs), expected)
368 | asserts.assert_trees_all_close(
369 | overidden_foo(x=inputs, y=inputs), expected)
370 |
371 | def test_parallel_ops_equivalence(self):
372 | """Test equivalence between parallel operations using pmap and vmap."""
373 | num_devices = len(jax.devices())
374 | inputs = jax.random.uniform(shape=(num_devices, num_devices, 2),
375 | key=jax.random.PRNGKey(1))
376 |
377 | def test_equivalence(fn):
378 | with fake.fake_pmap(enable_patching=False):
379 | outputs1 = jax.pmap(fn, axis_name='i', axis_size=num_devices)(inputs)
380 | with fake.fake_pmap(enable_patching=True):
381 | outputs2 = jax.pmap(fn, axis_name='i', axis_size=num_devices)(inputs)
382 | with fake.fake_pmap(enable_patching=True, jit_result=True):
383 | outputs3 = jax.pmap(fn, axis_name='i', axis_size=num_devices)(inputs)
384 | asserts.assert_trees_all_close(outputs1, outputs2, outputs3)
385 |
386 | parallel_ops_and_kwargs = [
387 | (jax.lax.psum, {}),
388 | (jax.lax.pmax, {}),
389 | (jax.lax.pmin, {}),
390 | (jax.lax.pmean, {}),
391 | (jax.lax.all_gather, {}),
392 | (jax.lax.all_to_all, {
393 | 'split_axis': 0,
394 | 'concat_axis': 1
395 | }),
396 | (jax.lax.ppermute, {
397 | 'perm': [(x, (x + 1) % num_devices) for x in range(num_devices)]
398 | }),
399 | ]
400 |
401 | def fn(op, kwargs, x, y=2.0):
402 | return op(x * y, axis_name='i', **kwargs)
403 | partial_fn = functools.partial(fn, y=4.0)
404 | lambda_fn = lambda op, kwargs, x: fn(op, kwargs, x, y=5.0)
405 |
406 | for op, kwargs in parallel_ops_and_kwargs:
407 | test_equivalence(functools.partial(fn, op, kwargs))
408 | test_equivalence(functools.partial(fn, op, kwargs, y=3.0))
409 | test_equivalence(functools.partial(partial_fn, op, kwargs))
410 | test_equivalence(functools.partial(lambda_fn, op, kwargs))
411 |
412 | def test_fake_parallel_axis(self):
413 | inputs = jnp.ones(shape=(2, 2))
414 | with fake.fake_pmap(fake_parallel_axis=False):
415 | @jax.pmap
416 | def no_fake_parallel_axis_fn(x):
417 | asserts.assert_shape(x, (2,))
418 | return 2.0 * x
419 |
420 | outputs = no_fake_parallel_axis_fn(inputs)
421 | asserts.assert_trees_all_close(outputs, 2.0)
422 |
423 | with fake.fake_pmap(fake_parallel_axis=True):
424 | @jax.pmap
425 | def fake_parallel_axis_fn(x):
426 | asserts.assert_shape(x, (2, 2,))
427 | return 2.0 * x
428 |
429 | outputs = fake_parallel_axis_fn(inputs)
430 | asserts.assert_trees_all_close(outputs, 2.0)
431 |
432 |
433 | class _Counter():
434 | """Counts how often an instance is called."""
435 |
436 | def __init__(self):
437 | self.count = 0
438 |
439 | def __call__(self, *unused_args, **unused_kwargs):
440 | self.count += 1
441 |
442 |
443 | class OnCallOfTransformedFunctionTest(parameterized.TestCase):
444 |
445 | def test_on_call_of_transformed_function(self):
446 | counter = _Counter()
447 | with fake.OnCallOfTransformedFunction('jax.jit', counter):
448 | jax.jit(jnp.sum)(jnp.zeros((10,)))
449 | jax.jit(jnp.max)(jnp.zeros((10,)))
450 | self.assertEqual(counter.count, 2)
451 |
452 |
453 | if __name__ == '__main__':
454 | absltest.main()
455 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Chex
2 |
3 | 
4 | 
5 | 
6 |
7 | Chex is a library of utilities for helping to write reliable JAX code.
8 |
9 | This includes utils to help:
10 |
11 | * Instrument your code (e.g. assertions, warnings)
12 | * Debug (e.g. transforming `pmaps` in `vmaps` within a context manager).
13 | * Test JAX code across many `variants` (e.g. jitted vs non-jitted).
14 |
15 | ## Installation
16 |
17 | You can install the latest released version of Chex from PyPI via:
18 |
19 | ```sh
20 | pip install chex
21 | ```
22 |
23 | or you can install the latest development version from GitHub:
24 |
25 | ```sh
26 | pip install git+https://github.com/deepmind/chex.git
27 | ```
28 |
29 | ## Modules Overview
30 |
31 | ### Dataclass ([dataclass.py](https://github.com/deepmind/chex/blob/master/chex/_src/dataclass.py))
32 |
33 | Dataclasses are a popular construct introduced by Python 3.7 to allow to
34 | easily specify typed data structures with minimal boilerplate code. They are
35 | not, however, compatible with JAX and
36 | [dm-tree](https://github.com/deepmind/tree) out of the box.
37 |
38 | In Chex we provide a JAX-friendly dataclass implementation reusing python [dataclasses](https://docs.python.org/3/library/dataclasses.html#module-dataclasses).
39 |
40 | Chex implementation of `dataclass` registers dataclasses as internal [_PyTree_
41 | nodes](https://jax.readthedocs.io/en/latest/pytrees.html) to ensure
42 | compatibility with JAX data structures.
43 |
44 | In addition, we provide a class wrapper that exposes dataclasses as
45 | `collections.Mapping` descendants which allows to process them
46 | (e.g. (un-)flatten) in `dm-tree` methods as usual Python dictionaries.
47 | See [`@mappable_dataclass`](https://github.com/deepmind/chex/blob/master/chex/_src/dataclass.py#L27)
48 | docstring for more details.
49 |
50 | Example:
51 |
52 | ```python
53 | @chex.dataclass
54 | class Parameters:
55 | x: chex.ArrayDevice
56 | y: chex.ArrayDevice
57 |
58 | parameters = Parameters(
59 | x=jnp.ones((2, 2)),
60 | y=jnp.ones((1, 2)),
61 | )
62 |
63 | # Dataclasses can be treated as JAX pytrees
64 | jax.tree_util.tree_map(lambda x: 2.0 * x, parameters)
65 |
66 | # and as mappings by dm-tree
67 | tree.flatten(parameters)
68 | ```
69 |
70 | **NOTE**: Unlike standard Python 3.7 dataclasses, Chex
71 | dataclasses cannot be constructed using positional arguments. They support
72 | construction arguments provided in the same format as the Python dict
73 | constructor. Dataclasses can be converted to tuples with the `from_tuple` and
74 | `to_tuple` methods if necessary.
75 |
76 | ```python
77 | parameters = Parameters(
78 | jnp.ones((2, 2)),
79 | jnp.ones((1, 2)),
80 | )
81 | # ValueError: Mappable dataclass constructor doesn't support positional args.
82 | ```
83 |
84 | ### Assertions ([asserts.py](https://github.com/deepmind/chex/blob/master/chex/_src/asserts.py))
85 |
86 | One limitation of PyType annotations for JAX is that they do not support the
87 | specification of `DeviceArray` ranks, shapes or dtypes. Chex includes a number
88 | of functions that allow flexible and concise specification of these properties.
89 |
90 | E.g. suppose you want to ensure that all tensors `t1`, `t2`, `t3` have the same
91 | shape, and that tensors `t4`, `t5` have rank `2` and (`3` or `4`), respectively.
92 |
93 | ```python
94 | chex.assert_equal_shape([t1, t2, t3])
95 | chex.assert_rank([t4, t5], [2, {3, 4}])
96 | ```
97 |
98 | More examples:
99 |
100 | ```python
101 | from chex import assert_shape, assert_rank, ...
102 |
103 | assert_shape(x, (2, 3)) # x has shape (2, 3)
104 | assert_shape([x, y], [(), (2,3)]) # x is scalar and y has shape (2, 3)
105 |
106 | assert_rank(x, 0) # x is scalar
107 | assert_rank([x, y], [0, 2]) # x is scalar and y is a rank-2 array
108 | assert_rank([x, y], {0, 2}) # x and y are scalar OR rank-2 arrays
109 |
110 | assert_type(x, int) # x has type `int` (x can be an array)
111 | assert_type([x, y], [int, float]) # x has type `int` and y has type `float`
112 |
113 | assert_equal_shape([x, y, z]) # x, y, and z have equal shapes
114 |
115 | assert_trees_all_close(tree_x, tree_y) # values and structure of trees match
116 | assert_tree_all_finite(tree_x) # all tree_x leaves are finite
117 |
118 | assert_devices_available(2, 'gpu') # 2 GPUs available
119 | assert_tpu_available() # at least 1 TPU available
120 |
121 | assert_numerical_grads(f, (x, y), j) # f^{(j)}(x, y) matches numerical grads
122 | ```
123 |
124 | See `asserts.py`
125 | [documentation](https://chex.readthedocs.io/en/latest/api.html#assertions) to
126 | find all supported assertions.
127 |
128 | If you cannot find a specific assertion, please consider making a pull request
129 | or openning an issue on
130 | [the bug tracker](https://github.com/deepmind/chex/issues).
131 |
132 | #### Optional Arguments
133 |
134 | All chex assertions support the following optional kwargs for manipulating the
135 | emitted exception messages:
136 |
137 | * `custom_message`: A string to include into the emitted exception messages.
138 | * `include_default_message`: Whether to include the default Chex message into
139 | the emitted exception messages.
140 | * `exception_type`: An exception type to use. `AssertionError` by default.
141 |
142 | For example, the following code:
143 |
144 | ```python
145 | dataset = load_dataset()
146 | params = init_params()
147 | for i in range(num_steps):
148 | params = update_params(params, dataset.sample())
149 | chex.assert_tree_all_finite(params,
150 | custom_message=f'Failed at iteration {i}.',
151 | exception_type=ValueError)
152 | ```
153 |
154 | will raise a `ValueError` that includes a step number when `params` get polluted
155 | with `NaNs` or `None`s.
156 |
157 | #### Static and Value (aka *Runtime*) Assertions
158 |
159 | Chex divides all assertions into 2 classes: ***static*** and ***value***
160 | assertions.
161 |
162 | 1. ***static*** assertions use anything except concrete values of tensors.
163 | Examples: `assert_shape`, `assert_trees_all_equal_dtypes`,
164 | `assert_max_traces`.
165 |
166 | 2. ***value*** assertions require access to tensor values, which are not
167 | available during JAX tracing (see
168 | [HowJAX primitives work](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html)),
169 | thus such assertion need special treatment in a *jitted* code.
170 |
171 | To enable value assertions in a jitted function, it can be decorated with
172 | `chex.chexify()` wrapper. Example:
173 |
174 | ```python
175 | @chex.chexify
176 | @jax.jit
177 | def logp1_abs_safe(x: chex.Array) -> chex.Array:
178 | chex.assert_tree_all_finite(x)
179 | return jnp.log(jnp.abs(x) + 1)
180 |
181 | logp1_abs_safe(jnp.ones(2)) # OK
182 | logp1_abs_safe(jnp.array([jnp.nan, 3])) # FAILS (in async mode)
183 |
184 | # The error will be raised either at the next line OR at the next
185 | # `logp1_abs_safe` call. See the docs for more detain on async mode.
186 | logp1_abs_safe.wait_checks() # Wait for the (async) computation to complete.
187 | ```
188 |
189 | See
190 | [this docstring](https://chex.readthedocs.io/en/latest/api.html#chex.chexify)
191 | for more detail on `chex.chexify()`.
192 |
193 | #### JAX Tracing Assertions
194 |
195 | JAX re-traces JIT'ted function every time the structure of passed arguments
196 | changes. Often this behavior is inadvertent and leads to a significant
197 | performance drop which is hard to debug. [@chex.assert_max_traces](https://github.com/deepmind/chex/blob/master/chex/_src/asserts.py#L44)
198 | decorator asserts that the function is not re-traced more than `n` times during
199 | program execution.
200 |
201 | Global trace counter can be cleared by calling
202 | `chex.clear_trace_counter()`. This function be used to isolate unittests relying
203 | on `@chex.assert_max_traces`.
204 |
205 | Examples:
206 |
207 | ```python
208 | @jax.jit
209 | @chex.assert_max_traces(n=1)
210 | def fn_sum_jitted(x, y):
211 | return x + y
212 |
213 | fn_sum_jitted(jnp.zeros(3), jnp.zeros(3)) # tracing for the 1st time - OK
214 | fn_sum_jitted(jnp.zeros([6, 7]), jnp.zeros([6, 7])) # AssertionError!
215 | ```
216 |
217 | Can be used with `jax.pmap()` as well:
218 |
219 | ```python
220 | def fn_sub(x, y):
221 | return x - y
222 |
223 | fn_sub_pmapped = jax.pmap(chex.assert_max_traces(fn_sub, n=10))
224 | ```
225 |
226 | See
227 | [HowJAX primitives work](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html)
228 | section for more information about tracing.
229 |
230 | ### Warnings ([warnigns.py](https://github.com/deepmind/chex/blob/master/chex/_src/warnings.py))
231 |
232 | In addition to hard assertions Chex also offers utilities to add common
233 | warnings, such as specific types of deprecation warnings.
234 |
235 | ### Test variants ([variants.py](https://github.com/deepmind/chex/blob/master/chex/_src/variants.py))
236 |
237 | JAX relies extensively on code transformation and compilation, meaning that it
238 | can be hard to ensure that code is properly tested. For instance, just testing a
239 | python function using JAX code will not cover the actual code path that is
240 | executed when jitted, and that path will also differ whether the code is jitted
241 | for CPU, GPU, or TPU. This has been a source of obscure and hard to catch bugs
242 | where XLA changes would lead to undesirable behaviours that however only
243 | manifest in one specific code transformation.
244 |
245 | Variants make it easy to ensure that unit tests cover different ‘variations’ of
246 | a function, by providing a simple decorator that can be used to repeat any test
247 | under all (or a subset) of the relevant code transformations.
248 |
249 | E.g. suppose you want to test the output of a function `fn` with or without jit.
250 | You can use `chex.variants` to run the test with both the jitted and non-jitted
251 | version of the function by simply decorating a test method with
252 | `@chex.variants`, and then using `self.variant(fn)` in place of `fn` in the body
253 | of the test.
254 |
255 | ```python
256 | def fn(x, y):
257 | return x + y
258 | ...
259 |
260 | class ExampleTest(chex.TestCase):
261 |
262 | @chex.variants(with_jit=True, without_jit=True)
263 | def test(self):
264 | var_fn = self.variant(fn)
265 | self.assertEqual(fn(1, 2), 3)
266 | self.assertEqual(var_fn(1, 2), fn(1, 2))
267 | ```
268 |
269 | If you define the function in the test method, you may also use `self.variant`
270 | as a decorator in the function definition. For example:
271 |
272 | ```python
273 | class ExampleTest(chex.TestCase):
274 |
275 | @chex.variants(with_jit=True, without_jit=True)
276 | def test(self):
277 | @self.variant
278 | def var_fn(x, y):
279 | return x + y
280 |
281 | self.assertEqual(var_fn(1, 2), 3)
282 | ```
283 |
284 | Example of parameterized test:
285 |
286 | ```python
287 | from absl.testing import parameterized
288 |
289 | # Could also be:
290 | # `class ExampleParameterizedTest(chex.TestCase, parameterized.TestCase):`
291 | # `class ExampleParameterizedTest(chex.TestCase):`
292 | class ExampleParameterizedTest(parameterized.TestCase):
293 |
294 | @chex.variants(with_jit=True, without_jit=True)
295 | @parameterized.named_parameters(
296 | ('case_positive', 1, 2, 3),
297 | ('case_negative', -1, -2, -3),
298 | )
299 | def test(self, arg_1, arg_2, expected):
300 | @self.variant
301 | def var_fn(x, y):
302 | return x + y
303 |
304 | self.assertEqual(var_fn(arg_1, arg_2), expected)
305 | ```
306 |
307 | Chex currently supports the following variants:
308 |
309 | * `with_jit` -- applies `jax.jit()` transformation to the function.
310 | * `without_jit` -- uses the function as is, i.e. identity transformation.
311 | * `with_device` -- places all arguments (except specified in `ignore_argnums`
312 | argument) into device memory before applying the function.
313 | * `without_device` -- places all arguments in RAM before applying the function.
314 | * `with_pmap` -- applies `jax.pmap()` transformation to the function (see notes below).
315 |
316 | See documentation in [variants.py](https://github.com/deepmind/chex/blob/master/chex/_src/variants.py) for more details on the supported variants.
317 | More examples can be found in [variants_test.py](https://github.com/deepmind/chex/blob/master/chex/_src/variants_test.py).
318 |
319 | ### Variants notes
320 |
321 | * Test classes that use `@chex.variants` must inherit from
322 | `chex.TestCase` (or any other base class that unrolls tests generators
323 | within `TestCase`, e.g. `absl.testing.parameterized.TestCase`).
324 |
325 | * **[`jax.vmap`]** All variants can be applied to a vmapped function;
326 | please see an example in [variants_test.py](https://github.com/deepmind/chex/blob/master/chex/_src/variants_test.py) (`test_vmapped_fn_named_params` and
327 | `test_pmap_vmapped_fn`).
328 |
329 | * **[`@chex.all_variants`]** You can get all supported variants
330 | by using the decorator `@chex.all_variants`.
331 |
332 | * **[`with_pmap` variant]** `jax.pmap(fn)`
333 | ([doc](https://jax.readthedocs.io/en/latest/jax.html#jax.pmap)) performs
334 | parallel map of `fn` onto multiple devices. Since most tests run in a
335 | single-device environment (i.e. having access to a single CPU or GPU), in which
336 | case `jax.pmap` is a functional equivalent to `jax.jit`, ` with_pmap` variant is
337 | skipped by default (although it works fine with a single device). Below we
338 | describe a way to properly test `fn` if it is supposed to be used in
339 | multi-device environments (TPUs or multiple CPUs/GPUs). To disable skipping
340 | `with_pmap` variants in case of a single device, add
341 | `--chex_skip_pmap_variant_if_single_device=false` to your test command.
342 |
343 | ### Fakes ([fake.py](https://github.com/deepmind/chex/blob/master/chex/_src/fake.py))
344 |
345 | Debugging in JAX is made more difficult by code transformations such as `jit`
346 | and `pmap`, which introduce optimizations that make code hard to inspect and
347 | trace. It can also be difficult to disable those transformations during
348 | debugging as they can be called at several places in the underlying
349 | code. Chex provides tools to globally replace `jax.jit` with a no-op
350 | transformation and `jax.pmap` with a (non-parallel) `jax.vmap`, in order to more
351 | easily debug code in a single-device context.
352 |
353 | For example, you can use Chex to fake `pmap` and have it replaced with a `vmap`.
354 | This can be achieved by wrapping your code with a context manager:
355 |
356 | ```python
357 | with chex.fake_pmap():
358 | @jax.pmap
359 | def fn(inputs):
360 | ...
361 |
362 | # Function will be vmapped over inputs
363 | fn(inputs)
364 | ```
365 |
366 | The same functionality can also be invoked with `start` and `stop`:
367 |
368 | ```python
369 | fake_pmap = chex.fake_pmap()
370 | fake_pmap.start()
371 | ... your jax code ...
372 | fake_pmap.stop()
373 | ```
374 |
375 | In addition, you can fake a real multi-device test environment with a
376 | multi-threaded CPU. See section **Faking multi-device test environments** for
377 | more details.
378 |
379 | See documentation in [fake.py](https://github.com/deepmind/chex/blob/master/chex/_src/fake.py) and examples in [fake_test.py](https://github.com/deepmind/chex/blob/master/chex/_src/fake_test.py) for more details.
380 |
381 | ## Faking multi-device test environments
382 |
383 | In situations where you do not have easy access to multiple devices, you can
384 | still test parallel computation using single-device multi-threading.
385 |
386 | In particular, one can force XLA to use a single CPU's threads as separate
387 | devices, i.e. to fake a real multi-device environment with a multi-threaded one.
388 | These two options are theoretically equivalent from XLA perspective because they
389 | expose the same interface and use identical abstractions.
390 |
391 | Chex has a flag `chex_n_cpu_devices` that specifies a number of CPU threads to
392 | use as XLA devices.
393 |
394 | To set up a multi-threaded XLA environment for `absl` tests, define
395 | `setUpModule` function in your test module:
396 |
397 | ```python
398 | def setUpModule():
399 | chex.set_n_cpu_devices()
400 | ```
401 |
402 | Now you can launch your test with `python test.py --chex_n_cpu_devices=N` to run
403 | it in multi-device regime. Note that **all** tests within a module will have an
404 | access to `N` devices.
405 |
406 | More examples can be found in [variants_test.py](https://github.com/deepmind/chex/blob/master/chex/_src/variants_test.py), [fake_test.py](https://github.com/deepmind/chex/blob/master/chex/_src/fake_test.py) and [fake_set_n_cpu_devices_test.py](https://github.com/deepmind/chex/blob/master/chex/_src/fake_set_n_cpu_devices_test.py).
407 |
408 | ### Using named dimension sizes.
409 |
410 | Chex comes with a small utility that allows you to package a collection of
411 | dimension sizes into a single object. The basic idea is:
412 |
413 | ```python
414 | dims = chex.Dimensions(B=batch_size, T=sequence_len, E=embedding_dim)
415 | ...
416 | chex.assert_shape(arr, dims['BTE'])
417 | ```
418 |
419 | String lookups are translated integer tuples. For instance, let's say
420 | `batch_size == 3`, `sequence_len = 5` and `embedding_dim = 7`, then
421 |
422 | ```python
423 | dims['BTE'] == (3, 5, 7)
424 | dims['B'] == (3,)
425 | dims['TTBEE'] == (5, 5, 3, 7, 7)
426 | ...
427 | ```
428 |
429 | You can also assign dimension sizes dynamically as follows:
430 |
431 | ```python
432 | dims['XY'] = some_matrix.shape
433 | dims.Z = 13
434 | ```
435 |
436 | For more examples, see [chex.Dimensions](https://chex.readthedocs.io/en/latest/api.html#chex.Dimensions)
437 | documentation.
438 |
439 | ## Citing Chex
440 |
441 | This repository is part of the [DeepMind JAX Ecosystem], to cite Chex please use
442 | the [DeepMind JAX Ecosystem citation].
443 |
444 | [DeepMind JAX Ecosystem]: https://deepmind.com/blog/article/using-jax-to-accelerate-our-research "DeepMind JAX Ecosystem"
445 | [DeepMind JAX Ecosystem citation]: https://github.com/deepmind/jax/blob/main/deepmind2020jax.txt "Citation"
446 |
--------------------------------------------------------------------------------
/chex/_src/asserts_internal.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Chex assertion internal utilities and symbols.
16 |
17 | [README!]
18 |
19 | We reserve the right to change the code in this module at any time without
20 | providing any guarantees of backward compatibility. For this reason,
21 | we strongly recommend that you avoid using this module directly at all costs!
22 | Instead, consider opening an issue on GitHub and describing your use case.
23 | """
24 |
25 | import collections
26 | import collections.abc
27 | import functools
28 | import re
29 | import threading
30 | import traceback
31 | from typing import Any, Sequence, Union, Callable, Hashable, List, Optional, Set, Tuple, Type
32 |
33 | from absl import logging
34 | from chex._src import pytypes
35 | import jax
36 | from jax.experimental import checkify
37 | import jax.numpy as jnp
38 | import numpy as np
39 |
40 | # Custom pytypes.
41 | TLeaf = Any
42 | TLeavesEqCmpFn = Callable[[TLeaf, TLeaf], bool]
43 | TLeavesEqCmpErrorFn = Callable[[TLeaf, TLeaf], str]
44 |
45 | # TODO(iukemaev): define a typing protocol for TChexAssertion.
46 | # Chex assertion signature:
47 | # (*args,
48 | # custom_message: Optional[str] = None,
49 | # custom_message_format_vars: Sequence[Any] = (),
50 | # include_default_message: bool = True,
51 | # exception_type: Type[Exception] = AssertionError,
52 | # **kwargs)
53 | TChexAssertion = Callable[..., None]
54 | TAssertFn = Callable[..., None]
55 | TJittableAssertFn = Callable[..., pytypes.Array] # a predicate function
56 |
57 | # Matchers.
58 | TDimMatcher = Optional[Union[int, Set[int], type(Ellipsis)]]
59 | TShapeMatcher = Sequence[TDimMatcher]
60 |
61 |
62 | class _ChexifyStorage(threading.local):
63 | """Thread-safe storage for internal variables used in @chexify."""
64 | wait_fns = []
65 | level = 0
66 |
67 |
68 | # Chex namespace variables.
69 | ERR_PREFIX = "[Chex] "
70 | TRACE_COUNTER = collections.Counter()
71 | DISABLE_ASSERTIONS = False
72 |
73 | # This variable is used for _chexify_ transformations, see `asserts_chexify.py`.
74 | CHEXIFY_STORAGE = _ChexifyStorage()
75 |
76 |
77 | def assert_collection_of_arrays(inputs: Sequence[pytypes.Array]):
78 | """Checks if ``inputs`` is a collection of arrays."""
79 | if not isinstance(inputs, collections.abc.Collection):
80 | raise ValueError(f"`inputs` is not a collection of arrays: {inputs}.")
81 |
82 |
83 | def jnp_to_np_array(arr: pytypes.Array) -> np.ndarray:
84 | """Converts `jnp.ndarray` to `np.ndarray`."""
85 | if getattr(arr, "dtype", None) == jnp.bfloat16:
86 | # Numpy does not support `bfloat16`.
87 | arr = arr.astype(jnp.float32)
88 | return jax.device_get(arr)
89 |
90 |
91 | def deprecation_wrapper(new_fn, old_name, new_name):
92 | """Allows deprecated functions to continue running, with a warning logged."""
93 |
94 | def inner_fn(*args, **kwargs):
95 | logging.warning(
96 | "chex.%s has been renamed to chex.%s, please update your code.",
97 | old_name, new_name)
98 | return new_fn(*args, **kwargs)
99 |
100 | return inner_fn
101 |
102 |
103 | def get_stacktrace_without_chex_internals() -> List[traceback.FrameSummary]:
104 | """Returns the latest non-chex frame from the call stack."""
105 | stacktrace = list(traceback.extract_stack())
106 | for i in reversed(range(len(stacktrace))):
107 | fname = stacktrace[i].filename
108 | if fname.find("/chex/") == -1 or fname.endswith("_test.py"):
109 | return stacktrace[:i+1]
110 |
111 | debug_info = "\n-----\n".join(traceback.format_stack())
112 | raise RuntimeError(
113 | "get_stacktrace_without_chex_internals() failed. "
114 | "Please file a bug at https://github.com/deepmind/chex/issues and "
115 | "include the following debug info in it. "
116 | "Please make sure it does not include any private information! "
117 | f"Debug: '{debug_info}'.")
118 |
119 |
120 | def get_err_regex(message: str) -> str:
121 | """Constructs a regexp for the exception message.
122 |
123 | Args:
124 | message: an exception message.
125 |
126 | Returns:
127 | Regexp that ensures the message follows the standard chex formatting.
128 | """
129 | # (ERR_PREFIX + any symbols (incl. \n) + message)
130 | return f"{re.escape(ERR_PREFIX)}[\\s\\S]*{message}"
131 |
132 |
133 | def get_chexify_err_message(name: str, msg: str = "") -> str:
134 | """Constructs an error message for the chexify exception."""
135 | return f"{ERR_PREFIX}chexify assertion '{name}' failed: {msg}"
136 |
137 |
138 | def _make_host_assertion(assert_fn: TAssertFn,
139 | name: Optional[str] = None) -> TChexAssertion:
140 | """Constructs a host assertion given `assert_fn`.
141 |
142 | This wrapper should only be applied to the assertions that are either
143 | a) never used in jitted code, or
144 | b) when used in jitted code they do not check/access tensor values (i.e.
145 | they do not introduce value-dependent python control flow, see
146 | https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError).
147 |
148 | Args:
149 | assert_fn: A function implementing the check.
150 | name: A name for assertion.
151 |
152 | Returns:
153 | A chex assertion.
154 | """
155 | if name is None:
156 | name = assert_fn.__name__
157 |
158 | def _assert_on_host(*args,
159 | custom_message: Optional[str] = None,
160 | custom_message_format_vars: Sequence[Any] = (),
161 | include_default_message: bool = True,
162 | exception_type: Type[Exception] = AssertionError,
163 | **kwargs) -> None:
164 | # Format error's stack trace to remove Chex' internal frames.
165 | assertion_exc = None
166 | value_exc = None
167 | try:
168 | assert_fn(*args, **kwargs)
169 | except AssertionError as e:
170 | assertion_exc = e
171 | except ValueError as e:
172 | value_exc = e
173 | finally:
174 | if value_exc is not None:
175 | raise ValueError(str(value_exc))
176 |
177 | if assertion_exc is not None:
178 | # Format the exception message.
179 | error_msg = str(assertion_exc)
180 |
181 | # Include only the name of the outermost chex assertion.
182 | if error_msg.startswith(ERR_PREFIX):
183 | error_msg = error_msg[error_msg.find("failed:") + len("failed:"):]
184 |
185 | # Whether to include the default error message.
186 | default_msg = (f"Assertion {name} failed: "
187 | if include_default_message else "")
188 | error_msg = f"{ERR_PREFIX}{default_msg}{error_msg}"
189 |
190 | # Whether to include a custom error message.
191 | if custom_message:
192 | if custom_message_format_vars:
193 | custom_message = custom_message.format(*custom_message_format_vars)
194 | error_msg = f"{error_msg} [{custom_message}]"
195 |
196 | raise exception_type(error_msg)
197 |
198 | return _assert_on_host
199 |
200 |
201 | def chex_assertion(
202 | assert_fn: TAssertFn,
203 | jittable_assert_fn: Optional[TJittableAssertFn],
204 | name: Optional[str] = None) -> TChexAssertion:
205 | """Wraps Chex assert functions to control their common behaviour.
206 |
207 | Extends the assertion to support the following optional auxiliary kwargs:
208 | custom_message: A string to include into the emitted exception messages.
209 | custom_message_format_vars: A list of variables to pass as arguments to
210 | `custom_message.format()`.
211 | include_default_message: Whether to include the default Chex message into
212 | the emitted exception messages.
213 | exception_type: An exception type to use. `AssertionError` by default.
214 |
215 | Args:
216 | assert_fn: A function implementing the check.
217 | jittable_assert_fn: An optional jittable version of `assert_fn` implementing
218 | a predicate (returning `True` only if assertion passes).
219 | Required for value assertions.
220 | name: A name for assertion. If not provided, use `assert_fn.__name__`.
221 |
222 | Returns:
223 | A Chex assertion (with auxiliary kwargs).
224 | """
225 | if name is None:
226 | name = assert_fn.__name__
227 |
228 | host_assertion_fn = _make_host_assertion(assert_fn, name)
229 |
230 | @functools.wraps(assert_fn)
231 | def _chex_assert_fn(*args,
232 | custom_message: Optional[str] = None,
233 | custom_message_format_vars: Sequence[Any] = (),
234 | include_default_message: bool = True,
235 | exception_type: Type[Exception] = AssertionError,
236 | **kwargs) -> None:
237 | if DISABLE_ASSERTIONS:
238 | return
239 | if (jittable_assert_fn is not None and has_tracers((args, kwargs))):
240 | if not CHEXIFY_STORAGE.level:
241 | raise RuntimeError(
242 | "Value assertions can only be called from functions wrapped "
243 | "with `@chex.chexify`. See the docs.")
244 |
245 | # A wrapped to inject auxiliary debug info and `custom_message`.
246 | original_check = checkify.check
247 |
248 | def _check(pred, msg, *fmt_args, **fmt_kwargs):
249 | # Add chex info.
250 | msg = get_chexify_err_message(name, msg)
251 |
252 | # Add a custom message.
253 | if custom_message:
254 | msg += f" Custom message: {custom_message}."
255 | fmt_args = list(fmt_args) + list(custom_message_format_vars)
256 |
257 | # Add a traceback and a pointer to the callsite.
258 | stacktrace = get_stacktrace_without_chex_internals()
259 | msg += (
260 | f" [failed at: {stacktrace[-1].filename}:{stacktrace[-1].lineno}]"
261 | )
262 |
263 | # Call original `checkify.check()`.
264 | original_check(pred, msg, *fmt_args, **fmt_kwargs)
265 |
266 | # Mock during the assertion's execution time.
267 | checkify.check = _check
268 | pred = jittable_assert_fn(*args, **kwargs) # execute the assertion
269 | checkify.check = original_check # return the original implementation
270 |
271 | # A safeguard to ensure that the results of a check are not ignored.
272 | # In particular, this check fails when `pred` is False and no
273 | # `checkify.check` calls took place in `jittable_assert_fn`, which would
274 | # be a bug in the assertion's implementation.
275 | checkify.check(pred, "assertion failed!")
276 | else:
277 | try:
278 | host_assertion_fn(
279 | *args,
280 | custom_message=custom_message,
281 | custom_message_format_vars=custom_message_format_vars,
282 | include_default_message=include_default_message,
283 | exception_type=exception_type,
284 | **kwargs)
285 | except jax.errors.ConcretizationTypeError as exc:
286 | msg = ("Chex assertion detected `ConcretizationTypeError`: it is very "
287 | "likely that it tried to access tensors' values during tracing. "
288 | "Make sure that you defined a jittable version of this chex "
289 | "assertion; if that does not help, please file a bug.")
290 | raise exc from RuntimeError(msg)
291 |
292 | # Override name.
293 | setattr(_chex_assert_fn, "__name__", name)
294 | return _chex_assert_fn
295 |
296 |
297 | def format_tree_path(path: Sequence[Any]) -> str:
298 | return "/".join(str(p) for p in path)
299 |
300 |
301 | def format_shape_matcher(shape: TShapeMatcher) -> str:
302 | return f"({', '.join('...' if d is Ellipsis else str(d) for d in shape)})"
303 |
304 |
305 | def num_devices_available(devtype: str, backend: Optional[str] = None) -> int:
306 | """Returns the number of available device of the given type."""
307 | devtype = devtype.lower()
308 | supported_types = ("cpu", "gpu", "tpu")
309 | if devtype not in supported_types:
310 | raise ValueError(
311 | f"Unknown device type '{devtype}' (expected one of {supported_types}).")
312 |
313 | return sum(d.platform == devtype for d in jax.devices(backend))
314 |
315 |
316 | def get_tracers(tree: pytypes.ArrayTree) -> Tuple[jax.core.Tracer]:
317 | """Returns a tuple with tracers from a tree."""
318 | return tuple(
319 | x for x in jax.tree_util.tree_leaves(tree)
320 | if isinstance(x, jax.core.Tracer))
321 |
322 |
323 | def has_tracers(tree: pytypes.ArrayTree) -> bool:
324 | """Checks whether a tree contains any tracers."""
325 | return any(
326 | isinstance(x, jax.core.Tracer) for x in jax.tree_util.tree_leaves(tree))
327 |
328 |
329 | def is_traceable(fn) -> bool:
330 | """Checks if function is traceable.
331 |
332 | JAX traces a function when it is wrapped with @jit, @pmap, or @vmap.
333 | In other words, this function checks whether `fn` is wrapped with any of
334 | the aforementioned JAX transformations.
335 |
336 | Args:
337 | fn: function to assert.
338 |
339 | Returns:
340 | Bool indicating whether fn is traceable.
341 | """
342 |
343 | fn_string_tokens = (
344 | ".reraise_with_filtered_traceback", # JIT in Python ver. >= 3.7
345 | "CompiledFunction", # C++ JIT in jaxlib 0.1.66 or newer.
346 | "pmap.", # Python pmap
347 | "PmapFunction", # C++ pmap in jaxlib 0.1.72 or newer.
348 | "vmap.", # vmap
349 | "_python_pjit",
350 | "_cpp_pjit",
351 | )
352 |
353 | fn_type_tokens = (
354 | "PmapFunction",
355 | "PjitFunction",
356 | )
357 |
358 | # Un-wrap `fn` and check if any internal fn is jitted by pattern matching.
359 | fn_ = fn
360 | while True:
361 | if any(t in str(fn_) for t in fn_string_tokens):
362 | return True
363 |
364 | if any(t in str(type(fn_)) for t in fn_type_tokens):
365 | return True
366 |
367 | if hasattr(fn_, "__wrapped__"):
368 | # Wrapper.
369 | fn_globals = getattr(fn_, "__globals__", {})
370 |
371 | if fn_globals.get("__name__", None) == "jax.api":
372 | # Wrapper from `jax.api`.
373 | return True
374 |
375 | if "api_boundary" in fn_globals:
376 | # api_boundary is a JAX wrapper for traced functions.
377 | return True
378 |
379 | try:
380 | if isinstance(fn_, jax.lib.xla_extension.PjitFunction):
381 | return True
382 | except AttributeError:
383 | pass
384 | else:
385 | break
386 |
387 | fn_ = fn_.__wrapped__
388 | return False
389 |
390 |
391 | def assert_leaves_all_eq_comparator(
392 | equality_comparator: TLeavesEqCmpFn,
393 | error_msg_fn: Callable[[TLeaf, TLeaf, str, int, int],
394 | str], path: Sequence[Any], *leaves: Sequence[TLeaf]):
395 | """Asserts all leaves are equal using custom comparator. Not jittable."""
396 | path_str = format_tree_path(path)
397 | for i in range(1, len(leaves)):
398 | if not equality_comparator(leaves[0], leaves[i]):
399 | raise AssertionError(error_msg_fn(leaves[0], leaves[i], path_str, 0, i))
400 |
401 |
402 | def assert_trees_all_eq_comparator_jittable(
403 | equality_comparator: TLeavesEqCmpFn,
404 | error_msg_template: str,
405 | *trees: Sequence[pytypes.ArrayTree]) -> pytypes.Array:
406 | """Asserts all trees are equal using custom comparator. JIT-friendly."""
407 |
408 | if len(trees) < 2:
409 | raise ValueError(
410 | "Assertions over only one tree does not make sense. Maybe you wrote "
411 | "`assert_trees_xxx([a, b])` instead of `assert_trees_xxx(a, b)`, or "
412 | "forgot the `error_msg_fn` arg to `assert_trees_xxx`?")
413 |
414 | def _tree_error_msg_fn(
415 | path: Tuple[Union[int, str, Hashable]], i_1: int, i_2: int):
416 | if path:
417 | return (
418 | f"Trees {i_1} and {i_2} differ in leaves '{path}':"
419 | f" {error_msg_template}"
420 | )
421 | else:
422 | return f"Trees (arrays) {i_1} and {i_2} differ: {error_msg_template}."
423 |
424 | def _cmp_leaves(path, *leaves):
425 | verdict = jnp.array(True)
426 | for i in range(1, len(leaves)):
427 | check_res = equality_comparator(leaves[0], leaves[i])
428 | checkify.check(
429 | pred=check_res,
430 | msg=_tree_error_msg_fn(path, 0, i),
431 | arr_1=leaves[0],
432 | arr_2=leaves[i],
433 | )
434 | verdict = jnp.logical_and(verdict, check_res)
435 | return verdict
436 |
437 | # Trees are guaranteed to have the same structure.
438 | paths = [
439 | convert_jax_path_to_dm_path(path)
440 | for path, _ in jax.tree_util.tree_flatten_with_path(trees[0])[0]]
441 | trees_leaves = [jax.tree_util.tree_leaves(tree) for tree in trees]
442 |
443 | verdict = jnp.array(True)
444 | for leaf_i, path in enumerate(paths):
445 | verdict = jnp.logical_and(
446 | verdict, _cmp_leaves(path, *[leaves[leaf_i] for leaves in trees_leaves])
447 | )
448 |
449 | return verdict
450 |
451 |
452 | JaxKeyType = Union[
453 | int,
454 | str,
455 | Hashable,
456 | jax.tree_util.SequenceKey,
457 | jax.tree_util.DictKey,
458 | jax.tree_util.FlattenedIndexKey,
459 | jax.tree_util.GetAttrKey,
460 | ]
461 |
462 |
463 | def convert_jax_path_to_dm_path(
464 | jax_tree_path: Sequence[JaxKeyType],
465 | ) -> Tuple[Union[int, str, Hashable]]:
466 | """Converts a path from jax.tree_util to one from dm-tree."""
467 |
468 | # pytype:disable=attribute-error
469 | def _convert_key_fn(key: JaxKeyType) -> Union[int, str, Hashable]:
470 | if isinstance(key, (str, int)):
471 | return key # int | str.
472 | if isinstance(key, jax.tree_util.SequenceKey):
473 | return key.idx # int.
474 | if isinstance(key, jax.tree_util.DictKey):
475 | return key.key # Hashable
476 | if isinstance(key, jax.tree_util.FlattenedIndexKey):
477 | return key.key # int.
478 | if isinstance(key, jax.tree_util.GetAttrKey):
479 | return key.name # str.
480 | raise ValueError(f"Jax tree key '{key}' of type '{type(key)}' not valid.")
481 | # pytype:enable=attribute-error
482 |
483 | return tuple(_convert_key_fn(key) for key in jax_tree_path)
484 |
--------------------------------------------------------------------------------
/chex/_src/variants.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Chex variants utilities."""
16 |
17 | import enum
18 | import functools
19 | import inspect
20 | import itertools
21 | from typing import Any, Sequence
22 | import unittest
23 |
24 | from absl import flags
25 | from absl.testing import parameterized
26 | from chex._src import fake
27 | from chex._src import pytypes
28 | import jax
29 | from jax import tree_util
30 | import jax.numpy as jnp
31 | import toolz
32 |
33 | FLAGS = flags.FLAGS
34 | flags.DEFINE_bool(
35 | "chex_skip_pmap_variant_if_single_device", True,
36 | "Whether to skip pmap variant if only one device is available.")
37 |
38 |
39 | # We choose to subclass instead of a simple alias, as Python doesn't allow
40 | # multiple inheritance from the same class, and users may want to subclass their
41 | # tests from both `chex.TestCase` and `parameterized.TestCase`.
42 | #
43 | # User is free to use any base class that supports generators unrolling
44 | # instead of `variants.TestCase` or `parameterized.TestCase`. If a base class
45 | # doesn't support this feature variant test fails with a corresponding error.
46 | class TestCase(parameterized.TestCase):
47 | """A class for Chex tests that use variants.
48 |
49 | See the docstring for ``chex.variants`` for more information.
50 |
51 | Note: ``chex.variants`` returns a generator producing one test per variant.
52 | Therefore, the used test class must support dynamic unrolling of these
53 | generators during module import. It is implemented (and battle-tested) in
54 | ``absl.parameterized.TestCase``, and here we subclass from it.
55 | """
56 |
57 | def variant(self, *args, **kwargs):
58 | """Raises a RuntimeError if not overriden or redefined."""
59 | raise RuntimeError(
60 | "self.variant is not defined: forgot to wrap a test in @chex.variants?")
61 |
62 |
63 | class ChexVariantType(enum.Enum):
64 | """An enumeration of available Chex variants.
65 |
66 | Use ``self.variant.type`` to get type of the current test variant.
67 | See the docstring of ``chex.variants`` for more information.
68 | """
69 |
70 | WITH_JIT = 1
71 | WITHOUT_JIT = 2
72 | WITH_DEVICE = 3
73 | WITHOUT_DEVICE = 4
74 | WITH_PMAP = 5
75 |
76 | def __str__(self) -> str:
77 | return "_" + self.name.lower()
78 |
79 |
80 | tree_map = tree_util.tree_map
81 |
82 |
83 | def params_product(*params_lists: Sequence[Sequence[Any]],
84 | named: bool = False) -> Sequence[Sequence[Any]]:
85 | """Generates a cartesian product of `params_lists`.
86 |
87 | See tests from ``variants_test.py`` for examples of usage.
88 |
89 | Args:
90 | *params_lists: A list of params combinations.
91 | named: Whether to generate test names (for
92 | `absl.parameterized.named_parameters(...)`).
93 |
94 | Returns:
95 | A cartesian product of `params_lists` combinations.
96 | """
97 |
98 | def generate():
99 | for combination in itertools.product(*params_lists):
100 | if named:
101 | name = "_".join(t[0] for t in combination)
102 | args_tuples = (t[1:] for t in combination)
103 | args = sum(args_tuples, ())
104 | yield (name, *args)
105 | else:
106 | yield sum(combination, ())
107 |
108 | return list(generate())
109 |
110 |
111 | def count_num_calls(fn):
112 | """Counts the number of times the function was called."""
113 | num_calls = 0
114 |
115 | @functools.wraps(fn)
116 | def fn_wrapped(*args, **kwargs):
117 | nonlocal num_calls
118 | num_calls += 1
119 | return fn(*args, **kwargs)
120 |
121 | return fn_wrapped, lambda: num_calls
122 |
123 |
124 | class VariantsTestCaseGenerator:
125 | """TestCase generator for chex variants. Supports sharding."""
126 |
127 | def __init__(self, test_object, which_variants):
128 | self._which_variants = which_variants
129 | self._generated_names_freq = {}
130 | if hasattr(test_object, "__iter__"):
131 | # `test_object` is a generator (e.g. parameterised test).
132 | self._test_methods = list(test_object)
133 | else:
134 | # `test_object` is a single test method.
135 | self._test_methods = [test_object]
136 |
137 | def add_variants(self, which_variants):
138 | """Merge variants."""
139 | for var, incl in which_variants.items():
140 | self._which_variants[var] = self._which_variants.get(var, False) or incl
141 |
142 | @property
143 | def __name__(self):
144 | msg = ("A test wrapper attempts to access __name__ of "
145 | "VariantsTestCaseGenerator. Usually, this happens when "
146 | "@parameterized wraps @variants.variants. Make sure that the "
147 | "@variants.variants wrapper is an outer one, i.e. nothing wraps it.")
148 | raise RuntimeError(msg)
149 |
150 | def __call__(self):
151 | msg = ("A test wrapper attempts to invoke __call__ of "
152 | "VariantsTestCaseGenerator: make sure that all `TestCase` instances "
153 | "that use variants inherit from `chex.TestCase`.")
154 | raise RuntimeError(msg)
155 |
156 | def _set_test_name(self, test_method, variant):
157 | """Set a name for the generated test."""
158 | name = getattr(test_method, "__name__", "")
159 | params_repr = getattr(test_method, "__x_params_repr__", "")
160 | chex_suffix = f"{variant}"
161 |
162 | candidate_name = "_".join(filter(None, [name, params_repr, chex_suffix]))
163 | name_freq = self._generated_names_freq.get(candidate_name, 0)
164 | if name_freq:
165 | # Ensure that test names are unique.
166 | new_name = name + "_" + str(name_freq)
167 | unique_name = "_".join(filter(None, [new_name, params_repr, chex_suffix]))
168 | else:
169 | unique_name = candidate_name
170 | self._generated_names_freq[candidate_name] = name_freq + 1
171 |
172 | # Always use name for compatibility with `absl.testing.parameterized`.
173 | setattr(test_method, "__name__", unique_name)
174 | setattr(test_method, "__x_params_repr__", "")
175 | setattr(test_method, "__x_use_name__", True)
176 | return test_method
177 |
178 | def _inner_iter(self, test_method):
179 | """Generate chex variants for a single test."""
180 |
181 | def make_test(variant: ChexVariantType):
182 |
183 | @functools.wraps(test_method)
184 | def test(self, *args, **kwargs):
185 | # Skip pmap variant if only one device is available.
186 |
187 | if (variant is ChexVariantType.WITH_PMAP and
188 | FLAGS["chex_skip_pmap_variant_if_single_device"].value and
189 | jax.device_count() < 2):
190 | raise unittest.SkipTest(
191 | f"Only 1 device is available ({jax.devices()}).")
192 |
193 | # n_cpu_devices assert.
194 | if FLAGS["chex_assert_multiple_cpu_devices"].value:
195 | required_n_cpus = fake.get_n_cpu_devices_from_xla_flags()
196 | if required_n_cpus < 2:
197 | raise RuntimeError(
198 | f"Required number of CPU devices is {required_n_cpus} < 2."
199 | "Consider setting up your test module to use multiple CPU "
200 | " devices (see README.md) or disabling "
201 | "`chex_assert_multiple_cpu_devices` flag.")
202 | available_n_cpus = jax.device_count("cpu")
203 | if required_n_cpus != available_n_cpus:
204 | raise RuntimeError(
205 | "Number of available CPU devices is not equal to the required: "
206 | f"{available_n_cpus} != {required_n_cpus}")
207 |
208 | # Set up the variant.
209 | self.variant, num_calls = count_num_calls(_variant_decorators[variant])
210 | self.variant.type = variant
211 | res = test_method(self, *args, **kwargs)
212 | if num_calls() == 0:
213 | raise RuntimeError(
214 | "Test is wrapped in @chex.variants, but never calls self.variant."
215 | " Consider debugging the test or removing @chex.variants wrapper."
216 | f" (variant: {variant})")
217 | return res
218 |
219 | self._set_test_name(test, variant)
220 | return test
221 |
222 | selected_variants = [
223 | var_name for var_name, is_included in self._which_variants.items()
224 | if is_included
225 | ]
226 | if not selected_variants:
227 | raise ValueError(f"No variants selected for test: {test_method}.")
228 |
229 | return (make_test(var_name) for var_name in selected_variants)
230 |
231 | def __iter__(self):
232 | """Generate chex variants for each test case."""
233 | return itertools.chain(*(self._inner_iter(m) for m in self._test_methods))
234 |
235 |
236 | @toolz.curry
237 | def _variants_fn(test_object, **which_variants) -> VariantsTestCaseGenerator:
238 | """Implements `variants` and `all_variants`."""
239 |
240 | # Convert keys to enum entries.
241 | which_variants = {
242 | ChexVariantType[name.upper()]: var
243 | for name, var in which_variants.items()
244 | }
245 | if isinstance(test_object, VariantsTestCaseGenerator):
246 | # Merge variants for nested wrappers.
247 | test_object.add_variants(which_variants)
248 | else:
249 | test_object = VariantsTestCaseGenerator(test_object, which_variants)
250 |
251 | return test_object
252 |
253 |
254 | @toolz.curry
255 | # pylint: disable=redefined-outer-name
256 | def variants(test_method,
257 | with_jit: bool = False,
258 | without_jit: bool = False,
259 | with_device: bool = False,
260 | without_device: bool = False,
261 | with_pmap: bool = False) -> VariantsTestCaseGenerator:
262 | # pylint: enable=redefined-outer-name
263 | """Decorates a test to expose Chex variants.
264 |
265 | The decorated test has access to a decorator called ``self.variant``, which
266 | may be applied to functions to test different JAX behaviors. Consider:
267 |
268 | .. code-block:: python
269 |
270 | @chex.variants(with_jit=True, without_jit=True)
271 | def test(self):
272 | @self.variant
273 | def f(x, y):
274 | return x + y
275 |
276 | self.assertEqual(f(1, 2), 3)
277 |
278 | In this example, the function ``test`` will be called twice: once with `f`
279 | jitted (i.e. using `jax.jit`) and another where `f` is not jitted.
280 |
281 | Variants `with_jit=True` and `with_pmap=True` accept additional specific to
282 | them arguments. Example:
283 |
284 | .. code-block:: python
285 |
286 | @chex.variants(with_jit=True)
287 | def test(self):
288 | @self.variant(static_argnums=(1,))
289 | def f(x, y):
290 | # `y` is not traced.
291 | return x + y
292 |
293 | self.assertEqual(f(1, 2), 3)
294 |
295 | Variant `with_pmap=True` also accepts `broadcast_args_to_devices`
296 | (whether to broadcast each input argument to all participating devices),
297 | `reduce_fn` (a function to apply to results of pmapped `fn`), and
298 | `n_devices` (number of devices to use in the `pmap` computation).
299 | See the docstring of `_with_pmap` for more details (including default values).
300 |
301 | If used with ``absl.testing.parameterized``, `@chex.variants` must wrap it:
302 |
303 | .. code-block:: python
304 |
305 | @chex.variants(with_jit=True, without_jit=True)
306 | @parameterized.named_parameters('test', *args)
307 | def test(self, *args):
308 | ...
309 |
310 | Tests that use this wrapper must be inherited from ``parameterized.TestCase``.
311 | For more examples see ``variants_test.py``.
312 |
313 | Args:
314 | test_method: A test method to decorate.
315 | with_jit: Whether to test with `jax.jit`.
316 | without_jit: Whether to test without `jax.jit`. Any jit compilation done
317 | within the test method will not be affected.
318 | with_device: Whether to test with args placed on device, using
319 | `jax.device_put`.
320 | without_device: Whether to test with args (explicitly) not placed on device,
321 | using `jax.device_get`.
322 | with_pmap: Whether to test with `jax.pmap`, with computation duplicated
323 | across devices.
324 |
325 | Returns:
326 | A decorated ``test_method``.
327 | """
328 | return _variants_fn(
329 | test_method,
330 | with_jit=with_jit,
331 | without_jit=without_jit,
332 | with_device=with_device,
333 | without_device=without_device,
334 | with_pmap=with_pmap)
335 |
336 |
337 | @toolz.curry
338 | # pylint: disable=redefined-outer-name
339 | def all_variants(test_method,
340 | with_jit: bool = True,
341 | without_jit: bool = True,
342 | with_device: bool = True,
343 | without_device: bool = True,
344 | with_pmap: bool = True) -> VariantsTestCaseGenerator:
345 | # pylint: enable=redefined-outer-name
346 | """Equivalent to ``chex.variants`` but with flipped defaults."""
347 | return _variants_fn(
348 | test_method,
349 | with_jit=with_jit,
350 | without_jit=without_jit,
351 | with_device=with_device,
352 | without_device=without_device,
353 | with_pmap=with_pmap)
354 |
355 |
356 | def check_variant_arguments(variant_fn):
357 | """Raises `ValueError` if `variant_fn` got an unknown argument."""
358 |
359 | @functools.wraps(variant_fn)
360 | def wrapper(*args, **kwargs):
361 | unknown_args = set(kwargs.keys()) - _valid_kwargs_keys
362 | if unknown_args:
363 | raise ValueError(f"Unknown arguments in `self.variant`: {unknown_args}.")
364 | return variant_fn(*args, **kwargs)
365 |
366 | return wrapper
367 |
368 |
369 | @toolz.curry
370 | @check_variant_arguments
371 | def _with_jit(fn,
372 | static_argnums=None,
373 | static_argnames=None,
374 | device=None,
375 | backend=None,
376 | **unused_kwargs):
377 | """Variant that applies `jax.jit` to fn."""
378 |
379 | return jax.jit(
380 | fn,
381 | static_argnums=static_argnums,
382 | static_argnames=static_argnames,
383 | device=device,
384 | backend=backend)
385 |
386 |
387 | @toolz.curry
388 | @check_variant_arguments
389 | def _without_jit(fn, **unused_kwargs):
390 | """Variant that does not apply `jax.jit` to a fn (identity)."""
391 |
392 | @functools.wraps(fn)
393 | def wrapper(*args, **kwargs):
394 | return fn(*args, **kwargs)
395 |
396 | return wrapper
397 |
398 |
399 | @toolz.curry
400 | @check_variant_arguments
401 | def _with_device(fn, ignore_argnums=(), static_argnums=(), **unused_kwargs):
402 | """Variant that applies `jax.device_put` to the args of fn."""
403 |
404 | if isinstance(ignore_argnums, int):
405 | ignore_argnums = (ignore_argnums,)
406 | if isinstance(static_argnums, int):
407 | static_argnums = (static_argnums,)
408 |
409 | @functools.wraps(fn)
410 | def wrapper(*args, **kwargs):
411 |
412 | def put(x):
413 | try:
414 | return jax.device_put(x)
415 | except TypeError: # not a valid JAX type
416 | return x
417 |
418 | device_args = [
419 | arg if (idx in ignore_argnums or idx in static_argnums) else tree_map(
420 | put, arg) for idx, arg in enumerate(args)
421 | ]
422 | device_kwargs = tree_map(put, kwargs)
423 | return fn(*device_args, **device_kwargs)
424 |
425 | return wrapper
426 |
427 |
428 | @toolz.curry
429 | @check_variant_arguments
430 | def _without_device(fn, **unused_kwargs):
431 | """Variant that applies `jax.device_get` to the args of fn."""
432 |
433 | @functools.wraps(fn)
434 | def wrapper(*args, **kwargs):
435 |
436 | def get(x):
437 | if isinstance(x, jax.Array):
438 | return jax.device_get(x)
439 | return x
440 |
441 | no_device_args = tree_map(get, args)
442 | no_device_kwargs = tree_map(get, kwargs)
443 | return fn(*no_device_args, **no_device_kwargs)
444 |
445 | return wrapper
446 |
447 |
448 | @toolz.curry
449 | @check_variant_arguments
450 | def _with_pmap(fn,
451 | broadcast_args_to_devices=True,
452 | reduce_fn="first_device_output",
453 | n_devices=None,
454 | axis_name="i",
455 | devices=None,
456 | in_axes=0,
457 | static_broadcasted_argnums=(),
458 | static_argnums=(),
459 | backend=None,
460 | **unused_kwargs):
461 | """Variant that applies `jax.pmap` to fn.
462 |
463 | Args:
464 | fn: A function to wrap.
465 | broadcast_args_to_devices: Whether to broadcast `fn` args to pmap format
466 | (i.e. pmapped axes' sizes == a number of devices).
467 | reduce_fn: A function to apply to outputs of `fn`.
468 | n_devices: A number of devices to use (can specify a `backend` if required).
469 | axis_name: An argument for `pmap`.
470 | devices: An argument for `pmap`.
471 | in_axes: An argument for `pmap`.
472 | static_broadcasted_argnums: An argument for `pmap`.
473 | static_argnums: An alias of ``static_broadcasted_argnums``.
474 | backend: An argument for `pmap`.
475 | **unused_kwargs: Unused kwargs (e.g. related to other variants).
476 |
477 | Returns:
478 | Wrapped `fn` that accepts `args` and `kwargs` and returns a superposition of
479 | `reduce_fn` and `fn` applied to them.
480 |
481 | Raises:
482 | ValueError: If `broadcast_args_to_devices` used with `in_axes` or
483 | `static_broadcasted_argnums`; if number of available devices is less than
484 | required; if pmappable arg axes' sizes are not equal to the number of
485 | devices.
486 | SkipTest: If the flag ``chex_skip_pmap_variant_if_single_device`` is set and
487 | there is only one device available.
488 | """
489 | if (FLAGS["chex_skip_pmap_variant_if_single_device"].value and
490 | jax.device_count() < 2):
491 | raise unittest.SkipTest(f"Only 1 device is available ({jax.devices()}).")
492 |
493 | if broadcast_args_to_devices and in_axes != 0:
494 | raise ValueError(
495 | "Do not use `broadcast_args_to_devices` when specifying `in_axes`.")
496 |
497 | # Set up a reduce function.
498 | if reduce_fn == "first_device_output":
499 | reduce_fn = lambda t: tree_map(lambda x: x[0], t)
500 | elif reduce_fn == "identity" or reduce_fn is None: # Identity.
501 | reduce_fn = lambda t: t
502 |
503 | if not static_argnums and static_argnums != 0:
504 | static_argnums = static_broadcasted_argnums
505 | if isinstance(static_argnums, int):
506 | static_argnums = (static_argnums,)
507 |
508 | pmap_kwargs = dict(
509 | axis_name=axis_name,
510 | devices=devices,
511 | in_axes=in_axes,
512 | static_broadcasted_argnums=static_argnums,
513 | backend=backend)
514 | pmapped_fn = jax.pmap(fn, **pmap_kwargs)
515 |
516 | @functools.wraps(pmapped_fn)
517 | def wrapper(*args: pytypes.ArrayTree, **kwargs: pytypes.ArrayTree):
518 | if kwargs and (in_axes != 0 or static_argnums):
519 | raise ValueError("Do not use kwargs with `in_axes` or `static_argnums` "
520 | "in pmapped function.")
521 | devices_ = list(devices or jax.devices(backend))
522 | n_devices_ = n_devices or len(devices_)
523 | devices_ = devices_[:n_devices_]
524 | if len(devices_) != n_devices_:
525 | raise ValueError("Number of available devices is less than required for "
526 | f"test ({len(devices_)} < {n_devices_})")
527 |
528 | bcast_fn = lambda x: jnp.broadcast_to(x, (n_devices_,) + jnp.array(x).shape)
529 | if broadcast_args_to_devices:
530 | args = [
531 | tree_map(bcast_fn, arg) if idx not in static_argnums else arg
532 | for idx, arg in enumerate(args)
533 | ]
534 | kwargs = tree_map(bcast_fn, kwargs)
535 | else:
536 | # Pmappable axes size must be equal to number of devices.
537 | in_axes_ = in_axes if isinstance(in_axes,
538 | (tuple, list)) else [in_axes] * len(args)
539 | is_pmappable_arg = [
540 | idx not in static_argnums and in_axes_[idx] is not None
541 | for idx in range(len(args))
542 | ]
543 | for is_pmappable_arg, arg in zip(is_pmappable_arg, args):
544 | if not is_pmappable_arg:
545 | continue
546 | if not all(
547 | x.shape[0] == n_devices_ for x in jax.tree_util.tree_leaves(arg)):
548 | shapes = tree_map(jnp.shape, arg)
549 | raise ValueError(
550 | f"Pmappable arg axes size must be equal to number of devices, "
551 | f"got: {shapes} (expected the first dim to be {n_devices_}). "
552 | "Consider setting `broadcast_args_to_devices=True`.")
553 |
554 | new_kwargs = dict(
555 | axis_name=axis_name,
556 | devices=devices_,
557 | in_axes=in_axes,
558 | static_broadcasted_argnums=static_argnums,
559 | backend=backend)
560 |
561 | # Re-compile fn if kwargs changed.
562 | nonlocal pmap_kwargs
563 | nonlocal pmapped_fn
564 | if new_kwargs != pmap_kwargs:
565 | pmap_kwargs = new_kwargs
566 | pmapped_fn = jax.pmap(fn, **pmap_kwargs)
567 |
568 | res = pmapped_fn(*args, **kwargs)
569 | return reduce_fn(res)
570 |
571 | return wrapper
572 |
573 |
574 | _variant_decorators = dict({
575 | ChexVariantType.WITH_JIT: _with_jit,
576 | ChexVariantType.WITHOUT_JIT: _without_jit,
577 | ChexVariantType.WITH_DEVICE: _with_device,
578 | ChexVariantType.WITHOUT_DEVICE: _without_device,
579 | ChexVariantType.WITH_PMAP: _with_pmap,
580 | })
581 |
582 |
583 | class Variant:
584 | """Variant class for typing and string representation."""
585 |
586 | def __init__(self, name, fn):
587 | self._fn = fn
588 | self._name = name
589 |
590 | def __repr__(self):
591 | return self._name
592 |
593 | def __call__(self, *args, **kwargs):
594 | # Could apply decorators (currying, arg-checking) here
595 | return self._fn(*args, **kwargs)
596 |
597 |
598 | # Expose variant objects.
599 | without_device = Variant("chex_without_device", _without_device)
600 | without_jit = Variant("chex_without_jit", _without_jit)
601 | with_device = Variant("chex_with_device", _with_device)
602 | with_jit = Variant("chex_with_jit", _with_jit)
603 | with_pmap = Variant("chex_with_pmap", _with_pmap)
604 | ALL_VARIANTS = (without_device, without_jit, with_device, with_jit, with_pmap)
605 |
606 | # Collect valid argument names from all variant decorators.
607 | _valid_kwargs_keys = set()
608 | for fn_ in _variant_decorators.values():
609 | original_fn = fn_.func.__wrapped__
610 | _valid_kwargs_keys.update(inspect.getfullargspec(original_fn).args)
611 |
--------------------------------------------------------------------------------