├── .github └── workflows │ ├── pypi-publish.yml │ └── tests.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── distrax ├── __init__.py ├── _src │ ├── __init__.py │ ├── bijectors │ │ ├── __init__.py │ │ ├── bijector.py │ │ ├── bijector_from_tfp.py │ │ ├── bijector_from_tfp_test.py │ │ ├── bijector_test.py │ │ ├── block.py │ │ ├── block_test.py │ │ ├── chain.py │ │ ├── chain_test.py │ │ ├── diag_linear.py │ │ ├── diag_linear_test.py │ │ ├── diag_plus_low_rank_linear.py │ │ ├── diag_plus_low_rank_linear_test.py │ │ ├── gumbel_cdf.py │ │ ├── gumbel_cdf_test.py │ │ ├── inverse.py │ │ ├── inverse_test.py │ │ ├── lambda_bijector.py │ │ ├── lambda_bijector_test.py │ │ ├── linear.py │ │ ├── linear_test.py │ │ ├── lower_upper_triangular_affine.py │ │ ├── lower_upper_triangular_affine_test.py │ │ ├── masked_coupling.py │ │ ├── masked_coupling_test.py │ │ ├── rational_quadratic_spline.py │ │ ├── rational_quadratic_spline_float64_test.py │ │ ├── rational_quadratic_spline_test.py │ │ ├── scalar_affine.py │ │ ├── scalar_affine_test.py │ │ ├── shift.py │ │ ├── shift_test.py │ │ ├── sigmoid.py │ │ ├── sigmoid_test.py │ │ ├── split_coupling.py │ │ ├── split_coupling_test.py │ │ ├── tanh.py │ │ ├── tanh_test.py │ │ ├── tfp_compatible_bijector.py │ │ ├── tfp_compatible_bijector_test.py │ │ ├── triangular_linear.py │ │ ├── triangular_linear_test.py │ │ ├── unconstrained_affine.py │ │ └── unconstrained_affine_test.py │ ├── distributions │ │ ├── __init__.py │ │ ├── bernoulli.py │ │ ├── bernoulli_test.py │ │ ├── beta.py │ │ ├── beta_test.py │ │ ├── categorical.py │ │ ├── categorical_test.py │ │ ├── categorical_uniform.py │ │ ├── categorical_uniform_test.py │ │ ├── clipped.py │ │ ├── clipped_test.py │ │ ├── deterministic.py │ │ ├── deterministic_test.py │ │ ├── dirichlet.py │ │ ├── dirichlet_test.py │ │ ├── distribution.py │ │ ├── distribution_from_tfp.py │ │ ├── distribution_from_tfp_test.py │ │ ├── distribution_test.py │ │ ├── epsilon_greedy.py │ │ ├── epsilon_greedy_test.py │ │ ├── gamma.py │ │ ├── gamma_test.py │ │ ├── greedy.py │ │ ├── greedy_test.py │ │ ├── gumbel.py │ │ ├── gumbel_test.py │ │ ├── independent.py │ │ ├── independent_test.py │ │ ├── joint.py │ │ ├── joint_test.py │ │ ├── laplace.py │ │ ├── laplace_test.py │ │ ├── log_stddev_normal.py │ │ ├── log_stddev_normal_test.py │ │ ├── logistic.py │ │ ├── logistic_test.py │ │ ├── mixture_of_two.py │ │ ├── mixture_of_two_test.py │ │ ├── mixture_same_family.py │ │ ├── mixture_same_family_test.py │ │ ├── multinomial.py │ │ ├── multinomial_test.py │ │ ├── mvn_diag.py │ │ ├── mvn_diag_plus_low_rank.py │ │ ├── mvn_diag_plus_low_rank_test.py │ │ ├── mvn_diag_test.py │ │ ├── mvn_from_bijector.py │ │ ├── mvn_from_bijector_test.py │ │ ├── mvn_full_covariance.py │ │ ├── mvn_full_covariance_test.py │ │ ├── mvn_kl_test.py │ │ ├── mvn_tri.py │ │ ├── mvn_tri_test.py │ │ ├── normal.py │ │ ├── normal_float64_test.py │ │ ├── normal_test.py │ │ ├── one_hot_categorical.py │ │ ├── one_hot_categorical_test.py │ │ ├── quantized.py │ │ ├── quantized_test.py │ │ ├── softmax.py │ │ ├── softmax_test.py │ │ ├── straight_through.py │ │ ├── straight_through_test.py │ │ ├── tfp_compatible_distribution.py │ │ ├── tfp_compatible_distribution_test.py │ │ ├── transformed.py │ │ ├── transformed_test.py │ │ ├── uniform.py │ │ ├── uniform_float64_test.py │ │ ├── uniform_test.py │ │ ├── von_mises.py │ │ └── von_mises_test.py │ └── utils │ │ ├── __init__.py │ │ ├── conversion.py │ │ ├── conversion_test.py │ │ ├── equivalence.py │ │ ├── hmm.py │ │ ├── hmm_test.py │ │ ├── importance_sampling.py │ │ ├── importance_sampling_test.py │ │ ├── jittable.py │ │ ├── jittable_test.py │ │ ├── math.py │ │ ├── math_test.py │ │ ├── monte_carlo.py │ │ ├── monte_carlo_test.py │ │ ├── transformations.py │ │ └── transformations_test.py └── distrax_test.py ├── examples ├── flow.py ├── hmm.py └── vae.py ├── requirements ├── requirements-examples.txt ├── requirements-tests.txt └── requirements.txt ├── setup.py └── test.sh /.github/workflows/pypi-publish.yml: -------------------------------------------------------------------------------- 1 | name: pypi 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install setuptools wheel twine 20 | - name: Check consistency between the package version and release tag 21 | run: | 22 | RELEASE_VER=${GITHUB_REF#refs/*/} 23 | PACKAGE_VER="v`python setup.py --version`" 24 | if [ $RELEASE_VER != $PACKAGE_VER ] 25 | then 26 | echo "package ver. ($PACKAGE_VER) != release ver. ($RELEASE_VER)"; exit 1 27 | fi 28 | - name: Build and publish 29 | env: 30 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 31 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 32 | run: | 33 | python setup.py sdist bdist_wheel 34 | twine upload dist/* 35 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: ["master"] 6 | pull_request: 7 | branches: ["master"] 8 | schedule: 9 | - cron: '30 3 * * *' 10 | 11 | jobs: 12 | build-and-test: 13 | name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" 14 | runs-on: "${{ matrix.os }}" 15 | 16 | strategy: 17 | matrix: 18 | python-version: ["3.9", "3.10", "3.11"] 19 | os: [ubuntu-latest] 20 | 21 | steps: 22 | - uses: "actions/checkout@v2" 23 | - uses: "actions/setup-python@v1" 24 | with: 25 | python-version: "${{ matrix.python-version }}" 26 | - name: Run CI tests 27 | run: bash test.sh 28 | shell: bash 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Building and releasing library: 2 | *.egg-info 3 | *.pyc 4 | *.so 5 | build/ 6 | dist/ 7 | venv/ 8 | 9 | # Mac OS 10 | .DS_Store 11 | 12 | # Python tools 13 | .mypy_cache/ 14 | .pytype/ 15 | .ipynb_checkpoints 16 | 17 | # Editors 18 | .idea 19 | .vscode 20 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Testing 26 | 27 | Please make sure that your PR passes all tests by running `bash test.sh` on your 28 | local machine. Also, you can run only tests that are affected by your code 29 | changes, but you will need to select them manually. 30 | 31 | ## Community Guidelines 32 | 33 | This project follows [Google's Open Source Community 34 | Guidelines](https://opensource.google.com/conduct/). 35 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | include requirements/* 4 | -------------------------------------------------------------------------------- /distrax/_src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /distrax/_src/bijectors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /distrax/_src/bijectors/bijector_from_tfp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Distrax adapter for Bijectors from TensorFlow Probability.""" 16 | 17 | from typing import Callable, Tuple 18 | 19 | from distrax._src.bijectors import bijector as base 20 | import jax 21 | import jax.numpy as jnp 22 | from tensorflow_probability.substrates import jax as tfp 23 | 24 | 25 | tfb = tfp.bijectors 26 | 27 | Array = base.Array 28 | 29 | 30 | class BijectorFromTFP(base.Bijector): 31 | """Wrapper around a TFP bijector that turns it into a Distrax bijector. 32 | 33 | TFP bijectors and Distrax bijectors have similar but not identical semantics, 34 | which makes them not directly compatible. This wrapper guarantees that the 35 | wrapepd TFP bijector fully satisfies the semantics of Distrax, which enables 36 | any TFP bijector to be used by Distrax without modification. 37 | """ 38 | 39 | def __init__(self, tfp_bijector: tfb.Bijector): 40 | """Initializes a BijectorFromTFP. 41 | 42 | Args: 43 | tfp_bijector: TFP bijector to convert to Distrax bijector. 44 | """ 45 | self._tfp_bijector = tfp_bijector 46 | super().__init__( 47 | event_ndims_in=tfp_bijector.forward_min_event_ndims, 48 | event_ndims_out=tfp_bijector.inverse_min_event_ndims, 49 | is_constant_jacobian=tfp_bijector.is_constant_jacobian) 50 | 51 | def __getattr__(self, name: str): 52 | return getattr(self._tfp_bijector, name) 53 | 54 | def forward(self, x: Array) -> Array: 55 | """Computes y = f(x).""" 56 | return self._tfp_bijector.forward(x) 57 | 58 | def inverse(self, y: Array) -> Array: 59 | """Computes x = f^{-1}(y).""" 60 | return self._tfp_bijector.inverse(y) 61 | 62 | def _ensure_batch_shape(self, 63 | logdet: Array, 64 | event_ndims_out: int, 65 | forward_fn: Callable[[Array], Array], 66 | x: Array) -> Array: 67 | """Broadcasts logdet to the batch shape as required.""" 68 | if self._tfp_bijector.is_constant_jacobian: 69 | # If the Jacobian is constant, TFP may return a log det that doesn't have 70 | # full batch shape, but is broadcastable to it. Distrax assumes that the 71 | # log det is always batch-shaped, so we broadcast. 72 | y_shape = jax.eval_shape(forward_fn, x).shape 73 | if event_ndims_out == 0: 74 | batch_shape = y_shape 75 | else: 76 | batch_shape = y_shape[:-event_ndims_out] 77 | logdet = jnp.broadcast_to(logdet, batch_shape) 78 | return logdet 79 | 80 | def forward_log_det_jacobian(self, x: Array) -> Array: 81 | """Computes log|det J(f)(x)|.""" 82 | logdet = self._tfp_bijector.forward_log_det_jacobian(x, self.event_ndims_in) 83 | logdet = self._ensure_batch_shape( 84 | logdet, self.event_ndims_out, self._tfp_bijector.forward, x) 85 | return logdet 86 | 87 | def inverse_log_det_jacobian(self, y: Array) -> Array: 88 | """Computes log|det J(f^{-1})(y)|.""" 89 | logdet = self._tfp_bijector.inverse_log_det_jacobian( 90 | y, self.event_ndims_out) 91 | logdet = self._ensure_batch_shape( 92 | logdet, self.event_ndims_in, self._tfp_bijector.inverse, y) 93 | return logdet 94 | 95 | def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: 96 | """Computes y = f(x) and log|det J(f)(x)|.""" 97 | y = self._tfp_bijector.forward(x) 98 | logdet = self._tfp_bijector.forward_log_det_jacobian(x, self.event_ndims_in) 99 | logdet = self._ensure_batch_shape( 100 | logdet, self.event_ndims_out, self._tfp_bijector.forward, x) 101 | return y, logdet 102 | 103 | def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: 104 | """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" 105 | x = self._tfp_bijector.inverse(y) 106 | logdet = self._tfp_bijector.inverse_log_det_jacobian( 107 | y, self.event_ndims_out) 108 | logdet = self._ensure_batch_shape( 109 | logdet, self.event_ndims_in, self._tfp_bijector.inverse, y) 110 | return x, logdet 111 | 112 | @property 113 | def name(self) -> str: 114 | """Name of the bijector.""" 115 | return self._tfp_bijector.name 116 | -------------------------------------------------------------------------------- /distrax/_src/bijectors/bijector_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `bijector.py`.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | 20 | import chex 21 | from distrax._src.bijectors import bijector 22 | import jax 23 | import jax.numpy as jnp 24 | import numpy as np 25 | 26 | 27 | class DummyBijector(bijector.Bijector): 28 | 29 | def forward_and_log_det(self, x): 30 | super()._check_forward_input_shape(x) 31 | return x, jnp.zeros(x.shape[:-1], jnp.float_) 32 | 33 | def inverse_and_log_det(self, y): 34 | super()._check_inverse_input_shape(y) 35 | return y, jnp.zeros(y.shape[:-1], jnp.float_) 36 | 37 | 38 | class BijectorTest(parameterized.TestCase): 39 | 40 | @parameterized.named_parameters( 41 | ('negative ndims_in', -1, 1, False, False), 42 | ('negative ndims_out', 1, -1, False, False), 43 | ('non-consistent constant properties', 1, 1, True, False), 44 | ) 45 | def test_invalid_parameters(self, ndims_in, ndims_out, cnst_jac, cnst_logdet): 46 | with self.assertRaises(ValueError): 47 | DummyBijector(ndims_in, ndims_out, cnst_jac, cnst_logdet) 48 | 49 | @chex.all_variants 50 | @parameterized.parameters('forward', 'inverse') 51 | def test_invalid_inputs(self, method_str): 52 | bij = DummyBijector(1, 1, True, True) 53 | fn = self.variant(getattr(bij, method_str)) 54 | with self.assertRaises(ValueError): 55 | fn(jnp.zeros(())) 56 | 57 | def test_jittable(self): 58 | @jax.jit 59 | def forward(bij, x): 60 | return bij.forward(x) 61 | 62 | bij = DummyBijector(1, 1, True, True) 63 | x = jnp.zeros((4,)) 64 | np.testing.assert_allclose(forward(bij, x), x) 65 | 66 | 67 | if __name__ == '__main__': 68 | absltest.main() 69 | -------------------------------------------------------------------------------- /distrax/_src/bijectors/block.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Wrapper to turn independent Bijectors into block Bijectors.""" 16 | 17 | from typing import Tuple 18 | 19 | from distrax._src.bijectors import bijector as base 20 | from distrax._src.utils import conversion 21 | from distrax._src.utils import math 22 | 23 | Array = base.Array 24 | BijectorLike = base.BijectorLike 25 | BijectorT = base.BijectorT 26 | 27 | 28 | class Block(base.Bijector): 29 | """A wrapper that promotes a bijector to a block bijector. 30 | 31 | A block bijector applies a bijector to a k-dimensional array of events, but 32 | considers that array of events to be a single event. In practical terms, this 33 | means that the log det Jacobian will be summed over its last k dimensions. 34 | 35 | For example, consider a scalar bijector (such as `Tanh`) that operates on 36 | scalar events. We may want to apply this bijector identically to a 4D array of 37 | shape [N, H, W, C] representing a sequence of N images. Doing so naively will 38 | produce a log det Jacobian of shape [N, H, W, C], because the scalar bijector 39 | will assume scalar events and so all 4 dimensions will be considered as batch. 40 | To promote the scalar bijector to a "block scalar" that operates on the 3D 41 | arrays can be done by `Block(bijector, ndims=3)`. Then, applying the block 42 | bijector will produce a log det Jacobian of shape [N] as desired. 43 | 44 | In general, suppose `bijector` operates on n-dimensional events. Then, 45 | `Block(bijector, k)` will promote `bijector` to a block bijector that 46 | operates on (k + n)-dimensional events, summing the log det Jacobian over its 47 | last k dimensions. In practice, this means that the last k batch dimensions 48 | will be turned into event dimensions. 49 | """ 50 | 51 | def __init__(self, bijector: BijectorLike, ndims: int): 52 | """Initializes a Block. 53 | 54 | Args: 55 | bijector: the bijector to be promoted to a block bijector. It can be a 56 | distrax bijector, a TFP bijector, or a callable to be wrapped by 57 | `Lambda`. 58 | ndims: number of batch dimensions to promote to event dimensions. 59 | """ 60 | if ndims < 0: 61 | raise ValueError(f"`ndims` must be non-negative; got {ndims}.") 62 | self._bijector = conversion.as_bijector(bijector) 63 | self._ndims = ndims 64 | super().__init__( 65 | event_ndims_in=ndims + self._bijector.event_ndims_in, 66 | event_ndims_out=ndims + self._bijector.event_ndims_out, 67 | is_constant_jacobian=self._bijector.is_constant_jacobian, 68 | is_constant_log_det=self._bijector.is_constant_log_det) 69 | 70 | @property 71 | def bijector(self) -> BijectorT: 72 | """The base bijector, without promoting to a block bijector.""" 73 | return self._bijector 74 | 75 | @property 76 | def ndims(self) -> int: 77 | """The number of batch dimensions promoted to event dimensions.""" 78 | return self._ndims 79 | 80 | def forward(self, x: Array) -> Array: 81 | """Computes y = f(x).""" 82 | self._check_forward_input_shape(x) 83 | return self._bijector.forward(x) 84 | 85 | def inverse(self, y: Array) -> Array: 86 | """Computes x = f^{-1}(y).""" 87 | self._check_inverse_input_shape(y) 88 | return self._bijector.inverse(y) 89 | 90 | def forward_log_det_jacobian(self, x: Array) -> Array: 91 | """Computes log|det J(f)(x)|.""" 92 | self._check_forward_input_shape(x) 93 | log_det = self._bijector.forward_log_det_jacobian(x) 94 | return math.sum_last(log_det, self._ndims) 95 | 96 | def inverse_log_det_jacobian(self, y: Array) -> Array: 97 | """Computes log|det J(f^{-1})(y)|.""" 98 | self._check_inverse_input_shape(y) 99 | log_det = self._bijector.inverse_log_det_jacobian(y) 100 | return math.sum_last(log_det, self._ndims) 101 | 102 | def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: 103 | """Computes y = f(x) and log|det J(f)(x)|.""" 104 | self._check_forward_input_shape(x) 105 | y, log_det = self._bijector.forward_and_log_det(x) 106 | return y, math.sum_last(log_det, self._ndims) 107 | 108 | def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: 109 | """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" 110 | self._check_inverse_input_shape(y) 111 | x, log_det = self._bijector.inverse_and_log_det(y) 112 | return x, math.sum_last(log_det, self._ndims) 113 | 114 | @property 115 | def name(self) -> str: 116 | """Name of the bijector.""" 117 | return self.__class__.__name__ + self._bijector.name 118 | 119 | def same_as(self, other: base.Bijector) -> bool: 120 | """Returns True if this bijector is guaranteed to be the same as `other`.""" 121 | if type(other) is Block: # pylint: disable=unidiomatic-typecheck 122 | return self.bijector.same_as(other.bijector) and self.ndims == other.ndims 123 | 124 | return False 125 | -------------------------------------------------------------------------------- /distrax/_src/bijectors/block_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `block.py`.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | 20 | import chex 21 | from distrax._src.bijectors import bijector as base 22 | from distrax._src.bijectors import block as block_bijector 23 | from distrax._src.bijectors import scalar_affine 24 | from distrax._src.utils import conversion 25 | import jax 26 | import jax.numpy as jnp 27 | import numpy as np 28 | from tensorflow_probability.substrates import jax as tfp 29 | 30 | 31 | tfd = tfp.distributions 32 | tfb = tfp.bijectors 33 | 34 | 35 | RTOL = 1e-6 36 | 37 | 38 | class BlockTest(parameterized.TestCase): 39 | 40 | def setUp(self): 41 | super().setUp() 42 | self.seed = jax.random.PRNGKey(1234) 43 | 44 | def test_properties(self): 45 | bijct = conversion.as_bijector(jnp.tanh) 46 | block = block_bijector.Block(bijct, 1) 47 | assert block.ndims == 1 48 | assert isinstance(block.bijector, base.Bijector) 49 | 50 | def test_invalid_properties(self): 51 | bijct = conversion.as_bijector(jnp.tanh) 52 | with self.assertRaises(ValueError): 53 | block_bijector.Block(bijct, -1) 54 | 55 | @chex.all_variants 56 | @parameterized.named_parameters( 57 | ('scale_0', lambda: tfb.Scale(2), 0), 58 | ('scale_1', lambda: tfb.Scale(2), 1), 59 | ('scale_2', lambda: tfb.Scale(2), 2), 60 | ('reshape_0', lambda: tfb.Reshape([120], [4, 5, 6]), 0), 61 | ('reshape_1', lambda: tfb.Reshape([120], [4, 5, 6]), 1), 62 | ('reshape_2', lambda: tfb.Reshape([120], [4, 5, 6]), 2), 63 | ) 64 | def test_against_tfp_semantics(self, tfp_bijector_fn, ndims): 65 | tfp_bijector = tfp_bijector_fn() 66 | x = jax.random.normal(self.seed, [2, 3, 4, 5, 6]) 67 | y = tfp_bijector(x) 68 | fwd_event_ndims = ndims + tfp_bijector.forward_min_event_ndims 69 | inv_event_ndims = ndims + tfp_bijector.inverse_min_event_ndims 70 | block = block_bijector.Block(tfp_bijector, ndims) 71 | np.testing.assert_allclose( 72 | tfp_bijector.forward_log_det_jacobian(x, fwd_event_ndims), 73 | self.variant(block.forward_log_det_jacobian)(x), atol=2e-5) 74 | np.testing.assert_allclose( 75 | tfp_bijector.inverse_log_det_jacobian(y, inv_event_ndims), 76 | self.variant(block.inverse_log_det_jacobian)(y), atol=2e-5) 77 | 78 | @chex.all_variants 79 | @parameterized.named_parameters( 80 | ('dx_tanh_0', lambda: jnp.tanh, 0), 81 | ('dx_tanh_1', lambda: jnp.tanh, 1), 82 | ('dx_tanh_2', lambda: jnp.tanh, 2), 83 | ('tfp_tanh_0', tfb.Tanh, 0), 84 | ('tfp_tanh_1', tfb.Tanh, 1), 85 | ('tfp_tanh_2', tfb.Tanh, 2), 86 | ) 87 | def test_forward_inverse_work_as_expected(self, bijector_fn, ndims): 88 | bijct = conversion.as_bijector(bijector_fn()) 89 | x = jax.random.normal(self.seed, [2, 3]) 90 | block = block_bijector.Block(bijct, ndims) 91 | np.testing.assert_array_equal( 92 | self.variant(bijct.forward)(x), 93 | self.variant(block.forward)(x)) 94 | np.testing.assert_array_equal( 95 | self.variant(bijct.inverse)(x), 96 | self.variant(block.inverse)(x)) 97 | np.testing.assert_allclose( 98 | self.variant(bijct.forward_and_log_det)(x)[0], 99 | self.variant(block.forward_and_log_det)(x)[0], atol=2e-7) 100 | np.testing.assert_array_equal( 101 | self.variant(bijct.inverse_and_log_det)(x)[0], 102 | self.variant(block.inverse_and_log_det)(x)[0]) 103 | 104 | @chex.all_variants 105 | @parameterized.named_parameters( 106 | ('dx_tanh_0', lambda: jnp.tanh, 0), 107 | ('dx_tanh_1', lambda: jnp.tanh, 1), 108 | ('dx_tanh_2', lambda: jnp.tanh, 2), 109 | ('tfp_tanh_0', tfb.Tanh, 0), 110 | ('tfp_tanh_1', tfb.Tanh, 1), 111 | ('tfp_tanh_2', tfb.Tanh, 2), 112 | ) 113 | def test_log_det_jacobian_works_as_expected(self, bijector_fn, ndims): 114 | bijct = conversion.as_bijector(bijector_fn()) 115 | x = jax.random.normal(self.seed, [2, 3]) 116 | block = block_bijector.Block(bijct, ndims) 117 | axes = tuple(range(-ndims, 0)) 118 | np.testing.assert_allclose( 119 | self.variant(bijct.forward_log_det_jacobian)(x).sum(axes), 120 | self.variant(block.forward_log_det_jacobian)(x), rtol=RTOL) 121 | np.testing.assert_allclose( 122 | self.variant(bijct.inverse_log_det_jacobian)(x).sum(axes), 123 | self.variant(block.inverse_log_det_jacobian)(x), rtol=RTOL) 124 | np.testing.assert_allclose( 125 | self.variant(bijct.forward_and_log_det)(x)[1].sum(axes), 126 | self.variant(block.forward_and_log_det)(x)[1], rtol=RTOL) 127 | np.testing.assert_allclose( 128 | self.variant(bijct.inverse_and_log_det)(x)[1].sum(axes), 129 | self.variant(block.inverse_and_log_det)(x)[1], rtol=RTOL) 130 | 131 | def test_raises_on_invalid_input_shape(self): 132 | bij = block_bijector.Block(lambda x: x, 1) 133 | for fn in [bij.forward, bij.inverse, 134 | bij.forward_log_det_jacobian, bij.inverse_log_det_jacobian, 135 | bij.forward_and_log_det, bij.inverse_and_log_det]: 136 | with self.assertRaises(ValueError): 137 | fn(jnp.array(0)) 138 | 139 | def test_jittable(self): 140 | @jax.jit 141 | def f(x, b): 142 | return b.forward(x) 143 | 144 | bijector = block_bijector.Block(scalar_affine.ScalarAffine(0), 1) 145 | x = np.zeros((2, 3)) 146 | f(x, bijector) 147 | 148 | 149 | if __name__ == '__main__': 150 | absltest.main() 151 | -------------------------------------------------------------------------------- /distrax/_src/bijectors/chain.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Chain Bijector for composing a sequence of Bijector transformations.""" 16 | 17 | from typing import List, Sequence, Tuple 18 | 19 | from distrax._src.bijectors import bijector as base 20 | from distrax._src.utils import conversion 21 | 22 | Array = base.Array 23 | BijectorLike = base.BijectorLike 24 | BijectorT = base.BijectorT 25 | 26 | 27 | class Chain(base.Bijector): 28 | """Composition of a sequence of bijectors into a single bijector. 29 | 30 | Bijectors are composable: if `f` and `g` are bijectors, then `g o f` is also 31 | a bijector. Given a sequence of bijectors `[f1, ..., fN]`, this class 32 | implements the bijector defined by `fN o ... o f1`. 33 | 34 | NOTE: the bijectors are applied in reverse order from the order they appear in 35 | the sequence. For example, consider the following code where `f` and `g` are 36 | two bijectors: 37 | ``` 38 | layers = [] 39 | layers.append(f) 40 | layers.append(g) 41 | bijector = distrax.Chain(layers) 42 | y = bijector.forward(x) 43 | ``` 44 | The above code will transform `x` by first applying `g`, then `f`, so that 45 | `y = f(g(x))`. 46 | """ 47 | 48 | def __init__(self, bijectors: Sequence[BijectorLike]): 49 | """Initializes a Chain bijector. 50 | 51 | Args: 52 | bijectors: a sequence of bijectors to be composed into one. Each bijector 53 | can be a distrax bijector, a TFP bijector, or a callable to be wrapped 54 | by `Lambda`. The sequence must contain at least one bijector. 55 | """ 56 | if not bijectors: 57 | raise ValueError("The sequence of bijectors cannot be empty.") 58 | self._bijectors = [conversion.as_bijector(b) for b in bijectors] 59 | 60 | # Check that neighboring bijectors in the chain have compatible dimensions 61 | for i, (outer, inner) in enumerate(zip(self._bijectors[:-1], 62 | self._bijectors[1:])): 63 | if outer.event_ndims_in != inner.event_ndims_out: 64 | raise ValueError( 65 | f"The chain of bijector event shapes are incompatible. Bijector " 66 | f"{i} ({outer.name}) expects events with {outer.event_ndims_in} " 67 | f"dimensions, while Bijector {i+1} ({inner.name}) produces events " 68 | f"with {inner.event_ndims_out} dimensions.") 69 | 70 | is_constant_jacobian = all(b.is_constant_jacobian for b in self._bijectors) 71 | is_constant_log_det = all(b.is_constant_log_det for b in self._bijectors) 72 | super().__init__( 73 | event_ndims_in=self._bijectors[-1].event_ndims_in, 74 | event_ndims_out=self._bijectors[0].event_ndims_out, 75 | is_constant_jacobian=is_constant_jacobian, 76 | is_constant_log_det=is_constant_log_det) 77 | 78 | @property 79 | def bijectors(self) -> List[BijectorT]: 80 | """The list of bijectors in the chain.""" 81 | return self._bijectors 82 | 83 | def forward(self, x: Array) -> Array: 84 | """Computes y = f(x).""" 85 | for bijector in reversed(self._bijectors): 86 | x = bijector.forward(x) 87 | return x 88 | 89 | def inverse(self, y: Array) -> Array: 90 | """Computes x = f^{-1}(y).""" 91 | for bijector in self._bijectors: 92 | y = bijector.inverse(y) 93 | return y 94 | 95 | def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: 96 | """Computes y = f(x) and log|det J(f)(x)|.""" 97 | x, log_det = self._bijectors[-1].forward_and_log_det(x) 98 | for bijector in reversed(self._bijectors[:-1]): 99 | x, ld = bijector.forward_and_log_det(x) 100 | log_det += ld 101 | return x, log_det 102 | 103 | def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: 104 | """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" 105 | y, log_det = self._bijectors[0].inverse_and_log_det(y) 106 | for bijector in self._bijectors[1:]: 107 | y, ld = bijector.inverse_and_log_det(y) 108 | log_det += ld 109 | return y, log_det 110 | 111 | def same_as(self, other: base.Bijector) -> bool: 112 | """Returns True if this bijector is guaranteed to be the same as `other`.""" 113 | if type(other) is Chain: # pylint: disable=unidiomatic-typecheck 114 | if len(self.bijectors) != len(other.bijectors): 115 | return False 116 | for bij1, bij2 in zip(self.bijectors, other.bijectors): 117 | if not bij1.same_as(bij2): 118 | return False 119 | return True 120 | elif len(self.bijectors) == 1: 121 | return self.bijectors[0].same_as(other) 122 | 123 | return False 124 | 125 | -------------------------------------------------------------------------------- /distrax/_src/bijectors/diag_linear.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Diagonal linear bijector.""" 16 | 17 | from typing import Tuple 18 | 19 | from distrax._src.bijectors import bijector as base 20 | from distrax._src.bijectors import block 21 | from distrax._src.bijectors import linear 22 | from distrax._src.bijectors import scalar_affine 23 | import jax.numpy as jnp 24 | 25 | Array = base.Array 26 | 27 | 28 | class DiagLinear(linear.Linear): 29 | """Linear bijector with a diagonal weight matrix. 30 | 31 | The bijector is defined as `f(x) = Ax` where `A` is a `DxD` diagonal matrix. 32 | Additional dimensions, if any, index batches. 33 | 34 | The Jacobian determinant is trivially computed by taking the product of the 35 | diagonal entries in `A`. The inverse transformation `x = f^{-1}(y)` is 36 | computed element-wise. 37 | 38 | The bijector is invertible if and only if the diagonal entries of `A` are all 39 | non-zero. It is the responsibility of the user to make sure that this is the 40 | case; the class will make no attempt to verify that the bijector is 41 | invertible. 42 | """ 43 | 44 | def __init__(self, diag: Array): 45 | """Initializes the bijector. 46 | 47 | Args: 48 | diag: a vector of length D, the diagonal of matrix `A`. Can also be a 49 | batch of such vectors. 50 | """ 51 | if diag.ndim < 1: 52 | raise ValueError("`diag` must have at least one dimension.") 53 | self._bijector = block.Block( 54 | scalar_affine.ScalarAffine(shift=0., scale=diag), ndims=1) 55 | super().__init__( 56 | event_dims=diag.shape[-1], 57 | batch_shape=diag.shape[:-1], 58 | dtype=diag.dtype) 59 | self._diag = diag 60 | self.forward = self._bijector.forward 61 | self.forward_log_det_jacobian = self._bijector.forward_log_det_jacobian 62 | self.inverse = self._bijector.inverse 63 | self.inverse_log_det_jacobian = self._bijector.inverse_log_det_jacobian 64 | self.inverse_and_log_det = self._bijector.inverse_and_log_det 65 | 66 | @property 67 | def diag(self) -> Array: 68 | """Vector of length D, the diagonal of matrix `A`.""" 69 | return self._diag 70 | 71 | @property 72 | def matrix(self) -> Array: 73 | """The full matrix `A`.""" 74 | return jnp.vectorize(jnp.diag, signature="(k)->(k,k)")(self.diag) 75 | 76 | def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: 77 | """Computes y = f(x) and log|det J(f)(x)|.""" 78 | return self._bijector.forward_and_log_det(x) 79 | 80 | def same_as(self, other: base.Bijector) -> bool: 81 | """Returns True if this bijector is guaranteed to be the same as `other`.""" 82 | if type(other) is DiagLinear: # pylint: disable=unidiomatic-typecheck 83 | return self.diag is other.diag 84 | return False 85 | -------------------------------------------------------------------------------- /distrax/_src/bijectors/gumbel_cdf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """GumbelCDF bijector.""" 16 | 17 | from typing import Tuple 18 | 19 | from distrax._src.bijectors import bijector as base 20 | import jax.numpy as jnp 21 | 22 | Array = base.Array 23 | 24 | 25 | class GumbelCDF(base.Bijector): 26 | """A bijector that computes the Gumbel cumulative density function (CDF). 27 | 28 | The Gumbel CDF is given by `y = f(x) = exp(-exp(-x))` for a scalar input `x`. 29 | Its inverse is `x = -log(-log(y))`. The log-det Jacobian of the transformation 30 | is `log df/dx = -exp(-x) - x`. 31 | """ 32 | 33 | def __init__(self): 34 | """Initializes a GumbelCDF bijector.""" 35 | super().__init__(event_ndims_in=0) 36 | 37 | def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: 38 | """Computes y = f(x) and log|det J(f)(x)|.""" 39 | exp_neg_x = jnp.exp(-x) 40 | y = jnp.exp(-exp_neg_x) 41 | log_det = - x - exp_neg_x 42 | return y, log_det 43 | 44 | def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: 45 | """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" 46 | log_y = jnp.log(y) 47 | x = -jnp.log(-log_y) 48 | return x, x - log_y 49 | 50 | def same_as(self, other: base.Bijector) -> bool: 51 | """Returns True if this bijector is guaranteed to be the same as `other`.""" 52 | return type(other) is GumbelCDF # pylint: disable=unidiomatic-typecheck 53 | -------------------------------------------------------------------------------- /distrax/_src/bijectors/inverse.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Wrapper for inverting a Distrax Bijector.""" 16 | 17 | from typing import Tuple 18 | 19 | from distrax._src.bijectors import bijector as base 20 | from distrax._src.utils import conversion 21 | 22 | Array = base.Array 23 | BijectorLike = base.BijectorLike 24 | BijectorT = base.BijectorT 25 | 26 | 27 | class Inverse(base.Bijector): 28 | """A bijector that inverts a given bijector. 29 | 30 | That is, if `bijector` implements the transformation `f`, `Inverse(bijector)` 31 | implements the inverse transformation `f^{-1}`. 32 | 33 | The inversion is performed by swapping the forward with the corresponding 34 | inverse methods of the given bijector. 35 | """ 36 | 37 | def __init__(self, bijector: BijectorLike): 38 | """Initializes an Inverse bijector. 39 | 40 | Args: 41 | bijector: the bijector to be inverted. It can be a distrax bijector, a TFP 42 | bijector, or a callable to be wrapped by `Lambda`. 43 | """ 44 | self._bijector = conversion.as_bijector(bijector) 45 | super().__init__( 46 | event_ndims_in=self._bijector.event_ndims_out, 47 | event_ndims_out=self._bijector.event_ndims_in, 48 | is_constant_jacobian=self._bijector.is_constant_jacobian, 49 | is_constant_log_det=self._bijector.is_constant_log_det) 50 | 51 | @property 52 | def bijector(self) -> BijectorT: 53 | """The base bijector that was the input to `Inverse`.""" 54 | return self._bijector 55 | 56 | def forward(self, x: Array) -> Array: 57 | """Computes y = f(x).""" 58 | return self._bijector.inverse(x) 59 | 60 | def inverse(self, y: Array) -> Array: 61 | """Computes x = f^{-1}(y).""" 62 | return self._bijector.forward(y) 63 | 64 | def forward_log_det_jacobian(self, x: Array) -> Array: 65 | """Computes log|det J(f)(x)|.""" 66 | return self._bijector.inverse_log_det_jacobian(x) 67 | 68 | def inverse_log_det_jacobian(self, y: Array) -> Array: 69 | """Computes log|det J(f^{-1})(y)|.""" 70 | return self._bijector.forward_log_det_jacobian(y) 71 | 72 | def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: 73 | """Computes y = f(x) and log|det J(f)(x)|.""" 74 | return self._bijector.inverse_and_log_det(x) 75 | 76 | def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: 77 | """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" 78 | return self._bijector.forward_and_log_det(y) 79 | 80 | @property 81 | def name(self) -> str: 82 | """Name of the bijector.""" 83 | return self.__class__.__name__ + self._bijector.name 84 | 85 | def same_as(self, other: base.Bijector) -> bool: 86 | """Returns True if this bijector is guaranteed to be the same as `other`.""" 87 | if type(other) is Inverse: # pylint: disable=unidiomatic-typecheck 88 | return self.bijector.same_as(other.bijector) 89 | return False 90 | -------------------------------------------------------------------------------- /distrax/_src/bijectors/linear.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Linear bijector.""" 16 | 17 | import abc 18 | from typing import Sequence, Tuple 19 | 20 | from distrax._src.bijectors import bijector as base 21 | import jax.numpy as jnp 22 | 23 | Array = base.Array 24 | 25 | 26 | class Linear(base.Bijector, metaclass=abc.ABCMeta): 27 | """Base class for linear bijectors. 28 | 29 | This class provides a base class for bijectors defined as `f(x) = Ax`, 30 | where `A` is a `DxD` matrix and `x` is a `D`-dimensional vector. 31 | """ 32 | 33 | def __init__(self, 34 | event_dims: int, 35 | batch_shape: Sequence[int], 36 | dtype: jnp.dtype): 37 | """Initializes a `Linear` bijector. 38 | 39 | Args: 40 | event_dims: the dimensionality `D` of the event `x`. It is assumed that 41 | `x` is a vector of length `event_dims`. 42 | batch_shape: the batch shape of the bijector. 43 | dtype: the data type of matrix `A`. 44 | """ 45 | super().__init__(event_ndims_in=1, is_constant_jacobian=True) 46 | self._event_dims = event_dims 47 | self._batch_shape = tuple(batch_shape) 48 | self._dtype = dtype 49 | 50 | @property 51 | def matrix(self) -> Array: 52 | """The matrix `A` of the transformation. 53 | 54 | To be optionally implemented in a subclass. 55 | 56 | Returns: 57 | An array of shape `batch_shape + (event_dims, event_dims)` and data type 58 | `dtype`. 59 | """ 60 | raise NotImplementedError( 61 | f"Linear bijector {self.name} does not implement `matrix`.") 62 | 63 | @property 64 | def event_dims(self) -> int: 65 | """The dimensionality `D` of the event `x`.""" 66 | return self._event_dims 67 | 68 | @property 69 | def batch_shape(self) -> Tuple[int, ...]: 70 | """The batch shape of the bijector.""" 71 | return self._batch_shape 72 | 73 | @property 74 | def dtype(self) -> jnp.dtype: 75 | """The data type of matrix `A`.""" 76 | return self._dtype 77 | -------------------------------------------------------------------------------- /distrax/_src/bijectors/linear_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `linear.py`.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from distrax._src.bijectors import linear 20 | import jax.numpy as jnp 21 | 22 | 23 | class MockLinear(linear.Linear): 24 | 25 | def forward_and_log_det(self, x): 26 | raise Exception # pylint:disable=broad-exception-raised 27 | 28 | 29 | class LinearTest(parameterized.TestCase): 30 | 31 | @parameterized.parameters( 32 | {'event_dims': 1, 'batch_shape': (), 'dtype': jnp.float16}, 33 | {'event_dims': 10, 'batch_shape': (2, 3), 'dtype': jnp.float32}) 34 | def test_properties(self, event_dims, batch_shape, dtype): 35 | bij = MockLinear(event_dims, batch_shape, dtype) 36 | self.assertEqual(bij.event_ndims_in, 1) 37 | self.assertEqual(bij.event_ndims_out, 1) 38 | self.assertTrue(bij.is_constant_jacobian) 39 | self.assertTrue(bij.is_constant_log_det) 40 | self.assertEqual(bij.event_dims, event_dims) 41 | self.assertEqual(bij.batch_shape, batch_shape) 42 | self.assertEqual(bij.dtype, dtype) 43 | with self.assertRaises(NotImplementedError): 44 | bij.matrix # pylint: disable=pointless-statement 45 | 46 | 47 | if __name__ == '__main__': 48 | absltest.main() 49 | -------------------------------------------------------------------------------- /distrax/_src/bijectors/lower_upper_triangular_affine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """LU-decomposed affine bijector.""" 16 | 17 | from distrax._src.bijectors import bijector as base 18 | from distrax._src.bijectors import block 19 | from distrax._src.bijectors import chain 20 | from distrax._src.bijectors import shift 21 | from distrax._src.bijectors import triangular_linear 22 | from distrax._src.bijectors import unconstrained_affine 23 | import jax.numpy as jnp 24 | 25 | Array = base.Array 26 | 27 | 28 | class LowerUpperTriangularAffine(chain.Chain): 29 | """An affine bijector whose weight matrix is parameterized as A = LU. 30 | 31 | This bijector is defined as `f(x) = Ax + b` where: 32 | 33 | * A = LU is a DxD matrix. 34 | * L is a lower-triangular matrix with ones on the diagonal. 35 | * U is an upper-triangular matrix. 36 | 37 | The Jacobian determinant can be computed in O(D) as follows: 38 | 39 | log|det J(x)| = log|det A| = sum(log|diag(U)|) 40 | 41 | The inverse can be computed in O(D^2) by solving two triangular systems: 42 | 43 | * Lz = y - b 44 | * Ux = z 45 | 46 | The bijector is invertible if and only if all diagonal elements of U are 47 | non-zero. It is the responsibility of the user to make sure that this is the 48 | case; the class will make no attempt to verify that the bijector is 49 | invertible. 50 | 51 | L and U are parameterized using a square matrix M as follows: 52 | 53 | * The lower-triangular part of M (excluding the diagonal) becomes L. 54 | * The upper-triangular part of M (including the diagonal) becomes U. 55 | 56 | The parameterization is such that if M is the identity, LU is also the 57 | identity. Note however that M is not generally equal to LU. 58 | """ 59 | 60 | def __init__(self, matrix: Array, bias: Array): 61 | """Initializes a `LowerUpperTriangularAffine` bijector. 62 | 63 | Args: 64 | matrix: a square matrix parameterizing `L` and `U` as described in the 65 | class docstring. Can also be a batch of matrices. If `matrix` is the 66 | identity, `LU` is also the identity. Note however that `matrix` is 67 | generally not equal to the product `LU`. 68 | bias: the vector `b` in `LUx + b`. Can also be a batch of vectors. 69 | """ 70 | unconstrained_affine.check_affine_parameters(matrix, bias) 71 | self._upper = triangular_linear.TriangularLinear(matrix, is_lower=False) 72 | dim = matrix.shape[-1] 73 | lower = jnp.eye(dim) + jnp.tril(matrix, -1) # Replace diagonal with ones. 74 | self._lower = triangular_linear.TriangularLinear(lower, is_lower=True) 75 | self._shift = block.Block(shift.Shift(bias), 1) 76 | self._bias = bias 77 | super().__init__([self._shift, self._lower, self._upper]) 78 | 79 | @property 80 | def lower(self) -> Array: 81 | """The lower triangular matrix `L` with ones in the diagonal.""" 82 | return self._lower.matrix 83 | 84 | @property 85 | def upper(self) -> Array: 86 | """The upper triangular matrix `U`.""" 87 | return self._upper.matrix 88 | 89 | @property 90 | def matrix(self) -> Array: 91 | """The matrix `A = LU` of the transformation.""" 92 | return self.lower @ self.upper 93 | 94 | @property 95 | def bias(self) -> Array: 96 | """The shift `b` of the transformation.""" 97 | return self._bias 98 | 99 | def same_as(self, other: base.Bijector) -> bool: 100 | """Returns True if this bijector is guaranteed to be the same as `other`.""" 101 | if type(other) is LowerUpperTriangularAffine: # pylint: disable=unidiomatic-typecheck 102 | return all(( 103 | self.lower is other.lower, 104 | self.upper is other.upper, 105 | self.bias is other.bias, 106 | )) 107 | return False 108 | -------------------------------------------------------------------------------- /distrax/_src/bijectors/rational_quadratic_spline_float64_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `rational_quadratic_spline.py`. 16 | 17 | Float64 is enabled in these tests. We keep them separate from other tests to 18 | avoid interfering with types elsewhere. 19 | """ 20 | 21 | from absl.testing import absltest 22 | from absl.testing import parameterized 23 | 24 | import chex 25 | from distrax._src.bijectors import rational_quadratic_spline 26 | from jax import config as jax_config 27 | import jax.numpy as jnp 28 | 29 | 30 | def setUpModule(): 31 | jax_config.update('jax_enable_x64', True) 32 | 33 | 34 | class RationalQuadraticSplineFloat64Test(chex.TestCase): 35 | """Tests for rational quadratic spline that use float64.""" 36 | 37 | def _assert_dtypes(self, bij, x, dtype): 38 | """Asserts dtypes.""" 39 | # Sanity check to make sure float64 is enabled. 40 | x_64 = jnp.zeros([]) 41 | self.assertEqual(jnp.float64, x_64.dtype) 42 | 43 | y, logd = self.variant(bij.forward_and_log_det)(x) 44 | self.assertEqual(dtype, y.dtype) 45 | self.assertEqual(dtype, logd.dtype) 46 | y, logd = self.variant(bij.inverse_and_log_det)(x) 47 | self.assertEqual(dtype, y.dtype) 48 | self.assertEqual(dtype, logd.dtype) 49 | 50 | @chex.all_variants 51 | @parameterized.product( 52 | dtypes=[(jnp.float32, jnp.float32, jnp.float32), 53 | (jnp.float32, jnp.float64, jnp.float64), 54 | (jnp.float64, jnp.float32, jnp.float64), 55 | (jnp.float64, jnp.float64, jnp.float64)], 56 | boundary_slopes=['unconstrained', 'lower_identity', 'upper_identity', 57 | 'identity', 'circular']) 58 | def test_dtypes(self, dtypes, boundary_slopes): 59 | x_dtype, params_dtype, result_dtype = dtypes 60 | x = jnp.zeros([3], x_dtype) 61 | self.assertEqual(x_dtype, x.dtype) 62 | spline = rational_quadratic_spline.RationalQuadraticSpline( 63 | jnp.zeros([25], params_dtype), 0., 1., boundary_slopes=boundary_slopes) 64 | self._assert_dtypes(spline, x, result_dtype) 65 | 66 | 67 | if __name__ == '__main__': 68 | absltest.main() 69 | -------------------------------------------------------------------------------- /distrax/_src/bijectors/scalar_affine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Scalar affine bijector.""" 16 | 17 | from typing import Optional, Tuple, Union 18 | 19 | from distrax._src.bijectors import bijector as base 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | 24 | Array = base.Array 25 | Numeric = Union[Array, float] 26 | 27 | 28 | class ScalarAffine(base.Bijector): 29 | """An affine bijector that acts elementwise. 30 | 31 | The bijector is defined as follows: 32 | 33 | - Forward: `y = scale * x + shift` 34 | - Forward Jacobian determinant: `log|det J(x)| = log|scale|` 35 | - Inverse: `x = (y - shift) / scale` 36 | - Inverse Jacobian determinant: `log|det J(y)| = -log|scale|` 37 | 38 | where `scale` and `shift` are the bijector's parameters. 39 | """ 40 | 41 | def __init__(self, 42 | shift: Numeric, 43 | scale: Optional[Numeric] = None, 44 | log_scale: Optional[Numeric] = None): 45 | """Initializes a ScalarAffine bijector. 46 | 47 | Args: 48 | shift: the bijector's shift parameter. Can also be batched. 49 | scale: the bijector's scale parameter. Can also be batched. NOTE: `scale` 50 | must be non-zero, otherwise the bijector is not invertible. It is the 51 | user's responsibility to make sure `scale` is non-zero; the class will 52 | make no attempt to verify this. 53 | log_scale: the log of the scale parameter. Can also be batched. If 54 | specified, the bijector's scale is set equal to `exp(log_scale)`. Unlike 55 | `scale`, `log_scale` is an unconstrained parameter. NOTE: either `scale` 56 | or `log_scale` can be specified, but not both. If neither is specified, 57 | the bijector's scale will default to 1. 58 | 59 | Raises: 60 | ValueError: if both `scale` and `log_scale` are not None. 61 | """ 62 | super().__init__(event_ndims_in=0, is_constant_jacobian=True) 63 | self._shift = shift 64 | if scale is None and log_scale is None: 65 | self._scale = 1. 66 | self._inv_scale = 1. 67 | self._log_scale = 0. 68 | elif log_scale is None: 69 | self._scale = scale 70 | self._inv_scale = 1. / scale 71 | self._log_scale = jnp.log(jnp.abs(scale)) 72 | elif scale is None: 73 | self._scale = jnp.exp(log_scale) 74 | self._inv_scale = jnp.exp(jnp.negative(log_scale)) 75 | self._log_scale = log_scale 76 | else: 77 | raise ValueError( 78 | 'Only one of `scale` and `log_scale` can be specified, not both.') 79 | self._batch_shape = jax.lax.broadcast_shapes( 80 | jnp.shape(self._shift), jnp.shape(self._scale)) 81 | 82 | @property 83 | def shift(self) -> Numeric: 84 | """The bijector's shift.""" 85 | return self._shift 86 | 87 | @property 88 | def log_scale(self) -> Numeric: 89 | """The log of the bijector's scale.""" 90 | return self._log_scale 91 | 92 | @property 93 | def scale(self) -> Numeric: 94 | """The bijector's scale.""" 95 | assert self._scale is not None # By construction. 96 | return self._scale 97 | 98 | def forward(self, x: Array) -> Array: 99 | """Computes y = f(x).""" 100 | batch_shape = jax.lax.broadcast_shapes(self._batch_shape, x.shape) 101 | batched_scale = jnp.broadcast_to(self._scale, batch_shape) 102 | batched_shift = jnp.broadcast_to(self._shift, batch_shape) 103 | return batched_scale * x + batched_shift 104 | 105 | def forward_log_det_jacobian(self, x: Array) -> Array: 106 | """Computes log|det J(f)(x)|.""" 107 | batch_shape = jax.lax.broadcast_shapes(self._batch_shape, x.shape) 108 | return jnp.broadcast_to(self._log_scale, batch_shape) 109 | 110 | def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: 111 | """Computes y = f(x) and log|det J(f)(x)|.""" 112 | return self.forward(x), self.forward_log_det_jacobian(x) 113 | 114 | def inverse(self, y: Array) -> Array: 115 | """Computes x = f^{-1}(y).""" 116 | batch_shape = jax.lax.broadcast_shapes(self._batch_shape, y.shape) 117 | batched_inv_scale = jnp.broadcast_to(self._inv_scale, batch_shape) 118 | batched_shift = jnp.broadcast_to(self._shift, batch_shape) 119 | return batched_inv_scale * (y - batched_shift) 120 | 121 | def inverse_log_det_jacobian(self, y: Array) -> Array: 122 | """Computes log|det J(f^{-1})(y)|.""" 123 | batch_shape = jax.lax.broadcast_shapes(self._batch_shape, y.shape) 124 | return jnp.broadcast_to(jnp.negative(self._log_scale), batch_shape) 125 | 126 | def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: 127 | """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" 128 | return self.inverse(y), self.inverse_log_det_jacobian(y) 129 | 130 | def same_as(self, other: base.Bijector) -> bool: 131 | """Returns True if this bijector is guaranteed to be the same as `other`.""" 132 | if type(other) is ScalarAffine: # pylint: disable=unidiomatic-typecheck 133 | return all(( 134 | self.shift is other.shift, 135 | self.scale is other.scale, 136 | self.log_scale is other.log_scale, 137 | )) 138 | else: 139 | return False 140 | -------------------------------------------------------------------------------- /distrax/_src/bijectors/shift.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Shift bijector.""" 16 | 17 | from typing import Tuple, Union 18 | 19 | from distrax._src.bijectors import bijector as base 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | Array = base.Array 24 | Numeric = Union[Array, float] 25 | 26 | 27 | class Shift(base.Bijector): 28 | """Bijector that translates its input elementwise. 29 | 30 | The bijector is defined as follows: 31 | 32 | - Forward: `y = x + shift` 33 | - Forward Jacobian determinant: `log|det J(x)| = 0` 34 | - Inverse: `x = y - shift` 35 | - Inverse Jacobian determinant: `log|det J(y)| = 0` 36 | 37 | where `shift` parameterizes the bijector. 38 | """ 39 | 40 | def __init__(self, shift: Numeric): 41 | """Initializes a `Shift` bijector. 42 | 43 | Args: 44 | shift: the bijector's shift parameter. Can also be batched. 45 | """ 46 | super().__init__(event_ndims_in=0, is_constant_jacobian=True) 47 | self._shift = shift 48 | self._batch_shape = jnp.shape(self._shift) 49 | 50 | @property 51 | def shift(self) -> Numeric: 52 | """The bijector's shift.""" 53 | return self._shift 54 | 55 | def forward(self, x: Array) -> Array: 56 | """Computes y = f(x).""" 57 | return x + self._shift 58 | 59 | def forward_log_det_jacobian(self, x: Array) -> Array: 60 | """Computes log|det J(f)(x)|.""" 61 | batch_shape = jax.lax.broadcast_shapes(self._batch_shape, x.shape) 62 | return jnp.zeros(batch_shape, dtype=x.dtype) 63 | 64 | def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: 65 | """Computes y = f(x) and log|det J(f)(x)|.""" 66 | return self.forward(x), self.forward_log_det_jacobian(x) 67 | 68 | def inverse(self, y: Array) -> Array: 69 | """Computes x = f^{-1}(y).""" 70 | return y - self._shift 71 | 72 | def inverse_log_det_jacobian(self, y: Array) -> Array: 73 | """Computes log|det J(f^{-1})(y)|.""" 74 | return self.forward_log_det_jacobian(y) 75 | 76 | def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: 77 | """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" 78 | return self.inverse(y), self.inverse_log_det_jacobian(y) 79 | 80 | def same_as(self, other: base.Bijector) -> bool: 81 | """Returns True if this bijector is guaranteed to be the same as `other`.""" 82 | if type(other) is Shift: # pylint: disable=unidiomatic-typecheck 83 | return self.shift is other.shift 84 | return False 85 | -------------------------------------------------------------------------------- /distrax/_src/bijectors/shift_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `shift.py`.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | 20 | import chex 21 | from distrax._src.bijectors.shift import Shift 22 | from distrax._src.bijectors.tanh import Tanh 23 | import jax 24 | import jax.numpy as jnp 25 | import numpy as np 26 | 27 | 28 | class ShiftTest(parameterized.TestCase): 29 | 30 | def test_jacobian_is_constant_property(self): 31 | bijector = Shift(jnp.ones((4,))) 32 | self.assertTrue(bijector.is_constant_jacobian) 33 | self.assertTrue(bijector.is_constant_log_det) 34 | 35 | def test_properties(self): 36 | bijector = Shift(jnp.array([1., 2., 3.])) 37 | np.testing.assert_array_equal(bijector.shift, np.array([1., 2., 3.])) 38 | 39 | @chex.all_variants 40 | @parameterized.parameters( 41 | {'batch_shape': (), 'param_shape': ()}, 42 | {'batch_shape': (3,), 'param_shape': ()}, 43 | {'batch_shape': (), 'param_shape': (3,)}, 44 | {'batch_shape': (2, 3), 'param_shape': (2, 3)}, 45 | ) 46 | def test_forward_methods(self, batch_shape, param_shape): 47 | bijector = Shift(jnp.ones(param_shape)) 48 | prng = jax.random.PRNGKey(42) 49 | x = jax.random.normal(prng, batch_shape) 50 | output_shape = jnp.broadcast_shapes(batch_shape, param_shape) 51 | y1 = self.variant(bijector.forward)(x) 52 | logdet1 = self.variant(bijector.forward_log_det_jacobian)(x) 53 | y2, logdet2 = self.variant(bijector.forward_and_log_det)(x) 54 | self.assertEqual(y1.shape, output_shape) 55 | self.assertEqual(y2.shape, output_shape) 56 | self.assertEqual(logdet1.shape, output_shape) 57 | self.assertEqual(logdet2.shape, output_shape) 58 | np.testing.assert_allclose(y1, x + 1., 1e-6) 59 | np.testing.assert_allclose(y2, x + 1., 1e-6) 60 | np.testing.assert_allclose(logdet1, 0., 1e-6) 61 | np.testing.assert_allclose(logdet2, 0., 1e-6) 62 | 63 | @chex.all_variants 64 | @parameterized.parameters( 65 | {'batch_shape': (), 'param_shape': ()}, 66 | {'batch_shape': (3,), 'param_shape': ()}, 67 | {'batch_shape': (), 'param_shape': (3,)}, 68 | {'batch_shape': (2, 3), 'param_shape': (2, 3)}, 69 | ) 70 | def test_inverse_methods(self, batch_shape, param_shape): 71 | bijector = Shift(jnp.ones(param_shape)) 72 | prng = jax.random.PRNGKey(42) 73 | y = jax.random.normal(prng, batch_shape) 74 | output_shape = jnp.broadcast_shapes(batch_shape, param_shape) 75 | x1 = self.variant(bijector.inverse)(y) 76 | logdet1 = self.variant(bijector.inverse_log_det_jacobian)(y) 77 | x2, logdet2 = self.variant(bijector.inverse_and_log_det)(y) 78 | self.assertEqual(x1.shape, output_shape) 79 | self.assertEqual(x2.shape, output_shape) 80 | self.assertEqual(logdet1.shape, output_shape) 81 | self.assertEqual(logdet2.shape, output_shape) 82 | np.testing.assert_allclose(x1, y - 1., 1e-6) 83 | np.testing.assert_allclose(x2, y - 1., 1e-6) 84 | np.testing.assert_allclose(logdet1, 0., 1e-6) 85 | np.testing.assert_allclose(logdet2, 0., 1e-6) 86 | 87 | def test_jittable(self): 88 | @jax.jit 89 | def f(x, b): 90 | return b.forward(x) 91 | 92 | bij = Shift(jnp.ones((4,))) 93 | x = np.zeros((4,)) 94 | f(x, bij) 95 | 96 | def test_same_as_itself(self): 97 | bij = Shift(jnp.ones((4,))) 98 | self.assertTrue(bij.same_as(bij)) 99 | 100 | def test_not_same_as_others(self): 101 | bij = Shift(jnp.ones((4,))) 102 | other = Shift(jnp.zeros((4,))) 103 | self.assertFalse(bij.same_as(other)) 104 | self.assertFalse(bij.same_as(Tanh())) 105 | 106 | 107 | if __name__ == '__main__': 108 | absltest.main() 109 | -------------------------------------------------------------------------------- /distrax/_src/bijectors/sigmoid.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Sigmoid bijector.""" 16 | 17 | from typing import Tuple 18 | 19 | from distrax._src.bijectors import bijector as base 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | 24 | Array = base.Array 25 | 26 | 27 | class Sigmoid(base.Bijector): 28 | """A bijector that computes the logistic sigmoid. 29 | 30 | The log-determinant implementation in this bijector is more numerically stable 31 | than relying on the automatic differentiation approach used by Lambda, so this 32 | bijector should be preferred over Lambda(jax.nn.sigmoid) where possible. See 33 | `tfp.bijectors.Sigmoid` for details. 34 | 35 | Note that the underlying implementation of `jax.nn.sigmoid` used by the 36 | `forward` function of this bijector does not support inputs of integer type. 37 | To invoke the forward function of this bijector on an argument of integer 38 | type, it should first be cast explicitly to a floating point type. 39 | 40 | When the absolute value of the input is large, `Sigmoid` becomes close to a 41 | constant, so that it is not possible to recover the input `x` from the output 42 | `y` within machine precision. In cases where it is needed to compute both the 43 | forward mapping and the backward mapping one after the other to recover the 44 | original input `x`, it is the user's responsibility to simplify the operation 45 | to avoid numerical issues; this is unlike the `tfp.bijectors.Sigmoid`. One 46 | example of such case is to use the bijector within a `Transformed` 47 | distribution and to obtain the log-probability of samples obtained from the 48 | distribution's `sample` method. For values of the samples for which it is not 49 | possible to apply the inverse bijector accurately, `log_prob` returns NaN. 50 | This can be avoided by using `sample_and_log_prob` instead of `sample` 51 | followed by `log_prob`. 52 | """ 53 | 54 | def __init__(self): 55 | """Initializes a Sigmoid bijector.""" 56 | super().__init__(event_ndims_in=0) 57 | 58 | def forward_log_det_jacobian(self, x: Array) -> Array: 59 | """Computes log|det J(f)(x)|.""" 60 | # pylint:disable=invalid-unary-operand-type 61 | return -_more_stable_softplus(-x) - _more_stable_softplus(x) 62 | 63 | def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: 64 | """Computes y = f(x) and log|det J(f)(x)|.""" 65 | return _more_stable_sigmoid(x), self.forward_log_det_jacobian(x) 66 | 67 | def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: 68 | """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" 69 | x = jnp.log(y) - jnp.log1p(-y) 70 | return x, -self.forward_log_det_jacobian(x) 71 | 72 | def same_as(self, other: base.Bijector) -> bool: 73 | """Returns True if this bijector is guaranteed to be the same as `other`.""" 74 | return type(other) is Sigmoid # pylint: disable=unidiomatic-typecheck 75 | 76 | 77 | def _more_stable_sigmoid(x: Array) -> Array: 78 | """Where extremely negatively saturated, approximate sigmoid with exp(x).""" 79 | return jnp.where(x < -9, jnp.exp(x), jax.nn.sigmoid(x)) 80 | 81 | 82 | def _more_stable_softplus(x: Array) -> Array: 83 | """Where extremely saturated, approximate softplus with log1p(exp(x)).""" 84 | return jnp.where(x < -9, jnp.log1p(jnp.exp(x)), jax.nn.softplus(x)) 85 | -------------------------------------------------------------------------------- /distrax/_src/bijectors/tanh.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tanh bijector.""" 16 | 17 | from typing import Tuple 18 | 19 | from distrax._src.bijectors import bijector as base 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | 24 | Array = base.Array 25 | 26 | 27 | class Tanh(base.Bijector): 28 | """A bijector that computes the hyperbolic tangent. 29 | 30 | The log-determinant implementation in this bijector is more numerically stable 31 | than relying on the automatic differentiation approach used by Lambda, so this 32 | bijector should be preferred over Lambda(jnp.tanh) where possible. See 33 | `tfp.bijectors.Tanh` for details. 34 | 35 | When the absolute value of the input is large, `Tanh` becomes close to a 36 | constant, so that it is not possible to recover the input `x` from the output 37 | `y` within machine precision. In cases where it is needed to compute both the 38 | forward mapping and the backward mapping one after the other to recover the 39 | original input `x`, it is the user's responsibility to simplify the operation 40 | to avoid numerical issues; this is unlike the `tfp.bijectors.Tanh`. One 41 | example of such case is to use the bijector within a `Transformed` 42 | distribution and to obtain the log-probability of samples obtained from the 43 | distribution's `sample` method. For values of the samples for which it is not 44 | possible to apply the inverse bijector accurately, `log_prob` returns NaN. 45 | This can be avoided by using `sample_and_log_prob` instead of `sample` 46 | followed by `log_prob`. 47 | """ 48 | 49 | def __init__(self): 50 | """Initializes a Tanh bijector.""" 51 | super().__init__(event_ndims_in=0) 52 | 53 | def forward_log_det_jacobian(self, x: Array) -> Array: 54 | """Computes log|det J(f)(x)|.""" 55 | return 2 * (jnp.log(2) - x - jax.nn.softplus(-2 * x)) 56 | 57 | def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: 58 | """Computes y = f(x) and log|det J(f)(x)|.""" 59 | return jnp.tanh(x), self.forward_log_det_jacobian(x) 60 | 61 | def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: 62 | """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" 63 | x = jnp.arctanh(y) 64 | return x, -self.forward_log_det_jacobian(x) 65 | 66 | def same_as(self, other: base.Bijector) -> bool: 67 | """Returns True if this bijector is guaranteed to be the same as `other`.""" 68 | return type(other) is Tanh # pylint: disable=unidiomatic-typecheck 69 | -------------------------------------------------------------------------------- /distrax/_src/bijectors/triangular_linear.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Triangular linear bijector.""" 16 | 17 | import functools 18 | from typing import Tuple 19 | 20 | from distrax._src.bijectors import bijector as base 21 | from distrax._src.bijectors import linear 22 | import jax 23 | import jax.numpy as jnp 24 | 25 | Array = base.Array 26 | 27 | 28 | def _triangular_logdet(matrix: Array) -> Array: 29 | """Computes the log absolute determinant of a triangular matrix.""" 30 | return jnp.sum(jnp.log(jnp.abs(jnp.diag(matrix)))) 31 | 32 | 33 | def _forward_unbatched(x: Array, matrix: Array) -> Array: 34 | return matrix @ x 35 | 36 | 37 | def _inverse_unbatched(y: Array, matrix: Array, is_lower: bool) -> Array: 38 | return jax.scipy.linalg.solve_triangular(matrix, y, lower=is_lower) 39 | 40 | 41 | class TriangularLinear(linear.Linear): 42 | """A linear bijector whose weight matrix is triangular. 43 | 44 | The bijector is defined as `f(x) = Ax` where `A` is a DxD triangular matrix. 45 | 46 | The Jacobian determinant can be computed in O(D) as follows: 47 | 48 | log|det J(x)| = log|det A| = sum(log|diag(A)|) 49 | 50 | The inverse is computed in O(D^2) by solving the triangular system `Ax = y`. 51 | 52 | The bijector is invertible if and only if all diagonal elements of `A` are 53 | non-zero. It is the responsibility of the user to make sure that this is the 54 | case; the class will make no attempt to verify that the bijector is 55 | invertible. 56 | """ 57 | 58 | def __init__(self, matrix: Array, is_lower: bool = True): 59 | """Initializes a `TriangularLinear` bijector. 60 | 61 | Args: 62 | matrix: a square matrix whose triangular part defines `A`. Can also be a 63 | batch of matrices. Whether `A` is the lower or upper triangular part of 64 | `matrix` is determined by `is_lower`. 65 | is_lower: if True, `A` is set to the lower triangular part of `matrix`. If 66 | False, `A` is set to the upper triangular part of `matrix`. 67 | """ 68 | if matrix.ndim < 2: 69 | raise ValueError(f"`matrix` must have at least 2 dimensions, got" 70 | f" {matrix.ndim}.") 71 | if matrix.shape[-2] != matrix.shape[-1]: 72 | raise ValueError(f"`matrix` must be square; instead, it has shape" 73 | f" {matrix.shape[-2:]}.") 74 | super().__init__( 75 | event_dims=matrix.shape[-1], 76 | batch_shape=matrix.shape[:-2], 77 | dtype=matrix.dtype) 78 | self._matrix = jnp.tril(matrix) if is_lower else jnp.triu(matrix) 79 | self._is_lower = is_lower 80 | triangular_logdet = jnp.vectorize(_triangular_logdet, signature="(m,m)->()") 81 | self._logdet = triangular_logdet(self._matrix) 82 | 83 | @property 84 | def matrix(self) -> Array: 85 | """The triangular matrix `A` of the transformation.""" 86 | return self._matrix 87 | 88 | @property 89 | def is_lower(self) -> bool: 90 | """True if `A` is lower triangular, False if upper triangular.""" 91 | return self._is_lower 92 | 93 | def forward(self, x: Array) -> Array: 94 | """Computes y = f(x).""" 95 | self._check_forward_input_shape(x) 96 | batched = jnp.vectorize(_forward_unbatched, signature="(m),(m,m)->(m)") 97 | return batched(x, self._matrix) 98 | 99 | def forward_log_det_jacobian(self, x: Array) -> Array: 100 | """Computes log|det J(f)(x)|.""" 101 | self._check_forward_input_shape(x) 102 | batch_shape = jax.lax.broadcast_shapes(self.batch_shape, x.shape[:-1]) 103 | return jnp.broadcast_to(self._logdet, batch_shape) 104 | 105 | def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: 106 | """Computes y = f(x) and log|det J(f)(x)|.""" 107 | return self.forward(x), self.forward_log_det_jacobian(x) 108 | 109 | def inverse(self, y: Array) -> Array: 110 | """Computes x = f^{-1}(y).""" 111 | self._check_inverse_input_shape(y) 112 | batched = jnp.vectorize( 113 | functools.partial(_inverse_unbatched, is_lower=self._is_lower), 114 | signature="(m),(m,m)->(m)") 115 | return batched(y, self._matrix) 116 | 117 | def inverse_log_det_jacobian(self, y: Array) -> Array: 118 | """Computes log|det J(f^{-1})(y)|.""" 119 | return -self.forward_log_det_jacobian(y) 120 | 121 | def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: 122 | """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" 123 | return self.inverse(y), self.inverse_log_det_jacobian(y) 124 | 125 | def same_as(self, other: base.Bijector) -> bool: 126 | """Returns True if this bijector is guaranteed to be the same as `other`.""" 127 | if type(other) is TriangularLinear: # pylint: disable=unidiomatic-typecheck 128 | return all(( 129 | self.matrix is other.matrix, 130 | self.is_lower is other.is_lower, 131 | )) 132 | return False 133 | -------------------------------------------------------------------------------- /distrax/_src/bijectors/unconstrained_affine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Unconstrained affine bijector.""" 16 | 17 | from typing import Tuple 18 | 19 | from distrax._src.bijectors import bijector as base 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | Array = base.Array 24 | 25 | 26 | def check_affine_parameters(matrix: Array, bias: Array) -> None: 27 | """Checks that `matrix` and `bias` have valid shapes. 28 | 29 | Args: 30 | matrix: a matrix, or a batch of matrices. 31 | bias: a vector, or a batch of vectors. 32 | 33 | Raises: 34 | ValueError: if the shapes of `matrix` and `bias` are invalid. 35 | """ 36 | if matrix.ndim < 2: 37 | raise ValueError(f"`matrix` must have at least 2 dimensions, got" 38 | f" {matrix.ndim}.") 39 | if bias.ndim < 1: 40 | raise ValueError("`bias` must have at least 1 dimension.") 41 | if matrix.shape[-2] != matrix.shape[-1]: 42 | raise ValueError(f"`matrix` must be square; instead, it has shape" 43 | f" {matrix.shape[-2:]}.") 44 | if matrix.shape[-1] != bias.shape[-1]: 45 | raise ValueError(f"`matrix` and `bias` have inconsistent shapes: `matrix`" 46 | f" is {matrix.shape[-2:]}, `bias` is {bias.shape[-1:]}.") 47 | 48 | 49 | class UnconstrainedAffine(base.Bijector): 50 | """An unconstrained affine bijection. 51 | 52 | This bijector is a linear-plus-bias transformation `f(x) = Ax + b`, where `A` 53 | is a `D x D` square matrix and `b` is a `D`-dimensional vector. 54 | 55 | The bijector is invertible if and only if `A` is an invertible matrix. It is 56 | the responsibility of the user to make sure that this is the case; the class 57 | will make no attempt to verify that the bijector is invertible. 58 | 59 | The Jacobian determinant is equal to `det(A)`. The inverse is computed by 60 | solving the linear system `Ax = y - b`. 61 | 62 | WARNING: Both the determinant and the inverse cost `O(D^3)` to compute. Thus, 63 | this bijector is recommended only for small `D`. 64 | """ 65 | 66 | def __init__(self, matrix: Array, bias: Array): 67 | """Initializes an `UnconstrainedAffine` bijector. 68 | 69 | Args: 70 | matrix: the matrix `A` in `Ax + b`. Must be square and invertible. Can 71 | also be a batch of matrices. 72 | bias: the vector `b` in `Ax + b`. Can also be a batch of vectors. 73 | """ 74 | check_affine_parameters(matrix, bias) 75 | super().__init__(event_ndims_in=1, is_constant_jacobian=True) 76 | self._batch_shape = jnp.broadcast_shapes(matrix.shape[:-2], bias.shape[:-1]) 77 | self._matrix = matrix 78 | self._bias = bias 79 | self._logdet = jnp.linalg.slogdet(matrix)[1] 80 | 81 | @property 82 | def matrix(self) -> Array: 83 | """The matrix `A` of the transformation.""" 84 | return self._matrix 85 | 86 | @property 87 | def bias(self) -> Array: 88 | """The shift `b` of the transformation.""" 89 | return self._bias 90 | 91 | def forward(self, x: Array) -> Array: 92 | """Computes y = f(x).""" 93 | self._check_forward_input_shape(x) 94 | 95 | def unbatched(single_x, matrix, bias): 96 | return matrix @ single_x + bias 97 | 98 | batched = jnp.vectorize(unbatched, signature="(m),(m,m),(m)->(m)") 99 | return batched(x, self._matrix, self._bias) 100 | 101 | def forward_log_det_jacobian(self, x: Array) -> Array: 102 | """Computes log|det J(f)(x)|.""" 103 | self._check_forward_input_shape(x) 104 | batch_shape = jax.lax.broadcast_shapes(self._batch_shape, x.shape[:-1]) 105 | return jnp.broadcast_to(self._logdet, batch_shape) 106 | 107 | def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: 108 | """Computes y = f(x) and log|det J(f)(x)|.""" 109 | return self.forward(x), self.forward_log_det_jacobian(x) 110 | 111 | def inverse(self, y: Array) -> Array: 112 | """Computes x = f^{-1}(y).""" 113 | self._check_inverse_input_shape(y) 114 | 115 | def unbatched(single_y, matrix, bias): 116 | return jnp.linalg.solve(matrix, single_y - bias) 117 | 118 | batched = jnp.vectorize(unbatched, signature="(m),(m,m),(m)->(m)") 119 | return batched(y, self._matrix, self._bias) 120 | 121 | def inverse_log_det_jacobian(self, y: Array) -> Array: 122 | """Computes log|det J(f^{-1})(y)|.""" 123 | return -self.forward_log_det_jacobian(y) 124 | 125 | def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: 126 | """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" 127 | return self.inverse(y), self.inverse_log_det_jacobian(y) 128 | 129 | def same_as(self, other: base.Bijector) -> bool: 130 | """Returns True if this bijector is guaranteed to be the same as `other`.""" 131 | if type(other) is UnconstrainedAffine: # pylint: disable=unidiomatic-typecheck 132 | return all(( 133 | self.matrix is other.matrix, 134 | self.bias is other.bias, 135 | )) 136 | return False 137 | -------------------------------------------------------------------------------- /distrax/_src/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /distrax/_src/distributions/clipped.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Clipped distributions.""" 16 | 17 | from typing import Tuple 18 | 19 | import chex 20 | from distrax._src.distributions import distribution as base_distribution 21 | from distrax._src.distributions import logistic 22 | from distrax._src.distributions import normal 23 | from distrax._src.utils import conversion 24 | import jax.numpy as jnp 25 | 26 | 27 | Array = chex.Array 28 | PRNGKey = chex.PRNGKey 29 | Numeric = chex.Numeric 30 | DistributionLike = base_distribution.DistributionLike 31 | EventT = base_distribution.EventT 32 | 33 | 34 | class Clipped(base_distribution.Distribution): 35 | """A clipped distribution.""" 36 | 37 | def __init__( 38 | self, 39 | distribution: DistributionLike, 40 | minimum: Numeric, 41 | maximum: Numeric): 42 | """Wraps a distribution clipping samples out of `[minimum, maximum]`. 43 | 44 | The samples outside of `[minimum, maximum]` are clipped to the boundary. 45 | The log probability of samples outside of this range is `-inf`. 46 | 47 | Args: 48 | distribution: a Distrax / TFP distribution to be wrapped. 49 | minimum: can be a `scalar` or `vector`; if a vector, must have fewer dims 50 | than `distribution.batch_shape` and must be broadcastable to it. 51 | maximum: can be a `scalar` or `vector`; if a vector, must have fewer dims 52 | than `distribution.batch_shape` and must be broadcastable to it. 53 | """ 54 | super().__init__() 55 | if distribution.event_shape: 56 | raise ValueError('The wrapped distribution must have event shape ().') 57 | if (jnp.array(minimum).ndim > len(distribution.batch_shape) or 58 | jnp.array(maximum).ndim > len(distribution.batch_shape)): 59 | raise ValueError( 60 | 'The minimum and maximum clipping boundaries must be scalars or' 61 | 'vectors with fewer dimensions as the batch_shape of distribution:' 62 | 'i.e. we can broadcast min/max to batch_shape but not viceversa.') 63 | self._distribution = conversion.as_distribution(distribution) 64 | self._minimum = jnp.broadcast_to(minimum, self._distribution.batch_shape) 65 | self._maximum = jnp.broadcast_to(maximum, self._distribution.batch_shape) 66 | self._log_prob_minimum = self._distribution.log_cdf(minimum) 67 | self._log_prob_maximum = self._distribution.log_survival_function(maximum) 68 | 69 | def _sample_n(self, key: PRNGKey, n: int) -> Array: 70 | """See `Distribution._sample_n`.""" 71 | raw_sample = self._distribution.sample(seed=key, sample_shape=[n]) 72 | return jnp.clip(raw_sample, self._minimum, self._maximum) 73 | 74 | def _sample_n_and_log_prob(self, key: PRNGKey, n: int) -> Tuple[Array, Array]: 75 | """See `Distribution._sample_n_and_log_prob`.""" 76 | samples = self._sample_n(key, n) 77 | return samples, self.log_prob(samples) 78 | 79 | def log_prob(self, value: EventT) -> Array: 80 | """See `Distribution.log_prob`.""" 81 | # The log_prob can be used to compute expectations by explicitly integrating 82 | # over the discrete and continuous elements. 83 | # Info about mixed distributions: 84 | # http://www.randomservices.org/random/dist/Mixed.html 85 | log_prob = jnp.where( 86 | jnp.equal(value, self._minimum), 87 | self._log_prob_minimum, 88 | jnp.where(jnp.equal(value, self._maximum), 89 | self._log_prob_maximum, 90 | self._distribution.log_prob(value))) 91 | # Giving -inf log_prob outside the boundaries. 92 | return jnp.where( 93 | jnp.logical_or(value < self._minimum, value > self._maximum), 94 | -jnp.inf, 95 | log_prob) 96 | 97 | @property 98 | def minimum(self) -> Array: 99 | return self._minimum 100 | 101 | @property 102 | def maximum(self) -> Array: 103 | return self._maximum 104 | 105 | @property 106 | def distribution(self) -> DistributionLike: 107 | return self._distribution 108 | 109 | @property 110 | def event_shape(self) -> Tuple[int, ...]: 111 | return () 112 | 113 | @property 114 | def batch_shape(self) -> Tuple[int, ...]: 115 | return self._distribution.batch_shape 116 | 117 | def __getitem__(self, index) -> 'Clipped': 118 | """See `Distribution.__getitem__`.""" 119 | index = base_distribution.to_batch_shape_index(self.batch_shape, index) 120 | return Clipped( 121 | distribution=self.distribution[index], 122 | minimum=self.minimum[index], 123 | maximum=self.maximum[index]) 124 | 125 | 126 | class ClippedNormal(Clipped): 127 | """A clipped normal distribution.""" 128 | 129 | def __init__( 130 | self, loc: Numeric, scale: Numeric, minimum: Numeric, maximum: Numeric): 131 | distribution = normal.Normal(loc=loc, scale=scale) 132 | super().__init__(distribution, minimum=minimum, maximum=maximum) 133 | 134 | 135 | class ClippedLogistic(Clipped): 136 | """A clipped logistic distribution.""" 137 | 138 | def __init__( 139 | self, loc: Numeric, scale: Numeric, minimum: Numeric, maximum: Numeric): 140 | distribution = logistic.Logistic(loc=loc, scale=scale) 141 | super().__init__(distribution, minimum=minimum, maximum=maximum) 142 | -------------------------------------------------------------------------------- /distrax/_src/distributions/distribution_from_tfp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Wrapper to adapt a TFP distribution.""" 16 | 17 | from typing import Tuple 18 | 19 | import chex 20 | from distrax._src.distributions import distribution 21 | import jax.numpy as jnp 22 | from tensorflow_probability.substrates import jax as tfp 23 | 24 | tfd = tfp.distributions 25 | 26 | Array = chex.Array 27 | PRNGKey = chex.PRNGKey 28 | DistributionT = distribution.DistributionT 29 | EventT = distribution.EventT 30 | 31 | 32 | def distribution_from_tfp(tfp_distribution: tfd.Distribution) -> DistributionT: 33 | """Create a Distrax distribution from a TFP distribution. 34 | 35 | Given a TFP distribution `tfp_distribution`, this method returns a 36 | distribution of a class that inherits from the class of `tfp_distribution`. 37 | The returned distribution behaves almost identically as the TFP distribution, 38 | except the common methods (`sample`, `variance`, etc.) are overwritten to 39 | return `jnp.ndarrays`. Moreover, the wrapped distribution also implements 40 | Distrax methods inherited from `Distribution`, such as `sample_and_log_prob`. 41 | 42 | Args: 43 | tfp_distribution: A TFP distribution. 44 | 45 | Returns: 46 | The wrapped distribution. 47 | """ 48 | 49 | class DistributionFromTFP( 50 | distribution.Distribution, tfp_distribution.__class__): 51 | """Class to wrap a TFP distribution. 52 | 53 | The wrapped class dynamically inherits from `tfp_distribution`, so that 54 | computations involving the KL remain valid. 55 | """ 56 | 57 | def __init__(self): 58 | pass 59 | 60 | def __getattr__(self, name: str): 61 | return getattr(tfp_distribution, name) 62 | 63 | def sample(self, *a, **k): # pylint: disable=useless-super-delegation 64 | """See `Distribution.sample`.""" 65 | return super().sample(*a, **k) 66 | 67 | def _sample_n(self, key: PRNGKey, n: int): 68 | """See `Distribution._sample_n`.""" 69 | return jnp.asarray( 70 | tfp_distribution.sample(seed=key, sample_shape=(n,)), 71 | dtype=tfp_distribution.dtype) 72 | 73 | def log_prob(self, value: EventT) -> Array: 74 | """See `Distribution.log_prob`.""" 75 | return jnp.asarray(tfp_distribution.log_prob(value)) 76 | 77 | def prob(self, value: EventT) -> Array: 78 | """See `Distribution.prob`.""" 79 | return jnp.asarray(tfp_distribution.prob(value)) 80 | 81 | @property 82 | def event_shape(self) -> Tuple[int, ...]: 83 | """See `Distribution.event_shape`.""" 84 | return tuple(tfp_distribution.event_shape) 85 | 86 | @property 87 | def batch_shape(self) -> Tuple[int, ...]: 88 | """See `Distribution.batch_shape`.""" 89 | return tuple(tfp_distribution.batch_shape) 90 | 91 | @property 92 | def name(self) -> str: 93 | """See `Distribution.name`.""" 94 | return tfp_distribution.name 95 | 96 | @property 97 | def dtype(self) -> jnp.dtype: 98 | """See `Distribution.dtype`.""" 99 | return tfp_distribution.dtype 100 | 101 | def kl_divergence(self, other_dist, *args, **kwargs) -> Array: 102 | """See `Distribution.kl_divergence`.""" 103 | return jnp.asarray( 104 | tfd.kullback_leibler.kl_divergence(self, other_dist, *args, **kwargs)) 105 | 106 | def entropy(self) -> Array: 107 | """See `Distribution.entropy`.""" 108 | return jnp.asarray(tfp_distribution.entropy()) 109 | 110 | def log_cdf(self, value: EventT) -> Array: 111 | """See `Distribution.log_cdf`.""" 112 | return jnp.asarray(tfp_distribution.log_cdf(value)) 113 | 114 | def cdf(self, value: EventT) -> Array: 115 | """See `Distribution.cdf`.""" 116 | return jnp.asarray(tfp_distribution.cdf(value)) 117 | 118 | def mean(self) -> Array: 119 | """See `Distribution.mean`.""" 120 | return jnp.asarray(tfp_distribution.mean()) 121 | 122 | def median(self) -> Array: 123 | """See `Distribution.median`.""" 124 | return jnp.asarray(tfp_distribution.median()) 125 | 126 | def variance(self) -> Array: 127 | """See `Distribution.variance`.""" 128 | return jnp.asarray(tfp_distribution.variance()) 129 | 130 | def stddev(self) -> Array: 131 | """See `Distribution.stddev`.""" 132 | return jnp.asarray(tfp_distribution.stddev()) 133 | 134 | def mode(self) -> Array: 135 | """See `Distribution.mode`.""" 136 | return jnp.asarray(tfp_distribution.mode()) 137 | 138 | def __getitem__(self, index) -> DistributionT: 139 | """See `Distribution.__getitem__`.""" 140 | return distribution_from_tfp(tfp_distribution[index]) 141 | 142 | return DistributionFromTFP() 143 | -------------------------------------------------------------------------------- /distrax/_src/distributions/epsilon_greedy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Epsilon-Greedy distributions with respect to a set of preferences.""" 16 | 17 | from typing import Any, Union 18 | 19 | import chex 20 | from distrax._src.distributions import categorical 21 | from distrax._src.distributions import distribution 22 | import jax.numpy as jnp 23 | 24 | 25 | Array = chex.Array 26 | 27 | 28 | def _argmax_with_random_tie_breaking(preferences: Array) -> Array: 29 | """Compute probabilities greedily with respect to a set of preferences.""" 30 | optimal_actions = preferences == preferences.max(axis=-1, keepdims=True) 31 | return optimal_actions / optimal_actions.sum(axis=-1, keepdims=True) 32 | 33 | 34 | def _mix_probs_with_uniform(probs: Array, epsilon: float) -> Array: 35 | """Mix an arbitrary categorical distribution with a uniform distribution.""" 36 | num_actions = probs.shape[-1] 37 | uniform_probs = jnp.ones_like(probs) / num_actions 38 | return (1 - epsilon) * probs + epsilon * uniform_probs 39 | 40 | 41 | class EpsilonGreedy(categorical.Categorical): 42 | """A Categorical that is ε-greedy with respect to some preferences. 43 | 44 | Given a set of unnormalized preferences, the distribution is a mixture 45 | of the Greedy and Uniform distribution; with weight (1-ε) and ε, respectively. 46 | """ 47 | 48 | def __init__(self, 49 | preferences: Array, 50 | epsilon: float, 51 | dtype: Union[jnp.dtype, type[Any]] = int): 52 | """Initializes an EpsilonGreedy distribution. 53 | 54 | Args: 55 | preferences: Unnormalized preferences. 56 | epsilon: Mixing parameter ε. 57 | dtype: The type of event samples. 58 | """ 59 | self._preferences = jnp.asarray(preferences) 60 | self._epsilon = epsilon 61 | greedy_probs = _argmax_with_random_tie_breaking(self._preferences) 62 | probs = _mix_probs_with_uniform(greedy_probs, epsilon) 63 | super().__init__(probs=probs, dtype=dtype) 64 | 65 | @property 66 | def epsilon(self) -> float: 67 | """Mixing parameters of the distribution.""" 68 | return self._epsilon 69 | 70 | @property 71 | def preferences(self) -> Array: 72 | """Unnormalized preferences.""" 73 | return self._preferences 74 | 75 | def __getitem__(self, index) -> 'EpsilonGreedy': 76 | """See `Distribution.__getitem__`.""" 77 | index = distribution.to_batch_shape_index(self.batch_shape, index) 78 | return EpsilonGreedy( 79 | preferences=self.preferences[index], 80 | epsilon=self.epsilon, 81 | dtype=self.dtype) 82 | -------------------------------------------------------------------------------- /distrax/_src/distributions/epsilon_greedy_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `epsilon_greedy.py`.""" 16 | 17 | import functools 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | 22 | import chex 23 | from distrax._src.distributions import epsilon_greedy 24 | from distrax._src.utils import equivalence 25 | import jax.experimental 26 | import jax.numpy as jnp 27 | import numpy as np 28 | 29 | 30 | class EpsilonGreedyTest(equivalence.EquivalenceTest): 31 | 32 | def setUp(self): 33 | super().setUp() 34 | self._init_distr_cls(epsilon_greedy.EpsilonGreedy) 35 | self.epsilon = 0.2 36 | self.preferences = jnp.array([0., 4., -1., 4.]) 37 | 38 | def test_parameters_from_preferences(self): 39 | dist = self.distrax_cls(preferences=self.preferences, epsilon=self.epsilon) 40 | expected_probs = jnp.array([0.05, 0.45, 0.05, 0.45]) 41 | self.assertion_fn(rtol=2e-3)(dist.logits, jnp.log(expected_probs)) 42 | self.assertion_fn(rtol=2e-3)(dist.probs, expected_probs) 43 | 44 | def test_num_categories(self): 45 | dist = self.distrax_cls(preferences=self.preferences, epsilon=self.epsilon) 46 | np.testing.assert_equal(dist.num_categories, len(self.preferences)) 47 | 48 | @chex.all_variants 49 | @parameterized.named_parameters( 50 | ('int32', jnp.int32), 51 | ('int64', jnp.int64), 52 | ('float32', jnp.float32), 53 | ('float64', jnp.float64)) 54 | def test_sample_dtype(self, dtype): 55 | with jax.experimental.enable_x64(dtype.dtype.itemsize == 8): 56 | dist = self.distrax_cls( 57 | preferences=self.preferences, epsilon=self.epsilon, dtype=dtype) 58 | samples = self.variant(dist.sample)(seed=self.key) 59 | self.assertEqual(samples.dtype, dist.dtype) 60 | chex.assert_type(samples, dtype) 61 | 62 | def test_jittable(self): 63 | super()._test_jittable( 64 | dist_args=(np.array([0., 4., -1., 4.]), 0.1), 65 | assertion_fn=functools.partial(np.testing.assert_allclose, rtol=1e-5)) 66 | 67 | @parameterized.named_parameters( 68 | ('single element', 2), 69 | ('range', slice(-1)), 70 | ('range_2', (slice(None), slice(-1))), 71 | ) 72 | def test_slice(self, slice_): 73 | preferences = np.abs(np.random.randn(3, 4, 5)) 74 | dtype = jnp.float32 75 | dist = self.distrax_cls(preferences, self.epsilon, dtype=dtype) 76 | dist_sliced = dist[slice_] 77 | self.assertIsInstance(dist_sliced, epsilon_greedy.EpsilonGreedy) 78 | self.assertion_fn(rtol=2e-3)(dist_sliced.preferences, preferences[slice_]) 79 | self.assertion_fn(rtol=2e-3)(dist_sliced.epsilon, self.epsilon) 80 | self.assertEqual(dist_sliced.dtype, dtype) 81 | 82 | def test_slice_ellipsis(self): 83 | preferences = np.abs(np.random.randn(3, 4, 5)) 84 | dtype = jnp.float32 85 | dist = self.distrax_cls(preferences, self.epsilon, dtype=dtype) 86 | dist_sliced = dist[..., -1] 87 | self.assertIsInstance(dist_sliced, epsilon_greedy.EpsilonGreedy) 88 | self.assertion_fn(rtol=2e-3)(dist_sliced.preferences, preferences[:, -1]) 89 | self.assertion_fn(rtol=2e-3)(dist_sliced.epsilon, self.epsilon) 90 | self.assertEqual(dist_sliced.dtype, dtype) 91 | 92 | 93 | if __name__ == '__main__': 94 | absltest.main() 95 | -------------------------------------------------------------------------------- /distrax/_src/distributions/gamma.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Gamma distribution.""" 16 | 17 | from typing import Tuple, Union 18 | 19 | import chex 20 | from distrax._src.distributions import distribution 21 | from distrax._src.utils import conversion 22 | import jax 23 | import jax.numpy as jnp 24 | from tensorflow_probability.substrates import jax as tfp 25 | 26 | tfd = tfp.distributions 27 | 28 | Array = chex.Array 29 | Numeric = chex.Numeric 30 | PRNGKey = chex.PRNGKey 31 | EventT = distribution.EventT 32 | 33 | 34 | class Gamma(distribution.Distribution): 35 | """Gamma distribution with parameters `concentration` and `rate`.""" 36 | 37 | equiv_tfp_cls = tfd.Gamma 38 | 39 | def __init__(self, concentration: Numeric, rate: Numeric): 40 | """Initializes a Gamma distribution. 41 | 42 | Args: 43 | concentration: Concentration parameter of the distribution. 44 | rate: Inverse scale params of the distribution. 45 | """ 46 | super().__init__() 47 | self._concentration = conversion.as_float_array(concentration) 48 | self._rate = conversion.as_float_array(rate) 49 | self._batch_shape = jax.lax.broadcast_shapes( 50 | self._concentration.shape, self._rate.shape) 51 | 52 | @property 53 | def event_shape(self) -> Tuple[int, ...]: 54 | """Shape of event of distribution samples.""" 55 | return () 56 | 57 | @property 58 | def batch_shape(self) -> Tuple[int, ...]: 59 | """Shape of batch of distribution samples.""" 60 | return self._batch_shape 61 | 62 | @property 63 | def concentration(self) -> Array: 64 | """Concentration of the distribution.""" 65 | return jnp.broadcast_to(self._concentration, self.batch_shape) 66 | 67 | @property 68 | def rate(self) -> Array: 69 | """Inverse scale of the distribution.""" 70 | return jnp.broadcast_to(self._rate, self.batch_shape) 71 | 72 | def _sample_from_std_gamma(self, key: PRNGKey, n: int) -> Array: 73 | out_shape = (n,) + self.batch_shape 74 | dtype = jnp.result_type(self._concentration, self._rate) 75 | return jax.random.gamma( 76 | key, a=self._concentration, shape=out_shape, dtype=dtype 77 | ) 78 | 79 | def _sample_n(self, key: PRNGKey, n: int) -> Array: 80 | """See `Distribution._sample_n`.""" 81 | rnd = self._sample_from_std_gamma(key, n) 82 | return rnd / self._rate 83 | 84 | def log_prob(self, value: EventT) -> Array: 85 | """See `Distribution.log_prob`.""" 86 | return ( 87 | self._concentration * jnp.log(self._rate) 88 | + (self._concentration - 1) * jnp.log(value) 89 | - self._rate * value 90 | - jax.lax.lgamma(self._concentration) 91 | ) 92 | 93 | def entropy(self) -> Array: 94 | """Calculates the Shannon entropy (in nats).""" 95 | log_rate = jnp.log(self._rate) 96 | return ( 97 | self._concentration 98 | - log_rate 99 | + jax.lax.lgamma(self._concentration) 100 | + (1.0 - self._concentration) * jax.lax.digamma(self._concentration) 101 | ) 102 | 103 | def cdf(self, value: EventT) -> Array: 104 | """See `Distribution.cdf`.""" 105 | return jax.lax.igamma(self._concentration, self._rate * value) 106 | 107 | def log_cdf(self, value: EventT) -> Array: 108 | """See `Distribution.log_cdf`.""" 109 | return jnp.log(self.cdf(value)) 110 | 111 | def mean(self) -> Array: 112 | """Calculates the mean.""" 113 | return self._concentration / self._rate 114 | 115 | def stddev(self) -> Array: 116 | """Calculates the standard deviation.""" 117 | return jnp.sqrt(self._concentration) / self._rate 118 | 119 | def variance(self) -> Array: 120 | """Calculates the variance.""" 121 | return self._concentration / jnp.square(self._rate) 122 | 123 | def mode(self) -> Array: 124 | """Calculates the mode.""" 125 | mode = (self._concentration - 1.0) / self._rate 126 | return jnp.where(self._concentration >= 1.0, mode, jnp.nan) 127 | 128 | def __getitem__(self, index) -> 'Gamma': 129 | """See `Distribution.__getitem__`.""" 130 | index = distribution.to_batch_shape_index(self.batch_shape, index) 131 | return Gamma( 132 | concentration=self.concentration[index], rate=self.rate[index]) 133 | 134 | 135 | def _kl_divergence_gamma_gamma( 136 | dist1: Union[Gamma, tfd.Gamma], 137 | dist2: Union[Gamma, tfd.Gamma], 138 | *unused_args, 139 | **unused_kwargs, 140 | ) -> Array: 141 | """Batched KL divergence KL(dist1 || dist2) between two Gamma distributions. 142 | 143 | Args: 144 | dist1: A Gamma distribution. 145 | dist2: A Gamma distribution. 146 | 147 | Returns: 148 | Batchwise `KL(dist1 || dist2)`. 149 | """ 150 | t1 = dist2.concentration * (jnp.log(dist1.rate) - jnp.log(dist2.rate)) 151 | t2 = jax.lax.lgamma(dist2.concentration) - jax.lax.lgamma(dist1.concentration) 152 | t3 = (dist1.concentration - dist2.concentration) * jax.lax.digamma( 153 | dist1.concentration) 154 | t4 = (dist2.rate - dist1.rate) * (dist1.concentration / dist1.rate) 155 | return t1 + t2 + t3 + t4 156 | 157 | 158 | # Register the KL functions with TFP. 159 | tfd.RegisterKL(Gamma, Gamma)(_kl_divergence_gamma_gamma) 160 | tfd.RegisterKL(Gamma, Gamma.equiv_tfp_cls)(_kl_divergence_gamma_gamma) 161 | tfd.RegisterKL(Gamma.equiv_tfp_cls, Gamma)(_kl_divergence_gamma_gamma) 162 | -------------------------------------------------------------------------------- /distrax/_src/distributions/greedy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Greedy distributions with respect to a set of preferences.""" 16 | 17 | from typing import Any, Union 18 | 19 | import chex 20 | from distrax._src.distributions import categorical 21 | from distrax._src.distributions import distribution 22 | import jax.numpy as jnp 23 | 24 | 25 | Array = chex.Array 26 | 27 | 28 | def _argmax_with_random_tie_breaking(preferences: Array) -> Array: 29 | """Compute probabilities greedily with respect to a set of preferences.""" 30 | optimal_actions = preferences == preferences.max(axis=-1, keepdims=True) 31 | return optimal_actions / optimal_actions.sum(axis=-1, keepdims=True) 32 | 33 | 34 | class Greedy(categorical.Categorical): 35 | """A Categorical distribution that is greedy with respect to some preferences. 36 | 37 | Given a set of unnormalized preferences, the probability mass is distributed 38 | equally among all indices `i` such that `preferences[i] = max(preferences)`, 39 | all other indices will be assigned a probability of zero. 40 | """ 41 | 42 | def __init__( 43 | self, preferences: Array, dtype: Union[jnp.dtype, type[Any]] = int 44 | ): 45 | """Initializes a Greedy distribution. 46 | 47 | Args: 48 | preferences: Unnormalized preferences. 49 | dtype: The type of event samples. 50 | """ 51 | self._preferences = jnp.asarray(preferences) 52 | probs = _argmax_with_random_tie_breaking(self._preferences) 53 | super().__init__(probs=probs, dtype=dtype) 54 | 55 | @property 56 | def preferences(self) -> Array: 57 | """Unnormalized preferences.""" 58 | return self._preferences 59 | 60 | def __getitem__(self, index) -> 'Greedy': 61 | """See `Distribution.__getitem__`.""" 62 | index = distribution.to_batch_shape_index(self.batch_shape, index) 63 | return Greedy( 64 | preferences=self.preferences[index], dtype=self.dtype) 65 | -------------------------------------------------------------------------------- /distrax/_src/distributions/greedy_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `greedy.py`.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | 20 | import chex 21 | from distrax._src.distributions import greedy 22 | from distrax._src.utils import equivalence 23 | import jax.experimental 24 | import jax.numpy as jnp 25 | import numpy as np 26 | 27 | 28 | class GreedyTest(equivalence.EquivalenceTest): 29 | 30 | def setUp(self): 31 | super().setUp() 32 | self._init_distr_cls(greedy.Greedy) 33 | self.preferences = jnp.array([0., 4., -1., 4.]) 34 | 35 | def test_parameters_from_preferences(self): 36 | dist = self.distrax_cls(preferences=self.preferences) 37 | expected_probs = jnp.array([0., 0.5, 0., 0.5]) 38 | self.assertion_fn(rtol=2e-3)(dist.logits, jnp.log(expected_probs)) 39 | self.assertion_fn(rtol=2e-3)(dist.probs, expected_probs) 40 | 41 | def test_num_categories(self): 42 | dist = self.distrax_cls(preferences=self.preferences) 43 | np.testing.assert_equal(dist.num_categories, len(self.preferences)) 44 | 45 | @chex.all_variants 46 | @parameterized.named_parameters( 47 | ('int32', jnp.int32), 48 | ('int64', jnp.int64), 49 | ('float32', jnp.float32), 50 | ('float64', jnp.float64)) 51 | def test_sample_dtype(self, dtype): 52 | with jax.experimental.enable_x64(dtype.dtype.itemsize == 8): 53 | dist = self.distrax_cls(preferences=self.preferences, dtype=dtype) 54 | samples = self.variant(dist.sample)(seed=self.key) 55 | self.assertEqual(samples.dtype, dist.dtype) 56 | chex.assert_type(samples, dtype) 57 | 58 | def test_jittable(self): 59 | super()._test_jittable((np.array([0., 4., -1., 4.]),)) 60 | 61 | @parameterized.named_parameters( 62 | ('single element', 2), 63 | ('range', slice(-1)), 64 | ('range_2', (slice(None), slice(-1))), 65 | ) 66 | def test_slice(self, slice_): 67 | preferences = np.abs(np.random.randn(3, 4, 5)) 68 | dtype = jnp.float32 69 | dist = self.distrax_cls(preferences, dtype=dtype) 70 | dist_sliced = dist[slice_] 71 | self.assertIsInstance(dist_sliced, greedy.Greedy) 72 | self.assertion_fn(rtol=2e-3)(dist_sliced.preferences, preferences[slice_]) 73 | self.assertEqual(dist_sliced.dtype, dtype) 74 | 75 | def test_slice_ellipsis(self): 76 | preferences = np.abs(np.random.randn(3, 4, 5)) 77 | dtype = jnp.float32 78 | dist = self.distrax_cls(preferences, dtype=dtype) 79 | dist_sliced = dist[..., -1] 80 | self.assertIsInstance(dist_sliced, greedy.Greedy) 81 | self.assertion_fn(rtol=2e-3)(dist_sliced.preferences, preferences[:, -1]) 82 | self.assertEqual(dist_sliced.dtype, dtype) 83 | 84 | 85 | if __name__ == '__main__': 86 | absltest.main() 87 | -------------------------------------------------------------------------------- /distrax/_src/distributions/gumbel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Gumbel distribution.""" 16 | 17 | import math 18 | from typing import Tuple, Union 19 | 20 | import chex 21 | from distrax._src.distributions import distribution 22 | from distrax._src.utils import conversion 23 | import jax 24 | import jax.numpy as jnp 25 | from tensorflow_probability.substrates import jax as tfp 26 | 27 | tfd = tfp.distributions 28 | 29 | Array = chex.Array 30 | Numeric = chex.Numeric 31 | PRNGKey = chex.PRNGKey 32 | EventT = distribution.EventT 33 | 34 | 35 | class Gumbel(distribution.Distribution): 36 | """Gumbel distribution with location `loc` and `scale` parameters.""" 37 | 38 | equiv_tfp_cls = tfd.Gumbel 39 | 40 | def __init__(self, loc: Numeric, scale: Numeric): 41 | """Initializes a Gumbel distribution. 42 | 43 | Args: 44 | loc: Mean of the distribution. 45 | scale: Spread of the distribution. 46 | """ 47 | super().__init__() 48 | self._loc = conversion.as_float_array(loc) 49 | self._scale = conversion.as_float_array(scale) 50 | self._batch_shape = jax.lax.broadcast_shapes( 51 | self._loc.shape, self._scale.shape) 52 | 53 | @property 54 | def event_shape(self) -> Tuple[int, ...]: 55 | """Shape of event of distribution samples.""" 56 | return () 57 | 58 | @property 59 | def batch_shape(self) -> Tuple[int, ...]: 60 | """Shape of batch of distribution samples.""" 61 | return self._batch_shape 62 | 63 | @property 64 | def loc(self) -> Array: 65 | """Mean of the distribution.""" 66 | return jnp.broadcast_to(self._loc, self.batch_shape) 67 | 68 | @property 69 | def scale(self) -> Array: 70 | """Scale of the distribution.""" 71 | return jnp.broadcast_to(self._scale, self.batch_shape) 72 | 73 | def _standardize(self, value: Array) -> Array: 74 | """Standardizes the input `value` in location and scale.""" 75 | return (value - self._loc) / self._scale 76 | 77 | def log_prob(self, value: EventT) -> Array: 78 | """See `Distribution.log_prob`.""" 79 | z = self._standardize(value) 80 | return -(z + jnp.exp(-z)) - jnp.log(self._scale) 81 | 82 | def _sample_from_std_gumbel(self, key: PRNGKey, n: int) -> Array: 83 | out_shape = (n,) + self.batch_shape 84 | dtype = jnp.result_type(self._loc, self._scale) 85 | return jax.random.gumbel(key, shape=out_shape, dtype=dtype) 86 | 87 | def _sample_n(self, key: PRNGKey, n: int) -> Array: 88 | """See `Distribution._sample_n`.""" 89 | rnd = self._sample_from_std_gumbel(key, n) 90 | return self._scale * rnd + self._loc 91 | 92 | def _sample_n_and_log_prob(self, key: PRNGKey, n: int) -> Tuple[Array, Array]: 93 | """See `Distribution._sample_n_and_log_prob`.""" 94 | rnd = self._sample_from_std_gumbel(key, n) 95 | samples = self._scale * rnd + self._loc 96 | log_prob = -(rnd + jnp.exp(-rnd)) - jnp.log(self._scale) 97 | return samples, log_prob 98 | 99 | def entropy(self) -> Array: 100 | """Calculates the Shannon entropy (in nats).""" 101 | return jnp.log(self._scale) + 1. + jnp.euler_gamma 102 | 103 | def log_cdf(self, value: EventT) -> Array: 104 | """See `Distribution.log_cdf`.""" 105 | z = self._standardize(value) 106 | return -jnp.exp(-z) 107 | 108 | def mean(self) -> Array: 109 | """Calculates the mean.""" 110 | return self._loc + self._scale * jnp.euler_gamma 111 | 112 | def stddev(self) -> Array: 113 | """Calculates the standard deviation.""" 114 | return self._scale * jnp.ones_like(self._loc) * jnp.pi / math.sqrt(6.) 115 | 116 | def variance(self) -> Array: 117 | """Calculates the variance.""" 118 | return jnp.square(self._scale * jnp.ones_like(self._loc) * jnp.pi) / 6. 119 | 120 | def mode(self) -> Array: 121 | """Calculates the mode.""" 122 | return self.loc 123 | 124 | def median(self) -> Array: 125 | """Calculates the median.""" 126 | return self._loc - self._scale * math.log(math.log(2.)) 127 | 128 | def __getitem__(self, index) -> 'Gumbel': 129 | """See `Distribution.__getitem__`.""" 130 | index = distribution.to_batch_shape_index(self.batch_shape, index) 131 | return Gumbel(loc=self.loc[index], scale=self.scale[index]) 132 | 133 | 134 | def _kl_divergence_gumbel_gumbel( 135 | dist1: Union[Gumbel, tfd.Gumbel], 136 | dist2: Union[Gumbel, tfd.Gumbel], 137 | *unused_args, **unused_kwargs, 138 | ) -> Array: 139 | """Batched KL divergence KL(dist1 || dist2) between two Gumbel distributions. 140 | 141 | Args: 142 | dist1: A Gumbel distribution. 143 | dist2: A Gumbel distribution. 144 | 145 | Returns: 146 | Batchwise `KL(dist1 || dist2)`. 147 | """ 148 | return (jnp.log(dist2.scale) - jnp.log(dist1.scale) + jnp.euler_gamma * 149 | (dist1.scale / dist2.scale - 1.) + 150 | jnp.expm1((dist2.loc - dist1.loc) / dist2.scale + 151 | jax.lax.lgamma(dist1.scale / dist2.scale + 1.)) + 152 | (dist1.loc - dist2.loc) / dist2.scale) 153 | 154 | 155 | # Register the KL functions with TFP. 156 | tfd.RegisterKL(Gumbel, Gumbel)(_kl_divergence_gumbel_gumbel) 157 | tfd.RegisterKL(Gumbel, Gumbel.equiv_tfp_cls)(_kl_divergence_gumbel_gumbel) 158 | tfd.RegisterKL(Gumbel.equiv_tfp_cls, Gumbel)(_kl_divergence_gumbel_gumbel) 159 | -------------------------------------------------------------------------------- /distrax/_src/distributions/log_stddev_normal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """LogStddevNormal distribution.""" 16 | 17 | import math 18 | from typing import Optional 19 | 20 | import chex 21 | from distrax._src.distributions import distribution 22 | from distrax._src.distributions import normal 23 | from distrax._src.utils import conversion 24 | import jax 25 | import jax.numpy as jnp 26 | from tensorflow_probability.substrates import jax as tfp 27 | 28 | tfd = tfp.distributions 29 | 30 | Array = chex.Array 31 | Numeric = chex.Numeric 32 | 33 | 34 | class LogStddevNormal(normal.Normal): 35 | """Normal distribution with `log_scale` parameter. 36 | 37 | The `LogStddevNormal` has three parameters: `loc`, `log_scale`, and 38 | (optionally) `max_scale`. The distribution is a univariate normal 39 | distribution with mean equal to `loc` and scale parameter (i.e., stddev) equal 40 | to `exp(log_scale)` if `max_scale` is None. If `max_scale` is not None, a soft 41 | thresholding is applied to obtain the scale parameter of the normal, so that 42 | its log is given by `log(max_scale) - softplus(log(max_scale) - log_scale)`. 43 | """ 44 | 45 | def __init__(self, 46 | loc: Numeric, 47 | log_scale: Numeric, 48 | max_scale: Optional[float] = None): 49 | """Initializes a LogStddevNormal distribution. 50 | 51 | Args: 52 | loc: Mean of the distribution. 53 | log_scale: Log of the distribution's scale (before the soft thresholding 54 | applied when `max_scale` is not None). 55 | max_scale: Maximum value of the scale that this distribution will saturate 56 | at. This parameter can be useful for numerical stability. It is not a 57 | hard maximum; rather, we compute `log(scale)` as per the formula: 58 | `log(max_scale) - softplus(log(max_scale) - log_scale)`. 59 | """ 60 | self._max_scale = max_scale 61 | if max_scale is not None: 62 | max_log_scale = math.log(max_scale) 63 | self._log_scale = max_log_scale - jax.nn.softplus( 64 | max_log_scale - conversion.as_float_array(log_scale)) 65 | else: 66 | self._log_scale = conversion.as_float_array(log_scale) 67 | scale = jnp.exp(self._log_scale) 68 | super().__init__(loc, scale) 69 | 70 | @property 71 | def log_scale(self) -> Array: 72 | """The log standard deviation (after thresholding, if applicable).""" 73 | return jnp.broadcast_to(self._log_scale, self.batch_shape) 74 | 75 | def __getitem__(self, index) -> 'LogStddevNormal': 76 | """See `Distribution.__getitem__`.""" 77 | index = distribution.to_batch_shape_index(self.batch_shape, index) 78 | return LogStddevNormal( 79 | loc=self.loc[index], 80 | log_scale=self.log_scale[index], 81 | max_scale=self._max_scale) 82 | 83 | 84 | def _kl_logstddevnormal_logstddevnormal( 85 | dist1: LogStddevNormal, dist2: LogStddevNormal, 86 | *unused_args, **unused_kwargs) -> Array: 87 | """Calculates the batched KL divergence between two LogStddevNormal's. 88 | 89 | Args: 90 | dist1: A LogStddevNormal distribution. 91 | dist2: A LogStddevNormal distribution. 92 | 93 | Returns: 94 | Batchwise KL(dist1 || dist2). 95 | """ 96 | # KL[N(u_a, s_a^2) || N(u_b, s_b^2)] between two Gaussians: 97 | # (s_a^2 + (u_a - u_b)^2)/(2*s_b^2) + log(s_b) - log(s_a) - 1/2. 98 | variance1 = jnp.square(dist1.scale) 99 | variance2 = jnp.square(dist2.scale) 100 | return ((variance1 + jnp.square(dist1.loc - dist2.loc)) / (2.0 * variance2) + 101 | dist2.log_scale - dist1.log_scale - 0.5) 102 | 103 | 104 | # Register the KL function. 105 | tfd.RegisterKL(LogStddevNormal, LogStddevNormal)( 106 | _kl_logstddevnormal_logstddevnormal) 107 | -------------------------------------------------------------------------------- /distrax/_src/distributions/logistic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Logistic distribution.""" 16 | 17 | from typing import Tuple 18 | 19 | import chex 20 | from distrax._src.distributions import distribution 21 | from distrax._src.utils import conversion 22 | import jax 23 | import jax.numpy as jnp 24 | from tensorflow_probability.substrates import jax as tfp 25 | 26 | tfd = tfp.distributions 27 | 28 | Array = chex.Array 29 | Numeric = chex.Numeric 30 | PRNGKey = chex.PRNGKey 31 | EventT = distribution.EventT 32 | 33 | 34 | class Logistic(distribution.Distribution): 35 | """The Logistic distribution with location `loc` and `scale` parameters.""" 36 | 37 | equiv_tfp_cls = tfd.Logistic 38 | 39 | def __init__(self, loc: Numeric, scale: Numeric) -> None: 40 | """Initializes a Logistic distribution. 41 | 42 | Args: 43 | loc: Mean of the distribution. 44 | scale: Spread of the distribution. 45 | """ 46 | super().__init__() 47 | self._loc = conversion.as_float_array(loc) 48 | self._scale = conversion.as_float_array(scale) 49 | self._batch_shape = jax.lax.broadcast_shapes( 50 | self._loc.shape, self._scale.shape) 51 | 52 | @property 53 | def event_shape(self) -> Tuple[int, ...]: 54 | """Shape of event of distribution samples.""" 55 | return () 56 | 57 | @property 58 | def batch_shape(self) -> Tuple[int, ...]: 59 | """Shape of batch of distribution samples.""" 60 | return self._batch_shape 61 | 62 | @property 63 | def loc(self) -> Array: 64 | """Mean of the distribution.""" 65 | return jnp.broadcast_to(self._loc, self.batch_shape) 66 | 67 | @property 68 | def scale(self) -> Array: 69 | """Spread of the distribution.""" 70 | return jnp.broadcast_to(self._scale, self.batch_shape) 71 | 72 | def _standardize(self, x: Array) -> Array: 73 | return (x - self.loc) / self.scale 74 | 75 | def _sample_n(self, key: PRNGKey, n: int) -> Array: 76 | """See `Distribution._sample_n`.""" 77 | out_shape = (n,) + self.batch_shape 78 | dtype = jnp.result_type(self._loc, self._scale) 79 | uniform = jax.random.uniform( 80 | key, 81 | shape=out_shape, 82 | dtype=dtype, 83 | minval=jnp.finfo(dtype).tiny, 84 | maxval=1.) 85 | rnd = jnp.log(uniform) - jnp.log1p(-uniform) 86 | return self._scale * rnd + self._loc 87 | 88 | def log_prob(self, value: EventT) -> Array: 89 | """See `Distribution.log_prob`.""" 90 | z = self._standardize(value) 91 | return -z - 2. * jax.nn.softplus(-z) - jnp.log(self._scale) 92 | 93 | def entropy(self) -> Array: 94 | """Calculates the Shannon entropy (in Nats).""" 95 | return 2. + jnp.broadcast_to(jnp.log(self._scale), self.batch_shape) 96 | 97 | def cdf(self, value: EventT) -> Array: 98 | """See `Distribution.cdf`.""" 99 | return jax.nn.sigmoid(self._standardize(value)) 100 | 101 | def log_cdf(self, value: EventT) -> Array: 102 | """See `Distribution.log_cdf`.""" 103 | return -jax.nn.softplus(-self._standardize(value)) 104 | 105 | def survival_function(self, value: EventT) -> Array: 106 | """See `Distribution.survival_function`.""" 107 | return jax.nn.sigmoid(-self._standardize(value)) 108 | 109 | def log_survival_function(self, value: EventT) -> Array: 110 | """See `Distribution.log_survival_function`.""" 111 | return -jax.nn.softplus(self._standardize(value)) 112 | 113 | def mean(self) -> Array: 114 | """Calculates the mean.""" 115 | return self.loc 116 | 117 | def variance(self) -> Array: 118 | """Calculates the variance.""" 119 | return jnp.square(self.scale * jnp.pi) / 3. 120 | 121 | def stddev(self) -> Array: 122 | """Calculates the standard deviation.""" 123 | return self.scale * jnp.pi / jnp.sqrt(3.) 124 | 125 | def mode(self) -> Array: 126 | """Calculates the mode.""" 127 | return self.mean() 128 | 129 | def median(self) -> Array: 130 | """Calculates the median.""" 131 | return self.mean() 132 | 133 | def __getitem__(self, index) -> 'Logistic': 134 | """See `Distribution.__getitem__`.""" 135 | index = distribution.to_batch_shape_index(self.batch_shape, index) 136 | return Logistic(loc=self.loc[index], scale=self.scale[index]) 137 | -------------------------------------------------------------------------------- /distrax/_src/distributions/mixture_of_two.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A simple mixture of two (possibly heterogeneous) distribution.""" 16 | 17 | from typing import Tuple 18 | 19 | import chex 20 | from distrax._src.distributions import distribution as base_distribution 21 | from distrax._src.utils import conversion 22 | import jax 23 | import jax.numpy as jnp 24 | from tensorflow_probability.substrates import jax as tfp 25 | 26 | 27 | tfd = tfp.distributions 28 | Array = chex.Array 29 | Numeric = chex.Numeric 30 | PRNGKey = chex.PRNGKey 31 | DistributionLike = base_distribution.DistributionLike 32 | EventT = base_distribution.EventT 33 | 34 | 35 | class MixtureOfTwo(base_distribution.Distribution): 36 | """A mixture of two distributions.""" 37 | 38 | def __init__( 39 | self, 40 | prob_a: Numeric, 41 | component_a: DistributionLike, 42 | component_b: DistributionLike): 43 | """Creates a mixture of two distributions. 44 | 45 | Differently from `MixtureSameFamily` the component distributions are allowed 46 | to belong to different families. 47 | 48 | Args: 49 | prob_a: a scalar weight for the `component_a`, is a float or a rank 0 50 | vector. 51 | component_a: the first component distribution. 52 | component_b: the second component distribution. 53 | """ 54 | super().__init__() 55 | # Validate inputs. 56 | chex.assert_rank(prob_a, 0) 57 | if component_a.event_shape != component_b.event_shape: 58 | raise ValueError( 59 | f'The component distributions must have the same event shape, but ' 60 | f'{component_a.event_shape} != {component_b.event_shape}.') 61 | if component_a.batch_shape != component_b.batch_shape: 62 | raise ValueError( 63 | f'The component distributions must have the same batch shape, but ' 64 | f'{component_a.batch_shape} != {component_b.batch_shape}.') 65 | if component_a.dtype != component_b.dtype: 66 | raise ValueError( 67 | 'The component distributions must have the same dtype, but' 68 | f' {component_a.dtype} != {component_b.dtype}.') 69 | # Store args. 70 | self._prob_a = prob_a 71 | self._component_a = conversion.as_distribution(component_a) 72 | self._component_b = conversion.as_distribution(component_b) 73 | 74 | def _sample_n(self, key: PRNGKey, n: int) -> Array: 75 | """See `Distribution._sample_n`.""" 76 | key, key_a, key_b, mask_key = jax.random.split(key, num=4) 77 | mask_from_a = jax.random.bernoulli(mask_key, p=self._prob_a, shape=[n]) 78 | sample_a = self._component_a.sample(seed=key_a, sample_shape=n) 79 | sample_b = self._component_b.sample(seed=key_b, sample_shape=n) 80 | mask_from_a = jnp.expand_dims(mask_from_a, tuple(range(1, sample_a.ndim))) 81 | return jnp.where(mask_from_a, sample_a, sample_b) 82 | 83 | def log_prob(self, value: EventT) -> Array: 84 | """See `Distribution.log_prob`.""" 85 | logp1 = jnp.log(self._prob_a) + self._component_a.log_prob(value) 86 | logp2 = jnp.log(1 - self._prob_a) + self._component_b.log_prob(value) 87 | return jnp.logaddexp(logp1, logp2) 88 | 89 | @property 90 | def event_shape(self) -> Tuple[int, ...]: 91 | return self._component_a.event_shape 92 | 93 | @property 94 | def batch_shape(self) -> Tuple[int, ...]: 95 | return self._component_a.batch_shape 96 | 97 | @property 98 | def prob_a(self) -> Numeric: 99 | return self._prob_a 100 | 101 | @property 102 | def prob_b(self) -> Numeric: 103 | return 1. - self._prob_a 104 | 105 | def __getitem__(self, index) -> 'MixtureOfTwo': 106 | """See `Distribution.__getitem__`.""" 107 | index = base_distribution.to_batch_shape_index(self.batch_shape, index) 108 | return MixtureOfTwo( 109 | prob_a=self.prob_a, 110 | component_a=self._component_a[index], 111 | component_b=self._component_b[index]) 112 | -------------------------------------------------------------------------------- /distrax/_src/distributions/mvn_diag.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """MultivariateNormalDiag distribution.""" 16 | 17 | from typing import Optional 18 | 19 | import chex 20 | from distrax._src.bijectors.diag_linear import DiagLinear 21 | from distrax._src.distributions import distribution 22 | from distrax._src.distributions.mvn_from_bijector import MultivariateNormalFromBijector 23 | from distrax._src.utils import conversion 24 | import jax 25 | import jax.numpy as jnp 26 | from tensorflow_probability.substrates import jax as tfp 27 | 28 | tfd = tfp.distributions 29 | 30 | Array = chex.Array 31 | EventT = distribution.EventT 32 | 33 | 34 | def _check_parameters( 35 | loc: Optional[Array], scale_diag: Optional[Array]) -> None: 36 | """Checks that the `loc` and `scale_diag` parameters are correct.""" 37 | chex.assert_not_both_none(loc, scale_diag) 38 | if scale_diag is not None and not scale_diag.shape: 39 | raise ValueError('If provided, argument `scale_diag` must have at least ' 40 | '1 dimension.') 41 | if loc is not None and not loc.shape: 42 | raise ValueError('If provided, argument `loc` must have at least ' 43 | '1 dimension.') 44 | if loc is not None and scale_diag is not None and ( 45 | loc.shape[-1] != scale_diag.shape[-1]): 46 | raise ValueError(f'The last dimension of arguments `loc` and ' 47 | f'`scale_diag` must coincide, but {loc.shape[-1]} != ' 48 | f'{scale_diag.shape[-1]}.') 49 | 50 | 51 | class MultivariateNormalDiag(MultivariateNormalFromBijector): 52 | """Multivariate normal distribution on `R^k` with diagonal covariance.""" 53 | 54 | equiv_tfp_cls = tfd.MultivariateNormalDiag 55 | 56 | def __init__(self, 57 | loc: Optional[Array] = None, 58 | scale_diag: Optional[Array] = None): 59 | """Initializes a MultivariateNormalDiag distribution. 60 | 61 | Args: 62 | loc: Mean vector of the distribution. Can also be a batch of vectors. If 63 | not specified, it defaults to zeros. At least one of `loc` and 64 | `scale_diag` must be specified. 65 | scale_diag: Vector of standard deviations. Can also be a batch of vectors. 66 | If not specified, it defaults to ones. At least one of `loc` and 67 | `scale_diag` must be specified. 68 | """ 69 | _check_parameters(loc, scale_diag) 70 | 71 | if scale_diag is None: 72 | loc = conversion.as_float_array(loc) 73 | scale_diag = jnp.ones(loc.shape[-1], loc.dtype) 74 | elif loc is None: 75 | scale_diag = conversion.as_float_array(scale_diag) 76 | loc = jnp.zeros(scale_diag.shape[-1], scale_diag.dtype) 77 | else: 78 | loc = conversion.as_float_array(loc) 79 | scale_diag = conversion.as_float_array(scale_diag) 80 | 81 | # Add leading dimensions to the paramteters to match the batch shape. This 82 | # prevents automatic rank promotion. 83 | broadcasted_shapes = jnp.broadcast_shapes(loc.shape, scale_diag.shape) 84 | loc = jnp.expand_dims( 85 | loc, axis=list(range(len(broadcasted_shapes) - loc.ndim))) 86 | scale_diag = jnp.expand_dims( 87 | scale_diag, axis=list(range(len(broadcasted_shapes) - scale_diag.ndim))) 88 | 89 | bias = jnp.zeros_like(loc, shape=loc.shape[-1:]) 90 | bias = jnp.expand_dims( 91 | bias, axis=list(range(len(broadcasted_shapes) - bias.ndim))) 92 | scale = DiagLinear(scale_diag) 93 | super().__init__(loc=loc, scale=scale) 94 | self._scale_diag = scale_diag 95 | 96 | @property 97 | def scale_diag(self) -> Array: 98 | """Scale of the distribution.""" 99 | return jnp.broadcast_to( 100 | self._scale_diag, self.batch_shape + self.event_shape) 101 | 102 | def _standardize(self, value: Array) -> Array: 103 | return (value - self._loc) / self._scale_diag 104 | 105 | def cdf(self, value: EventT) -> Array: 106 | """See `Distribution.cdf`.""" 107 | return jnp.prod(jax.scipy.special.ndtr(self._standardize(value)), axis=-1) 108 | 109 | def log_cdf(self, value: EventT) -> Array: 110 | """See `Distribution.log_cdf`.""" 111 | return jnp.sum( 112 | jax.scipy.special.log_ndtr(self._standardize(value)), axis=-1) 113 | 114 | def __getitem__(self, index) -> 'MultivariateNormalDiag': 115 | """See `Distribution.__getitem__`.""" 116 | index = distribution.to_batch_shape_index(self.batch_shape, index) 117 | return MultivariateNormalDiag( 118 | loc=self.loc[index], scale_diag=self.scale_diag[index]) 119 | -------------------------------------------------------------------------------- /distrax/_src/distributions/mvn_full_covariance.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """MultivariateNormalFullCovariance distribution.""" 16 | 17 | from typing import Optional 18 | 19 | import chex 20 | from distrax._src.distributions import distribution 21 | from distrax._src.distributions.mvn_tri import MultivariateNormalTri 22 | from distrax._src.utils import conversion 23 | import jax.numpy as jnp 24 | from tensorflow_probability.substrates import jax as tfp 25 | 26 | tfd = tfp.distributions 27 | Array = chex.Array 28 | 29 | 30 | def _check_parameters( 31 | loc: Optional[Array], covariance_matrix: Optional[Array]) -> None: 32 | """Checks that the inputs are correct.""" 33 | 34 | if loc is None and covariance_matrix is None: 35 | raise ValueError( 36 | 'At least one of `loc` and `covariance_matrix` must be specified.') 37 | 38 | if loc is not None and loc.ndim < 1: 39 | raise ValueError('The parameter `loc` must have at least one dimension.') 40 | 41 | if covariance_matrix is not None and covariance_matrix.ndim < 2: 42 | raise ValueError( 43 | f'The `covariance_matrix` must have at least two dimensions, but ' 44 | f'`covariance_matrix.shape = {covariance_matrix.shape}`.') 45 | 46 | if covariance_matrix is not None and ( 47 | covariance_matrix.shape[-1] != covariance_matrix.shape[-2]): 48 | raise ValueError( 49 | f'The `covariance_matrix` must be a (batched) square matrix, but ' 50 | f'`covariance_matrix.shape = {covariance_matrix.shape}`.') 51 | 52 | if loc is not None: 53 | num_dims = loc.shape[-1] 54 | if covariance_matrix is not None and ( 55 | covariance_matrix.shape[-1] != num_dims): 56 | raise ValueError( 57 | f'Shapes are not compatible: `loc.shape = {loc.shape}` and ' 58 | f'`covariance_matrix.shape = {covariance_matrix.shape}`.') 59 | 60 | 61 | class MultivariateNormalFullCovariance(MultivariateNormalTri): 62 | """Multivariate normal distribution on `R^k`. 63 | 64 | The `MultivariateNormalFullCovariance` distribution is parameterized by a 65 | `k`-length location (mean) vector `b` and a covariance matrix `C` of size 66 | `k x k` that must be positive definite and symmetric. 67 | 68 | This class makes no attempt to verify that `C` is positive definite or 69 | symmetric. It is the responsibility of the user to make sure that it is the 70 | case. 71 | """ 72 | 73 | equiv_tfp_cls = tfd.MultivariateNormalFullCovariance 74 | 75 | def __init__(self, 76 | loc: Optional[Array] = None, 77 | covariance_matrix: Optional[Array] = None): 78 | """Initializes a MultivariateNormalFullCovariance distribution. 79 | 80 | Args: 81 | loc: Mean vector of the distribution of shape `k` (can also be a batch of 82 | such vectors). If not specified, it defaults to zeros. 83 | covariance_matrix: The covariance matrix `C`. It must be a `k x k` matrix 84 | (additional dimensions index batches). If not specified, it defaults to 85 | the identity. 86 | """ 87 | loc = None if loc is None else conversion.as_float_array(loc) 88 | covariance_matrix = None if covariance_matrix is None else ( 89 | conversion.as_float_array(covariance_matrix)) 90 | _check_parameters(loc, covariance_matrix) 91 | 92 | num_dims = None 93 | if loc is not None: 94 | num_dims = loc.shape[-1] 95 | elif covariance_matrix is not None: 96 | num_dims = covariance_matrix.shape[-1] 97 | 98 | dtype = jnp.result_type( 99 | *[x for x in [loc, covariance_matrix] if x is not None]) 100 | 101 | if loc is None: 102 | assert num_dims is not None 103 | loc = jnp.zeros((num_dims,), dtype=dtype) 104 | 105 | if covariance_matrix is None: 106 | self._covariance_matrix = jnp.eye(num_dims, dtype=dtype) 107 | scale_tril = None 108 | else: 109 | self._covariance_matrix = covariance_matrix 110 | scale_tril = jnp.linalg.cholesky(covariance_matrix) 111 | 112 | super().__init__(loc=loc, scale_tri=scale_tril) 113 | 114 | @property 115 | def covariance_matrix(self) -> Array: 116 | """Covariance matrix `C`.""" 117 | return jnp.broadcast_to( 118 | self._covariance_matrix, 119 | self.batch_shape + self.event_shape + self.event_shape) 120 | 121 | def covariance(self) -> Array: 122 | """Covariance matrix `C`.""" 123 | return self.covariance_matrix 124 | 125 | def variance(self) -> Array: 126 | """Calculates the variance of all one-dimensional marginals.""" 127 | return jnp.vectorize(jnp.diag, signature='(k,k)->(k)')( 128 | self.covariance_matrix) 129 | 130 | def __getitem__(self, index) -> 'MultivariateNormalFullCovariance': 131 | """See `Distribution.__getitem__`.""" 132 | index = distribution.to_batch_shape_index(self.batch_shape, index) 133 | return MultivariateNormalFullCovariance( 134 | loc=self.loc[index], 135 | covariance_matrix=self.covariance_matrix[index]) 136 | -------------------------------------------------------------------------------- /distrax/_src/distributions/mvn_kl_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `kl_divergence` across different types of MultivariateNormal.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | 20 | import chex 21 | 22 | from distrax._src.distributions.mvn_diag import MultivariateNormalDiag 23 | from distrax._src.distributions.mvn_diag_plus_low_rank import MultivariateNormalDiagPlusLowRank 24 | from distrax._src.distributions.mvn_full_covariance import MultivariateNormalFullCovariance 25 | from distrax._src.distributions.mvn_tri import MultivariateNormalTri 26 | 27 | import numpy as np 28 | 29 | 30 | def _get_dist_params(dist, batch_shape, dim, rng): 31 | """Generates random parameters depending on the distribution type.""" 32 | if dist is MultivariateNormalDiag: 33 | distrax_dist_params = { 34 | 'scale_diag': rng.normal(size=batch_shape + (dim,)), 35 | } 36 | tfp_dist_params = distrax_dist_params 37 | elif dist is MultivariateNormalDiagPlusLowRank: 38 | scale_diag = rng.normal(size=batch_shape + (dim,)) 39 | scale_u_matrix = 0.2 * rng.normal(size=batch_shape + (dim, 2)) 40 | scale_perturb_diag = rng.normal(size=batch_shape + (2,)) 41 | scale_v_matrix = scale_u_matrix * np.expand_dims( 42 | scale_perturb_diag, axis=-2) 43 | distrax_dist_params = { 44 | 'scale_diag': scale_diag, 45 | 'scale_u_matrix': scale_u_matrix, 46 | 'scale_v_matrix': scale_v_matrix, 47 | } 48 | tfp_dist_params = { 49 | 'scale_diag': scale_diag, 50 | 'scale_perturb_factor': scale_u_matrix, 51 | 'scale_perturb_diag': scale_perturb_diag, 52 | } 53 | elif dist is MultivariateNormalTri: 54 | scale_tril = rng.normal(size=batch_shape + (dim, dim)) 55 | distrax_dist_params = { 56 | 'scale_tri': scale_tril, 57 | 'is_lower': True, 58 | } 59 | tfp_dist_params = { 60 | 'scale_tril': scale_tril, 61 | } 62 | elif dist is MultivariateNormalFullCovariance: 63 | matrix = rng.normal(size=batch_shape + (dim, dim)) 64 | matrix_t = np.vectorize(np.transpose, signature='(k,k)->(k,k)')(matrix) 65 | distrax_dist_params = { 66 | 'covariance_matrix': np.matmul(matrix, matrix_t), 67 | } 68 | tfp_dist_params = distrax_dist_params 69 | else: 70 | raise ValueError(f'Unsupported distribution type: {dist}') 71 | loc = rng.normal(size=batch_shape + (dim,)) 72 | distrax_dist_params.update({'loc': loc}) 73 | tfp_dist_params.update({'loc': loc}) 74 | return distrax_dist_params, tfp_dist_params 75 | 76 | 77 | class MultivariateNormalKLTest(parameterized.TestCase): 78 | 79 | @chex.all_variants(with_pmap=False) 80 | @parameterized.named_parameters( 81 | ('Diag vs DiagPlusLowRank', 82 | MultivariateNormalDiag, MultivariateNormalDiagPlusLowRank), 83 | ('Diag vs FullCovariance', 84 | MultivariateNormalDiag, MultivariateNormalFullCovariance), 85 | ('Diag vs Tri', 86 | MultivariateNormalDiag, MultivariateNormalTri), 87 | ('DiagPlusLowRank vs FullCovariance', 88 | MultivariateNormalDiagPlusLowRank, MultivariateNormalFullCovariance), 89 | ('DiagPlusLowRank vs Tri', 90 | MultivariateNormalDiagPlusLowRank, MultivariateNormalTri), 91 | ('Tri vs FullCovariance', 92 | MultivariateNormalTri, MultivariateNormalFullCovariance), 93 | ) 94 | def test_two_distributions(self, dist1_type, dist2_type): 95 | rng = np.random.default_rng(42) 96 | 97 | distrax_dist1_params, tfp_dist1_params = _get_dist_params( 98 | dist1_type, batch_shape=(8, 1), dim=5, rng=rng) 99 | distrax_dist2_params, tfp_dist2_params = _get_dist_params( 100 | dist2_type, batch_shape=(6,), dim=5, rng=rng) 101 | 102 | dist1_distrax = dist1_type(**distrax_dist1_params) 103 | dist1_tfp = dist1_type.equiv_tfp_cls(**tfp_dist1_params) 104 | dist2_distrax = dist2_type(**distrax_dist2_params) 105 | dist2_tfp = dist2_type.equiv_tfp_cls(**tfp_dist2_params) 106 | 107 | for method in ['kl_divergence', 'cross_entropy']: 108 | expected_result1 = getattr(dist1_tfp, method)(dist2_tfp) 109 | expected_result2 = getattr(dist2_tfp, method)(dist1_tfp) 110 | for mode in ['distrax_to_distrax', 'distrax_to_tfp', 'tfp_to_distrax']: 111 | with self.subTest(method=method, mode=mode): 112 | if mode == 'distrax_to_distrax': 113 | result1 = self.variant(getattr(dist1_distrax, method))( 114 | dist2_distrax) 115 | result2 = self.variant(getattr(dist2_distrax, method))( 116 | dist1_distrax) 117 | elif mode == 'distrax_to_tfp': 118 | result1 = self.variant(getattr(dist1_distrax, method))(dist2_tfp) 119 | result2 = self.variant(getattr(dist2_distrax, method))(dist1_tfp) 120 | elif mode == 'tfp_to_distrax': 121 | result1 = self.variant(getattr(dist1_tfp, method))(dist2_distrax) 122 | result2 = self.variant(getattr(dist2_tfp, method))(dist1_distrax) 123 | else: 124 | raise ValueError(f'Unsupported mode: {mode}') 125 | np.testing.assert_allclose(result1, expected_result1, rtol=0.03) 126 | np.testing.assert_allclose(result2, expected_result2, rtol=0.03) 127 | 128 | 129 | if __name__ == '__main__': 130 | absltest.main() 131 | -------------------------------------------------------------------------------- /distrax/_src/distributions/mvn_tri.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """MultivariateNormalTri distribution.""" 16 | 17 | from typing import Optional 18 | 19 | import chex 20 | from distrax._src.bijectors.diag_linear import DiagLinear 21 | from distrax._src.bijectors.triangular_linear import TriangularLinear 22 | from distrax._src.distributions import distribution 23 | from distrax._src.distributions.mvn_from_bijector import MultivariateNormalFromBijector 24 | from distrax._src.utils import conversion 25 | import jax.numpy as jnp 26 | from tensorflow_probability.substrates import jax as tfp 27 | 28 | tfd = tfp.distributions 29 | 30 | Array = chex.Array 31 | 32 | 33 | def _check_parameters( 34 | loc: Optional[Array], scale_tri: Optional[Array]) -> None: 35 | """Checks that the inputs are correct.""" 36 | 37 | if loc is None and scale_tri is None: 38 | raise ValueError( 39 | 'At least one of `loc` and `scale_tri` must be specified.') 40 | 41 | if loc is not None and loc.ndim < 1: 42 | raise ValueError('The parameter `loc` must have at least one dimension.') 43 | 44 | if scale_tri is not None and scale_tri.ndim < 2: 45 | raise ValueError( 46 | f'The parameter `scale_tri` must have at least two dimensions, but ' 47 | f'`scale_tri.shape = {scale_tri.shape}`.') 48 | 49 | if scale_tri is not None and scale_tri.shape[-1] != scale_tri.shape[-2]: 50 | raise ValueError( 51 | f'The parameter `scale_tri` must be a (batched) square matrix, but ' 52 | f'`scale_tri.shape = {scale_tri.shape}`.') 53 | 54 | if loc is not None: 55 | num_dims = loc.shape[-1] 56 | if scale_tri is not None and scale_tri.shape[-1] != num_dims: 57 | raise ValueError( 58 | f'Shapes are not compatible: `loc.shape = {loc.shape}` and ' 59 | f'`scale_tri.shape = {scale_tri.shape}`.') 60 | 61 | 62 | class MultivariateNormalTri(MultivariateNormalFromBijector): 63 | """Multivariate normal distribution on `R^k`. 64 | 65 | The `MultivariateNormalTri` distribution is parameterized by a `k`-length 66 | location (mean) vector `b` and a (lower or upper) triangular scale matrix `S` 67 | of size `k x k`. The covariance matrix is `C = S @ S.T`. 68 | """ 69 | 70 | equiv_tfp_cls = tfd.MultivariateNormalTriL 71 | 72 | def __init__(self, 73 | loc: Optional[Array] = None, 74 | scale_tri: Optional[Array] = None, 75 | is_lower: bool = True): 76 | """Initializes a MultivariateNormalTri distribution. 77 | 78 | Args: 79 | loc: Mean vector of the distribution of shape `k` (can also be a batch of 80 | such vectors). If not specified, it defaults to zeros. 81 | scale_tri: The scale matrix `S`. It must be a `k x k` triangular matrix 82 | (additional dimensions index batches). If `scale_tri` is not triangular, 83 | the entries above or below the main diagonal will be ignored. The 84 | parameter `is_lower` specifies if `scale_tri` is lower or upper 85 | triangular. It is the responsibility of the user to make sure that 86 | `scale_tri` only contains non-zero elements in its diagonal; this class 87 | makes no attempt to verify that. If `scale_tri` is not specified, it 88 | defaults to the identity. 89 | is_lower: Indicates if `scale_tri` is lower (if True) or upper (if False) 90 | triangular. 91 | """ 92 | loc = None if loc is None else conversion.as_float_array(loc) 93 | scale_tri = None if scale_tri is None else conversion.as_float_array( 94 | scale_tri) 95 | _check_parameters(loc, scale_tri) 96 | 97 | num_dims = None 98 | if loc is not None: 99 | num_dims = loc.shape[-1] 100 | elif scale_tri is not None: 101 | num_dims = scale_tri.shape[-1] 102 | 103 | dtype = jnp.result_type(*[x for x in [loc, scale_tri] if x is not None]) 104 | 105 | if loc is None: 106 | assert num_dims is not None 107 | loc = jnp.zeros((num_dims,), dtype=dtype) 108 | 109 | if scale_tri is None: 110 | self._scale_tri = jnp.eye(num_dims, dtype=dtype) 111 | scale = DiagLinear(diag=jnp.ones(loc.shape[-1:], dtype=dtype)) 112 | else: 113 | tri_fn = jnp.tril if is_lower else jnp.triu 114 | self._scale_tri = tri_fn(scale_tri) 115 | scale = TriangularLinear(matrix=self._scale_tri, is_lower=is_lower) 116 | self._is_lower = is_lower 117 | 118 | super().__init__(loc=loc, scale=scale) 119 | 120 | @property 121 | def scale_tri(self) -> Array: 122 | """Triangular scale matrix `S`.""" 123 | return jnp.broadcast_to( 124 | self._scale_tri, 125 | self.batch_shape + self.event_shape + self.event_shape) 126 | 127 | @property 128 | def is_lower(self) -> bool: 129 | """Whether the `scale_tri` matrix is lower triangular.""" 130 | return self._is_lower 131 | 132 | def __getitem__(self, index) -> 'MultivariateNormalTri': 133 | """See `Distribution.__getitem__`.""" 134 | index = distribution.to_batch_shape_index(self.batch_shape, index) 135 | return MultivariateNormalTri( 136 | loc=self.loc[index], 137 | scale_tri=self.scale_tri[index], 138 | is_lower=self.is_lower) 139 | -------------------------------------------------------------------------------- /distrax/_src/distributions/normal_float64_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `normal.py`.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | 20 | import chex 21 | from distrax._src.distributions import normal 22 | import jax 23 | from jax import config as jax_config 24 | import jax.numpy as jnp 25 | 26 | 27 | def setUpModule(): 28 | jax_config.update('jax_enable_x64', True) 29 | 30 | 31 | class NormalFloat64Test(chex.TestCase): 32 | 33 | def _assert_dtypes(self, dist, dtype): 34 | """Asserts dist methods' outputs' datatypes.""" 35 | # Sanity check to make sure float64 is enabled. 36 | x_64 = jnp.zeros([]) 37 | self.assertEqual(jnp.float64, x_64.dtype) 38 | 39 | key = jax.random.PRNGKey(1729) 40 | z, log_prob = self.variant( 41 | lambda: dist.sample_and_log_prob(seed=key, sample_shape=[3]))() 42 | z2 = self.variant( 43 | lambda: dist.sample(seed=key, sample_shape=[3]))() 44 | self.assertEqual(dtype, z.dtype) 45 | self.assertEqual(dtype, z2.dtype) 46 | self.assertEqual(dtype, log_prob.dtype) 47 | self.assertEqual(dtype, self.variant(dist.log_prob)(z).dtype) 48 | self.assertEqual(dtype, self.variant(dist.prob)(z).dtype) 49 | self.assertEqual(dtype, self.variant(dist.cdf)(z).dtype) 50 | self.assertEqual(dtype, self.variant(dist.log_cdf)(z).dtype) 51 | self.assertEqual(dtype, self.variant(dist.entropy)().dtype) 52 | self.assertEqual(dtype, self.variant(dist.mean)().dtype) 53 | self.assertEqual(dtype, self.variant(dist.mode)().dtype) 54 | self.assertEqual(dtype, self.variant(dist.median)().dtype) 55 | self.assertEqual(dtype, self.variant(dist.stddev)().dtype) 56 | self.assertEqual(dtype, self.variant(dist.variance)().dtype) 57 | self.assertEqual(dtype, dist.loc.dtype) 58 | self.assertEqual(dtype, dist.scale.dtype) 59 | self.assertEqual(dtype, dist.dtype) 60 | 61 | @chex.all_variants 62 | @parameterized.named_parameters( 63 | ('float32', jnp.float32), 64 | ('float64', jnp.float64)) 65 | def test_dtype(self, dtype): 66 | dist = normal.Normal(loc=jnp.zeros([], dtype), scale=jnp.ones([], dtype)) 67 | self._assert_dtypes(dist, dtype) 68 | 69 | 70 | if __name__ == '__main__': 71 | absltest.main() 72 | -------------------------------------------------------------------------------- /distrax/_src/distributions/one_hot_categorical.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """OneHotCategorical distribution.""" 16 | 17 | from typing import Any, Optional, Tuple, Union 18 | 19 | import chex 20 | from distrax._src.distributions import categorical 21 | from distrax._src.distributions import distribution 22 | from distrax._src.utils import math 23 | import jax 24 | import jax.numpy as jnp 25 | from tensorflow_probability.substrates import jax as tfp 26 | 27 | 28 | tfd = tfp.distributions 29 | 30 | Array = chex.Array 31 | PRNGKey = chex.PRNGKey 32 | EventT = distribution.EventT 33 | 34 | 35 | class OneHotCategorical(categorical.Categorical): 36 | """OneHotCategorical distribution.""" 37 | 38 | equiv_tfp_cls = tfd.OneHotCategorical 39 | 40 | def __init__(self, 41 | logits: Optional[Array] = None, 42 | probs: Optional[Array] = None, 43 | dtype: Union[jnp.dtype, type[Any]] = int): 44 | """Initializes a OneHotCategorical distribution. 45 | 46 | Args: 47 | logits: Logit transform of the probability of each category. Only one 48 | of `logits` or `probs` can be specified. 49 | probs: Probability of each category. Only one of `logits` or `probs` can 50 | be specified. 51 | dtype: The type of event samples. 52 | """ 53 | super().__init__(logits=logits, probs=probs, dtype=dtype) 54 | 55 | @property 56 | def event_shape(self) -> Tuple[int, ...]: 57 | """Shape of event of distribution samples.""" 58 | return (self.num_categories,) 59 | 60 | def _sample_n(self, key: PRNGKey, n: int) -> Array: 61 | """See `Distribution._sample_n`.""" 62 | new_shape = (n,) + self.logits.shape[:-1] 63 | is_valid = jnp.logical_and( 64 | jnp.all(jnp.isfinite(self.probs), axis=-1, keepdims=True), 65 | jnp.all(self.probs >= 0, axis=-1, keepdims=True)) 66 | draws = jax.random.categorical( 67 | key=key, logits=self.logits, axis=-1, shape=new_shape) 68 | draws_one_hot = jax.nn.one_hot( 69 | draws, num_classes=self.num_categories).astype(self._dtype) 70 | return jnp.where(is_valid, draws_one_hot, jnp.ones_like(draws_one_hot) * -1) 71 | 72 | def log_prob(self, value: EventT) -> Array: 73 | """See `Distribution.log_prob`.""" 74 | return jnp.sum(math.multiply_no_nan(self.logits, value), axis=-1) 75 | 76 | def prob(self, value: EventT) -> Array: 77 | """See `Distribution.prob`.""" 78 | return jnp.sum(math.multiply_no_nan(self.probs, value), axis=-1) 79 | 80 | def mode(self) -> Array: 81 | """Calculates the mode.""" 82 | preferences = self._probs if self._logits is None else self._logits 83 | assert preferences is not None 84 | greedy_index = jnp.argmax(preferences, axis=-1) 85 | return jax.nn.one_hot(greedy_index, self.num_categories).astype(self._dtype) 86 | 87 | def cdf(self, value: EventT) -> Array: 88 | """See `Distribution.cdf`.""" 89 | return jnp.sum(math.multiply_no_nan( 90 | jnp.cumsum(self.probs, axis=-1), value), axis=-1) 91 | 92 | def __getitem__(self, index) -> 'OneHotCategorical': 93 | """See `Distribution.__getitem__`.""" 94 | index = distribution.to_batch_shape_index(self.batch_shape, index) 95 | if self._logits is not None: 96 | return OneHotCategorical(logits=self.logits[index], dtype=self._dtype) 97 | return OneHotCategorical(probs=self.probs[index], dtype=self._dtype) 98 | -------------------------------------------------------------------------------- /distrax/_src/distributions/softmax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Softmax distribution.""" 16 | 17 | from typing import Any, Union 18 | 19 | import chex 20 | from distrax._src.distributions import categorical 21 | from distrax._src.distributions import distribution 22 | import jax.numpy as jnp 23 | 24 | 25 | Array = chex.Array 26 | 27 | 28 | class Softmax(categorical.Categorical): 29 | """Categorical implementing a softmax over logits, with given temperature. 30 | 31 | Given a set of logits, the probability mass is distributed such that each 32 | index `i` has probability `exp(logits[i]/τ)/Σ(exp(logits/τ)` where τ is a 33 | scalar `temperature` parameter such that for τ→0, the distribution 34 | becomes fully greedy, and for τ→∞ the distribution becomes fully uniform. 35 | """ 36 | 37 | def __init__(self, 38 | logits: Array, 39 | temperature: float = 1., 40 | dtype: Union[jnp.dtype, type[Any]] = int): 41 | """Initializes a Softmax distribution. 42 | 43 | Args: 44 | logits: Logit transform of the probability of each category. 45 | temperature: Softmax temperature τ. 46 | dtype: The type of event samples. 47 | """ 48 | self._temperature = temperature 49 | self._unscaled_logits = logits 50 | scaled_logits = logits / temperature 51 | super().__init__(logits=scaled_logits, dtype=dtype) 52 | 53 | @property 54 | def temperature(self) -> float: 55 | """The softmax temperature parameter.""" 56 | return self._temperature 57 | 58 | @property 59 | def unscaled_logits(self) -> Array: 60 | """The logits of the distribution before the temperature scaling.""" 61 | return self._unscaled_logits 62 | 63 | def __getitem__(self, index) -> 'Softmax': 64 | """See `Distribution.__getitem__`.""" 65 | index = distribution.to_batch_shape_index(self.batch_shape, index) 66 | return Softmax( 67 | logits=self.unscaled_logits[index], 68 | temperature=self.temperature, 69 | dtype=self.dtype) 70 | -------------------------------------------------------------------------------- /distrax/_src/distributions/softmax_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `softmax.py`.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | 20 | import chex 21 | from distrax._src.distributions import softmax 22 | from distrax._src.utils import equivalence 23 | from distrax._src.utils import math 24 | import jax 25 | import jax.numpy as jnp 26 | import numpy as np 27 | 28 | 29 | class SoftmaxUnitTemperatureTest(equivalence.EquivalenceTest): 30 | 31 | def setUp(self): 32 | super().setUp() 33 | self._init_distr_cls(softmax.Softmax) 34 | self.temperature = 1. 35 | self.probs = jnp.array([0.2, 0.4, 0.1, 0.3]) 36 | self.logits = jnp.log(self.probs) 37 | 38 | def test_num_categories(self): 39 | dist = self.distrax_cls(logits=self.logits) 40 | np.testing.assert_equal(dist.num_categories, len(self.logits)) 41 | 42 | def test_parameters(self): 43 | dist = self.distrax_cls(logits=self.logits) 44 | self.assertion_fn(rtol=2e-3)(dist.logits, self.logits) 45 | self.assertion_fn(rtol=2e-3)(dist.probs, self.probs) 46 | 47 | 48 | class SoftmaxTest(equivalence.EquivalenceTest): 49 | 50 | def setUp(self): 51 | super().setUp() 52 | self._init_distr_cls(softmax.Softmax) 53 | self.temperature = 10. 54 | self.logits = jnp.array([2., 4., 1., 3.]) 55 | self.probs = jax.nn.softmax(self.logits / self.temperature) 56 | 57 | def test_num_categories(self): 58 | dist = self.distrax_cls(logits=self.logits, temperature=self.temperature) 59 | np.testing.assert_equal(dist.num_categories, len(self.logits)) 60 | 61 | def test_parameters(self): 62 | dist = self.distrax_cls(logits=self.logits, temperature=self.temperature) 63 | self.assertion_fn(rtol=2e-3)( 64 | dist.logits, math.normalize(logits=self.logits / self.temperature)) 65 | self.assertion_fn(rtol=2e-3)(dist.probs, self.probs) 66 | 67 | @chex.all_variants 68 | @parameterized.named_parameters( 69 | ('int32', jnp.int32), 70 | ('int64', jnp.int64), 71 | ('float32', jnp.float32), 72 | ('float64', jnp.float64)) 73 | def test_sample_dtype(self, dtype): 74 | with jax.experimental.enable_x64(dtype.dtype.itemsize == 8): 75 | dist = self.distrax_cls( 76 | logits=self.logits, temperature=self.temperature, dtype=dtype) 77 | samples = self.variant(dist.sample)(seed=self.key) 78 | self.assertEqual(samples.dtype, dist.dtype) 79 | chex.assert_type(samples, dtype) 80 | 81 | def test_jittable(self): 82 | super()._test_jittable((np.array([2., 4., 1., 3.]),)) 83 | 84 | @parameterized.named_parameters( 85 | ('single element', 2), 86 | ('range', slice(-1)), 87 | ('range_2', (slice(None), slice(-1))), 88 | ) 89 | def test_slice(self, slice_): 90 | logits = jnp.array(np.random.randn(3, 4, 5)) 91 | temperature = 0.8 92 | scaled_logits = logits / temperature 93 | dist = self.distrax_cls(logits=logits, temperature=temperature) 94 | self.assertIsInstance(dist[slice_], self.distrax_cls) 95 | self.assertion_fn(rtol=2e-3)(dist[slice_].temperature, temperature) 96 | self.assertion_fn(rtol=2e-3)( 97 | jax.nn.softmax(dist[slice_].logits, axis=-1), 98 | jax.nn.softmax(scaled_logits[slice_], axis=-1)) 99 | 100 | def test_slice_ellipsis(self): 101 | logits = jnp.array(np.random.randn(3, 4, 5)) 102 | temperature = 0.8 103 | scaled_logits = logits / temperature 104 | dist = self.distrax_cls(logits=logits, temperature=temperature) 105 | dist_sliced = dist[..., -1] 106 | self.assertIsInstance(dist_sliced, self.distrax_cls) 107 | self.assertion_fn(rtol=2e-3)(dist_sliced.temperature, temperature) 108 | self.assertion_fn(rtol=2e-3)( 109 | jax.nn.softmax(dist_sliced.logits, axis=-1), 110 | jax.nn.softmax(scaled_logits[:, -1], axis=-1)) 111 | 112 | 113 | if __name__ == '__main__': 114 | absltest.main() 115 | -------------------------------------------------------------------------------- /distrax/_src/distributions/straight_through.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Straight-through gradient sampling distribution.""" 16 | from distrax._src.distributions import categorical 17 | from distrax._src.distributions import distribution 18 | import jax 19 | 20 | 21 | def straight_through_wrapper( # pylint: disable=invalid-name 22 | Distribution, 23 | ) -> distribution.DistributionLike: 24 | """Wrap a distribution to use straight-through gradient for samples.""" 25 | 26 | def sample(self, seed, sample_shape=()): # pylint: disable=g-doc-args 27 | """Sampling with straight through biased gradient estimator. 28 | 29 | Sample a value from the distribution, but backpropagate through the 30 | underlying probability to compute the gradient. 31 | 32 | References: 33 | [1] Yoshua Bengio, Nicholas Léonard, Aaron Courville, Estimating or 34 | Propagating Gradients Through Stochastic Neurons for Conditional 35 | Computation, https://arxiv.org/abs/1308.3432 36 | 37 | Args: 38 | seed: a random seed. 39 | sample_shape: the shape of the required sample. 40 | 41 | Returns: 42 | A sample with straight-through gradient. 43 | """ 44 | # pylint: disable=protected-access 45 | obj = Distribution(probs=self._probs, logits=self._logits) 46 | assert isinstance(obj, categorical.Categorical) 47 | sample = obj.sample(seed=seed, sample_shape=sample_shape) 48 | probs = obj.probs 49 | padded_probs = _pad(probs, sample.shape) 50 | 51 | # Keep sample unchanged, but add gradient through probs. 52 | sample += padded_probs - jax.lax.stop_gradient(padded_probs) 53 | return sample 54 | 55 | def _pad(probs, shape): 56 | """Grow probs to have the same number of dimensions as shape.""" 57 | while len(probs.shape) < len(shape): 58 | probs = probs[None] 59 | return probs 60 | 61 | parent_name = Distribution.__name__ 62 | # Return a new object, overriding sample. 63 | return type('StraighThrough' + parent_name, (Distribution,), 64 | {'sample': sample}) 65 | -------------------------------------------------------------------------------- /distrax/_src/distributions/uniform.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Uniform distribution.""" 16 | 17 | import math 18 | from typing import Tuple, Union 19 | 20 | import chex 21 | from distrax._src.distributions import distribution 22 | from distrax._src.utils import conversion 23 | import jax 24 | import jax.numpy as jnp 25 | from tensorflow_probability.substrates import jax as tfp 26 | 27 | tfd = tfp.distributions 28 | 29 | Array = chex.Array 30 | Numeric = chex.Numeric 31 | PRNGKey = chex.PRNGKey 32 | EventT = distribution.EventT 33 | 34 | 35 | class Uniform(distribution.Distribution): 36 | """Uniform distribution with `low` and `high` parameters.""" 37 | 38 | equiv_tfp_cls = tfd.Uniform 39 | 40 | def __init__(self, low: Numeric = 0., high: Numeric = 1.): 41 | """Initializes a Uniform distribution. 42 | 43 | Args: 44 | low: Lower bound. 45 | high: Upper bound. 46 | """ 47 | super().__init__() 48 | self._low = conversion.as_float_array(low) 49 | self._high = conversion.as_float_array(high) 50 | self._batch_shape = jax.lax.broadcast_shapes( 51 | self._low.shape, self._high.shape) 52 | 53 | @property 54 | def event_shape(self) -> Tuple[int, ...]: 55 | """Shape of the events.""" 56 | return () 57 | 58 | @property 59 | def low(self) -> Array: 60 | """Lower bound.""" 61 | return jnp.broadcast_to(self._low, self.batch_shape) 62 | 63 | @property 64 | def high(self) -> Array: 65 | """Upper bound.""" 66 | return jnp.broadcast_to(self._high, self.batch_shape) 67 | 68 | @property 69 | def range(self) -> Array: 70 | return self.high - self.low 71 | 72 | @property 73 | def batch_shape(self) -> Tuple[int, ...]: 74 | return self._batch_shape 75 | 76 | def _sample_n(self, key: PRNGKey, n: int) -> Array: 77 | """See `Distribution._sample_n`.""" 78 | new_shape = (n,) + self.batch_shape 79 | uniform = jax.random.uniform( 80 | key=key, shape=new_shape, dtype=self.range.dtype, minval=0., maxval=1.) 81 | low = jnp.expand_dims(self._low, range(uniform.ndim - self._low.ndim)) 82 | range_ = jnp.expand_dims(self.range, range(uniform.ndim - self.range.ndim)) 83 | return low + range_ * uniform 84 | 85 | def _sample_n_and_log_prob(self, key: PRNGKey, n: int) -> Tuple[Array, Array]: 86 | """See `Distribution._sample_n_and_log_prob`.""" 87 | samples = self._sample_n(key, n) 88 | log_prob = -jnp.log(self.range) 89 | log_prob = jnp.repeat(log_prob[None], n, axis=0) 90 | return samples, log_prob 91 | 92 | def log_prob(self, value: EventT) -> Array: 93 | """See `Distribution.log_prob`.""" 94 | return jnp.log(self.prob(value)) 95 | 96 | def prob(self, value: EventT) -> Array: 97 | """See `Distribution.prob`.""" 98 | return jnp.where( 99 | jnp.logical_or(value < self.low, value > self.high), 100 | jnp.zeros_like(value), 101 | jnp.ones_like(value) / self.range) 102 | 103 | def entropy(self) -> Array: 104 | """Calculates the entropy.""" 105 | return jnp.log(self.range) 106 | 107 | def mean(self) -> Array: 108 | """Calculates the mean.""" 109 | return (self.low + self.high) / 2. 110 | 111 | def variance(self) -> Array: 112 | """Calculates the variance.""" 113 | return jnp.square(self.range) / 12. 114 | 115 | def stddev(self) -> Array: 116 | """Calculates the standard deviation.""" 117 | return self.range / math.sqrt(12.) 118 | 119 | def median(self) -> Array: 120 | """Calculates the median.""" 121 | return self.mean() 122 | 123 | def cdf(self, value: EventT) -> Array: 124 | """See `Distribution.cdf`.""" 125 | ones = jnp.ones_like(self.range) 126 | zeros = jnp.zeros_like(ones) 127 | result_if_not_big = jnp.where( 128 | value < self.low, zeros, (value - self.low) / self.range) 129 | return jnp.where(value > self.high, ones, result_if_not_big) 130 | 131 | def log_cdf(self, value: EventT) -> Array: 132 | """See `Distribution.log_cdf`.""" 133 | return jnp.log(self.cdf(value)) 134 | 135 | def __getitem__(self, index) -> 'Uniform': 136 | """See `Distribution.__getitem__`.""" 137 | index = distribution.to_batch_shape_index(self.batch_shape, index) 138 | return Uniform(low=self.low[index], high=self.high[index]) 139 | 140 | 141 | def _kl_divergence_uniform_uniform( 142 | dist1: Union[Uniform, tfd.Uniform], 143 | dist2: Union[Uniform, tfd.Uniform], 144 | *unused_args, **unused_kwargs, 145 | ) -> Array: 146 | """Obtain the KL divergence `KL(dist1 || dist2)` between two Uniforms. 147 | 148 | Note that the KL divergence is infinite if the support of `dist1` is not a 149 | subset of the support of `dist2`. 150 | 151 | Args: 152 | dist1: A Uniform distribution. 153 | dist2: A Uniform distribution. 154 | 155 | Returns: 156 | Batchwise `KL(dist1 || dist2)`. 157 | """ 158 | return jnp.where( 159 | jnp.logical_and(dist2.low <= dist1.low, dist1.high <= dist2.high), 160 | jnp.log(dist2.high - dist2.low) - jnp.log(dist1.high - dist1.low), 161 | jnp.inf) 162 | 163 | 164 | # Register the KL functions with TFP. 165 | tfd.RegisterKL(Uniform, Uniform)(_kl_divergence_uniform_uniform) 166 | tfd.RegisterKL(Uniform, Uniform.equiv_tfp_cls)(_kl_divergence_uniform_uniform) 167 | tfd.RegisterKL(Uniform.equiv_tfp_cls, Uniform)(_kl_divergence_uniform_uniform) 168 | -------------------------------------------------------------------------------- /distrax/_src/distributions/uniform_float64_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `uniform.py`.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | 20 | import chex 21 | from distrax._src.distributions import uniform 22 | import jax 23 | from jax import config as jax_config 24 | import jax.numpy as jnp 25 | 26 | 27 | def setUpModule(): 28 | jax_config.update('jax_enable_x64', True) 29 | 30 | 31 | class UniformFloat64Test(chex.TestCase): 32 | 33 | def _assert_dtypes(self, dist, dtype): 34 | """Asserts dist methods' outputs' datatypes.""" 35 | # Sanity check to make sure float64 is enabled. 36 | x_64 = jnp.zeros([]) 37 | self.assertEqual(jnp.float64, x_64.dtype) 38 | 39 | key = jax.random.PRNGKey(1729) 40 | z, log_prob = self.variant( 41 | lambda: dist.sample_and_log_prob(seed=key, sample_shape=[3]))() 42 | z2 = self.variant( 43 | lambda: dist.sample(seed=key, sample_shape=[3]))() 44 | self.assertEqual(dtype, z.dtype) 45 | self.assertEqual(dtype, z2.dtype) 46 | self.assertEqual(dtype, log_prob.dtype) 47 | self.assertEqual(dtype, self.variant(dist.log_prob)(z).dtype) 48 | self.assertEqual(dtype, self.variant(dist.prob)(z).dtype) 49 | self.assertEqual(dtype, self.variant(dist.log_cdf)(z).dtype) 50 | self.assertEqual(dtype, self.variant(dist.cdf)(z).dtype) 51 | self.assertEqual(dtype, self.variant(dist.entropy)().dtype) 52 | self.assertEqual(dtype, self.variant(dist.mean)().dtype) 53 | self.assertEqual(dtype, self.variant(dist.median)().dtype) 54 | self.assertEqual(dtype, self.variant(dist.stddev)().dtype) 55 | self.assertEqual(dtype, self.variant(dist.variance)().dtype) 56 | self.assertEqual(dtype, dist.low.dtype) 57 | self.assertEqual(dtype, dist.high.dtype) 58 | self.assertEqual(dtype, dist.range.dtype) 59 | self.assertEqual(dtype, dist.dtype) 60 | 61 | @chex.all_variants 62 | @parameterized.named_parameters( 63 | ('float32', jnp.float32), 64 | ('float64', jnp.float64)) 65 | def test_dtype(self, dtype): 66 | dist = uniform.Uniform(low=jnp.zeros([], dtype), high=jnp.ones([], dtype)) 67 | self._assert_dtypes(dist, dtype) 68 | 69 | 70 | if __name__ == '__main__': 71 | absltest.main() 72 | -------------------------------------------------------------------------------- /distrax/_src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /distrax/_src/utils/conversion.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utility functions for conversion between different types.""" 16 | 17 | from typing import Optional, Union 18 | 19 | import chex 20 | from distrax._src.bijectors import bijector 21 | from distrax._src.bijectors import bijector_from_tfp 22 | from distrax._src.bijectors import lambda_bijector 23 | from distrax._src.bijectors import sigmoid 24 | from distrax._src.bijectors import tanh 25 | from distrax._src.bijectors import tfp_compatible_bijector 26 | from distrax._src.distributions import distribution 27 | from distrax._src.distributions import distribution_from_tfp 28 | from distrax._src.distributions import tfp_compatible_distribution 29 | import jax 30 | import jax.numpy as jnp 31 | import numpy as np 32 | from tensorflow_probability.substrates import jax as tfp 33 | 34 | tfb = tfp.bijectors 35 | tfd = tfp.distributions 36 | 37 | Array = chex.Array 38 | Numeric = chex.Numeric 39 | BijectorLike = bijector.BijectorLike 40 | DistributionLike = distribution.DistributionLike 41 | 42 | 43 | def to_tfp(obj: Union[bijector.Bijector, tfb.Bijector, 44 | distribution.Distribution, tfd.Distribution], 45 | name: Optional[str] = None): 46 | """Converts a distribution or bijector to a TFP-compatible equivalent object. 47 | 48 | The returned object is not necessarily of type `tfb.Bijector` or 49 | `tfd.Distribution`; rather, it is a Distrax object that implements TFP 50 | functionality so that it can be used in TFP. 51 | 52 | If the input is already of TFP type, it is returned unchanged. 53 | 54 | Args: 55 | obj: The distribution or bijector to be converted to TFP. 56 | name: The name of the resulting object. 57 | 58 | Returns: 59 | A TFP-compatible equivalent distribution or bijector. 60 | """ 61 | if isinstance(obj, (tfb.Bijector, tfd.Distribution)): 62 | return obj 63 | elif isinstance(obj, bijector.Bijector): 64 | return tfp_compatible_bijector.tfp_compatible_bijector(obj, name) 65 | elif isinstance(obj, distribution.Distribution): 66 | return tfp_compatible_distribution.tfp_compatible_distribution(obj, name) 67 | else: 68 | raise TypeError( 69 | f"`to_tfp` can only convert objects of type: `distrax.Bijector`," 70 | f" `tfb.Bijector`, `distrax.Distribution`, `tfd.Distribution`. Got type" 71 | f" `{type(obj)}`.") 72 | 73 | 74 | def as_bijector(obj: BijectorLike) -> bijector.BijectorT: 75 | """Converts a bijector-like object to a Distrax bijector. 76 | 77 | Bijector-like objects are: Distrax bijectors, TFP bijectors, and callables. 78 | Distrax bijectors are returned unchanged. TFP bijectors are converted to a 79 | Distrax equivalent. Callables are wrapped by `distrax.Lambda`, with a few 80 | exceptions where an explicit implementation already exists and is returned. 81 | 82 | Args: 83 | obj: The bijector-like object to be converted. 84 | 85 | Returns: 86 | A Distrax bijector. 87 | """ 88 | if isinstance(obj, bijector.Bijector): 89 | return obj 90 | elif isinstance(obj, tfb.Bijector): 91 | return bijector_from_tfp.BijectorFromTFP(obj) 92 | elif obj is jax.nn.sigmoid: 93 | return sigmoid.Sigmoid() 94 | elif obj is jnp.tanh: 95 | return tanh.Tanh() 96 | elif callable(obj): 97 | return lambda_bijector.Lambda(obj) 98 | else: 99 | raise TypeError( 100 | f"A bijector-like object can be a `distrax.Bijector`, a `tfb.Bijector`," 101 | f" or a callable. Got type `{type(obj)}`.") 102 | 103 | 104 | def as_distribution(obj: DistributionLike) -> distribution.DistributionT: 105 | """Converts a distribution-like object to a Distrax distribution. 106 | 107 | Distribution-like objects are: Distrax distributions and TFP distributions. 108 | Distrax distributions are returned unchanged. TFP distributions are converted 109 | to a Distrax equivalent. 110 | 111 | Args: 112 | obj: A distribution-like object to be converted. 113 | 114 | Returns: 115 | A Distrax distribution. 116 | """ 117 | if isinstance(obj, distribution.Distribution): 118 | return obj 119 | elif isinstance(obj, tfd.Distribution): 120 | return distribution_from_tfp.distribution_from_tfp(obj) 121 | else: 122 | raise TypeError( 123 | f"A distribution-like object can be a `distrax.Distribution` or a" 124 | f" `tfd.Distribution`. Got type `{type(obj)}`.") 125 | 126 | 127 | def as_float_array(x: Numeric) -> Array: 128 | """Converts input to an array with floating-point dtype. 129 | 130 | If the input is already an array with floating-point dtype, it is returned 131 | unchanged. 132 | 133 | Args: 134 | x: input to convert. 135 | 136 | Returns: 137 | An array with floating-point dtype. 138 | """ 139 | if not isinstance(x, (jax.Array, np.ndarray)): 140 | x = jnp.asarray(x) 141 | 142 | if jnp.issubdtype(x.dtype, jnp.floating): 143 | return x 144 | elif jnp.issubdtype(x.dtype, jnp.integer): 145 | return x.astype(jnp.float_) 146 | else: 147 | raise ValueError( 148 | f"Expected either floating or integer dtype, got {x.dtype}.") 149 | -------------------------------------------------------------------------------- /distrax/_src/utils/importance_sampling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Importance sampling.""" 16 | 17 | import chex 18 | from distrax._src.distributions import distribution 19 | import jax.numpy as jnp 20 | 21 | 22 | Array = chex.Array 23 | DistributionLike = distribution.DistributionLike 24 | 25 | 26 | def importance_sampling_ratios( 27 | target_dist: DistributionLike, 28 | sampling_dist: DistributionLike, 29 | event: Array 30 | ) -> Array: 31 | """Compute importance sampling ratios given target and sampling distributions. 32 | 33 | Args: 34 | target_dist: Target probability distribution. 35 | sampling_dist: Sampling probability distribution. 36 | event: Samples. 37 | 38 | Returns: 39 | Importance sampling ratios. 40 | """ 41 | log_pi_a = target_dist.log_prob(event) 42 | log_mu_a = sampling_dist.log_prob(event) 43 | rho = jnp.exp(log_pi_a - log_mu_a) 44 | return rho 45 | -------------------------------------------------------------------------------- /distrax/_src/utils/importance_sampling_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `importance_sampling.py`.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | 20 | import chex 21 | from distrax._src.distributions import categorical 22 | from distrax._src.utils import importance_sampling 23 | import jax 24 | import jax.numpy as jnp 25 | import numpy as np 26 | 27 | 28 | class ImportanceSamplingTest(parameterized.TestCase): 29 | 30 | @chex.all_variants(with_pmap=False) 31 | def test_importance_sampling_ratios_on_policy(self): 32 | key = jax.random.PRNGKey(42) 33 | probs = jnp.array([0.4, 0.2, 0.1, 0.3]) 34 | dist = categorical.Categorical(probs=probs) 35 | event = dist.sample(seed=key, sample_shape=()) 36 | 37 | ratios_fn = self.variant( 38 | importance_sampling.importance_sampling_ratios) 39 | rhos = ratios_fn(target_dist=dist, sampling_dist=dist, event=event) 40 | 41 | expected_rhos = jnp.ones_like(rhos) 42 | np.testing.assert_array_almost_equal(rhos, expected_rhos) 43 | 44 | @chex.all_variants(with_pmap=False) 45 | def test_importance_sampling_ratios_off_policy(self): 46 | """Tests for a full batch.""" 47 | pi_logits = np.array([[0.2, 0.8], [0.6, 0.4]], dtype=np.float32) 48 | pi = categorical.Categorical(logits=pi_logits) 49 | mu_logits = np.array([[0.8, 0.2], [0.6, 0.4]], dtype=np.float32) 50 | mu = categorical.Categorical(logits=mu_logits) 51 | events = np.array([1, 0], dtype=np.int32) 52 | 53 | ratios_fn = self.variant( 54 | importance_sampling.importance_sampling_ratios) 55 | rhos = ratios_fn(pi, mu, events) 56 | 57 | expected_rhos = np.array( 58 | [pi.probs[0][1] / mu.probs[0][1], pi.probs[1][0] / mu.probs[1][0]], 59 | dtype=np.float32) 60 | np.testing.assert_allclose(expected_rhos, rhos, atol=1e-4) 61 | 62 | 63 | if __name__ == '__main__': 64 | absltest.main() 65 | -------------------------------------------------------------------------------- /distrax/_src/utils/jittable.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Abstract class for Jittable objects.""" 16 | 17 | import abc 18 | import jax 19 | 20 | 21 | class Jittable(metaclass=abc.ABCMeta): 22 | """ABC that can be passed as an arg to a jitted fn, with readable state.""" 23 | 24 | def __new__(cls, *args, **kwargs): 25 | # Discard the parameters to this function because the constructor is not 26 | # called during serialization: its `__dict__` gets repopulated directly. 27 | del args, kwargs 28 | try: 29 | registered_cls = jax.tree_util.register_pytree_node_class(cls) 30 | except ValueError: 31 | registered_cls = cls # Already registered. 32 | return object.__new__(registered_cls) 33 | 34 | def tree_flatten(self): 35 | leaves, treedef = jax.tree_util.tree_flatten(self.__dict__) 36 | switch = list(map(_is_jax_data, leaves)) 37 | children = [leaf if s else None for leaf, s in zip(leaves, switch)] 38 | metadata = [None if s else leaf for leaf, s in zip(leaves, switch)] 39 | return children, (metadata, switch, treedef) 40 | 41 | @classmethod 42 | def tree_unflatten(cls, aux_data, children): 43 | metadata, switch, treedef = aux_data 44 | leaves = [j if s else p for j, p, s in zip(children, metadata, switch)] 45 | obj = object.__new__(cls) 46 | obj.__dict__ = jax.tree_util.tree_unflatten(treedef, leaves) 47 | return obj 48 | 49 | 50 | def _is_jax_data(x): 51 | """Check whether `x` is an instance of a JAX-compatible type.""" 52 | # If it's a tracer, then it's already been converted by JAX. 53 | if isinstance(x, jax.core.Tracer): 54 | return True 55 | 56 | # `jax.vmap` replaces vmappable leaves with `object()` during serialization. 57 | if type(x) is object: # pylint: disable=unidiomatic-typecheck 58 | return True 59 | 60 | # Primitive types (e.g. shape tuples) are treated as metadata for Distrax. 61 | if isinstance(x, (bool, int, float)) or x is None: 62 | return False 63 | 64 | # Return True if JAX considers `x` a valid JAX type. 65 | return jax.core.valid_jaxtype(x) 66 | -------------------------------------------------------------------------------- /distrax/_src/utils/jittable_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `jittable.py`.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | 20 | from distrax._src.utils import jittable 21 | import jax 22 | import jax.numpy as jnp 23 | import numpy as np 24 | 25 | 26 | class DummyJittable(jittable.Jittable): 27 | 28 | def __init__(self, params): 29 | self.name = 'dummy' # Non-JAX property, cannot be traced. 30 | self.data = {'params': params} # Tree property, must be traced recursively. 31 | 32 | 33 | class JittableTest(parameterized.TestCase): 34 | 35 | def test_jittable(self): 36 | @jax.jit 37 | def get_params(obj): 38 | return obj.data['params'] 39 | obj = DummyJittable(jnp.ones((5,))) 40 | np.testing.assert_array_equal(get_params(obj), obj.data['params']) 41 | 42 | def test_vmappable(self): 43 | def do_sum(obj): 44 | return obj.data['params'].sum() 45 | obj = DummyJittable(jnp.array([[1, 2, 3], [4, 5, 6]])) 46 | 47 | with self.subTest('no vmap'): 48 | np.testing.assert_array_equal(do_sum(obj), obj.data['params'].sum()) 49 | 50 | with self.subTest('in_axes=0'): 51 | np.testing.assert_array_equal( 52 | jax.vmap(do_sum, in_axes=0)(obj), obj.data['params'].sum(axis=1)) 53 | 54 | with self.subTest('in_axes=1'): 55 | np.testing.assert_array_equal( 56 | jax.vmap(do_sum, in_axes=1)(obj), obj.data['params'].sum(axis=0)) 57 | 58 | def test_traceable(self): 59 | @jax.jit 60 | def inner_fn(obj): 61 | obj.data['params'] *= 3 # Modification after passing to jitted fn. 62 | return obj.data['params'].sum() 63 | 64 | def loss_fn(params): 65 | obj = DummyJittable(params) 66 | obj.data['params'] *= 2 # Modification before passing to jitted fn. 67 | return inner_fn(obj) 68 | 69 | with self.subTest('numpy'): 70 | params = np.ones((5,)) 71 | # Both modifications will be traced if data tree is correctly traversed. 72 | grad_expected = params * 2 * 3 73 | grad = jax.grad(loss_fn)(params) 74 | np.testing.assert_array_equal(grad, grad_expected) 75 | 76 | with self.subTest('jax.numpy'): 77 | params = jnp.ones((5,)) 78 | # Both modifications will be traced if data tree is correctly traversed. 79 | grad_expected = params * 2 * 3 80 | grad = jax.grad(loss_fn)(params) 81 | np.testing.assert_array_equal(grad, grad_expected) 82 | 83 | def test_different_jittables_to_compiled_function(self): 84 | @jax.jit 85 | def add_one_to_params(obj): 86 | obj.data['params'] = obj.data['params'] + 1 87 | return obj 88 | 89 | with self.subTest('numpy'): 90 | add_one_to_params(DummyJittable(np.zeros((5,)))) 91 | add_one_to_params(DummyJittable(np.ones((5,)))) 92 | 93 | with self.subTest('jax.numpy'): 94 | add_one_to_params(DummyJittable(jnp.zeros((5,)))) 95 | add_one_to_params(DummyJittable(jnp.ones((5,)))) 96 | 97 | def test_modifying_object_data_does_not_leak_tracers(self): 98 | @jax.jit 99 | def add_one_to_params(obj): 100 | obj.data['params'] = obj.data['params'] + 1 101 | return obj 102 | 103 | dummy = DummyJittable(jnp.ones((5,))) 104 | dummy_out = add_one_to_params(dummy) 105 | dummy_out.data['params'] -= 1 106 | 107 | def test_metadata_modification_statements_are_removed_by_compilation(self): 108 | @jax.jit 109 | def add_char_to_name(obj): 110 | obj.name += '_x' 111 | return obj 112 | 113 | dummy = DummyJittable(jnp.ones((5,))) 114 | dummy_out = add_char_to_name(dummy) 115 | dummy_out = add_char_to_name(dummy) # `name` change has been compiled out. 116 | dummy_out.name += 'y' 117 | self.assertEqual(dummy_out.name, 'dummy_xy') 118 | 119 | 120 | if __name__ == '__main__': 121 | absltest.main() 122 | -------------------------------------------------------------------------------- /distrax/_src/utils/math_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `math.py`.""" 16 | 17 | from absl.testing import absltest 18 | 19 | from distrax._src.utils import math 20 | import jax 21 | import jax.numpy as jnp 22 | import numpy as np 23 | import scipy.special 24 | 25 | 26 | class MathTest(absltest.TestCase): 27 | 28 | def test_multiply_no_nan(self): 29 | zero = jnp.zeros(()) 30 | nan = zero / zero 31 | self.assertTrue(jnp.isnan(math.multiply_no_nan(zero, nan))) 32 | self.assertFalse(jnp.isnan(math.multiply_no_nan(nan, zero))) 33 | 34 | def test_multiply_no_nan_grads(self): 35 | x = -jnp.inf 36 | y = 0. 37 | self.assertEqual(math.multiply_no_nan(x, y), 0.) 38 | grad_fn = jax.grad( 39 | lambda inputs: math.multiply_no_nan(inputs[0], inputs[1])) 40 | np.testing.assert_allclose(grad_fn((x, y)), (y, x), rtol=1e-3) 41 | 42 | def test_power_no_nan(self): 43 | zero = jnp.zeros(()) 44 | nan = zero / zero 45 | self.assertTrue(jnp.isnan(math.power_no_nan(zero, nan))) 46 | self.assertFalse(jnp.isnan(math.power_no_nan(nan, zero))) 47 | 48 | def test_power_no_nan_grads(self): 49 | x = np.exp(1.) 50 | y = 0. 51 | self.assertEqual(math.power_no_nan(x, y), 1.) 52 | grad_fn = jax.grad( 53 | lambda inputs: math.power_no_nan(inputs[0], inputs[1])) 54 | np.testing.assert_allclose(grad_fn((x, y)), (0., 1.), rtol=1e-3) 55 | 56 | def test_normalize_probs(self): 57 | pre_normalised_probs = jnp.array([0.4, 0.4, 0., 0.2]) 58 | unnormalised_probs = jnp.array([4., 4., 0., 2.]) 59 | expected_probs = jnp.array([0.4, 0.4, 0., 0.2]) 60 | np.testing.assert_array_almost_equal( 61 | math.normalize(probs=pre_normalised_probs), expected_probs) 62 | np.testing.assert_array_almost_equal( 63 | math.normalize(probs=unnormalised_probs), expected_probs) 64 | 65 | def test_normalize_logits(self): 66 | unnormalised_logits = jnp.array([1., -1., 3.]) 67 | expected_logits = jax.nn.log_softmax(unnormalised_logits, axis=-1) 68 | np.testing.assert_array_almost_equal( 69 | math.normalize(logits=unnormalised_logits), expected_logits) 70 | np.testing.assert_array_almost_equal( 71 | math.normalize(logits=expected_logits), expected_logits) 72 | 73 | def test_sum_last(self): 74 | x = jax.random.normal(jax.random.PRNGKey(42), (2, 3, 4)) 75 | np.testing.assert_array_equal(math.sum_last(x, 0), x) 76 | np.testing.assert_array_equal(math.sum_last(x, 1), x.sum(-1)) 77 | np.testing.assert_array_equal(math.sum_last(x, 2), x.sum((-2, -1))) 78 | np.testing.assert_array_equal(math.sum_last(x, 3), x.sum()) 79 | 80 | def test_log_expbig_minus_expsmall(self): 81 | small = jax.random.normal(jax.random.PRNGKey(42), (2, 3, 4)) 82 | big = small + jax.random.uniform(jax.random.PRNGKey(43), (2, 3, 4)) 83 | expected_result = np.log(np.exp(big) - np.exp(small)) 84 | np.testing.assert_allclose( 85 | math.log_expbig_minus_expsmall(big, small), expected_result, atol=1e-4) 86 | 87 | def test_log_beta(self): 88 | a = jnp.abs(jax.random.normal(jax.random.PRNGKey(42), (2, 3, 4))) 89 | b = jnp.abs(jax.random.normal(jax.random.PRNGKey(43), (3, 4))) 90 | expected_result = scipy.special.betaln(a, b) 91 | np.testing.assert_allclose(math.log_beta(a, b), expected_result, atol=2e-4) 92 | 93 | def test_log_beta_bivariate(self): 94 | a = jnp.abs(jax.random.normal(jax.random.PRNGKey(42), (4, 3, 2))) 95 | expected_result = scipy.special.betaln(a[..., 0], a[..., 1]) 96 | np.testing.assert_allclose( 97 | math.log_beta_multivariate(a), expected_result, atol=2e-4) 98 | 99 | def test_log_beta_multivariate(self): 100 | a = jnp.abs(jax.random.normal(jax.random.PRNGKey(42), (2, 3, 4))) 101 | expected_result = (jnp.sum(scipy.special.gammaln(a), axis=-1) 102 | - scipy.special.gammaln(jnp.sum(a, axis=-1))) 103 | np.testing.assert_allclose( 104 | math.log_beta_multivariate(a), expected_result, atol=1e-3) 105 | 106 | if __name__ == '__main__': 107 | absltest.main() 108 | -------------------------------------------------------------------------------- /distrax/_src/utils/monte_carlo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Monte-Carlo estimation of the KL divergence.""" 16 | 17 | from typing import Optional 18 | 19 | import chex 20 | from distrax._src.distributions.distribution import DistributionLike 21 | from distrax._src.utils import conversion 22 | import jax 23 | import jax.numpy as jnp 24 | from tensorflow_probability.substrates import jax as tfp 25 | 26 | tfd = tfp.distributions 27 | PRNGKey = chex.PRNGKey 28 | 29 | 30 | def estimate_kl_best_effort( 31 | distribution_a: DistributionLike, 32 | distribution_b: DistributionLike, 33 | rng_key: PRNGKey, 34 | num_samples: int, 35 | proposal_distribution: Optional[DistributionLike] = None): 36 | """Estimates KL(distribution_a, distribution_b) exactly or with DiCE. 37 | 38 | If the kl_divergence(distribution_a, distribution_b) is not supported, 39 | the DiCE estimator is used instead. 40 | 41 | Args: 42 | distribution_a: The first distribution. 43 | distribution_b: The second distribution. 44 | rng_key: The PRNGKey random key. 45 | num_samples: The number of samples, if using the DiCE estimator. 46 | proposal_distribution: A proposal distribution for the samples, if using 47 | the DiCE estimator. If None, use `distribution_a` as proposal. 48 | 49 | Returns: 50 | The estimated KL divergence. 51 | """ 52 | distribution_a = conversion.as_distribution(distribution_a) 53 | distribution_b = conversion.as_distribution(distribution_b) 54 | # If possible, compute the exact KL. 55 | try: 56 | return tfd.kl_divergence(distribution_a, distribution_b) 57 | except NotImplementedError: 58 | pass 59 | return mc_estimate_kl(distribution_a, distribution_b, rng_key, 60 | num_samples=num_samples, 61 | proposal_distribution=proposal_distribution) 62 | 63 | 64 | def mc_estimate_kl( 65 | distribution_a: DistributionLike, 66 | distribution_b: DistributionLike, 67 | rng_key: PRNGKey, 68 | num_samples: int, 69 | proposal_distribution: Optional[DistributionLike] = None): 70 | """Estimates KL(distribution_a, distribution_b) with the DiCE estimator. 71 | 72 | To get correct gradients with respect the `distribution_a`, we use the DiCE 73 | estimator, i.e., we stop the gradient with respect to the samples and with 74 | respect to the denominator in the importance weights. We then do not need 75 | reparametrized distributions. 76 | 77 | Args: 78 | distribution_a: The first distribution. 79 | distribution_b: The second distribution. 80 | rng_key: The PRNGKey random key. 81 | num_samples: The number of samples, if using the DiCE estimator. 82 | proposal_distribution: A proposal distribution for the samples, if using the 83 | DiCE estimator. If None, use `distribution_a` as proposal. 84 | 85 | Returns: 86 | The estimated KL divergence. 87 | """ 88 | if proposal_distribution is None: 89 | proposal_distribution = distribution_a 90 | proposal_distribution = conversion.as_distribution(proposal_distribution) 91 | distribution_a = conversion.as_distribution(distribution_a) 92 | distribution_b = conversion.as_distribution(distribution_b) 93 | 94 | samples, logp_proposal = proposal_distribution.sample_and_log_prob( 95 | seed=rng_key, sample_shape=[num_samples]) 96 | samples = jax.lax.stop_gradient(samples) 97 | logp_proposal = jax.lax.stop_gradient(logp_proposal) 98 | logp_a = distribution_a.log_prob(samples) 99 | logp_b = distribution_b.log_prob(samples) 100 | importance_weight = jnp.exp(logp_a - logp_proposal) 101 | log_ratio = logp_b - logp_a 102 | kl_estimator = -importance_weight * log_ratio 103 | return jnp.mean(kl_estimator, axis=0) 104 | 105 | 106 | def mc_estimate_kl_with_reparameterized( 107 | distribution_a: DistributionLike, 108 | distribution_b: DistributionLike, 109 | rng_key: PRNGKey, 110 | num_samples: int): 111 | """Estimates KL(distribution_a, distribution_b).""" 112 | if isinstance(distribution_a, tfd.Distribution): 113 | if distribution_a.reparameterization_type != tfd.FULLY_REPARAMETERIZED: 114 | raise ValueError( 115 | f'Distribution `{distribution_a.name}` cannot be reparameterized.') 116 | distribution_a = conversion.as_distribution(distribution_a) 117 | distribution_b = conversion.as_distribution(distribution_b) 118 | 119 | samples, logp_a = distribution_a.sample_and_log_prob( 120 | seed=rng_key, sample_shape=[num_samples]) 121 | logp_b = distribution_b.log_prob(samples) 122 | log_ratio = logp_b - logp_a 123 | kl_estimator = -log_ratio 124 | return jnp.mean(kl_estimator, axis=0) 125 | 126 | 127 | def mc_estimate_mode( 128 | distribution: DistributionLike, 129 | rng_key: PRNGKey, 130 | num_samples: int): 131 | """Returns a Monte Carlo estimate of the mode of a distribution.""" 132 | distribution = conversion.as_distribution(distribution) 133 | # Obtain samples from the distribution and their log probability. 134 | samples, log_probs = distribution.sample_and_log_prob( 135 | seed=rng_key, sample_shape=[num_samples]) 136 | # Do argmax over the sample_shape. 137 | index = jnp.expand_dims(jnp.argmax(log_probs, axis=0), axis=0) 138 | # Broadcast index to include event_shape of the sample. 139 | index = index.reshape(index.shape + (1,) * (samples.ndim - index.ndim)) 140 | mode = jnp.squeeze(jnp.take_along_axis(samples, index, axis=0), axis=0) 141 | return mode 142 | -------------------------------------------------------------------------------- /distrax/_src/utils/monte_carlo_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `monte_carlo.py`.""" 16 | 17 | from absl.testing import absltest 18 | 19 | import chex 20 | from distrax._src.distributions import mvn_diag 21 | from distrax._src.distributions import normal 22 | from distrax._src.utils import monte_carlo 23 | import haiku as hk 24 | import jax 25 | import numpy as np 26 | from tensorflow_probability.substrates import jax as tfp 27 | 28 | 29 | class McTest(absltest.TestCase): 30 | 31 | def test_estimate_kl_with_dice(self): 32 | batch_size = 5 33 | num_actions = 11 34 | num_samples = 1024 35 | rng_seq = hk.PRNGSequence(0) 36 | 37 | distribution_a = tfp.distributions.Categorical( 38 | logits=jax.random.normal(next(rng_seq), [batch_size, num_actions])) 39 | distribution_b = tfp.distributions.Categorical( 40 | logits=jax.random.normal(next(rng_seq), [batch_size, num_actions])) 41 | 42 | kl_estim_exact = monte_carlo.estimate_kl_best_effort( 43 | distribution_a, distribution_b, next(rng_seq), num_samples=num_samples) 44 | kl_estim_mc = monte_carlo.mc_estimate_kl( 45 | distribution_a, distribution_b, next(rng_seq), num_samples=num_samples) 46 | kl = distribution_a.kl_divergence(distribution_b) 47 | np.testing.assert_allclose(kl, kl_estim_exact, rtol=1e-5) 48 | np.testing.assert_allclose(kl, kl_estim_mc, rtol=2e-1) 49 | 50 | def test_estimate_continuous_kl_with_dice(self): 51 | _check_kl_estimator(monte_carlo.mc_estimate_kl, tfp.distributions.Normal) 52 | _check_kl_estimator(monte_carlo.mc_estimate_kl, normal.Normal) 53 | 54 | def test_estimate_continuous_kl_with_reparameterized(self): 55 | _check_kl_estimator(monte_carlo.mc_estimate_kl_with_reparameterized, 56 | tfp.distributions.Normal) 57 | _check_kl_estimator(monte_carlo.mc_estimate_kl_with_reparameterized, 58 | normal.Normal) 59 | 60 | def test_estimate_mode(self): 61 | with self.subTest('ScalarEventShape'): 62 | distribution = normal.Normal( 63 | loc=np.zeros((4, 5, 100)), 64 | scale=np.ones((4, 5, 100))) 65 | # pytype: disable=wrong-arg-types 66 | mode_estimate = monte_carlo.mc_estimate_mode( 67 | distribution, rng_key=42, num_samples=100) 68 | # pytype: enable=wrong-arg-types 69 | mean_mode_estimate = np.abs(np.mean(mode_estimate)) 70 | self.assertLess(mean_mode_estimate, 1e-3) 71 | with self.subTest('NonScalarEventShape'): 72 | distribution = mvn_diag.MultivariateNormalDiag( 73 | loc=np.zeros((4, 5, 100)), 74 | scale_diag=np.ones((4, 5, 100))) 75 | # pytype: disable=wrong-arg-types 76 | mv_mode_estimate = monte_carlo.mc_estimate_mode( 77 | distribution, rng_key=42, num_samples=100) 78 | # pytype: enable=wrong-arg-types 79 | mean_mv_mode_estimate = np.abs(np.mean(mv_mode_estimate)) 80 | self.assertLess(mean_mv_mode_estimate, 1e-1) 81 | # The mean of the mode-estimate of the Normal should be a lot closer 82 | # to 0 compared to the MultivariateNormal, because the 100 less samples 83 | # are taken and most of the mass in a high-dimensional gaussian is NOT 84 | # at 0! 85 | self.assertLess(10 * mean_mode_estimate, mean_mv_mode_estimate) 86 | 87 | 88 | def _check_kl_estimator(estimator_fn, distribution_fn, num_samples=10000, 89 | rtol=1e-1, atol=1e-3, grad_rtol=2e-1, grad_atol=1e-1): 90 | """Compares the estimator_fn output and gradient to exact KL.""" 91 | rng_key = jax.random.PRNGKey(0) 92 | 93 | def expected_kl(params): 94 | distribution_a = distribution_fn(**params[0]) 95 | distribution_b = distribution_fn(**params[1]) 96 | return distribution_a.kl_divergence(distribution_b) 97 | 98 | def estimate_kl(params): 99 | distribution_a = distribution_fn(**params[0]) 100 | distribution_b = distribution_fn(**params[1]) 101 | return estimator_fn(distribution_a, distribution_b, rng_key=rng_key, 102 | num_samples=num_samples) 103 | 104 | params = ( 105 | dict(loc=0.0, scale=1.0), 106 | dict(loc=0.1, scale=1.0), 107 | ) 108 | expected_value, expected_grad = jax.value_and_grad(expected_kl)(params) 109 | value, grad = jax.value_and_grad(estimate_kl)(params) 110 | 111 | np.testing.assert_allclose(expected_value, value, rtol=rtol, atol=atol) 112 | chex.assert_trees_all_close( 113 | expected_grad, grad, rtol=grad_rtol, atol=grad_atol 114 | ) 115 | 116 | 117 | if __name__ == '__main__': 118 | absltest.main() 119 | -------------------------------------------------------------------------------- /distrax/distrax_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for distrax.""" 16 | 17 | from absl.testing import absltest 18 | 19 | import distrax 20 | 21 | 22 | class DistraxTest(absltest.TestCase): 23 | """Test distrax can be imported correctly.""" 24 | 25 | def test_import(self): 26 | self.assertTrue(hasattr(distrax, 'Uniform')) 27 | 28 | 29 | if __name__ == '__main__': 30 | absltest.main() 31 | -------------------------------------------------------------------------------- /requirements/requirements-examples.txt: -------------------------------------------------------------------------------- 1 | tensorflow>=2.4.0 2 | tensorflow-datasets>=4.2.0 3 | dm-haiku>=0.0.3 4 | optax>=0.0.6 5 | -------------------------------------------------------------------------------- /requirements/requirements-tests.txt: -------------------------------------------------------------------------------- 1 | dm-haiku>=0.0.3 2 | mock>=4.0.3 3 | -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py>=0.9.0 2 | chex>=0.1.8 3 | jax>=0.1.55 4 | jaxlib>=0.1.67 5 | numpy>=1.23.0,<2 6 | setuptools;python_version>="3.12" 7 | tensorflow-probability>=0.15.0 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Install script for setuptools.""" 16 | 17 | import os 18 | from setuptools import find_namespace_packages 19 | from setuptools import setup 20 | 21 | _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 22 | 23 | 24 | def _get_version(): 25 | with open(os.path.join(_CURRENT_DIR, 'distrax', '__init__.py')) as fp: 26 | for line in fp: 27 | if line.startswith('__version__') and '=' in line: 28 | version = line[line.find('=') + 1:].strip(' \'"\n') 29 | if version: 30 | return version 31 | raise ValueError('`__version__` not defined in `distrax/__init__.py`') 32 | 33 | 34 | def _parse_requirements(path): 35 | 36 | with open(os.path.join(_CURRENT_DIR, path)) as f: 37 | return [ 38 | line.rstrip() 39 | for line in f 40 | if not (line.isspace() or line.startswith('#')) 41 | ] 42 | 43 | 44 | setup( 45 | name='distrax', 46 | version=_get_version(), 47 | url='https://github.com/deepmind/distrax', 48 | license='Apache 2.0', 49 | author='DeepMind', 50 | description=('Distrax: Probability distributions in JAX.'), 51 | long_description=open(os.path.join(_CURRENT_DIR, 'README.md')).read(), 52 | long_description_content_type='text/markdown', 53 | author_email='distrax-dev@google.com', 54 | keywords='jax probability distribution python machine learning', 55 | packages=find_namespace_packages(exclude=['*_test.py']), 56 | install_requires=_parse_requirements( 57 | os.path.join(_CURRENT_DIR, 'requirements', 'requirements.txt')), 58 | tests_require=_parse_requirements( 59 | os.path.join(_CURRENT_DIR, 'requirements', 'requirements-tests.txt')), 60 | zip_safe=False, # Required for full installation. 61 | include_package_data=True, 62 | python_requires='>=3.9', 63 | classifiers=[ 64 | 'Development Status :: 5 - Production/Stable', 65 | 'Environment :: Console', 66 | 'Intended Audience :: Science/Research', 67 | 'Intended Audience :: Developers', 68 | 'License :: OSI Approved :: Apache Software License', 69 | 'Operating System :: OS Independent', 70 | 'Programming Language :: Python', 71 | 'Programming Language :: Python :: 3', 72 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 73 | 'Topic :: Scientific/Engineering :: Mathematics', 74 | 'Topic :: Scientific/Engineering :: Physics', 75 | 'Topic :: Software Development :: Libraries :: Python Modules', 76 | ], 77 | ) 78 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # Runs CI tests on a local machine. 17 | set -xeuo pipefail 18 | 19 | # Install deps in a virtual env. 20 | readonly VENV_DIR=/tmp/distrax-env 21 | rm -rf "${VENV_DIR}" 22 | python3 -m venv "${VENV_DIR}" 23 | source "${VENV_DIR}/bin/activate" 24 | python --version 25 | 26 | # Install dependencies. 27 | pip install --upgrade pip setuptools wheel 28 | pip install flake8 pytest-xdist pytest-forked pylint pylint-exit 29 | pip install -r requirements/requirements.txt 30 | pip install -r requirements/requirements-tests.txt 31 | 32 | # Lint with flake8. 33 | flake8 `find distrax -name '*.py' | xargs` --count --select=E9,F63,F7,F82,E225,E251 --show-source --statistics 34 | 35 | # Lint with pylint. 36 | # Fail on errors, warning, conventions and refactoring messages. 37 | PYLINT_ARGS="-efail -wfail -cfail -rfail" 38 | # Download Google OSS config. 39 | wget -nd -v -t 3 -O .pylintrc https://google.github.io/styleguide/pylintrc 40 | # Append specific config lines. 41 | echo "disable=abstract-method,unnecessary-lambda-assignment,no-value-for-parameter,use-dict-literal" >> .pylintrc 42 | # Lint modules and tests separately. 43 | # Disable `abstract-method` warnings. 44 | pylint --rcfile=.pylintrc `find distrax -name '*.py' | grep -v 'test.py' | xargs` || pylint-exit $PYLINT_ARGS $? 45 | # Disable `protected-access` and `arguments-differ` warnings for tests. 46 | pylint --rcfile=.pylintrc `find distrax -name '*_test.py' | xargs` -d W0212,W0221 || pylint-exit $PYLINT_ARGS $? 47 | # Cleanup. 48 | rm .pylintrc 49 | 50 | # Build the package. 51 | python setup.py sdist 52 | pip wheel --verbose --no-deps --no-clean dist/distrax*.tar.gz 53 | pip install distrax*.whl 54 | 55 | # Use TFP nightly builds in tests. 56 | pip uninstall tensorflow-probability -y 57 | pip install tfp-nightly 58 | 59 | # Check types with pytype. 60 | # Note: pytype does not support 3.11 as of 25.06.23 61 | # See https://github.com/google/pytype/issues/1308 62 | if [ `python -c 'import sys; print(sys.version_info.minor)'` -lt 11 ]; 63 | then 64 | pip install pytype 65 | pytype `find distrax/_src/ -name "*py" | xargs` -k 66 | fi; 67 | 68 | # Run tests using pytest. 69 | # Change directory to avoid importing the package from repo root. 70 | mkdir _testing && cd _testing 71 | 72 | # Main tests. 73 | 74 | # Disable JAX optimizations to speed up tests. 75 | export JAX_DISABLE_MOST_OPTIMIZATIONS=True 76 | pytest -n"$(grep -c ^processor /proc/cpuinfo)" --forked `find ../distrax/_src/ -name "*_test.py" | sort` -k 'not _float64_test' 77 | 78 | # Isolate tests that set double precision. 79 | pytest -n"$(grep -c ^processor /proc/cpuinfo)" --forked `find ../distrax/_src/ -name "*_test.py" | sort` -k '_float64_test' 80 | unset JAX_DISABLE_MOST_OPTIMIZATIONS 81 | 82 | cd .. 83 | 84 | set +u 85 | deactivate 86 | echo "All tests passed. Congrats!" 87 | --------------------------------------------------------------------------------