├── bayestorch ├── py.typed ├── optim │ ├── __init__.py │ ├── sgld.py │ ├── svgd.py │ └── sghmc.py ├── __init__.py ├── nn │ ├── __init__.py │ ├── utils.py │ ├── prior_module.py │ ├── particle_posterior_module.py │ └── variational_posterior_module.py ├── distributions │ ├── __init__.py │ ├── log_scale_normal.py │ ├── softplus_inv_scale_normal.py │ ├── constraints.py │ ├── deterministic.py │ ├── finite.py │ └── cat_distribution.py └── version.py ├── examples ├── regression │ ├── requirements.txt │ ├── README.md │ ├── train.py │ ├── train_bbb.py │ ├── train_mcmc.py │ └── train_svgd.py └── mnist │ ├── requirements.txt │ ├── README.md │ ├── train.py │ ├── train_mcmc.py │ ├── train_svgd.py │ └── train_bbb.py ├── MANIFEST.in ├── setup.cfg ├── tests ├── nn │ ├── test_nn_utils.py │ ├── test_prior_module.py │ ├── test_particle_posterior_module.py │ └── test_variational_posterior_module.py ├── distributions │ ├── test_constraints.py │ ├── test_deterministic.py │ ├── test_log_scale_normal.py │ ├── test_finite.py │ ├── test_softplus_inv_scale_normal.py │ ├── test_cat_distribution.py │ └── test_distributions_utils.py └── optim │ ├── test_sgld.py │ ├── test_sghmc.py │ └── test_svgd.py ├── .gitignore ├── .pre-commit-config.yaml ├── NOTICE ├── setup.py ├── README.md └── LICENSE /bayestorch/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/regression/requirements.txt: -------------------------------------------------------------------------------- 1 | bayestorch 2 | torch 3 | -------------------------------------------------------------------------------- /examples/mnist/requirements.txt: -------------------------------------------------------------------------------- 1 | bayestorch 2 | torch 3 | torchvision 4 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include bayestorch/py.typed 2 | global-exclude __pycache__ 3 | global-exclude *.py[co] 4 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E266, E501, F403, W503 3 | max-line-length = 88 4 | max-complexity = 18 5 | select = B, C, E, F, W, T4 6 | 7 | [isort] 8 | force_grid_wrap = 0 9 | include_trailing_comma = True 10 | known_local_folder = bayestorch 11 | line_length = 88 12 | lines_after_imports = 2 13 | multi_line_output = 3 14 | skip_gitignore = True 15 | use_parentheses = True 16 | 17 | [tool:pytest] 18 | addopts = --cache-clear --cov=bayestorch --doctest-modules 19 | testpaths = bayestorch tests 20 | -------------------------------------------------------------------------------- /examples/mnist/README.md: -------------------------------------------------------------------------------- 1 | # Basic MNIST Example 2 | 3 | Trains a convolutional neural network on MNIST data. 4 | 5 | ## Standard non-Bayesian 6 | 7 | ``` 8 | pip install -r requirements.txt 9 | python train.py 10 | ``` 11 | 12 | ## Bayes by Backprop (BBB) 13 | 14 | ``` 15 | pip install -r requirements.txt 16 | python train_bbb.py 17 | ``` 18 | 19 | ## Markov chain Monte Carlo (MCMC) 20 | 21 | ``` 22 | pip install -r requirements.txt 23 | python train_mcmc.py 24 | ``` 25 | 26 | ## Stein variational gradient descent (SVGD) 27 | 28 | ``` 29 | pip install -r requirements.txt 30 | python train_svgd.py 31 | ``` -------------------------------------------------------------------------------- /examples/regression/README.md: -------------------------------------------------------------------------------- 1 | # Linear regression example 2 | 3 | Trains a single fully-connected layer to fit a 4th degree polynomial. 4 | 5 | ## Standard non-Bayesian 6 | 7 | ``` 8 | pip install -r requirements.txt 9 | python train.py 10 | ``` 11 | 12 | ## Bayes by Backprop (BBB) 13 | 14 | ``` 15 | pip install -r requirements.txt 16 | python train_bbb.py 17 | ``` 18 | 19 | ## Markov chain Monte Carlo (MCMC) 20 | 21 | ``` 22 | pip install -r requirements.txt 23 | python train_mcmc.py 24 | ``` 25 | 26 | ## Stein variational gradient descent (SVGD) 27 | 28 | ``` 29 | pip install -r requirements.txt 30 | python train_svgd.py 31 | ``` 32 | -------------------------------------------------------------------------------- /bayestorch/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2022 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Optimization.""" 18 | 19 | from bayestorch.optim.sghmc import * 20 | from bayestorch.optim.sgld import * 21 | from bayestorch.optim.svgd import * 22 | -------------------------------------------------------------------------------- /bayestorch/__init__.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2022 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Lightweight Bayesian deep learning library for fast prototyping based on PyTorch.""" 18 | 19 | from bayestorch import distributions, nn, optim 20 | from bayestorch.version import VERSION as __version__ 21 | -------------------------------------------------------------------------------- /bayestorch/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2022 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Neural networks.""" 18 | 19 | from bayestorch.nn.particle_posterior_module import * 20 | from bayestorch.nn.prior_module import * 21 | from bayestorch.nn.utils import * 22 | from bayestorch.nn.variational_posterior_module import * 23 | -------------------------------------------------------------------------------- /bayestorch/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2022 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Distributions.""" 18 | 19 | from bayestorch.distributions.cat_distribution import * 20 | from bayestorch.distributions.constraints import * 21 | from bayestorch.distributions.deterministic import * 22 | from bayestorch.distributions.finite import * 23 | from bayestorch.distributions.log_scale_normal import * 24 | from bayestorch.distributions.softplus_inv_scale_normal import * 25 | from bayestorch.distributions.utils import * 26 | -------------------------------------------------------------------------------- /bayestorch/version.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2022 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Version according to SemVer versioning system (https://semver.org/).""" 18 | 19 | 20 | __all__ = [ 21 | "VERSION", 22 | ] 23 | 24 | 25 | _MAJOR = "0" # Major version to increment in case of incompatible API changes 26 | 27 | _MINOR = ( 28 | "0" # Minor version to increment in case of backward compatible new functionality 29 | ) 30 | 31 | _PATCH = "3" # Patch version to increment in case of backward compatible bug fixes 32 | 33 | VERSION = f"{_MAJOR}.{_MINOR}.{_PATCH}" 34 | """The package version.""" 35 | -------------------------------------------------------------------------------- /tests/nn/test_nn_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # ============================================================================== 4 | # Copyright 2022 Luca Della Libera. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # ============================================================================== 18 | 19 | """Test neural network utilities.""" 20 | 21 | import pytest 22 | import torch 23 | 24 | from bayestorch.nn.utils import nested_apply 25 | 26 | 27 | def test_nested_apply() -> "None": 28 | num_outputs = 4 29 | inputs = [ 30 | {"a": [torch.rand(2, 3), torch.rand(3, 5)], "b": torch.rand(1, 2)} 31 | for _ in range(num_outputs) 32 | ] 33 | outputs = nested_apply(torch.stack, inputs) 34 | print(f"Shape of first nested input: {inputs[0]['a'][0].shape}") 35 | print(f"Shape of first nested output: {outputs['a'][0].shape}") 36 | 37 | 38 | if __name__ == "__main__": 39 | pytest.main([__file__]) 40 | -------------------------------------------------------------------------------- /examples/regression/train.py: -------------------------------------------------------------------------------- 1 | # Adapted from: 2 | # https://github.com/pytorch/examples/blob/9aad148615b7519eadfa1a60356116a50561f192/regression/main.py 3 | 4 | #!/usr/bin/env python 5 | from __future__ import print_function 6 | from itertools import count 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | POLY_DEGREE = 4 12 | W_target = torch.randn(POLY_DEGREE, 1) * 5 13 | b_target = torch.randn(1) * 5 14 | 15 | 16 | def make_features(x): 17 | """Builds features i.e. a matrix with columns [x, x^2, x^3, x^4].""" 18 | x = x.unsqueeze(1) 19 | return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1) 20 | 21 | 22 | def f(x): 23 | """Approximated function.""" 24 | return x.mm(W_target) + b_target.item() 25 | 26 | 27 | def poly_desc(W, b): 28 | """Creates a string description of a polynomial.""" 29 | result = 'y = ' 30 | for i, w in enumerate(W): 31 | result += '{:+.2f} x^{} '.format(w, i + 1) 32 | result += '{:+.2f}'.format(b[0]) 33 | return result 34 | 35 | 36 | def get_batch(batch_size=32): 37 | """Builds a batch i.e. (x, f(x)) pair.""" 38 | random = torch.randn(batch_size) 39 | x = make_features(random) 40 | y = f(x) 41 | return x, y 42 | 43 | 44 | # Define model 45 | fc = torch.nn.Linear(W_target.size(0), 1) 46 | 47 | for batch_idx in count(1): 48 | # Get data 49 | batch_x, batch_y = get_batch() 50 | 51 | # Reset gradients 52 | fc.zero_grad() 53 | 54 | # Forward pass 55 | output = F.smooth_l1_loss(fc(batch_x), batch_y) 56 | loss = output.item() 57 | 58 | # Backward pass 59 | output.backward() 60 | 61 | # Apply gradients 62 | for param in fc.parameters(): 63 | param.data.add_(-0.1 * param.grad) 64 | 65 | # Stop criterion 66 | if loss < 1e-3: 67 | break 68 | 69 | print('Loss: {:.6f} after {} batches'.format(loss, batch_idx)) 70 | print('==> Learned function:\t' + poly_desc(fc.weight.view(-1), fc.bias)) 71 | print('==> Actual function:\t' + poly_desc(W_target.view(-1), b_target)) 72 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # IPython 78 | profile_default/ 79 | ipython_config.py 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/* 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | .dmypy.json 112 | dmypy.json 113 | 114 | # Pyre type checker 115 | .pyre/ 116 | 117 | # PyCharm 118 | .idea/ 119 | 120 | # Visual Studio Code 121 | .vscode/ 122 | 123 | # Data 124 | */data/* -------------------------------------------------------------------------------- /tests/distributions/test_constraints.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # ============================================================================== 4 | # Copyright 2022 Luca Della Libera. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # ============================================================================== 18 | 19 | """Test constraints.""" 20 | 21 | import pytest 22 | import torch 23 | from torch.distributions.constraints import independent, positive, real 24 | 25 | from bayestorch.distributions.constraints import cat, ordered_real_vector, real_set 26 | 27 | 28 | def test_cat() -> "None": 29 | constraint = cat([independent(real, 1), independent(positive, 1)], lengths=(2, 1)) 30 | check = constraint.check(torch.as_tensor([-0.2, -0.5, 2.3])) 31 | print(f"Constraint: {constraint}") 32 | print(f"Is discrete: {constraint.is_discrete}") 33 | print(f"Check: {check}") 34 | 35 | 36 | def test_ordered_real_vector() -> "None": 37 | constraint = ordered_real_vector 38 | check = constraint.check(torch.as_tensor([0.2, 0.4, 0.8])) 39 | print(f"Constraint: {constraint}") 40 | print(f"Is discrete: {constraint.is_discrete}") 41 | print(f"Check: {check}") 42 | print(f"Is discrete: {constraint.is_discrete}") 43 | 44 | 45 | def test_real_set() -> "None": 46 | constraint = real_set(torch.as_tensor([0.2, 0.4, 0.8])) 47 | check = constraint.check(torch.as_tensor(0.2)) 48 | print(f"Constraint: {constraint}") 49 | print(f"Is discrete: {constraint.is_discrete}") 50 | print(f"Check: {check}") 51 | 52 | 53 | if __name__ == "__main__": 54 | pytest.main([__file__]) 55 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: isort 5 | name: isort 6 | language: system 7 | entry: isort 8 | types: [python] 9 | exclude: (^examples/mnist|^examples/regression) 10 | 11 | - id: black 12 | name: black 13 | language: system 14 | entry: black 15 | types: [python] 16 | exclude: (^examples/mnist|^examples/regression) 17 | 18 | - id: trailing-whitespace 19 | name: trailing-whitespace 20 | language: system 21 | entry: trailing-whitespace-fixer 22 | types: [python] 23 | 24 | - id: end-of-file-fixer 25 | name: end-of-file-fixer 26 | language: system 27 | entry: end-of-file-fixer 28 | types: [python] 29 | 30 | - id: mixed-line-ending 31 | name: mixed-line-ending 32 | language: system 33 | entry: mixed-line-ending 34 | types: [python] 35 | args: ["--fix=lf"] 36 | 37 | - id: fix-encoding-pragma 38 | name: fix-encoding-pragma 39 | language: system 40 | entry: fix-encoding-pragma 41 | types: [python] 42 | args: ["--remove"] 43 | 44 | - id: check-case-conflict 45 | name: check-case-conflict 46 | language: system 47 | entry: check-case-conflict 48 | types: [python] 49 | 50 | - id: check-merge-conflict 51 | name: check-merge-conflict 52 | language: system 53 | entry: check-merge-conflict 54 | types: [file] 55 | 56 | - id: flake8 57 | name: flake8 except __init__.py 58 | language: system 59 | entry: flake8 60 | types: [python] 61 | exclude: (^examples/mnist|^examples/regression|/__init__\.py) 62 | 63 | - id: flake8 64 | name: flake8 only __init__.py 65 | language: system 66 | entry: flake8 67 | types: [python] 68 | # Ignore unused imports in __init__.py 69 | args: ["--extend-ignore=F401"] 70 | files: /__init__\.py 71 | 72 | - id: pytest 73 | name: pytest 74 | language: system 75 | entry: pytest 76 | types: [python] 77 | files: ^test 78 | -------------------------------------------------------------------------------- /tests/distributions/test_deterministic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # ============================================================================== 4 | # Copyright 2022 Luca Della Libera. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # ============================================================================== 18 | 19 | """Test deterministic distribution.""" 20 | 21 | import pytest 22 | from torch.distributions import kl_divergence 23 | 24 | from bayestorch.distributions import Deterministic 25 | 26 | 27 | def test_deterministic() -> "None": 28 | value = 1.0 29 | distribution = Deterministic(value) 30 | print(distribution) 31 | print(distribution.expand((2, 3))) 32 | if distribution.has_rsample: 33 | distribution.rsample() 34 | else: 35 | distribution.sample() 36 | print(f"Mean: {distribution.mean}") 37 | print(f"Mode: {distribution.mode}") 38 | print(f"Standard deviation: {distribution.stddev}") 39 | print(f"Variance: {distribution.variance}") 40 | print(f"Log prob: {distribution.log_prob(distribution.sample())}") 41 | print(f"CDF: {distribution.cdf(distribution.sample())}") 42 | print(f"Entropy: {distribution.entropy()}") 43 | print(f"Support: {distribution.support}") 44 | print(f"Enumerated support: {distribution.enumerate_support()}") 45 | try: 46 | print(f"Enumerated support: {distribution.enumerate_support(False)}") 47 | except NotImplementedError: 48 | pass 49 | print( 50 | f"Kullback-Leibler divergence: " 51 | f"{kl_divergence(distribution, Deterministic(value, validate_args=True))}" 52 | ) 53 | 54 | 55 | if __name__ == "__main__": 56 | pytest.main([__file__]) 57 | -------------------------------------------------------------------------------- /tests/distributions/test_log_scale_normal.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # ============================================================================== 4 | # Copyright 2022 Luca Della Libera. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # ============================================================================== 18 | 19 | """Test log scale normal distribution.""" 20 | 21 | import pytest 22 | from torch.distributions import kl_divergence 23 | 24 | from bayestorch.distributions import LogScaleNormal 25 | 26 | 27 | def test_log_scale_normal() -> "None": 28 | loc = 0.0 29 | log_scale = -1.0 30 | distribution = LogScaleNormal(loc, log_scale) 31 | print(distribution) 32 | print(distribution.expand((2, 3))) 33 | if distribution.has_rsample: 34 | distribution.rsample() 35 | else: 36 | distribution.sample() 37 | print(f"Mean: {distribution.mean}") 38 | print(f"Mode: {distribution.mode}") 39 | print(f"Standard deviation: {distribution.stddev}") 40 | print(f"Variance: {distribution.variance}") 41 | print(f"Log prob: {distribution.log_prob(distribution.sample())}") 42 | print(f"CDF: {distribution.cdf(distribution.sample())}") 43 | print(f"Entropy: {distribution.entropy()}") 44 | print(f"Support: {distribution.support}") 45 | try: 46 | print(f"Enumerated support: {distribution.enumerate_support()}") 47 | print(f"Enumerated support: {distribution.enumerate_support(False)}") 48 | except NotImplementedError: 49 | pass 50 | print( 51 | f"Kullback-Leibler divergence: " 52 | f"{kl_divergence(distribution, LogScaleNormal(loc, log_scale, validate_args=True))}" 53 | ) 54 | 55 | 56 | if __name__ == "__main__": 57 | pytest.main([__file__]) 58 | -------------------------------------------------------------------------------- /tests/nn/test_prior_module.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # ============================================================================== 4 | # Copyright 2022 Luca Della Libera. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # ============================================================================== 18 | 19 | """Test prior module.""" 20 | 21 | import pytest 22 | import torch 23 | from torch import nn 24 | 25 | from bayestorch.distributions import LogScaleNormal 26 | from bayestorch.nn import PriorModule 27 | 28 | 29 | def test_prior_module() -> "None": 30 | batch_size = 10 31 | in_features = 4 32 | out_features = 2 33 | model = nn.Linear(in_features, out_features) 34 | num_parameters = sum(parameter.numel() for parameter in model.parameters()) 35 | model = PriorModule( 36 | model, 37 | prior_builder=LogScaleNormal, 38 | prior_kwargs={ 39 | "loc": torch.zeros(num_parameters), 40 | "log_scale": torch.full((num_parameters,), -1.0), 41 | }, 42 | ).to("cpu") 43 | input = torch.rand(batch_size, in_features) 44 | _ = model(input) 45 | output, log_prior = model(input, return_log_prior=True) 46 | print(model) 47 | print(dict(model.named_parameters()).keys()) 48 | print(model.parameters()) 49 | print(dict(model.named_parameters(include_all=False)).keys()) 50 | print(model.parameters(include_all=False)) 51 | state_dict = model.state_dict() 52 | model.load_state_dict(state_dict) 53 | print(f"Batch size: {batch_size}") 54 | print(f"Input shape: {(batch_size, in_features)}") 55 | print(f"Output shape: {output.shape}") 56 | print(f"Log prior shape: {log_prior.shape}") 57 | 58 | 59 | if __name__ == "__main__": 60 | pytest.main([__file__]) 61 | -------------------------------------------------------------------------------- /tests/distributions/test_finite.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # ============================================================================== 4 | # Copyright 2022 Luca Della Libera. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # ============================================================================== 18 | 19 | """Test finite distribution.""" 20 | 21 | import pytest 22 | import torch 23 | from torch.distributions import kl_divergence 24 | 25 | from bayestorch.distributions import Finite 26 | 27 | 28 | def test_finite() -> "None": 29 | logits = torch.as_tensor([0.25, 0.15, 0.10, 0.30, 0.20]) 30 | atoms = torch.as_tensor([5.0, 7.5, 10.0, 12.5, 15.0]) 31 | distribution = Finite(logits, atoms=atoms) 32 | print(distribution) 33 | print(distribution.expand((2, 3))) 34 | if distribution.has_rsample: 35 | distribution.rsample() 36 | else: 37 | distribution.sample() 38 | print(f"Mean: {distribution.mean}") 39 | print(f"Mode: {distribution.mode}") 40 | print(f"Standard deviation: {distribution.stddev}") 41 | print(f"Variance: {distribution.variance}") 42 | print(f"Log prob: {distribution.log_prob(distribution.sample())}") 43 | print(f"CDF: {distribution.cdf(distribution.sample())}") 44 | print(f"Entropy: {distribution.entropy()}") 45 | print(f"Support: {distribution.support}") 46 | print(f"Enumerated support: {distribution.enumerate_support()}") 47 | try: 48 | print(f"Enumerated support: {distribution.enumerate_support(False)}") 49 | except NotImplementedError: 50 | pass 51 | print( 52 | f"Kullback-Leibler divergence: " 53 | f"{kl_divergence(distribution, Finite(atoms, validate_args=True))}" 54 | ) 55 | 56 | 57 | if __name__ == "__main__": 58 | pytest.main([__file__]) 59 | -------------------------------------------------------------------------------- /tests/optim/test_sgld.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # ============================================================================== 4 | # Copyright 2022 Luca Della Libera. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # ============================================================================== 18 | 19 | """Stochastic gradient Langevin dynamics optimizer.""" 20 | 21 | import pytest 22 | import torch 23 | from torch import nn 24 | 25 | from bayestorch.optim import SGLD 26 | 27 | 28 | def test_sgld() -> "None": 29 | batch_size = 10 30 | in_features = 4 31 | out_features = 2 32 | num_burn_in_steps = 100 33 | model = nn.Linear(in_features, out_features) 34 | try: 35 | _ = SGLD(model.parameters(), lr=-1) 36 | _ = SGLD(model.parameters(), num_burn_in_steps=-1) 37 | _ = SGLD(model.parameters(), num_burn_in_steps=0.4) 38 | _ = SGLD(model.parameters(), precondition_decay_rate=-1) 39 | _ = SGLD(model.parameters(), epsilon=-1) 40 | except Exception: 41 | pass 42 | optimizer = SGLD(model.parameters(), num_burn_in_steps=num_burn_in_steps) 43 | input = torch.rand(batch_size, in_features) 44 | output = model(input) 45 | loss = output.sum() 46 | loss.backward() 47 | params_before = nn.utils.parameters_to_vector(model.parameters()) 48 | for _ in range(num_burn_in_steps + 1): 49 | optimizer.step() 50 | params_after = nn.utils.parameters_to_vector(model.parameters()) 51 | print(optimizer) 52 | print(f"Batch size: {batch_size}") 53 | print(f"Input shape: {(batch_size, in_features)}") 54 | print(f"Output shape: {output.shape}") 55 | print(f"Parameters shape: {params_before.shape}") 56 | assert not torch.allclose(params_before, params_after) 57 | 58 | 59 | if __name__ == "__main__": 60 | pytest.main([__file__]) 61 | -------------------------------------------------------------------------------- /tests/distributions/test_softplus_inv_scale_normal.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # ============================================================================== 4 | # Copyright 2022 Luca Della Libera. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # ============================================================================== 18 | 19 | """Test inverse softplus scale normal distribution.""" 20 | 21 | import pytest 22 | from torch.distributions import kl_divergence 23 | 24 | from bayestorch.distributions import SoftplusInvScaleNormal 25 | 26 | 27 | def test_softplus_inv_scale_normal() -> "None": 28 | loc = 0.0 29 | softplus_inv_scale = -1.0 30 | distribution = SoftplusInvScaleNormal(loc, softplus_inv_scale) 31 | print(distribution) 32 | print(distribution.expand((2, 3))) 33 | if distribution.has_rsample: 34 | distribution.rsample() 35 | else: 36 | distribution.sample() 37 | print(f"Mean: {distribution.mean}") 38 | print(f"Mode: {distribution.mode}") 39 | print(f"Standard deviation: {distribution.stddev}") 40 | print(f"Variance: {distribution.variance}") 41 | print(f"Log prob: {distribution.log_prob(distribution.sample())}") 42 | print(f"CDF: {distribution.cdf(distribution.sample())}") 43 | print(f"Entropy: {distribution.entropy()}") 44 | print(f"Support: {distribution.support}") 45 | try: 46 | print(f"Enumerated support: {distribution.enumerate_support()}") 47 | print(f"Enumerated support: {distribution.enumerate_support(False)}") 48 | except NotImplementedError: 49 | pass 50 | print( 51 | f"Kullback-Leibler divergence: " 52 | f"{kl_divergence(distribution, SoftplusInvScaleNormal(loc, softplus_inv_scale, validate_args=True))}" 53 | ) 54 | 55 | 56 | if __name__ == "__main__": 57 | pytest.main([__file__]) 58 | -------------------------------------------------------------------------------- /tests/optim/test_sghmc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # ============================================================================== 4 | # Copyright 2022 Luca Della Libera. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # ============================================================================== 18 | 19 | """Test stochastic gradient Hamiltonian Monte Carlo optimizer.""" 20 | 21 | import pytest 22 | import torch 23 | from torch import nn 24 | 25 | from bayestorch.optim import SGHMC 26 | 27 | 28 | def test_sghmc() -> "None": 29 | batch_size = 10 30 | in_features = 4 31 | out_features = 2 32 | num_burn_in_steps = 100 33 | model = nn.Linear(in_features, out_features) 34 | try: 35 | _ = SGHMC(model.parameters(), lr=-1) 36 | _ = SGHMC(model.parameters(), num_burn_in_steps=-1) 37 | _ = SGHMC(model.parameters(), num_burn_in_steps=0.4) 38 | _ = SGHMC(model.parameters(), momentum_decay=-1) 39 | _ = SGHMC(model.parameters(), grad_noise=-1) 40 | _ = SGHMC(model.parameters(), epsilon=-1) 41 | except Exception: 42 | pass 43 | optimizer = SGHMC(model.parameters(), num_burn_in_steps=num_burn_in_steps) 44 | input = torch.rand(batch_size, in_features) 45 | output = model(input) 46 | loss = output.sum() 47 | loss.backward() 48 | params_before = nn.utils.parameters_to_vector(model.parameters()) 49 | for _ in range(num_burn_in_steps + 1): 50 | optimizer.step() 51 | params_after = nn.utils.parameters_to_vector(model.parameters()) 52 | print(optimizer) 53 | print(f"Batch size: {batch_size}") 54 | print(f"Input shape: {(batch_size, in_features)}") 55 | print(f"Output shape: {output.shape}") 56 | print(f"Parameters shape: {params_before.shape}") 57 | assert not torch.allclose(params_before, params_after) 58 | 59 | 60 | if __name__ == "__main__": 61 | pytest.main([__file__]) 62 | -------------------------------------------------------------------------------- /bayestorch/nn/utils.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2022 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Neural network utilities.""" 18 | 19 | from typing import Callable, Dict, Sequence, TypeVar, Union 20 | 21 | from torch import Tensor 22 | 23 | 24 | __all__ = [ 25 | "nested_apply", 26 | ] 27 | 28 | 29 | _T = TypeVar("_T") 30 | 31 | _Nested = Union[_T, Sequence[_T], Dict[str, _T]] 32 | 33 | 34 | def nested_apply( 35 | operator: "Callable[[Sequence[Tensor], int], Tensor]", 36 | inputs: "Sequence[_Nested[Tensor]]", 37 | dim: "int" = 0, 38 | ) -> "_Nested[Tensor]": 39 | """Apply an operator to a sequence of possibly 40 | nested tensors along a dimension. 41 | 42 | Parameters 43 | ---------- 44 | operator: 45 | The operator, i.e. a callable that receives a 46 | sequence of possibly nested tensors and a 47 | dimension, and returns a tensor. 48 | inputs: 49 | The sequence of possibly nested tensors. 50 | dim: 51 | The dimension. 52 | 53 | Examples 54 | -------- 55 | >>> import torch 56 | >>> 57 | >>> from bayestorch.nn.utils import nested_apply 58 | >>> 59 | >>> 60 | >>> num_outputs = 4 61 | >>> inputs = [ 62 | ... {"a": [torch.rand(2, 3), torch.rand(3, 5)], "b": torch.rand(1, 2)} 63 | ... for _ in range(num_outputs) 64 | ... ] 65 | >>> outputs = nested_apply(torch.stack, inputs) 66 | 67 | """ 68 | first_input = inputs[0] 69 | if isinstance(first_input, Tensor): 70 | return operator(inputs, dim) 71 | if isinstance(first_input, dict): 72 | return type(first_input)( 73 | (k, nested_apply(operator, [output[k] for output in inputs], dim)) 74 | for k in first_input 75 | ) 76 | return type(first_input)( 77 | map(lambda inputs: nested_apply(operator, inputs, dim), zip(*inputs)) 78 | ) 79 | -------------------------------------------------------------------------------- /tests/nn/test_particle_posterior_module.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # ============================================================================== 4 | # Copyright 2022 Luca Della Libera. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # ============================================================================== 18 | 19 | """Test particle posterior module.""" 20 | 21 | import pytest 22 | import torch 23 | from torch import nn 24 | 25 | from bayestorch.distributions import LogScaleNormal 26 | from bayestorch.nn import ParticlePosteriorModule 27 | 28 | 29 | def test_particle_posterior_module() -> "None": 30 | num_particles = 5 31 | batch_size = 10 32 | in_features = 4 33 | out_features = 2 34 | model = nn.Linear(in_features, out_features) 35 | num_parameters = sum(parameter.numel() for parameter in model.parameters()) 36 | model = ParticlePosteriorModule( 37 | model, 38 | prior_builder=LogScaleNormal, 39 | prior_kwargs={ 40 | "loc": torch.zeros(num_parameters), 41 | "log_scale": torch.full((num_parameters,), -1.0), 42 | }, 43 | num_particles=num_particles, 44 | ) 45 | input = torch.rand(batch_size, in_features) 46 | for reduction in ["none", "mean"]: 47 | output = model(input, reduction=reduction) 48 | outputs, log_priors = model( 49 | input, 50 | return_log_prior=True, 51 | reduction=reduction, 52 | ) 53 | print(model) 54 | print(dict(model.named_parameters()).keys()) 55 | print(model.parameters()) 56 | print(dict(model.named_parameters(include_all=False)).keys()) 57 | print(model.parameters(include_all=False)) 58 | print(model.particles.shape) 59 | state_dict = model.state_dict() 60 | model.load_state_dict(state_dict) 61 | print(model.particles.shape) 62 | print(f"Number of particles: {num_particles}") 63 | print(f"Batch size: {batch_size}") 64 | print(f"Input shape: {(batch_size, in_features)}") 65 | print(f"Output shape: {output.shape}") 66 | print(f"Outputs shape: {outputs.shape}") 67 | print(f"Log priors shape: {log_priors.shape}") 68 | 69 | 70 | if __name__ == "__main__": 71 | pytest.main([__file__]) 72 | -------------------------------------------------------------------------------- /bayestorch/distributions/log_scale_normal.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2022 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Log scale normal distribution.""" 18 | 19 | from typing import Optional, Union 20 | 21 | from torch import Size, Tensor 22 | from torch.distributions import Normal, constraints 23 | 24 | 25 | __all__ = [ 26 | "LogScaleNormal", 27 | ] 28 | 29 | 30 | class LogScaleNormal(Normal): 31 | """Normal distribution parameterized by location 32 | and log scale parameters. 33 | 34 | Scale parameter is computed as `exp(log_scale)`. 35 | 36 | Examples 37 | -------- 38 | >>> from bayestorch.distributions import LogScaleNormal 39 | >>> 40 | >>> 41 | >>> loc = 0.0 42 | >>> log_scale = -1.0 43 | >>> distribution = LogScaleNormal(loc, log_scale) 44 | 45 | """ 46 | 47 | arg_constraints = { 48 | "loc": constraints.real, 49 | "log_scale": constraints.real, 50 | } # override 51 | 52 | # override 53 | def __init__( 54 | self, 55 | loc: "Union[int, float, Tensor]", 56 | log_scale: "Union[int, float, Tensor]", 57 | validate_args: "Optional[bool]" = None, 58 | ) -> "None": 59 | super().__init__(loc, log_scale, validate_args) 60 | 61 | @property 62 | def mode(self) -> "Tensor": 63 | return self.mean 64 | 65 | # override 66 | @property 67 | def scale(self) -> "Tensor": 68 | return self.log_scale.exp() 69 | 70 | # override 71 | @scale.setter 72 | def scale(self, value: "Tensor") -> "None": 73 | self.log_scale = value 74 | 75 | # override 76 | def expand( 77 | self, 78 | batch_shape: "Size" = Size(), # noqa: B008 79 | _instance: "Optional[LogScaleNormal]" = None, 80 | ) -> "LogScaleNormal": 81 | new = self._get_checked_instance(LogScaleNormal, _instance) 82 | loc = self.loc.expand(batch_shape) 83 | log_scale = self.log_scale.expand(batch_shape) 84 | super(LogScaleNormal, new).__init__(loc, log_scale, validate_args=False) 85 | new._validate_args = self._validate_args 86 | return new 87 | -------------------------------------------------------------------------------- /tests/optim/test_svgd.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # ============================================================================== 4 | # Copyright 2022 Luca Della Libera. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # ============================================================================== 18 | 19 | """Test Stein variational gradient descent preconditioner.""" 20 | 21 | import math 22 | 23 | import pytest 24 | import torch 25 | from torch import Tensor, nn 26 | 27 | from bayestorch.optim import SVGD 28 | 29 | 30 | def rbf_kernel(x1: "Tensor", x2: "Tensor") -> "Tensor": 31 | deltas = torch.cdist(x1, x2) 32 | squared_deltas = deltas**2 33 | bandwidth = squared_deltas.detach().median() / math.log( 34 | min(x1.shape[0], x2.shape[0]) 35 | ) 36 | log_kernels = -squared_deltas / bandwidth 37 | kernels = log_kernels.exp() 38 | return kernels 39 | 40 | 41 | def test_svgd() -> "None": 42 | num_particles = 5 43 | batch_size = 10 44 | in_features = 4 45 | out_features = 2 46 | models = nn.ModuleList( 47 | [nn.Linear(in_features, out_features) for _ in range(num_particles)] 48 | ) 49 | try: 50 | _ = SVGD(models.parameters(), num_particles=-1) 51 | _ = SVGD(models.parameters(), num_particles=0.5) 52 | _ = SVGD(models.parameters(), num_particles=3) 53 | except Exception: 54 | pass 55 | preconditioner = SVGD(models.parameters(), rbf_kernel, num_particles) 56 | input = torch.rand(batch_size, in_features) 57 | outputs = torch.cat([model(input) for model in models]) 58 | loss = outputs.sum() 59 | loss.backward() 60 | grads_before = nn.utils.parameters_to_vector( 61 | (parameter.grad for parameter in models.parameters()) 62 | ) 63 | preconditioner.step() 64 | grads_after = nn.utils.parameters_to_vector( 65 | (parameter.grad for parameter in models.parameters()) 66 | ) 67 | print(preconditioner) 68 | print(f"Number of particles: {num_particles}") 69 | print(f"Batch size: {batch_size}") 70 | print(f"Input shape: {(batch_size, in_features)}") 71 | print(f"Outputs shape: {outputs.shape}") 72 | print(f"Gradients shape: {grads_before.shape}") 73 | assert not torch.allclose(grads_before, grads_after) 74 | 75 | 76 | if __name__ == "__main__": 77 | pytest.main([__file__]) 78 | -------------------------------------------------------------------------------- /bayestorch/distributions/softplus_inv_scale_normal.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2022 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Inverse softplus scale normal distribution.""" 18 | 19 | from typing import Optional, Union 20 | 21 | import torch.nn.functional as F 22 | from torch import Size, Tensor 23 | from torch.distributions import Normal, constraints 24 | 25 | 26 | __all__ = [ 27 | "SoftplusInvScaleNormal", 28 | ] 29 | 30 | 31 | class SoftplusInvScaleNormal(Normal): 32 | """Normal distribution parameterized by location 33 | and inverse softplus scale parameters. 34 | 35 | Scale parameter is computed as `softplus(softplus_inv_scale)`. 36 | 37 | Examples 38 | -------- 39 | >>> from bayestorch.distributions import SoftplusInvScaleNormal 40 | >>> 41 | >>> 42 | >>> loc = 0.0 43 | >>> softplus_inv_scale = -1.0 44 | >>> distribution = SoftplusInvScaleNormal(loc, softplus_inv_scale) 45 | 46 | """ 47 | 48 | arg_constraints = { 49 | "loc": constraints.real, 50 | "softplus_inv_scale": constraints.real, 51 | } # override 52 | 53 | # override 54 | def __init__( 55 | self, 56 | loc: "Union[int, float, Tensor]", 57 | softplus_inv_scale: "Union[int, float, Tensor]", 58 | validate_args: "Optional[bool]" = None, 59 | ) -> "None": 60 | super().__init__(loc, softplus_inv_scale, validate_args) 61 | 62 | @property 63 | def mode(self) -> "Tensor": 64 | return self.mean 65 | 66 | # override 67 | @property 68 | def scale(self) -> "Tensor": 69 | return F.softplus(self.softplus_inv_scale) 70 | 71 | # override 72 | @scale.setter 73 | def scale(self, value: "Tensor") -> "None": 74 | self.softplus_inv_scale = value 75 | 76 | # override 77 | def expand( 78 | self, 79 | batch_shape: "Size" = Size(), # noqa: B008 80 | _instance: "Optional[SoftplusInvScaleNormal]" = None, 81 | ) -> "SoftplusInvScaleNormal": 82 | new = self._get_checked_instance(SoftplusInvScaleNormal, _instance) 83 | loc = self.loc.expand(batch_shape) 84 | softplus_inv_scale = self.softplus_inv_scale.expand(batch_shape) 85 | super(SoftplusInvScaleNormal, new).__init__( 86 | loc, softplus_inv_scale, validate_args=False 87 | ) 88 | new._validate_args = self._validate_args 89 | return new 90 | -------------------------------------------------------------------------------- /tests/distributions/test_cat_distribution.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # ============================================================================== 4 | # Copyright 2022 Luca Della Libera. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # ============================================================================== 18 | 19 | """Test concatenated distribution.""" 20 | 21 | import pytest 22 | import torch 23 | from torch.distributions import Categorical, Independent, Normal, kl_divergence 24 | 25 | from bayestorch.distributions import CatDistribution 26 | 27 | 28 | def test_cat_distribution() -> "None": 29 | loc = 0.0 30 | scale = 1.0 31 | logits = torch.as_tensor([0.25, 0.15, 0.10, 0.30, 0.20]) 32 | try: 33 | _ = CatDistribution([Normal(loc, scale), Categorical(logits)], dim=3) 34 | _ = CatDistribution( 35 | [ 36 | Normal(torch.full((2, 3), loc), torch.full((2, 3), scale)), 37 | Categorical(logits), 38 | ] 39 | ) 40 | _ = CatDistribution( 41 | [ 42 | Independent( 43 | Normal(torch.full((2, 3, 2), loc), torch.full((2, 3, 2), scale)), 2 44 | ), 45 | Categorical(logits.expand(2, 5)), 46 | ], 47 | dim=1, 48 | ) 49 | except Exception: 50 | pass 51 | distribution = CatDistribution([Normal(loc, scale), Categorical(logits)]) 52 | print(distribution) 53 | print(distribution.expand((2, 3))) 54 | if distribution.has_rsample: 55 | distribution.rsample() 56 | else: 57 | distribution.sample() 58 | print(f"Mean: {distribution.mean}") 59 | print(f"Mode: {distribution.mode}") 60 | print(f"Standard deviation: {distribution.stddev}") 61 | print(f"Variance: {distribution.variance}") 62 | print(f"Log prob: {distribution.log_prob(distribution.sample())}") 63 | print(f"Entropy: {distribution.entropy()}") 64 | print(f"Support: {distribution.support}") 65 | try: 66 | print(f"CDF: {distribution.cdf(distribution.sample())}") 67 | print(f"Enumerated support: {distribution.enumerate_support()}") 68 | print(f"Enumerated support: {distribution.enumerate_support(False)}") 69 | except NotImplementedError: 70 | pass 71 | print( 72 | f"Kullback-Leibler divergence: " 73 | f"{kl_divergence(distribution, CatDistribution([Normal(loc, scale), Categorical(logits)], validate_args=True))}" 74 | ) 75 | print( 76 | f"Kullback-Leibler divergence: " 77 | f"{kl_divergence(CatDistribution([Normal(loc, scale)]), Normal(loc, scale))}" 78 | ) 79 | print( 80 | f"Kullback-Leibler divergence: " 81 | f"{kl_divergence(Normal(loc, scale), CatDistribution([Normal(loc, scale)]))}" 82 | ) 83 | 84 | 85 | if __name__ == "__main__": 86 | pytest.main([__file__]) 87 | -------------------------------------------------------------------------------- /examples/regression/train_bbb.py: -------------------------------------------------------------------------------- 1 | # Adapted from: 2 | # https://github.com/pytorch/examples/blob/9aad148615b7519eadfa1a60356116a50561f192/regression/main.py 3 | 4 | # Changes to the code are kept to a minimum to facilitate the comparison with the original example 5 | 6 | #!/usr/bin/env python 7 | from __future__ import print_function 8 | from itertools import count 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | from bayestorch.distributions import get_log_scale_normal, get_softplus_inv_scale_normal 14 | from bayestorch.nn import VariationalPosteriorModule 15 | 16 | POLY_DEGREE = 4 17 | W_target = torch.randn(POLY_DEGREE, 1) * 5 18 | b_target = torch.randn(1) * 5 19 | 20 | 21 | def make_features(x): 22 | """Builds features i.e. a matrix with columns [x, x^2, x^3, x^4].""" 23 | x = x.unsqueeze(1) 24 | return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1) 25 | 26 | 27 | def f(x): 28 | """Approximated function.""" 29 | return x.mm(W_target) + b_target.item() 30 | 31 | 32 | def poly_desc(W, b): 33 | """Creates a string description of a polynomial.""" 34 | result = 'y = ' 35 | for i, w in enumerate(W): 36 | result += '{:+.2f} x^{} '.format(w, i + 1) 37 | result += '{:+.2f}'.format(b[0]) 38 | return result 39 | 40 | 41 | def get_batch(batch_size=32): 42 | """Builds a batch i.e. (x, f(x)) pair.""" 43 | random = torch.randn(batch_size) 44 | x = make_features(random) 45 | y = f(x) 46 | return x, y 47 | 48 | 49 | # Number of Monte Carlo samples 50 | num_mc_samples = 10 51 | 52 | # Kullback-Leibler divergence weight 53 | kl_div_weight = 1e-1 54 | 55 | # Define model 56 | fc = torch.nn.Linear(W_target.size(0), 1) 57 | 58 | # Prior arguments (WITHOUT gradient tracking) 59 | prior_builder, prior_kwargs = get_log_scale_normal(fc.parameters(), 0.0, -1.0) 60 | 61 | # Posterior arguments (WITH gradient tracking) 62 | posterior_builder, posterior_kwargs = get_softplus_inv_scale_normal(fc.parameters(), 0.0, -7.0, requires_grad=True) 63 | 64 | # Bayesian model 65 | fc = VariationalPosteriorModule(fc, prior_builder, prior_kwargs, posterior_builder, posterior_kwargs) 66 | 67 | for batch_idx in count(1): 68 | # Get data 69 | batch_x, batch_y = get_batch() 70 | 71 | # Reset gradients 72 | fc.zero_grad() 73 | 74 | # Forward pass 75 | #output = F.smooth_l1_loss(fc(batch_x), batch_y) 76 | #loss = output.item() 77 | output, kl_div = fc(batch_x, num_mc_samples=num_mc_samples, return_kl_div=True) 78 | loss = F.smooth_l1_loss(output, batch_y, reduction="sum") + kl_div_weight * kl_div 79 | 80 | # Backward pass 81 | loss.backward() 82 | loss = loss.item() 83 | 84 | # Apply gradients 85 | for param in fc.parameters(): 86 | param.data.add_(-0.1 * param.grad) 87 | 88 | # Stop criterion 89 | if loss < 1e2: 90 | break 91 | 92 | print('Loss: {:.6f} after {} batches'.format(loss, batch_idx)) 93 | #print('==> Learned function:\t' + poly_desc(fc.weight.view(-1), fc.bias)) 94 | #print('==> Actual function:\t' + poly_desc(W_target.view(-1), b_target)) 95 | print('==> Learned function mean: \t' + poly_desc( 96 | fc.posterior_loc[:-1], 97 | fc.posterior_loc[-1][None], 98 | )) 99 | print('==> Learned function standard deviation:\t' + poly_desc( 100 | F.softplus(fc.posterior_softplus_inv_scale[:-1]), 101 | F.softplus(fc.posterior_softplus_inv_scale[-1][None]), 102 | )) 103 | print('==> Actual function: \t' + poly_desc(W_target.view(-1), b_target)) 104 | -------------------------------------------------------------------------------- /examples/regression/train_mcmc.py: -------------------------------------------------------------------------------- 1 | # Adapted from: 2 | # https://github.com/pytorch/examples/blob/9aad148615b7519eadfa1a60356116a50561f192/regression/main.py 3 | 4 | # Changes to the code are kept to a minimum to facilitate the comparison with the original example 5 | 6 | #!/usr/bin/env python 7 | from __future__ import print_function 8 | from itertools import count 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | import copy 14 | from collections import deque 15 | from bayestorch.distributions import get_log_scale_normal 16 | from bayestorch.nn import PriorModule 17 | from bayestorch.optim import SGLD 18 | 19 | POLY_DEGREE = 4 20 | W_target = torch.randn(POLY_DEGREE, 1) * 5 21 | b_target = torch.randn(1) * 5 22 | 23 | 24 | def make_features(x): 25 | """Builds features i.e. a matrix with columns [x, x^2, x^3, x^4].""" 26 | x = x.unsqueeze(1) 27 | return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1) 28 | 29 | 30 | def f(x): 31 | """Approximated function.""" 32 | return x.mm(W_target) + b_target.item() 33 | 34 | 35 | def poly_desc(W, b): 36 | """Creates a string description of a polynomial.""" 37 | result = 'y = ' 38 | for i, w in enumerate(W): 39 | result += '{:+.2f} x^{} '.format(w, i + 1) 40 | result += '{:+.2f}'.format(b[0]) 41 | return result 42 | 43 | 44 | def get_batch(batch_size=32): 45 | """Builds a batch i.e. (x, f(x)) pair.""" 46 | random = torch.randn(batch_size) 47 | x = make_features(random) 48 | y = f(x) 49 | return x, y 50 | 51 | 52 | # Log prior weight 53 | log_prior_weight = 1e-2 54 | 55 | # Define model 56 | fc = torch.nn.Linear(W_target.size(0), 1) 57 | 58 | # Prior arguments (WITHOUT gradient tracking) 59 | prior_builder, prior_kwargs = get_log_scale_normal(fc.parameters(), 0.0, -1.0) 60 | 61 | # Bayesian model 62 | fc = PriorModule(fc, prior_builder, prior_kwargs) 63 | 64 | # Optimizer 65 | optimizer = SGLD( 66 | fc.parameters(), 67 | lr=1e-2, 68 | num_burn_in_steps=200, 69 | precondition_decay_rate=0.95, 70 | ) 71 | 72 | # Keep track of the last 10 models 73 | models = deque(maxlen=10) 74 | 75 | for batch_idx in count(1): 76 | # Get data 77 | batch_x, batch_y = get_batch() 78 | 79 | # Reset gradients 80 | fc.zero_grad() 81 | 82 | # Forward pass 83 | #output = F.smooth_l1_loss(fc(batch_x), batch_y) 84 | #loss = output.item() 85 | output, log_prior = fc(batch_x, return_log_prior=True) 86 | loss = F.smooth_l1_loss(output, batch_y, reduction="sum") - log_prior_weight * log_prior 87 | 88 | # Backward pass 89 | loss.backward() 90 | loss = loss.item() 91 | 92 | # Optimizer step 93 | optimizer.step() 94 | 95 | # Save model 96 | models.append(copy.deepcopy(fc.module)) 97 | 98 | # Apply gradients 99 | #for param in fc.parameters(): 100 | # param.data.add_(-0.1 * param.grad) 101 | 102 | # Stop criterion 103 | if loss < 1e1: 104 | break 105 | 106 | print('Loss: {:.6f} after {} batches'.format(loss, batch_idx)) 107 | #print('==> Learned function:\t' + poly_desc(fc.weight.view(-1), fc.bias)) 108 | #print('==> Actual function:\t' + poly_desc(W_target.view(-1), b_target)) 109 | print('==> Learned function mean: \t' + poly_desc( 110 | torch.stack([model.weight.view(-1) for model in models]).mean(dim=0), 111 | torch.stack([model.bias for model in models]).mean(dim=0), 112 | )) 113 | print('==> Learned function standard deviation:\t' + poly_desc( 114 | torch.stack([model.weight.view(-1) for model in models]).std(dim=0), 115 | torch.stack([model.bias for model in models]).std(dim=0), 116 | )) 117 | print('==> Actual function: \t' + poly_desc(W_target.view(-1), b_target)) 118 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | This project incorporates components from the projects listed below. The original copyright notices are set forth below. 2 | 3 | ############################################################################################################################################################# 4 | 5 | 1. Code in bayestorch/optim/{sghmc, sgld}.py adapted from: 6 | https://github.com/JavierAntoran/Bayesian-Neural-Networks/blob/1f867a5bcbd1abfecede99807eb0b5f97ed8be7c/src/Stochastic_Gradient_HMC_SA/optimizers.py 7 | https://github.com/JavierAntoran/Bayesian-Neural-Networks/blob/1f867a5bcbd1abfecede99807eb0b5f97ed8be7c/src/Stochastic_Gradient_Langevin_Dynamics/optimizers.py 8 | 9 | MIT License 10 | 11 | Copyright (c) 2019 Javier Antoran 12 | 13 | Permission is hereby granted, free of charge, to any person obtaining a copy 14 | of this software and associated documentation files (the "Software"), to deal 15 | in the Software without restriction, including without limitation the rights 16 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 17 | copies of the Software, and to permit persons to whom the Software is 18 | furnished to do so, subject to the following conditions: 19 | 20 | The above copyright notice and this permission notice shall be included in all 21 | copies or substantial portions of the Software. 22 | 23 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 24 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 25 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 26 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 27 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 28 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 29 | SOFTWARE. 30 | 31 | ############################################################################################################################################################# 32 | 33 | 2. Code in examples/{mnist, regression}/{train_bbb, train_mcmc, train_svgd}.py adapted from: 34 | https://github.com/pytorch/examples/blob/9aad148615b7519eadfa1a60356116a50561f192/mnist/main.py 35 | https://github.com/pytorch/examples/blob/9aad148615b7519eadfa1a60356116a50561f192/regression/main.py 36 | 37 | BSD 3-Clause License 38 | 39 | Copyright (c) 2017, 40 | All rights reserved. 41 | 42 | Redistribution and use in source and binary forms, with or without 43 | modification, are permitted provided that the following conditions are met: 44 | 45 | * Redistributions of source code must retain the above copyright notice, this 46 | list of conditions and the following disclaimer. 47 | 48 | * Redistributions in binary form must reproduce the above copyright notice, 49 | this list of conditions and the following disclaimer in the documentation 50 | and/or other materials provided with the distribution. 51 | 52 | * Neither the name of the copyright holder nor the names of its 53 | contributors may be used to endorse or promote products derived from 54 | this software without specific prior written permission. 55 | 56 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 57 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 58 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 59 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 60 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 61 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 62 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 63 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 64 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 65 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 66 | 67 | ############################################################################################################################################################# 68 | -------------------------------------------------------------------------------- /examples/regression/train_svgd.py: -------------------------------------------------------------------------------- 1 | # Adapted from: 2 | # https://github.com/pytorch/examples/blob/9aad148615b7519eadfa1a60356116a50561f192/regression/main.py 3 | 4 | # Changes to the code are kept to a minimum to facilitate the comparison with the original example 5 | 6 | #!/usr/bin/env python 7 | from __future__ import print_function 8 | from itertools import count 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | import math 14 | from bayestorch.distributions import get_log_scale_normal 15 | from bayestorch.nn import ParticlePosteriorModule 16 | from bayestorch.optim import SVGD 17 | 18 | POLY_DEGREE = 4 19 | W_target = torch.randn(POLY_DEGREE, 1) * 5 20 | b_target = torch.randn(1) * 5 21 | 22 | 23 | def make_features(x): 24 | """Builds features i.e. a matrix with columns [x, x^2, x^3, x^4].""" 25 | x = x.unsqueeze(1) 26 | return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1) 27 | 28 | 29 | def f(x): 30 | """Approximated function.""" 31 | return x.mm(W_target) + b_target.item() 32 | 33 | 34 | def poly_desc(W, b): 35 | """Creates a string description of a polynomial.""" 36 | result = 'y = ' 37 | for i, w in enumerate(W): 38 | result += '{:+.2f} x^{} '.format(w, i + 1) 39 | result += '{:+.2f}'.format(b[0]) 40 | return result 41 | 42 | 43 | def get_batch(batch_size=32): 44 | """Builds a batch i.e. (x, f(x)) pair.""" 45 | random = torch.randn(batch_size) 46 | x = make_features(random) 47 | y = f(x) 48 | return x, y 49 | 50 | 51 | def rbf_kernel(x1, x2): 52 | deltas = torch.cdist(x1, x2) 53 | squared_deltas = deltas**2 54 | bandwidth = ( 55 | squared_deltas.detach().median() 56 | / math.log(min(x1.shape[0], x2.shape[0])) 57 | ) 58 | log_kernels = -squared_deltas / bandwidth 59 | kernels = log_kernels.exp() 60 | return kernels 61 | 62 | 63 | # Number of particles 64 | num_particles = 10 65 | 66 | # Log prior weight 67 | log_prior_weight = 1e-1 68 | 69 | # Define model 70 | fc = torch.nn.Linear(W_target.size(0), 1) 71 | 72 | # Prior arguments (WITHOUT gradient tracking) 73 | prior_builder, prior_kwargs = get_log_scale_normal(fc.parameters(), 0.0, -1.0) 74 | 75 | # Bayesian model 76 | fc = ParticlePosteriorModule(fc, prior_builder, prior_kwargs, num_particles) 77 | 78 | # SVGD preconditioner 79 | preconditioner = SVGD(fc.parameters(include_all=False), rbf_kernel, num_particles) 80 | 81 | for batch_idx in count(1): 82 | # Get data 83 | batch_x, batch_y = get_batch() 84 | 85 | # Reset gradients 86 | fc.zero_grad() 87 | 88 | # Forward pass 89 | #output = F.smooth_l1_loss(fc(batch_x), batch_y) 90 | #loss = output.item() 91 | output, log_prior = fc(batch_x, return_log_prior=True) 92 | loss = F.smooth_l1_loss(output, batch_y, reduction="sum") - log_prior_weight * log_prior 93 | 94 | # Backward pass 95 | loss.backward() 96 | loss = loss.item() 97 | 98 | # SVGD step 99 | preconditioner.step() 100 | 101 | # Apply gradients 102 | for param in fc.parameters(): 103 | param.data.add_(-0.1 * param.grad) 104 | 105 | # Stop criterion 106 | if loss < 1e2: 107 | break 108 | 109 | print('Loss: {:.6f} after {} batches'.format(loss, batch_idx)) 110 | #print('==> Learned function:\t' + poly_desc(fc.weight.view(-1), fc.bias)) 111 | #print('==> Actual function:\t' + poly_desc(W_target.view(-1), b_target)) 112 | print('==> Learned function mean: \t' + poly_desc( 113 | torch.stack([replica.weight.view(-1) for replica in fc.replicas]).mean(dim=0), 114 | torch.stack([replica.bias for replica in fc.replicas]).mean(dim=0), 115 | )) 116 | print('==> Learned function standard deviation:\t' + poly_desc( 117 | torch.stack([replica.weight.view(-1) for replica in fc.replicas]).std(dim=0), 118 | torch.stack([replica.bias for replica in fc.replicas]).std(dim=0), 119 | )) 120 | print('==> Actual function: \t' + poly_desc(W_target.view(-1), b_target)) 121 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2022 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Setup script.""" 18 | 19 | 20 | ################## Install setup requirements ################## 21 | def _preinstall_requirement(requirement, options=None): 22 | import subprocess 23 | 24 | args = ["pip", "install", requirement, *(options or [])] 25 | return_code = subprocess.call(args) 26 | if return_code != 0: 27 | raise RuntimeError(f"{requirement} installation failed") 28 | 29 | 30 | for requirement in ["setuptools>=58.0.4", "wheel>=0.37.1"]: 31 | _preinstall_requirement(requirement) 32 | ################################################################ 33 | 34 | 35 | import os # noqa: E402 36 | 37 | from setuptools import find_packages, setup # noqa: E402 38 | 39 | 40 | _ROOT_DIR = os.path.dirname(os.path.realpath(__file__)) 41 | 42 | with open(os.path.join(_ROOT_DIR, "bayestorch", "version.py")) as f: 43 | tmp = {} 44 | exec(f.read(), tmp) 45 | _VERSION = tmp["VERSION"] 46 | 47 | with open(os.path.join(_ROOT_DIR, "README.md"), encoding="utf-8") as f: 48 | _README = f.read() 49 | 50 | setup( 51 | name="bayestorch", 52 | version=_VERSION, 53 | description="Lightweight Bayesian deep learning library for fast prototyping based on PyTorch", 54 | long_description=_README, 55 | long_description_content_type="text/markdown", 56 | author="Luca Della Libera", 57 | author_email="luca.dellalib@gmail.com", 58 | url="https://github.com/lucadellalib/bayestorch", 59 | packages=find_packages(), 60 | classifiers=[ 61 | "Development Status :: 3 - Alpha", 62 | "Environment :: Console", 63 | "Environment :: GPU :: NVIDIA CUDA", 64 | "Intended Audience :: Developers", 65 | "Intended Audience :: Information Technology", 66 | "Intended Audience :: Science/Research", 67 | "License :: OSI Approved :: Apache Software License", 68 | "Natural Language :: English", 69 | "Operating System :: OS Independent", 70 | "Programming Language :: Python :: 3", 71 | "Programming Language :: Python :: 3.6", 72 | "Programming Language :: Python :: 3.7", 73 | "Programming Language :: Python :: 3.8", 74 | "Programming Language :: Python :: 3.9", 75 | "Programming Language :: Python :: 3.10", 76 | "Programming Language :: Python :: 3 :: Only", 77 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 78 | "Topic :: Software Development :: Libraries :: Python Modules", 79 | "Typing :: Typed", 80 | ], 81 | license="Apache License 2.0", 82 | keywords=["Bayesian deep learning", "PyTorch"], 83 | platforms=["OS Independent"], 84 | include_package_data=True, 85 | install_requires=["torch>=1.5.0"], 86 | extras_require={ 87 | "test": [ 88 | "pytest>=5.4.3", 89 | "pytest-cov>=2.9.0", 90 | ], 91 | "dev": [ 92 | "black>=22.3.0", 93 | "cibuildwheel>=2.3.1", 94 | "flake8>=3.8.3", 95 | "flake8-bugbear>=20.1.4", 96 | "isort>=5.4.2", 97 | "pre-commit>=2.6.0", 98 | "pre-commit-hooks>=3.2.0", 99 | "pytest>=5.4.3", 100 | "pytest-cov>=2.9.0", 101 | "twine>=3.3.0", 102 | ], 103 | }, 104 | python_requires=">=3.6", 105 | ) 106 | -------------------------------------------------------------------------------- /tests/nn/test_variational_posterior_module.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # ============================================================================== 4 | # Copyright 2022 Luca Della Libera. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # ============================================================================== 18 | 19 | """Test variational posterior module.""" 20 | 21 | import pytest 22 | import torch 23 | from torch import nn 24 | 25 | from bayestorch.distributions import LogScaleNormal, SoftplusInvScaleNormal 26 | from bayestorch.nn import VariationalPosteriorModule 27 | 28 | 29 | def test_variational_posterior_module() -> "None": 30 | num_mc_samples = 5 31 | batch_size = 10 32 | in_features = 4 33 | out_features = 2 34 | model = nn.Linear(in_features, out_features) 35 | num_parameters = sum(parameter.numel() for parameter in model.parameters()) 36 | model = VariationalPosteriorModule( 37 | model, 38 | prior_builder=LogScaleNormal, 39 | prior_kwargs={ 40 | "loc": torch.zeros(num_parameters), 41 | "log_scale": torch.full((num_parameters,), -1.0), 42 | }, 43 | posterior_builder=SoftplusInvScaleNormal, 44 | posterior_kwargs={ 45 | "loc": torch.zeros(num_parameters, requires_grad=True), 46 | "softplus_inv_scale": torch.full( 47 | (num_parameters,), 48 | -7.0, 49 | requires_grad=True, 50 | ), 51 | }, 52 | ).to("cpu") 53 | input = torch.rand(batch_size, in_features) 54 | for reduction in ["none", "mean"]: 55 | output = model( 56 | input, 57 | num_mc_samples=num_mc_samples, 58 | reduction=reduction, 59 | ) 60 | loss = output.sum() 61 | loss.backward() 62 | outputs, kl_divs = model( 63 | input, 64 | num_mc_samples=num_mc_samples, 65 | return_kl_div=True, 66 | reduction=reduction, 67 | ) 68 | loss = outputs.sum() + kl_divs.sum() 69 | loss.backward() 70 | with torch.no_grad(): 71 | _ = model( 72 | input, 73 | num_mc_samples=num_mc_samples, 74 | reduction=reduction, 75 | ) 76 | _, _ = model( 77 | input, 78 | num_mc_samples=num_mc_samples, 79 | return_kl_div=True, 80 | reduction=reduction, 81 | ) 82 | outputs, kl_divs = model( 83 | input, 84 | num_mc_samples=num_mc_samples, 85 | return_kl_div=True, 86 | exact_kl_div=True, 87 | reduction=reduction, 88 | ) 89 | loss = outputs.sum() + kl_divs.sum() 90 | loss.backward() 91 | with torch.no_grad(): 92 | outputs, kl_divs = model( 93 | input, 94 | num_mc_samples=num_mc_samples, 95 | return_kl_div=True, 96 | exact_kl_div=True, 97 | reduction=reduction, 98 | ) 99 | print(model) 100 | print(dict(model.named_parameters()).keys()) 101 | print(model.parameters()) 102 | print(dict(model.named_parameters(include_all=False)).keys()) 103 | print(model.parameters(include_all=False)) 104 | state_dict = model.state_dict() 105 | model.load_state_dict(state_dict) 106 | print(f"Number of Monte Carlo samples: {num_mc_samples}") 107 | print(f"Batch size: {batch_size}") 108 | print(f"Input shape: {(batch_size, in_features)}") 109 | print(f"Output shape: {output.shape}") 110 | print(f"Outputs shape: {outputs.shape}") 111 | print(f"Kullback-Leibler divergences shape: {kl_divs.shape}") 112 | 113 | 114 | if __name__ == "__main__": 115 | pytest.main([__file__]) 116 | -------------------------------------------------------------------------------- /bayestorch/distributions/constraints.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2022 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Constraints.""" 18 | 19 | import torch 20 | from torch import Tensor 21 | from torch.distributions import constraints 22 | 23 | 24 | __all__ = [ 25 | "cat", 26 | "ordered_real_vector", 27 | "real_set", 28 | ] 29 | 30 | 31 | class _Cat(constraints.cat): 32 | """Extended version of `torch.distributions.constraints.cat` 33 | that implements `is_discrete` and `check` correctly. 34 | 35 | Examples 36 | -------- 37 | >>> import torch 38 | >>> from torch.distributions.constraints import independent, positive, real 39 | >>> 40 | >>> from bayestorch.distributions.constraints import cat 41 | >>> 42 | >>> 43 | >>> constraint = cat([independent(real, 1), independent(positive, 1)], lengths=(2, 1)) 44 | >>> check = constraint.check(torch.as_tensor([-0.2, -0.5, 2.3])) 45 | 46 | """ 47 | 48 | # override 49 | @property 50 | def is_discrete(self) -> "bool": 51 | return all(c.is_discrete for c in self.cseq) 52 | 53 | # override 54 | def check(self, value: "Tensor") -> "Tensor": 55 | if self.dim < -value.ndim or self.dim >= value.ndim: 56 | raise IndexError( 57 | f"`dim` ({self.dim}) must be in the integer interval [-{value.ndim}, {value.ndim})" 58 | ) 59 | chunks = value.split(self.lengths, dim=self.dim) 60 | checks = [c.check(chunks[i]) for i, c in enumerate(self.cseq)] 61 | return torch.stack(checks).all(dim=0) 62 | 63 | # override 64 | def __repr__(self) -> "str": 65 | return ( 66 | f"{type(self).__name__[1:]}" 67 | f"({self.cseq}, " 68 | f"dim: {self.dim}, " 69 | f"lengths: {self.lengths})" 70 | ) 71 | 72 | 73 | class _OrderedRealVector(constraints.Constraint): 74 | """Constrain to a real-valued vector whose elements 75 | are sorted in (strict) ascending order. 76 | 77 | Examples 78 | -------- 79 | >>> import torch 80 | >>> 81 | >>> from bayestorch.distributions.constraints import ordered_real_vector 82 | >>> 83 | >>> 84 | >>> constraint = ordered_real_vector 85 | >>> check = constraint.check(torch.as_tensor([0.2, 0.4, 0.8])) 86 | 87 | """ 88 | 89 | event_dim = 1 90 | 91 | # override 92 | def check(self, value: "Tensor") -> "Tensor": 93 | return (value[..., 1:] > value[..., :-1]).all(dim=-1) 94 | 95 | 96 | class _RealSet(constraints.Constraint): 97 | """Constrain to a set (i.e. no duplicates allowed) 98 | of real values. 99 | 100 | Examples 101 | -------- 102 | >>> import torch 103 | >>> 104 | >>> from bayestorch.distributions.constraints import real_set 105 | >>> 106 | >>> 107 | >>> constraint = real_set(torch.as_tensor([0.2, 0.4, 0.8])) 108 | >>> check = constraint.check(torch.as_tensor(0.2)) 109 | 110 | """ 111 | 112 | is_discrete = True 113 | 114 | # override 115 | def __init__(self, values: "Tensor") -> "None": 116 | """Initialize the object. 117 | 118 | Parameters 119 | ---------- 120 | values: 121 | The set of real values. All dimensions except for 122 | the last one are interpreted as batch dimensions. 123 | 124 | Raises 125 | ------ 126 | ValueError 127 | If `values` contains duplicates. 128 | 129 | """ 130 | sorted_values = values.sort(dim=-1).values 131 | if (sorted_values[..., 1:] == sorted_values[..., :-1]).any(dim=-1).any(): 132 | raise ValueError(f"`values` ({values}) must not contain duplicates") 133 | self._values = values 134 | 135 | # override 136 | def check(self, value: "Tensor") -> "Tensor": 137 | # Add event dimension 138 | value = value[..., None].expand(*value.shape, self._values.shape[-1]) 139 | expanded_support = self._values.expand_as(value) 140 | return (expanded_support == value).any(dim=-1) 141 | 142 | # override 143 | def __repr__(self) -> "str": 144 | return f"{type(self).__name__[1:]}(values: {self._values})" 145 | 146 | 147 | # Public interface 148 | cat = _Cat 149 | ordered_real_vector = _OrderedRealVector() 150 | real_set = _RealSet 151 | -------------------------------------------------------------------------------- /bayestorch/distributions/deterministic.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2022 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Deterministic distribution.""" 18 | 19 | from typing import Optional, Union 20 | 21 | import torch 22 | from torch import Size, Tensor 23 | from torch.distributions import Distribution, constraints, register_kl 24 | 25 | from bayestorch.distributions.finite import Finite 26 | 27 | 28 | __all__ = [ 29 | "Deterministic", 30 | ] 31 | 32 | 33 | class Deterministic(Finite): 34 | """Distribution that returns a single deterministic 35 | value with probability equal to 1. 36 | 37 | Examples 38 | -------- 39 | >>> from bayestorch.distributions import Deterministic 40 | >>> 41 | >>> 42 | >>> value = 1.0 43 | >>> distribution = Deterministic(value) 44 | 45 | """ 46 | 47 | has_rsample = True 48 | arg_constraints = { 49 | "value": constraints.real, 50 | } 51 | 52 | # override 53 | def __init__( 54 | self, 55 | value: "Union[int, float, Tensor]", 56 | validate_args: "Optional[bool]" = None, 57 | ) -> "None": 58 | """Initialize the object. 59 | 60 | Parameters 61 | ---------- 62 | value: 63 | The deterministic value to return. 64 | validate_args: 65 | True to validate the arguments, False otherwise. 66 | Default to ``__debug__``. 67 | 68 | """ 69 | self.value = torch.as_tensor(value) 70 | atoms = self.value[..., None] 71 | probs = torch.ones_like(atoms) 72 | super().__init__(probs, atoms=atoms, validate_args=validate_args) 73 | 74 | # override 75 | def expand( 76 | self, 77 | batch_shape: "Size" = torch.Size(), # noqa: B008 78 | _instance: "Optional[Deterministic]" = None, 79 | ) -> "Deterministic": 80 | new = self._get_checked_instance(Deterministic, _instance) 81 | new.value = self.value.expand(batch_shape) 82 | atoms = new.value[..., None] 83 | probs = torch.ones_like(atoms) 84 | super(Deterministic, new).__init__(probs, atoms=atoms, validate_args=False) 85 | new._validate_args = self._validate_args 86 | return new 87 | 88 | # override 89 | @property 90 | def mean(self) -> "Tensor": 91 | return self.value 92 | 93 | # override 94 | @property 95 | def mode(self) -> "Tensor": 96 | return self.value 97 | 98 | # override 99 | @property 100 | def variance(self) -> "Tensor": 101 | return torch.zeros_like(self.value) 102 | 103 | # override 104 | def sample(self, sample_shape: "Size" = torch.Size()) -> "Tensor": # noqa: B008 105 | with torch.no_grad(): 106 | return self.rsample(sample_shape) 107 | 108 | # override 109 | def rsample(self, sample_shape: "Size" = torch.Size()) -> "Tensor": # noqa: B008 110 | shape = self._extended_shape(sample_shape) 111 | return self.value.expand(shape) 112 | 113 | # override 114 | def log_prob(self, value: "Tensor") -> "Tensor": 115 | if self._validate_args: 116 | self._validate_sample(value) 117 | expanded_value = self.value.expand_as(value) 118 | return (expanded_value == value).type(value.type()).log() 119 | 120 | # override 121 | def cdf(self, value: "Tensor") -> "Tensor": 122 | if self._validate_args: 123 | self._validate_sample(value) 124 | expanded_value = self.value.expand_as(value) 125 | return (expanded_value <= value).type(value.type()) 126 | 127 | # override 128 | def enumerate_support(self, expand: "bool" = True) -> "Tensor": 129 | try: 130 | return super().enumerate_support(expand) 131 | except NotImplementedError: 132 | raise NotImplementedError( 133 | "`enumerate_support` does not support inhomogeneous values" 134 | ) 135 | 136 | # override 137 | def entropy(self) -> "Tensor": 138 | return torch.zeros_like(self.value) 139 | 140 | # override 141 | def __repr__(self) -> "str": 142 | return ( 143 | f"{type(self).__name__}" 144 | f"(value: {self.value if self.value.numel() == 1 else self.value.shape})" 145 | ) 146 | 147 | 148 | @register_kl(Deterministic, Distribution) 149 | @register_kl(Deterministic, Finite) # Avoid ambiguities 150 | def _kl_deterministic_distribution(p: "Deterministic", q: "Distribution") -> "Tensor": 151 | return -q.log_prob(p.value) 152 | -------------------------------------------------------------------------------- /tests/distributions/test_distributions_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # ============================================================================== 4 | # Copyright 2022 Luca Della Libera. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # ============================================================================== 18 | 19 | """Test distribution utilities.""" 20 | 21 | import pytest 22 | import torch 23 | 24 | from bayestorch.distributions.utils import ( 25 | get_deterministic, 26 | get_laplace, 27 | get_log_scale_normal, 28 | get_mixture_laplace, 29 | get_mixture_log_scale_normal, 30 | get_mixture_normal, 31 | get_mixture_softplus_inv_scale_normal, 32 | get_normal, 33 | get_softplus_inv_scale_normal, 34 | ) 35 | 36 | 37 | def test_get_deterministic() -> "None": 38 | model = torch.nn.Linear(4, 2) 39 | _, _ = get_deterministic(model.parameters(), 2.0, requires_grad=True) 40 | builder, kwargs = get_deterministic(model.parameters(), requires_grad=True) 41 | distribution = builder(**kwargs) 42 | print( 43 | f"Number of parameters: {sum(parameter.numel() for parameter in model.parameters())}" 44 | ) 45 | print(f"Sample shape: {distribution.sample().shape}") 46 | 47 | 48 | def test_get_laplace() -> "None": 49 | model = torch.nn.Linear(4, 2) 50 | builder, kwargs = get_laplace(model.parameters(), 0.0, 1.0, requires_grad=True) 51 | distribution = builder(**kwargs) 52 | print( 53 | f"Number of parameters: {sum(parameter.numel() for parameter in model.parameters())}" 54 | ) 55 | print(f"Sample shape: {distribution.sample().shape}") 56 | 57 | 58 | def test_get_normal() -> "None": 59 | model = torch.nn.Linear(4, 2) 60 | builder, kwargs = get_normal(model.parameters(), 0.0, 1.0, requires_grad=True) 61 | distribution = builder(**kwargs) 62 | print( 63 | f"Number of parameters: {sum(parameter.numel() for parameter in model.parameters())}" 64 | ) 65 | print(f"Sample shape: {distribution.sample().shape}") 66 | 67 | 68 | def test_get_log_scale_normal() -> "None": 69 | model = torch.nn.Linear(4, 2) 70 | builder, kwargs = get_log_scale_normal( 71 | model.parameters(), 0.0, -1.0, requires_grad=True 72 | ) 73 | distribution = builder(**kwargs) 74 | print( 75 | f"Number of parameters: {sum(parameter.numel() for parameter in model.parameters())}" 76 | ) 77 | print(f"Sample shape: {distribution.sample().shape}") 78 | 79 | 80 | def test_get_softplus_inv_scale_normal() -> "None": 81 | model = torch.nn.Linear(4, 2) 82 | builder, kwargs = get_softplus_inv_scale_normal( 83 | model.parameters(), 0.0, -1.0, requires_grad=True 84 | ) 85 | distribution = builder(**kwargs) 86 | print( 87 | f"Number of parameters: {sum(parameter.numel() for parameter in model.parameters())}" 88 | ) 89 | print(f"Sample shape: {distribution.sample().shape}") 90 | 91 | 92 | def test_get_mixture_laplace() -> "None": 93 | model = torch.nn.Linear(4, 2) 94 | builder, kwargs = get_mixture_laplace( 95 | model.parameters(), 96 | (0.75, 0.25), 97 | (0.0, 0.0), 98 | (1.0, 2.0), 99 | requires_grad=True, 100 | ) 101 | distribution = builder(**kwargs) 102 | print( 103 | f"Number of parameters: {sum(parameter.numel() for parameter in model.parameters())}" 104 | ) 105 | print(f"Sample shape: {distribution.sample().shape}") 106 | 107 | 108 | def test_get_mixture_normal() -> "None": 109 | model = torch.nn.Linear(4, 2) 110 | builder, kwargs = get_mixture_normal( 111 | model.parameters(), 112 | (0.75, 0.25), 113 | (0.0, 0.0), 114 | (1.0, 2.0), 115 | requires_grad=True, 116 | ) 117 | distribution = builder(**kwargs) 118 | print( 119 | f"Number of parameters: {sum(parameter.numel() for parameter in model.parameters())}" 120 | ) 121 | print(f"Sample shape: {distribution.sample().shape}") 122 | 123 | 124 | def test_get_mixture_log_scale_normal() -> "None": 125 | model = torch.nn.Linear(4, 2) 126 | builder, kwargs = get_mixture_log_scale_normal( 127 | model.parameters(), 128 | (0.75, 0.25), 129 | (0.0, 0.0), 130 | (-1.0, -2.0), 131 | requires_grad=True, 132 | ) 133 | distribution = builder(**kwargs) 134 | print( 135 | f"Number of parameters: {sum(parameter.numel() for parameter in model.parameters())}" 136 | ) 137 | print(f"Sample shape: {distribution.sample().shape}") 138 | 139 | 140 | def test_get_mixture_softplus_inv_scale_normal() -> "None": 141 | model = torch.nn.Linear(4, 2) 142 | builder, kwargs = get_mixture_softplus_inv_scale_normal( 143 | model.parameters(), 144 | (0.75, 0.25), 145 | (0.0, 0.0), 146 | (-1.0, -2.0), 147 | requires_grad=True, 148 | ) 149 | distribution = builder(**kwargs) 150 | print( 151 | f"Number of parameters: {sum(parameter.numel() for parameter in model.parameters())}" 152 | ) 153 | print(f"Sample shape: {distribution.sample().shape}") 154 | 155 | 156 | if __name__ == "__main__": 157 | pytest.main([__file__]) 158 | -------------------------------------------------------------------------------- /bayestorch/optim/sgld.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2022 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Stochastic gradient Langevin dynamics optimizer.""" 18 | 19 | from typing import Any, Dict, Iterable, Union 20 | 21 | import torch 22 | from torch import Tensor 23 | from torch.optim import Optimizer 24 | 25 | 26 | __all__ = [ 27 | "SGLD", 28 | ] 29 | 30 | 31 | # Adapted from: 32 | # https://github.com/JavierAntoran/Bayesian-Neural-Networks/blob/1f867a5bcbd1abfecede99807eb0b5f97ed8be7c/src/Stochastic_Gradient_Langevin_Dynamics/optimizers.py 33 | class SGLD(Optimizer): 34 | """Stochastic gradient Langevin dynamics optimizer. 35 | 36 | The optimization parameters are interpreted as a posterior 37 | sample under stochastic gradient Langevin dynamics with 38 | noise rescaled in each dimension according to RMSProp. 39 | 40 | References 41 | ---------- 42 | .. [1] C. Li, C. Chen, D. Carlson, and L. Carin. 43 | "Preconditioned Stochastic Gradient Langevin Dynamics for Deep Neural Networks". 44 | In: AAAI. 2016, pp. 1788-1794. 45 | URL: https://arxiv.org/abs/1512.07666 46 | 47 | Examples 48 | -------- 49 | >>> import torch 50 | >>> from torch import nn 51 | >>> 52 | >>> from bayestorch.optim import SGLD 53 | >>> 54 | >>> 55 | >>> batch_size = 10 56 | >>> in_features = 4 57 | >>> out_features = 2 58 | >>> model = nn.Linear(in_features, out_features) 59 | >>> optimizer = SGLD(model.parameters()) 60 | >>> input = torch.rand(batch_size, in_features) 61 | >>> output = model(input) 62 | >>> loss = output.sum() 63 | >>> loss.backward() 64 | >>> optimizer.step() 65 | 66 | """ 67 | 68 | # override 69 | def __init__( 70 | self, 71 | params: "Union[Iterable[Tensor], Iterable[Dict[str, Any]]]", 72 | lr: "float" = 1e-2, 73 | num_burn_in_steps: "int" = 3000, 74 | precondition_decay_rate: "float" = 0.95, 75 | epsilon: "float" = 1e-8, 76 | ) -> "None": 77 | """Initialize the object. 78 | 79 | Parameters 80 | ---------- 81 | params: 82 | The parameters to optimize. 83 | lr: 84 | The learning rate. 85 | num_burn_in_steps: 86 | The number of steps for which gradient statistics 87 | are collected to update the preconditioner before 88 | drawing noisy samples. 89 | precondition_decay_rate: 90 | The exponential decay rate for rescaling the 91 | preconditioner according to RMSProp. Should 92 | be close to 1 to approximate sampling from 93 | the posterior. 94 | epsilon: 95 | The term for improving numerical stability. 96 | 97 | Raises 98 | ------ 99 | ValueError 100 | If an invalid argument value is given. 101 | 102 | """ 103 | if lr < 0.0: 104 | raise ValueError(f"`lr` ({lr}) must be in the interval [0, inf)") 105 | if num_burn_in_steps < 0 or not float(num_burn_in_steps).is_integer(): 106 | raise ValueError( 107 | f"`num_burn_in_steps` ({num_burn_in_steps}) must be in the integer interval [0, inf)" 108 | ) 109 | if precondition_decay_rate < 0.0 or precondition_decay_rate > 1.0: 110 | raise ValueError( 111 | f"`precondition_decay_rate` ({precondition_decay_rate}) must be in the interval [0, 1]" 112 | ) 113 | if epsilon <= 0.0: 114 | raise ValueError(f"`epsilon` ({epsilon}) must be in the interval (0, inf)") 115 | 116 | defaults = { 117 | "lr": lr, 118 | "num_burn_in_steps": int(num_burn_in_steps), 119 | "precondition_decay_rate": precondition_decay_rate, 120 | "epsilon": epsilon, 121 | } 122 | super().__init__(params, defaults) 123 | 124 | # override 125 | @torch.no_grad() 126 | def step(self) -> "None": 127 | for group in self.param_groups: 128 | for param in group["params"]: 129 | if param.grad is None: 130 | continue 131 | 132 | state = self.state[param] 133 | 134 | # State initialization 135 | if not state: 136 | state["iteration"] = 0 137 | state["momentum"] = torch.ones_like(param) 138 | 139 | lr = group["lr"] 140 | num_burn_in_steps = group["num_burn_in_steps"] 141 | precondition_decay_rate = group["precondition_decay_rate"] 142 | epsilon = group["epsilon"] 143 | 144 | state["iteration"] += 1 145 | iteration = state["iteration"] 146 | momentum = state["momentum"] 147 | 148 | grad = param.grad 149 | 150 | # Momentum update 151 | momentum += (1.0 - precondition_decay_rate) * (grad**2 - momentum) 152 | 153 | # Burn-in steps 154 | if iteration <= num_burn_in_steps: 155 | stddev = torch.zeros_like(param) 156 | else: 157 | stddev = 1.0 / torch.full_like(param, lr).sqrt() 158 | 159 | # Draw random sample 160 | preconditioner = 1.0 / (momentum + epsilon).sqrt() 161 | mean = 0.5 * preconditioner * grad 162 | stddev *= preconditioner.sqrt() 163 | sample = torch.normal(mean, stddev) 164 | 165 | # Parameter update 166 | param += -lr * sample 167 | -------------------------------------------------------------------------------- /examples/mnist/train.py: -------------------------------------------------------------------------------- 1 | # Adapted from: 2 | # https://github.com/pytorch/examples/blob/9aad148615b7519eadfa1a60356116a50561f192/mnist/main.py 3 | 4 | from __future__ import print_function 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torchvision import datasets, transforms 11 | from torch.optim.lr_scheduler import StepLR 12 | 13 | 14 | class Net(nn.Module): 15 | def __init__(self): 16 | super(Net, self).__init__() 17 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 18 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 19 | self.dropout1 = nn.Dropout(0.25) 20 | self.dropout2 = nn.Dropout(0.5) 21 | self.fc1 = nn.Linear(9216, 128) 22 | self.fc2 = nn.Linear(128, 10) 23 | 24 | def forward(self, x): 25 | x = self.conv1(x) 26 | x = F.relu(x) 27 | x = self.conv2(x) 28 | x = F.relu(x) 29 | x = F.max_pool2d(x, 2) 30 | x = self.dropout1(x) 31 | x = torch.flatten(x, 1) 32 | x = self.fc1(x) 33 | x = F.relu(x) 34 | x = self.dropout2(x) 35 | x = self.fc2(x) 36 | output = F.log_softmax(x, dim=1) 37 | return output 38 | 39 | 40 | def train(args, model, device, train_loader, optimizer, epoch): 41 | model.train() 42 | for batch_idx, (data, target) in enumerate(train_loader): 43 | data, target = data.to(device), target.to(device) 44 | optimizer.zero_grad() 45 | output = model(data) 46 | loss = F.nll_loss(output, target) 47 | loss.backward() 48 | optimizer.step() 49 | if batch_idx % args.log_interval == 0: 50 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 51 | epoch, batch_idx * len(data), len(train_loader.dataset), 52 | 100. * batch_idx / len(train_loader), loss.item())) 53 | if args.dry_run: 54 | break 55 | 56 | 57 | def test(model, device, test_loader): 58 | model.eval() 59 | test_loss = 0 60 | correct = 0 61 | with torch.no_grad(): 62 | for data, target in test_loader: 63 | data, target = data.to(device), target.to(device) 64 | output = model(data) 65 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 66 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 67 | correct += pred.eq(target.view_as(pred)).sum().item() 68 | 69 | test_loss /= len(test_loader.dataset) 70 | 71 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 72 | test_loss, correct, len(test_loader.dataset), 73 | 100. * correct / len(test_loader.dataset))) 74 | 75 | 76 | def main(): 77 | # Training settings 78 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 79 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 80 | help='input batch size for training (default: 64)') 81 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 82 | help='input batch size for testing (default: 1000)') 83 | parser.add_argument('--epochs', type=int, default=14, metavar='N', 84 | help='number of epochs to train (default: 14)') 85 | parser.add_argument('--lr', type=float, default=1.0, metavar='LR', 86 | help='learning rate (default: 1.0)') 87 | parser.add_argument('--gamma', type=float, default=0.7, metavar='M', 88 | help='Learning rate step gamma (default: 0.7)') 89 | parser.add_argument('--no-cuda', action='store_true', default=False, 90 | help='disables CUDA training') 91 | parser.add_argument('--no-mps', action='store_true', default=False, 92 | help='disables macOS GPU training') 93 | parser.add_argument('--dry-run', action='store_true', default=False, 94 | help='quickly check a single pass') 95 | parser.add_argument('--seed', type=int, default=1, metavar='S', 96 | help='random seed (default: 1)') 97 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 98 | help='how many batches to wait before logging training status') 99 | parser.add_argument('--save-model', action='store_true', default=False, 100 | help='For Saving the current Model') 101 | args = parser.parse_args() 102 | use_cuda = not args.no_cuda and torch.cuda.is_available() 103 | #use_mps = not args.no_mps and torch.backends.mps.is_available() 104 | 105 | torch.manual_seed(args.seed) 106 | 107 | if use_cuda: 108 | device = torch.device("cuda") 109 | #elif use_mps: 110 | # device = torch.device("mps") 111 | else: 112 | device = torch.device("cpu") 113 | 114 | train_kwargs = {'batch_size': args.batch_size} 115 | test_kwargs = {'batch_size': args.test_batch_size} 116 | if use_cuda: 117 | cuda_kwargs = {'num_workers': 1, 118 | 'pin_memory': True, 119 | 'shuffle': True} 120 | train_kwargs.update(cuda_kwargs) 121 | test_kwargs.update(cuda_kwargs) 122 | 123 | transform=transforms.Compose([ 124 | transforms.ToTensor(), 125 | transforms.Normalize((0.1307,), (0.3081,)) 126 | ]) 127 | dataset1 = datasets.MNIST('../data', train=True, download=True, 128 | transform=transform) 129 | dataset2 = datasets.MNIST('../data', train=False, 130 | transform=transform) 131 | train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) 132 | test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) 133 | 134 | model = Net().to(device) 135 | optimizer = optim.Adadelta(model.parameters(), lr=args.lr) 136 | 137 | scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) 138 | for epoch in range(1, args.epochs + 1): 139 | train(args, model, device, train_loader, optimizer, epoch) 140 | test(model, device, test_loader) 141 | scheduler.step() 142 | 143 | if args.save_model: 144 | torch.save(model.state_dict(), "mnist_cnn.pt") 145 | 146 | 147 | if __name__ == '__main__': 148 | main() 149 | -------------------------------------------------------------------------------- /bayestorch/optim/svgd.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2022 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Stein variational gradient descent preconditioner.""" 18 | 19 | from typing import Any, Callable, Dict, Iterable, Union 20 | 21 | import torch 22 | from torch import Tensor, nn 23 | from torch.optim import Optimizer 24 | 25 | 26 | __all__ = [ 27 | "SVGD", 28 | ] 29 | 30 | 31 | class SVGD(Optimizer): 32 | """Stein variational gradient descent preconditioner. 33 | 34 | References 35 | ---------- 36 | .. [1] Q. Liu and D. Wang. 37 | "Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm". 38 | In: Advances in Neural Information Processing Systems. 2016, pp. 2378-2386. 39 | URL: https://arxiv.org/abs/1608.04471 40 | 41 | Examples 42 | -------- 43 | >>> import math 44 | >>> 45 | >>> import torch 46 | >>> from torch import nn 47 | >>> 48 | >>> from bayestorch.optim import SVGD 49 | >>> 50 | >>> 51 | >>> def rbf_kernel(x1, x2): 52 | ... deltas = torch.cdist(x1, x2) 53 | ... squared_deltas = deltas**2 54 | ... bandwidth = ( 55 | ... squared_deltas.detach().median() 56 | ... / math.log(min(x1.shape[0], x2.shape[0])) 57 | ... ) 58 | ... log_kernels = -squared_deltas / bandwidth 59 | ... kernels = log_kernels.exp() 60 | ... return kernels 61 | >>> 62 | >>> 63 | >>> num_particles = 5 64 | >>> batch_size = 10 65 | >>> in_features = 4 66 | >>> out_features = 2 67 | >>> models = nn.ModuleList( 68 | ... [nn.Linear(in_features, out_features) for _ in range(num_particles)] 69 | ... ) 70 | >>> preconditioner = SVGD(models.parameters(), rbf_kernel, num_particles) 71 | >>> input = torch.rand(batch_size, in_features) 72 | >>> outputs = torch.cat([model(input) for model in models]) 73 | >>> loss = outputs.sum() 74 | >>> loss.backward() 75 | >>> preconditioner.step() 76 | 77 | """ 78 | 79 | # override 80 | def __init__( 81 | self, 82 | params: "Union[Iterable[Tensor], Iterable[Dict[str, Any]]]", 83 | kernel: "Callable[[Tensor, Tensor], Tensor]", 84 | num_particles: "int", 85 | ) -> "None": 86 | """Initialize the object. 87 | 88 | Parameters 89 | ---------- 90 | params: 91 | The parameters to precondition. The total number of 92 | parameters must be a multiple of `num_particles`. 93 | kernel: 94 | The kernel, i.e. a callable that receives two input 95 | tensors and returns the corresponding kernel values 96 | (must be differentiable with respect to both input 97 | tensors). 98 | num_particles: 99 | The number of particles. 100 | 101 | Raises 102 | ------ 103 | ValueError 104 | If an invalid argument value is given. 105 | 106 | """ 107 | if num_particles < 1 or not float(num_particles).is_integer(): 108 | raise ValueError( 109 | f"`num_particles` ({num_particles}) must be in the integer interval [1, inf)" 110 | ) 111 | num_particles = int(num_particles) 112 | params = list(params) 113 | 114 | # Extract particles 115 | with torch.no_grad(): 116 | particles = nn.utils.parameters_to_vector(params) 117 | 118 | if particles.numel() % num_particles != 0: 119 | raise ValueError( 120 | f"Total number of parameters ({particles.numel()}) must " 121 | f"be a multiple of `num_particles` ({num_particles})" 122 | ) 123 | 124 | defaults = {"kernel": kernel, "num_particles": num_particles} 125 | super().__init__(params, defaults) 126 | 127 | # override 128 | @torch.no_grad() 129 | def step(self) -> "None": 130 | for group in self.param_groups: 131 | params = group["params"] 132 | kernel = group["kernel"] 133 | num_particles = group["num_particles"] 134 | 135 | # Extract particles 136 | particles = nn.utils.parameters_to_vector(params).reshape(num_particles, -1) 137 | 138 | # Extract particle gradients 139 | particle_grads = [] 140 | for param in params: 141 | grad = param.grad 142 | if grad is None: 143 | raise RuntimeError("Gradient of some parameters is None") 144 | particle_grads.append(grad) 145 | particle_grads = nn.utils.parameters_to_vector(particle_grads).reshape( 146 | num_particles, -1 147 | ) 148 | 149 | # Compute kernels and kernel gradients 150 | with torch.enable_grad(): 151 | particles.requires_grad_() 152 | kernels = kernel(particles, particles) 153 | # Need to multiply by -0.5 (see https://github.com/activatedgeek/svgd/issues/1#issuecomment-649235844) 154 | kernels.backward(torch.full_like(kernels, -0.5)) 155 | kernel_grads = particles.grad 156 | kernels.detach_() 157 | particles.requires_grad_(False) 158 | 159 | # Driving gradients (already divided by `num_particles`) 160 | particle_grads = particle_grads.T 161 | particle_grads @= kernels.T 162 | particle_grads = particle_grads.T 163 | 164 | # Repulsive gradients 165 | kernel_grads /= num_particles 166 | particle_grads -= kernel_grads 167 | 168 | # Flatten 169 | particle_grads = particle_grads.flatten() 170 | 171 | # Inject particle gradients 172 | start_idx = 0 173 | for param in params: 174 | end_idx = start_idx + param.numel() 175 | param.grad = particle_grads[start_idx:end_idx].reshape_as(param) 176 | start_idx = end_idx 177 | -------------------------------------------------------------------------------- /bayestorch/optim/sghmc.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2022 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Stochastic gradient Hamiltonian Monte Carlo optimizer.""" 18 | 19 | from typing import Any, Dict, Iterable, Union 20 | 21 | import torch 22 | from torch import Tensor 23 | from torch.optim import Optimizer 24 | 25 | 26 | __all__ = [ 27 | "SGHMC", 28 | ] 29 | 30 | 31 | # Adapted from: 32 | # https://github.com/JavierAntoran/Bayesian-Neural-Networks/blob/1f867a5bcbd1abfecede99807eb0b5f97ed8be7c/src/Stochastic_Gradient_HMC_SA/optimizers.py 33 | class SGHMC(Optimizer): 34 | """Stochastic gradient Hamiltonian Monte Carlo optimizer. 35 | 36 | A burn-in procedure is used to adapt the hyperparameters 37 | during the initial stages of sampling. 38 | 39 | References 40 | ---------- 41 | .. [1] T. Chen, E. B. Fox, and C. Guestrin. 42 | "Stochastic Gradient Hamiltonian Monte Carlo". 43 | In: ICML. 2014, pp. 1683-1691. 44 | URL: https://arxiv.org/abs/1402.4102 45 | 46 | Examples 47 | -------- 48 | >>> import torch 49 | >>> from torch import nn 50 | >>> 51 | >>> from bayestorch.optim import SGHMC 52 | >>> 53 | >>> 54 | >>> batch_size = 10 55 | >>> in_features = 4 56 | >>> out_features = 2 57 | >>> model = nn.Linear(in_features, out_features) 58 | >>> optimizer = SGHMC(model.parameters()) 59 | >>> input = torch.rand(batch_size, in_features) 60 | >>> output = model(input) 61 | >>> loss = output.sum() 62 | >>> loss.backward() 63 | >>> optimizer.step() 64 | 65 | """ 66 | 67 | # override 68 | def __init__( 69 | self, 70 | params: "Union[Iterable[Tensor], Iterable[Dict[str, Any]]]", 71 | lr: "float" = 1e-2, 72 | num_burn_in_steps: "int" = 3000, 73 | momentum_decay: "float" = 0.05, 74 | grad_noise: "float" = 0.0, 75 | epsilon: "float" = 1e-16, 76 | ) -> "None": 77 | """Initialize the object. 78 | 79 | Parameters 80 | ---------- 81 | params: 82 | The parameters to optimize. 83 | lr: 84 | The learning rate. 85 | num_burn_in_steps: 86 | The number of burn-in steps. At each step, 87 | the optimizer hyperparameters are adapted 88 | to decrease the error. 89 | momentum_decay: 90 | The momentum decay per timestep. 91 | grad_noise: 92 | The constant per-parameter gradient 93 | noise used for sampling. 94 | epsilon: 95 | The term for improving numerical stability. 96 | 97 | Raises 98 | ------ 99 | ValueError 100 | If an invalid argument value is given. 101 | 102 | """ 103 | if lr < 0.0: 104 | raise ValueError(f"`lr` ({lr}) must be in the interval [0, inf)") 105 | if num_burn_in_steps < 0 or not float(num_burn_in_steps).is_integer(): 106 | raise ValueError( 107 | f"`num_burn_in_steps` ({num_burn_in_steps}) must be in the integer interval [0, inf)" 108 | ) 109 | if momentum_decay < 0.0: 110 | raise ValueError( 111 | f"`momentum_decay` ({momentum_decay}) must be in the interval [0, inf)" 112 | ) 113 | if grad_noise < 0.0: 114 | raise ValueError( 115 | f"`grad_noise` ({grad_noise}) must be in the interval [0, inf)" 116 | ) 117 | if epsilon <= 0.0: 118 | raise ValueError(f"`epsilon` ({epsilon}) must be in the interval (0, inf)") 119 | 120 | defaults = { 121 | "lr": lr, 122 | "num_burn_in_steps": int(num_burn_in_steps), 123 | "momentum_decay": momentum_decay, 124 | "grad_noise": grad_noise, 125 | "epsilon": epsilon, 126 | } 127 | super().__init__(params, defaults) 128 | 129 | # override 130 | @torch.no_grad() 131 | def step(self) -> "None": 132 | for group in self.param_groups: 133 | for param in group["params"]: 134 | if param.grad is None: 135 | continue 136 | 137 | state = self.state[param] 138 | 139 | # State initialization 140 | if not state: 141 | state["iteration"] = 0 142 | state["tau"] = torch.ones_like(param) 143 | state["g"] = torch.ones_like(param) 144 | state["v_hat"] = torch.ones_like(param) 145 | state["momentum"] = torch.zeros_like(param) 146 | 147 | lr = group["lr"] 148 | num_burn_in_steps = group["num_burn_in_steps"] 149 | momentum_decay = group["momentum_decay"] 150 | grad_noise = group["grad_noise"] 151 | epsilon = group["epsilon"] 152 | 153 | state["iteration"] += 1 154 | iteration = state["iteration"] 155 | tau = state["tau"] 156 | g = state["g"] 157 | v_hat = state["v_hat"] 158 | momentum = state["momentum"] 159 | 160 | grad = param.grad 161 | r = 1.0 / (tau + 1.0) 162 | m_inv = 1.0 / v_hat.sqrt() 163 | 164 | # Burn-in steps 165 | if iteration <= num_burn_in_steps: 166 | tau += 1.0 - tau * (g**2 / v_hat) 167 | g += (grad - g) * r 168 | v_hat += (grad**2 - v_hat) * r 169 | 170 | # Draw random sample 171 | grad_noise_var = ( 172 | 2.0 * (lr**2) * momentum_decay * m_inv 173 | - 2.0 * (lr**3) * (m_inv**2) * grad_noise 174 | - (lr**4) 175 | ) 176 | stddev = grad_noise_var.clamp(min=epsilon).sqrt() 177 | sample = torch.normal(0.0, stddev) 178 | 179 | # Parameter update 180 | momentum += sample - lr**2 * m_inv * grad - momentum_decay * momentum 181 | param += momentum 182 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BayesTorch 2 | 3 | [![Python version: 3.6 | 3.7 | 3.8 | 3.9 | 3.10](https://img.shields.io/badge/python-3.6%20|%203.7%20|%203.8%20|%203.9%20|%203.10-blue)](https://www.python.org/downloads/) 4 | [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://github.com/lucadellalib/bayestorch/blob/main/LICENSE) 5 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 6 | [![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://github.com/PyCQA/isort) 7 | [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit) 8 | ![PyPI version](https://img.shields.io/pypi/v/bayestorch) 9 | [![](https://pepy.tech/badge/bayestorch)](https://pypi.org/project/bayestorch/) 10 | 11 | Welcome to `bayestorch`, a lightweight Bayesian deep learning library for fast prototyping based on 12 | [PyTorch](https://pytorch.org). It provides the basic building blocks for the following 13 | Bayesian inference algorithms: 14 | 15 | - [Bayes by Backprop (BBB)](https://arxiv.org/abs/1505.05424) 16 | - [Markov chain Monte Carlo (MCMC)](https://www.cs.toronto.edu/~radford/ftp/thesis.pdf) 17 | - [Stein variational gradient descent (SVGD)](https://arxiv.org/abs/1608.04471) 18 | 19 | --------------------------------------------------------------------------------------------------------- 20 | 21 | ## 💡 Key features 22 | 23 | - Low-code definition of Bayesian (or partially Bayesian) models 24 | - Support for custom neural network layers 25 | - Support for custom prior/posterior distributions 26 | - Support for layer/parameter-wise prior/posterior distributions 27 | - Support for composite prior/posterior distributions 28 | - Highly modular object-oriented design 29 | - User-friendly and easily extensible APIs 30 | - Detailed API documentation 31 | 32 | --------------------------------------------------------------------------------------------------------- 33 | 34 | ## 🛠️️ Installation 35 | 36 | ### Using Pip 37 | 38 | First of all, install [Python 3.6 or later](https://www.python.org). Open a terminal and run: 39 | 40 | ``` 41 | pip install bayestorch 42 | ``` 43 | 44 | ### From source 45 | 46 | First of all, install [Python 3.6 or later](https://www.python.org). 47 | Clone or download and extract the repository, navigate to ``, open a 48 | terminal and run: 49 | 50 | ``` 51 | pip install -e . 52 | ``` 53 | 54 | --------------------------------------------------------------------------------------------------------- 55 | 56 | ## ▶️ Quickstart 57 | 58 | Here are a few code snippets showcasing some key features of the library. 59 | For complete training loops, please refer to `examples/mnist` and `examples/regression`. 60 | 61 | ### Bayesian model trainable via Bayes by Backprop 62 | 63 | ```python 64 | from torch.nn import Linear 65 | 66 | from bayestorch.distributions import ( 67 | get_mixture_log_scale_normal, 68 | get_softplus_inv_scale_normal, 69 | ) 70 | from bayestorch.nn import VariationalPosteriorModule 71 | 72 | 73 | # Define model 74 | model = Linear(5, 1) 75 | 76 | # Define log scale normal mixture prior over the model parameters 77 | prior_builder, prior_kwargs = get_mixture_log_scale_normal( 78 | model.parameters(), 79 | weights=[0.75, 0.25], 80 | locs=(0.0, 0.0), 81 | log_scales=(-1.0, -6.0) 82 | ) 83 | 84 | # Define inverse softplus scale normal posterior over the model parameters 85 | posterior_builder, posterior_kwargs = get_softplus_inv_scale_normal( 86 | model.parameters(), loc=0.0, softplus_inv_scale=-7.0, requires_grad=True, 87 | ) 88 | 89 | # Define Bayesian model trainable via Bayes by Backprop 90 | model = VariationalPosteriorModule( 91 | model, prior_builder, prior_kwargs, posterior_builder, posterior_kwargs 92 | ) 93 | ``` 94 | 95 | ### Partially Bayesian model trainable via Bayes by Backprop 96 | 97 | ```python 98 | from torch.nn import Linear 99 | 100 | from bayestorch.distributions import ( 101 | get_mixture_log_scale_normal, 102 | get_softplus_inv_scale_normal, 103 | ) 104 | from bayestorch.nn import VariationalPosteriorModule 105 | 106 | 107 | # Define model 108 | model = Linear(5, 1) 109 | 110 | # Define log scale normal mixture prior over `model.weight` 111 | prior_builder, prior_kwargs = get_mixture_log_scale_normal( 112 | [model.weight], 113 | weights=[0.75, 0.25], 114 | locs=(0.0, 0.0), 115 | log_scales=(-1.0, -6.0) 116 | ) 117 | 118 | # Define inverse softplus scale normal posterior over `model.weight` 119 | posterior_builder, posterior_kwargs = get_softplus_inv_scale_normal( 120 | [model.weight], loc=0.0, softplus_inv_scale=-7.0, requires_grad=True, 121 | ) 122 | 123 | # Define partially Bayesian model trainable via Bayes by Backprop 124 | model = VariationalPosteriorModule( 125 | model, prior_builder, prior_kwargs, 126 | posterior_builder, posterior_kwargs, [model.weight], 127 | ) 128 | ``` 129 | 130 | ### Composite prior 131 | 132 | ```python 133 | from torch.distributions import Independent 134 | from torch.nn import Linear 135 | 136 | from bayestorch.distributions import ( 137 | CatDistribution, 138 | get_laplace, 139 | get_normal, 140 | get_softplus_inv_scale_normal, 141 | ) 142 | from bayestorch.nn import VariationalPosteriorModule 143 | 144 | 145 | # Define model 146 | model = Linear(5, 1) 147 | 148 | # Define normal prior over `model.weight` 149 | weight_prior_builder, weight_prior_kwargs = get_normal( 150 | [model.weight], 151 | loc=0.0, 152 | scale=1.0, 153 | prefix="weight_", 154 | ) 155 | 156 | # Define Laplace prior over `model.bias` 157 | bias_prior_builder, bias_prior_kwargs = get_laplace( 158 | [model.bias], 159 | loc=0.0, 160 | scale=1.0, 161 | prefix="bias_", 162 | ) 163 | 164 | # Define composite prior over the model parameters 165 | prior_builder = ( 166 | lambda **kwargs: CatDistribution([ 167 | Independent(weight_prior_builder(**kwargs), 1), 168 | Independent(bias_prior_builder(**kwargs), 1), 169 | ]) 170 | ) 171 | prior_kwargs = {**weight_prior_kwargs, **bias_prior_kwargs} 172 | 173 | # Define inverse softplus scale normal posterior over the model parameters 174 | posterior_builder, posterior_kwargs = get_softplus_inv_scale_normal( 175 | model.parameters(), loc=0.0, softplus_inv_scale=-7.0, requires_grad=True, 176 | ) 177 | 178 | # Define Bayesian model trainable via Bayes by Backprop 179 | model = VariationalPosteriorModule( 180 | model, prior_builder, prior_kwargs, posterior_builder, posterior_kwargs, 181 | ) 182 | ``` 183 | 184 | --------------------------------------------------------------------------------------------------------- 185 | 186 | ## 📧 Contact 187 | 188 | [luca.dellalib@gmail.com](mailto:luca.dellalib@gmail.com) 189 | 190 | --------------------------------------------------------------------------------------------------------- 191 | -------------------------------------------------------------------------------- /bayestorch/distributions/finite.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2022 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Finite distribution.""" 18 | 19 | from typing import Optional 20 | 21 | import torch 22 | from torch import Size, Tensor 23 | from torch.distributions import ( 24 | Distribution, 25 | OneHotCategorical, 26 | constraints, 27 | register_kl, 28 | ) 29 | from torch.distributions.constraints import Constraint 30 | 31 | from bayestorch.distributions.constraints import ordered_real_vector, real_set 32 | 33 | 34 | __all__ = [ 35 | "Finite", 36 | ] 37 | 38 | 39 | class Finite(Distribution): 40 | """Distribution defined over an arbitrary finite support. 41 | 42 | Examples 43 | -------- 44 | >>> from bayestorch.distributions import Finite 45 | >>> 46 | >>> 47 | >>> logits = torch.as_tensor([0.25, 0.15, 0.10, 0.30, 0.20]) 48 | >>> atoms = torch.as_tensor([5.0, 7.5, 10.0, 12.5, 15.0]) 49 | >>> distribution = Finite(logits, atoms=atoms) 50 | 51 | """ 52 | 53 | has_enumerate_support = True 54 | arg_constraints = { 55 | "probs": constraints.simplex, 56 | "logits": constraints.real_vector, 57 | "atoms": ordered_real_vector, 58 | } 59 | 60 | # override 61 | def __init__( 62 | self, 63 | probs: "Optional[Tensor]" = None, 64 | logits: "Optional[Tensor]" = None, 65 | atoms: "Optional[Tensor]" = None, 66 | validate_args: "Optional[bool]" = None, 67 | ) -> "None": 68 | """Initialize the object. 69 | 70 | Parameters 71 | ---------- 72 | probs: 73 | The event probabilities. 74 | Must be None if `logits` is given. 75 | logits: 76 | The event log probabilities (unnormalized). 77 | Must be None if `probs` is given. 78 | atoms: 79 | The atoms that form the support of the distribution, 80 | sorted in (strict) ascending order. 81 | Default to `{0, ..., N - 1}` where `N` is 82 | ``probs.shape[-1]`` or ``logits.shape[-1]``. 83 | validate_args: 84 | True to validate the arguments, False otherwise. 85 | Default to ``__debug__``. 86 | 87 | """ 88 | param = probs if probs is not None else logits 89 | if atoms is None: 90 | atoms = torch.arange(param.shape[-1], device=param.device) 91 | atoms = atoms.type(param.type()) 92 | param, self.atoms = torch.broadcast_tensors(param, atoms) 93 | if probs is not None: 94 | probs = probs.expand_as(param) 95 | if logits is not None: 96 | logits = logits.expand_as(param) 97 | self._one_hot_categorical = OneHotCategorical(probs, logits, validate_args) 98 | super().__init__( 99 | self._one_hot_categorical.batch_shape, validate_args=validate_args 100 | ) 101 | 102 | # override 103 | def expand( 104 | self, 105 | batch_shape: "Size" = torch.Size(), # noqa: B008 106 | _instance: "Optional[Finite]" = None, 107 | ) -> "Finite": 108 | new = self._get_checked_instance(Finite, _instance) 109 | param_shape = batch_shape + self.atoms.shape[-1:] 110 | new.atoms = self.atoms.expand(param_shape) 111 | new._one_hot_categorical = self._one_hot_categorical.expand(batch_shape) 112 | super(Finite, new).__init__(batch_shape, self.event_shape, False) 113 | new._validate_args = self._validate_args 114 | return new 115 | 116 | # override 117 | @property 118 | def support(self) -> "Constraint": 119 | return real_set(self.atoms) 120 | 121 | # override 122 | @property 123 | def probs(self) -> "Tensor": 124 | return self._one_hot_categorical.probs 125 | 126 | # override 127 | @property 128 | def logits(self) -> "Tensor": 129 | return self._one_hot_categorical.logits 130 | 131 | # override 132 | @property 133 | def mean(self) -> "Tensor": 134 | return (self.probs * self.atoms).sum(dim=-1) 135 | 136 | @property 137 | def mode(self) -> "Tensor": 138 | return self.atoms.gather( 139 | -1, 140 | self.probs.argmax(dim=-1, keepdim=True), 141 | )[..., 0] 142 | 143 | # override 144 | @property 145 | def variance(self) -> "Tensor": 146 | return (self.probs * (self.atoms**2)).sum(dim=-1) - self.mean**2 147 | 148 | # override 149 | def sample(self, sample_shape: "Size" = torch.Size()) -> "Tensor": # noqa: B008 150 | one_hot_sample = self._one_hot_categorical.sample(sample_shape) 151 | return (self.atoms * one_hot_sample).sum(dim=-1) 152 | 153 | # override 154 | def log_prob(self, value: "Tensor") -> "Tensor": 155 | if self._validate_args: 156 | self._validate_sample(value) 157 | # Add event dimension 158 | value = value[..., None].expand(*value.shape, self.atoms.shape[-1]) 159 | expanded_atoms = self.atoms.expand_as(value) 160 | mask = expanded_atoms == value 161 | result = (self.logits * mask).sum(dim=-1) 162 | result[mask.sum(dim=-1) == 0] = -float("inf") 163 | return result 164 | 165 | # override 166 | def cdf(self, value: "Tensor") -> "Tensor": 167 | if self._validate_args: 168 | self._validate_sample(value) 169 | # Add event dimension 170 | value = value[..., None].expand(*value.shape, self.atoms.shape[-1]) 171 | expanded_atoms = self.atoms.expand_as(value) 172 | mask = expanded_atoms <= value 173 | return (self.probs * mask).sum(dim=-1) 174 | 175 | # override 176 | def enumerate_support(self, expand: "bool" = True) -> "Tensor": 177 | if not ( 178 | self.atoms == self.atoms[(slice(1),) * len(self.batch_shape) + (...,)] 179 | ).all(): 180 | raise NotImplementedError( 181 | "`enumerate_support` does not support inhomogeneous atoms" 182 | ) 183 | values = self.atoms.movedim(-1, 0) 184 | if not expand: 185 | values = values[ 186 | (...,) 187 | + ( 188 | slice( 189 | 1, 190 | ), 191 | ) 192 | * len(self.batch_shape) 193 | ] 194 | return values 195 | 196 | # override 197 | def entropy(self) -> "Tensor": 198 | return self._one_hot_categorical.entropy() 199 | 200 | # override 201 | def __repr__(self) -> "str": 202 | return ( 203 | f"{type(self).__name__}" 204 | f"(probs: {self.probs if self.probs.numel() == 1 else self.probs.shape}, " 205 | f"atoms: {self.atoms if self.atoms.numel() == 1 else self.atoms.shape})" 206 | ) 207 | 208 | 209 | @register_kl(Finite, Finite) 210 | def _kl_finite_finite(p: "Finite", q: "Finite") -> "Tensor": 211 | try: 212 | if (p.atoms == q.atoms).all(): 213 | return torch.distributions.kl._kl_categorical_categorical(p, q) 214 | except Exception: 215 | raise NotImplementedError 216 | -------------------------------------------------------------------------------- /examples/mnist/train_mcmc.py: -------------------------------------------------------------------------------- 1 | # Adapted from: 2 | # https://github.com/pytorch/examples/blob/9aad148615b7519eadfa1a60356116a50561f192/mnist/main.py 3 | 4 | # Changes to the code are kept to a minimum to facilitate the comparison with the original example 5 | 6 | from __future__ import print_function 7 | import argparse 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torchvision import datasets, transforms 12 | from torch.optim.lr_scheduler import StepLR 13 | 14 | from bayestorch.distributions import get_mixture_log_scale_normal 15 | from bayestorch.nn import PriorModule 16 | from bayestorch.optim import SGLD 17 | 18 | 19 | class Net(nn.Module): 20 | def __init__(self): 21 | super(Net, self).__init__() 22 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 23 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 24 | #self.dropout1 = nn.Dropout(0.25) 25 | #self.dropout2 = nn.Dropout(0.5) 26 | self.fc1 = nn.Linear(9216, 128) 27 | self.fc2 = nn.Linear(128, 10) 28 | 29 | def forward(self, x): 30 | x = self.conv1(x) 31 | x = F.relu(x) 32 | x = self.conv2(x) 33 | x = F.relu(x) 34 | x = F.max_pool2d(x, 2) 35 | #x = self.dropout1(x) 36 | x = torch.flatten(x, 1) 37 | x = self.fc1(x) 38 | x = F.relu(x) 39 | #x = self.dropout2(x) 40 | x = self.fc2(x) 41 | output = F.log_softmax(x, dim=1) 42 | return output 43 | 44 | 45 | def train(args, model, device, train_loader, optimizer, epoch, log_prior_weight): 46 | model.train() 47 | for batch_idx, (data, target) in enumerate(train_loader): 48 | data, target = data.to(device), target.to(device) 49 | optimizer.zero_grad() 50 | #output = model(data) 51 | #loss = F.nll_loss(output, target) 52 | output, log_prior = model(data, return_log_prior=True) 53 | loss = F.nll_loss(output, target, reduction="sum") - log_prior_weight * log_prior 54 | loss.backward() 55 | optimizer.step() 56 | if batch_idx % args.log_interval == 0: 57 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 58 | epoch, batch_idx * len(data), len(train_loader.dataset), 59 | 100. * batch_idx / len(train_loader), loss.item())) 60 | if args.dry_run: 61 | break 62 | 63 | 64 | def test(model, device, test_loader, log_prior_weight): 65 | model.eval() 66 | test_loss = 0 67 | correct = 0 68 | with torch.no_grad(): 69 | for data, target in test_loader: 70 | data, target = data.to(device), target.to(device) 71 | #output = model(data) 72 | #test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 73 | output, log_prior = model(data, return_log_prior=True) 74 | test_loss += (F.nll_loss(output, target, reduction="sum") - log_prior_weight * log_prior).item() # sum up batch loss 75 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 76 | correct += pred.eq(target.view_as(pred)).sum().item() 77 | 78 | test_loss /= len(test_loader.dataset) 79 | 80 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 81 | test_loss, correct, len(test_loader.dataset), 82 | 100. * correct / len(test_loader.dataset))) 83 | 84 | 85 | def main(): 86 | # Training settings 87 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 88 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 89 | help='input batch size for training (default: 64)') 90 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 91 | help='input batch size for testing (default: 1000)') 92 | parser.add_argument('--epochs', type=int, default=14, metavar='N', 93 | help='number of epochs to train (default: 14)') 94 | parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', 95 | help='learning rate (default: 1e-3)') 96 | parser.add_argument('--gamma', type=float, default=0.9, metavar='M', 97 | help='Learning rate step gamma (default: 0.9)') 98 | parser.add_argument('--no-cuda', action='store_true', default=False, 99 | help='disables CUDA training') 100 | parser.add_argument('--dry-run', action='store_true', default=False, 101 | help='quickly check a single pass') 102 | parser.add_argument('--seed', type=int, default=1, metavar='S', 103 | help='random seed (default: 1)') 104 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 105 | help='how many batches to wait before logging training status') 106 | parser.add_argument('--save-model', action='store_true', default=False, 107 | help='For Saving the current Model') 108 | parser.add_argument('--normal-mixture-prior-weight', type=float, default=0.75, 109 | help='mixture weight of normal mixture prior (default: 0.75)') 110 | parser.add_argument('--normal-mixture-prior-log-scale1', type=float, default=-1.0, 111 | help='log scale of first component of normal mixture prior (default: -1.0)') 112 | parser.add_argument('--normal-mixture-prior-log-scale2', type=float, default=-6.0, 113 | help='log scale of second component of normal mixture prior (default: -6.0)') 114 | parser.add_argument('--log-prior-weight', type=float, default=1e-6, 115 | help='log prior weight (default: 1e-6)') 116 | parser.add_argument('--num-burn-in-steps', type=int, default=60000, 117 | help='number of burn-in steps (default: 60000)') 118 | parser.add_argument('--precondition-decay-rate', type=float, default=0.95, 119 | help='precondition decay rate (default: 0.95)') 120 | args = parser.parse_args() 121 | use_cuda = not args.no_cuda and torch.cuda.is_available() 122 | #use_mps = not args.no_mps and torch.backends.mps.is_available() 123 | 124 | torch.manual_seed(args.seed) 125 | 126 | if use_cuda: 127 | device = torch.device("cuda") 128 | #elif use_mps: 129 | # device = torch.device("mps") 130 | else: 131 | device = torch.device("cpu") 132 | 133 | train_kwargs = {'batch_size': args.batch_size} 134 | test_kwargs = {'batch_size': args.test_batch_size} 135 | if use_cuda: 136 | cuda_kwargs = {'num_workers': 1, 137 | 'pin_memory': True, 138 | 'shuffle': True} 139 | train_kwargs.update(cuda_kwargs) 140 | test_kwargs.update(cuda_kwargs) 141 | 142 | transform=transforms.Compose([ 143 | transforms.ToTensor(), 144 | transforms.Normalize((0.1307,), (0.3081,)) 145 | ]) 146 | dataset1 = datasets.MNIST('../data', train=True, download=True, 147 | transform=transform) 148 | dataset2 = datasets.MNIST('../data', train=False, 149 | transform=transform) 150 | train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) 151 | test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) 152 | 153 | model = Net() 154 | 155 | # Prior is defined as in https://arxiv.org/abs/1505.05424 156 | # Prior arguments (WITHOUT gradient tracking) 157 | prior_builder, prior_kwargs = get_mixture_log_scale_normal( 158 | model.parameters(), 159 | weights=[args.normal_mixture_prior_weight, 1 - args.normal_mixture_prior_weight], 160 | locs=(0.0, 0.0), 161 | log_scales=(args.normal_mixture_prior_log_scale1, args.normal_mixture_prior_log_scale2) 162 | ) 163 | 164 | # Bayesian model 165 | model = PriorModule(model, prior_builder, prior_kwargs).to(device) 166 | 167 | optimizer = SGLD( 168 | model.parameters(), 169 | lr=args.lr, 170 | num_burn_in_steps=args.num_burn_in_steps, 171 | precondition_decay_rate=args.precondition_decay_rate, 172 | ) 173 | 174 | scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) 175 | for epoch in range(1, args.epochs + 1): 176 | train(args, model, device, train_loader, optimizer, epoch, args.log_prior_weight) 177 | test(model, device, test_loader, args.log_prior_weight) 178 | scheduler.step() 179 | 180 | if args.save_model: 181 | torch.save(model.state_dict(), "mnist_cnn.pt") 182 | 183 | 184 | if __name__ == '__main__': 185 | main() 186 | -------------------------------------------------------------------------------- /examples/mnist/train_svgd.py: -------------------------------------------------------------------------------- 1 | # Adapted from: 2 | # https://github.com/pytorch/examples/blob/9aad148615b7519eadfa1a60356116a50561f192/mnist/main.py 3 | 4 | # Changes to the code are kept to a minimum to facilitate the comparison with the original example 5 | 6 | from __future__ import print_function 7 | import argparse 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from torchvision import datasets, transforms 13 | from torch.optim.lr_scheduler import StepLR 14 | 15 | import math 16 | from bayestorch.distributions import get_mixture_log_scale_normal 17 | from bayestorch.nn import ParticlePosteriorModule 18 | from bayestorch.optim import SVGD 19 | 20 | 21 | class Net(nn.Module): 22 | def __init__(self): 23 | super(Net, self).__init__() 24 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 25 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 26 | #self.dropout1 = nn.Dropout(0.25) 27 | #self.dropout2 = nn.Dropout(0.5) 28 | self.fc1 = nn.Linear(9216, 128) 29 | self.fc2 = nn.Linear(128, 10) 30 | 31 | def forward(self, x): 32 | x = self.conv1(x) 33 | x = F.relu(x) 34 | x = self.conv2(x) 35 | x = F.relu(x) 36 | x = F.max_pool2d(x, 2) 37 | #x = self.dropout1(x) 38 | x = torch.flatten(x, 1) 39 | x = self.fc1(x) 40 | x = F.relu(x) 41 | #x = self.dropout2(x) 42 | x = self.fc2(x) 43 | output = F.log_softmax(x, dim=1) 44 | return output 45 | 46 | 47 | def rbf_kernel(x1, x2): 48 | deltas = torch.cdist(x1, x2) 49 | squared_deltas = deltas**2 50 | bandwidth = ( 51 | squared_deltas.detach().median() 52 | / math.log(min(x1.shape[0], x2.shape[0])) 53 | ) 54 | log_kernels = -squared_deltas / bandwidth 55 | kernels = log_kernels.exp() 56 | return kernels 57 | 58 | 59 | def train(args, model, device, train_loader, preconditioner, optimizer, epoch, log_prior_weight): 60 | model.train() 61 | for batch_idx, (data, target) in enumerate(train_loader): 62 | data, target = data.to(device), target.to(device) 63 | optimizer.zero_grad() 64 | #output = model(data) 65 | #loss = F.nll_loss(output, target) 66 | output, log_prior = model(data, return_log_prior=True) 67 | loss = F.nll_loss(output, target, reduction="sum") - log_prior_weight * log_prior 68 | loss.backward() 69 | preconditioner.step() 70 | optimizer.step() 71 | if batch_idx % args.log_interval == 0: 72 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 73 | epoch, batch_idx * len(data), len(train_loader.dataset), 74 | 100. * batch_idx / len(train_loader), loss.item())) 75 | if args.dry_run: 76 | break 77 | 78 | 79 | def test(model, device, test_loader, log_prior_weight): 80 | model.eval() 81 | test_loss = 0 82 | correct = 0 83 | with torch.no_grad(): 84 | for data, target in test_loader: 85 | data, target = data.to(device), target.to(device) 86 | #output = model(data) 87 | #test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 88 | output, log_prior = model(data, return_log_prior=True) 89 | test_loss += (F.nll_loss(output, target, reduction="sum") - log_prior_weight * log_prior).item() # sum up batch loss 90 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 91 | correct += pred.eq(target.view_as(pred)).sum().item() 92 | 93 | test_loss /= len(test_loader.dataset) 94 | 95 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 96 | test_loss, correct, len(test_loader.dataset), 97 | 100. * correct / len(test_loader.dataset))) 98 | 99 | 100 | def main(): 101 | # Training settings 102 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 103 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 104 | help='input batch size for training (default: 64)') 105 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 106 | help='input batch size for testing (default: 1000)') 107 | parser.add_argument('--epochs', type=int, default=14, metavar='N', 108 | help='number of epochs to train (default: 14)') 109 | parser.add_argument('--lr', type=float, default=1.0, metavar='LR', 110 | help='learning rate (default: 1.0)') 111 | parser.add_argument('--gamma', type=float, default=0.7, metavar='M', 112 | help='Learning rate step gamma (default: 0.7)') 113 | parser.add_argument('--no-cuda', action='store_true', default=False, 114 | help='disables CUDA training') 115 | parser.add_argument('--dry-run', action='store_true', default=False, 116 | help='quickly check a single pass') 117 | parser.add_argument('--seed', type=int, default=1, metavar='S', 118 | help='random seed (default: 1)') 119 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 120 | help='how many batches to wait before logging training status') 121 | parser.add_argument('--save-model', action='store_true', default=False, 122 | help='For Saving the current Model') 123 | parser.add_argument('--normal-mixture-prior-weight', type=float, default=0.75, 124 | help='mixture weight of normal mixture prior (default: 0.75)') 125 | parser.add_argument('--normal-mixture-prior-log-scale1', type=float, default=-1.0, 126 | help='log scale of first component of normal mixture prior (default: -1.0)') 127 | parser.add_argument('--normal-mixture-prior-log-scale2', type=float, default=-6.0, 128 | help='log scale of second component of normal mixture prior (default: -6.0)') 129 | parser.add_argument('--log-prior-weight', type=float, default=1e-6, 130 | help='log prior weight (default: 1e-6)') 131 | parser.add_argument('--num-particles', type=int, default=10, 132 | help='number of particles (default: 10)') 133 | args = parser.parse_args() 134 | use_cuda = not args.no_cuda and torch.cuda.is_available() 135 | #use_mps = not args.no_mps and torch.backends.mps.is_available() 136 | 137 | torch.manual_seed(args.seed) 138 | 139 | if use_cuda: 140 | device = torch.device("cuda") 141 | #elif use_mps: 142 | # device = torch.device("mps") 143 | else: 144 | device = torch.device("cpu") 145 | 146 | train_kwargs = {'batch_size': args.batch_size} 147 | test_kwargs = {'batch_size': args.test_batch_size} 148 | if use_cuda: 149 | cuda_kwargs = {'num_workers': 1, 150 | 'pin_memory': True, 151 | 'shuffle': True} 152 | train_kwargs.update(cuda_kwargs) 153 | test_kwargs.update(cuda_kwargs) 154 | 155 | transform=transforms.Compose([ 156 | transforms.ToTensor(), 157 | transforms.Normalize((0.1307,), (0.3081,)) 158 | ]) 159 | dataset1 = datasets.MNIST('../data', train=True, download=True, 160 | transform=transform) 161 | dataset2 = datasets.MNIST('../data', train=False, 162 | transform=transform) 163 | train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) 164 | test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) 165 | 166 | model = Net() 167 | 168 | # Prior is defined as in https://arxiv.org/abs/1505.05424 169 | # Prior arguments (WITHOUT gradient tracking) 170 | prior_builder, prior_kwargs = get_mixture_log_scale_normal( 171 | model.parameters(), 172 | weights=[args.normal_mixture_prior_weight, 1 - args.normal_mixture_prior_weight], 173 | locs=(0.0, 0.0), 174 | log_scales=(args.normal_mixture_prior_log_scale1, args.normal_mixture_prior_log_scale2) 175 | ) 176 | 177 | # Bayesian model 178 | model = ParticlePosteriorModule(model, prior_builder, prior_kwargs, args.num_particles).to(device) 179 | 180 | # SVGD preconditioner 181 | preconditioner = SVGD(model.parameters(include_all=False), rbf_kernel, args.num_particles) 182 | 183 | optimizer = optim.Adadelta(model.parameters(), lr=args.lr) 184 | 185 | scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) 186 | for epoch in range(1, args.epochs + 1): 187 | train(args, model, device, train_loader, preconditioner, optimizer, epoch, args.log_prior_weight) 188 | test(model, device, test_loader, args.log_prior_weight) 189 | scheduler.step() 190 | 191 | if args.save_model: 192 | torch.save(model.state_dict(), "mnist_cnn.pt") 193 | 194 | 195 | if __name__ == '__main__': 196 | main() 197 | -------------------------------------------------------------------------------- /bayestorch/distributions/cat_distribution.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2022 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Concatenated distribution.""" 18 | 19 | from typing import List, Optional, Sequence 20 | 21 | import torch 22 | from torch import Size, Tensor 23 | from torch.distributions import Distribution, kl_divergence, register_kl 24 | from torch.distributions.constraints import independent 25 | 26 | from bayestorch.distributions.constraints import cat 27 | 28 | 29 | __all__ = [ 30 | "CatDistribution", 31 | ] 32 | 33 | 34 | class CatDistribution(Distribution): 35 | """Concatenate a sequence of base distributions with identical 36 | batch shapes along one of their event dimensions. 37 | 38 | Examples 39 | -------- 40 | >>> from torch.distributions import Categorical, Normal 41 | >>> 42 | >>> from bayestorch.distributions import CatDistribution 43 | >>> 44 | >>> 45 | >>> loc = 0.0 46 | >>> scale = 1.0 47 | >>> logits = torch.as_tensor([0.25, 0.15, 0.10, 0.30, 0.20]) 48 | >>> distribution = CatDistribution([Normal(loc, scale), Categorical(logits)]) 49 | 50 | """ 51 | 52 | has_enumerate_support = False 53 | arg_constraints = {} 54 | 55 | # override 56 | def __init__( 57 | self, 58 | base_distributions: "Sequence[Distribution]", 59 | dim: "int" = 0, 60 | validate_args: "Optional[bool]" = None, 61 | ) -> "None": 62 | """Initialize the object. 63 | 64 | Parameters 65 | ---------- 66 | base_distributions: 67 | The base distributions to concatenate. 68 | dim: 69 | The event dimension along which to concatenate. 70 | validate_args: 71 | True to validate the arguments, False otherwise. 72 | Default to ``__debug__``. 73 | 74 | Raises 75 | ------ 76 | IndexError 77 | If `dim` is out of range or not integer. 78 | ValueError 79 | If batch shapes of the base distributions are not identical, or the 80 | corresponding expanded event shapes differ along dimensions other 81 | than the `dim`-th. 82 | 83 | """ 84 | self.base_dists = base_distributions 85 | event_ndims = max(len(d.event_shape or [1]) for d in base_distributions) 86 | self._expanded_event_shapes = [ 87 | torch.Size(list(d.event_shape) + [1] * (event_ndims - len(d.event_shape))) 88 | for d in base_distributions 89 | ] 90 | if dim < -event_ndims or dim >= event_ndims or not float(dim).is_integer(): 91 | raise IndexError( 92 | f"`dim` ({dim}) must be in the integer interval [-{event_ndims}, {event_ndims})" 93 | ) 94 | self.dim = dim = int(dim) % event_ndims 95 | batch_shape = base_distributions[0].batch_shape 96 | for base_dist in base_distributions[1:]: 97 | if base_dist.batch_shape != batch_shape: 98 | raise ValueError( 99 | f"Batch shapes of all base distributions " 100 | f"({[d.batch_shape for d in base_distributions]}) " 101 | f"must be identical" 102 | ) 103 | event_shape = list(self._expanded_event_shapes[0]) 104 | for expanded_event_shape in self._expanded_event_shapes[1:]: 105 | if (list(expanded_event_shape[:dim] + expanded_event_shape[dim + 1 :])) != ( 106 | event_shape[:dim] + event_shape[dim + 1 :] 107 | ): 108 | raise ValueError( 109 | f"Expanded event shapes of all base distributions " 110 | f"({self._expanded_event_shapes}) must be identical " 111 | f"except for the `dim`-th ({dim}) dimension" 112 | ) 113 | event_shape[dim] += expanded_event_shape[dim] 114 | super().__init__(batch_shape, torch.Size(event_shape), validate_args) 115 | 116 | # override 117 | def expand( 118 | self, 119 | batch_shape: "Size" = Size(), # noqa: B008 120 | _instance: "Optional[CatDistribution]" = None, 121 | ) -> "CatDistribution": 122 | new = self._get_checked_instance(CatDistribution, _instance) 123 | new.base_dists = [d.expand(batch_shape) for d in self.base_dists] 124 | new.dim = self.dim 125 | super(CatDistribution, new).__init__(batch_shape, self.event_shape, False) 126 | new._validate_args = self._validate_args 127 | return new 128 | 129 | # override 130 | @property 131 | def support(self) -> "cat": 132 | return cat( 133 | [ 134 | independent(d.support, len(self.event_shape) - len(d.event_shape)) 135 | for d in self.base_dists 136 | ], 137 | dim=self.dim - len(self.event_shape), 138 | lengths=[shape[self.dim] for shape in self._expanded_event_shapes], 139 | ) 140 | 141 | # override 142 | @property 143 | def mean(self) -> "Tensor": 144 | return self._cat([d.mean for d in self.base_dists]) 145 | 146 | @property 147 | def mode(self) -> "Tensor": 148 | return self._cat([d.mode for d in self.base_dists]) 149 | 150 | # override 151 | @property 152 | def variance(self) -> "Tensor": 153 | return self._cat([d.variance for d in self.base_dists]) 154 | 155 | # override 156 | def sample(self, sample_shape: "Size" = torch.Size()) -> "Tensor": # noqa: B008 157 | return self._cat([d.sample(sample_shape) for d in self.base_dists]) 158 | 159 | # override 160 | def rsample(self, sample_shape: "Size" = torch.Size()) -> "Tensor": # noqa: B008 161 | return self._cat([d.rsample(sample_shape) for d in self.base_dists]) 162 | 163 | # override 164 | def log_prob(self, value: "Tensor") -> "Tensor": 165 | chunks = self._split(value) 166 | return torch.stack( 167 | [d.log_prob(c) for d, c in zip(self.base_dists, chunks)] 168 | ).sum(dim=0) 169 | 170 | # override 171 | def cdf(self, value: "Tensor") -> "Tensor": 172 | chunks = self._split(value) 173 | return torch.stack([d.cdf(c) for d, c in zip(self.base_dists, chunks)]).prod( 174 | dim=0 175 | ) 176 | 177 | # override 178 | @property 179 | def has_rsample(self) -> "bool": 180 | return all(d.has_rsample for d in self.base_dists) 181 | 182 | # override 183 | def entropy(self) -> "Tensor": 184 | return torch.stack([d.entropy() for d in self.base_dists]).sum(dim=0) 185 | 186 | def _cat(self, inputs: "Sequence[Tensor]") -> "Tensor": 187 | inputs = [ 188 | x[(...,) + (None,) * (len(self.event_shape) - len(d.event_shape))] 189 | for x, d in zip(inputs, self.base_dists) 190 | ] 191 | return torch.cat(inputs, dim=self.dim - len(self.event_shape)) 192 | 193 | def _split(self, input: "Tensor") -> "List[Tensor]": 194 | split_sizes = [shape[self.dim] for shape in self._expanded_event_shapes] 195 | chunks = input.split(split_sizes, dim=self.dim - len(self.event_shape)) 196 | return [ 197 | chunk.reshape( 198 | ( 199 | *input.shape[: input.ndim - len(self.event_shape)], 200 | *d.event_shape, 201 | ) 202 | ) 203 | for chunk, d in zip(chunks, self.base_dists) 204 | ] 205 | 206 | # override 207 | def __repr__(self) -> "str": 208 | return f"{type(self).__name__}({self.base_dists}, dim: {self.dim})" 209 | 210 | 211 | @register_kl(CatDistribution, CatDistribution) 212 | def _kl_cat_cat(p: "CatDistribution", q: "CatDistribution") -> "Tensor": 213 | if (p.dim != q.dim) or (len(p.base_dists) != len(q.base_dists)): 214 | raise NotImplementedError 215 | return torch.stack( 216 | [kl_divergence(d, q.base_dists[i]) for i, d in enumerate(p.base_dists)] 217 | ).sum(dim=0) 218 | 219 | 220 | @register_kl(CatDistribution, Distribution) 221 | def _kl_cat_distribution(p: "CatDistribution", q: "Distribution") -> "Tensor": 222 | if len(p.base_dists) > 1: 223 | raise NotImplementedError 224 | return kl_divergence(p.base_dists[0], q) 225 | 226 | 227 | @register_kl(Distribution, CatDistribution) 228 | def _kl_distribution_cat(p: "Distribution", q: "CatDistribution") -> "Tensor": 229 | if len(q.base_dists) > 1: 230 | raise NotImplementedError 231 | return kl_divergence(p, q.base_dists[0]) 232 | -------------------------------------------------------------------------------- /examples/mnist/train_bbb.py: -------------------------------------------------------------------------------- 1 | # Adapted from: 2 | # https://github.com/pytorch/examples/blob/9aad148615b7519eadfa1a60356116a50561f192/mnist/main.py 3 | 4 | # Changes to the code are kept to a minimum to facilitate the comparison with the original example 5 | 6 | from __future__ import print_function 7 | import argparse 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from torchvision import datasets, transforms 13 | from torch.optim.lr_scheduler import StepLR 14 | 15 | from bayestorch.distributions import get_mixture_log_scale_normal, get_softplus_inv_scale_normal 16 | from bayestorch.nn import VariationalPosteriorModule 17 | 18 | 19 | class Net(nn.Module): 20 | def __init__(self): 21 | super(Net, self).__init__() 22 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 23 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 24 | #self.dropout1 = nn.Dropout(0.25) 25 | #self.dropout2 = nn.Dropout(0.5) 26 | self.fc1 = nn.Linear(9216, 128) 27 | self.fc2 = nn.Linear(128, 10) 28 | 29 | def forward(self, x): 30 | x = self.conv1(x) 31 | x = F.relu(x) 32 | x = self.conv2(x) 33 | x = F.relu(x) 34 | x = F.max_pool2d(x, 2) 35 | #x = self.dropout1(x) 36 | x = torch.flatten(x, 1) 37 | x = self.fc1(x) 38 | x = F.relu(x) 39 | #x = self.dropout2(x) 40 | x = self.fc2(x) 41 | output = F.log_softmax(x, dim=1) 42 | return output 43 | 44 | 45 | def train(args, model, device, train_loader, optimizer, epoch, num_train_mc_samples, kl_div_weight): 46 | model.train() 47 | for batch_idx, (data, target) in enumerate(train_loader): 48 | data, target = data.to(device), target.to(device) 49 | optimizer.zero_grad() 50 | #output = model(data) 51 | #loss = F.nll_loss(output, target) 52 | output, kl_div = model(data, num_mc_samples=num_train_mc_samples, return_kl_div=True) 53 | loss = F.nll_loss(output, target, reduction="sum") + kl_div_weight * kl_div 54 | loss.backward() 55 | optimizer.step() 56 | if batch_idx % args.log_interval == 0: 57 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 58 | epoch, batch_idx * len(data), len(train_loader.dataset), 59 | 100. * batch_idx / len(train_loader), loss.item())) 60 | if args.dry_run: 61 | break 62 | 63 | 64 | def test(model, device, test_loader, num_test_mc_samples, kl_div_weight): 65 | model.eval() 66 | test_loss = 0 67 | correct = 0 68 | with torch.no_grad(): 69 | for data, target in test_loader: 70 | data, target = data.to(device), target.to(device) 71 | #output = model(data) 72 | #test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 73 | output, kl_div = model(data, num_mc_samples=num_test_mc_samples, return_kl_div=True) 74 | test_loss += (F.nll_loss(output, target, reduction="sum") + kl_div_weight * kl_div).item() # sum up batch loss 75 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 76 | correct += pred.eq(target.view_as(pred)).sum().item() 77 | 78 | test_loss /= len(test_loader.dataset) 79 | 80 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 81 | test_loss, correct, len(test_loader.dataset), 82 | 100. * correct / len(test_loader.dataset))) 83 | 84 | 85 | def main(): 86 | # Training settings 87 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 88 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 89 | help='input batch size for training (default: 64)') 90 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 91 | help='input batch size for testing (default: 1000)') 92 | parser.add_argument('--epochs', type=int, default=14, metavar='N', 93 | help='number of epochs to train (default: 14)') 94 | parser.add_argument('--lr', type=float, default=1.0, metavar='LR', 95 | help='learning rate (default: 1.0)') 96 | parser.add_argument('--gamma', type=float, default=0.7, metavar='M', 97 | help='Learning rate step gamma (default: 0.7)') 98 | parser.add_argument('--no-cuda', action='store_true', default=False, 99 | help='disables CUDA training') 100 | parser.add_argument('--dry-run', action='store_true', default=False, 101 | help='quickly check a single pass') 102 | parser.add_argument('--seed', type=int, default=1, metavar='S', 103 | help='random seed (default: 1)') 104 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 105 | help='how many batches to wait before logging training status') 106 | parser.add_argument('--save-model', action='store_true', default=False, 107 | help='For Saving the current Model') 108 | parser.add_argument('--normal-mixture-prior-weight', type=float, default=0.75, 109 | help='mixture weight of normal mixture prior (default: 0.75)') 110 | parser.add_argument('--normal-mixture-prior-log-scale1', type=float, default=-1.0, 111 | help='log scale of first component of normal mixture prior (default: -1.0)') 112 | parser.add_argument('--normal-mixture-prior-log-scale2', type=float, default=-6.0, 113 | help='log scale of second component of normal mixture prior (default: -6.0)') 114 | parser.add_argument('--normal-posterior-softplus-inv-scale', type=float, default=-7.0, 115 | help='inverse softplus scale of normal posterior (default: -7.0)') 116 | parser.add_argument('--kl-div-weight', type=float, default=1e-6, 117 | help='Kullback-Leibler divergence weight (default: 1e-6)') 118 | parser.add_argument('--num-train-mc-samples', type=int, default=10, 119 | help='number of Monte Carlo samples for training (default: 10)') 120 | parser.add_argument('--num-test-mc-samples', type=int, default=10, 121 | help='number of Monte Carlo samples for testing (default: 10)') 122 | args = parser.parse_args() 123 | use_cuda = not args.no_cuda and torch.cuda.is_available() 124 | #use_mps = not args.no_mps and torch.backends.mps.is_available() 125 | 126 | torch.manual_seed(args.seed) 127 | 128 | if use_cuda: 129 | device = torch.device("cuda") 130 | #elif use_mps: 131 | # device = torch.device("mps") 132 | else: 133 | device = torch.device("cpu") 134 | 135 | train_kwargs = {'batch_size': args.batch_size} 136 | test_kwargs = {'batch_size': args.test_batch_size} 137 | if use_cuda: 138 | cuda_kwargs = {'num_workers': 1, 139 | 'pin_memory': True, 140 | 'shuffle': True} 141 | train_kwargs.update(cuda_kwargs) 142 | test_kwargs.update(cuda_kwargs) 143 | 144 | transform=transforms.Compose([ 145 | transforms.ToTensor(), 146 | transforms.Normalize((0.1307,), (0.3081,)) 147 | ]) 148 | dataset1 = datasets.MNIST('../data', train=True, download=True, 149 | transform=transform) 150 | dataset2 = datasets.MNIST('../data', train=False, 151 | transform=transform) 152 | train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) 153 | test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) 154 | 155 | model = Net() 156 | 157 | # Prior and posterior are defined as in https://arxiv.org/abs/1505.05424 158 | # Prior arguments (WITHOUT gradient tracking) 159 | prior_builder, prior_kwargs = get_mixture_log_scale_normal( 160 | model.parameters(), 161 | weights=[args.normal_mixture_prior_weight, 1 - args.normal_mixture_prior_weight], 162 | locs=(0.0, 0.0), 163 | log_scales=(args.normal_mixture_prior_log_scale1, args.normal_mixture_prior_log_scale2) 164 | ) 165 | 166 | # Posterior arguments (WITH gradient tracking) 167 | posterior_builder, posterior_kwargs = get_softplus_inv_scale_normal( 168 | model.parameters(), loc=0.0, softplus_inv_scale=args.normal_posterior_softplus_inv_scale, requires_grad=True, 169 | ) 170 | 171 | # Bayesian model 172 | model = VariationalPosteriorModule( 173 | model, prior_builder, prior_kwargs, posterior_builder, posterior_kwargs 174 | ).to(device) 175 | 176 | optimizer = optim.Adadelta(model.parameters(), lr=args.lr) 177 | 178 | scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) 179 | for epoch in range(1, args.epochs + 1): 180 | train(args, model, device, train_loader, optimizer, epoch, args.num_train_mc_samples, args.kl_div_weight) 181 | test(model, device, test_loader, args.num_test_mc_samples, args.kl_div_weight) 182 | scheduler.step() 183 | 184 | if args.save_model: 185 | torch.save(model.state_dict(), "mnist_cnn.pt") 186 | 187 | 188 | if __name__ == '__main__': 189 | main() 190 | -------------------------------------------------------------------------------- /bayestorch/nn/prior_module.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2022 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Prior module.""" 18 | 19 | from typing import ( 20 | Any, 21 | Callable, 22 | Dict, 23 | Iterable, 24 | Iterator, 25 | Optional, 26 | Tuple, 27 | TypeVar, 28 | Union, 29 | ) 30 | 31 | from torch import Tensor, nn 32 | from torch.distributions import Distribution, Independent 33 | from torch.nn import Module, Parameter 34 | 35 | 36 | __all__ = [ 37 | "PriorModule", 38 | ] 39 | 40 | 41 | _T = TypeVar("_T", bound="PriorModule") 42 | 43 | 44 | class PriorModule(Module): 45 | """Bayesian module that defines a prior over its parameters. 46 | 47 | Examples 48 | -------- 49 | >>> import torch 50 | >>> from torch import nn 51 | >>> 52 | >>> from bayestorch.distributions import LogScaleNormal 53 | >>> from bayestorch.nn import PriorModule 54 | >>> 55 | >>> 56 | >>> batch_size = 10 57 | >>> in_features = 4 58 | >>> out_features = 2 59 | >>> model = nn.Linear(in_features, out_features) 60 | >>> num_parameters = sum(parameter.numel() for parameter in model.parameters()) 61 | >>> model = PriorModule( 62 | ... model, 63 | ... prior_builder=LogScaleNormal, 64 | ... prior_kwargs={ 65 | ... "loc": torch.zeros(num_parameters), 66 | ... "log_scale": torch.full((num_parameters,), -1.0), 67 | ... }, 68 | ... ) 69 | >>> input = torch.rand(batch_size, in_features) 70 | >>> output, log_prior = model(input, return_log_prior=True) 71 | 72 | """ 73 | 74 | prior: "Distribution" 75 | """The prior distribution.""" 76 | 77 | # override 78 | def __init__( 79 | self, 80 | module: "Module", 81 | prior_builder: "Callable[..., Distribution]", 82 | prior_kwargs: "Dict[str, Any]", 83 | module_parameters: "Optional[Iterable[Tensor]]" = None, 84 | ) -> "None": 85 | """Initialize the object. 86 | 87 | Parameters 88 | ---------- 89 | module: 90 | The module. 91 | prior_builder: 92 | The prior builder, i.e. a callable that receives keyword 93 | arguments and returns a prior with size (batch + event) 94 | equal to the length of the 1D tensor obtained by flattening 95 | and concatenating each tensor in `module_parameters`. 96 | prior_kwargs: 97 | The keyword arguments to pass to the prior builder. 98 | Tensor arguments are internally registered as parameters 99 | if their `requires_grad` attribute is True, as persistent 100 | buffers otherwise. 101 | module_parameters: 102 | The module parameters over which the prior is defined. 103 | Useful to selectively define a prior over a restricted 104 | subset of submodules/parameters. 105 | Default to ``module.parameters()``. 106 | 107 | Raises 108 | ------ 109 | ValueError 110 | If an invalid argument value is given. 111 | 112 | """ 113 | super().__init__() 114 | self.module = module 115 | self.prior_builder = prior_builder 116 | self.prior_kwargs = prior_kwargs = { 117 | k: v for k, v in prior_kwargs.items() 118 | } # Avoid side effects 119 | self.module_parameters = module_parameters = list( 120 | module_parameters or module.parameters() 121 | ) 122 | if not set(module_parameters).issubset(set(module.parameters())): 123 | raise ValueError( 124 | f"`module_parameters` ({module_parameters}) must be a subset of `module.parameters()` ({module.parameters()})" 125 | ) 126 | 127 | # Build prior 128 | self.prior = self._build_distribution("prior", prior_builder, prior_kwargs) 129 | 130 | # override 131 | def named_parameters( 132 | self, 133 | *args: "Any", 134 | include_all: "bool" = True, 135 | **kwargs: "Any", 136 | ) -> "Iterator[Tuple[str, Parameter]]": 137 | """Return the named parameters. 138 | 139 | Parameters 140 | ---------- 141 | include_all: 142 | True to include all the named parameters, 143 | False to include only those over which the 144 | prior is defined. 145 | 146 | Returns 147 | ------- 148 | The named parameters. 149 | 150 | """ 151 | if include_all: 152 | return super().named_parameters(*args, **kwargs) 153 | return ( 154 | (k, v) 155 | for k, v in super().named_parameters(*args, **kwargs) 156 | if any(v is parameter for parameter in self.module_parameters) 157 | ) 158 | 159 | # override 160 | def parameters( 161 | self, 162 | *args: "Any", 163 | **kwargs: "Any", 164 | ) -> "Iterator[Parameter]": 165 | for _, parameter in self.named_parameters(*args, **kwargs): 166 | yield parameter 167 | 168 | # override 169 | def forward( 170 | self, *args: "Any", return_log_prior: "bool" = False, **kwargs: "Any" 171 | ) -> "Union[Any, Tuple[Any, Tensor]]": 172 | """Forward pass. 173 | 174 | In the following, let `B = {B_1, ..., B_k}` denote the 175 | batch shape and `O = {O_1, ..., O_m}` the shape of a 176 | leaf value of the underlying module output (can be a 177 | nested tensor). 178 | 179 | Parameters 180 | ---------- 181 | args: 182 | The positional arguments to pass to the underlying module. 183 | return_log_prior: 184 | True to additionally return the log prior (usually 185 | required during training), False otherwise. 186 | kwargs: 187 | The keyword arguments to pass to the underlying module. 188 | 189 | Returns 190 | ------- 191 | - The output, shape of a leaf value: ``[*B, *O]``; 192 | - if `return_log_prior` is True, the log prior, shape: ``[]``. 193 | 194 | """ 195 | # Forward pass 196 | output = self.module(*args, **kwargs) 197 | 198 | if not return_log_prior: 199 | return output 200 | 201 | # Extract particle 202 | particle = nn.utils.parameters_to_vector(self.module_parameters) 203 | 204 | # Compute log prior 205 | log_prior = self.prior.log_prob(particle) 206 | 207 | return output, log_prior 208 | 209 | # override 210 | def _apply(self, *args: "Any", **kwargs: "Any") -> "_T": 211 | super()._apply(*args, **kwargs) 212 | 213 | # Rebuild prior using updated parameters/buffers 214 | # (`_apply` might create copies of parameters/buffers, 215 | # therefore references within the prior are lost) 216 | self.prior = self._build_distribution( 217 | "prior", self.prior_builder, self.prior_kwargs 218 | ) 219 | 220 | return self 221 | 222 | def _build_distribution( 223 | self, 224 | name: "str", 225 | distribution_builder: "Callable[..., Distribution]", 226 | distribution_kwargs: "Dict[str, Any]", 227 | ) -> "Distribution": 228 | # Extract particle 229 | # particle = module parameters flattened into a 1D vector 230 | particle = nn.utils.parameters_to_vector(self.module_parameters) 231 | 232 | # Build distribution 233 | for k, v in distribution_kwargs.items(): 234 | key = f"{name}_{k}" 235 | if isinstance(v, Tensor): 236 | if key in self._parameters: 237 | v = self._parameters[key] 238 | elif key in self._buffers: 239 | v = self._buffers[key] 240 | else: 241 | v = self._register_tensor(key, v.to(particle)) 242 | distribution_kwargs[k] = v 243 | distribution = distribution_builder(**distribution_kwargs) 244 | 245 | # Adjust distribution shape 246 | batch_ndims = len(distribution.batch_shape) 247 | if batch_ndims > 0: 248 | distribution = Independent(distribution, batch_ndims) 249 | 250 | # Validate distribution event shape 251 | event_shape = distribution.event_shape 252 | if event_shape != (1,) and event_shape != particle.shape: 253 | raise ValueError( 254 | f"{name.capitalize()} size (batch + event) ({event_shape.numel()}) " 255 | f"must be equal to the number of module parameters ({particle.numel()})" 256 | ) 257 | 258 | return distribution 259 | 260 | def _register_tensor(self, name: "str", input: "Tensor") -> "Tensor": 261 | if input.requires_grad: 262 | input = Parameter(input) 263 | self.register_parameter(name, input) 264 | else: 265 | self.register_buffer(name, input) 266 | return input 267 | 268 | # override 269 | def __repr__(self) -> "str": 270 | return ( 271 | f"{type(self).__name__}" 272 | f"(module: {self.module}, " 273 | f"prior: {self.prior}, " 274 | f"module_parameters: {sum(parameter.numel() for parameter in self.module_parameters)})" 275 | ) 276 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | Copyright 2022 Luca Della Libera. 179 | 180 | Licensed under the Apache License, Version 2.0 (the "License"); 181 | you may not use this file except in compliance with the License. 182 | You may obtain a copy of the License at 183 | 184 | https://www.apache.org/licenses/LICENSE-2.0 185 | 186 | Unless required by applicable law or agreed to in writing, software 187 | distributed under the License is distributed on an "AS IS" BASIS, 188 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 189 | See the License for the specific language governing permissions and 190 | limitations under the License. 191 | -------------------------------------------------------------------------------- /bayestorch/nn/particle_posterior_module.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2022 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Particle posterior module.""" 18 | 19 | import copy 20 | from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Tuple, Union 21 | 22 | import torch 23 | from torch import Tensor 24 | from torch.distributions import Distribution 25 | from torch.nn import Module, ModuleList, Parameter 26 | 27 | from bayestorch.nn.prior_module import PriorModule 28 | from bayestorch.nn.utils import nested_apply 29 | 30 | 31 | __all__ = [ 32 | "ParticlePosteriorModule", 33 | ] 34 | 35 | 36 | class ParticlePosteriorModule(PriorModule): 37 | """Bayesian module that defines a prior and a particle-based 38 | posterior over its parameters. 39 | 40 | References 41 | ---------- 42 | .. [1] Q. Liu and D. Wang. 43 | "Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm". 44 | In: Advances in Neural Information Processing Systems. 2016, pp. 2378-2386. 45 | URL: https://arxiv.org/abs/1608.04471 46 | 47 | Examples 48 | -------- 49 | >>> import torch 50 | >>> from torch import nn 51 | >>> 52 | >>> from bayestorch.distributions import LogScaleNormal 53 | >>> from bayestorch.nn import ParticlePosteriorModule 54 | >>> 55 | >>> 56 | >>> num_particles = 5 57 | >>> batch_size = 10 58 | >>> in_features = 4 59 | >>> out_features = 2 60 | >>> model = nn.Linear(in_features, out_features) 61 | >>> num_parameters = sum(parameter.numel() for parameter in model.parameters()) 62 | >>> model = ParticlePosteriorModule( 63 | ... model, 64 | ... prior_builder=LogScaleNormal, 65 | ... prior_kwargs={ 66 | ... "loc": torch.zeros(num_parameters), 67 | ... "log_scale": torch.full((num_parameters,), -1.0), 68 | ... }, 69 | ... num_particles=num_particles, 70 | ... ) 71 | >>> input = torch.rand(batch_size, in_features) 72 | >>> output = model(input) 73 | >>> outputs, log_priors = model( 74 | ... input, 75 | ... return_log_prior=True, 76 | ... reduction="none", 77 | ... ) 78 | 79 | """ 80 | 81 | replicas: "ModuleList" 82 | """The module replicas (one for each particle).""" 83 | 84 | # override 85 | def __init__( 86 | self, 87 | module: "Module", 88 | prior_builder: "Callable[..., Distribution]", 89 | prior_kwargs: "Dict[str, Any]", 90 | num_particles: "int" = 10, 91 | module_parameters: "Optional[Iterable[Tensor]]" = None, 92 | ) -> "None": 93 | """Initialize the object. 94 | 95 | Parameters 96 | ---------- 97 | module: 98 | The module. 99 | prior_builder: 100 | The prior builder, i.e. a callable that receives keyword 101 | arguments and returns a prior with size (batch + event) 102 | equal to the length of the 1D tensor obtained by flattening 103 | and concatenating each tensor in `module_parameters`. 104 | prior_kwargs: 105 | The keyword arguments to pass to the prior builder. 106 | Tensor arguments are internally registered as parameters 107 | if their `requires_grad` attribute is True, as persistent 108 | buffers otherwise. 109 | num_particles: 110 | The number of particles. 111 | module_parameters: 112 | The module parameters over which the prior is defined. 113 | Useful to selectively define a prior over a restricted 114 | subset of submodules/parameters. 115 | Default to ``module.parameters()``. 116 | 117 | Raises 118 | ------ 119 | ValueError 120 | If an invalid argument value is given. 121 | 122 | Warnings 123 | -------- 124 | High memory usage is to be expected as `num_particles - 1` 125 | replicas of the module must be maintained internally. 126 | 127 | """ 128 | if num_particles < 1 or not float(num_particles).is_integer(): 129 | raise ValueError( 130 | f"`num_particles` ({num_particles}) must be in the integer interval [1, inf)" 131 | ) 132 | 133 | super().__init__(module, prior_builder, prior_kwargs, module_parameters) 134 | self.num_particles = int(num_particles) 135 | 136 | # Replicate module (one replica for each particle) 137 | self.replicas = ModuleList( 138 | [module] + [copy.deepcopy(module) for _ in range(num_particles - 1)] 139 | ) 140 | 141 | # Retrieve indices of the selected parameters 142 | self._module_parameter_idxes = [] 143 | replica_parameters = list(module.parameters()) 144 | for parameter in self.module_parameters: 145 | for i, x in enumerate(replica_parameters): 146 | if parameter is x: 147 | self._module_parameter_idxes.append(i) 148 | break 149 | 150 | for replica in self.replicas: 151 | # Sample new particle 152 | new_particle = self.prior.sample() 153 | 154 | # Inject sampled particle 155 | start_idx = 0 156 | replica_parameters = list(replica.parameters()) 157 | module_parameters = [ 158 | replica_parameters[idx] for idx in self._module_parameter_idxes 159 | ] 160 | for parameter in module_parameters: 161 | end_idx = start_idx + parameter.numel() 162 | new_parameter = new_particle[start_idx:end_idx].reshape_as(parameter) 163 | parameter.detach_().requires_grad_(False).copy_( 164 | new_parameter 165 | ).requires_grad_() 166 | start_idx = end_idx 167 | 168 | # override 169 | def named_parameters( 170 | self, 171 | *args: "Any", 172 | include_all: "bool" = True, 173 | **kwargs: "Any", 174 | ) -> "Iterator[Tuple[str, Parameter]]": 175 | """Return the named parameters. 176 | 177 | Parameters 178 | ---------- 179 | include_all: 180 | True to include all the named parameters, 181 | False to include only those over which the 182 | prior is defined. 183 | 184 | Returns 185 | ------- 186 | The named parameters. 187 | 188 | """ 189 | if include_all: 190 | return super(PriorModule, self).named_parameters(*args, **kwargs) 191 | named_parameters = dict( 192 | super(PriorModule, self).named_parameters(*args, **kwargs) 193 | ) 194 | result = [] 195 | for replica in self.replicas: 196 | replica_parameters = list(replica.parameters()) 197 | for idx in self._module_parameter_idxes: 198 | for k, v in named_parameters.items(): 199 | if v is replica_parameters[idx]: 200 | result.append((k, v)) 201 | break 202 | return result 203 | 204 | @property 205 | def particles(self) -> "Tensor": 206 | """Return the particles. 207 | 208 | In the following, let `N` denote the number of particles, 209 | and `D` the number of parameters over which the prior is 210 | defined. 211 | 212 | Returns 213 | ------- 214 | The particles, shape: ``[N, D]``. 215 | 216 | """ 217 | result = [] 218 | for replica in self.replicas: 219 | replica_parameters = list(replica.parameters()) 220 | module_parameters = [ 221 | replica_parameters[idx] for idx in self._module_parameter_idxes 222 | ] 223 | for parameter in module_parameters: 224 | result.append(parameter.flatten()) 225 | return torch.cat(result).reshape(self.num_particles, -1) 226 | 227 | # override 228 | def forward( 229 | self, 230 | *args: "Any", 231 | return_log_prior: "bool" = False, 232 | reduction: "str" = "mean", 233 | **kwargs: "Any", 234 | ) -> "Union[Any, Tuple[Any, Tensor]]": 235 | """Forward pass. 236 | 237 | In the following, let `N` denote the number of particles, 238 | `B = {B_1, ..., B_k}` the batch shape, and `O = {O_1, ..., O_m}` 239 | the shape of a leaf value of the underlying module output (can be 240 | a nested tensor). 241 | 242 | Parameters 243 | ---------- 244 | args: 245 | The positional arguments to pass to the underlying module. 246 | return_log_prior: 247 | True to additionally return the log prior (usually 248 | required during training), False otherwise. 249 | reduction: 250 | The reduction to apply to the leaf values of the underlying 251 | module output and to the log prior (if `return_log_prior` is 252 | True) across particles. Must be one of the following: 253 | - "none": no reduction is applied; 254 | - "mean": the leaf values and the log prior are averaged 255 | across particles. 256 | kwargs: 257 | The keyword arguments to pass to the underlying module. 258 | 259 | Returns 260 | ------- 261 | - The output, shape of a leaf value: ``[N, *B, *O]`` 262 | if `reduction` is "none" , ``[*B, *O]`` otherwise; 263 | - if `return_log_prior` is True, the log prior, shape: 264 | ``[N]`` if `reduction` is "none" , ``[]`` otherwise. 265 | 266 | Raises 267 | ------ 268 | ValueError 269 | If an invalid argument value is given. 270 | 271 | """ 272 | if reduction not in ["none", "mean"]: 273 | raise ValueError( 274 | f"`reduction` ({reduction}) must be one of {['none', 'mean']}" 275 | ) 276 | 277 | # Forward pass 278 | outputs = [replica(*args, **kwargs) for replica in self.replicas] 279 | if reduction == "none": 280 | outputs = nested_apply(torch.stack, outputs) 281 | elif reduction == "mean": 282 | outputs = nested_apply( 283 | lambda inputs, dim: torch.mean(torch.stack(inputs, dim), dim), outputs 284 | ) 285 | 286 | if not return_log_prior: 287 | return outputs 288 | 289 | # Extract particles 290 | particles = self.particles 291 | 292 | # Compute log prior 293 | log_priors = self.prior.log_prob(particles) 294 | if reduction == "mean": 295 | log_priors = log_priors.mean() 296 | 297 | return outputs, log_priors 298 | 299 | # override 300 | def __repr__(self) -> "str": 301 | return ( 302 | f"{type(self).__name__}" 303 | f"(module: {self.module}, " 304 | f"prior: {self.prior}, " 305 | f"num_particles: {self.num_particles}, " 306 | f"module_parameters: {sum(parameter.numel() for parameter in self.module_parameters)})" 307 | ) 308 | -------------------------------------------------------------------------------- /bayestorch/nn/variational_posterior_module.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2022 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Variational posterior module.""" 18 | 19 | import copy 20 | import logging 21 | from typing import ( 22 | Any, 23 | Callable, 24 | Dict, 25 | Iterable, 26 | Iterator, 27 | Optional, 28 | Tuple, 29 | TypeVar, 30 | Union, 31 | ) 32 | 33 | import torch 34 | from torch import Tensor 35 | from torch.distributions import Distribution, kl_divergence 36 | from torch.nn import Module, Parameter 37 | 38 | from bayestorch.nn.prior_module import PriorModule 39 | from bayestorch.nn.utils import nested_apply 40 | 41 | 42 | __all__ = [ 43 | "VariationalPosteriorModule", 44 | ] 45 | 46 | 47 | _T = TypeVar("_T", bound="VariationalPosteriorModule") 48 | 49 | _T_destination = Module.T_destination 50 | 51 | _IncompatibleKeys = torch.nn.modules.module._IncompatibleKeys 52 | 53 | _LOGGER = logging.getLogger(__name__) 54 | 55 | 56 | class VariationalPosteriorModule(PriorModule): 57 | """Bayesian module that defines a prior and a variational 58 | posterior over its parameters. 59 | 60 | References 61 | ---------- 62 | .. [1] C. Blundell, J. Cornebise, K. Kavukcuoglu, and D. Wierstra. 63 | "Weight Uncertainty in Neural Networks". 64 | In: ICML. 2015, pp. 1613-1622. 65 | URL: https://arxiv.org/abs/1505.05424 66 | 67 | Examples 68 | -------- 69 | >>> import torch 70 | >>> from torch import nn 71 | >>> 72 | >>> from bayestorch.distributions import LogScaleNormal, SoftplusInvScaleNormal 73 | >>> from bayestorch.nn import VariationalPosteriorModule 74 | >>> 75 | >>> 76 | >>> num_mc_samples = 5 77 | >>> batch_size = 10 78 | >>> in_features = 4 79 | >>> out_features = 2 80 | >>> model = nn.Linear(in_features, out_features) 81 | >>> num_parameters = sum(parameter.numel() for parameter in model.parameters()) 82 | >>> model = VariationalPosteriorModule( 83 | ... model, 84 | ... prior_builder=LogScaleNormal, 85 | ... prior_kwargs={ 86 | ... "loc": torch.zeros(num_parameters), 87 | ... "log_scale": torch.full((num_parameters,), -1.0), 88 | ... }, 89 | ... posterior_builder=SoftplusInvScaleNormal, 90 | ... posterior_kwargs={ 91 | ... "loc": torch.zeros(num_parameters, requires_grad=True), 92 | ... "softplus_inv_scale": torch.full((num_parameters,), -7.0, requires_grad=True), 93 | ... }, 94 | ... ) 95 | >>> input = torch.rand(batch_size, in_features) 96 | >>> output = model( 97 | ... input, 98 | ... num_mc_samples=num_mc_samples, 99 | ... ) 100 | >>> outputs, kl_divs = model( 101 | ... input, 102 | ... num_mc_samples=num_mc_samples, 103 | ... return_kl_div=True, 104 | ... reduction="none", 105 | ... ) 106 | 107 | """ 108 | 109 | posterior: "Distribution" 110 | """The posterior distribution.""" 111 | 112 | # override 113 | def __init__( 114 | self, 115 | module: "Module", 116 | prior_builder: "Callable[..., Distribution]", 117 | prior_kwargs: "Dict[str, Any]", 118 | posterior_builder: "Callable[..., Distribution]", 119 | posterior_kwargs: "Dict[str, Any]", 120 | module_parameters: "Optional[Iterable[Tensor]]" = None, 121 | ) -> "None": 122 | """Initialize the object. 123 | 124 | Parameters 125 | ---------- 126 | module: 127 | The module. 128 | prior_builder: 129 | The prior builder, i.e. a callable that receives keyword 130 | arguments and returns a prior with size (batch + event) 131 | equal to the length of the 1D tensor obtained by flattening 132 | and concatenating each tensor in `module_parameters`. 133 | prior_kwargs: 134 | The keyword arguments to pass to the prior builder. 135 | Tensor arguments are internally registered as parameters 136 | if their `requires_grad` attribute is True, as persistent 137 | buffers otherwise. 138 | posterior_builder: 139 | The posterior builder, i.e. a callable that receives 140 | keyword arguments and returns a posterior. 141 | posterior_kwargs: 142 | The keyword arguments to pass to the posterior builder. 143 | Tensor arguments are internally registered as parameters 144 | if their `requires_grad` attribute is True, as persistent 145 | buffers otherwise. 146 | module_parameters: 147 | The module parameters over which the prior and posterior 148 | are defined. Useful to selectively define a prior and a 149 | posterior over a restricted subset of submodules/parameters. 150 | Default to ``module.parameters()``. 151 | 152 | Raises 153 | ------ 154 | ValueError 155 | If an invalid argument value is given. 156 | 157 | """ 158 | super().__init__(module, prior_builder, prior_kwargs, module_parameters) 159 | self.posterior_builder = posterior_builder 160 | self.posterior_kwargs = posterior_kwargs = { 161 | k: v for k, v in posterior_kwargs.items() 162 | } # Avoid side effects 163 | 164 | # Build prior 165 | self.posterior = self._build_distribution( 166 | "posterior", posterior_builder, posterior_kwargs 167 | ) 168 | 169 | # Retrieve indices of the selected parameters 170 | self._module_parameter_idxes = [] 171 | all_module_parameters = list(module.parameters()) 172 | for parameter in self.module_parameters: 173 | for i, x in enumerate(all_module_parameters): 174 | if parameter is x: 175 | self._module_parameter_idxes.append(i) 176 | break 177 | 178 | # Log Kullback-Leibler divergence warning only once 179 | self._log_kl_div_warning = True 180 | 181 | # override 182 | def named_parameters( 183 | self, 184 | *args: "Any", 185 | include_all: "bool" = True, 186 | **kwargs: "Any", 187 | ) -> "Iterator[Tuple[str, Parameter]]": 188 | """Return the named parameters. 189 | 190 | Parameters 191 | ---------- 192 | include_all: 193 | True to include all the named parameters, 194 | False to include only those over which the 195 | prior/posterior is defined. 196 | 197 | Returns 198 | ------- 199 | The named parameters. 200 | 201 | """ 202 | if include_all: 203 | return ( 204 | (k, v) 205 | for k, v in super(PriorModule, self).named_parameters(*args, **kwargs) 206 | if not any(v is parameter for parameter in self.module_parameters) 207 | ) 208 | return ( 209 | (k, v) 210 | for k, v in super(PriorModule, self).named_parameters(*args, **kwargs) 211 | if not any(v is parameter for parameter in self.module.parameters()) 212 | ) 213 | 214 | # override 215 | def state_dict( 216 | self, 217 | *args, 218 | destination: "_T_destination" = None, 219 | prefix: "str" = "", 220 | keep_vars: "bool" = False, 221 | ) -> "_T_destination": 222 | result = super().state_dict( 223 | destination=destination, prefix=prefix, keep_vars=True 224 | ) 225 | for k, v in list(result.items()): 226 | if any(v is parameter for parameter in self.module_parameters): 227 | result.pop(k) 228 | elif not keep_vars: 229 | result[k] = v.detach() 230 | return result 231 | 232 | # override 233 | def load_state_dict( 234 | self, 235 | state_dict: "Dict[str, Any]", 236 | strict: "bool" = True, 237 | ) -> "_IncompatibleKeys": 238 | parameter_names = [ 239 | f"module.{name}" for name, _ in self.module.named_parameters() 240 | ] 241 | for idx, parameter in zip(self._module_parameter_idxes, self.module_parameters): 242 | state_dict[parameter_names[idx]] = parameter 243 | result = super().load_state_dict(state_dict, strict) 244 | return result 245 | 246 | # override 247 | def forward( 248 | self, 249 | *args: "Any", 250 | num_mc_samples: "int" = 1, 251 | return_kl_div: "bool" = False, 252 | exact_kl_div: "bool" = False, 253 | reduction: "str" = "mean", 254 | **kwargs: "Any", 255 | ) -> "Union[Any, Tuple[Any, Tensor]]": 256 | """Forward pass. 257 | 258 | In the following, let `N` denote the number of Monte Carlo samples, 259 | `B = {B_1, ..., B_k}` the batch shape, and `O = {O_1, ..., O_m}` 260 | the shape of a leaf value of the underlying module output (can be 261 | a nested tensor). 262 | 263 | Parameters 264 | ---------- 265 | args: 266 | The positional arguments to pass to the underlying module. 267 | num_mc_samples: 268 | The number of Monte Carlo samples. 269 | return_kl_div: 270 | True to additionally return the Kullback-Leibler divergence of 271 | the prior from the posterior (usually required during training), 272 | False otherwise. 273 | exact_kl_div: 274 | True to use the exact Kullback-Leibler divergence of the prior 275 | from the posterior (if a closed-form expression exists), 276 | False to use Monte Carlo approximation. 277 | reduction: 278 | The reduction to apply to the leaf values of the underlying 279 | module output and to the Kullback-Leibler divergence (if 280 | `return_kl_div` is True) across Monte Carlo samples. 281 | Must be one of the following: 282 | - "none": no reduction is applied; 283 | - "mean": the leaf values and the Kullback-Leibler divergence 284 | are averaged across Monte Carlo samples. 285 | kwargs: 286 | The keyword arguments to pass to the underlying module. 287 | 288 | Returns 289 | ------- 290 | - The output, shape of a leaf value: ``[N, *B, *O]`` 291 | if `reduction` is "none" , ``[*B, *O]`` otherwise; 292 | - if `return_kl_div` is True, the Kullback-Leibler 293 | divergence of the prior from the posterior, shape: 294 | ``[N]`` if `reduction` is "none" , ``[]`` otherwise. 295 | 296 | Raises 297 | ------ 298 | ValueError 299 | If an invalid argument value is given. 300 | 301 | Warnings 302 | -------- 303 | High memory usage is to be expected as `num_mc_samples - 1` 304 | replicas of the module must be maintained internally. 305 | 306 | """ 307 | if num_mc_samples < 1 or not float(num_mc_samples).is_integer(): 308 | raise ValueError( 309 | f"`num_mc_samples` ({num_mc_samples}) must be in the integer interval [1, inf)" 310 | ) 311 | if reduction not in ["none", "mean"]: 312 | raise ValueError( 313 | f"`reduction` ({reduction}) must be one of {['none', 'mean']}" 314 | ) 315 | 316 | if not torch.is_grad_enabled(): 317 | return self._fast_forward( 318 | *args, 319 | num_mc_samples=num_mc_samples, 320 | return_kl_div=return_kl_div, 321 | exact_kl_div=exact_kl_div, 322 | reduction=reduction, 323 | **kwargs, 324 | ) 325 | 326 | # Sample new particles 327 | new_particles = self.posterior.rsample((num_mc_samples,)) 328 | 329 | # Replicate module (one replica for each Monte Carlo sample) 330 | replicas = [self.module] + [ 331 | copy.deepcopy(self.module) for _ in range(num_mc_samples - 1) 332 | ] 333 | 334 | # Inject sampled particles 335 | new_particles = new_particles.flatten() 336 | start_idx = 0 337 | for replica in replicas: 338 | replica_parameters = list(replica.parameters()) 339 | module_parameters = [ 340 | replica_parameters[idx] for idx in self._module_parameter_idxes 341 | ] 342 | for parameter in module_parameters: 343 | end_idx = start_idx + parameter.numel() 344 | new_parameter = new_particles[start_idx:end_idx].reshape_as(parameter) 345 | parameter.detach_().requires_grad_(False).copy_(new_parameter) 346 | start_idx = end_idx 347 | 348 | # Forward pass 349 | outputs = [replica(*args, **kwargs) for replica in replicas] 350 | if reduction == "none": 351 | outputs = nested_apply(torch.stack, outputs) 352 | elif reduction == "mean": 353 | outputs = nested_apply( 354 | lambda inputs, dim: torch.mean(torch.stack(inputs, dim), dim), outputs 355 | ) 356 | 357 | if not return_kl_div: 358 | return outputs 359 | 360 | # Compute Kullback-Leibler divergence 361 | kl_divs = None 362 | if exact_kl_div: 363 | try: 364 | kl_divs = kl_divergence(self.posterior, self.prior) 365 | if reduction == "none": 366 | kl_divs = kl_divs.expand(num_mc_samples) 367 | except NotImplementedError: 368 | kl_divs = None 369 | if self._log_kl_div_warning: 370 | _LOGGER.warning( 371 | "Could not compute exact Kullback-Leibler divergence, " 372 | "reverting to Monte Carlo approximation" 373 | ) 374 | self._log_kl_div_warning = False 375 | if kl_divs is None: 376 | new_particles = new_particles.reshape(-1, *self.posterior.event_shape) 377 | log_posteriors = self.posterior.log_prob(new_particles) 378 | log_priors = self.prior.log_prob(new_particles) 379 | kl_divs = log_posteriors - log_priors 380 | if reduction == "mean": 381 | kl_divs = kl_divs.mean() 382 | 383 | return outputs, kl_divs 384 | 385 | # This implementation does not require copying the module 386 | # and can be used when gradient tracking is disabled 387 | def _fast_forward( 388 | self, 389 | *args: "Any", 390 | num_mc_samples: "int" = 1, 391 | return_kl_div: "bool" = False, 392 | exact_kl_div: "bool" = False, 393 | reduction: "str" = "mean", 394 | **kwargs: "Any", 395 | ) -> "Union[Any, Tuple[Any, Tensor]]": 396 | kl_divs = [] 397 | if return_kl_div and exact_kl_div: 398 | try: 399 | kl_divs = kl_divergence(self.posterior, self.prior) 400 | if reduction == "none": 401 | kl_divs = kl_divs.expand(num_mc_samples) 402 | except NotImplementedError: 403 | kl_divs = [] 404 | if self._log_kl_div_warning: 405 | _LOGGER.warning( 406 | "Could not compute exact Kullback-Leibler divergence, " 407 | "reverting to Monte Carlo approximation" 408 | ) 409 | self._log_kl_div_warning = False 410 | 411 | outputs = [] 412 | for _ in range(num_mc_samples): 413 | # Sample new particle 414 | new_particle = self.posterior.rsample() 415 | 416 | # Inject sampled particle 417 | start_idx = 0 418 | for parameter in self.module_parameters: 419 | end_idx = start_idx + parameter.numel() 420 | new_parameter = new_particle[start_idx:end_idx].reshape_as(parameter) 421 | parameter.detach_().requires_grad_(False).copy_(new_parameter) 422 | start_idx = end_idx 423 | 424 | # Forward pass 425 | output = self.module(*args, **kwargs) 426 | outputs.append(output) 427 | 428 | if isinstance(kl_divs, Tensor) or not return_kl_div: 429 | continue 430 | 431 | # Compute Kullback-Leibler divergence 432 | log_posterior = self.posterior.log_prob(new_particle) 433 | log_prior = self.prior.log_prob(new_particle) 434 | kl_div = log_posterior - log_prior 435 | kl_divs.append(kl_div) 436 | 437 | if reduction == "none": 438 | outputs = nested_apply(torch.stack, outputs) 439 | elif reduction == "mean": 440 | outputs = nested_apply( 441 | lambda inputs, dim: torch.mean(torch.stack(inputs, dim), dim), outputs 442 | ) 443 | 444 | if not return_kl_div: 445 | return outputs 446 | 447 | if isinstance(kl_divs, list): 448 | kl_divs = torch.stack(kl_divs) 449 | if reduction == "mean": 450 | kl_divs = kl_divs.mean() 451 | 452 | return outputs, kl_divs 453 | 454 | # override 455 | def _apply(self, *args: "Any", **kwargs: "Any") -> "_T": 456 | super()._apply(*args, **kwargs) 457 | 458 | # Rebuild posterior using updated parameters/buffers 459 | # (`_apply` might create copies of parameters/buffers, 460 | # therefore references within the posterior are lost) 461 | self.posterior = self._build_distribution( 462 | "posterior", self.posterior_builder, self.posterior_kwargs 463 | ) 464 | 465 | return self 466 | 467 | # override 468 | def __repr__(self) -> "str": 469 | return ( 470 | f"{type(self).__name__}" 471 | f"(module: {self.module}, " 472 | f"prior: {self.prior}, " 473 | f"posterior: {self.posterior}, " 474 | f"module_parameters: {sum(parameter.numel() for parameter in self.module_parameters)})" 475 | ) 476 | --------------------------------------------------------------------------------