├── autobound ├── example_bounds.png ├── __init__.py ├── jax │ ├── __init__.py │ ├── jaxpr_editor_test.py │ ├── jaxpr_editor.py │ ├── jax_bound_test.py │ └── jax_bound.py ├── types_test.py ├── graph_editor_test.py ├── test_utils.py ├── polynomials_test.py ├── graph_editor.py ├── polynomials.py ├── interval_arithmetic_test.py ├── enclosure_arithmetic_test.py ├── primitive_enclosures.py ├── elementwise_functions_test.py ├── elementwise_functions.py ├── interval_arithmetic.py └── enclosure_arithmetic.py ├── .gitignore ├── CONTRIBUTING.md ├── .github └── workflows │ └── ci-build.yaml ├── pyproject.toml ├── README.md ├── LICENSE └── .pylintrc /autobound/example_bounds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dpsanders/autobound/main/autobound/example_bounds.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Byte-compiled 5 | _pycache__/ 6 | .cache/ 7 | 8 | # Poetry, setuptools, PyPI distribution artifacts. 9 | /*.egg-info 10 | .eggs/ 11 | build/ 12 | dist/ 13 | poetry.lock 14 | 15 | # Tests 16 | .pytest_cache/ 17 | 18 | # Type checking 19 | .pytype/ 20 | 21 | # Other 22 | *.DS_Store 23 | 24 | # PyCharm 25 | .idea 26 | -------------------------------------------------------------------------------- /autobound/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The autobound Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __version__ = "0.1.2" 16 | -------------------------------------------------------------------------------- /autobound/jax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The autobound Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Package for computing Taylor polynomial enclosures in JAX.""" 16 | 17 | from autobound.jax.jax_bound import taylor_bounds 18 | -------------------------------------------------------------------------------- /autobound/types_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The autobound Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from autobound import test_utils 17 | from autobound import types 18 | 19 | 20 | class TestCase(test_utils.TestCase): 21 | 22 | def test_ndarray(self): 23 | for np_like in self.backends: 24 | self.assertIsInstance(np_like.eye(3), types.NDArray) 25 | 26 | def test_numpy_like(self): 27 | for np_like in self.backends: 28 | self.assertIsInstance(np_like, types.NumpyLike) 29 | 30 | 31 | if __name__ == '__main__': 32 | absltest.main() 33 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /.github/workflows/ci-build.yaml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: CI 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: ["3.9", "3.10"] 20 | 21 | steps: 22 | - uses: actions/checkout@v3 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v3 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install flake8 .[dev] 31 | - name: Lint with flake8 32 | run: | 33 | # stop the build if there are Python syntax errors or undefined names 34 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 35 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 36 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 37 | - name: Test with pytest 38 | run: | 39 | pytest 40 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "autobound" 3 | description = "" 4 | readme = "README.md" 5 | requires-python = ">=3.7" 6 | license = {file = "LICENSE"} 7 | authors = [{name = "AutoBound authors", email="autobound@google.com"}] 8 | classifiers = [ 9 | "Programming Language :: Python :: 3", 10 | "Programming Language :: Python :: 3 :: Only", 11 | "License :: OSI Approved :: Apache Software License", 12 | "Intended Audience :: Science/Research", 13 | ] 14 | keywords = [] 15 | 16 | # pip dependencies of the project 17 | dependencies = [ 18 | "jax>=0.4.6" 19 | ] 20 | 21 | # This is set automatically by flit using `autobound.__version__` 22 | dynamic = ["version"] 23 | 24 | [project.urls] 25 | homepage = "https://github.com/google/autobound" 26 | repository = "https://github.com/google/autobound" 27 | # Other: `documentation`, `changelog` 28 | 29 | [project.optional-dependencies] 30 | # Development deps (unittest, linting, formating,...) 31 | # Installed through `pip install .[dev]` 32 | dev = [ 33 | "absl-py>=1.3.0", 34 | "flax>=0.6.7", 35 | "pytest", 36 | "pytest-xdist", 37 | "pylint>=2.6.0", 38 | "pyink", 39 | "typing_extensions>=4.4.0", 40 | ] 41 | 42 | [tool.pyink] 43 | # Formatting configuration to follow Google style-guide 44 | pyink-indentation = 2 45 | pyink-use-majority-quotes = true 46 | 47 | [build-system] 48 | requires = ["flit_core >=3.5,<4"] 49 | build-backend = "flit_core.buildapi" 50 | -------------------------------------------------------------------------------- /autobound/graph_editor_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The autobound Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from absl.testing import parameterized 17 | from autobound import graph_editor 18 | from autobound.graph_editor import ComputationGraph, Operation 19 | 20 | 21 | class TestCase(parameterized.TestCase): 22 | 23 | @parameterized.parameters( 24 | ( 25 | Operation('foo', ['a', 'b'], ['x', 'y']), 26 | {'a': 'A', 'y': 'Y'}, 27 | Operation('foo', ['A', 'b'], ['x', 'Y']), 28 | ) 29 | ) 30 | def test_edge_subs(self, edge, mapping, expected): 31 | actual = edge.subs(mapping) 32 | self.assertEqual(expected, actual) 33 | 34 | @parameterized.parameters( 35 | ( 36 | ComputationGraph(['a', 'b'], ['z'], 37 | [Operation('foo', ['a'], ['y']), 38 | Operation('bar', ['y'], ['z'])]), 39 | {'a', 'b', 'y', 'z'} 40 | ), 41 | ) 42 | def test_intermediate_variables(self, graph, expected): 43 | actual = graph.intermediate_variables() 44 | self.assertSetEqual(expected, actual) 45 | 46 | @parameterized.parameters( 47 | ( 48 | [Operation('foo', ['x'], ['y'])], 49 | [Operation('foo', ['x'], ['y'])], 50 | lambda u, v: True, 51 | {'x': 'x', 'y': 'y'} 52 | ), 53 | ( 54 | [Operation('foo', ['x'], ['y'])], 55 | [Operation('foo', ['x'], ['y'])], 56 | lambda u, v: False, 57 | None 58 | ), 59 | ( 60 | [Operation('foo', ['x'], ['y'])], 61 | [Operation('foo', ['a'], ['b'])], 62 | lambda u, v: True, 63 | {'x': 'a', 'y': 'b'} 64 | ), 65 | ( 66 | [Operation('foo', ['x'], ['y'])], 67 | [Operation('bar', ['x'], ['y'])], 68 | lambda u, v: True, 69 | None 70 | ), 71 | ) 72 | def test_match(self, pattern, subject, can_bind, expected): 73 | actual = graph_editor.match(pattern, subject, can_bind) 74 | self.assertEqual(expected, actual) 75 | if actual is not None: 76 | self.assertTrue(all(can_bind(u, v) for u, v in actual.items())) 77 | self.assertEqual([e.subs(actual) for e in pattern], subject) 78 | 79 | @parameterized.parameters( 80 | ( 81 | [Operation('foo', ['x'], ['y'])], 82 | [Operation('bar', ['x'], ['y'])], 83 | ComputationGraph(['a'], ['b'], [Operation('foo', ['a'], ['b'])]), 84 | lambda u, v: True, 85 | ComputationGraph(['a'], ['b'], [Operation('bar', ['a'], ['b'])]), 86 | ), 87 | ) 88 | def test_replace(self, pattern, replacement, subject, can_bind, expected): 89 | actual = graph_editor.replace(pattern, replacement, subject, can_bind) 90 | self.assertEqual(expected, actual) 91 | 92 | 93 | if __name__ == '__main__': 94 | absltest.main() 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoBound: Automatically Bounding Functions 2 | 3 | ![Continuous integration](https://github.com/google/autobound/actions/workflows/ci-build.yaml/badge.svg) 4 | ![PyPI version](https://img.shields.io/pypi/v/autobound) 5 | 6 | AutoBound is a generalization of automatic differentiation. In addition to 7 | computing a Taylor polynomial approximation of a function, it computes upper 8 | and lower bounds that are guaranteed to hold over a user-specified 9 | _trust region_. 10 | 11 | As an example, here are the quadratic upper and lower bounds AutoBound computes 12 | for the function `f(x) = 1.5*exp(3*x) - 25*(x**2)`, centered at `0.5`, and 13 | valid over the trust region `[0, 1]`. 14 | 15 |
16 | Example quadratic upper and lower bounds 17 |
18 | 19 | The code to compute the bounds shown in this plot looks like this (see [quickstart](https://colab.research.google.com/github/google/autobound/blob/main/autobound/notebooks/quickstart.ipynb)): 20 | 21 | ```python 22 | import autobound.jax as ab 23 | import jax.numpy as jnp 24 | 25 | f = lambda x: 1.5*jnp.exp(3*x) - 25*x**2 26 | x0 = .5 27 | trust_region = (0, 1) 28 | # Compute quadratic upper and lower bounds on f. 29 | bounds = ab.taylor_bounds(f, 2)(x0, trust_region) 30 | # bounds.upper(1) == 5.1283045 == f(1) 31 | # bounds.lower(0) == 1.5 == f(0) 32 | # bounds.coefficients == (0.47253323, -4.8324013, (-5.5549355, 28.287888)) 33 | ``` 34 | 35 | These bounds can be used for: 36 | 37 | * [Computing learning rates that are guaranteed to reduce a loss function](https://colab.research.google.com/github/google/autobound/blob/main/autobound/notebooks/safe_learning_rates.ipynb) 38 | * [Upper and lower bounding integrals](https://colab.research.google.com/github/google/autobound/blob/main/autobound/notebooks/bounding_integrals.ipynb) 39 | * Proving optimality guarantees in global optimization 40 | 41 | and more! 42 | 43 | Under the hood, AutoBound computes these bounds using an interval arithmetic 44 | variant of Taylor-mode automatic differentiation. Accordingly, the memory 45 | requirements are linear in the input dimension, and the method is only 46 | practical for functions with low-dimensional inputs. A reverse-mode algorithm 47 | that efficiently handles high-dimensional inputs is under development. 48 | 49 | A detailed description of the AutoBound algorithm can be found in 50 | [this paper](https://arxiv.org/abs/2212.11429). 51 | 52 | ## Installation 53 | 54 | Assuming you have [installed pip](https://pip.pypa.io/en/stable/installation/), you can install this package directly from GitHub with 55 | 56 | ```bash 57 | pip install git+https://github.com/google/autobound.git 58 | ``` 59 | 60 | or from PyPI with 61 | 62 | ```bash 63 | pip install autobound 64 | ``` 65 | 66 | You may need to [upgrade pip](https://pip.pypa.io/en/stable/installation/#upgrading-pip) before running these commands. 67 | 68 | ## Limitations 69 | 70 | The current code has a few limitations: 71 | 72 | * Only JAX-traceable functions can be automatically bounded. 73 | * Many JAX library functions are not yet supported. What _is_ 74 | supported is bounding the squared error loss of a multi-layer perceptron or convolutional neural network that uses the `jax.nn.sigmoid`, `jax.nn.softplus`, or `jax.nn.swish` activation functions. 75 | * To compute accurate bounds for deeper neural networks, you may need to use 76 | `float64` rather than `float32`. 77 | 78 | ## Citing AutoBound 79 | 80 | To cite this repository: 81 | 82 | ``` 83 | @article{autobound2022, 84 | title={Automatically Bounding the Taylor Remainder Series: Tighter Bounds and New Applications}, 85 | author={Streeter, Matthew and Dillon, Joshua V}, 86 | journal={arXiv preprint arXiv:2212.11429}, 87 | url = {http://github.com/google/autobound}, 88 | year={2022} 89 | } 90 | ``` 91 | 92 | *This is not an officially supported Google product.* 93 | -------------------------------------------------------------------------------- /autobound/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The autobound Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Base class for unit tests.""" 16 | 17 | import math 18 | from typing import List 19 | 20 | from absl.testing import absltest 21 | from autobound import types 22 | import numpy as np 23 | 24 | 25 | MAX_SIGMOID_DERIV = .25 26 | MIN_SIGMOID_SECOND_DERIV = -0.09622504486493762 27 | MAX_SIGMOID_SECOND_DERIV = 0.09622504486493762 28 | MIN_SIGMOID_THIRD_DERIV = -0.125 29 | MAX_SIGMOID_THIRD_DERIV = 0.04166666666666668 30 | 31 | 32 | def sigmoid(x: float) -> float: 33 | return 1/(1+math.exp(-x)) if x >= 0 else math.exp(x)/(1+math.exp(x)) 34 | 35 | 36 | def sigmoid_derivative(order: int, x: float) -> float: 37 | """Returns order `order` derivative of sigmoid at `x`.""" 38 | if order == 0: 39 | return sigmoid(x) 40 | elif order == 1: 41 | return sigmoid(x)*sigmoid(-x) 42 | elif order == 2: 43 | s = sigmoid(x) 44 | return s*sigmoid(-x)*(1-2*s) 45 | elif order == 3: 46 | s = sigmoid(x) 47 | sm = sigmoid(-x) 48 | return s*sm*((1-2*s)**2 - 2*s*sm) 49 | elif order == 4: 50 | s = sigmoid(x) 51 | sm = sigmoid(-x) 52 | return (s*sm*(1-2*s)*((1-2*s)**2 - 2*s*sm) + 53 | s*sm*(-4*(1-2*s)*s*sm - 2*s*sm*(1-2*s))) 54 | else: 55 | raise NotImplementedError(order) 56 | 57 | 58 | def softplus(x: float) -> float: 59 | # Avoid overflow for large positive x using: 60 | # log(1+exp(x)) == log(1+exp(-|x|)) + max(x, 0). 61 | return math.log1p(math.exp(-abs(x))) + max(x, 0.) 62 | 63 | 64 | def softplus_derivative(order: int, x: float) -> float: 65 | if order == 0: 66 | return softplus(x) 67 | else: 68 | return sigmoid_derivative(order-1, x) 69 | 70 | 71 | def swish(x: float) -> float: 72 | return x*sigmoid(x) 73 | 74 | 75 | def swish_derivative(order: int, x: float) -> float: 76 | return (order*sigmoid_derivative(order - 1, x) + 77 | x*sigmoid_derivative(order, x)) 78 | 79 | 80 | class TestCase(absltest.TestCase): 81 | """Base class for test cases.""" 82 | 83 | def assert_allclose_strict(self, expected, actual, **kwargs): 84 | """Like np.testing.assert_allclose, but requires same shape/dtype.""" 85 | np.testing.assert_allclose(actual, expected, **kwargs) 86 | self.assertEqual(np.asarray(expected).shape, np.asarray(actual).shape, 87 | (expected, actual)) 88 | self.assertEqual(np.asarray(expected).dtype, np.array(actual).dtype, 89 | (expected, actual)) 90 | 91 | def assert_enclosure_equal(self, expected, actual, **kwargs): 92 | self.assertLen(actual, len(expected)) 93 | for e, a in zip(expected, actual): 94 | self.assert_interval_equal(e, a, **kwargs) 95 | 96 | def assert_interval_equal(self, expected, actual, **kwargs): 97 | if isinstance(expected, tuple): 98 | e0, e1 = expected 99 | self.assertIsInstance(actual, tuple) 100 | self.assertLen(actual, 2) 101 | a0, a1 = actual 102 | self.assert_allclose_strict(e0, a0, **kwargs) 103 | self.assert_allclose_strict(e1, a1, **kwargs) 104 | else: 105 | self.assert_allclose_strict(expected, actual, **kwargs) 106 | 107 | @classmethod 108 | def setUpClass(cls): 109 | super().setUpClass() 110 | cls.backends = _get_backends() 111 | 112 | 113 | def _get_backends() -> List[types.NumpyLike]: 114 | """Returns list of NumpyLike back ends to test.""" 115 | backends = [np] 116 | 117 | try: 118 | from jax.config import config as jax_config 119 | import jax.numpy as jnp 120 | backends.append(jnp) 121 | jax_config.update('jax_enable_x64', True) 122 | except ModuleNotFoundError: 123 | pass 124 | 125 | try: 126 | import tensorflow.experimental.numpy as tnp 127 | tnp.experimental_enable_numpy_behavior() 128 | backends.append(tnp) 129 | except ModuleNotFoundError: 130 | pass 131 | 132 | return backends -------------------------------------------------------------------------------- /autobound/polynomials_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The autobound Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import operator 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from autobound import polynomials 20 | from autobound import test_utils 21 | import numpy as np 22 | 23 | 24 | class TestCase(parameterized.TestCase, test_utils.TestCase): 25 | 26 | @parameterized.parameters( 27 | ([], .5, operator.mul, operator.pow, 0), 28 | ([1], .5, operator.mul, operator.pow, 1), 29 | ([1, 10], .5, operator.mul, operator.pow, 6), 30 | ([1, 10, 100], .5, operator.mul, operator.pow, 31), 31 | ( 32 | (np.array([2, 3]), np.eye(2)), 33 | np.array([-7, 5]), 34 | lambda a, b: np.tensordot(a, b, np.ndim(b)), 35 | lambda a, b: np.tensordot(a, b, 0), 36 | np.array([-5, 8]), 37 | ), 38 | ) 39 | def test_eval_polynomial(self, coefficients, z, inner_product, outer_product, 40 | expected): 41 | actual = polynomials.eval_polynomial(coefficients, z, inner_product, 42 | outer_product) 43 | np.testing.assert_allclose(actual, expected) 44 | 45 | @parameterized.parameters( 46 | ((1,), 3, 1), 47 | ((1, 2), 3, 7), 48 | ((1, 2, 3), 3, 34), 49 | ((1, 2, (3, 4)), 3, (34, 43)), 50 | ( 51 | ( 52 | np.array([5, 7]), 53 | ), 54 | np.array([-1, -2]), 55 | np.array([5, 7]), 56 | ), 57 | ( 58 | ( 59 | np.array([5, 7]), 60 | np.array([3, 4]), 61 | ), 62 | np.array([-1, -2]), 63 | np.array([2, -1]), 64 | ), 65 | ( 66 | ( 67 | np.array([5, 7]), 68 | ( 69 | np.array([-3, -2]), 70 | np.array([3, 4]), 71 | ) 72 | ), 73 | np.array([-1, -2]), 74 | ( 75 | np.array([2, -1]), 76 | np.array([8, 11]), 77 | ), 78 | ), 79 | ( 80 | (0, 0, (0, 1)), 81 | (-1., 1.), 82 | (0., 1.), 83 | ), 84 | ( 85 | (0, 0, (-1, 1)), 86 | (-1., 1.), 87 | (-1., 1.), 88 | ), 89 | ) 90 | def test_eval_elementwise_taylor_enclosure(self, enclosure, x_minus_x0, 91 | expected): 92 | actual = polynomials.eval_elementwise_taylor_enclosure( 93 | enclosure, x_minus_x0, np) 94 | self.assert_interval_equal(expected, actual) 95 | 96 | @parameterized.parameters( 97 | ( 98 | (np.array([2, 3]), np.eye(2)), 99 | np.array([-7, 5]), 100 | np.array([-5., 8.]), 101 | ), 102 | ( 103 | (np.array([2, 3]), (np.eye(2), 2*np.eye(2))), 104 | np.array([-7, 5]), 105 | (np.array([-12., 8.]), np.array([-5., 13.])), 106 | ), 107 | ( 108 | (0, 0, (0, 1)), 109 | (-1., 1.), 110 | (0., 1.), 111 | ), 112 | ( 113 | (0, 0, (-1, 1)), 114 | (-1., 1.), 115 | (-1., 1.), 116 | ), 117 | ) 118 | def test_eval_taylor_enclosure(self, enclosure, x_minus_x0, expected): 119 | actual = polynomials.eval_taylor_enclosure(enclosure, x_minus_x0, np) 120 | self.assert_interval_equal(expected, actual) 121 | 122 | @parameterized.parameters( 123 | ( 124 | (2, 3), 125 | (5, 7), 126 | operator.add, 127 | 0, 128 | lambda c0, c1, i, j: c0*c1, 129 | (10, 29, 21), 130 | ) 131 | ) 132 | def test_arbitrary_bilinear( 133 | self, a, b, add, additive_identity, term_product_coefficient, expected): 134 | actual = polynomials.arbitrary_bilinear(a, b, add, additive_identity, 135 | term_product_coefficient) 136 | self.assertEqual(expected, actual) 137 | 138 | @parameterized.parameters( 139 | ( 140 | 1, 141 | 2, 142 | 3, 143 | [(1, 1, 0)] 144 | ), 145 | ( 146 | 2, 147 | 2, 148 | 3, 149 | [(0, 2, 0), (1, 0, 1)] 150 | ), 151 | ) 152 | def test_iter_partitions(self, n, m, k, expected_iterates): 153 | actual_iterates = list(polynomials._iter_partitions(n, m, k)) 154 | self.assertListEqual(expected_iterates, actual_iterates) 155 | 156 | @parameterized.parameters( 157 | ((2, 3, 5), 0, (1,)), 158 | ((2, 3, 5), 1, (2, 3, 5)), 159 | ((2, 3, 5), 2, (4, 12, 29, 30, 25)), 160 | ) 161 | def test_integer_power(self, a, exponent, expected): 162 | # TODO(mstreeter): test non-default values of the kwargs. 163 | actual = polynomials.integer_power(a, exponent) 164 | self.assertEqual(expected, actual) 165 | 166 | @parameterized.parameters( 167 | ([2, 0], 1), 168 | ([1, 1], 2), 169 | ([99, 1], 100), 170 | ([98, 2], 100*99 // 2), 171 | ([95, 2, 3], 100*99*98*97*96 // (3*2*2)), 172 | ) 173 | def test_multinomial_coefficient(self, ks, expected): 174 | actual = polynomials._multinomial_coefficient(ks) 175 | self.assertEqual(expected, actual) 176 | 177 | 178 | if __name__ == '__main__': 179 | absltest.main() 180 | -------------------------------------------------------------------------------- /autobound/graph_editor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The autobound Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A library for editing computation graphs. 16 | 17 | This library offers more limited functionality than dedicated pattern-matching 18 | libraries like matchpy (https://github.com/HPAC/matchpy). However, unlike 19 | matchpy and many other existing libraries, it works directly on computation 20 | graphs rather than on trees. This makes it easier to handle certain use cases, 21 | like editing Jaxprs. 22 | """ 23 | 24 | import dataclasses 25 | import itertools 26 | from typing import (Any, Callable, Dict, Hashable, Mapping, Optional, Sequence, 27 | Set) 28 | 29 | # An intermediate variable in a computation graph (e.g., representing a tensor). 30 | IntermediateVariable = Hashable 31 | 32 | 33 | @dataclasses.dataclass 34 | class Operation: 35 | """An operation (a.k.a. equation) in a computation graph.""" 36 | data: Any # for example, the operation type (e.g., 'MatMul') 37 | inputs: Sequence[IntermediateVariable] 38 | outputs: Sequence[IntermediateVariable] 39 | 40 | def subs(self, mapping: Mapping[IntermediateVariable, IntermediateVariable]): 41 | return Operation(self.data, 42 | [mapping.get(u, u) for u in self.inputs], 43 | [mapping.get(v, v) for v in self.outputs]) 44 | 45 | 46 | @dataclasses.dataclass 47 | class ComputationGraph: 48 | """A computation graph (e.g., representing a Jaxpr or a tf.Graph).""" 49 | inputs: Sequence[IntermediateVariable] 50 | outputs: Sequence[IntermediateVariable] 51 | operations: Sequence[Operation] 52 | data: Any = None # Any global data associated with the graph. 53 | 54 | def intermediate_variables(self) -> Set[IntermediateVariable]: 55 | intvars = set(self.inputs) 56 | intvars.update(self.outputs) 57 | for op in self.operations: 58 | intvars.update(op.inputs) 59 | intvars.update(op.outputs) 60 | return intvars 61 | 62 | 63 | def match( 64 | pattern: Sequence[Operation], 65 | subject: Sequence[Operation], 66 | can_bind: Callable[[IntermediateVariable, IntermediateVariable], bool] 67 | ) -> Optional[Dict[IntermediateVariable, IntermediateVariable]]: 68 | """Checks whether a pattern matches a subject, and returns mapping if so. 69 | 70 | Args: 71 | pattern: a sequence of `Operations` 72 | subject: a sequence of `Operations` 73 | can_bind: a callable that, given as arguments a pattern 74 | `IntermediateVariable` `u` and a subject `IntermediateVariable` `v`, 75 | determines whether `u` can be mapped to `v`. (This could return `False`, 76 | for example, if `u `represents a constant and `v` represents a different 77 | constant.) 78 | 79 | Returns: 80 | A dict `m` representing a match, or `None` if no match was found. If the 81 | return value is not `None`, it maps from pattern `IntermediateVertex` to 82 | subject `IntermediateVertex`, and satisfies: 83 | ```python 84 | [e.subs(m) for e in pattern] == subject 85 | ``` 86 | """ 87 | if len(pattern) != len(subject): 88 | return None 89 | 90 | vertex_map = {} # from pattern vertex to subject vertex 91 | for p, s in zip(pattern, subject): 92 | if (len(p.inputs) != len(s.inputs) or len(p.outputs) != len(s.outputs)): 93 | return None 94 | if p.data != s.data: 95 | return None 96 | for u, v in itertools.chain(zip(p.inputs, s.inputs), 97 | zip(p.outputs, s.outputs)): 98 | if u in vertex_map: 99 | if vertex_map[u] != v: 100 | return None 101 | else: 102 | if can_bind(u, v): 103 | vertex_map[u] = v 104 | else: 105 | return None 106 | return vertex_map 107 | 108 | 109 | def replace( 110 | pattern: Sequence[Operation], 111 | replacement: Sequence[Operation], 112 | subject: ComputationGraph, 113 | can_bind: Callable[[IntermediateVariable, IntermediateVariable], bool] 114 | ) -> ComputationGraph: 115 | """Perform a search/replace on a ComputationGraph. 116 | 117 | This method greedily replaces occurences of a given operation sequence 118 | `pattern` with an operation sequence `replacement`. 119 | 120 | Args: 121 | pattern: a sequence of `Operations` 122 | replacement: a sequence of `Operations` with which to replace occurrences of 123 | the pattern. 124 | subject: a `ComputationGraph` 125 | can_bind: a `Callable` with the same meaning as the corresponding argument 126 | to `match()`. 127 | 128 | Returns: 129 | a `ComputationGraph` with occurrences of `pattern` replaced by 130 | `replacement`. 131 | """ 132 | # Note: this could be made more efficient using the Knuth-Morris-Pratt 133 | # algorithm. 134 | k = len(pattern) 135 | output_operations = [] 136 | i = 0 137 | while i < len(subject.operations): 138 | subgraph = subject.operations[i:i+k] 139 | m = match(pattern, subgraph, can_bind) 140 | if m is not None: 141 | output_operations.extend([e.subs(m) for e in replacement]) 142 | i += k 143 | else: 144 | output_operations.append(subject.operations[i]) 145 | i += 1 146 | 147 | return ComputationGraph( 148 | subject.inputs, 149 | subject.outputs, 150 | output_operations, 151 | data=subject.data 152 | ) 153 | -------------------------------------------------------------------------------- /autobound/jax/jaxpr_editor_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The autobound Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from absl.testing import parameterized 17 | from autobound.jax import jaxpr_editor 18 | import jax 19 | import jax.numpy as jnp 20 | import numpy as np 21 | 22 | 23 | softplus_p = jax.core.Primitive('__autobound_softplus__') 24 | softplus_p.def_abstract_eval( 25 | lambda x: jax.abstract_arrays.ShapedArray(x.shape, x.dtype, weak_type=True)) 26 | 27 | 28 | class TestCase(parameterized.TestCase): 29 | 30 | @parameterized.named_parameters( 31 | ( 32 | 'exp', 33 | jax.make_jaxpr(jnp.exp)(0.).jaxpr, 34 | jax.make_jaxpr(jnp.exp)(0.).jaxpr, 35 | False, 36 | True 37 | ), 38 | ( 39 | 'exp_different_shapes', 40 | jax.make_jaxpr(jnp.exp)(0.).jaxpr, 41 | jax.make_jaxpr(jnp.exp)(np.array([0.])).jaxpr, 42 | False, 43 | False 44 | ), 45 | ( 46 | 'exp_different_shapes_ignore', 47 | jax.make_jaxpr(jnp.exp)(0.).jaxpr, 48 | jax.make_jaxpr(jnp.exp)(np.array([0.])).jaxpr, 49 | True, 50 | True 51 | ), 52 | ( 53 | 'exp_vs_log', 54 | jax.make_jaxpr(jnp.exp)(0.).jaxpr, 55 | jax.make_jaxpr(jnp.log)(0.).jaxpr, 56 | False, 57 | False 58 | ), 59 | ( 60 | 'exp_float_vs_exp_int', 61 | jax.make_jaxpr(jnp.exp)(0.).jaxpr, 62 | jax.make_jaxpr(jnp.exp)(0).jaxpr, 63 | False, 64 | False 65 | ), 66 | ( 67 | 'plus_one', 68 | jax.make_jaxpr(lambda x: x+1)(0.).jaxpr, 69 | jax.make_jaxpr(lambda x: x+1)(0.).jaxpr, 70 | False, 71 | True 72 | ), 73 | ( 74 | 'plus_one_vs_plus_two', 75 | jax.make_jaxpr(lambda x: x+1)(0.).jaxpr, 76 | jax.make_jaxpr(lambda x: x+2)(0.).jaxpr, 77 | False, 78 | False 79 | ), 80 | ( 81 | 'softplus_p', 82 | jax.make_jaxpr(softplus_p.bind)(0.).jaxpr, 83 | jax.make_jaxpr(softplus_p.bind)(0.).jaxpr, 84 | False, 85 | True 86 | ), 87 | ( 88 | 'jax_nn_sigmoid', 89 | jax.make_jaxpr(jax.nn.sigmoid)(0.).jaxpr, 90 | jax.make_jaxpr(jax.nn.sigmoid)(0.).jaxpr, 91 | False, 92 | True 93 | ), 94 | ( 95 | 'sigmoid_vs_softplus', 96 | jax.make_jaxpr(jax.nn.sigmoid)(0.).jaxpr, 97 | jax.make_jaxpr(jax.nn.softplus)(0.).jaxpr, 98 | False, 99 | False 100 | ), 101 | ( 102 | 'jax_nn_softplus', 103 | jax.make_jaxpr(jax.nn.softplus)(0.).jaxpr, 104 | jax.make_jaxpr(jax.nn.softplus)(0.).jaxpr, 105 | False, 106 | True 107 | ), 108 | ( 109 | 'jax_nn_swish', 110 | jax.make_jaxpr(jax.nn.swish)(0.).jaxpr, 111 | jax.make_jaxpr(jax.nn.swish)(0.).jaxpr, 112 | False, 113 | True 114 | ), 115 | ) 116 | def test_same_jaxpr_up_to_variable_renaming(self, j0, j1, ignore_shape, 117 | expected): 118 | actual = jaxpr_editor._same_jaxpr_up_to_variable_renaming(j0, j1, 119 | ignore_shape) 120 | self.assertEqual(expected, actual) 121 | 122 | def assert_jaxpr_equiv(self, expected, actual): 123 | self.assertTrue( 124 | jaxpr_editor._same_jaxpr_up_to_variable_renaming(expected, actual)) 125 | 126 | @parameterized.parameters( 127 | (jax.make_jaxpr(jnp.exp)(0.).jaxpr,), 128 | (jax.make_jaxpr(lambda x: 2*jnp.exp(x+1))(0.).jaxpr,), 129 | (jax.make_jaxpr(lambda x: (x * jnp.array([[1], [2]])).sum())(0.).jaxpr,), 130 | ) 131 | def test_graph_conversion(self, jaxpr): 132 | graph = jaxpr_editor._jaxpr_to_graph(jaxpr) 133 | converted_jaxpr = jaxpr_editor._graph_to_jaxpr(graph) 134 | self.assert_jaxpr_equiv(jaxpr, converted_jaxpr) 135 | 136 | @parameterized.named_parameters( 137 | ( 138 | 'replace_exp_with_log', 139 | jax.make_jaxpr(jnp.exp)(0.).jaxpr, 140 | jax.make_jaxpr(jnp.log)(0.).jaxpr, 141 | jax.make_jaxpr(jnp.exp)(0.).jaxpr, 142 | jax.make_jaxpr(jnp.log)(0.).jaxpr, 143 | ), 144 | ( 145 | 'replace_exp_with_log_different_shapes', 146 | jax.make_jaxpr(jnp.exp)(0.).jaxpr, 147 | jax.make_jaxpr(jnp.log)(0.).jaxpr, 148 | jax.make_jaxpr(jnp.exp)(np.array([0.])).jaxpr, 149 | jax.make_jaxpr(jnp.log)(np.array([0.])).jaxpr, 150 | ), 151 | ( 152 | 'replace_exp_with_log_2', 153 | jax.make_jaxpr(jnp.exp)(0.).jaxpr, 154 | jax.make_jaxpr(jnp.log)(0.).jaxpr, 155 | jax.make_jaxpr(lambda x: 2*jnp.exp(x+1))(0.).jaxpr, 156 | jax.make_jaxpr(lambda x: 2*jnp.log(x+1))(0.).jaxpr, 157 | ), 158 | ( 159 | 'replace_exp_with_log_x_squared', 160 | jax.make_jaxpr(jnp.exp)(0.).jaxpr, 161 | jax.make_jaxpr(lambda x: jnp.log(x**2))(0.).jaxpr, 162 | jax.make_jaxpr(lambda x: 2*jnp.exp(x+1))(0.).jaxpr, 163 | jax.make_jaxpr(lambda x: 2*jnp.log((x+1)**2))(0.).jaxpr, 164 | ), 165 | ( 166 | 'replace_softplus_with_primitive', 167 | jax.make_jaxpr(jax.nn.softplus)(0.).jaxpr, 168 | jax.make_jaxpr(softplus_p.bind)(0.).jaxpr, 169 | jax.make_jaxpr(jax.nn.softplus)(0.).jaxpr, 170 | jax.make_jaxpr(softplus_p.bind)(0.).jaxpr, 171 | ), 172 | ( 173 | 'replace_softplus_with_primitive_different_shapes', 174 | jax.make_jaxpr(jax.nn.softplus)(0.).jaxpr, 175 | jax.make_jaxpr(softplus_p.bind)(0.).jaxpr, 176 | jax.make_jaxpr(jax.nn.softplus)(np.array([0.])).jaxpr, 177 | jax.make_jaxpr(softplus_p.bind)(np.array([0.])).jaxpr, 178 | ), 179 | ) 180 | def test_jaxpr_replace(self, pattern, replacement, subject, expected): 181 | actual = jaxpr_editor.replace(pattern, replacement, subject) 182 | self.assert_jaxpr_equiv(expected, actual) 183 | 184 | 185 | if __name__ == '__main__': 186 | absltest.main() 187 | -------------------------------------------------------------------------------- /autobound/jax/jaxpr_editor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The autobound Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library for editing JAX expressions (Jaxprs).""" 16 | 17 | import itertools 18 | import types 19 | from typing import Any, Tuple, Union 20 | 21 | from autobound import graph_editor 22 | import jax 23 | 24 | 25 | def replace(pattern: jax.core.Jaxpr, 26 | replacement: jax.core.Jaxpr, 27 | subject: jax.core.Jaxpr) -> jax.core.Jaxpr: 28 | """Replace occurrences of a pattern Jaxpr within a subject Jaxpr. 29 | 30 | This method return a `Jaxpr` that is an edited version of `subject` 31 | in which occurrences of `pattern` have been replaced by `replacement`. 32 | Occurences of `pattern` must be contiguous sequences of equations in 33 | `subject`. 34 | 35 | Example usage: 36 | 37 | ```python 38 | pattern_fun = jnp.exp 39 | replacement_fun = jnp.log 40 | subject_fun = lambda x: 2*jnp.exp(x+1) 41 | to_jaxpr = lambda f: jax.make_jaxpr(f)(0.).jaxpr 42 | 43 | replace( 44 | to_jaxpr(pattern_fun), 45 | to_jaxpr(replacement_fun), 46 | to_jaxpr(subject_fun), 47 | ) # returns Jaxpr for 2*jnp.log(x+1). 48 | ``` 49 | 50 | Args: 51 | pattern: a `Jaxpr` 52 | replacement: a `Jaxpr` 53 | subject: a `Jaxpr` 54 | 55 | Returns: 56 | a `Jaxpr` in which occurrences of `pattern` have been replaced by 57 | `replacement`. 58 | """ 59 | if len(pattern.invars) != len(replacement.invars): 60 | raise ValueError() 61 | if len(pattern.outvars) != len(replacement.outvars): 62 | raise ValueError() 63 | if pattern.constvars or replacement.constvars: 64 | raise NotImplementedError() 65 | pattern_graph = _jaxpr_to_graph(pattern) 66 | # When creating the replacement graph, offset the variable counts to keep 67 | # them unique. 68 | max_count = max( 69 | v[1] if v[0] else -1 # pytype: disable=unsupported-operands 70 | for v in pattern_graph.intermediate_variables() 71 | ) 72 | replacement_graph = _jaxpr_to_graph(replacement, offset=max_count + 1) 73 | # In the replacement graph, replace inputs/outputs with those in the pattern 74 | # graph. 75 | intermediate_variable_map = {} 76 | for u, v in itertools.chain( 77 | zip(replacement_graph.inputs, pattern_graph.inputs), 78 | zip(replacement_graph.outputs, pattern_graph.outputs), 79 | ): 80 | intermediate_variable_map[u] = v 81 | replacement_operations = [ 82 | e.subs(intermediate_variable_map) for e in replacement_graph.operations 83 | ] 84 | hypergraph = graph_editor.replace( 85 | pattern_graph.operations, 86 | replacement_operations, 87 | _jaxpr_to_graph(subject), 88 | _can_bind, 89 | ) 90 | return _graph_to_jaxpr(hypergraph) 91 | 92 | 93 | class _EqnData: 94 | 95 | def __init__(self, primitive, params): 96 | self.primitive = primitive 97 | self.params = params 98 | 99 | def __eq__(self, other): 100 | return (isinstance(other, _EqnData) and 101 | self.primitive == other.primitive and 102 | _jaxpr_eqn_params_equiv(self.params, other.params)) 103 | 104 | def __hash__(self): 105 | return hash(self.primitive) 106 | 107 | 108 | # An intermediate variable is a tuple that either represents a jax.core.Var or 109 | # a jax.core.Literal. 110 | _IntermediateVariable = Union[ 111 | # If the first element of the tuple is True, then the tuple represents 112 | # a jax.core.Var, and is of the form (True, count, suffix, aval). 113 | Tuple[bool, int, str, jax.core.AbstractValue], 114 | # If the first element of the tuple is False, then the tuple represents 115 | # a jax.core.Literal, and is of the form (False, val, aval). 116 | Tuple[bool, Any, jax.core.AbstractValue] 117 | ] 118 | 119 | 120 | def _jaxpr_to_graph(jaxpr: jax.core.Jaxpr, 121 | offset: int = 0) -> graph_editor.ComputationGraph: 122 | """Returns a ComputationGraph that represents a Jaxpr. 123 | 124 | Args: 125 | jaxpr: a `Jaxpr` 126 | offset: an offset for the indices that appear in intermediate variables. 127 | This can be used to ensure uniqueness. 128 | 129 | Returns: 130 | a ComputationGraph that represents `jaxpr`. 131 | """ 132 | 133 | def get_intermediate_variable( 134 | var_or_literal: Union[jax.core.Var, jax.core.Literal] 135 | ) -> _IntermediateVariable: 136 | if isinstance(var_or_literal, jax.core.Var): 137 | var = var_or_literal 138 | return (True, var.count + offset, var.suffix, var.aval) 139 | elif isinstance(var_or_literal, jax.core.Literal): 140 | literal = var_or_literal 141 | return (False, literal.val, literal.aval) 142 | else: 143 | raise NotImplementedError() 144 | 145 | operations = [] 146 | for eqn in jaxpr.eqns: 147 | data = _EqnData(eqn.primitive, eqn.params) 148 | edge = graph_editor.Operation( 149 | data, 150 | [get_intermediate_variable(v) for v in eqn.invars], 151 | [get_intermediate_variable(v) for v in eqn.outvars] 152 | ) 153 | operations.append(edge) 154 | 155 | data = [get_intermediate_variable(v) for v in jaxpr.constvars] 156 | return graph_editor.ComputationGraph( 157 | [get_intermediate_variable(v) for v in jaxpr.invars], 158 | [get_intermediate_variable(v) for v in jaxpr.outvars], 159 | operations, 160 | data=data 161 | ) 162 | 163 | 164 | def _graph_to_jaxpr(h: graph_editor.ComputationGraph) -> jax.core.Jaxpr: 165 | """Returns the Jaxpr represented by a ComputationGraph.""" 166 | count_to_var = {} 167 | 168 | def vertex_to_var_or_literal(vertex): 169 | if vertex[0]: 170 | _, count, suffix, aval = vertex 171 | if count not in count_to_var: 172 | count_to_var[count] = jax.core.Var(count, suffix, aval) 173 | return count_to_var[count] 174 | else: 175 | _, val, aval = vertex 176 | return jax.core.Literal(val, aval) 177 | 178 | eqns = [] 179 | for edge in h.operations: 180 | eqn = jax.core.new_jaxpr_eqn( 181 | invars=[vertex_to_var_or_literal(u) for u in edge.inputs], 182 | outvars=[vertex_to_var_or_literal(v) for v in edge.outputs], 183 | primitive=edge.data.primitive, 184 | params=edge.data.params, 185 | effects=set() 186 | ) 187 | eqns.append(eqn) 188 | 189 | invars = [vertex_to_var_or_literal(v) for v in h.inputs] 190 | outvars = [vertex_to_var_or_literal(v) for v in h.outputs] 191 | constvars = [vertex_to_var_or_literal(v) for v in h.data] 192 | return jax.core.Jaxpr(constvars, invars, outvars, eqns) 193 | 194 | 195 | def _can_bind(u, v): 196 | if u[0]: 197 | return v[0] 198 | else: 199 | return (not v[0]) and (u[1] == v[1]) 200 | 201 | 202 | # Set of Jaxpr equation params we ignore for matching purposes. 203 | _JAXPR_EQN_PARAMS_TO_IGNORE = frozenset(['weak_type']) 204 | 205 | 206 | def _jaxpr_eqn_params_equiv(p0, p1) -> bool: 207 | """Returns whether two Jaxpr equation params dicts are equivalent.""" 208 | if set(p0.keys()) != set(p1.keys()): 209 | return False 210 | for k0, v0 in p0.items(): 211 | if k0 in _JAXPR_EQN_PARAMS_TO_IGNORE: 212 | continue 213 | v1 = p1[k0] 214 | if isinstance(v0, jax.core.ClosedJaxpr): 215 | # TODO(mstreeter): this could incorrectly return True if v0.consts and 216 | # v1.consts are different. 217 | if not _same_jaxpr_up_to_variable_renaming(v0.jaxpr, v1.jaxpr, 218 | ignore_shape=True): 219 | return False 220 | elif isinstance(v0, jax.core.Jaxpr): 221 | if not _same_jaxpr_up_to_variable_renaming(v0, v1, ignore_shape=True): 222 | return False 223 | elif isinstance(v0, types.FunctionType): 224 | if not isinstance(v1, types.FunctionType): 225 | return False 226 | elif v0 != v1: 227 | return False 228 | return True 229 | 230 | 231 | def _match_avals(a0, a1, ignore_shape): 232 | return a0.dtype == a1.dtype and (ignore_shape or (a0.shape == a1.shape)) 233 | 234 | 235 | def _same_jaxpr_up_to_variable_renaming(j0: jax.core.Jaxpr, 236 | j1: jax.core.Jaxpr, 237 | ignore_shape: bool = False) -> bool: 238 | """Return whether to Jaxprs are identical up to variable renaming.""" 239 | var_map = {} 240 | def check(v0, v1): 241 | if isinstance(v0, jax.core.Literal): 242 | return (isinstance(v1, jax.core.Literal) and v0.val == v1.val and 243 | v0.aval == v1.aval) 244 | elif isinstance(v0, jax.core.Var): 245 | if v0 not in var_map: 246 | if _match_avals(v0.aval, v1.aval, ignore_shape): 247 | var_map[v0] = v1 248 | return True 249 | else: 250 | return False 251 | else: 252 | return var_map[v0] == v1 253 | else: 254 | raise NotImplementedError(v0) 255 | 256 | if (len(j0.constvars) != len(j1.constvars) or 257 | len(j0.invars) != len(j1.invars) or 258 | len(j0.outvars) != len(j1.outvars) or 259 | len(j0.eqns) != len(j1.eqns)): 260 | return False 261 | 262 | for v0, v1 in itertools.chain(zip(j0.invars, j1.invars), 263 | zip(j0.constvars, j1.constvars), 264 | zip(j0.outvars, j1.outvars)): 265 | if not check(v0, v1): 266 | return False 267 | 268 | for eq0, eq1 in zip(j0.eqns, j1.eqns): 269 | if eq0.primitive != eq1.primitive: 270 | return False 271 | if not _jaxpr_eqn_params_equiv(eq0.params, eq1.params): 272 | return False 273 | if (len(eq0.invars) != len(eq1.invars) or 274 | len(eq0.outvars) != len(eq1.outvars)): 275 | return False 276 | for v0, v1 in zip(eq0.invars, eq1.invars): 277 | if not check(v0, v1): 278 | return False 279 | for v0, v1 in zip(eq0.outvars, eq1.outvars): 280 | if not check(v0, v1): 281 | return False 282 | 283 | return True -------------------------------------------------------------------------------- /autobound/polynomials.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The autobound Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library for evaluating different types of polynomials.""" 16 | 17 | import math 18 | import operator 19 | from typing import Callable, Iterator, Sequence, Tuple, TypeVar, Union 20 | 21 | from autobound import interval_arithmetic 22 | from autobound import types 23 | 24 | Foo = TypeVar('Foo') # some arbitrary type, like NDArray or Interval 25 | FooLike = TypeVar('FooLike', bound=Foo) # pytype: disable=invalid-typevar 26 | 27 | 28 | def eval_polynomial( 29 | coefficients: Sequence[FooLike], 30 | z: FooLike, 31 | inner_product: Callable[[FooLike, FooLike], Foo], 32 | outer_power: Callable[[FooLike, int], Foo], 33 | add: Callable[[FooLike, FooLike], Foo] = operator.add, 34 | additive_identity: Foo = 0, 35 | multiplicative_identity: Foo = 1) -> Foo: 36 | """Returns the value of a polynomial at a specific point.""" 37 | running_sum = additive_identity 38 | z_to_the_i = multiplicative_identity 39 | for i, coefficient in enumerate(coefficients): 40 | if i > 0: 41 | z_to_the_i = outer_power(z, i) 42 | term = inner_product(coefficient, z_to_the_i) 43 | running_sum = add(running_sum, term) 44 | return running_sum 45 | 46 | 47 | def eval_elementwise_taylor_enclosure( 48 | enclosure: types.ElementwiseTaylorEnclosureLike, 49 | x_minus_x0: Union[types.NDArrayLike, types.IntervalLike], 50 | np_like: types.NumpyLike) -> Union[types.Interval, types.NDArray]: 51 | """Returns value of an ElementwiseTaylorEnclosure at x-x0.""" 52 | set_arithmetic = interval_arithmetic.IntervalArithmetic(np_like) 53 | return eval_polynomial(enclosure, 54 | set_arithmetic.as_interval_or_ndarray(x_minus_x0), 55 | set_arithmetic.multiply, 56 | set_arithmetic.power, 57 | set_arithmetic.add, 58 | np_like.array(0), 59 | np_like.array(1)) 60 | 61 | 62 | def eval_taylor_enclosure( 63 | enclosure: types.TaylorEnclosureLike, 64 | x_minus_x0: Union[types.NDArrayLike, types.IntervalLike], 65 | np_like: types.NumpyLike) -> Union[types.Interval, types.NDArray]: 66 | """Returns value of an TaylorEnclosure at x-x0.""" 67 | set_arithmetic = interval_arithmetic.IntervalArithmetic(np_like) 68 | inner_product = ( 69 | lambda a, b: set_arithmetic.tensordot(a, b, set_arithmetic.ndim(b))) 70 | return eval_polynomial(enclosure, 71 | set_arithmetic.as_interval_or_ndarray(x_minus_x0), 72 | inner_product, 73 | set_arithmetic.outer_power, 74 | set_arithmetic.add, 75 | np_like.array(0), 76 | np_like.array(1)) 77 | 78 | 79 | def arbitrary_bilinear( 80 | a: Sequence[FooLike], 81 | b: Sequence[FooLike], 82 | add: Callable[[FooLike, FooLike], Foo] = operator.add, 83 | additive_identity: Foo = 0, 84 | term_product_coefficient: Callable[[FooLike, FooLike, int, int], Foo] 85 | = lambda c0, c1, i, j: c0*c1, 86 | ) -> Tuple[Foo, ...]: 87 | """Applies an arbitrary bilinear operation to two polynomials. 88 | 89 | The arguments a and b give the coefficients of polynomials, defined in terms 90 | of some inner product and some exponentiation operator: 91 | 92 | P_a(z) = sum_{i=0}^{len(a)-1} . 93 | 94 | Similarly, the sequence b represents a polynomial P_b(z). 95 | 96 | Args: 97 | a: a polynomial (sequence of coefficients) 98 | b: a polynomial (sequence of coefficients) 99 | add: a function that returns the sum of two polynomial coefficients 100 | additive_identity: a addtive identity object 101 | term_product_coefficient: a callable that, given arguments c0, c1, i, j, 102 | returns d such that op(, ) = , where op 103 | is the underlying bilinear operation. 104 | 105 | Returns: 106 | a polynomial Q (tuple of coefficients), such that for any z, 107 | op(P_a(z), P_b(z)) == Q(z) 108 | where op is the underlying bilinear operation. 109 | """ 110 | # By bilinearity, 111 | # op(sum_i , sum_j ) 112 | # == sum_{ij} op(, ) 113 | # == sum_{ij} . 114 | output_degree = len(a) + len(b) - 2 115 | output = [additive_identity] * (output_degree + 1) 116 | # If a and b have length n, this takes time O(n^2). If we ever care about 117 | # large n, we could consider implementing an O(n log n) algorithm using 118 | # Fourier transforms. 119 | for i, c0 in enumerate(a): 120 | for j, c1 in enumerate(b): 121 | c = term_product_coefficient(c0, c1, i, j) 122 | output[i+j] = add(output[i+j], c) 123 | return tuple(output) 124 | 125 | 126 | def integer_power( 127 | a: Sequence[FooLike], 128 | exponent: int, 129 | add: Callable[[FooLike, FooLike], Foo] = operator.add, 130 | additive_identity: Foo = 0, 131 | multiplicative_identity: Foo = 1, 132 | term_product_coefficient: Callable[[FooLike, FooLike, int, int], Foo] 133 | = lambda c0, c1, i, j: c0*c1, 134 | term_power_coefficient: Callable[[FooLike, int, int], Foo] 135 | = lambda c, i, j: c**j, 136 | scalar_product: Callable[[int, FooLike], Foo] = operator.mul 137 | ) -> Tuple[Foo, ...]: 138 | """Returns the coefficients of a polynomial raised to a power. 139 | 140 | The arguments a gives the coefficients of a polynomial, defined in terms 141 | of some inner product and some exponentiation operator: 142 | 143 | P_a(z) = sum_{i=0}^{len(a)-1} . 144 | 145 | Let op be some bilinear, associative, and commutative operation. We define: 146 | 147 | power(a, 0) == multiplicative_identity 148 | power(a, k) = op(a, power(a, k-1)). 149 | 150 | This code uses the functions provided as arguments to efficiently compute 151 | the coefficients of the polynomial power(a, exponent). 152 | 153 | When the coefficients of P_a are intervals, this efficient computation 154 | translates into tighter intervals in the returned coefficients. 155 | 156 | Args: 157 | a: a polynomial (sequence of coefficients) 158 | exponent: a non-negative integer exponent 159 | add: a function that returns the sum of two polynomial coefficients 160 | additive_identity: a addtive identity object 161 | multiplicative_identity: a multiplicative identity object 162 | term_product_coefficient: a callable that, given arguments c0, c1, i, j, 163 | returns d such that op(, ) = , where op 164 | is the underlying bilinear operation. 165 | term_power_coefficient: given arguments c, i, and j, returns d such that: 166 | (c * z**i)**j == d * z**(i*j) 167 | scalar_product: a callable that, given as arguments a non-negative integer i 168 | and coefficient c, returns the result of adding c to itself i times. 169 | 170 | Returns: 171 | the coefficients of the polynomial P_a, raised to the exponent power. 172 | """ 173 | if exponent < 0: 174 | raise ValueError(exponent) 175 | elif exponent == 0: 176 | return (multiplicative_identity,) 177 | else: 178 | # To understand what this code is doing, it is helpful to consider the 179 | # special case where `a` is a sequence of floats, and all arguments have 180 | # their default values. Then, we just need to compute the coefficients 181 | # of the scalar polynomial: 182 | # 183 | # (a[0] + a[1]*z**1 + ... + a[k-1])**exponent 184 | # 185 | # where k = len(a). 186 | # 187 | # Using the multinomial theorem, the result is a polynomial whose ith 188 | # coefficient is: 189 | # 190 | # sum_{p in Partitions(i, exponent, k)} 191 | # (exponent choose (p_0, p_1, ..., p_{k-1})) * 192 | # Prod_{j=0}^{k-1} a[j]**p_j 193 | # 194 | # where Partitions(i, exponent, k) is the set of length-k non-negative 195 | # integer tuples whose elements sum to `exponent`, and that furthermore 196 | # satisfy sum_{j=0}^{k-1} j*p_j == i. 197 | # 198 | # The code below uses a generalization of this idea that works for an 199 | # arbitrary commutative and associative bilinear operation (rather than 200 | # just scalar multiplication). In the general version, the product 201 | # series Prod_{j=0}^{k-1} a[j]**p_j is computed via appropriate calls to 202 | # term_product_coefficient(), term_power_coefficient() and scalar_product(). 203 | 204 | def get_coeff(i: int) -> Foo: 205 | c = additive_identity 206 | for p in _iter_partitions(i, exponent, len(a)): 207 | assert sum(p) == exponent 208 | assert sum(j*p_j for j, p_j in enumerate(p)) == i 209 | running_product = multiplicative_identity 210 | running_product_power = 0 211 | for j, p_j in enumerate(p): 212 | running_product = term_product_coefficient( 213 | running_product, 214 | term_power_coefficient(a[j], j, p_j), 215 | running_product_power, 216 | j*p_j 217 | ) 218 | running_product_power += j*p_j 219 | assert running_product_power == i 220 | term = scalar_product(_multinomial_coefficient(p), running_product) 221 | c = add(c, term) 222 | return c 223 | output_degree = (len(a) - 1) * exponent 224 | return tuple(get_coeff(i) for i in range(1 + output_degree)) 225 | 226 | 227 | def _iter_partitions( 228 | n: int, m: int, k: int) -> Iterator[Tuple[int, ...]]: 229 | """Yields length-k tuples with sum m and sum_{j=1}^k (j-1)*i_j == n.""" 230 | if n < 0: 231 | raise ValueError(n) 232 | if m < 0: 233 | raise ValueError(m) 234 | if k <= 0: 235 | raise ValueError(k) 236 | if k == 1: 237 | if n == 0: 238 | yield (m,) 239 | else: 240 | for z in range(min(m+1, n // (k-1) + 1)): 241 | for p in _iter_partitions(n - (k-1)*z, m - z, k - 1): 242 | yield p + (z,) 243 | 244 | 245 | def _multinomial_coefficient(ks: Sequence[int]) -> int: 246 | """Returns (n choose (ks[0], ks[1], ...)), where n = sum(ks).""" 247 | if not ks: 248 | raise ValueError(ks) 249 | elif len(ks) == 1: 250 | return 1 251 | else: 252 | return math.comb(sum(ks), ks[0]) * _multinomial_coefficient(ks[1:]) 253 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /autobound/interval_arithmetic_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The autobound Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from autobound import interval_arithmetic 20 | from autobound import test_utils 21 | import numpy as np 22 | 23 | 24 | class TestCase(parameterized.TestCase, test_utils.TestCase): 25 | 26 | @parameterized.parameters( 27 | ((1, 2), (10, 20), (11, 22)), 28 | (1, (10, 20), (11, 21)), 29 | ((1, 2), 10, (11, 12)), 30 | (1, 10, 11), 31 | (10, np.array([1, 2]), np.array([11, 12])), 32 | (np.array([1, 2]), 10, np.array([11, 12])), 33 | ( 34 | (np.array([1, 2]), np.array([3, 4])), 35 | np.array([10, 20]), 36 | (np.array([11, 22]), np.array([13, 24])) 37 | ), 38 | ( 39 | np.array([10, 20]), 40 | (np.array([1, 2]), np.array([3, 4])), 41 | (np.array([11, 22]), np.array([13, 24])) 42 | ), 43 | ( 44 | (np.array([1, 2]), np.array([3, 4])), 45 | (np.array([10, 20]), np.array([30, 40])), 46 | (np.array([11, 22]), np.array([33, 44])) 47 | ), 48 | ) 49 | def test_add(self, a, b, expected): 50 | for np_like in self.backends: 51 | actual = interval_arithmetic.IntervalArithmetic(np_like).add(a, b) 52 | self.assert_interval_equal(expected, actual) 53 | 54 | @parameterized.parameters( 55 | # Test multiplication of scalar intervals via arbitary_bilinear(), using 56 | # all 9 combinations of signs for the interval end points. 57 | # 58 | # Test with both assume_product=True and assume_product=False. This only 59 | # changes the expected output in 1 of the 9 cases. 60 | ((2, 3), (5, 7), lambda np_like: np_like.multiply, False, (10, 21)), 61 | ((2, 3), (5, 7), lambda np_like: np_like.multiply, True, (10, 21)), 62 | ((-3, -2), (5, 7), lambda np_like: np_like.multiply, False, (-21, -10)), 63 | ((-3, -2), (5, 7), lambda np_like: np_like.multiply, True, (-21, -10)), 64 | ((-2, 3), (5, 7), lambda np_like: np_like.multiply, False, (-14, 21)), 65 | ((-2, 3), (5, 7), lambda np_like: np_like.multiply, True, (-14, 21)), 66 | ((2, 3), (-7, -5), lambda np_like: np_like.multiply, False, (-21, -10)), 67 | ((2, 3), (-7, -5), lambda np_like: np_like.multiply, True, (-21, -10)), 68 | ((-3, -2), (-7, -5), lambda np_like: np_like.multiply, False, (10, 21)), 69 | ((-3, -2), (-7, -5), lambda np_like: np_like.multiply, True, (10, 21)), 70 | ((-2, 3), (-7, -5), lambda np_like: np_like.multiply, False, (-21, 14)), 71 | ((-2, 3), (-7, -5), lambda np_like: np_like.multiply, True, (-21, 14)), 72 | ((2, 3), (-5, 7), lambda np_like: np_like.multiply, False, (-15, 21)), 73 | ((2, 3), (-5, 7), lambda np_like: np_like.multiply, True, (-15, 21)), 74 | ((-3, -2), (-5, 7), lambda np_like: np_like.multiply, False, (-21, 15)), 75 | ((-3, -2), (-5, 7), lambda np_like: np_like.multiply, True, (-21, 15)), 76 | # This is the one case where setting assume_product=False yields a looser 77 | # interval. 78 | ((-2, 3), (-5, 7), lambda np_like: np_like.multiply, False, (-29, 31)), 79 | ((-2, 3), (-5, 7), lambda np_like: np_like.multiply, True, (-15, 21)), 80 | # Test other bilinear operations. 81 | ( 82 | (np.array([2, 3]), np.array([5, 7])), 83 | (np.array([11, 13]), np.array([17, 19])), 84 | lambda np_like: np_like.dot, 85 | False, 86 | (61, 218), 87 | ), 88 | ( 89 | (np.diag([2, 3]), np.diag([5, 7])), 90 | (np.array([11, 13]), np.array([17, 19])), 91 | lambda np_like: np_like.matmul, 92 | False, 93 | (np.array([22, 39]), np.array([85, 133])), 94 | ), 95 | ( 96 | np.ones((1, 3)), 97 | (np.zeros((3,)), np.ones((3,))), 98 | lambda np_like: functools.partial(np_like.tensordot, axes=1), 99 | False, 100 | (np.array([0.]), np.array([3.])) 101 | ), 102 | ( 103 | (np.zeros((1, 3)), np.ones((1, 3))), 104 | np.ones((3,)), 105 | lambda np_like: functools.partial(np_like.tensordot, axes=1), 106 | False, 107 | (np.array([0.]), np.array([3.])) 108 | ), 109 | ( 110 | (np.zeros((1, 3)), np.ones((1, 3))), 111 | (np.zeros((3,)), np.ones((3,))), 112 | lambda np_like: functools.partial(np_like.tensordot, axes=1), 113 | False, 114 | (np.array([0.]), np.array([3.])) 115 | ), 116 | ) 117 | def test_arbitrary_bilinear(self, a, b, get_bilinear, assume_product, 118 | expected): 119 | for np_like in self.backends: 120 | bilinear = get_bilinear(np_like) 121 | arithmetic = interval_arithmetic.IntervalArithmetic(np_like) 122 | actual = arithmetic.arbitrary_bilinear(a, b, bilinear, assume_product) 123 | self.assert_interval_equal(expected, actual) 124 | 125 | @parameterized.parameters( 126 | ([1, 2, 3], 1, [1, 2, 3]), 127 | ([1, 2, 3], 2, [[1, 0, 0], [0, 2, 0], [0, 0, 3]]), 128 | ) 129 | def test_generalized_diag_interval(self, a, n, expected): 130 | for np_like in self.backends: 131 | arithmetic = interval_arithmetic.IntervalArithmetic(np_like) 132 | actual = arithmetic._generalized_diag_interval(a, n) 133 | self.assert_interval_equal(expected, actual) 134 | 135 | @parameterized.parameters( 136 | (2, 5, 10), 137 | ((2, 3), 5, (10, 15)), 138 | (2, (5, 7), (10, 14)), 139 | (np.array([2, 3]), np.array([5, 7]), np.array([10, 21])), 140 | (np.array([2, 3]), (5, 7), (np.array([10, 15]), np.array([14, 21]))), 141 | ((5, 7), np.array([2, 3]), (np.array([10, 15]), np.array([14, 21]))), 142 | # Test all 9 valid combinations of signs for scalar intervals. 143 | ((2, 3), (5, 7), (10, 21)), 144 | ((-3, -2), (5, 7), (-21, -10)), 145 | ((-2, 3), (5, 7), (-14, 21)), 146 | ((2, 3), (-7, -5), (-21, -10)), 147 | ((-3, -2), (-7, -5), (10, 21)), 148 | ((-2, 3), (-7, -5), (-21, 14)), 149 | ((2, 3), (-5, 7), (-15, 21)), 150 | ((-3, -2), (-5, 7), (-21, 15)), 151 | ((-2, 3), (-5, 7), (-15, 21)), 152 | # Same 9 combinations, but as a single test. 153 | ( 154 | ( 155 | np.array([2, -3, -2, 2, -3, -2, 2, -3, -2]), 156 | np.array([3, -2, 3, 3, -2, 3, 3, -2, 3]) 157 | ), 158 | ( 159 | np.array([5, 5, 5, -7, -7, -7, -5, -5, -5]), 160 | np.array([7, 7, 7, -5, -5, -5, 7, 7, 7]) 161 | ), 162 | ( 163 | np.array([10, -21, -14, -21, 10, -21, -15, -21, -15]), 164 | np.array([21, -10, 21, -10, 21, 14, 21, 15, 21]) 165 | ), 166 | ), 167 | ) 168 | def test_multiply(self, a, b, expected): 169 | for np_like in self.backends: 170 | actual = interval_arithmetic.IntervalArithmetic(np_like).multiply(a, b) 171 | self.assert_interval_equal(expected, actual) 172 | 173 | @parameterized.parameters( 174 | ((1, 2), (-2, -1)), 175 | (np.array([-1, 2]), np.array([1, -2])), 176 | ( 177 | (np.array([-1, 2]), np.array([10, 20])), 178 | (np.array([-10, -20]), np.array([1, -2])), 179 | ), 180 | ) 181 | def test_negative(self, a, expected): 182 | for np_like in self.backends: 183 | actual = interval_arithmetic.IntervalArithmetic(np_like).negative(a) 184 | self.assert_interval_equal(expected, actual) 185 | 186 | @parameterized.parameters( 187 | ((-1., 1.), 2, 0, (0., 1.)), 188 | ( 189 | (np.array([-1, -2]), np.array([3, 4])), 190 | 2, 191 | 0, 192 | (np.array([[0, -6], [-6, 0]]), np.array([[9, 12], [12, 16]])) 193 | ), 194 | ([2., 3.], 2, 1, np.array([4., 9.])), 195 | ) 196 | def test_outer_power(self, a, exponent, batch_dims, expected): 197 | for np_like in self.backends: 198 | actual = interval_arithmetic.IntervalArithmetic(np_like).outer_power( 199 | a, exponent, batch_dims) 200 | self.assert_interval_equal(expected, actual) 201 | 202 | @parameterized.parameters( 203 | # ndarray times ndarray 204 | (np.array([2, 3]), np.array([5, 7]), 0, np.array([[10, 14], [15, 21]])), 205 | (np.array([2, 3]), np.array([5, 7]), 1, np.array([10, 21])), 206 | # ndarray times interval 207 | ( 208 | np.array([1, 2]), 209 | (np.array([3, 4]), np.array([5, 6])), 210 | 0, 211 | (np.array([[3, 4], [6, 8]]), np.array([[5, 6], [10, 12]])), 212 | ), 213 | # interval times ndarray 214 | ( 215 | (np.array([3, 4]), np.array([5, 6])), 216 | np.array([1, 2]), 217 | 0, 218 | (np.array([[3, 6], [4, 8]]), np.array([[5, 10], [6, 12]])), 219 | ), 220 | # interval times interval 221 | ( 222 | (np.array([-1, 2]), np.array([1, 2])), 223 | (np.array([3, 4]), np.array([5, 6])), 224 | 0, 225 | (np.array([[-5, -6], [6, 8]]), np.array([[5, 6], [10, 12]])), 226 | ), 227 | ) 228 | def test_outer_product(self, a, b, batch_dims, expected): 229 | for np_like in self.backends: 230 | actual = interval_arithmetic.IntervalArithmetic(np_like).outer_product( 231 | a, b, batch_dims) 232 | self.assert_interval_equal(expected, actual) 233 | 234 | @parameterized.parameters( 235 | ((-2., 3.), 2, (0., 9.)), 236 | ((2., 3.), 2, (4., 9.)), 237 | ((-3., -2.), 2, (4., 9.)), 238 | ((-2., 3.), 3, (-8., 27.)), 239 | ((2., 3.), 3, (8., 27.)), 240 | ((-3., -2.), 3, (-27., -8.)), 241 | ((4., 9.), .5, (2., 3.)), 242 | ( 243 | (np.array([-2., 2., -3.]), np.array([3., 3., -2.])), 244 | 2, 245 | (np.array([0., 4., 4.]), np.array([9., 9., 9.])) 246 | ), 247 | ) 248 | def test_power(self, a, exponent, expected): 249 | for np_like in self.backends: 250 | actual = interval_arithmetic.IntervalArithmetic(np_like).power(a, 251 | exponent) 252 | self.assert_interval_equal(expected, actual) 253 | 254 | @parameterized.parameters( 255 | ((10, 20), (1, 2), (8, 19)), 256 | (10, (1, 2), (8, 9)), 257 | ((10, 20), 1, (9, 19)), 258 | (10, 1, 9), 259 | (10, np.array([1, 2]), np.array([9, 8])), 260 | (np.array([10, 20]), 1, np.array([9, 19])), 261 | ( 262 | (np.array([10, 20]), np.array([30, 40])), 263 | np.array([1, 2]), 264 | (np.array([9, 18]), np.array([29, 38])) 265 | ), 266 | ( 267 | np.array([10, 20]), 268 | (np.array([1, 2]), np.array([3, 4])), 269 | (np.array([7, 16]), np.array([9, 18])) 270 | ), 271 | ( 272 | (np.array([10, 20]), np.array([30, 40])), 273 | (np.array([1, 2]), np.array([3, 4])), 274 | (np.array([7, 16]), np.array([29, 38])) 275 | ), 276 | ) 277 | def test_subtract(self, a, b, expected): 278 | for np_like in self.backends: 279 | actual = interval_arithmetic.IntervalArithmetic(np_like).subtract(a, b) 280 | self.assert_interval_equal(expected, actual) 281 | 282 | @parameterized.parameters( 283 | (2, 3, 0, 6), 284 | (np.array([2, 3]), np.array([5, 7]), 0, np.array([[10, 14], [15, 21]])), 285 | (np.array([2, 3]), np.array([5, 7]), 1, 31), 286 | ((-.5, .5), (-.5, .5), 0, (-.25, .25)), 287 | ) 288 | def test_tensordot(self, a, b, axes, expected): 289 | for np_like in self.backends: 290 | actual = interval_arithmetic.IntervalArithmetic(np_like).tensordot( 291 | a, b, axes) 292 | self.assert_interval_equal(expected, actual) 293 | 294 | 295 | if __name__ == '__main__': 296 | absltest.main() 297 | -------------------------------------------------------------------------------- /autobound/enclosure_arithmetic_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The autobound Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from absl.testing import parameterized 17 | from autobound import enclosure_arithmetic 18 | from autobound import test_utils 19 | import numpy as np 20 | 21 | 22 | class TestCase(parameterized.TestCase, test_utils.TestCase): 23 | 24 | @parameterized.parameters( 25 | (100, (0, .25), (1,), (10,), (11,)), 26 | (100, (0, .25), (1, 2), (10,), (11, 2)), 27 | (100, (0, .25), (1, 2), (10, 20), (11, 22)), 28 | (100, (0, .25), (1,), (10, 20), (11, 20)), 29 | (100, (0, .25), (1, (2, 3),), (10, (20, 30),), (11, (22, 33),)), 30 | # Test a case where we truncate the sum. 31 | # 1+10+(2+20)*.25 == 16.5 32 | (0, (0, .25), (1, 2), (10., 20.), ((11., 16.5),)), 33 | ( 34 | 100, 35 | (np.array([0, 0,]), np.array([1, 1])), 36 | (np.array([1, 2]), 3*np.eye(2)), 37 | (np.array([10, 20]), 30*np.eye(2)), 38 | (np.array([11, 22]), 33*np.eye(2)), 39 | ), 40 | ) 41 | def test_add(self, max_degree, trust_region, a, b, expected): 42 | for np_like in self.backends: 43 | arithmetic = enclosure_arithmetic.TaylorEnclosureArithmetic( 44 | max_degree, trust_region, np_like) 45 | actual = arithmetic.add(a, b) 46 | self.assert_enclosure_equal(expected, actual) 47 | 48 | @parameterized.parameters( 49 | ( 50 | 100, 51 | (0, .5), 52 | np, 53 | ([2],), 54 | ([3],), 55 | lambda u, v, p, q: np.einsum('i...,i...->i...', u, v), 56 | ([6],) 57 | ), 58 | ) 59 | def test_arbitrary_bilinear(self, max_degree, trust_region, np_like, a, b, 60 | pairwise_batched_bilinear, expected): 61 | arithmetic = enclosure_arithmetic.TaylorEnclosureArithmetic( 62 | max_degree, trust_region, np_like) 63 | actual = arithmetic.arbitrary_bilinear(a, b, pairwise_batched_bilinear) 64 | self.assert_enclosure_equal(expected, actual) 65 | 66 | @parameterized.parameters( 67 | ( 68 | 2, 69 | (0, .5), 70 | np, 71 | (2, 3), 72 | (5, 7), 73 | # 2+3(7z) = 2+21z 74 | (2, 21) 75 | ), 76 | ( 77 | 2, 78 | (0, .5), 79 | np, 80 | (3, 3, (2, 4)), 81 | (1, 1), 82 | # a(x) = 1 + 1(x-x0). 83 | # a(x)-a(x0) = x 84 | (3, 3, (2, 4)), 85 | ), 86 | ( 87 | 2, 88 | (-10, 10), 89 | np, 90 | (1, (-1, 1)), 91 | (-1, 1), 92 | (1, (-1, 1)), 93 | ), 94 | ( 95 | 2, 96 | (np.array([-10, -10, -10]), np.array([10, 10, 10])), 97 | np, 98 | (np.array([1, 0, 2]), (np.array([-1, -1, -1]), np.array([1, 1, 1]))), 99 | (np.array([-1, 0, 2]), np.eye(3)), 100 | (np.array([1, 0, 2]), (-np.eye(3), np.eye(3))), 101 | ), 102 | ( 103 | 2, 104 | (np.zeros((1,)), np.ones((1,))), 105 | np, 106 | (np.zeros((2, 3)), np.zeros((2, 3))), 107 | (np.zeros((2, 3)), np.zeros((2, 3, 1))), 108 | (np.zeros((2, 3)), np.zeros((2, 3, 1))), 109 | ), 110 | ) 111 | def test_compose_enclosures( 112 | self, max_degree, trust_region, np_like, scalar_enclosure, arg_enclosure, 113 | expected): 114 | arithmetic = enclosure_arithmetic.TaylorEnclosureArithmetic( 115 | max_degree, trust_region, np_like) 116 | actual = arithmetic.compose_enclosures(scalar_enclosure, arg_enclosure) 117 | self.assert_enclosure_equal(expected, actual) 118 | 119 | @parameterized.parameters( 120 | ((2,), (0, .5), 0, (2,)), 121 | ((2, 3), (0, .5), 1, (2, 3)), 122 | ((2, 3), (0, .5), 0, ((2., 3.5),)), 123 | ((11, 22), (0, .25), 0, ((11., 16.5),)), # 11+22/4 == 16.5 124 | # A multivariate linear enclosure, truncated to rank 0. 125 | ( 126 | (np.array([2., 3.]), np.diag([20, 30])), 127 | (np.array([0, 0]), np.array([.25, .5])), 128 | 0, 129 | ((np.array([2., 3.]), np.array([7., 18.])),) 130 | ), 131 | ) 132 | def test_enclose_enclosure(self, enclosure, trust_region, max_degree, 133 | expected): 134 | for np_like in self.backends: 135 | actual = enclosure_arithmetic.enclose_enclosure( 136 | enclosure, trust_region, max_degree, np_like) 137 | self.assert_enclosure_equal(expected, actual) 138 | 139 | @parameterized.parameters( 140 | (np.array(0.), 3, None, np.zeros((1, 1, 1))), 141 | (np.zeros((3, 5)), 2, None, np.zeros((3, 5, 1, 1))), 142 | (np.zeros((3, 5)), 2, 1, np.zeros((3, 1, 1, 5))), 143 | ) 144 | def test_expand_multiple_dims(self, a, n, axis, expected): 145 | actual = enclosure_arithmetic.expand_multiple_dims(a, n, axis) 146 | self.assert_allclose_strict(expected, actual) 147 | 148 | @parameterized.parameters( 149 | (100, (0, .5), (2,), (3,), (6,)), 150 | (0, (0, .5), (2,), (3,), (6,)), 151 | (100, (0, .5), (2, 3), (5,), (10, 15)), 152 | (0, (0, .5), (2, 3), (5,), ((10., 17.5),)), 153 | (0, (0, .25), (2,), (np.ones((3,)),), (2*np.ones((3,)),)), 154 | (2, (-10, 10), (0,), ((-1, 1),), ((0, 0),)), 155 | # Multiplication of degree-0 enclosures should work like np.multiply, 156 | # and should support broadcasting. 157 | (0, (0, .5), (np.ones((2, 3)),), (5,), (5*np.ones((2, 3)),)), 158 | (0, (0, .5), (np.ones((2, 3)),), (np.array([20., 30., 50.]),), 159 | (np.array([[20., 30., 50.], [20., 30., 50.]]),)), 160 | (100, (0, .5), (2, 3), (0, 1), (0, 2, 3)), 161 | (100, (0, .5), (2, [3]), (0, 1), (0, np.array([2]), np.array([3]))), 162 | ( 163 | 100, 164 | (np.zeros((2,)), np.ones((2,))), 165 | (np.zeros((2,)), np.eye(2)), 166 | (np.zeros((2,)), np.eye(2)), 167 | (np.zeros((2,)), np.zeros((2, 2)), 168 | np.array([[[1., 0.], [0., 0.]], [[0., 0.], [0., 1.]]])) 169 | ), 170 | (100, (-13., 17.), ((-.5, .5),), ((-.5, .5),), ((-.25, .25),)), 171 | ) 172 | def test_multiply(self, max_degree, trust_region, a, b, expected): 173 | for np_like in self.backends: 174 | arithmetic = enclosure_arithmetic.TaylorEnclosureArithmetic( 175 | max_degree, trust_region, np_like) 176 | actual = arithmetic.multiply(a, b) 177 | self.assert_enclosure_equal(expected, actual) 178 | 179 | @parameterized.parameters( 180 | (100, (np.zeros((2,)), np.ones((2,))), (np.array([-1, 2]),), 181 | (np.array([1, -2]),)), 182 | (100, (0, .25), (1, 2), (-1, -2)), 183 | (0, (0, .25), (1, 2), ((-1.5, -1.),)), 184 | ) 185 | def test_negative(self, max_degree, trust_region, a, expected): 186 | for np_like in self.backends: 187 | arithmetic = enclosure_arithmetic.TaylorEnclosureArithmetic( 188 | max_degree, trust_region, np_like) 189 | actual = arithmetic.negative(a) 190 | self.assert_enclosure_equal(expected, actual) 191 | 192 | @parameterized.parameters( 193 | (2, 3, 0, 0, 6), 194 | (np.zeros((2, 2)), np.zeros((2, 2)), 1, 1, np.zeros((2, 2, 2))), 195 | (np.zeros((2, 3, 5)), np.zeros((1, 1, 7, 11)), 1, 2, 196 | np.zeros((2, 3, 5, 7, 11))), 197 | ) 198 | def test_pairwise_batched_multiply(self, u, v, p, q, expected): 199 | for np_like in self.backends: 200 | actual = enclosure_arithmetic._pairwise_batched_multiply(u, v, p, q, 201 | np_like) 202 | self.assert_allclose_strict(expected, actual) 203 | 204 | @parameterized.parameters( 205 | (100, (0, 1), (2,), 0, (1,)), 206 | (100, (0, 1), (2,), 3, (8,)), 207 | ( 208 | 2, 209 | (np.zeros((3,)), np.array([1, 2, 3])), 210 | (np.zeros((3,)), np.eye(3)), 211 | 4, 212 | (np.zeros((3,)), np.zeros((3, 3)), 213 | ( 214 | np.zeros((3, 3, 3)), 215 | np.array([[[1., 0., 0.], [0., 0., 0.], [0., 0., 0.]], 216 | [[0., 0., 0.], [0., 4., 0.], [0., 0., 0.]], 217 | [[0., 0., 0.], [0., 0., 0.], [0., 0., 9.]]]) 218 | ) 219 | ) 220 | ), 221 | ( 222 | 100, 223 | (np.zeros((3,)), np.ones((3,))), 224 | (.5*np.ones((3,)),), 225 | 0, 226 | (np.ones((3,)),) 227 | ), 228 | ) 229 | def test_power(self, max_degree, trust_region, enclosure, p, expected): 230 | for np_like in self.backends: 231 | arithmetic = enclosure_arithmetic.TaylorEnclosureArithmetic( 232 | max_degree, trust_region, np_like) 233 | actual = arithmetic.power(enclosure, p) 234 | self.assert_enclosure_equal(expected, actual) 235 | 236 | @parameterized.parameters( 237 | (100, (0, .5), (1, 2, 3), (10, 20), (-9, -18, 3)), 238 | (1, (0, .5), (1, 2, 3), (10, 20), (-9, (-18., -16.5))), 239 | ) 240 | def test_subtract(self, max_degree, trust_region, a, b, expected): 241 | for np_like in self.backends: 242 | arithmetic = enclosure_arithmetic.TaylorEnclosureArithmetic( 243 | max_degree, trust_region, np_like) 244 | actual = arithmetic.subtract(a, b) 245 | self.assert_enclosure_equal(expected, actual) 246 | 247 | @parameterized.named_parameters( 248 | ( 249 | 'scalar_ndarray', 250 | 2., 251 | 3., 252 | 5, 253 | 7, 254 | 0, 255 | 6. 256 | ), 257 | ( 258 | 'scalar_interval', 259 | (-2., 3.), 260 | (-5., 7.), 261 | 11, 262 | 13, 263 | 0, 264 | (-15., 21.) 265 | ), 266 | ( 267 | 'vector_ndarray', 268 | # [2, 3] * <[[5, 0, 0], [0, 0, 7]], z> 269 | # == <[[10, 0, 0], [0, 0, 21]], z > 270 | [2., 3.], 271 | [[5., 0., 0.], [0., 0., 7.]], 272 | 0, 273 | 1, 274 | 1, 275 | np.array([[10., 0., 0.], [0., 0., 21.]]), 276 | ) 277 | ) 278 | def test_elementwise_term_product_coefficient( 279 | self, c0, c1, i, j, x_ndim, expected): 280 | for np_like in self.backends: 281 | actual = enclosure_arithmetic._elementwise_term_product_coefficient( 282 | c0, c1, i, j, x_ndim, np_like) 283 | self.assert_interval_equal(expected, actual) 284 | 285 | @parameterized.named_parameters( 286 | ( 287 | 'scalar_ndarray', 288 | 2., 289 | 3, 290 | 5, 291 | 0, 292 | 32. 293 | ), 294 | ( 295 | 'scalar_interval', 296 | (-.5, .5), 297 | 3, 298 | 2, 299 | 0, 300 | (0., .25) 301 | ), 302 | ( 303 | 'vector_interval', 304 | ([-.5, 0.], [.5, 0.]), 305 | 1, 306 | 2, 307 | 1, 308 | (np.array([[0., 0.], [0., 0.]]), np.array([[0.25, 0.], [0., 0.]])) 309 | ), 310 | ( 311 | 'vector_ndarray_constant_term', 312 | [2., 3.], 313 | 0, 314 | 2, 315 | 5, 316 | np.array([4., 9.]), 317 | ), 318 | ) 319 | def test_elementwise_term_power_coefficient( 320 | self, c, i, exponent, x_ndim, expected): 321 | for np_like in self.backends: 322 | actual = enclosure_arithmetic._elementwise_term_power_coefficient( 323 | c, i, exponent, x_ndim, np_like) 324 | self.assert_interval_equal(expected, actual) 325 | 326 | @parameterized.parameters( 327 | (2, 3, 6), 328 | ( 329 | 5 * np.ones((2,)), 330 | 7 * np.ones((2, 3)), 331 | 35 * np.ones((2, 3)) 332 | ), 333 | ) 334 | def test_left_broadcasting_multiply(self, a, b, expected): 335 | for np_like in self.backends: 336 | actual = enclosure_arithmetic._left_broadcasting_multiply(a, b, np_like) 337 | self.assert_allclose_strict(expected, actual) 338 | 339 | 340 | if __name__ == '__main__': 341 | absltest.main() 342 | -------------------------------------------------------------------------------- /autobound/primitive_enclosures.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The autobound Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A library of Taylor enclosures for various primitive functions. 16 | 17 | For now this only supports elementwise functions, but in the future it will 18 | support other multivariate functions. 19 | """ 20 | 21 | import functools 22 | import math 23 | from typing import Callable, Sequence, Tuple 24 | 25 | from autobound import elementwise_functions 26 | # pylint: disable=g-multiple-import 27 | from autobound.types import ( 28 | Interval, NumpyLike, ElementwiseTaylorEnclosure, NDArray, NDArrayLike) 29 | 30 | 31 | def get_elementwise_taylor_enclosure( 32 | function_id: elementwise_functions.FunctionId, 33 | x0: NDArray, 34 | trust_region: Interval, 35 | degree: int, 36 | np_like: NumpyLike) -> ElementwiseTaylorEnclosure: 37 | """Returns ElementwiseTaylorEnclosure for function with given ID. 38 | 39 | Args: 40 | function_id: an `elementwise_functions.FunctionId` 41 | x0: reference point 42 | trust_region: trust region over which enclosure is valid 43 | degree: the degree of the returned `ElementwiseTaylorEnclosure` 44 | np_like: a `NumpyLike` backend 45 | 46 | Returns: 47 | an `ElementwiseTaylorEnclosure` for the elementwise function specified by 48 | `function_id`. 49 | """ 50 | f = elementwise_functions.get_function(function_id, np_like) 51 | deriv_id = function_id.derivative_id(degree) 52 | deriv_data = elementwise_functions.get_function_data(deriv_id) 53 | taylor_coefficients = functools.partial( 54 | elementwise_functions.get_taylor_polynomial_coefficients, 55 | function_id, x0=x0, np_like=np_like) 56 | if (deriv_data.monotonically_increasing or 57 | deriv_data.monotonically_decreasing): 58 | return sharp_enclosure_monotonic_derivative( 59 | x0, degree, trust_region, f, taylor_coefficients(degree), 60 | deriv_data.monotonically_increasing, np_like) 61 | elif degree == 2 and deriv_data.even_symmetric: 62 | return sharp_quadratic_enclosure_even_symmetric_hessian( 63 | x0, trust_region, f, taylor_coefficients(degree), np_like) 64 | else: 65 | # For indices where the derivative is monotonically decreasing or 66 | # monotonically increasing over the trust region, we return the sharp 67 | # enclosure. For other indices, we fall back to using the enclosure 68 | # based on the range of the derivative. 69 | coeffs = taylor_coefficients(degree) 70 | enclosure_if_decreasing, enclosure_if_increasing = [ 71 | sharp_enclosure_monotonic_derivative( 72 | x0, degree, trust_region, f, coeffs, increasing, np_like) 73 | for increasing in [False, True] 74 | ] 75 | decreasing, increasing = deriv_data.monotone_over(trust_region, np_like) 76 | deriv_range = elementwise_functions.get_range(deriv_id, trust_region, 77 | np_like) 78 | fallback = bounded_derivative_enclosure(degree, coeffs[:-1], deriv_range) 79 | def endpoint(i: int): 80 | return np_like.where( 81 | decreasing, 82 | enclosure_if_decreasing[-1][i], 83 | np_like.where( 84 | increasing, 85 | enclosure_if_increasing[-1][i], 86 | fallback[-1][i] 87 | ) 88 | ) 89 | final_interval = (endpoint(0), endpoint(1)) 90 | return ElementwiseTaylorEnclosure( 91 | tuple(coeffs[:degree]) + (final_interval,)) 92 | 93 | 94 | abs_enclosure = functools.partial(get_elementwise_taylor_enclosure, 95 | elementwise_functions.ABS) 96 | exp_enclosure = functools.partial(get_elementwise_taylor_enclosure, 97 | elementwise_functions.EXP) 98 | log_enclosure = functools.partial(get_elementwise_taylor_enclosure, 99 | elementwise_functions.LOG) 100 | sigmoid_enclosure = functools.partial(get_elementwise_taylor_enclosure, 101 | elementwise_functions.SIGMOID) 102 | softplus_enclosure = functools.partial(get_elementwise_taylor_enclosure, 103 | elementwise_functions.SOFTPLUS) 104 | swish_enclosure = functools.partial(get_elementwise_taylor_enclosure, 105 | elementwise_functions.SWISH) 106 | 107 | 108 | # TODO(mstreeter): we could implement pow_enclosure in terms of 109 | # get_elementwise_taylor_enclosure if we allowed FunctionIds to have parameters 110 | # (in this case, the exponent). 111 | def pow_enclosure(exponent: float, 112 | x0: NDArray, 113 | trust_region: Interval, 114 | degree: int, 115 | np_like: NumpyLike) -> ElementwiseTaylorEnclosure: 116 | """Returns an ElementwiseTaylorEnclosure for x**exponent in terms of x-x0.""" 117 | # The kth derivative of x**p is p * (p-1) * ... * (p-k) * x0**(p-k) 118 | taylor_coefficients_at_x0 = [] 119 | c = 1. 120 | i_factorial = 1. 121 | for i in range(degree + 1): 122 | if i > 0: 123 | i_factorial *= i 124 | # Note: the next line can sometimes generate bogus RuntimeWarnings when 125 | # using Numpy. This seems to be a bug in Numpy, as even doing 126 | # np.array(2.)**-1 generates the same RuntimeWarning. 127 | taylor_coefficients_at_x0.append(c * x0**(exponent - i) / i_factorial) 128 | if i < degree: 129 | c *= exponent - i 130 | 131 | # Compute sharp enclosures for two cases: x-x0 > 0 (enc_pos below), and 132 | # x-x0 < 0 (enc_neg below). 133 | # 134 | # The kth derivative of x**p at x is c*x**(p-k), where c 135 | # is negative if k is odd, and positive if k is even. 136 | # For x > 0, c*x**(p-k) is decreasing if c is positive, and 137 | # decreasing otherwise. 138 | # For x < 0 and even p-k, the situation is the same. 139 | # For x < 0 and odd p-k, the situation is reversed: c*x**(p-k) is increasing 140 | # if c is positive and decreasing otherwise. 141 | # pylint: disable=g-complex-comprehension 142 | enc_decreasing, enc_increasing = [ 143 | sharp_enclosure_monotonic_derivative( 144 | x0, degree, trust_region, lambda x: x**exponent, 145 | taylor_coefficients_at_x0, increasing, np_like 146 | ) 147 | for increasing in [False, True] 148 | ] 149 | enc_pos = enc_decreasing if c > 0 else enc_increasing 150 | if exponent - degree % 2 == 0: 151 | enc_neg = enc_pos 152 | else: 153 | enc_neg = enc_increasing if c > 0 else enc_decreasing 154 | 155 | def interval_endpoint(i): 156 | """Returns left (i == 0) or right (i == 1) endpoint of interval.""" 157 | a, b = trust_region 158 | endpoint_if_positive = enc_pos[-1][i] 159 | if int(exponent) != exponent: 160 | # If exponent is not an integer, then z**exponent is undefined for z < 0. 161 | # We return the interval (-inf, inf) in this case. 162 | endpoint_if_negative = -np_like.inf if i == 0 else np_like.inf 163 | endpoint_if_possibly_zero = endpoint_if_negative 164 | elif exponent < 0: 165 | endpoint_if_negative = enc_neg[-1][i] 166 | endpoint_if_possibly_zero = -np_like.inf if i == 0 else np_like.inf 167 | else: 168 | endpoint_if_negative = enc_neg[-1][i] 169 | endpoint_if_possibly_zero = functools.reduce( 170 | np_like.minimum if i == 0 else np_like.maximum, 171 | [endpoint_if_positive, endpoint_if_negative] 172 | ) 173 | return np_like.where( 174 | a >= 0, 175 | endpoint_if_positive, 176 | np_like.where( 177 | b <= 0, 178 | endpoint_if_negative, 179 | endpoint_if_possibly_zero 180 | ) 181 | ) 182 | 183 | interval_coefficient = tuple(interval_endpoint(i) for i in [0, 1]) 184 | return ElementwiseTaylorEnclosure( 185 | enc_decreasing[:-1] + (interval_coefficient,)) 186 | 187 | 188 | def bounded_derivative_enclosure( 189 | degree: int, 190 | taylor_coefficients_at_x0: Sequence[NDArray], 191 | derivative_bound: Tuple[NDArray, NDArray] 192 | ) -> ElementwiseTaylorEnclosure: 193 | if len(taylor_coefficients_at_x0) != degree: 194 | raise ValueError() 195 | degree_factorial = math.factorial(degree) 196 | final_interval = (derivative_bound[0] / degree_factorial, 197 | derivative_bound[1] / degree_factorial) 198 | return ElementwiseTaylorEnclosure( 199 | tuple(taylor_coefficients_at_x0[:degree]) + (final_interval,) 200 | ) 201 | 202 | 203 | def sharp_enclosure_monotonic_derivative( 204 | x0: NDArray, 205 | degree: int, 206 | trust_region: Interval, 207 | sigma: Callable[[NDArray], NDArray], 208 | taylor_coefficients_at_x0: Sequence[NDArray], 209 | increasing: bool, 210 | np_like: NumpyLike 211 | ) -> ElementwiseTaylorEnclosure: 212 | """Returns sharp degree-k enclosure assuming monotone k-th derivative. 213 | 214 | Args: 215 | x0: the center point for the Taylor enclosure 216 | degree: the degree of the enclosure to return 217 | trust_region: the trust region over which to compute an enclosure 218 | sigma: the function for which to compute a sharp polynomial enclosure 219 | taylor_coefficients_at_x0: the first (degree+1) coefficients of the 220 | Taylor series expansion of sigma at x0. 221 | increasing: whether the (degree)th derivative of sigma is increasing 222 | or decreasing 223 | np_like: a NumpyLike backend 224 | 225 | Returns: 226 | a sharp ElementwiseTaylorEnclosure for sigma 227 | """ 228 | if degree < 0: 229 | raise ValueError(degree) 230 | ratio = functools.partial(taylor_remainder_ratio, 231 | x0, degree, sigma, 232 | taylor_coefficients_at_x0, 233 | np_like=np_like) 234 | a, b = trust_region 235 | if increasing: 236 | final_interval = (ratio(a), ratio(b)) 237 | else: 238 | final_interval = (ratio(b), ratio(a)) 239 | return ElementwiseTaylorEnclosure( 240 | tuple(taylor_coefficients_at_x0[:degree]) + (final_interval,) 241 | ) 242 | 243 | 244 | def sharp_quadratic_enclosure_even_symmetric_hessian( 245 | x0: NDArray, 246 | trust_region: Interval, 247 | sigma: Callable[[NDArray], NDArray], 248 | taylor_coefficients_at_x0: Sequence[NDArray], 249 | np_like: NumpyLike 250 | ) -> ElementwiseTaylorEnclosure: 251 | """Returns sharp quadratic enclosure for function with even-symmetric Hessian. 252 | 253 | It's assumed that the Hessian is decreasing at z >= 0. 254 | 255 | Args: 256 | x0: the center point for the Taylor enclosure 257 | trust_region: the trust region over which to compute an enclosure 258 | sigma: an elementwise function for which to compute a Taylor enclosure 259 | taylor_coefficients_at_x0: the first two coefficients of the 260 | Taylor series expansion of sigma at x0. 261 | np_like: a Numpy-like back end. 262 | """ 263 | ratio = functools.partial(taylor_remainder_ratio, 264 | x0, 2, sigma, 265 | taylor_coefficients_at_x0, 266 | np_like=np_like) 267 | 268 | a, b = trust_region 269 | max_ratio = ratio(np_like.clip(-x0, a, b)) 270 | min_ratio = np_like.minimum(ratio(a), ratio(b)) 271 | final_interval = (min_ratio, max_ratio) 272 | return ElementwiseTaylorEnclosure( 273 | tuple(taylor_coefficients_at_x0[:2]) + (final_interval,) 274 | ) 275 | 276 | 277 | def taylor_remainder_ratio( 278 | x0: NDArray, 279 | degree: int, 280 | sigma: Callable[[NDArray], NDArray], 281 | taylor_coefficients_at_x0: Sequence[NDArray], 282 | x: NDArray, 283 | np_like: NumpyLike): 284 | """Returns R_{degree - 1}(x; sigma, x0) / (x - x0)**degree.""" 285 | if len(taylor_coefficients_at_x0) != degree + 1: 286 | raise ValueError(degree, taylor_coefficients_at_x0) 287 | # Let r_k denote the degree k Taylor series remainder. 288 | # 289 | # Letting k = degree, we want to return r_{k-1} / (x-x0)**k, but in a way that 290 | # is numerically stable when x-x0 is small (and that is well-defined when 291 | # x=x0). 292 | # 293 | # We do so using: 294 | # r_{k-1} / (x-x0)**k = (r_k + c_k*(x-x0)**k) / (x-x0)**k 295 | # = c_k + r_k / (x-x0)**k. 296 | r_k = sigma(x) - sum( 297 | c * (x - x0)**i for i, c in enumerate(taylor_coefficients_at_x0)) 298 | denom = (x-x0)**degree 299 | return ( 300 | taylor_coefficients_at_x0[degree] + 301 | # Return r_k * 1 / denom, capping the magnitude of 1 / denom at 1e12. 302 | # TODO(mstreeter): this results in an enclosure that's not strictly valid 303 | # when denom is very small. 304 | r_k * np_like.sign(denom) / (np_like.maximum(1e-12, np_like.abs(denom))) 305 | ) 306 | -------------------------------------------------------------------------------- /autobound/elementwise_functions_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The autobound Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import itertools 16 | import math 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from autobound import elementwise_functions 21 | from autobound import test_utils 22 | import jax 23 | import jax.numpy as jnp 24 | import numpy as np 25 | 26 | 27 | def get_jax_callable(function_id): 28 | jax_funs = { 29 | elementwise_functions.ABS.name: jnp.abs, 30 | elementwise_functions.EXP.name: jnp.exp, 31 | elementwise_functions.LOG.name: jnp.log, 32 | elementwise_functions.SIGMOID.name: jax.nn.sigmoid, 33 | elementwise_functions.SOFTPLUS.name: jax.nn.softplus, 34 | elementwise_functions.SWISH.name: jax.nn.swish, 35 | } 36 | if (function_id.name == elementwise_functions.ABS.name and 37 | function_id.derivative_order == 0): 38 | # Don't use jax.grad for the ABS function, because jax.grad(jnp.abs)(0.) 39 | # == 1, whereas our tests assume that every local minimum should have a 40 | # gradient of 0. 41 | return jnp.sign 42 | if function_id.name in jax_funs: 43 | f = jax_funs[function_id.name] 44 | for _ in range(function_id.derivative_order): 45 | f = jax.grad(f) 46 | return f 47 | else: 48 | raise NotImplementedError(function_id) 49 | 50 | 51 | class TestCase(test_utils.TestCase, parameterized.TestCase): 52 | 53 | @parameterized.parameters( 54 | (elementwise_functions.EXP,), 55 | (elementwise_functions.EXP.derivative_id(17),), 56 | (elementwise_functions.LOG,), 57 | (elementwise_functions.LOG.derivative_id(1),), 58 | (elementwise_functions.LOG.derivative_id(2),), 59 | (elementwise_functions.LOG.derivative_id(3),), 60 | (elementwise_functions.LOG.derivative_id(4),), 61 | (elementwise_functions.SIGMOID,), 62 | (elementwise_functions.SIGMOID.derivative_id(1),), 63 | (elementwise_functions.SIGMOID.derivative_id(2),), 64 | (elementwise_functions.SIGMOID.derivative_id(3),), 65 | (elementwise_functions.SIGMOID.derivative_id(4),), 66 | (elementwise_functions.SOFTPLUS,), 67 | (elementwise_functions.SOFTPLUS.derivative_id(1),), 68 | (elementwise_functions.SOFTPLUS.derivative_id(2),), 69 | (elementwise_functions.SOFTPLUS.derivative_id(3),), 70 | (elementwise_functions.SOFTPLUS.derivative_id(4),), 71 | (elementwise_functions.SOFTPLUS.derivative_id(5),), 72 | (elementwise_functions.SWISH,), 73 | (elementwise_functions.SWISH.derivative_id(1),), 74 | (elementwise_functions.SWISH.derivative_id(2),), 75 | (elementwise_functions.SWISH.derivative_id(3),), 76 | (elementwise_functions.SWISH.derivative_id(4),), 77 | ) 78 | def test_get_function(self, function_id): 79 | actual = elementwise_functions.get_function(function_id, jnp) 80 | expected = get_jax_callable(function_id) 81 | for test_x in [-1000., -1., -.5, 0., .5, 1., 1000.]: 82 | if function_id.name == 'log' and test_x <= 0.: 83 | continue 84 | np.testing.assert_allclose(expected(test_x), actual(test_x)) 85 | 86 | @parameterized.parameters( 87 | (elementwise_functions.EXP, 88 | elementwise_functions.FunctionData((), (), 89 | monotonically_increasing=True)), 90 | (elementwise_functions.EXP.derivative_id(17), 91 | elementwise_functions.FunctionData((), (), 92 | monotonically_increasing=True)), 93 | ) 94 | def test_get_function_data(self, function_id, expected): 95 | actual = elementwise_functions.get_function_data(function_id) 96 | if expected is not None: 97 | self.assertEqual(expected, actual) 98 | self.sanity_check_function_data(function_id, actual) 99 | 100 | def test_all_function_data(self): 101 | for function_id, function_data in ( 102 | elementwise_functions._FUNCTION_DATA.items()): 103 | self.sanity_check_function_data(function_id, function_data) 104 | 105 | @parameterized.parameters( 106 | (elementwise_functions.SIGMOID, (-1., 1.), 107 | (test_utils.sigmoid(-1.), test_utils.sigmoid(1.))), 108 | (elementwise_functions.SIGMOID, (-1e6, 1e6), (0., 1.)), 109 | (elementwise_functions.SIGMOID.derivative_id(1), 110 | (-2., 1.), 111 | (test_utils.sigmoid_derivative(1, -2.), test_utils.MAX_SIGMOID_DERIV)), 112 | (elementwise_functions.SIGMOID.derivative_id(1), 113 | (-2., -1.), 114 | (test_utils.sigmoid_derivative(1, -2.), 115 | test_utils.sigmoid_derivative(1, -1.))), 116 | (elementwise_functions.SIGMOID.derivative_id(1), (-1., 3.), 117 | (test_utils.sigmoid_derivative(1, 3.), test_utils.MAX_SIGMOID_DERIV)), 118 | (elementwise_functions.SIGMOID.derivative_id(1), (1., 3.), 119 | (test_utils.sigmoid_derivative(1, 3.), 120 | test_utils.sigmoid_derivative(1, 1.))), 121 | (elementwise_functions.SIGMOID.derivative_id(1), (-1e6, 1e6), (0., .25)), 122 | (elementwise_functions.SIGMOID.derivative_id(2), (-1e6, -4.), 123 | (0., test_utils.sigmoid_derivative(2, -4.))), 124 | (elementwise_functions.SIGMOID.derivative_id(2), (-4., -2.), 125 | (test_utils.sigmoid_derivative(2, -4.), 126 | test_utils.sigmoid_derivative(2, -2.))), 127 | (elementwise_functions.SIGMOID.derivative_id(2), (-2., -.5), 128 | (test_utils.sigmoid_derivative(2, -.5), 129 | test_utils.MAX_SIGMOID_SECOND_DERIV)), 130 | (elementwise_functions.SIGMOID.derivative_id(2), (-.5, .5), 131 | (test_utils.sigmoid_derivative(2, .5), 132 | test_utils.sigmoid_derivative(2, -.5))), 133 | (elementwise_functions.SIGMOID.derivative_id(2), (.5, 2.), 134 | (test_utils.MIN_SIGMOID_SECOND_DERIV, 135 | test_utils.sigmoid_derivative(2, .5))), 136 | (elementwise_functions.SIGMOID.derivative_id(2), (2., 5.), 137 | (test_utils.sigmoid_derivative(2, 2.), 138 | test_utils.sigmoid_derivative(2, 5.))), 139 | (elementwise_functions.SIGMOID.derivative_id(2), (5., 1e6), 140 | (test_utils.sigmoid_derivative(2, 5.), 0.)), 141 | (elementwise_functions.SIGMOID.derivative_id(3), (-1e6, -4.), 142 | (0., test_utils.sigmoid_derivative(3, -4.))), 143 | (elementwise_functions.SIGMOID.derivative_id(3), (-4., -3.), 144 | (test_utils.sigmoid_derivative(3, -4.), 145 | test_utils.sigmoid_derivative(3, -3.))), 146 | (elementwise_functions.SIGMOID.derivative_id(3), (-3., -1.), 147 | (test_utils.sigmoid_derivative(3, -1.), 148 | test_utils.MAX_SIGMOID_THIRD_DERIV)), 149 | (elementwise_functions.SIGMOID.derivative_id(3), (-1., .5), 150 | (test_utils.MIN_SIGMOID_THIRD_DERIV, 151 | test_utils.sigmoid_derivative(3, -1.))), 152 | (elementwise_functions.SIGMOID.derivative_id(3), (.5, 1.), 153 | (test_utils.sigmoid_derivative(3, .5), 154 | test_utils.sigmoid_derivative(3, 1.))), 155 | (elementwise_functions.SIGMOID.derivative_id(3), (1., 3.), 156 | (test_utils.sigmoid_derivative(3, 1.), 157 | test_utils.MAX_SIGMOID_THIRD_DERIV)), 158 | (elementwise_functions.SIGMOID.derivative_id(3), (3., 4.), 159 | (test_utils.sigmoid_derivative(3, 4.), 160 | test_utils.sigmoid_derivative(3, 3.))), 161 | (elementwise_functions.SIGMOID.derivative_id(3), (4., 1e6), 162 | (0., test_utils.sigmoid_derivative(3, 4.))), 163 | (elementwise_functions.SOFTPLUS.derivative_id(1), (-1., 1.), 164 | (test_utils.sigmoid(-1.), test_utils.sigmoid(1.))), 165 | (elementwise_functions.SWISH, (1., 2.), 166 | (test_utils.swish(1.), test_utils.swish(2.))), 167 | ) 168 | def test_get_range(self, function_id, trust_region, expected): 169 | if callable(expected): 170 | expected = expected() 171 | for np_like in self.backends: 172 | actual = elementwise_functions.get_range(function_id, trust_region, 173 | np_like) 174 | self.assert_interval_equal(expected, actual) 175 | 176 | @parameterized.parameters( 177 | ( 178 | elementwise_functions.EXP, 179 | 0, 180 | 3.14, 181 | (math.exp(3.14),) 182 | ), 183 | ( 184 | elementwise_functions.EXP, 185 | 2, 186 | 3.14, 187 | (math.exp(3.14), math.exp(3.14), math.exp(3.14)/2) 188 | ), 189 | ( 190 | elementwise_functions.LOG, 191 | 0, 192 | 3.14, 193 | (math.log(3.14),) 194 | ), 195 | ( 196 | elementwise_functions.LOG, 197 | 2, 198 | 3.14, 199 | (math.log(3.14), 1/3.14, -.5/3.14**2) 200 | ), 201 | ) 202 | def test_get_taylor_polynomial_coefficients(self, function_id, degree, x0, 203 | expected): 204 | for np_like in self.backends: 205 | actual = elementwise_functions.get_taylor_polynomial_coefficients( 206 | function_id, degree, x0, np_like) 207 | self.assert_enclosure_equal(expected, actual) 208 | 209 | @parameterized.parameters( 210 | # Function with no local minima. 211 | # Because there are no local minima, the function must either be 212 | # monotonically decreasing or monotonically increasing. 213 | ( 214 | elementwise_functions.FunctionData((), (), 215 | monotonically_decreasing=True), 216 | ([-3., -2., -1., 0., 1., 2.], [-2., -1., 0., 1., 2., 3.]), 217 | ( 218 | [True] * 6, 219 | [False] * 6, 220 | ) 221 | ), 222 | ( 223 | elementwise_functions.FunctionData((), (), 224 | monotonically_increasing=True), 225 | ([-3., -2., -1., 0., 1., 2.], [-2., -1., 0., 1., 2., 3.]), 226 | ( 227 | [False] * 6, 228 | [True] * 6, 229 | ) 230 | ), 231 | # Function with local minimum at -1 and local maximum at 1. 232 | ( 233 | elementwise_functions.FunctionData((-1.,), (1.,)), 234 | ([-3., -2., -1., 0., 1., 2., -3.], [-2., -1., 0., 1., 2., 3., 3.]), 235 | ( 236 | [True, True, False, False, True, True, False], 237 | [False, False, True, True, False, False, False], 238 | ) 239 | ), 240 | # Function with local maximum at -1 and local minimum at 1. 241 | ( 242 | elementwise_functions.FunctionData((1.,), (-1.,)), 243 | ([-3., -2., -1., 0., 1., 2., -3.], [-2., -1., 0., 1., 2., 3., 3.]), 244 | ( 245 | [False, False, True, True, False, False, False], 246 | [True, True, False, False, True, True, False], 247 | ) 248 | ), 249 | ) 250 | def test_monotone_over(self, function_data, region, expected): 251 | expected_decreasing, expected_increasing = expected 252 | region = np.asarray(region) 253 | expected_decreasing = np.asarray(expected_decreasing) 254 | expected_increasing = np.asarray(expected_increasing) 255 | for np_like in self.backends: 256 | actual = function_data.monotone_over(region, np_like) 257 | self.assertIsInstance(actual, tuple) 258 | self.assertLen(actual, 2) 259 | actual_decreasing, actual_increasing = actual 260 | np.testing.assert_equal(actual_decreasing, expected_decreasing) 261 | np.testing.assert_equal(actual_increasing, expected_increasing) 262 | 263 | def sanity_check_function_data(self, function_id, function_data): 264 | # Make sure all local minima/maxima have 0 gradient. 265 | f = get_jax_callable(function_id) 266 | grad = jax.grad(f) 267 | for x in itertools.chain(function_data.local_minima, 268 | function_data.local_maxima): 269 | np.testing.assert_allclose(grad(x), 0., rtol=0, atol=1e-12) 270 | # Make sure hessian is non-negative at local minima, and non-positive at 271 | # local maxima. 272 | hessian = jax.grad(grad) 273 | for x in function_data.local_minima: 274 | self.assertGreaterEqual(hessian(x), 0.) 275 | for x in function_data.local_maxima: 276 | self.assertLessEqual(hessian(x), 0.) 277 | 278 | if not function_data.local_minima and not function_data.local_maxima: 279 | self.assertTrue(function_data.monotonically_decreasing or 280 | function_data.monotonically_increasing) 281 | 282 | if (function_data.monotonically_decreasing or 283 | function_data.monotonically_increasing): 284 | self.assertEqual((), function_data.local_minima) 285 | self.assertEqual((), function_data.local_maxima) 286 | 287 | if function_data.monotonically_increasing: 288 | self.assertLessEqual(f(1.), f(2.)) 289 | if function_data.monotonically_decreasing: 290 | self.assertLessEqual(f(2.), f(1.)) 291 | 292 | 293 | if __name__ == '__main__': 294 | absltest.main() 295 | -------------------------------------------------------------------------------- /autobound/elementwise_functions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The autobound Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library that lists properties of various one-dimensional functions.""" 16 | 17 | import dataclasses 18 | import functools 19 | import math 20 | 21 | from typing import Callable, List, Optional, Sequence, Tuple 22 | # pylint: disable=g-multiple-import 23 | from autobound.types import (Interval, IntervalLike, NDArray, NDArrayLike, 24 | NumpyLike) 25 | 26 | 27 | @dataclasses.dataclass(eq=True, frozen=True) 28 | class FunctionId: 29 | """An identifier for an one-dimensional function.""" 30 | # Functions are identified by a unique name and a derivative order. For 31 | # example, the third derivative of the sigmoid function has name == 'sigmoid' 32 | # and derivative_order == 3. 33 | # 34 | # In some cases, multiple FunctionIDs refer to the same mathematical function. 35 | # For example, the sigmoid function can be represented by the FunctionId 36 | # with name == 'sigmoid' and derivative_order == 0, or by the FunctionId 37 | # with name == 'softplus' and derivative_order == 1. 38 | name: str 39 | derivative_order: int = 0 40 | # x_min and x_max specify the domain of function. A value of None represents 41 | # +/- infinity. 42 | x_min: Optional[float] = None 43 | x_max: Optional[float] = None 44 | 45 | def derivative_id(self, order: int): 46 | """Returns `FunctionID` for order `order` derivative of this function.""" 47 | return FunctionId(self.name, self.derivative_order + order, 48 | x_min=self.x_min, x_max=self.x_max) 49 | 50 | 51 | @dataclasses.dataclass(eq=True, frozen=True) 52 | class FunctionData: 53 | """An object that lists properties of a one-dimensional function.""" 54 | # The lists of local minima and maxima include all minima/maxima over the 55 | # domain of the function, in ascending order. 56 | local_minima: Tuple[float, ...] 57 | local_maxima: Tuple[float, ...] 58 | monotonically_decreasing: bool = False 59 | monotonically_increasing: bool = False 60 | even_symmetric: bool = False # whether f(x) = f(-x) for all x. 61 | 62 | def monotone_over( 63 | self, 64 | region: IntervalLike, 65 | np_like: NumpyLike) -> Tuple[NDArray, NDArray]: 66 | """Returns ndarrays showing whether the function is monotone over `region`. 67 | 68 | Args: 69 | region: an `Interval` 70 | np_like: a `NumpyLike` back end. 71 | 72 | Returns: 73 | a pair of boolean `NDArray`s `(decreasing, increasing)`, where the 74 | elements of `decreasing` (resp `increasing`) indicate whether the 75 | function is monotonically decreasing (resp increasing) over the interval 76 | specified by the corresponding elements of `region`. 77 | """ 78 | x_min = np_like.asarray(region[0]) 79 | x_max = np_like.asarray(region[1]) 80 | 81 | sorted_extrema = sorted(self.local_minima + self.local_maxima) 82 | decreasing_conditions = [] 83 | increasing_conditions = [] 84 | for i, x in enumerate(sorted_extrema): 85 | is_minimum = x in self.local_minima 86 | if i == 0: 87 | if is_minimum: 88 | decreasing_conditions.append(x_max <= x) 89 | else: 90 | increasing_conditions.append(x_max <= x) 91 | else: 92 | prev_x = sorted_extrema[i-1] 93 | contained = np_like.logical_and(x_min >= prev_x, x_max <= x) 94 | if is_minimum: 95 | decreasing_conditions.append(contained) 96 | else: 97 | increasing_conditions.append(contained) 98 | if i == len(sorted_extrema) - 1: 99 | if is_minimum: 100 | increasing_conditions.append(x_min >= x) 101 | else: 102 | decreasing_conditions.append(x_min >= x) 103 | 104 | decreasing = functools.reduce( 105 | np_like.logical_or, decreasing_conditions, 106 | np_like.full(x_min.shape, self.monotonically_decreasing)) 107 | increasing = functools.reduce( 108 | np_like.logical_or, increasing_conditions, 109 | np_like.full(x_min.shape, self.monotonically_increasing)) 110 | return decreasing, increasing 111 | 112 | 113 | # FunctionIds for various elementwise functions. 114 | ABS = FunctionId('abs') 115 | EXP = FunctionId('exp') 116 | LOG = FunctionId('log', x_min=0.) 117 | SIGMOID = FunctionId('sigmoid') 118 | SOFTPLUS = FunctionId('softplus') 119 | SWISH = FunctionId('swish') 120 | 121 | 122 | def get_function(function_id: FunctionId, 123 | np_like: NumpyLike) -> Callable[[NDArray], NDArray]: 124 | """Returns a callable version of a function with a given `FunctionId`.""" 125 | if function_id.name == ABS.name: 126 | if function_id.derivative_order == 0: 127 | return np_like.abs 128 | elif function_id.derivative_order == 1: 129 | return np_like.sign 130 | else: 131 | raise NotImplementedError(function_id.derivative_order) 132 | if function_id.name == EXP.name: 133 | return np_like.exp 134 | elif function_id.name == LOG.name: 135 | k = function_id.derivative_order 136 | if k == 0: 137 | return np_like.log 138 | else: 139 | sign = -1 if k % 2 == 0 else 1 140 | return lambda x: sign * math.factorial(k-1) * np_like.asarray(x)**-k 141 | elif function_id.name == SIGMOID.name: 142 | return functools.partial(_sigmoid_derivative, 143 | function_id.derivative_order, np_like=np_like) 144 | elif function_id.name == SOFTPLUS.name: 145 | return functools.partial(_softplus_derivative, 146 | function_id.derivative_order, np_like=np_like) 147 | elif function_id.name == SWISH.name: 148 | return functools.partial(_swish_derivative, 149 | function_id.derivative_order, np_like=np_like) 150 | else: 151 | raise NotImplementedError(function_id) 152 | 153 | 154 | def get_function_data(function_id: FunctionId) -> FunctionData: 155 | """Gets `FunctionData` given `FunctionId`.""" 156 | if function_id.name == EXP.name: 157 | return FunctionData((), (), monotonically_increasing=True) 158 | elif function_id.name == LOG.name: 159 | k = function_id.derivative_order 160 | return FunctionData( 161 | # The domain of the log function is (0, infinity). Over this domain, 162 | # none of the derivatives have any local extrema. 163 | (), 164 | (), 165 | monotonically_decreasing=(k%2 == 1), 166 | monotonically_increasing=(k%2 == 0) 167 | ) 168 | elif function_id.name == SIGMOID.name: 169 | # Sigmoid is 1st derivative of softplus, so kth derivative of sigmoid is 170 | # (k+1)st derivative of softplus. 171 | return get_function_data( 172 | SOFTPLUS.derivative_id(1 + function_id.derivative_order)) 173 | elif function_id in _FUNCTION_DATA: 174 | return _FUNCTION_DATA[function_id] 175 | else: 176 | raise NotImplementedError(function_id) 177 | 178 | 179 | def get_taylor_polynomial_coefficients( 180 | function_id: FunctionId, 181 | degree: int, 182 | x0: NDArray, 183 | np_like: NumpyLike) -> List[NDArray]: 184 | """Returns the Taylor polynomial coefficients for a given function at `x0`. 185 | 186 | Args: 187 | function_id: a `FunctionId` 188 | degree: the degree of the Taylor polynomial whose coefficients we return 189 | x0: the reference point 190 | np_like: a `NumpyLike` backend. 191 | 192 | Returns: 193 | a list of `NDArray`s of Taylor polynomial coefficients, of length 194 | `degree+1`. 195 | """ 196 | coefficients = [] 197 | for i in range(degree + 1): 198 | f_deriv = get_function(function_id.derivative_id(i), np_like) 199 | coefficients.append(f_deriv(x0) / math.factorial(i)) 200 | return coefficients 201 | 202 | 203 | def maximum_value(f, 204 | x_min: NDArray, 205 | x_max: NDArray, 206 | local_maxima: Sequence[float], 207 | np_like: NumpyLike) -> NDArray: 208 | """Returns maximum value of `f` over `[x_min, x_max]`.""" 209 | if not local_maxima: 210 | return np_like.maximum(f(x_min), f(x_max)) 211 | sorted_maxima = list(sorted(local_maxima, key=f, reverse=True)) 212 | x = sorted_maxima[0] 213 | return np_like.where( 214 | np_like.logical_and(x_min <= x, x <= x_max), 215 | f(x), 216 | maximum_value(f, x_min, x_max, sorted_maxima[1:], np_like)) 217 | 218 | 219 | def minimum_value(f, 220 | x_min: NDArray, 221 | x_max: NDArray, 222 | local_minima: Sequence[float], 223 | np_like: NumpyLike) -> NDArray: 224 | """Returns minimum value of `f` over `[x_min, x_max]`.""" 225 | if not local_minima: 226 | return np_like.minimum(f(x_min), f(x_max)) 227 | sorted_minima = list(sorted(local_minima, key=f)) 228 | x = sorted_minima[0] 229 | return np_like.where( 230 | np_like.logical_and(x_min <= x, x <= x_max), 231 | f(x), 232 | minimum_value(f, x_min, x_max, sorted_minima[1:], np_like)) 233 | 234 | 235 | def _get_range(f, 236 | x_min: NDArray, 237 | x_max: NDArray, 238 | local_minima: Sequence[float], 239 | local_maxima: Sequence[float], 240 | np_like: NumpyLike) -> Tuple[NDArray, NDArray]: 241 | minval = minimum_value(f, x_min, x_max, local_minima, np_like) 242 | maxval = maximum_value(f, x_min, x_max, local_maxima, np_like) 243 | return (minval, maxval) 244 | 245 | 246 | def get_range(function_id: FunctionId, 247 | trust_region: Interval, 248 | np_like: NumpyLike) -> Interval: 249 | """Returns exact range of specified function over `trust_region`.""" 250 | f = get_function(function_id, np_like) 251 | function_data = get_function_data(function_id) 252 | return _get_range( 253 | f, 254 | trust_region[0], 255 | trust_region[1], 256 | function_data.local_minima, 257 | function_data.local_maxima, 258 | np_like 259 | ) 260 | 261 | 262 | def _sigmoid(x: NDArrayLike, np_like: NumpyLike) -> NDArray: 263 | return np_like.where( 264 | x >= 0, 265 | 1 / (1 + np_like.exp(-x)), 266 | np_like.exp(x) / (1 + np_like.exp(x)) 267 | ) 268 | 269 | 270 | def _sigmoid_derivative(order: int, x: NDArrayLike, 271 | np_like: NumpyLike) -> NDArray: 272 | """Returns the (elementwise) derivative of a specified order.""" 273 | # Note: we could make this work for arbitrary order using autodiff, but we 274 | # don't because this module is backend-agnostic, and we don't have a way to 275 | # do autodiff in a backend-agnostic way. 276 | s = _sigmoid(x, np_like) 277 | sm = _sigmoid(-x, np_like) 278 | if order == 0: 279 | return s 280 | elif order == 1: 281 | return s*sm 282 | elif order == 2: 283 | return s*sm*(1-2*s) 284 | elif order == 3: 285 | return s*sm*((1-2*s)**2 - 2*s*sm) 286 | elif order == 4: 287 | return (s*sm*(1-2*s)*((1-2*s)**2 - 2*s*sm) + 288 | s*sm*(-4*(1-2*s)*s*sm - 2*s*sm*(1-2*s))) 289 | else: 290 | raise NotImplementedError(order) 291 | 292 | 293 | def _softplus(x: NDArrayLike, np_like: NumpyLike) -> NDArray: 294 | # Avoid overflow for large positive x using: 295 | # log(1+exp(x)) == log(1+exp(-|x|)) + max(x, 0). 296 | return np_like.log1p(np_like.exp(-np_like.abs(x))) + np_like.maximum(x, 0) 297 | 298 | 299 | def _softplus_derivative(order: int, x: NDArrayLike, 300 | np_like: NumpyLike) -> NDArray: 301 | if order == 0: 302 | return _softplus(x, np_like) 303 | else: 304 | return _sigmoid_derivative(order - 1, x, np_like) 305 | 306 | 307 | def _swish(x: NDArrayLike, np_like: NumpyLike) -> NDArray: 308 | return x*_sigmoid(x, np_like) 309 | 310 | 311 | def _swish_derivative(order: int, x: NDArrayLike, 312 | np_like: NumpyLike) -> NDArray: 313 | if order == 0: 314 | return _swish(x, np_like) 315 | else: 316 | # swish(x) = x*sigmoid(x) 317 | # swish'(x) = sigmoid(x) + x*sigmoid'(x). 318 | # Inductively, 319 | # swish^{(k)}(x) = k*sigmoid^({k-1})(x) + x*sigmoid^{(k)}(x). 320 | return (order*_sigmoid_derivative(order - 1, x, np_like) + 321 | x*_sigmoid_derivative(order, x, np_like)) 322 | 323 | 324 | # Dict from FunctionId to FunctionData. 325 | _FUNCTION_DATA = { 326 | # ABS 327 | ABS: FunctionData((0.,), (), even_symmetric=True), 328 | ABS.derivative_id(1): FunctionData((), (), monotonically_increasing=True), 329 | # SOFTPLUS 330 | SOFTPLUS: FunctionData((), (), monotonically_increasing=True), 331 | SOFTPLUS.derivative_id(1): 332 | FunctionData((), (), monotonically_increasing=True), 333 | SOFTPLUS.derivative_id(2): 334 | FunctionData((), (0.,), even_symmetric=True), 335 | SOFTPLUS.derivative_id(3): 336 | FunctionData((1.3169578969249405,), (-1.3169578969249423,)), 337 | SOFTPLUS.derivative_id(4): 338 | FunctionData((0.,), (-2.292431669561122, 2.2924316695611195)), 339 | SOFTPLUS.derivative_id(5): 340 | FunctionData((-0.8426329481295408, 3.1443184061547065), 341 | (-3.144318406154709, 0.8426329481295388)), 342 | # SWISH 343 | SWISH: FunctionData((-1.278464542761141,), ()), 344 | SWISH.derivative_id(1): 345 | FunctionData((-2.399357280515326,), (2.399357280515324,)), 346 | SWISH.derivative_id(2): 347 | FunctionData((-3.4358409935350243, 3.4358409935350225), (0.,)), 348 | SWISH.derivative_id(3): 349 | FunctionData((-4.429235100557346, 1.0319582417807385), 350 | (-1.0319582417807402, 4.429235100557342)), 351 | SWISH.derivative_id(4): 352 | FunctionData((0.,), (-1.8197756117249821, 1.8197756117249804)), 353 | SWISH.derivative_id(5): 354 | FunctionData((-0.7177419231466055, 2.5062894864026024), 355 | (-2.506289486402605, 0.7177419231466035)), 356 | } 357 | -------------------------------------------------------------------------------- /autobound/interval_arithmetic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The autobound Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """The IntervalArithmetic class.""" 16 | 17 | import functools 18 | from typing import Callable, Sequence, Union 19 | 20 | # pylint: disable=g-multiple-import 21 | from autobound.types import ( 22 | Interval, IntervalLike, NDArray, NDArrayLike, NumpyLike) 23 | 24 | 25 | class IntervalArithmetic: 26 | """Interval arithmetic on n-dimensional arrays via a numpy-like back end. 27 | 28 | Note: the intervals returned by methods in this class are only correct up to 29 | floating point roundoff error. 30 | """ 31 | 32 | def __init__(self, np_like: NumpyLike): 33 | self.np_like = np_like 34 | 35 | def add(self, 36 | a: Union[NDArrayLike, IntervalLike], 37 | b: Union[NDArrayLike, IntervalLike]) -> Union[NDArray, Interval]: 38 | """Returns the sum of two intervals.""" 39 | a_is_interval = isinstance(a, tuple) 40 | b_is_interval = isinstance(b, tuple) 41 | if a_is_interval and b_is_interval: 42 | return (self.np_like.add(a[0], b[0]), self.np_like.add(a[1], b[1])) 43 | elif a_is_interval: 44 | return (self.np_like.add(a[0], b), self.np_like.add(a[1], b)) 45 | elif b_is_interval: 46 | return (self.np_like.add(a, b[0]), self.np_like.add(a, b[1])) 47 | else: 48 | return self.np_like.add(a, b) 49 | 50 | def arbitrary_bilinear( 51 | self, 52 | a: Union[NDArrayLike, IntervalLike], 53 | b: Union[NDArrayLike, IntervalLike], 54 | bilinear: Callable[[NDArrayLike, NDArrayLike], NDArray], 55 | assume_product: bool = False 56 | ) -> Union[NDArray, Interval]: 57 | """Applies a bilinear operation (e.g., matmul, conv2d) to two intervals. 58 | 59 | Args: 60 | a: an NDArray-like or Interval-like object. 61 | b: an NDArray-like or Interval-like object. 62 | bilinear: a callable that takes two NDArray-like arguments, and returns 63 | the NDArray that results from applying some bilinear operation to them. 64 | assume_product: if True, we assume that each element of the NDArray 65 | returned by `bilinear` is a product of some element of `a` and some 66 | element of `b`, and use a rule that returns a tighter interval under 67 | this assumption. 68 | 69 | Returns: 70 | an NDArray or Interval representing the result of applying the bilinear 71 | operation to `a` and `b`. 72 | """ 73 | a_is_interval = isinstance(a, tuple) 74 | b_is_interval = isinstance(b, tuple) 75 | if not a_is_interval and not b_is_interval: 76 | return bilinear(a, b) 77 | 78 | if assume_product: 79 | def yield_endpoint_products(): 80 | a_endpoints = a if a_is_interval else (a,) 81 | b_endpoints = b if b_is_interval else (b,) 82 | for a_endpoint in a_endpoints: 83 | for b_endpoint in b_endpoints: 84 | yield bilinear(a_endpoint, b_endpoint) # pytype: disable=wrong-arg-types 85 | endpoint_products = list(yield_endpoint_products()) 86 | return ( 87 | functools.reduce(self.np_like.minimum, endpoint_products), 88 | functools.reduce(self.np_like.maximum, endpoint_products) 89 | ) 90 | else: 91 | # TODO(mstreeter): there are multiple methods that could be used here, 92 | # which make different tradeoffs between computation and the tightness of 93 | # the returned interval. 94 | # TODO(mstreeter): add reference to proof of correctness for the method 95 | # used here. 96 | def positive_and_negative_parts(x): 97 | return (self.np_like.maximum(0, x), self.np_like.minimum(0, x)) 98 | 99 | if not b_is_interval: 100 | assert a_is_interval 101 | b_pos, b_neg = positive_and_negative_parts(b) 102 | min_vals = self.np_like.add(bilinear(a[0], b_pos), 103 | bilinear(a[1], b_neg)) 104 | max_vals = self.np_like.add(bilinear(a[1], b_pos), 105 | bilinear(a[0], b_neg)) 106 | return (min_vals, max_vals) 107 | elif not a_is_interval: 108 | a_pos, a_neg = positive_and_negative_parts(a) 109 | min_vals = self.np_like.add(bilinear(a_pos, b[0]), 110 | bilinear(a_neg, b[1])) 111 | max_vals = self.np_like.add(bilinear(a_pos, b[1]), 112 | bilinear(a_neg, b[0])) 113 | return (min_vals, max_vals) 114 | else: 115 | assert a_is_interval and b_is_interval 116 | u, v = a 117 | w, x = b 118 | u_pos, u_neg = positive_and_negative_parts(u) 119 | v_pos, v_neg = positive_and_negative_parts(v) 120 | w_pos, w_neg = positive_and_negative_parts(w) 121 | x_pos, x_neg = positive_and_negative_parts(x) 122 | min_pairs = [(u_pos, w_pos), (v_pos, w_neg), 123 | (u_neg, x_pos), (v_neg, x_neg)] 124 | min_vals = functools.reduce( 125 | self.np_like.add, 126 | [bilinear(x, y) for x, y in min_pairs] 127 | ) 128 | max_pairs = [(v_pos, x_pos), (v_neg, w_pos), 129 | (u_pos, x_neg), (u_neg, w_neg)] 130 | max_vals = functools.reduce( 131 | self.np_like.add, 132 | [bilinear(x, y) for x, y in max_pairs] 133 | ) 134 | return (min_vals, max_vals) 135 | 136 | def as_interval(self, a: IntervalLike) -> Interval: 137 | return tuple(self.np_like.asarray(c) for c in a) 138 | 139 | def as_interval_or_ndarray( 140 | self, 141 | a: Union[NDArrayLike, IntervalLike]) -> Union[NDArray, Interval]: 142 | if isinstance(a, tuple): 143 | return self.as_interval(a) 144 | else: 145 | return self.np_like.asarray(a) 146 | 147 | def multiply( 148 | self, 149 | a: Union[NDArrayLike, IntervalLike], 150 | b: Union[NDArrayLike, IntervalLike]) -> Union[NDArray, Interval]: 151 | """Returns the element-wise product of two intervals.""" 152 | return self.arbitrary_bilinear(a, b, self.np_like.multiply, True) 153 | 154 | def ndim(self, a: Union[IntervalLike, NDArrayLike]) -> int: 155 | return self.np_like.ndim(a[0] if isinstance(a, tuple) else a) 156 | 157 | def negative(self, 158 | a: Union[NDArrayLike, IntervalLike]) -> Union[NDArray, Interval]: 159 | """Returns element-wise negative of an interval.""" 160 | if isinstance(a, tuple): 161 | return (self.np_like.negative(a[1]), self.np_like.negative(a[0])) 162 | else: 163 | return self.np_like.negative(a) 164 | 165 | def outer_power(self, 166 | a: Union[NDArrayLike, IntervalLike], 167 | exponent: int, 168 | batch_dims: int = 0) -> Union[NDArray, Interval]: 169 | """Returns a repeated outer product.""" 170 | if exponent < 0: 171 | raise ValueError(exponent) 172 | elif exponent == 0: 173 | return self.np_like.asarray(1) 174 | elif self.ndim(a) == 0: 175 | return self.power(a, exponent) 176 | else: 177 | # For the off-diagonal elements of the output, the best we can do is 178 | # to repeatedly call self.outer_product(). For the diagonal elements, 179 | # we can use self.power() to get a tighter result. 180 | a = self.as_interval_or_ndarray(a) 181 | running_outer_product = a 182 | for _ in range(exponent - 1): 183 | running_outer_product = self.outer_product(running_outer_product, a, 184 | batch_dims) 185 | if batch_dims != 0: 186 | # TODO(mstreeter): adapt the code below to handle batch_dims != 0. 187 | return running_outer_product 188 | try: 189 | a_is_interval = isinstance(a, tuple) 190 | eye = _generalized_diag_ndarray( 191 | self.np_like.ones_like(a[0] if a_is_interval else a), 192 | exponent, self.np_like) 193 | except NotImplementedError: 194 | return running_outer_product 195 | diagonal_elements = self.power(a, exponent) 196 | diag = self._generalized_diag_interval(diagonal_elements, exponent) 197 | return self.add( 198 | self.multiply(running_outer_product, 1 - eye), 199 | diag 200 | ) 201 | 202 | def outer_product(self, 203 | a: Union[NDArrayLike, IntervalLike], 204 | b: Union[NDArrayLike, IntervalLike], 205 | batch_dims: int = 0) -> Union[NDArray, Interval]: 206 | """Interval variant of _ndarray_outer_product().""" 207 | if batch_dims > self.ndim(a) or batch_dims > self.ndim(b): 208 | raise ValueError((self.ndim(a), self.ndim(b), batch_dims)) 209 | product = functools.partial(_ndarray_outer_product, 210 | batch_dims=batch_dims, np_like=self.np_like) 211 | return self.arbitrary_bilinear(a, b, product, True) 212 | 213 | def power(self, a: Union[NDArrayLike, IntervalLike], 214 | exponent: float) -> Union[NDArray, Interval]: 215 | """Returns a**exponent (element-wise).""" 216 | a = self.as_interval_or_ndarray(a) 217 | a_is_interval = isinstance(a, tuple) 218 | if a_is_interval: 219 | if exponent < 0: 220 | raise NotImplementedError(exponent) 221 | elif exponent == 0: 222 | return self.np_like.ones_like(a[0]) 223 | else: 224 | # For scalars u and v, with u <= v, and even K, the left end point of 225 | # [u, v]**K is 0 if u <= 0 <= v, and is min{u**K, v**K} otherwise. 226 | # If K is odd, the left end point is u**K. The expression for 227 | # min_vals below handles all cases. 228 | # 229 | # The right and point is always max{u**K, v**K}, giving a simpler 230 | # expression for max_vals. 231 | contains_zero = self.np_like.logical_and(a[0] < 0, a[1] > 0) 232 | pow0 = a[0]**exponent 233 | pow1 = a[1]**exponent 234 | min_vals = functools.reduce(self.np_like.minimum, 235 | [pow0, pow1, (1-contains_zero)*pow0]) 236 | max_vals = self.np_like.maximum(pow0, pow1) 237 | return (min_vals, max_vals) 238 | else: 239 | return self.np_like.power(a, exponent) 240 | 241 | def shape(self, a: Union[IntervalLike, NDArrayLike]): 242 | # Note: calling tuple(...) is necessary when self.np_like is 243 | # tf.experimental.numpy. 244 | return tuple(self.np_like.shape(a[0] if isinstance(a, tuple) else a)) 245 | 246 | def subtract( 247 | self, 248 | a: Union[NDArrayLike, IntervalLike], 249 | b: Union[NDArrayLike, IntervalLike]) -> Union[NDArray, Interval]: 250 | """Returns the difference between two intervals.""" 251 | a_is_interval = isinstance(a, tuple) 252 | b_is_interval = isinstance(b, tuple) 253 | if a_is_interval and b_is_interval: 254 | return (self.np_like.subtract(a[0], b[1]), 255 | self.np_like.subtract(a[1], b[0])) 256 | elif a_is_interval: 257 | return (self.np_like.subtract(a[0], b), self.np_like.subtract(a[1], b)) 258 | elif b_is_interval: 259 | return (self.np_like.subtract(a, b[1]), self.np_like.subtract(a, b[0])) 260 | else: 261 | return self.np_like.subtract(a, b) 262 | 263 | def tensordot( 264 | self, 265 | a: Union[NDArrayLike, IntervalLike], 266 | b: Union[NDArrayLike, IntervalLike], 267 | axes) -> Union[NDArray, Interval]: 268 | """Like np.tensordot(), but for intervals.""" 269 | bilinear = functools.partial(self.np_like.tensordot, axes=axes) 270 | return self.arbitrary_bilinear(a, b, bilinear, axes == 0) 271 | 272 | def _generalized_diag_interval( 273 | self, 274 | a: Union[NDArrayLike, IntervalLike], n: int) -> Union[NDArray, Interval]: 275 | """Interval variant of _generalized_diag_ndarray.""" 276 | if isinstance(a, tuple): 277 | if len(a) != 2: 278 | raise ValueError() 279 | return (_generalized_diag_ndarray(a[0], n, self.np_like), 280 | _generalized_diag_ndarray(a[1], n, self.np_like)) 281 | else: 282 | return _generalized_diag_ndarray(a, n, self.np_like) 283 | 284 | 285 | def _generalized_diag_ndarray(a: NDArrayLike, n: int, 286 | np_like: NumpyLike) -> NDArray: 287 | """Returns NDArray of shape shape(a)*n, with a on diagonal.""" 288 | if n == 1: 289 | return np_like.asarray(a) 290 | elif n == 2: 291 | a = np_like.asarray(a) 292 | if a.ndim == 1: 293 | return np_like.diag(a) 294 | else: 295 | raise NotImplementedError(a.ndim) 296 | else: 297 | raise NotImplementedError(n) 298 | 299 | 300 | def _stringify(axes: Sequence[int]) -> str: 301 | """Helper for creating an einsum() equation string.""" 302 | offset = ord('a') 303 | return ''.join(chr(i + offset) for i in axes) 304 | 305 | 306 | def _ndarray_outer_product(a: NDArrayLike, 307 | b: NDArrayLike, 308 | batch_dims: int, 309 | np_like: NumpyLike) -> NDArray: 310 | """Returns an outer product with batch dimensions. 311 | 312 | Args: 313 | a: an NDArray-like object. 314 | b: an NDArray-like object. 315 | batch_dims: number of batch dimensions. 316 | np_like: a Numpy-like backend. 317 | 318 | Returns: 319 | an NDArray c such that, for every tuple I that indexes the first 320 | batch_dims elements of a (and b), every tuple J that indexes the last 321 | a.ndim - batch_dims elements of a, and every tuple K that indexes the last 322 | a.ndim - batch_dims elements of b, we have: 323 | 324 | c[I + J + K] = a[I + J] * b[I + K] 325 | """ 326 | a = np_like.asarray(a) 327 | b = np_like.asarray(b) 328 | if batch_dims == 0: 329 | return np_like.tensordot(a, b, 0) 330 | else: 331 | a_axes = tuple(range(a.ndim)) 332 | b_non_batch_axes = tuple(range(a.ndim, a.ndim + b.ndim - batch_dims)) 333 | b_axes = tuple(range(batch_dims)) + b_non_batch_axes 334 | output_axes = a_axes + b_non_batch_axes 335 | eq = (_stringify(a_axes) + ',' + _stringify(b_axes) + 336 | '->' + _stringify(output_axes)) 337 | return np_like.einsum(eq, a, b) 338 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | # This Pylint rcfile contains a best-effort configuration to uphold the 2 | # best-practices and style described in the Google Python style guide: 3 | # https://google.github.io/styleguide/pyguide.html 4 | # 5 | # Its canonical open-source location is: 6 | # https://google.github.io/styleguide/pylintrc 7 | 8 | [MASTER] 9 | 10 | # Add files or directories to the ignore list. They should be base names, not 11 | # paths. 12 | ignore=third_party 13 | 14 | # Add files or directories matching the regex patterns to the ignore list. The 15 | # regex matches against base names, not paths. 16 | ignore-patterns= 17 | 18 | # Pickle collected data for later comparisons. 19 | persistent=no 20 | 21 | # List of plugins (as comma separated values of python modules names) to load, 22 | # usually to register additional checkers. 23 | load-plugins= 24 | 25 | # Use multiple processes to speed up Pylint. 26 | jobs=4 27 | 28 | # Allow loading of arbitrary C extensions. Extensions are imported into the 29 | # active Python interpreter and may run arbitrary code. 30 | unsafe-load-any-extension=no 31 | 32 | # A comma-separated list of package or module names from where C extensions may 33 | # be loaded. Extensions are loading into the active Python interpreter and may 34 | # run arbitrary code. 35 | extension-pkg-allow-list= 36 | 37 | 38 | [MESSAGES CONTROL] 39 | 40 | # Only show warnings with the listed confidence levels. Leave empty to show 41 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 42 | confidence= 43 | 44 | # Enable the message, report, category or checker with the given id(s). You can 45 | # either give multiple identifier separated by comma (,) or put this option 46 | # multiple time (only on the command line, not in the configuration file where 47 | # it should appear only once). See also the "--disable" option for examples. 48 | #enable= 49 | 50 | # Disable the message, report, category or checker with the given id(s). You 51 | # can either give multiple identifiers separated by comma (,) or put this 52 | # option multiple times (only on the command line, not in the configuration 53 | # file where it should appear only once).You can also use "--disable=all" to 54 | # disable everything first and then reenable specific checks. For example, if 55 | # you want to run only the similarities checker, you can use "--disable=all 56 | # --enable=similarities". If you want to run only the classes checker, but have 57 | # no Warning level messages displayed, use"--disable=all --enable=classes 58 | # --disable=W" 59 | disable=abstract-method, 60 | apply-builtin, 61 | arguments-differ, 62 | attribute-defined-outside-init, 63 | backtick, 64 | bad-option-value, 65 | basestring-builtin, 66 | buffer-builtin, 67 | c-extension-no-member, 68 | consider-using-enumerate, 69 | cmp-builtin, 70 | cmp-method, 71 | coerce-builtin, 72 | coerce-method, 73 | delslice-method, 74 | div-method, 75 | duplicate-code, 76 | eq-without-hash, 77 | execfile-builtin, 78 | file-builtin, 79 | filter-builtin-not-iterating, 80 | fixme, 81 | getslice-method, 82 | global-statement, 83 | hex-method, 84 | idiv-method, 85 | implicit-str-concat-in-sequence, 86 | import-error, 87 | import-self, 88 | import-star-module-level, 89 | inconsistent-return-statements, 90 | input-builtin, 91 | intern-builtin, 92 | invalid-str-codec, 93 | locally-disabled, 94 | long-builtin, 95 | long-suffix, 96 | map-builtin-not-iterating, 97 | misplaced-comparison-constant, 98 | missing-function-docstring, 99 | metaclass-assignment, 100 | next-method-called, 101 | next-method-defined, 102 | no-absolute-import, 103 | no-else-break, 104 | no-else-continue, 105 | no-else-raise, 106 | no-else-return, 107 | no-init, # added 108 | no-member, 109 | no-name-in-module, 110 | no-self-use, 111 | nonzero-method, 112 | oct-method, 113 | old-division, 114 | old-ne-operator, 115 | old-octal-literal, 116 | old-raise-syntax, 117 | parameter-unpacking, 118 | print-statement, 119 | raising-string, 120 | range-builtin-not-iterating, 121 | raw_input-builtin, 122 | rdiv-method, 123 | reduce-builtin, 124 | relative-import, 125 | reload-builtin, 126 | round-builtin, 127 | setslice-method, 128 | signature-differs, 129 | standarderror-builtin, 130 | suppressed-message, 131 | sys-max-int, 132 | too-few-public-methods, 133 | too-many-ancestors, 134 | too-many-arguments, 135 | too-many-boolean-expressions, 136 | too-many-branches, 137 | too-many-instance-attributes, 138 | too-many-locals, 139 | too-many-nested-blocks, 140 | too-many-public-methods, 141 | too-many-return-statements, 142 | too-many-statements, 143 | trailing-newlines, 144 | unichr-builtin, 145 | unicode-builtin, 146 | unnecessary-pass, 147 | unpacking-in-except, 148 | useless-else-on-loop, 149 | useless-object-inheritance, 150 | useless-suppression, 151 | using-cmp-argument, 152 | wrong-import-order, 153 | xrange-builtin, 154 | zip-builtin-not-iterating, 155 | 156 | 157 | [REPORTS] 158 | 159 | # Set the output format. Available formats are text, parseable, colorized, msvs 160 | # (visual studio) and html. You can also give a reporter class, eg 161 | # mypackage.mymodule.MyReporterClass. 162 | output-format=text 163 | 164 | # Put messages in a separate file for each module / package specified on the 165 | # command line instead of printing them on stdout. Reports (if any) will be 166 | # written in a file name "pylint_global.[txt|html]". This option is deprecated 167 | # and it will be removed in Pylint 2.0. 168 | files-output=no 169 | 170 | # Tells whether to display a full report or only the messages 171 | reports=no 172 | 173 | # Python expression which should return a note less than 10 (10 is the highest 174 | # note). You have access to the variables errors warning, statement which 175 | # respectively contain the number of errors / warnings messages and the total 176 | # number of statements analyzed. This is used by the global evaluation report 177 | # (RP0004). 178 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 179 | 180 | # Template used to display messages. This is a python new-style format string 181 | # used to format the message information. See doc for all details 182 | #msg-template= 183 | 184 | 185 | [BASIC] 186 | 187 | # Good variable names which should always be accepted, separated by a comma 188 | good-names=main,_ 189 | 190 | # Bad variable names which should always be refused, separated by a comma 191 | bad-names= 192 | 193 | # Colon-delimited sets of names that determine each other's naming style when 194 | # the name regexes allow several styles. 195 | name-group= 196 | 197 | # Include a hint for the correct naming format with invalid-name 198 | include-naming-hint=no 199 | 200 | # List of decorators that produce properties, such as abc.abstractproperty. Add 201 | # to this list to register other decorators that produce valid properties. 202 | property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl 203 | 204 | # Regular expression matching correct function names 205 | function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ 206 | 207 | # Regular expression matching correct variable names 208 | variable-rgx=^[a-z][a-z0-9_]*$ 209 | 210 | # Regular expression matching correct constant names 211 | const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 212 | 213 | # Regular expression matching correct attribute names 214 | attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ 215 | 216 | # Regular expression matching correct argument names 217 | argument-rgx=^[a-z][a-z0-9_]*$ 218 | 219 | # Regular expression matching correct class attribute names 220 | class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 221 | 222 | # Regular expression matching correct inline iteration names 223 | inlinevar-rgx=^[a-z][a-z0-9_]*$ 224 | 225 | # Regular expression matching correct class names 226 | class-rgx=^_?[A-Z][a-zA-Z0-9]*$ 227 | 228 | # Regular expression matching correct module names 229 | module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ 230 | 231 | # Regular expression matching correct method names 232 | method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ 233 | 234 | # Regular expression which should only match function or class names that do 235 | # not require a docstring. 236 | no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ 237 | 238 | # Minimum line length for functions/classes that require docstrings, shorter 239 | # ones are exempt. 240 | docstring-min-length=10 241 | 242 | 243 | [TYPECHECK] 244 | 245 | # List of decorators that produce context managers, such as 246 | # contextlib.contextmanager. Add to this list to register other decorators that 247 | # produce valid context managers. 248 | contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager 249 | 250 | # Tells whether missing members accessed in mixin class should be ignored. A 251 | # mixin class is detected if its name ends with "mixin" (case insensitive). 252 | ignore-mixin-members=yes 253 | 254 | # List of module names for which member attributes should not be checked 255 | # (useful for modules/projects where namespaces are manipulated during runtime 256 | # and thus existing member attributes cannot be deduced by static analysis. It 257 | # supports qualified module names, as well as Unix pattern matching. 258 | ignored-modules= 259 | 260 | # List of class names for which member attributes should not be checked (useful 261 | # for classes with dynamically set attributes). This supports the use of 262 | # qualified names. 263 | ignored-classes=optparse.Values,thread._local,_thread._local 264 | 265 | # List of members which are set dynamically and missed by pylint inference 266 | # system, and so shouldn't trigger E1101 when accessed. Python regular 267 | # expressions are accepted. 268 | generated-members= 269 | 270 | 271 | [FORMAT] 272 | 273 | # Maximum number of characters on a single line. 274 | max-line-length=80 275 | 276 | # TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt 277 | # lines made too long by directives to pytype. 278 | 279 | # Regexp for a line that is allowed to be longer than the limit. 280 | ignore-long-lines=(?x)( 281 | ^\s*(\#\ )??$| 282 | ^\s*(from\s+\S+\s+)?import\s+.+$) 283 | 284 | # Allow the body of an if to be on the same line as the test if there is no 285 | # else. 286 | single-line-if-stmt=yes 287 | 288 | # List of optional constructs for which whitespace checking is disabled. `dict- 289 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. 290 | # `trailing-comma` allows a space between comma and closing bracket: (a, ). 291 | # `empty-line` allows space-only lines. 292 | no-space-check= 293 | 294 | # Maximum number of lines in a module 295 | max-module-lines=99999 296 | 297 | # String used as indentation unit. The internal Google style guide mandates 2 298 | # spaces. Google's externaly-published style guide says 4, consistent with 299 | # PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google 300 | # projects (like TensorFlow). 301 | indent-string=' ' 302 | 303 | # Number of spaces of indent required inside a hanging or continued line. 304 | indent-after-paren=4 305 | 306 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 307 | expected-line-ending-format= 308 | 309 | 310 | [MISCELLANEOUS] 311 | 312 | # List of note tags to take in consideration, separated by a comma. 313 | notes=TODO 314 | 315 | 316 | [STRING] 317 | 318 | # This flag controls whether inconsistent-quotes generates a warning when the 319 | # character used as a quote delimiter is used inconsistently within a module. 320 | check-quote-consistency=yes 321 | 322 | 323 | [VARIABLES] 324 | 325 | # Tells whether we should check for unused import in __init__ files. 326 | init-import=no 327 | 328 | # A regular expression matching the name of dummy variables (i.e. expectedly 329 | # not used). 330 | dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) 331 | 332 | # List of additional names supposed to be defined in builtins. Remember that 333 | # you should avoid to define new builtins when possible. 334 | additional-builtins= 335 | 336 | # List of strings which can identify a callback function by name. A callback 337 | # name must start or end with one of those strings. 338 | callbacks=cb_,_cb 339 | 340 | # List of qualified module names which can have objects that can redefine 341 | # builtins. 342 | redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools 343 | 344 | 345 | [LOGGING] 346 | 347 | # Logging modules to check that the string format arguments are in logging 348 | # function parameter format 349 | logging-modules=logging,absl.logging,tensorflow.io.logging 350 | 351 | 352 | [SIMILARITIES] 353 | 354 | # Minimum lines number of a similarity. 355 | min-similarity-lines=4 356 | 357 | # Ignore comments when computing similarities. 358 | ignore-comments=yes 359 | 360 | # Ignore docstrings when computing similarities. 361 | ignore-docstrings=yes 362 | 363 | # Ignore imports when computing similarities. 364 | ignore-imports=no 365 | 366 | 367 | [SPELLING] 368 | 369 | # Spelling dictionary name. Available dictionaries: none. To make it working 370 | # install python-enchant package. 371 | spelling-dict= 372 | 373 | # List of comma separated words that should not be checked. 374 | spelling-ignore-words= 375 | 376 | # A path to a file that contains private dictionary; one word per line. 377 | spelling-private-dict-file= 378 | 379 | # Tells whether to store unknown words to indicated private dictionary in 380 | # --spelling-private-dict-file option instead of raising a message. 381 | spelling-store-unknown-words=no 382 | 383 | 384 | [IMPORTS] 385 | 386 | # Deprecated modules which should not be used, separated by a comma 387 | deprecated-modules=regsub, 388 | TERMIOS, 389 | Bastion, 390 | rexec, 391 | sets 392 | 393 | # Create a graph of every (i.e. internal and external) dependencies in the 394 | # given file (report RP0402 must not be disabled) 395 | import-graph= 396 | 397 | # Create a graph of external dependencies in the given file (report RP0402 must 398 | # not be disabled) 399 | ext-import-graph= 400 | 401 | # Create a graph of internal dependencies in the given file (report RP0402 must 402 | # not be disabled) 403 | int-import-graph= 404 | 405 | # Force import order to recognize a module as part of the standard 406 | # compatibility libraries. 407 | known-standard-library= 408 | 409 | # Force import order to recognize a module as part of a third party library. 410 | known-third-party=enchant, absl 411 | 412 | # Analyse import fallback blocks. This can be used to support both Python 2 and 413 | # 3 compatible code, which means that the block might have code that exists 414 | # only in one or another interpreter, leading to false positives when analysed. 415 | analyse-fallback-blocks=no 416 | 417 | 418 | [CLASSES] 419 | 420 | # List of method names used to declare (i.e. assign) instance attributes. 421 | defining-attr-methods=__init__, 422 | __new__, 423 | setUp 424 | 425 | # List of member names, which should be excluded from the protected access 426 | # warning. 427 | exclude-protected=_asdict, 428 | _fields, 429 | _replace, 430 | _source, 431 | _make 432 | 433 | # List of valid names for the first argument in a class method. 434 | valid-classmethod-first-arg=cls, 435 | class_ 436 | 437 | # List of valid names for the first argument in a metaclass class method. 438 | valid-metaclass-classmethod-first-arg=mcs 439 | 440 | 441 | [EXCEPTIONS] 442 | 443 | # Exceptions that will emit a warning when being caught. Defaults to 444 | # "Exception" 445 | overgeneral-exceptions=StandardError, 446 | Exception, 447 | BaseException 448 | -------------------------------------------------------------------------------- /autobound/jax/jax_bound_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The autobound Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from autobound import enclosure_arithmetic 20 | from autobound import primitive_enclosures 21 | from autobound import test_utils 22 | from autobound.jax import jax_bound 23 | from flax import linen as nn 24 | import jax 25 | import jax.numpy as jnp 26 | import numpy as np 27 | 28 | 29 | # Custom softplus primitive, for use in testing registration mechanism. 30 | my_softplus_p = jax.core.Primitive('my_softplus') 31 | my_softplus_p.def_abstract_eval( 32 | lambda x: jax.abstract_arrays.ShapedArray(x.shape, x.dtype)) 33 | 34 | 35 | def my_softplus(x): 36 | return my_softplus_p.bind(x) 37 | 38 | 39 | class TestCase(parameterized.TestCase, test_utils.TestCase): 40 | 41 | @parameterized.parameters( 42 | ( 43 | 2, 44 | (np.zeros((1,)), np.ones((1,))), 45 | (np.ones((1,)), np.ones((1, 1))), 46 | (np.ones((1, 13, 7)),), 47 | { 48 | 'dimension_numbers': (((0,), (0,)), ((), ())), 49 | 'precision': None, 50 | 'preferred_element_type': None 51 | }, 52 | (np.ones((13, 7)), np.ones((13, 7, 1))) 53 | ), 54 | ( 55 | 2, 56 | (np.zeros((2,)), np.ones((2,))), 57 | (np.ones((3, 5)),), 58 | (np.ones((5, 7)), np.ones((5, 7, 2))), 59 | { 60 | 'dimension_numbers': (((1,), (0,)), ((), ())), 61 | 'precision': None, 62 | 'preferred_element_type': None 63 | }, 64 | (5*np.ones((3, 7)), 5*np.ones((3, 7, 2))) 65 | ), 66 | ( 67 | 2, 68 | (np.zeros((2,)), np.ones((2,))), 69 | (np.ones((5, 7)), np.ones((5, 7, 2))), 70 | (np.ones((7, 11)),), 71 | { 72 | 'dimension_numbers': (((1,), (0,)), ((), ())), 73 | 'precision': None, 74 | 'preferred_element_type': None 75 | }, 76 | (7*np.ones((5, 11)), 7*np.ones((5, 11, 2))) 77 | ), 78 | # TODO(mstreeter): test an example with batch dimensions. 79 | ) 80 | def test_dot_general_pushforward_fun(self, max_degree, trust_region, 81 | lhs, rhs, params, expected): 82 | arithmetic = enclosure_arithmetic.TaylorEnclosureArithmetic( 83 | max_degree, trust_region, jnp) 84 | fun = jax_bound._dot_general_pushforward_fun(arithmetic) 85 | actual = fun(jax_bound._IntermediateEnclosure(enclosure=lhs), 86 | jax_bound._IntermediateEnclosure(enclosure=rhs), 87 | **params) 88 | self.assert_enclosure_equal(expected, actual) 89 | 90 | @parameterized.named_parameters( 91 | ( 92 | 'identity_0', 93 | lambda x: x, 94 | 0, 95 | .1, 96 | (0., .25), 97 | False, 98 | ((0., .25),), 99 | ), 100 | ( 101 | 'identity_1', 102 | lambda x: x, 103 | 1, 104 | 3.14, 105 | (0, .25), 106 | False, 107 | (3.14, 1.), 108 | ), 109 | ( 110 | 'identity_2', 111 | lambda x: x, 112 | 2, 113 | 3.14, 114 | (0, .25), 115 | False, 116 | (3.14, 1.), 117 | ), 118 | ( 119 | 'addition', 120 | lambda x: 2 + x, 121 | 1, 122 | 3.14, 123 | (0, .25), 124 | False, 125 | (5.14, 1.), 126 | ), 127 | ( 128 | 'addition_prop', 129 | lambda x: 2 + x, 130 | 1, 131 | 3.14, 132 | (0, .25), 133 | True, 134 | (5.14, 1.), 135 | ), 136 | ( 137 | 'constant', 138 | lambda x: 2., 139 | 1, 140 | .5, 141 | (0, 1), 142 | False, 143 | (2.,), 144 | ), 145 | ( 146 | 'constant_prop_0', 147 | lambda x: jnp.exp(1), 148 | 1, 149 | .5, 150 | (0, 1), 151 | True, 152 | (math.e,), 153 | ), 154 | ( 155 | 'constant_prop_1', 156 | lambda x: jnp.eye(3) @ (jnp.zeros((3,)) * x), 157 | 1, 158 | .5, 159 | (0, 1), 160 | True, 161 | (np.zeros((3,)), np.zeros((3,))), 162 | ), 163 | ( 164 | 'multiplication', 165 | lambda x: 2*x, 166 | 1, 167 | 3.14, 168 | (0, .25), 169 | False, 170 | (6.28, 2.), 171 | ), 172 | ( 173 | 'abs', 174 | jnp.abs, 175 | 1, 176 | np.array([0., 1., 1., -1., -1.]), 177 | ( 178 | np.array([-1., 0., -1., -2., -2.]), 179 | np.array([2., 2., 2., 0., 1.]) 180 | ), 181 | False, 182 | (np.array([0., 1., 1., 1., 1.]), 183 | (np.diag([-1., 1., 0., -1., -1.]), np.diag([1., 1., 1., -1., 0.]))) 184 | ), 185 | ( 186 | 'exp', 187 | jnp.exp, 188 | 2, 189 | 1., 190 | (0, 2), 191 | False, 192 | (math.e, math.e, (1., math.exp(2) - 2*math.e)), 193 | ), 194 | ( 195 | 'log_exp_noprop', 196 | lambda x: jnp.log(jnp.exp(x)), 197 | 1, 198 | 1., 199 | (.5, 2), 200 | False, 201 | # Degree-1 enclosure for exp(x) at 1 over [.5, 2] is: 202 | # (e, (2*(e-e**.5), e**2 - e)). 203 | # Enclosing this degree-1 enclosure by an interval, using the fact 204 | # that x-x0 in [-.5, 1], gives: 205 | # (e - .5*(e**2 - e), e**2) 206 | # Degree-1 enclosure for log(y) at e over [e - .5*(e**2 - e), e**2] 207 | # is: 208 | # (1, [1/(e**2 - e), 2*(1 - log(1.5*e - .5*e**2)) / (e**2 - e)]) 209 | # Composing the two degree-1 enclosures gives the following expected 210 | # enclosure: 211 | (1., (2*(math.e-math.e**.5)/(math.e**2-math.e), 212 | 2*(1 - math.log(1.5*math.e - .5*math.e**2)))), 213 | # Expected value is approximately (1, (0.45797998, 3.91999056)). 214 | ), 215 | ( 216 | 'log_exp_prop', 217 | lambda x: jnp.log(jnp.exp(x)), 218 | 1, 219 | 1., 220 | (0.5, 2), 221 | True, 222 | # Degree-0 enclosure for exp(x) at 1 over [.5, 2] is [e**.5, e**2] 223 | # Degree-1 enclosure for log(y) at e over [e**.5, e**2] is 224 | # (1, (1/(e**2 - e) .5/(e - e**.5))) 225 | # Degree-1 enclosure for exp(x) at 1 over [.5, 2] is: 226 | # (e, (2*(e-e**.5), e**2 - e)). 227 | # Composing the two degree-1 enclosures gives the following expected 228 | # enclosure: 229 | (1., (2*(math.e-math.e**.5)/(math.e**2-math.e), 230 | .5*(math.e**2-math.e)/(math.e-math.e**.5))), 231 | # Expected value is approximately (1, (0.45797998, 2.18350155)). 232 | ), 233 | ( 234 | 'log', 235 | jnp.log, 236 | 2, 237 | 2., 238 | (1, 3), 239 | False, 240 | (np.log(2), .5, (-np.log(2)+.5, np.log(3)-np.log(2)-.5)) 241 | ), 242 | ( 243 | 'eye', 244 | lambda x: jnp.eye(3), 245 | 0, 246 | 3.14, 247 | (0, .25), 248 | False, 249 | (np.eye(3),) 250 | ), 251 | ( 252 | 'matmul_a', 253 | lambda x: jnp.matmul(jnp.eye(3), x*jnp.ones((3,))), 254 | 1, 255 | lambda: jnp.array(1.), 256 | lambda: (jnp.array(0.), jnp.array(2.)), 257 | False, 258 | (np.ones((3,)), np.ones((3,))) 259 | ), 260 | ( 261 | 'matmul_b', 262 | lambda x: jnp.matmul(jnp.eye(3), x), 263 | 1, 264 | lambda: jnp.ones((3,)), 265 | lambda: (jnp.zeros((3,)), 2*jnp.ones((3,))), 266 | False, 267 | (np.ones((3,)), np.eye(3)) 268 | ), 269 | ( 270 | 'tensordot', 271 | lambda x: jnp.tensordot(x, jnp.array([[6., 8.]]), axes=1), 272 | 1, 273 | np.array([0.]), 274 | (np.array([0.]), np.array([1.])), 275 | False, 276 | (np.array([0., 0.]), np.array([[6.], [8.]])), 277 | ), 278 | ( 279 | 'negative', 280 | jnp.negative, 281 | 0, 282 | 3.14, 283 | (0., 5.), 284 | False, 285 | ((-5., 0.),) 286 | ), 287 | ( 288 | 'minus', 289 | lambda x: x - 1, 290 | 0, 291 | 3.14, 292 | (0., 5.), 293 | False, 294 | ((-1., 4.),) 295 | ), 296 | ( 297 | 'broadcast', 298 | lambda x: jax.lax.broadcast(x, [2]), 299 | 0, 300 | 3.14, 301 | (0., 5.), 302 | False, 303 | (([0., 0.], [5., 5.]),) 304 | ), 305 | ( 306 | 'square_a', 307 | lambda x: x**2, 308 | 0, 309 | 0, 310 | (-.5, .5), 311 | False, 312 | ((0., .25),), 313 | ), 314 | ( 315 | 'square_b', 316 | lambda x: x**2, 317 | 2, 318 | 0., 319 | (0, 1), 320 | False, 321 | (0., 0., 1.), 322 | ), 323 | ( 324 | 'square_c', 325 | lambda x: x**2, 326 | 2, 327 | 0.5, 328 | (0, 1), 329 | False, 330 | (0.25, 1., 1.), 331 | ), 332 | ( 333 | 'square_b_prop', 334 | lambda x: x**2, 335 | 2, 336 | 0.5, 337 | (0, 1), 338 | True, 339 | (0.25, 1., 1.), 340 | ), 341 | ( 342 | 'sqrt', 343 | lambda x: x**.5, 344 | 2, 345 | 2., 346 | (1, 3), 347 | False, 348 | (2**.5, .5 / 2**.5, 349 | (1 - (2**.5 - .5/2**.5), 3**.5 - (2**.5 + .5/2**.5))) 350 | ), 351 | ( 352 | 'multiply_b', 353 | lambda x: x * jnp.array([[1], [2]]), 354 | 2, 355 | 0.5, 356 | (0, 2), 357 | False, 358 | (np.array([[0.5], [1.]]), np.array([[1.], [2.]])) 359 | ), 360 | ( 361 | 'reshape', 362 | lambda x: (x * jnp.array([[1], [2]])).reshape(-1), 363 | 2, 364 | 0.5, 365 | (0, 2), 366 | False, 367 | (np.array([0.5, 1.]), np.array([1., 2.])) 368 | ), 369 | ( 370 | 'sum', 371 | lambda x: (x * jnp.array([[1], [2]])).sum(), 372 | 2, 373 | 0.5, 374 | (0, 2), 375 | False, 376 | (1.5, 3.) 377 | ), 378 | ( 379 | 'transpose', 380 | lambda x: (x * jnp.array([[1], [2]])).transpose(), 381 | 2, 382 | 0.5, 383 | (0, 2), 384 | False, 385 | (np.array([[0.5, 1.]]), np.array([[1., 2.]])) 386 | ), 387 | ( 388 | 'squeeze', 389 | lambda x: (x * jnp.ones((4, 1, 2))).squeeze([1]), 390 | 2, 391 | 0.5, 392 | (0, 2), 393 | False, 394 | (.5 * np.ones((4, 2)), np.ones((4, 2))) 395 | ), 396 | ( 397 | 'avg_pool_a', 398 | lambda x: nn.avg_pool(x*jnp.array([[[1], [2], [3], [4]]]), (2,)), 399 | 2, 400 | 0.5, 401 | (0, 2), 402 | False, 403 | (.5 * np.array([[[1.5], [2.5], [3.5]]]), 404 | np.array([[[1.5], [2.5], [3.5]]])) 405 | ), 406 | ( 407 | 'avg_pool_b', 408 | lambda x: nn.avg_pool((x**2)*jnp.array([[[1], [2], [3], [4]]]), (2,)), 409 | 2, 410 | 3., 411 | (0, 5), 412 | False, 413 | ( 414 | 9 * np.array([[[1.5], [2.5], [3.5]]]), 415 | 6 * np.array([[[1.5], [2.5], [3.5]]]), 416 | np.array([[[1.5], [2.5], [3.5]]]) 417 | ) 418 | ), 419 | ( 420 | 'my_softplus', 421 | my_softplus, 422 | 2, 423 | 0., 424 | (-1., 1.), 425 | False, 426 | (math.log(2), .5, (math.log(1+math.exp(1)) - math.log(2) - .5, .125)), 427 | ), 428 | ( 429 | 'conv_general_dilated_a', 430 | # This returns x**2 * ones((1,1,1,1)). 431 | # pylint: disable=g-long-lambda 432 | lambda x: jax.lax.conv_general_dilated(x*jnp.ones((1, 1, 1, 1)), 433 | x*jnp.ones((1, 1, 1, 1)), 434 | [1, 1], 'VALID'), 435 | 2, 436 | 3., 437 | (0, 5), 438 | False, 439 | ( 440 | 9.*np.ones((1, 1, 1, 1)), 441 | 6.*np.ones((1, 1, 1, 1)), 442 | np.ones((1, 1, 1, 1)) 443 | ) 444 | ), 445 | ( 446 | 'conv_general_dilated_b', 447 | # This returns 5 * x**2 * ones((1,1,1,1)). 448 | # pylint: disable=g-long-lambda 449 | lambda x: jax.lax.conv_general_dilated(x*jnp.ones((1, 5, 1, 1)), 450 | x*jnp.ones((1, 5, 1, 1)), 451 | [1, 1], 'VALID'), 452 | 2, 453 | 3., 454 | (0, 5), 455 | False, 456 | ( 457 | 9.*5*np.ones((1, 1, 1, 1)), 458 | 6.*5*np.ones((1, 1, 1, 1)), 459 | 5*np.ones((1, 1, 1, 1)) 460 | ) 461 | ), 462 | ( 463 | 'conv_general_dilated_c', 464 | # This returns 2 * x**2 * ones((2,3,5,6)). 465 | # pylint: disable=g-long-lambda 466 | lambda x: jax.lax.conv_general_dilated(x*jnp.ones((2, 1, 5, 7)), 467 | x*jnp.ones((3, 1, 1, 2)), 468 | [1, 1], 'VALID'), 469 | 2, 470 | 3., 471 | (0, 5), 472 | False, 473 | ( 474 | 9.*2*np.ones((2, 3, 5, 6)), 475 | 6.*2*np.ones((2, 3, 5, 6)), 476 | 2*np.ones((2, 3, 5, 6)) 477 | ) 478 | ), 479 | # TODO(mstreeter): test convolutions where the input is not a scalar. 480 | ( 481 | 'jax_nn_sigmoid', 482 | jax.nn.sigmoid, 483 | 0, 484 | 0., 485 | (-1., 1.), 486 | False, 487 | ((test_utils.sigmoid(-1.), test_utils.sigmoid(1.)),) 488 | ), 489 | ( 490 | 'jax_nn_softplus', 491 | jax.nn.softplus, 492 | 0, 493 | 0., 494 | (-1., 1.), 495 | False, 496 | ((test_utils.softplus(-1.), test_utils.softplus(1.)),) 497 | ), 498 | ( 499 | 'jax_nn_swish', 500 | jax.nn.swish, 501 | 0, 502 | 2., 503 | (1., 3.), 504 | False, 505 | ((test_utils.swish(1), test_utils.swish(3)),) 506 | ), 507 | ( 508 | 'jax_nn_sigmoid_ndarray', 509 | jax.nn.sigmoid, 510 | 0, 511 | np.array([0., 1.]), 512 | (np.array([-1., -2.]), np.array([1., 2.])), 513 | False, 514 | ( 515 | ( 516 | np.array([test_utils.sigmoid(-1.), test_utils.sigmoid(-2.)]), 517 | np.array([test_utils.sigmoid(1.), test_utils.sigmoid(2.)]), 518 | ), 519 | ) 520 | ), 521 | ( 522 | 'jax_nn_softplus_ndarray', 523 | jax.nn.softplus, 524 | 0, 525 | np.array([0., 1.]), 526 | (np.array([-1., -2.]), np.array([1., 2.])), 527 | False, 528 | ( 529 | ( 530 | np.array([test_utils.softplus(-1.), 531 | test_utils.softplus(-2.)]), 532 | np.array([test_utils.softplus(1.), test_utils.softplus(2.)]) 533 | ), 534 | ) 535 | ), 536 | ( 537 | 'jax_nn_swish_ndarray', 538 | jax.nn.swish, 539 | 0, 540 | np.array([2., 3.]), 541 | (np.array([1., 2.]), np.array([3., 4.])), 542 | False, 543 | ( 544 | ( 545 | np.array([test_utils.swish(1.), test_utils.swish(2.)]), 546 | np.array([test_utils.swish(3.), test_utils.swish(4.)]) 547 | ), 548 | ) 549 | ), 550 | ) 551 | def test_taylor_bounds( 552 | self, f, max_degree, test_x0, test_trust_region, propagate_trust_regions, 553 | expected_coefficients): 554 | if callable(test_x0): 555 | test_x0 = test_x0() 556 | if callable(test_trust_region): 557 | test_trust_region = test_trust_region() 558 | g = jax_bound.taylor_bounds(f, max_degree, propagate_trust_regions) 559 | actual_coefficients = g(test_x0, test_trust_region).coefficients 560 | self.assert_enclosure_equal(expected_coefficients, actual_coefficients, 561 | rtol=1e-6) 562 | 563 | @classmethod 564 | def setUpClass(cls): 565 | super().setUpClass() 566 | jax_bound.register_elementwise_primitive( 567 | my_softplus_p, 568 | primitive_enclosures.softplus_enclosure 569 | ) 570 | 571 | 572 | if __name__ == '__main__': 573 | absltest.main() 574 | -------------------------------------------------------------------------------- /autobound/enclosure_arithmetic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The autobound Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Code for performing arithmetic on Taylor enclosures.""" 16 | 17 | import functools 18 | from typing import Callable, Optional, Union 19 | 20 | from autobound import interval_arithmetic 21 | from autobound import polynomials 22 | from autobound import primitive_enclosures 23 | # pylint: disable=g-multiple-import 24 | from autobound.types import ( 25 | ElementwiseTaylorEnclosure, ElementwiseTaylorEnclosureLike, Interval, 26 | IntervalLike, NDArray, NDArrayLike, NumpyLike, TaylorEnclosure, 27 | TaylorEnclosureLike) 28 | 29 | 30 | class TaylorEnclosureArithmetic: 31 | """Taylor enclosure arithmetic via a numpy-like back end. 32 | 33 | Given a specified trust region, maximum polynomial degree D, and NumpyLike 34 | back end, this class allows you to perform arithmetic operations on 35 | TaylorEnclosures, and get back TaylorEnclosures of degree <= D. 36 | 37 | What does it mean to perform arithmetic operations on Taylor enclosures? 38 | Let op(arg_0(x), arg_1(x)) be some binary operation (such as addition or 39 | multiplication). Let enclosure_i be a Taylor enclosure for arg_i: 40 | 41 | arg_i(x) in enclosure_i(x) for x-x0 in trust_region. 42 | 43 | Then, op(enclosure_0, enclosure_1) is a Taylor enclosure f, such that: 44 | 45 | op(arg_0(x), arg_1(x)) in f(x) for x-x0 in trust_region. 46 | 47 | Example usage: 48 | import jax.numpy as jnp 49 | trust_region = (jnp.zeros((3,)), jnp.array([1, 2, 3])) 50 | arithmetic = TaylorEnclosureArithmetic(2, trust_region, jnp) 51 | arithmetic.power( 52 | TaylorEnclosure((0, 1), 53 | 4 54 | ) # ==> quadratic TaylorEnclosure of x**4, valid for x in trust_region 55 | """ 56 | 57 | def __init__(self, 58 | max_degree: int, 59 | trust_region: IntervalLike, 60 | np_like: NumpyLike): 61 | """Initializer. 62 | 63 | Args: 64 | max_degree: the maximum degree polynomial output by any of the functions 65 | in this class 66 | trust_region: an interval that contains x-x0 (not x!). 67 | np_like: a NumpyLike Module 68 | """ 69 | if np_like.shape(trust_region[0]) != np_like.shape(trust_region[1]): 70 | raise ValueError(trust_region) 71 | if not isinstance(max_degree, int): 72 | raise ValueError(max_degree) 73 | 74 | self.max_degree = max_degree 75 | self.set_arithmetic = interval_arithmetic.IntervalArithmetic(np_like) 76 | self.trust_region = self.set_arithmetic.as_interval(trust_region) 77 | self.np_like = np_like 78 | 79 | def add(self, 80 | a: TaylorEnclosureLike, 81 | b: TaylorEnclosureLike) -> TaylorEnclosure: 82 | """Returns the (possibly-truncated) sum of two TaylorEnclosures.""" 83 | a = as_taylor_enclosure(a, self.np_like) 84 | b = as_taylor_enclosure(b, self.np_like) 85 | 86 | def get_coefficient(i): 87 | if i < len(a) and i < len(b): 88 | return self.set_arithmetic.add(a[i], b[i]) 89 | elif i < len(a): 90 | return a[i] 91 | else: 92 | return b[i] 93 | 94 | sum_ab = tuple(get_coefficient(i) for i in range(max(len(a), len(b)))) 95 | return self._truncate_if_necessary(sum_ab) 96 | 97 | def arbitrary_bilinear( 98 | self, 99 | a: TaylorEnclosureLike, 100 | b: TaylorEnclosureLike, 101 | pairwise_batched_bilinear: Callable[[NDArray, NDArray, int, int], NDArray] 102 | ) -> TaylorEnclosure: 103 | """Applies an arbitrary bilinear operation to two TaylorEnclosures. 104 | 105 | Args: 106 | a: a TaylorEnclosure 107 | b: a TaylorEnclosure 108 | pairwise_batched_bilinear: a callable that takes parameters (u, v, p, q), 109 | and returns the result of applying some underlying operation 'bilinear' 110 | to various pairs of arguments, where the first argument is indexed by 111 | the last p dimensions of u, and the second argument is indexed by the 112 | last q dimensions of v. 113 | 114 | In the special case p=q=1, 115 | 116 | pairwise_batched_bilinear(u, v, 1, 1)[..., i, j] 117 | 118 | == bilinear(u[..., i], v[..., j]) for all i, j. 119 | 120 | In the general case, for all tuples I = (i_1, i_2, ..., i_p), and 121 | J = (j_1, j_2, ..., j_q), 122 | 123 | pairwise_batched_bilinear(u, v, p, q)[(...,) + I + J] 124 | 125 | == bilinear(u[(...,) + I], v[(...,) + J]) . 126 | 127 | Returns: 128 | a TaylorEnclosure c, such that for x-x0 in self.trust_region: 129 | 130 | arg_0 in a(x) and arg_1 in b(x) ==> bilinear(arg0, arg1) in c(x) 131 | 132 | where 'bilinear' is pairwise_batched_bilinear's underlying operation. 133 | """ 134 | # Our goal is to compute a Taylor enclosure for bilinear(a, b), where: 135 | # 136 | # a = sum_i 137 | # b = sum_j 138 | # 139 | # Here, we use ^ to denote self.set_arithmetic.outer_power, and we use 140 | # to denote self.set_arithmetic.tensordot(u, v, v.ndim). 141 | # 142 | # Using bilinearity, we can show: 143 | # 144 | # bilinear(a, b) 145 | # == sum_{i, j} bilinear(, ). 146 | # 147 | # Thus, it suffices to compute a Taylor enclosure for each term, and then 148 | # sum those up. 149 | # 150 | # How do we compute a Taylor enclosure for each term? It can be shown that: 151 | # 152 | # bilinear(, ) 153 | # == bilinear(pairwise_batched_bilinear(a[i], b[j], i*x_ndim, j*x_ndim), 154 | # (x-x0)^(i+j)) . 155 | # 156 | # If i + j <= self.max_degree, we use simply use this formula to return 157 | # the coefficient for term (i, j). Otherwise, we make use of the trust 158 | # region to enclose the left hand side in terms of (x-x0)^self.max_degree. 159 | 160 | x_ndim = self.np_like.ndim(self.trust_region[0]) 161 | 162 | def get_term_enclosure_coefficient( 163 | i: int, j: int) -> Union[NDArray, Interval]: 164 | r"""Returns coefficient of enclosure for term coming from (a[i], b[j]). 165 | 166 | Args: 167 | i: an index into a 168 | j: an index into b 169 | 170 | Returns: 171 | an NDArray or Interval c, such that: 172 | 173 | bilinear(inner(a[i], z^i), inner(b[j], z^j)) \subseteq inner(c, z^k) 174 | 175 | where k = min(i+j, self.max_degree). 176 | """ 177 | if i + j <= self.max_degree: 178 | c0, c1, c0_power, c1_power = (a[i], b[j], i, j) 179 | else: 180 | excess_degree = i + j - self.max_degree 181 | # Distribute the excess degree between the two factors. 182 | excess_i = min(i, excess_degree) 183 | excess_j = excess_degree - excess_i 184 | assert (i - excess_i) + (j - excess_j) == self.max_degree 185 | c0 = self.set_arithmetic.tensordot( 186 | a[i], 187 | self.set_arithmetic.outer_power(self.trust_region, excess_i), 188 | excess_i * x_ndim 189 | ) 190 | if excess_j > 0: 191 | c1 = self.set_arithmetic.tensordot( 192 | b[j], 193 | self.set_arithmetic.outer_power(self.trust_region, excess_j), 194 | excess_j * x_ndim 195 | ) 196 | else: 197 | c1 = b[j] 198 | c0_power = i - excess_i 199 | c1_power = j - excess_j 200 | 201 | def bilinear(y, z): 202 | return pairwise_batched_bilinear(y, z, c0_power*x_ndim, 203 | c1_power*x_ndim) 204 | 205 | return self.set_arithmetic.arbitrary_bilinear(c0, c1, bilinear) 206 | 207 | # Sum up the coefficients from each term. 208 | output_degree = min(self.max_degree, len(a) + len(b) - 2) 209 | product_coefficients = [self.np_like.asarray(0)] * (output_degree + 1) 210 | for i in range(len(a)): 211 | for j in range(len(b)): 212 | term_coefficient = get_term_enclosure_coefficient(i, j) 213 | term_degree = min(i+j, self.max_degree) 214 | product_coefficients[term_degree] = self.set_arithmetic.add( 215 | product_coefficients[term_degree], 216 | term_coefficient, 217 | ) 218 | 219 | return TaylorEnclosure(tuple(product_coefficients)) 220 | 221 | def compose_enclosures( 222 | self, 223 | elementwise_enclosure: ElementwiseTaylorEnclosureLike, 224 | arg_enclosure: TaylorEnclosureLike) -> TaylorEnclosure: 225 | """Returns composition of two enclosures.""" 226 | # TODO(mstreeter): we can potentially do something more efficient than 227 | # computing each term separately and summing the results. 228 | if not elementwise_enclosure: 229 | raise ValueError() 230 | arg_diff_enclosure = (0,) + arg_enclosure[1:] 231 | output = TaylorEnclosure((self.np_like.array(0),)) 232 | 233 | def interval_left_broadcasting_multiply(a, b): 234 | bilinear = functools.partial(_left_broadcasting_multiply, 235 | np_like=self.np_like) 236 | return self.set_arithmetic.arbitrary_bilinear(a, b, bilinear, 237 | assume_product=True) 238 | 239 | for p, coefficient in enumerate(elementwise_enclosure): 240 | if p == 0: 241 | term = (coefficient,) 242 | else: 243 | poly = self.power(arg_diff_enclosure, p) 244 | term = tuple( 245 | # The special-casing when i < p ensures that the TaylorEnclosure 246 | # returned by this function will not contain trivial intervals 247 | # the form (x, x). 248 | # 249 | # The issue is that the first p elements of 'poly' are guaranteed 250 | # to be 0, but the expression below can express this as the interval 251 | # (0, 0). 252 | 0 if i < p else interval_left_broadcasting_multiply(coefficient, t) 253 | for i, t in enumerate(poly) 254 | ) 255 | output = self.add(output, term) 256 | assert output is not None 257 | return output 258 | 259 | def divide(self, 260 | a: TaylorEnclosureLike, 261 | b: TaylorEnclosureLike) -> TaylorEnclosure: 262 | return self.multiply(a, self.power(b, -1)) 263 | 264 | def get_elementwise_fun( 265 | self, 266 | get_elementwise_enclosure: Callable[ 267 | [NDArray, Interval, int, NumpyLike], 268 | ElementwiseTaylorEnclosure 269 | ]): 270 | """Returns elementwise function that inputs/output TaylorEnclosures.""" 271 | def fun( 272 | arg_enclosure: TaylorEnclosureLike, 273 | arg_trust_region: Optional[Union[NDArrayLike, IntervalLike]] = None 274 | ) -> TaylorEnclosure: 275 | if arg_trust_region is None: 276 | # If arg_trust_region is not provided derive it from arg_enclosure. 277 | degree_0_enclosure = enclose_enclosure(arg_enclosure, self.trust_region, 278 | 0, self.np_like) 279 | arg_trust_region = degree_0_enclosure[0] 280 | arg_trust_region = self.set_arithmetic.as_interval_or_ndarray( 281 | arg_trust_region) 282 | if not isinstance(arg_trust_region, tuple): 283 | arg_trust_region = (arg_trust_region, arg_trust_region) 284 | x0 = arg_enclosure[0] 285 | if isinstance(x0, tuple): 286 | assert self.max_degree == 0 287 | assert len(x0) == 2 288 | x0 = x0[0] 289 | elementwise_enclosure = get_elementwise_enclosure( 290 | x0, 291 | arg_trust_region, 292 | self.max_degree, 293 | self.np_like) 294 | return self.compose_enclosures(elementwise_enclosure, arg_enclosure) 295 | return fun 296 | 297 | def multiply(self, 298 | a: TaylorEnclosureLike, 299 | b: TaylorEnclosureLike) -> TaylorEnclosure: 300 | """Returns elementwise product of two TaylorEnclosures.""" 301 | self._validate_taylor_enclosure(a) 302 | self._validate_taylor_enclosure(b) 303 | term_product_coefficient = functools.partial( 304 | _elementwise_term_product_coefficient, 305 | x_ndim=self.np_like.ndim(self.trust_region[0]), 306 | np_like=self.np_like) 307 | product = polynomials.arbitrary_bilinear( 308 | a, 309 | b, 310 | self.set_arithmetic.add, 311 | self.np_like.asarray(0), 312 | term_product_coefficient 313 | ) 314 | return self._truncate_if_necessary(product) 315 | 316 | def negative(self, a: TaylorEnclosureLike) -> TaylorEnclosure: 317 | return self._truncate_if_necessary( 318 | TaylorEnclosure(tuple(self.set_arithmetic.negative(c) for c in a))) 319 | 320 | def power(self, a: TaylorEnclosureLike, p: float) -> TaylorEnclosure: 321 | """Returns a TaylorEnclosure for a**p, of degree <= self.max_degree.""" 322 | self._validate_taylor_enclosure(a) 323 | if p >= 0 and p == int(p): 324 | x_ndim = self.np_like.ndim(self.trust_region[0]) 325 | term_product_coefficient = functools.partial( 326 | _elementwise_term_product_coefficient, x_ndim=x_ndim, 327 | np_like=self.np_like) 328 | term_power_coefficient = functools.partial( 329 | _elementwise_term_power_coefficient, x_ndim=x_ndim, 330 | np_like=self.np_like) 331 | multiplicative_identity = self.np_like.ones_like(self.trust_region[0]) 332 | result = polynomials.integer_power( # pytype: disable=wrong-arg-types 333 | a, 334 | p, 335 | self.set_arithmetic.add, 336 | self.np_like.asarray(0), 337 | multiplicative_identity, 338 | term_product_coefficient, 339 | term_power_coefficient, 340 | self.set_arithmetic.multiply 341 | ) 342 | return self._truncate_if_necessary(result) 343 | else: 344 | get_elementwise_enclosure = functools.partial( 345 | primitive_enclosures.pow_enclosure, p) 346 | return self.get_elementwise_fun(get_elementwise_enclosure)(a) 347 | 348 | def subtract(self, 349 | a: TaylorEnclosureLike, 350 | b: TaylorEnclosureLike) -> TaylorEnclosure: 351 | return self.add(a, self.negative(b)) 352 | 353 | def _truncate_if_necessary(self, a: TaylorEnclosureLike) -> TaylorEnclosure: 354 | return enclose_enclosure(a, self.trust_region, self.max_degree, 355 | self.np_like) 356 | 357 | def _validate_taylor_enclosure(self, a: TaylorEnclosureLike): 358 | x_shape = self.set_arithmetic.shape(self.trust_region) 359 | for i, coeff in enumerate(a): 360 | s = self.set_arithmetic.shape(coeff) 361 | if s[len(s)-i*len(x_shape):] != i*x_shape: 362 | raise ValueError(x_shape, i, s, coeff, a) 363 | 364 | 365 | def as_taylor_enclosure(a: TaylorEnclosureLike, 366 | np_like: NumpyLike) -> TaylorEnclosure: 367 | set_arithmetic = interval_arithmetic.IntervalArithmetic(np_like) 368 | return TaylorEnclosure( 369 | tuple(set_arithmetic.as_interval_or_ndarray(c) for c in a)) 370 | 371 | 372 | def enclose_enclosure( 373 | enclosure: TaylorEnclosureLike, 374 | trust_region: IntervalLike, 375 | max_degree: int, 376 | np_like: NumpyLike, 377 | ) -> TaylorEnclosure: 378 | """Returns a (possibly) lower-degree enclosure of a given TaylorEnclosure.""" 379 | set_arithmetic = interval_arithmetic.IntervalArithmetic(np_like) 380 | trust_region = set_arithmetic.as_interval(trust_region) 381 | enclosure = as_taylor_enclosure(enclosure, np_like) 382 | orig_degree = len(enclosure) - 1 383 | if orig_degree <= max_degree: 384 | return enclosure 385 | else: 386 | new_final_coefficient = polynomials.eval_taylor_enclosure( 387 | enclosure[max_degree:], trust_region, set_arithmetic.np_like) 388 | return TaylorEnclosure(enclosure[:max_degree] + (new_final_coefficient,)) 389 | 390 | 391 | def expand_multiple_dims(a: NDArray, n: int, axis=None) -> NDArray: 392 | """Like expand_dims, but adds n dims rather than just 1.""" 393 | if axis is None: 394 | axis = a.ndim 395 | colon = slice(None, None, None) 396 | return a[(colon,) * axis + (None,) * n + (...,)] 397 | 398 | 399 | def map_over_enclosure( 400 | a: TaylorEnclosure, 401 | fun: Callable[[NDArray], NDArray]) -> TaylorEnclosure: 402 | """Apply a function to each NDArray in a TaylorEnclosure.""" 403 | return TaylorEnclosure( 404 | tuple(tuple(map(fun, c)) if isinstance(c, tuple) else fun(c) for c in a)) 405 | 406 | 407 | def _elementwise_term_power_coefficient( 408 | c: Union[NDArrayLike, IntervalLike], 409 | i: int, 410 | exponent: int, 411 | x_ndim: int, 412 | np_like: NumpyLike) -> Union[NDArray, Interval]: 413 | """Returns d such that ^exponent == . 414 | 415 | Args: 416 | c: a coefficient 417 | i: a non-negative integer 418 | exponent: a non-negative integer 419 | x_ndim: the number of dimensions in the independent variable 420 | np_like: a Numpy-like backend 421 | 422 | Returns: 423 | an NDArray or Interval d, such that 424 | ^exponent == . 425 | where ** denotes outer product, and ^ denotes elementwise exponentiation. 426 | """ 427 | set_arithmetic = interval_arithmetic.IntervalArithmetic(np_like) 428 | batch_dims = set_arithmetic.ndim(c) - i*x_ndim 429 | if batch_dims < 0: 430 | raise ValueError((set_arithmetic.ndim(c), i, x_ndim)) 431 | return set_arithmetic.outer_power(c, exponent, batch_dims) 432 | 433 | 434 | def _elementwise_term_product_coefficient( 435 | c0: Union[NDArrayLike, IntervalLike], 436 | c1: Union[NDArrayLike, IntervalLike], 437 | i: int, 438 | j: int, 439 | x_ndim: int, 440 | np_like: NumpyLike) -> Union[NDArray, Interval]: 441 | """Returns d such that * == .""" 442 | def product(u, v): 443 | return _pairwise_batched_multiply(u, v, i*x_ndim, j*x_ndim, np_like) 444 | set_arithmetic = interval_arithmetic.IntervalArithmetic(np_like) 445 | return set_arithmetic.arbitrary_bilinear(c0, c1, product, assume_product=True) 446 | 447 | 448 | def _pairwise_batched_multiply( 449 | u: NDArrayLike, 450 | v: NDArrayLike, 451 | p: int, 452 | q: int, 453 | np_like: NumpyLike) -> NDArray: 454 | """Batched version of multiply, for use as input to arbitrary_bilinear(). 455 | 456 | See the docstring for TaylorEnclosureArithmetic.arbitrary_bilinear for 457 | context. 458 | 459 | Args: 460 | u: an NDArray of dimension at least p 461 | v: an NDArray of dimension at least q 462 | p: a non-negative integer 463 | q: a non-negative integer 464 | np_like: a NumpyLike back end 465 | 466 | Returns: 467 | an NDArray 'output', such that for every pair of tuples 468 | I = (i_1, i_2, ..., i_p) and J = (j_1, j_2, ..., j_q), 469 | 470 | output[(...,) + I + J] = u[(...,) + I] * v[(...,) + J] . 471 | """ 472 | u = np_like.asarray(u) 473 | v = np_like.asarray(v) 474 | return expand_multiple_dims(u, q) * expand_multiple_dims(v, p, v.ndim-q) 475 | 476 | 477 | def _left_broadcasting_multiply(a: NDArrayLike, b: NDArrayLike, 478 | np_like: NumpyLike) -> NDArray: 479 | """Multiplies a and b, broadcasting over leftmost dimensions.""" 480 | a = np_like.asarray(a) 481 | b = np_like.asarray(b) 482 | if a.ndim > b.ndim: 483 | raise NotImplementedError() 484 | return expand_multiple_dims(a, b.ndim - a.ndim) * b 485 | -------------------------------------------------------------------------------- /autobound/jax/jax_bound.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The autobound Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Code for computing Taylor enclosures in JAX.""" 16 | 17 | import dataclasses 18 | import functools 19 | from typing import Callable, Dict, List, Optional, Set, Tuple, Union 20 | 21 | from autobound import enclosure_arithmetic 22 | from autobound import interval_arithmetic 23 | from autobound import polynomials 24 | from autobound import primitive_enclosures 25 | from autobound import types 26 | from autobound.jax import jaxpr_editor 27 | import jax 28 | import jax.numpy as jnp 29 | 30 | 31 | @dataclasses.dataclass 32 | class TaylorBounds: 33 | """Upper and lower bounds on a function, valid over a trust region.""" 34 | f: Callable[[jnp.array], jnp.array] 35 | x0: jnp.ndarray 36 | # Interval containing values of x (not x-x0) for which the bound on f(x) 37 | # holds. 38 | x_trust_region: Tuple[jnp.array, jnp.array] 39 | coefficients: types.TaylorEnclosure 40 | 41 | def __call__(self, 42 | x: types.NDArrayLike) -> Union[types.NDArray, types.Interval]: 43 | x = jnp.asarray(x) 44 | return polynomials.eval_taylor_enclosure(self.coefficients, x-self.x0, jnp) 45 | 46 | def final_interval(self) -> Tuple[jnp.array, jnp.array]: 47 | """Returns final coefficient (as a trivial interval, if it is scalar).""" 48 | c = self.coefficients[-1] 49 | return c if isinstance(c, tuple) else (c, c) 50 | 51 | def lower(self, x): 52 | bound = self(x) 53 | return bound[0] if isinstance(bound, tuple) else bound 54 | 55 | def upper(self, x): 56 | bound = self(x) 57 | return bound[1] if isinstance(bound, tuple) else bound 58 | 59 | 60 | def taylor_bounds( 61 | f: Callable[[jnp.array], jnp.array], 62 | max_degree: int, 63 | propagate_trust_regions: bool = False, 64 | ) -> Callable[[jnp.array, Tuple[jnp.array, jnp.array]], TaylorBounds]: 65 | """Returns version of f that returns a TaylorBounds object. 66 | 67 | Args: 68 | f: a function that takes a jnp.array as input, and returns a jnp.array 69 | max_degree: the maximum degree TaylorEnclosure for the returned function 70 | to return 71 | propagate_trust_regions: if True, trust regions are propagated 72 | through the Jaxpr, rather than being computed from higher-degree 73 | enclosures. This results in tighter bounds at the cost of additional 74 | memory. 75 | 76 | Returns: 77 | a function that takes as input a jnp.array x0, and a trust region 78 | (min_vals, max_vals), and return a TaylorBounds object `bound` such that 79 | `bound.coefficients` is a TaylorEnclosure g of degree at most max_degree, 80 | such that: 81 | 82 | f(x) in g(x-x0) for all x with min_vals <= x <= max_vals 83 | """ 84 | if max_degree < 0: 85 | raise ValueError(max_degree) 86 | if max_degree == 0: 87 | propagate_trust_regions = False # avoid redundant computation 88 | 89 | jaxpr_factory = jax.make_jaxpr(f) 90 | def bound_fun(x0: jnp.array, 91 | x_trust_region: types.Interval) -> TaylorBounds: 92 | trust_region = interval_arithmetic.IntervalArithmetic(jnp).subtract( 93 | x_trust_region, x0) 94 | 95 | arithmetic = enclosure_arithmetic.TaylorEnclosureArithmetic( 96 | max_degree, trust_region, jnp) 97 | primitive_to_enclosure_fun = _pushforward_funs(arithmetic) 98 | 99 | degree_0_arithmetic = ( 100 | enclosure_arithmetic.TaylorEnclosureArithmetic(0, trust_region, jnp)) 101 | primitive_to_enclosure_fun0 = _pushforward_funs(degree_0_arithmetic) 102 | 103 | closed_jaxpr = jaxpr_factory(x0) 104 | jaxpr = _rewrite_jaxpr(closed_jaxpr.jaxpr) 105 | 106 | x0 = jnp.asarray(x0) 107 | if x0.ndim == 0: 108 | identity = jnp.asarray(1.) 109 | elif x0.ndim == 1: 110 | identity = jnp.eye(x0.shape[0]) 111 | else: 112 | raise NotImplementedError(x0.ndim) 113 | x0_enclosure = types.TaylorEnclosure( 114 | (x0, identity) if max_degree > 0 else (x_trust_region,)) 115 | assert len(closed_jaxpr.consts) == len(jaxpr.constvars) 116 | var_to_intermediate = { 117 | var: _constant_intermediate_enclosure(val) 118 | for var, val in zip(jaxpr.constvars, closed_jaxpr.consts) 119 | } 120 | assert len(jaxpr.invars) == 1 121 | var_to_intermediate[jaxpr.invars[0]] = _IntermediateEnclosure( 122 | enclosure=x0_enclosure, 123 | trust_region=x_trust_region if propagate_trust_regions else None 124 | ) 125 | 126 | def get_intermediate( 127 | invar: Union[jax.core.Var, jax.core.Literal]) -> _IntermediateEnclosure: 128 | if isinstance(invar, jax.core.Var): 129 | return var_to_intermediate[invar] 130 | else: 131 | assert isinstance(invar, jax.core.Literal) 132 | return _constant_intermediate_enclosure(invar.val) 133 | 134 | for eqn in jaxpr.eqns: 135 | invar_intermediates = [get_intermediate(invar) for invar in eqn.invars] 136 | has_non_constant_invars = any(not intermediate.is_constant() 137 | for intermediate in invar_intermediates) 138 | if has_non_constant_invars: 139 | fun = primitive_to_enclosure_fun.get(eqn.primitive) 140 | if fun is None: 141 | raise NotImplementedError(eqn.primitive) 142 | outvar_enclosures = fun(*invar_intermediates, **eqn.params) 143 | if len(eqn.outvars) == 1: 144 | outvar_enclosures = (outvar_enclosures,) 145 | if propagate_trust_regions: 146 | fun0 = primitive_to_enclosure_fun0.get(eqn.primitive) 147 | assert fun0 is not None 148 | assert all(i.trust_region is not None for i in invar_intermediates) 149 | invar_degree_0_intermediates = [ 150 | _IntermediateEnclosure( 151 | enclosure=types.TaylorEnclosure((intermediate.trust_region,))) 152 | for intermediate in invar_intermediates 153 | ] 154 | outvar_degree_0_enclosures_a = fun0(*invar_degree_0_intermediates, 155 | **eqn.params) 156 | if len(eqn.outvars) == 1: 157 | outvar_degree_0_enclosures_a = [outvar_degree_0_enclosures_a] 158 | assert len(outvar_degree_0_enclosures_a) == len(outvar_enclosures) 159 | outvar_degree_0_enclosures_b = [ 160 | enclosure_arithmetic.enclose_enclosure(enclosure, trust_region, 161 | 0, jnp) 162 | for enclosure in outvar_enclosures 163 | ] 164 | outvar_trust_regions = [ 165 | _intersect_intervals(a[0], b[0]) 166 | for a, b in zip(outvar_degree_0_enclosures_a, 167 | outvar_degree_0_enclosures_b) 168 | ] 169 | for i, (a, b) in enumerate(outvar_trust_regions): 170 | # It should always be the case that the actual value of the ith 171 | # output of a function (y0 below) is inside the associated trust 172 | # region. But this invariant may not hold due to floating 173 | # point roundoff error, so we enforce it here. 174 | # 175 | # TODO(mstreeter): add a test case that fails if we remove this. 176 | y0 = outvar_enclosures[i][0] 177 | outvar_trust_regions[i] = (jnp.minimum(y0, a), jnp.maximum(y0, b)) # pytype: disable=wrong-arg-types 178 | else: 179 | outvar_trust_regions = (None,) * len(outvar_enclosures) 180 | assert all(isinstance(v, tuple) for v in outvar_enclosures), ( 181 | eqn.primitive, fun, outvar_enclosures) 182 | outvar_intermediates = tuple( 183 | _IntermediateEnclosure(enclosure=e, trust_region=r) 184 | for r, e in zip(outvar_trust_regions, outvar_enclosures) 185 | ) 186 | else: 187 | invar_values = tuple(intermediate.constant_value() 188 | for intermediate in invar_intermediates) 189 | vals = eqn.primitive.bind(*invar_values, **eqn.params) 190 | if len(eqn.outvars) == 1: 191 | vals = (vals,) 192 | outvar_intermediates = [_constant_intermediate_enclosure(v) 193 | for v in vals] 194 | 195 | assert len(outvar_intermediates) == len(eqn.outvars), ( 196 | eqn.primitive, len(outvar_intermediates), len(eqn.outvars)) 197 | for var, intermediate in zip(eqn.outvars, outvar_intermediates): 198 | assert var not in var_to_intermediate 199 | assert isinstance(intermediate.enclosure, tuple), ( 200 | eqn.primitive, intermediate) 201 | _validate_taylor_enclosure(intermediate.enclosure, x0.shape) 202 | var_to_intermediate[var] = intermediate 203 | 204 | assert len(jaxpr.outvars) == 1 205 | output_intermediate = get_intermediate(jaxpr.outvars[0]) 206 | return TaylorBounds(f=f, x0=x0, x_trust_region=x_trust_region, 207 | coefficients=output_intermediate.enclosure) 208 | 209 | return bound_fun 210 | 211 | 212 | # Type for functions that generate enclosures for primitive elementwise 213 | # functions. The callable takes arguments x0, trust_region, degree, and 214 | # np_like, and returns an ElementwiseTaylorEnclosure (see examples in 215 | # primitive_enclosures.py). 216 | ElementwiseEnclosureGeneratingFunction = Callable[ 217 | [types.NDArray, types.Interval, int, types.NumpyLike], 218 | types.ElementwiseTaylorEnclosure 219 | ] 220 | 221 | 222 | # TODO(mstreeter): add a mechanism for supporting a new elementwise function 223 | # given its FunctionData. 224 | def register_elementwise_primitive( 225 | p: jax.core.Primitive, 226 | get_enclosure: ElementwiseEnclosureGeneratingFunction): 227 | """Register an enclosure-generating function for a user-defined primitive. 228 | 229 | Args: 230 | p: a jax.core.Primitive 231 | get_enclosure: an ElementwiseEnclosureGeneratingFunction for p. 232 | """ 233 | _ELEMENTWISE_PRIMITIVE_ENCLOSURES[p] = get_enclosure 234 | 235 | 236 | # 237 | # Private variables/functions. 238 | # 239 | 240 | 241 | _PRIMITIVE_NAMES = set() # type: Set[str] 242 | # Rewrite rules are callables that return (pattern Jaxpr, replacement Jaxpr) 243 | # pairs. We make them callables because the Jaxpr returned by jax.make_jaxpr 244 | # depends on how Jax is configured (in particular whether float64 is enabled), 245 | # and we need to use the Jaxprs that match whatever configuration is being 246 | # used when the rule is applied. 247 | _JAXPR_REWRITE_RULES = [ 248 | ] # type: List[Callable[[], Tuple[jax.core.Jaxpr, jax.core.Jaxpr]]] 249 | 250 | 251 | def _register_elementwise_function( 252 | f: Callable[[jnp.array], jnp.array], 253 | get_enclosure: ElementwiseEnclosureGeneratingFunction 254 | ): 255 | """Register enclosure-generating function for elementwise Jax function.""" 256 | name = f'__autobound_{f.__name__}__' 257 | if name in _PRIMITIVE_NAMES: 258 | raise ValueError(f) 259 | _PRIMITIVE_NAMES.add(name) 260 | p = jax.core.Primitive(name) 261 | p.def_abstract_eval( 262 | lambda x: jax.abstract_arrays.ShapedArray(x.shape, x.dtype)) 263 | rule = lambda: (jax.make_jaxpr(f)(0.).jaxpr, jax.make_jaxpr(p.bind)(0.).jaxpr) 264 | _JAXPR_REWRITE_RULES.append(rule) 265 | register_elementwise_primitive(p, get_enclosure) 266 | 267 | 268 | # Dict from Jax Primitive to ElementwiseEnclosureGeneratingFunction. 269 | # TODO(mstreeter): support more elementwise functions. 270 | _ELEMENTWISE_PRIMITIVE_ENCLOSURES = { 271 | jax.lax.abs_p: primitive_enclosures.abs_enclosure, 272 | jax.lax.exp_p: primitive_enclosures.exp_enclosure, 273 | jax.lax.log_p: primitive_enclosures.log_enclosure, 274 | } 275 | _register_elementwise_function(jax.nn.sigmoid, 276 | primitive_enclosures.sigmoid_enclosure) 277 | _register_elementwise_function(jax.nn.softplus, 278 | primitive_enclosures.softplus_enclosure) 279 | _register_elementwise_function(jax.nn.swish, 280 | primitive_enclosures.swish_enclosure) 281 | 282 | 283 | # Set of primitives that can be applied separately to each coefficient of a 284 | # TaylorEnclosure. 285 | _PASS_THRU_PRIMITIVES = frozenset([ 286 | jax.lax.convert_element_type_p, 287 | jax.lax.reshape_p, 288 | jax.lax.reduce_sum_p, 289 | jax.lax.reduce_window_sum_p, 290 | jax.lax.squeeze_p, 291 | jax.lax.transpose_p, 292 | # TODO(mstreeter): add more of these 293 | ]) 294 | 295 | 296 | def _rewrite_jaxpr(jaxpr: jax.core.Jaxpr) -> jax.core.Jaxpr: 297 | """Rewrite a Jaxpr to make is suitable for use by taylor_bounds().""" 298 | for rule_generator in _JAXPR_REWRITE_RULES: 299 | pattern, replacement = rule_generator() 300 | jaxpr = jaxpr_editor.replace(pattern, replacement, jaxpr) 301 | return jaxpr 302 | 303 | 304 | @dataclasses.dataclass 305 | class _IntermediateEnclosure: 306 | """An enclosure for some intermediate variable in a Jaxpr.""" 307 | enclosure: types.TaylorEnclosure 308 | trust_region: Optional[types.Interval] = None 309 | 310 | def is_constant(self) -> bool: 311 | """Returns whether self.enclosure represents a constant value.""" 312 | return len(self.enclosure) == 1 and not isinstance(self.enclosure[0], tuple) 313 | 314 | def constant_value(self) -> types.NDArray: 315 | if not self.is_constant(): 316 | raise ValueError() 317 | else: 318 | return self.enclosure[0] # pytype: disable=bad-return-type 319 | 320 | 321 | def _broadcast_in_dim_pushforward_fun(intermediate, shape, 322 | broadcast_dimensions): 323 | """Enclosure-generating function for jax.lax.broadcast_in_dim.""" 324 | enclosure = intermediate.enclosure 325 | x0 = enclosure[0] 326 | if isinstance(x0, tuple): 327 | x0 = enclosure[0][0] 328 | x_shape = (() if len(enclosure) == 1 else enclosure[1].shape[x0.ndim:]) 329 | def broadcast_ndarray(a, i): 330 | return jax.lax.broadcast_in_dim(a, shape + i*x_shape, broadcast_dimensions) 331 | def broadcast_ndarray_or_interval(a, i): 332 | if isinstance(a, tuple): 333 | return tuple(broadcast_ndarray(x, i) for x in a) 334 | else: 335 | return broadcast_ndarray(a, i) 336 | return tuple( 337 | broadcast_ndarray_or_interval(coeff, i) 338 | for i, coeff in enumerate(enclosure) 339 | ) 340 | 341 | 342 | def _constant_intermediate_enclosure(val: types.NDArray): 343 | return _IntermediateEnclosure(enclosure=types.TaylorEnclosure((val,)), 344 | trust_region=(val, val)) 345 | 346 | 347 | def _conv_general_dilated_pushforward_fun(arithmetic): 348 | """Returns function that implements conv_general_dilated on enclosures.""" 349 | def fun(lhs_intermediate: _IntermediateEnclosure, 350 | rhs_intermediate: _IntermediateEnclosure, 351 | **params): 352 | def pairwise_batched_bilinear(a: jnp.array, b: jnp.array, 353 | p: int, q: int) -> jnp.array: 354 | def move_last_n_dims_to_front(x: jnp.array, n: int): 355 | if n == 0: 356 | return x 357 | perm = tuple(range(x.ndim - n, x.ndim)) + tuple(range(x.ndim - n)) 358 | transposed = jnp.transpose(x, axes=perm) 359 | return jnp.reshape(transposed, (-1,) + x.shape[:x.ndim-n]) 360 | 361 | a_reshaped = move_last_n_dims_to_front(a, p) 362 | b_reshaped = move_last_n_dims_to_front(b, q) 363 | 364 | c = jax.lax.conv_general_dilated_p.bind(a_reshaped, b_reshaped, **params) 365 | if p == 0 and q == 0: 366 | return c 367 | elif p == 0 or q == 0: 368 | raise NotImplementedError((p, q)) 369 | c_perm = tuple(range(2, c.ndim)) + (0, 1) 370 | c_transposed = jnp.transpose(c, axes=c_perm) 371 | return jnp.reshape(c_transposed, 372 | c.shape[2:] + a.shape[a.ndim-p:] + b.shape[b.ndim-q:]) 373 | 374 | return arithmetic.arbitrary_bilinear( 375 | lhs_intermediate.enclosure, 376 | rhs_intermediate.enclosure, 377 | pairwise_batched_bilinear) 378 | return fun 379 | 380 | 381 | def _dot_general_pushforward_fun(arithmetic): 382 | """Returns function that implements dot_general on enclosures.""" 383 | def fun(lhs_intermediate: _IntermediateEnclosure, 384 | rhs_intermediate: _IntermediateEnclosure, 385 | **params): 386 | a_contracting_dims = set(a # pylint: disable=g-complex-comprehension 387 | for t in params['dimension_numbers'] 388 | for a in t[0]) 389 | def pairwise_batched_bilinear(a: jnp.array, b: jnp.array, 390 | p: int, q: int) -> jnp.array: 391 | transposed_output = jax.lax.dot_general_p.bind(a, b, **params) 392 | p_start = a.ndim - p - len(a_contracting_dims) 393 | assert p_start >= 0 394 | n = transposed_output.ndim 395 | # Shift axes p_start through p_start+p to the right, so that they start 396 | # at position n-q. 397 | assert p_start + p <= n-q, (p_start, p, q, n) 398 | perm = ( 399 | tuple(range(p_start)) + 400 | tuple(range(p_start+p, n-q)) + 401 | tuple(range(p_start, p_start + p)) + 402 | tuple(range(n-q, n)) 403 | ) 404 | assert len(set(perm)) == n, (p_start, p, q, n, perm) 405 | return jnp.transpose(transposed_output, axes=perm) 406 | return arithmetic.arbitrary_bilinear( 407 | lhs_intermediate.enclosure, 408 | rhs_intermediate.enclosure, 409 | pairwise_batched_bilinear 410 | ) 411 | return fun 412 | 413 | 414 | def _elementwise_pushforward_fun(arithmetic, get_enclosure): 415 | f = arithmetic.get_elementwise_fun(get_enclosure) 416 | def g(intermediate): 417 | return f(intermediate.enclosure, intermediate.trust_region) 418 | return g 419 | 420 | 421 | def _intersect_intervals( 422 | a: types.Interval, b: types.Interval) -> types.Interval: 423 | if not len(a) == len(b) == 2: 424 | raise ValueError() 425 | return (jnp.maximum(a[0], b[0]), jnp.minimum(a[1], b[1])) # pytype: disable=wrong-arg-types 426 | 427 | 428 | def _pass_thru_pushforward_fun(primitive): 429 | def fun(intermediate, **params): 430 | return enclosure_arithmetic.map_over_enclosure( 431 | intermediate.enclosure, 432 | functools.partial(primitive.bind, **params) 433 | ) 434 | return fun 435 | 436 | 437 | # A pushforward function for an underlying primitive with K inputs and N 438 | # outputs takes K _IntermediateEnclosure as arguments, plus kwargs for any 439 | # parameters the primitive has, and returns a tuple of N TaylorEnclosures (or 440 | # in the special case N=1, a single TaylorEnclosure rather than a tuple). 441 | PushforwardFunction = Callable[ 442 | ..., 443 | Union[types.TaylorEnclosure, Tuple[types.TaylorEnclosure, ...]] 444 | ] 445 | 446 | 447 | def _pushforward_funs( 448 | arithmetic: enclosure_arithmetic.TaylorEnclosureArithmetic 449 | ) -> Dict[jax.core.Primitive, PushforwardFunction]: 450 | """Returns dict from primitive to function that inputs/outputs enclosures.""" 451 | def pushforward_integer_pow(intermediate, y: int): 452 | return arithmetic.power(intermediate.enclosure, y) 453 | 454 | def pushforward_pow(intermediate_0, intermediate_1): 455 | if not intermediate_1.is_constant(): 456 | raise NotImplementedError() 457 | exponent = float(intermediate_1.constant_value()) 458 | return arithmetic.power(intermediate_0.enclosure, exponent) 459 | 460 | def wrap(f): 461 | def g(*args): 462 | return f(*[intermediate.enclosure for intermediate in args]) 463 | return g 464 | 465 | primitive_to_enclosure_fun = { 466 | jax.lax.add_p: wrap(arithmetic.add), 467 | jax.lax.div_p: wrap(arithmetic.divide), 468 | jax.lax.integer_pow_p: pushforward_integer_pow, 469 | jax.lax.mul_p: wrap(arithmetic.multiply), 470 | jax.lax.neg_p: wrap(arithmetic.negative), 471 | jax.lax.pow_p: pushforward_pow, 472 | jax.lax.sub_p: wrap(arithmetic.subtract), 473 | # TODO(mstreeter): handle all bilinear primitives in a uniform way. 474 | jax.lax.dot_general_p: _dot_general_pushforward_fun(arithmetic), 475 | jax.lax.conv_general_dilated_p: _conv_general_dilated_pushforward_fun( 476 | arithmetic), 477 | jax.lax.broadcast_in_dim_p: _broadcast_in_dim_pushforward_fun, 478 | } 479 | primitive_to_enclosure_fun.update({ 480 | primitive: _elementwise_pushforward_fun(arithmetic, get_enclosure) 481 | for primitive, get_enclosure in _ELEMENTWISE_PRIMITIVE_ENCLOSURES.items() 482 | }) 483 | primitive_to_enclosure_fun.update({ 484 | primitive: _pass_thru_pushforward_fun(primitive) 485 | for primitive in _PASS_THRU_PRIMITIVES 486 | }) 487 | return primitive_to_enclosure_fun 488 | 489 | 490 | def _validate_taylor_enclosure(a: types.TaylorEnclosureLike, x_shape): 491 | set_arithmetic = interval_arithmetic.IntervalArithmetic(jnp) 492 | for i, coeff in enumerate(a): 493 | s = set_arithmetic.shape(coeff) 494 | if s[len(s)-i*len(x_shape):] != i*x_shape: 495 | raise ValueError(x_shape, i, s, coeff, a) 496 | --------------------------------------------------------------------------------