├── .gitignore ├── LICENSE.md ├── README.md ├── blob ├── regression_ml.png └── regression_samples.png ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── _autosummary │ ├── tyxe.bnn.rst │ ├── tyxe.guides.rst │ ├── tyxe.likelihoods.rst │ ├── tyxe.poutine.handlers.rst │ ├── tyxe.poutine.reparameterization_messengers.rst │ ├── tyxe.poutine.rst │ ├── tyxe.poutine.selective_messengers.rst │ ├── tyxe.priors.rst │ ├── tyxe.rst │ └── tyxe.util.rst │ ├── _templates │ ├── autosummary │ │ └── class.rst │ ├── module.rst_t │ ├── package.rst_t │ └── toc.rst_t │ ├── conf.py │ ├── index.rst │ ├── tutorials.rst │ ├── tyxe.bnn.rst │ ├── tyxe.guides.rst │ ├── tyxe.likelihoods.rst │ ├── tyxe.poutine.handlers.rst │ ├── tyxe.poutine.reparameterization_messengers.rst │ ├── tyxe.poutine.rst │ ├── tyxe.poutine.selective_messengers.rst │ ├── tyxe.priors.rst │ ├── tyxe.rst │ └── tyxe.util.rst ├── environment.yml ├── examples ├── gnn.py ├── nerf.py ├── resnet.py └── vcl.py ├── notebooks ├── regression.ipynb └── viz_nerf.ipynb ├── setup.py ├── tests ├── test_bnn.py ├── test_guides.py ├── test_likelihoods.py ├── test_poutine.py └── test_priors.py └── tyxe ├── __init__.py ├── bnn.py ├── guides.py ├── likelihoods.py ├── poutine ├── __init__.py ├── handlers.py ├── reparameterization_messengers.py └── selective_messengers.py ├── priors.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | .ipynb_checkpoints/ 4 | .idea/ 5 | *.egg-info/ 6 | docs/build/ 7 | env/ 8 | .coverage 9 | data/ -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Hippolyt Ritter, Theofanis Karaletsos 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TyXe: Pyro-based BNNs for Pytorch users 2 | 3 | TyXe aims to simplify the process of turning [Pytorch](www.pytorch.org) neural networks into Bayesian neural networks by 4 | leveraging the model definition and inference capabilities of [Pyro](www.pyro.ai). 5 | Our core design principle is to cleanly separate the construction of neural architecture, prior, inference distribution 6 | and likelihood, enabling a flexible workflow where each component can be exchanged independently. 7 | Defining a BNN in TyXe takes as little as 5 lines of code: 8 | ``` 9 | net = nn.Sequential(nn.Linear(1, 50), nn.Tanh(), nn.Linear(50, 1)) 10 | prior = tyxe.priors.IIDPrior(dist.Normal(0, 1)) 11 | likelihood = tyxe.likelihoods.HomoskedasticGaussian(scale=0.1) 12 | inference = tyxe.guides.AutoNormal 13 | bnn = tyxe.VariationalBNN(net, prior, likelihood, inference) 14 | ``` 15 | 16 | In the following, we assume that you (roughly) know what a BNN is mathematically. 17 | 18 | 19 | ## Motivating example 20 | Standard neural networks give us a single function that fits the data, but many different ones are typically plausible. 21 | With only a single fit, we don't know for what inputs the model is 'certain' (because there is training data nearby) and 22 | where it is uncertain. 23 | 24 | | ![ML](blob/regression_ml.png) | ![Samples](blob/regression_samples.png) | 25 | |:---:|:---:| 26 | | Maximum likelihood fit | Posterior samples | 27 | 28 | Implementing the former can be achieved easily in a few lines of Pytorch code, but training a BNN that gives a 29 | distribution over different fits is typically more complicated and is specifically what we aim to simplify. 30 | 31 | ## Training 32 | 33 | Constructing a BNN object has been shown in the example above. 34 | For fitting the posterior approximation, we provide a high-level `.fit` method similar to libraries such as scikit-learn 35 | or keras: 36 | 37 | ``` 38 | optim = pyro.optim.Adam({"lr": 1e-3}) 39 | bnn.fit(data_loader, optim, num_epochs) 40 | ``` 41 | 42 | ## Prediction & evaluation 43 | 44 | Further we provide `.predict` and `.evaluation` methods, which make predictions based on multiple samples from the approximate posterior, average them based on the observation model, and return log likelihoods and an error measure: 45 | ``` 46 | predictions = bnn.predict(x_test, num_samples) 47 | error, log_likelihood = bnn.evaluate(x_test, y_test, num_samples) 48 | ``` 49 | 50 | ## Local reparameterization 51 | 52 | We implement [local reparameterization](https://arxiv.org/abs/1506.02557) for factorized Gaussians as a poutine, which reduces gradient noise during training. 53 | This means it can be enabled or disabled at both during training and prediction with a context manager: 54 | ``` 55 | with tyxe.poutine.local_reparameterization(): 56 | bnn.fit(data_loader, optim, num_epochs) 57 | bnn.predict(x_test, num_predictions) 58 | ``` 59 | At the moment, this poutine does not work with the `AutoNormal` and `AutoDiagonalNormal` guides in pyro, since those draw the weights from a Delta distribution, so you need to use `tyxe.guides.ParameterwiseDiagonalNormal` as your guide. 60 | 61 | ## MCMC 62 | 63 | We provide a unified interface to pyro's MCMC implementations, simply use the `tyxe.MCMC_BNN` class instead and provide a kernel instead of the guide: 64 | ``` 65 | kernel = pyro.infer.mcmcm.NUTS 66 | bnn = tyxe.MCMC_BNN(net, prior, likelihood, kernel) 67 | ``` 68 | Any parameters that pyro's `MCMC` class accepts can be passed through the keyword arguments of the `.fit` method. 69 | 70 | ## Continual learning 71 | 72 | Due to our design that cleanly separates the prior from guide, architecture and likelihood, it is easy to update it in a continual setting. 73 | For example, you can construct a `tyxe.priors.DictPrior` by extracting the distributions over all weights and biases from a `ParameterwiseDiagonalNormal` instance using the `get_detached_distributions` method and pass it to `bnn.update_prior` to implement [Variational Continual Learning](https://arxiv.org/abs/1710.10628) in a few lines of code. 74 | See `examples/vcl.py` for a basic example on split-MNIST and split-CIFAR. 75 | 76 | ## Network architectures 77 | 78 | We don't implement any layer classes. 79 | You construct your network in Pytorch and then turn it into a BNN, which makes it easy to apply the same prior and inference strategies to different neural networks. 80 | 81 | ## Inference 82 | 83 | For inference, we mainly provide an equivalent to pyro's `AutoDiagonalNormal` that is compatible with local reparameterization in `tyxe.guides`. 84 | This module also contains a few helper functions for initialization of Gaussian mean parameters, e.g. to the values of a pre-trained network. 85 | It should be possible to use any of pyro's autoguides for variational inference. 86 | See `examples/resnet.py` for a few options as well as initializing to pre-trained weights. 87 | 88 | ## Priors 89 | 90 | The priors can be found in `tyxe.priors`. 91 | We currently only support placing priors on the parameters. 92 | Through the expose and hide arguments in the init method you can specify layers, types of layers and specific parameters over which you want to place a prior. 93 | This helps, for example in learning the parameters of BatchNorm layers deterministically. 94 | 95 | 96 | ## Likelihoods 97 | 98 | `tyxe.observation_models` contains classes that wrap the most common `torch.distributions` for specifying noise models of data to 99 | 100 | 101 | # Installation 102 | 103 | We recommend installing TyXe using conda with the provided `environment.yml`, which also installs all the dependencies for the examples except for Pytorch3d, which needs to be added manually. 104 | The environment assumes that you are using CUDA11.0, if this is not the case, simply change the `cudatoolkit` and `dgl-cuda` versions before running: 105 | ``` 106 | conda env create -f environment.yml 107 | conda activate tyxe 108 | pip install -e . 109 | ``` 110 | 111 | ## Citation 112 | If you use TyXe, please consider citing: 113 | ``` 114 | @inproceedings{ritter2022tyxe, 115 | title={Ty{X}e: Pyro-based {B}ayesian neural nets for {P}ytorch}, 116 | author={Ritter, Hippolyt and Karaletsos, Theofanis}, 117 | booktitle={Proceedings of Machine Learning and Systems}, 118 | year={2022} 119 | } 120 | ``` 121 | -------------------------------------------------------------------------------- /blob/regression_ml.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TyXe-BDL/TyXe/368bf62ade8628d8f2af8d37e486b671f14578b6/blob/regression_ml.png -------------------------------------------------------------------------------- /blob/regression_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TyXe-BDL/TyXe/368bf62ade8628d8f2af8d37e486b671f14578b6/blob/regression_samples.png -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | APIDOC = sphinx-apidoc 9 | SPHINXPROJ = TyXe 10 | SOURCEDIR = source 11 | TEMPLATEDIR = "$(SOURCEDIR)/_templates" 12 | PROJECTDIR = ../tyxe 13 | BUILDDIR = build 14 | 15 | # Put it first so that "make" without argument is like "make help". 16 | help: 17 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 18 | 19 | .PHONY: help Makefile 20 | 21 | apidoc: 22 | $(APIDOC) --templatedir "$(TEMPLATEDIR)" -fMeT -o "$(SOURCEDIR)" "$(PROJECTDIR)" 23 | 24 | # hack to avoid trailing module/package headers from: 25 | # https://stackoverflow.com/questions/21003122/sphinx-apidoc-section-titles-for-python-module-package-names 26 | docs: 27 | $(MAKE) clean 28 | $(MAKE) apidoc 29 | $(MAKE) html 30 | 31 | clean: 32 | @echo "Removing everything under 'build' and 'source/generated'.." 33 | @rm -rf $(BUILDDIR)/html/ $(BUILDDIR)/doctrees $(SOURCEDIR)/generated $(SOURCEDIR)/_autosummary 34 | 35 | # Catch-all target: route all unknown targets to Sphinx using the new 36 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 37 | %: Makefile 38 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 39 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx_rtd_theme 3 | pyro-ppl>=1.4.0 4 | torch>=1.7.0 5 | -------------------------------------------------------------------------------- /docs/source/_autosummary/tyxe.bnn.rst: -------------------------------------------------------------------------------- 1 | tyxe.bnn 2 | ======== 3 | 4 | .. automodule:: tyxe.bnn 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | .. rubric:: Classes 17 | 18 | .. autosummary:: 19 | 20 | GuidedBNN 21 | MCMC_BNN 22 | PytorchBNN 23 | VariationalBNN 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /docs/source/_autosummary/tyxe.guides.rst: -------------------------------------------------------------------------------- 1 | tyxe.guides 2 | =========== 3 | 4 | .. automodule:: tyxe.guides 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | .. rubric:: Functions 13 | 14 | .. autosummary:: 15 | 16 | init_to_constant 17 | init_to_normal 18 | init_to_normal_kaiming 19 | init_to_normal_radford 20 | init_to_normal_xavier 21 | init_to_sample 22 | init_to_zero 23 | 24 | 25 | 26 | 27 | 28 | .. rubric:: Classes 29 | 30 | .. autosummary:: 31 | 32 | AutoNormal 33 | PretrainedInitializer 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /docs/source/_autosummary/tyxe.likelihoods.rst: -------------------------------------------------------------------------------- 1 | tyxe.likelihoods 2 | ================ 3 | 4 | .. automodule:: tyxe.likelihoods 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | .. rubric:: Functions 13 | 14 | .. autosummary:: 15 | 16 | inverse_softplus 17 | 18 | 19 | 20 | 21 | 22 | .. rubric:: Classes 23 | 24 | .. autosummary:: 25 | 26 | Bernoulli 27 | Categorical 28 | Gaussian 29 | HeteroskedasticGaussian 30 | HomoskedasticGaussian 31 | Likelihood 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /docs/source/_autosummary/tyxe.poutine.handlers.rst: -------------------------------------------------------------------------------- 1 | tyxe.poutine.handlers 2 | ===================== 3 | 4 | .. automodule:: tyxe.poutine.handlers 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | .. rubric:: Functions 13 | 14 | .. autosummary:: 15 | 16 | flipout 17 | local_reparameterization 18 | selective_mask 19 | selective_scale 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /docs/source/_autosummary/tyxe.poutine.reparameterization_messengers.rst: -------------------------------------------------------------------------------- 1 | tyxe.poutine.reparameterization\_messengers 2 | =========================================== 3 | 4 | .. automodule:: tyxe.poutine.reparameterization_messengers 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | .. rubric:: Classes 17 | 18 | .. autosummary:: 19 | 20 | FlipoutMessenger 21 | LocalReparameterizationMessenger 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /docs/source/_autosummary/tyxe.poutine.rst: -------------------------------------------------------------------------------- 1 | tyxe.poutine 2 | ============ 3 | 4 | .. automodule:: tyxe.poutine 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | .. rubric:: Modules 25 | 26 | .. autosummary:: 27 | :toctree: 28 | :recursive: 29 | 30 | tyxe.poutine.handlers 31 | tyxe.poutine.reparameterization_messengers 32 | tyxe.poutine.selective_messengers 33 | 34 | -------------------------------------------------------------------------------- /docs/source/_autosummary/tyxe.poutine.selective_messengers.rst: -------------------------------------------------------------------------------- 1 | tyxe.poutine.selective\_messengers 2 | ================================== 3 | 4 | .. automodule:: tyxe.poutine.selective_messengers 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | .. rubric:: Classes 17 | 18 | .. autosummary:: 19 | 20 | SelectiveMaskMessenger 21 | SelectiveMixin 22 | SelectiveScaleMessenger 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /docs/source/_autosummary/tyxe.priors.rst: -------------------------------------------------------------------------------- 1 | tyxe.priors 2 | =========== 3 | 4 | .. automodule:: tyxe.priors 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | .. rubric:: Classes 17 | 18 | .. autosummary:: 19 | 20 | DictPrior 21 | IIDPrior 22 | LambdaPrior 23 | LayerwiseNormalPrior 24 | Prior 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /docs/source/_autosummary/tyxe.rst: -------------------------------------------------------------------------------- 1 | tyxe 2 | ==== 3 | 4 | .. automodule:: tyxe 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | .. rubric:: Modules 25 | 26 | .. autosummary:: 27 | :toctree: 28 | :recursive: 29 | 30 | tyxe.bnn 31 | tyxe.guides 32 | tyxe.likelihoods 33 | tyxe.poutine 34 | tyxe.priors 35 | tyxe.util 36 | 37 | -------------------------------------------------------------------------------- /docs/source/_autosummary/tyxe.util.rst: -------------------------------------------------------------------------------- 1 | tyxe.util 2 | ========= 3 | 4 | .. automodule:: tyxe.util 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | .. rubric:: Functions 13 | 14 | .. autosummary:: 15 | 16 | calculate_prior_std 17 | deep_hasattr 18 | fan_in_fan_out 19 | named_pyro_samples 20 | prod 21 | pyro_sample_sites 22 | to_pyro_module 23 | to_pyro_module_ 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | {{ name | underline}} 7 | 8 | .. autoclass:: {{ name }} 9 | :inherited-members: 10 | :members: 11 | 12 | .. autogenerated from source/_templates/autosummary/class.rst 13 | -------------------------------------------------------------------------------- /docs/source/_templates/module.rst_t: -------------------------------------------------------------------------------- 1 | {%- if show_headings %} 2 | {{- basename | e | heading }} 3 | 4 | {% endif -%} 5 | .. automodule:: {{ qualname }} 6 | {%- for option in automodule_options %} 7 | :{{ option }}: 8 | {%- endfor %} 9 | 10 | -------------------------------------------------------------------------------- /docs/source/_templates/package.rst_t: -------------------------------------------------------------------------------- 1 | {%- macro automodule(modname, options) -%} 2 | .. automodule:: {{ modname }} 3 | {%- for option in options %} 4 | :{{ option }}: 5 | {%- endfor %} 6 | {%- endmacro %} 7 | 8 | {%- macro toctree(docnames) -%} 9 | .. toctree:: 10 | :maxdepth: {{ maxdepth }} 11 | {% for docname in docnames %} 12 | {{ docname }} 13 | {%- endfor %} 14 | {%- endmacro %} 15 | 16 | {%- if is_namespace %} 17 | {{- pkgname | e | heading }} 18 | {% else %} 19 | {{- pkgname | e | heading }} 20 | {% endif %} 21 | 22 | {%- if modulefirst and not is_namespace %} 23 | {{ automodule(pkgname, automodule_options) }} 24 | {% endif %} 25 | 26 | {%- if subpackages %} 27 | {{ toctree(subpackages) }} 28 | {% endif %} 29 | 30 | {%- if submodules %} 31 | {% if separatemodules %} 32 | {{ toctree(submodules) }} 33 | {% else %} 34 | {%- for submodule in submodules %} 35 | {% if show_headings %} 36 | {{- submodule | e | heading(2) }} 37 | {% endif %} 38 | {{ automodule(submodule, automodule_options) }} 39 | {% endfor %} 40 | {%- endif %} 41 | {%- endif %} 42 | 43 | {%- if not modulefirst and not is_namespace %} 44 | {{ automodule(pkgname, automodule_options) }} 45 | {% endif %} 46 | -------------------------------------------------------------------------------- /docs/source/_templates/toc.rst_t: -------------------------------------------------------------------------------- 1 | {{ header | heading }} 2 | 3 | .. toctree:: 4 | :maxdepth: {{ maxdepth }} 5 | {% for docname in docnames %} 6 | {{ docname }} 7 | {%- endfor %} 8 | 9 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | d = os.path.dirname(os.path.abspath(__file__)) 16 | sys.path.insert(0, os.path.realpath(os.path.join(d, '../..'))) 17 | 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = 'TyXe' 22 | copyright = '2021, Hippolyt Ritter, Theofanis Karaletsos' 23 | author = 'Hippolyt Ritter, Theofanis Karaletsos' 24 | 25 | # The full version, including alpha/beta/rc tags 26 | release = '0.0.1' 27 | 28 | 29 | # -- General configuration --------------------------------------------------- 30 | 31 | import sphinx_rtd_theme 32 | 33 | # Add any Sphinx extension module names here, as strings. They can be 34 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 35 | # ones. 36 | extensions = [ 37 | 'sphinx.ext.intersphinx', # 38 | 'sphinx.ext.todo', # 39 | 'sphinx.ext.mathjax', # 40 | 'sphinx.ext.ifconfig', # 41 | 'sphinx.ext.viewcode', # 42 | 'sphinx.ext.githubpages', # 43 | 'sphinx.ext.graphviz', # 44 | 'sphinx.ext.autodoc', # 45 | 'sphinx.ext.autosummary', # 46 | 'sphinx.ext.doctest' 47 | ] 48 | 49 | add_module_names = False 50 | autodoc_inherit_docstrings = False 51 | autosummary_generate = True 52 | autodoc_inherit_docstrings = False 53 | 54 | # Add any paths that contain templates here, relative to this directory. 55 | templates_path = ['_templates'] 56 | 57 | # List of patterns, relative to source directory, that match files and 58 | # directories to ignore when looking for source files. 59 | # This pattern also affects html_static_path and html_extra_path. 60 | exclude_patterns = [] 61 | 62 | 63 | # -- Options for HTML output ------------------------------------------------- 64 | 65 | # The theme to use for HTML and HTML Help pages. See the documentation for 66 | # a list of builtin themes. 67 | # 68 | html_theme = 'sphinx_rtd_theme' 69 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 70 | 71 | # Add any paths that contain custom static files (such as style sheets) here, 72 | # relative to this directory. They are copied after the builtin static files, 73 | # so a file named "default.css" will overwrite the builtin "default.css". 74 | html_static_path = ['_static'] 75 | 76 | 77 | # -- Extension configuration ------------------------------------------------- 78 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/TyXe-BDL/TyXe 2 | 3 | 4 | TyXe Documentation 5 | ================== 6 | 7 | 8 | TyXe website and docs. Under construction! 9 | 10 | .. toctree:: 11 | :maxdepth: 2 12 | :caption: Tutorials 13 | 14 | tutorials 15 | 16 | 17 | API Reference 18 | ============= 19 | 20 | .. toctree:: 21 | 22 | .. autosummary:: 23 | :toctree: _autosummary 24 | :recursive: 25 | :caption: API Reference 26 | 27 | tyxe 28 | tyxe.poutine 29 | 30 | 31 | Indices and tables 32 | ================== 33 | 34 | * :ref:`genindex` 35 | * :ref:`search` 36 | -------------------------------------------------------------------------------- /docs/source/tutorials.rst: -------------------------------------------------------------------------------- 1 | TyXe Tutorials 2 | ============== 3 | 4 | Coming soon! 5 | -------------------------------------------------------------------------------- /docs/source/tyxe.bnn.rst: -------------------------------------------------------------------------------- 1 | tyxe.bnn 2 | ======== 3 | 4 | .. automodule:: tyxe.bnn 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/tyxe.guides.rst: -------------------------------------------------------------------------------- 1 | tyxe.guides 2 | =========== 3 | 4 | .. automodule:: tyxe.guides 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/tyxe.likelihoods.rst: -------------------------------------------------------------------------------- 1 | tyxe.likelihoods 2 | ================ 3 | 4 | .. automodule:: tyxe.likelihoods 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/tyxe.poutine.handlers.rst: -------------------------------------------------------------------------------- 1 | tyxe.poutine.handlers 2 | ===================== 3 | 4 | .. automodule:: tyxe.poutine.handlers 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/tyxe.poutine.reparameterization_messengers.rst: -------------------------------------------------------------------------------- 1 | tyxe.poutine.reparameterization\_messengers 2 | =========================================== 3 | 4 | .. automodule:: tyxe.poutine.reparameterization_messengers 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/tyxe.poutine.rst: -------------------------------------------------------------------------------- 1 | tyxe.poutine 2 | ============ 3 | 4 | .. automodule:: tyxe.poutine 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | 10 | .. toctree:: 11 | :maxdepth: 4 12 | 13 | tyxe.poutine.handlers 14 | tyxe.poutine.reparameterization_messengers 15 | tyxe.poutine.selective_messengers 16 | -------------------------------------------------------------------------------- /docs/source/tyxe.poutine.selective_messengers.rst: -------------------------------------------------------------------------------- 1 | tyxe.poutine.selective\_messengers 2 | ================================== 3 | 4 | .. automodule:: tyxe.poutine.selective_messengers 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/tyxe.priors.rst: -------------------------------------------------------------------------------- 1 | tyxe.priors 2 | =========== 3 | 4 | .. automodule:: tyxe.priors 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/tyxe.rst: -------------------------------------------------------------------------------- 1 | tyxe 2 | ==== 3 | 4 | .. automodule:: tyxe 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | .. toctree:: 10 | :maxdepth: 4 11 | 12 | tyxe.poutine 13 | 14 | 15 | .. toctree:: 16 | :maxdepth: 4 17 | 18 | tyxe.bnn 19 | tyxe.guides 20 | tyxe.likelihoods 21 | tyxe.priors 22 | tyxe.util 23 | -------------------------------------------------------------------------------- /docs/source/tyxe.util.rst: -------------------------------------------------------------------------------- 1 | tyxe.util 2 | ========= 3 | 4 | .. automodule:: tyxe.util 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: tyxe 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - dglteam 6 | - defaults 7 | dependencies: 8 | - python=3.9 9 | - pytorch=1.12.0 10 | - torchvision=0.13.0 11 | - cudatoolkit=11.6 12 | - numpy=1.22.3 13 | - scipy=1.8.1 14 | - tqdm=4.64.0 15 | - matplotlib=3.5.1 16 | - pytest=7.1.2 17 | - dgl=0.8.2 18 | - requests=2.28.1 19 | - psutil=5.9.0 20 | - pip 21 | - pip: 22 | - pyro-ppl==1.8.0 23 | -------------------------------------------------------------------------------- /examples/gnn.py: -------------------------------------------------------------------------------- 1 | """Bayesian Graph Neural Net, based on the DGL tutorial at https://docs.dgl.ai/tutorials/models/1_gnn/1_gcn.html 2 | Also calculates expected calibration error as a metric """ 3 | import argparse 4 | from functools import partial 5 | 6 | import dgl 7 | import dgl.function as fn 8 | from dgl.data import citation_graph as citegrh 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | import pyro 14 | import pyro.distributions as dist 15 | 16 | 17 | import tyxe 18 | 19 | 20 | gcn_msg = fn.copy_src(src='h', out='m') 21 | gcn_reduce = fn.sum(msg='m', out='h') 22 | 23 | 24 | class GCNLayer(nn.Module): 25 | def __init__(self, in_feats, out_feats): 26 | super(GCNLayer, self).__init__() 27 | self.linear = nn.Linear(in_feats, out_feats) 28 | 29 | def forward(self, g, feature): 30 | # Creating a local scope so that all the stored ndata and edata 31 | # (such as the `'h'` ndata below) are automatically popped out 32 | # when the scope exits. 33 | with g.local_scope(): 34 | g.ndata['h'] = feature 35 | g.update_all(gcn_msg, gcn_reduce) 36 | h = g.ndata['h'] 37 | return self.linear(h) 38 | 39 | 40 | class Net(nn.Module): 41 | def __init__(self): 42 | super(Net, self).__init__() 43 | self.layer1 = GCNLayer(1433, 16) 44 | self.layer2 = GCNLayer(16, 7) 45 | 46 | def forward(self, g, features): 47 | x = torch.relu(self.layer1(g, features)) 48 | x = self.layer2(g, x) 49 | return x 50 | 51 | 52 | def load_cora_data(): 53 | data = citegrh.load_cora() 54 | features = torch.FloatTensor(data.features) 55 | labels = torch.LongTensor(data.labels) 56 | train_mask = torch.BoolTensor(data.train_mask) 57 | test_mask = torch.BoolTensor(data.test_mask) 58 | val_mask = torch.BoolTensor(data.val_mask) 59 | g = dgl.from_networkx(data.graph) 60 | return g, features, labels, train_mask, test_mask, val_mask 61 | 62 | 63 | def calc_ece(probs, labels, num_bins): 64 | maxp, predictions = probs.max(-1, keepdims=True) 65 | boundaries = torch.linspace(0, 1, num_bins+1) 66 | lower_bound, upper_bound = boundaries[:-1], boundaries[1:] 67 | in_bin = maxp.ge(lower_bound).logical_and(maxp.lt(upper_bound)).float() 68 | bin_sizes = in_bin.sum(0) 69 | correct = predictions.eq(labels.unsqueeze(-1)).float() 70 | 71 | non_empty = bin_sizes.gt(0) 72 | accs = torch.where(non_empty, correct.mul(in_bin).sum(0) / bin_sizes, torch.zeros_like(bin_sizes)) 73 | pred_probs = torch.where(non_empty, maxp.mul(in_bin).sum(0) / bin_sizes, torch.zeros_like(bin_sizes)) 74 | bin_weight = bin_sizes / bin_sizes.sum() 75 | 76 | return accs.sub(pred_probs).abs().mul(bin_weight).sum() 77 | 78 | 79 | def main(inference, lr, num_epochs, milestones): 80 | net = Net() 81 | 82 | g, features, labels, train_mask, test_mask, val_mask = load_cora_data() 83 | # Add edges between each node and itself to preserve old node representations 84 | g.add_edges(g.nodes(), g.nodes()) 85 | total_nodes = len(train_mask) 86 | training_nodes = train_mask.float().sum().item() 87 | 88 | prior_kwargs = {} 89 | test_samples = 1 90 | if inference == "ml": 91 | # hide everything from the prior so that every nn.Parameter becomes a PyroParam 92 | prior_kwargs.update(expose_all=False, hide_all=True) 93 | # a guide is not needed in that case 94 | guide = None 95 | elif inference == "map": 96 | guide = pyro.infer.autoguide.AutoDelta 97 | elif inference == "mean-field": 98 | guide = partial(tyxe.guides.AutoNormal, init_scale=1e-4, max_guide_scale=0.3, 99 | init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(net)) 100 | test_samples = 8 101 | else: 102 | raise RuntimeError("Unreachable") 103 | prior = tyxe.priors.IIDPrior(dist.Normal(0, 1), **prior_kwargs) 104 | 105 | # the dataset size needs to be set to the **total** number of nodes, since the pyro.plate receives all nodes 106 | # as a subsample, i.e. measures the batch size as equal to the total number of nodes, so we need to set the 107 | # dataset size accordingly to achieve the correct scaling of the log likelihood 108 | obs = tyxe.likelihoods.Categorical(dataset_size=total_nodes) 109 | bnn = tyxe.VariationalBNN(net, prior, obs, guide) 110 | 111 | optim = torch.optim.Adam 112 | scheduler = pyro.optim.MultiStepLR({"optimizer": optim, "optim_args": {"lr": lr}, "milestones": milestones}) 113 | # we only have one batch of data so it can go into a single-element list. BNN.fit assumes that the loader is an 114 | # iterator over two-element tuples, where the first element is a single element or tuple/list that is fed into the 115 | # NN and the second element is a tensor that contains the labels 116 | loader = [((g, features), labels)] 117 | 118 | acc_list, ece_list, nll_list = [], [], [] 119 | 120 | def callback(b, i, e): 121 | errs, lls = b.evaluate((g, features), labels, num_predictions=test_samples, reduction="none") 122 | test_acc = 1 - errs[test_mask].mean().item() 123 | val_nll = -lls[val_mask].mean().item() 124 | ece = calc_ece(b.predict(g, features, num_predictions=test_samples).softmax(-1)[test_mask], 125 | labels[test_mask], 10).item() 126 | print(f"Epoch {i+1:03d} | ELBO {e/training_nodes:03.4f} | Test Acc {100 * test_acc:.1f}% |" 127 | f" ECE {100 * ece:.2f}% | Val NLL {val_nll:.4f}") 128 | 129 | scheduler.step() 130 | acc_list.append(test_acc) 131 | ece_list.append(ece) 132 | nll_list.append(val_nll) 133 | 134 | # the mask poutine is needed to only evaluate the log likelihood of the training nodes 135 | with tyxe.poutine.selective_mask(mask=train_mask, hide_all=False, expose=["likelihood.data"]): 136 | bnn.fit(loader, scheduler, num_epochs, callback, num_particles=1) 137 | 138 | min_nll_epoch = torch.tensor(nll_list).argmin().item() 139 | print(f"At lowest validation NLL (Epoch {min_nll_epoch:03d}; NLL={nll_list[min_nll_epoch]:.4f}): " 140 | f"Test accuracy {100*acc_list[min_nll_epoch]:.1f}% | ECE {100*ece_list[min_nll_epoch]:.1f}%") 141 | 142 | 143 | if __name__ == "__main__": 144 | def list_of_ints(s): 145 | return list(map(int, s.split(","))) 146 | 147 | parser = argparse.ArgumentParser() 148 | parser.add_argument("--inference", default="mean-field", choices=["ml", "map", "mean-field"]) 149 | parser.add_argument("--lr", default=1e-1, type=float) 150 | parser.add_argument("--milestones", default=[100, 200, 300], type=list_of_ints) 151 | parser.add_argument("--num-epochs", default=400, type=int) 152 | main(**vars(parser.parse_args())) 153 | -------------------------------------------------------------------------------- /examples/nerf.py: -------------------------------------------------------------------------------- 1 | """Adapted version of the NeRF example from pytorch3d. Includes a Bayesian version of NeRF based on TyXe. 2 | The original notebook is available at: https://github.com/facebookresearch/pytorch3d/blob/master/docs/tutorials/fit_simple_neural_radiance_field.ipynb""" 3 | 4 | # LICENSE FROM THE pytorch3D repo: 5 | # -------------------------------- 6 | # BSD 3-Clause License 7 | # 8 | # For PyTorch3D software 9 | # 10 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 11 | # 12 | # Redistribution and use in source and binary forms, with or without modification, 13 | # are permitted provided that the following conditions are met: 14 | # 15 | # * Redistributions of source code must retain the above copyright notice, this 16 | # list of conditions and the following disclaimer. 17 | # 18 | # * Redistributions in binary form must reproduce the above copyright notice, 19 | # this list of conditions and the following disclaimer in the documentation 20 | # and/or other materials provided with the distribution. 21 | # 22 | # * Neither the name Facebook nor the names of its contributors may be used to 23 | # endorse or promote products derived from this software without specific 24 | # prior written permission. 25 | # 26 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 27 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 28 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 29 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 30 | # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 31 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 32 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 33 | # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 34 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 35 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 36 | 37 | # END OF LICENSE 38 | 39 | 40 | import argparse 41 | from functools import partial 42 | import os 43 | import torch 44 | import matplotlib.pyplot as plt 45 | import numpy as np 46 | from tqdm import tqdm 47 | from collections import namedtuple 48 | 49 | # Data structures and functions for rendering 50 | try: 51 | from pytorch3d.io import load_objs_as_meshes 52 | from pytorch3d.structures import Volumes 53 | from pytorch3d.transforms import so3_exponential_map 54 | from pytorch3d.renderer import ( 55 | BlendParams, 56 | EmissionAbsorptionRaymarcher, 57 | FoVPerspectiveCameras, 58 | ImplicitRenderer, 59 | look_at_view_transform, 60 | MeshRasterizer, 61 | MeshRenderer, 62 | MonteCarloRaysampler, 63 | NDCGridRaysampler, 64 | PointLights, 65 | RasterizationSettings, 66 | RayBundle, 67 | ray_bundle_to_ray_points, 68 | SoftPhongShader, 69 | SoftSilhouetteShader, 70 | ) 71 | except ImportError: 72 | print("Failed to import from pytorch3d. This is not a dependency of Tyxe and needs to be installed separately. " 73 | "You may need to install the nightly rather than the stable version") 74 | raise 75 | 76 | import pyro 77 | import pyro.distributions as dist 78 | 79 | 80 | import tyxe 81 | 82 | 83 | # create the default data directory 84 | current_dir = os.path.dirname(os.path.realpath(__file__)) 85 | DATA_DIR = os.path.join(current_dir, "..", "data", "cow_mesh") 86 | 87 | 88 | def generate_cow_renders( 89 | num_views: int = 40, data_dir: str = DATA_DIR, azimuth_low: float = -180, azimuth_high: float = 180 90 | ): 91 | """ 92 | This function generates `num_views` renders of a cow mesh. 93 | The renders are generated from viewpoints sampled at uniformly distributed 94 | azimuth intervals. The elevation is kept constant so that the camera's 95 | vertical position coincides with the equator. 96 | For a more detailed explanation of this code, please refer to the 97 | docs/tutorials/fit_textured_mesh.ipynb notebook. 98 | Args: 99 | num_views: The number of generated renders. 100 | data_dir: The folder that contains the cow mesh files. If the cow mesh 101 | files do not exist in the folder, this function will automatically 102 | download them. 103 | Returns: 104 | cameras: A batch of `num_views` `FoVPerspectiveCameras` from which the 105 | images are rendered. 106 | images: A tensor of shape `(num_views, height, width, 3)` containing 107 | the rendered images. 108 | silhouettes: A tensor of shape `(num_views, height, width)` containing 109 | the rendered silhouettes. 110 | """ 111 | 112 | # set the paths 113 | 114 | # download the cow mesh if not done before 115 | cow_mesh_files = [ 116 | os.path.join(data_dir, fl) for fl in ("cow.obj", "cow.mtl", "cow_texture.png") 117 | ] 118 | if any(not os.path.isfile(f) for f in cow_mesh_files): 119 | os.makedirs(data_dir, exist_ok=True) 120 | os.system( 121 | f"wget -P {data_dir} " 122 | + "https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.obj" 123 | ) 124 | os.system( 125 | f"wget -P {data_dir} " 126 | + "https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.mtl" 127 | ) 128 | os.system( 129 | f"wget -P {data_dir} " 130 | + "https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png" 131 | ) 132 | 133 | # Setup 134 | if torch.cuda.is_available(): 135 | device = torch.device("cuda:0") 136 | torch.cuda.set_device(device) 137 | else: 138 | device = torch.device("cpu") 139 | 140 | # Load obj file 141 | obj_filename = os.path.join(data_dir, "cow.obj") 142 | mesh = load_objs_as_meshes([obj_filename], device=device) 143 | 144 | # We scale normalize and center the target mesh to fit in a sphere of radius 1 145 | # centered at (0,0,0). (scale, center) will be used to bring the predicted mesh 146 | # to its original center and scale. Note that normalizing the target mesh, 147 | # speeds up the optimization but is not necessary! 148 | verts = mesh.verts_packed() 149 | N = verts.shape[0] 150 | center = verts.mean(0) 151 | scale = max((verts - center).abs().max(0)[0]) 152 | mesh.offset_verts_(-(center.expand(N, 3))) 153 | mesh.scale_verts_((1.0 / float(scale))) 154 | 155 | # Get a batch of viewing angles. 156 | elev = torch.linspace(0, 0, num_views) # keep constant 157 | azim = torch.linspace(azimuth_low, azimuth_high, num_views) + 180.0 158 | 159 | # Place a point light in front of the object. As mentioned above, the front of 160 | # the cow is facing the -z direction. 161 | lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]]) 162 | 163 | # Initialize an OpenGL perspective camera that represents a batch of different 164 | # viewing angles. All the cameras helper methods support mixed type inputs and 165 | # broadcasting. So we can view the camera from the a distance of dist=2.7, and 166 | # then specify elevation and azimuth angles for each viewpoint as tensors. 167 | R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim) 168 | cameras = FoVPerspectiveCameras(device=device, R=R, T=T) 169 | 170 | # Define the settings for rasterization and shading. Here we set the output 171 | # image to be of size 128X128. As we are rendering images for visualization 172 | # purposes only we will set faces_per_pixel=1 and blur_radius=0.0. Refer to 173 | # rasterize_meshes.py for explanations of these parameters. We also leave 174 | # bin_size and max_faces_per_bin to their default values of None, which sets 175 | # their values using huristics and ensures that the faster coarse-to-fine 176 | # rasterization method is used. Refer to docs/notes/renderer.md for an 177 | # explanation of the difference between naive and coarse-to-fine rasterization. 178 | raster_settings = RasterizationSettings( 179 | image_size=128, blur_radius=0.0, faces_per_pixel=1 180 | ) 181 | 182 | # Create a phong renderer by composing a rasterizer and a shader. The textured 183 | # phong shader will interpolate the texture uv coordinates for each vertex, 184 | # sample from a texture image and apply the Phong lighting model 185 | blend_params = BlendParams(sigma=1e-4, gamma=1e-4, background_color=(0.0, 0.0, 0.0)) 186 | renderer = MeshRenderer( 187 | rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), 188 | shader=SoftPhongShader( 189 | device=device, cameras=cameras, lights=lights, blend_params=blend_params 190 | ), 191 | ) 192 | 193 | # Create a batch of meshes by repeating the cow mesh and associated textures. 194 | # Meshes has a useful `extend` method which allows us do this very easily. 195 | # This also extends the textures. 196 | meshes = mesh.extend(num_views) 197 | 198 | # Render the cow mesh from each viewing angle 199 | target_images = renderer(meshes, cameras=cameras, lights=lights) 200 | 201 | # Rasterization settings for silhouette rendering 202 | sigma = 1e-4 203 | raster_settings_silhouette = RasterizationSettings( 204 | image_size=128, blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma, faces_per_pixel=50 205 | ) 206 | 207 | # Silhouette renderer 208 | renderer_silhouette = MeshRenderer( 209 | rasterizer=MeshRasterizer( 210 | cameras=cameras, raster_settings=raster_settings_silhouette 211 | ), 212 | shader=SoftSilhouetteShader(), 213 | ) 214 | 215 | # Render silhouette images. The 3rd channel of the rendering output is 216 | # the alpha/silhouette channel 217 | silhouette_images = renderer_silhouette(meshes, cameras=cameras, lights=lights) 218 | 219 | # binary silhouettes 220 | silhouette_binary = (silhouette_images[..., 3] > 1e-4).float() 221 | 222 | return cameras, target_images[..., :3], silhouette_binary 223 | 224 | 225 | def image_grid( 226 | images, 227 | rows=None, 228 | cols=None, 229 | fill: bool = True, 230 | show_axes: bool = False, 231 | rgb: bool = True, 232 | **kwargs 233 | ): 234 | """ 235 | A util function for plotting a grid of images. 236 | Args: 237 | images: (N, H, W, 4) array of RGBA images 238 | rows: number of rows in the grid 239 | cols: number of columns in the grid 240 | fill: boolean indicating if the space between images should be filled 241 | show_axes: boolean indicating if the axes of the plots should be visible 242 | rgb: boolean, If True, only RGB channels are plotted. 243 | If False, only the alpha channel is plotted. 244 | Returns: 245 | None 246 | """ 247 | if (rows is None) != (cols is None): 248 | raise ValueError("Specify either both rows and cols or neither.") 249 | 250 | if rows is None: 251 | rows = len(images) 252 | cols = 1 253 | 254 | gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {} 255 | fig, axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw, figsize=(15, 9)) 256 | bleed = 0 257 | fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed)) 258 | 259 | for ax, im in zip(axarr.ravel(), images): 260 | if rgb: 261 | # only render RGB channels 262 | ax.imshow(im[..., :3], **kwargs) 263 | else: 264 | # only render Alpha channel 265 | ax.imshow(im[..., 3], **kwargs) 266 | if not show_axes: 267 | ax.set_axis_off() 268 | 269 | 270 | class HarmonicEmbedding(torch.nn.Module): 271 | def __init__(self, n_harmonic_functions=60, omega0=0.1): 272 | """ 273 | Given an input tensor `x` of shape [minibatch, ... , dim], 274 | the harmonic embedding layer converts each feature 275 | in `x` into a series of harmonic features `embedding` 276 | as follows: 277 | embedding[..., i*dim:(i+1)*dim] = [ 278 | sin(x[..., i]), 279 | sin(2*x[..., i]), 280 | sin(4*x[..., i]), 281 | ... 282 | sin(2**self.n_harmonic_functions * x[..., i]), 283 | cos(x[..., i]), 284 | cos(2*x[..., i]), 285 | cos(4*x[..., i]), 286 | ... 287 | cos(2**self.n_harmonic_functions * x[..., i]) 288 | ] 289 | 290 | Note that `x` is also premultiplied by `omega0` before 291 | evaluting the harmonic functions. 292 | """ 293 | super().__init__() 294 | self.register_buffer( 295 | 'frequencies', 296 | omega0 * (2.0 ** torch.arange(n_harmonic_functions)), 297 | ) 298 | 299 | def forward(self, x): 300 | """ 301 | Args: 302 | x: tensor of shape [..., dim] 303 | Returns: 304 | embedding: a harmonic embedding of `x` 305 | of shape [..., n_harmonic_functions * dim * 2] 306 | """ 307 | embed = (x[..., None] * self.frequencies).view(*x.shape[:-1], -1) 308 | return torch.cat((embed.sin(), embed.cos()), dim=-1) 309 | 310 | 311 | def save_io_nerf(rays_points_world, rays_densities, rays_colors, path='./data/'): 312 | os.makedirs(path, exist_ok=True) 313 | torch.save(rays_points_world.detach().to("cpu"), os.path.join(path, 'rays_points_world.pt')) 314 | torch.save(rays_densities.detach().to("cpu"), os.path.join(path, "rays_densities.pt")) 315 | torch.save(rays_colors.detach().to("cpu"), os.path.join(path, "rays_colors.pt")) 316 | 317 | 318 | class NeuralRadianceField(torch.nn.Module): 319 | def __init__(self, n_harmonic_functions=60, n_hidden_neurons=256): 320 | super().__init__() 321 | """ 322 | Args: 323 | n_harmonic_functions: The number of harmonic functions 324 | used to form the harmonic embedding of each point. 325 | n_hidden_neurons: The number of hidden units in the 326 | fully connected layers of the MLPs of the model. 327 | """ 328 | 329 | # The harmonic embedding layer converts input 3D coordinates 330 | # to a representation that is more suitable for 331 | # processing with a deep neural network. 332 | self.harmonic_embedding = HarmonicEmbedding(n_harmonic_functions) 333 | 334 | # The dimension of the harmonic embedding. 335 | embedding_dim = n_harmonic_functions * 2 * 3 336 | 337 | # self.mlp is a simple 2-layer multi-layer perceptron 338 | # which converts the input per-point harmonic embeddings 339 | # to a latent representation. 340 | # Not that we use Softplus activations instead of ReLU. 341 | self.mlp = torch.nn.Sequential( 342 | torch.nn.Linear(embedding_dim, n_hidden_neurons), 343 | torch.nn.Softplus(beta=10.0), 344 | torch.nn.Linear(n_hidden_neurons, n_hidden_neurons), 345 | torch.nn.Softplus(beta=10.0), 346 | ) 347 | 348 | # Given features predicted by self.mlp, self.color_layer 349 | # is responsible for predicting a 3-D per-point vector 350 | # that represents the RGB color of the point. 351 | self.color_layer = torch.nn.Sequential( 352 | torch.nn.Linear(n_hidden_neurons + embedding_dim, n_hidden_neurons), 353 | torch.nn.Softplus(beta=10.0), 354 | torch.nn.Linear(n_hidden_neurons, 3), 355 | torch.nn.Sigmoid(), 356 | # To ensure that the colors correctly range between [0-1], 357 | # the layer is terminated with a sigmoid layer. 358 | ) 359 | 360 | # The density layer converts the features of self.mlp 361 | # to a 1D density value representing the raw opacity 362 | # of each point. 363 | self.density_layer = torch.nn.Sequential( 364 | torch.nn.Linear(n_hidden_neurons, 1), 365 | torch.nn.Softplus(beta=10.0), 366 | # Sofplus activation ensures that the raw opacity 367 | # is a non-negative number. 368 | ) 369 | 370 | # We set the bias of the density layer to - 1.5 in order to initialize the opacities of the 371 | # ray points to values close to 0. 372 | # This is a crucial detail for ensuring convergence of the model. 373 | self.density_layer[0].bias.data[0] = -1.5 374 | 375 | def _get_densities(self, features): 376 | """ 377 | This function takes `features` predicted by `self.mlp` 378 | and converts them to `raw_densities` with `self.density_layer`. 379 | `raw_densities` are later mapped to [0-1] range with 380 | 1 - inverse exponential of `raw_densities`. 381 | """ 382 | raw_densities = self.density_layer(features) 383 | return 1 - (-raw_densities).exp() 384 | 385 | def _get_colors(self, features, rays_directions): 386 | """ 387 | This function takes per-point `features` predicted by `self.mlp` 388 | and evaluates the color model in order to attach to each 389 | point a 3D vector of its RGB color. 390 | 391 | In order to represent viewpoint dependent effects, 392 | before evaluating `self.color_layer`, `NeuralRadianceField` 393 | concatenates to the `features` a harmonic embedding 394 | of `ray_directions`, which are per-point directions 395 | of point rays expressed as 3D l2-normalized vectors 396 | in world coordinates. 397 | """ 398 | spatial_size = features.shape[:-1] 399 | 400 | # Normalize the ray_directions to unit l2 norm. 401 | rays_directions_normed = torch.nn.functional.normalize( 402 | rays_directions, dim=-1 403 | ) 404 | 405 | # Obtain the harmonic embedding of the normalized ray directions. 406 | rays_embedding = self.harmonic_embedding( 407 | rays_directions_normed 408 | ) 409 | 410 | # Expand the ray directions tensor so that its spatial size 411 | # is equal to the size of features. 412 | rays_embedding_expand = rays_embedding[..., None, :].expand( 413 | *spatial_size, rays_embedding.shape[-1] 414 | ) 415 | 416 | # Concatenate ray direction embeddings with 417 | # features and evaluate the color model. 418 | color_layer_input = torch.cat( 419 | (features, rays_embedding_expand), 420 | dim=-1) 421 | return self.color_layer(color_layer_input) 422 | 423 | def forward( 424 | self, 425 | ray_bundle: RayBundle, 426 | path = None, 427 | **kwargs, 428 | ): 429 | """ 430 | The forward function accepts the parametrizations of 431 | 3D points sampled along projection rays. The forward 432 | pass is responsible for attaching a 3D vector 433 | and a 1D scalar representing the point's 434 | RGB color and opacity respectively. 435 | 436 | Args: 437 | ray_bundle: A RayBundle object containing the following variables: 438 | origins: A tensor of shape `(minibatch, ..., 3)` denoting the 439 | origins of the sampling rays in world coords. 440 | directions: A tensor of shape `(minibatch, ..., 3)` 441 | containing the direction vectors of sampling rays in world coords. 442 | lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)` 443 | containing the lengths at which the rays are sampled. 444 | 445 | Returns: 446 | rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)` 447 | denoting the opacitiy of each ray point. 448 | rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)` 449 | denoting the color of each ray point. 450 | """ 451 | # We first convert the ray parametrizations to world 452 | # coordinates with `ray_bundle_to_ray_points`. 453 | rays_points_world = ray_bundle_to_ray_points(ray_bundle) 454 | # rays_points_world.shape = [minibatch x ... x 3] 455 | 456 | # For each 3D world coordinate, we obtain its harmonic embedding. 457 | embeds = self.harmonic_embedding( 458 | rays_points_world 459 | ) 460 | # embeds.shape = [minibatch x ... x self.n_harmonic_functions*6] 461 | 462 | # self.mlp maps each harmonic embedding to a latent feature space. 463 | features = self.mlp(embeds) 464 | # features.shape = [minibatch x ... x n_hidden_neurons] 465 | 466 | # Finally, given the per-point features, 467 | # execute the density and color branches. 468 | 469 | rays_densities = self._get_densities(features) 470 | # rays_densities.shape = [minibatch x ... x 1] 471 | 472 | rays_colors = self._get_colors(features, ray_bundle.directions) 473 | # rays_colors.shape = [minibatch x ... x 3] 474 | 475 | if path is not None: 476 | save_io_nerf(rays_points_world, rays_densities, rays_colors, path) 477 | 478 | return rays_densities, rays_colors 479 | 480 | 481 | # TYXE comment: this was previously a method of the NRF class, I'm now passing a PytorchBNN instance as the net argument 482 | def batched_forward( 483 | net, 484 | ray_bundle: RayBundle, 485 | n_batches: int = 16, 486 | path = None, 487 | **kwargs, 488 | ): 489 | """ 490 | This function is used to allow for memory efficient processing 491 | of input rays. The input rays are first split to `n_batches` 492 | chunks and passed through the `self.forward` function one at a time 493 | in a for loop. Combined with disabling Pytorch gradient caching 494 | (`torch.no_grad()`), this allows for rendering large batches 495 | of rays that do not all fit into GPU memory in a single forward pass. 496 | In our case, batched_forward is used to export a fully-sized render 497 | of the radiance field for visualisation purposes. 498 | 499 | Args: 500 | ray_bundle: A RayBundle object containing the following variables: 501 | origins: A tensor of shape `(minibatch, ..., 3)` denoting the 502 | origins of the sampling rays in world coords. 503 | directions: A tensor of shape `(minibatch, ..., 3)` 504 | containing the direction vectors of sampling rays in world coords. 505 | lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)` 506 | containing the lengths at which the rays are sampled. 507 | n_batches: Specifies the number of batches the input rays are split into. 508 | The larger the number of batches, the smaller the memory footprint 509 | and the lower the processing speed. 510 | 511 | Returns: 512 | rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)` 513 | denoting the opacitiy of each ray point. 514 | rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)` 515 | denoting the color of each ray point. 516 | 517 | """ 518 | 519 | # Parse out shapes needed for tensor reshaping in this function. 520 | n_pts_per_ray = ray_bundle.lengths.shape[-1] 521 | spatial_size = [*ray_bundle.origins.shape[:-1], n_pts_per_ray] 522 | 523 | # Split the rays to `n_batches` batches. 524 | tot_samples = ray_bundle.origins.shape[:-1].numel() 525 | batches = torch.chunk(torch.arange(tot_samples), n_batches) 526 | 527 | # For each batch, execute the standard forward pass. 528 | batch_outputs = [ 529 | net( 530 | RayBundle( 531 | origins=ray_bundle.origins.view(-1, 3)[batch_idx], 532 | directions=ray_bundle.directions.view(-1, 3)[batch_idx], 533 | lengths=ray_bundle.lengths.view(-1, n_pts_per_ray)[batch_idx], 534 | xys=None, 535 | ), path=None if path is None else f"{path}/batch{i}" 536 | ) for i, batch_idx in enumerate(batches) 537 | ] 538 | 539 | # Concatenate the per-batch rays_densities and rays_colors 540 | # and reshape according to the sizes of the inputs. 541 | rays_densities, rays_colors = [ 542 | torch.cat( 543 | [batch_output[output_i] for batch_output in batch_outputs], dim=0 544 | ).view(*spatial_size, -1) for output_i in (0, 1) 545 | ] 546 | if path is not None: 547 | torch.save(spatial_size, f"{path}/spatial_size.pt") 548 | return rays_densities, rays_colors 549 | 550 | 551 | def huber(x, y, scaling=0.1): 552 | """ 553 | A helper function for evaluating the smooth L1 (huber) loss 554 | between the rendered silhouettes and colors. 555 | """ 556 | diff_sq = (x - y) ** 2 557 | loss = ((1 + diff_sq / (scaling ** 2)).clamp(1e-4).sqrt() - 1) * float(scaling) 558 | return loss 559 | 560 | 561 | def sample_images_at_mc_locs(target_images, sampled_rays_xy): 562 | """ 563 | Given a set of Monte Carlo pixel locations `sampled_rays_xy`, 564 | this method samples the tensor `target_images` at the 565 | respective 2D locations. 566 | 567 | This function is used in order to extract the colors from 568 | ground truth images that correspond to the colors 569 | rendered using `MonteCarloRaysampler`. 570 | """ 571 | ba = target_images.shape[0] 572 | dim = target_images.shape[-1] 573 | spatial_size = sampled_rays_xy.shape[1:-1] 574 | # In order to sample target_images, we utilize 575 | # the grid_sample function which implements a 576 | # bilinear image sampler. 577 | # Note that we have to invert the sign of the 578 | # sampled ray positions to convert the NDC xy locations 579 | # of the MonteCarloRaysampler to the coordinate 580 | # convention of grid_sample. 581 | images_sampled = torch.nn.functional.grid_sample( 582 | target_images.permute(0, 3, 1, 2), 583 | -sampled_rays_xy.view(ba, -1, 1, 2), # note the sign inversion 584 | align_corners=True 585 | ) 586 | return images_sampled.permute(0, 2, 3, 1).view( 587 | ba, *spatial_size, dim 588 | ) 589 | 590 | 591 | def show_full_render( 592 | neural_radiance_field, camera, 593 | target_image, target_silhouette, 594 | loss_history_color, loss_history_sil, 595 | renderer_grid, num_forward=1 596 | ): 597 | """ 598 | This is a helper function for visualizing the 599 | intermediate results of the learning. 600 | 601 | Since the `NeuralRadianceField` suffers from 602 | a large memory footprint, which does not allow to 603 | render the full image grid in a single forward pass, 604 | we utilize the `NeuralRadianceField.batched_forward` 605 | function in combination with disabling the gradient caching. 606 | This chunks the set of emitted rays to batches and 607 | evaluates the implicit function on one-batch at a time 608 | to prevent GPU memory overflow. 609 | """ 610 | 611 | rendered_image_list, rendered_silhouette_list = [], [] 612 | # Prevent gradient caching. 613 | with torch.no_grad(): 614 | for _ in range(num_forward): 615 | # Render using the grid renderer and the 616 | # batched_forward function of neural_radiance_field. 617 | rendered_image_silhouette, _ = renderer_grid( 618 | cameras=camera, 619 | volumetric_function=partial(batched_forward, net=neural_radiance_field) 620 | ) 621 | # Split the rendering result to a silhouette render 622 | # and the image render. 623 | rendered_image_, rendered_silhouette_ = ( 624 | rendered_image_silhouette[0].split([3, 1], dim=-1) 625 | ) 626 | rendered_image_list.append(rendered_image_) 627 | rendered_silhouette_list.append(rendered_silhouette_) 628 | 629 | rendered_images = torch.stack(rendered_image_list) 630 | rendered_image = rendered_images.mean(0) 631 | 632 | rendered_silhouettes = torch.stack(rendered_silhouette_list) 633 | rendered_silhouette = rendered_silhouettes.mean(0) 634 | 635 | if num_forward > 1: 636 | rendered_image_std = rendered_images.var(0).sum(-1).sqrt() 637 | rendered_silhouette_std = rendered_silhouettes.std(0) 638 | else: 639 | rendered_image_std = torch.zeros_like(rendered_image[..., 0]) 640 | rendered_silhouette_std = torch.zeros_like(rendered_silhouette) 641 | 642 | print(f"Max image std: {rendered_image_std.max().item():.4f}; " 643 | f"max image: {rendered_image.max().item():.4f}; " 644 | f"max silhouette std: {rendered_silhouette_std.max().item():.4f}; " 645 | f"max silhouette: {rendered_silhouette.max().item():.4f}") 646 | # Generate plots. 647 | fig, ax = plt.subplots(2, 4, figsize=(20, 10)) 648 | ax = ax.ravel() 649 | clamp_and_detach = lambda x: x.clamp(0.0, 1.0).cpu().detach().numpy() 650 | ax[0].plot(list(range(len(loss_history_color))), loss_history_color, linewidth=1) 651 | ax[1].imshow(clamp_and_detach(rendered_image)) 652 | ax[2].imshow(clamp_and_detach(rendered_silhouette[..., 0])) 653 | ax[3].imshow(clamp_and_detach(rendered_image_std), cmap="hot", vmax=0.75 ** 0.5) 654 | ax[4].plot(list(range(len(loss_history_sil))), loss_history_sil, linewidth=1) 655 | ax[5].imshow(clamp_and_detach(target_image)) 656 | ax[6].imshow(clamp_and_detach(target_silhouette)) 657 | ax[7].imshow(clamp_and_detach(rendered_silhouette_std), cmap="hot", vmax=0.5) 658 | for ax_, title_ in zip( 659 | ax, 660 | ( 661 | "loss color", "rendered image", "rendered silhouette", "image uncertainty", 662 | "loss silhouette", "target image", "target silhouette", "silhouette uncertainty" 663 | ) 664 | ): 665 | if not title_.startswith('loss'): 666 | ax_.grid("off") 667 | ax_.axis("off") 668 | ax_.set_title(title_) 669 | fig.canvas.draw() 670 | fig.show() 671 | return fig 672 | 673 | 674 | def generate_rotating_nerf(neural_radiance_field, target_cameras, renderer_grid, device, n_frames=50, num_forward=1, save_visualization=False): 675 | logRs = torch.zeros(n_frames, 3, device=device) 676 | logRs[:, 1] = torch.linspace(-3.14, 3.14, n_frames, device=device) 677 | Rs = so3_exponential_map(logRs) 678 | Ts = torch.zeros(n_frames, 3, device=device) 679 | Ts[:, 2] = 2.7 680 | frames = [] 681 | uncertainties = [] 682 | path = f"nerf_vis/view{i}/sample{j}" if save_visualization else None 683 | print('Rendering rotating NeRF ...') 684 | for i, (R, T) in enumerate(zip(tqdm(Rs), Ts)): 685 | camera = FoVPerspectiveCameras( 686 | R=R[None], 687 | T=T[None], 688 | znear=target_cameras.znear[0], 689 | zfar=target_cameras.zfar[0], 690 | aspect_ratio=target_cameras.aspect_ratio[0], 691 | fov=target_cameras.fov[0], 692 | device=device, 693 | ) 694 | # Note that we again render with `NDCGridSampler` 695 | # and the batched_forward function of neural_radiance_field. 696 | frame_samples = torch.stack([renderer_grid( 697 | cameras=camera, 698 | volumetric_function=partial(batched_forward, net=neural_radiance_field, path=path) 699 | )[0][..., :3] for j in range(num_forward)]) 700 | frames.append(frame_samples.mean(0)) 701 | uncertainties.append(frame_samples.var(0).sum(-1).sqrt() if num_forward > 1 else torch.zeros_like(frame_samples[0, ..., 0])) 702 | return torch.cat(frames), torch.cat(uncertainties) 703 | 704 | 705 | def main(inference, n_iter, save_state_dict, load_state_dict, kl_annealing_iters, zero_kl_iters, max_kl_factor, 706 | init_scale, save_visualization): 707 | if torch.cuda.is_available(): 708 | device = torch.device("cuda:0") 709 | torch.cuda.set_device(device) 710 | else: 711 | print( 712 | 'Please note that NeRF is a resource-demanding method.' 713 | + ' Running this notebook on CPU will be extremely slow.' 714 | + ' We recommend running the example on a GPU' 715 | + ' with at least 10 GB of memory.' 716 | ) 717 | device = torch.device("cpu") 718 | 719 | target_cameras, target_images, target_silhouettes = generate_cow_renders( 720 | num_views=30, azimuth_low=-180, azimuth_high=90) 721 | print(f'Generated {len(target_images)} images/silhouettes/cameras.') 722 | 723 | # render_size describes the size of both sides of the 724 | # rendered images in pixels. Since an advantage of 725 | # Neural Radiance Fields are high quality renders 726 | # with a significant amount of details, we render 727 | # the implicit function at double the size of 728 | # target images. 729 | render_size = target_images.shape[1] * 2 730 | 731 | # Our rendered scene is centered around (0,0,0) 732 | # and is enclosed inside a bounding box 733 | # whose side is roughly equal to 3.0 (world units). 734 | volume_extent_world = 3.0 735 | 736 | # 1) Instantiate the raysamplers. 737 | 738 | # Here, NDCGridRaysampler generates a rectangular image 739 | # grid of rays whose coordinates follow the PyTorch3d 740 | # coordinate conventions. 741 | raysampler_grid = NDCGridRaysampler( 742 | image_height=render_size, 743 | image_width=render_size, 744 | n_pts_per_ray=128, 745 | min_depth=0.1, 746 | max_depth=volume_extent_world, 747 | ) 748 | 749 | # MonteCarloRaysampler generates a random subset 750 | # of `n_rays_per_image` rays emitted from the image plane. 751 | raysampler_mc = MonteCarloRaysampler( 752 | min_x=-1.0, 753 | max_x=1.0, 754 | min_y=-1.0, 755 | max_y=1.0, 756 | n_rays_per_image=750, 757 | n_pts_per_ray=128, 758 | min_depth=0.1, 759 | max_depth=volume_extent_world, 760 | ) 761 | 762 | # 2) Instantiate the raymarcher. 763 | # Here, we use the standard EmissionAbsorptionRaymarcher 764 | # which marches along each ray in order to render 765 | # the ray into a single 3D color vector 766 | # and an opacity scalar. 767 | raymarcher = EmissionAbsorptionRaymarcher() 768 | 769 | # Finally, instantiate the implicit renders 770 | # for both raysamplers. 771 | renderer_grid = ImplicitRenderer( 772 | raysampler=raysampler_grid, raymarcher=raymarcher, 773 | ) 774 | renderer_mc = ImplicitRenderer( 775 | raysampler=raysampler_mc, raymarcher=raymarcher, 776 | ) 777 | 778 | # First move all relevant variables to the correct device. 779 | renderer_grid = renderer_grid.to(device) 780 | renderer_mc = renderer_mc.to(device) 781 | target_cameras = target_cameras.to(device) 782 | target_images = target_images.to(device) 783 | target_silhouettes = target_silhouettes.to(device) 784 | 785 | # Set the seed for reproducibility 786 | torch.manual_seed(1) 787 | 788 | # Instantiate the radiance field model. 789 | neural_radiance_field_net = NeuralRadianceField().to(device) 790 | if load_state_dict is not None: 791 | sd = torch.load(load_state_dict) 792 | sd["harmonic_embedding.frequencies"] = neural_radiance_field_net.harmonic_embedding.frequencies 793 | neural_radiance_field_net.load_state_dict(sd) 794 | 795 | # TYXE comment: set up the BNN depending on the desired inference 796 | standard_normal = dist.Normal(torch.tensor(0.).to(device), torch.tensor(1.).to(device)) 797 | prior_kwargs = {} 798 | test_samples = 1 799 | if inference == "ml": 800 | prior_kwargs.update(expose_all=False, hide_all=True) 801 | guide = None 802 | elif inference == "map": 803 | guide = partial(pyro.infer.autoguide.AutoDelta, 804 | init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(neural_radiance_field_net)) 805 | elif inference == "mean-field": 806 | guide = partial(tyxe.guides.AutoNormal, init_scale=init_scale, 807 | init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(neural_radiance_field_net)) 808 | test_samples = 8 809 | else: 810 | raise RuntimeError(f"Unreachable inference: {inference}") 811 | 812 | prior = tyxe.priors.IIDPrior(standard_normal, **prior_kwargs) 813 | neural_radiance_field = tyxe.PytorchBNN(neural_radiance_field_net, prior, guide) 814 | 815 | # TYXE comment: we need a batch of dummy data for the BNN to trace the parameters 816 | dummy_data = namedtuple("RayBundle", "origins directions lengths")( 817 | torch.randn(1, 1, 3).to(device), 818 | torch.randn(1, 1, 3).to(device), 819 | torch.randn(1, 1, 8).to(device) 820 | ) 821 | # Instantiate the Adam optimizer. We set its master learning rate to 1e-3. 822 | lr = 1e-3 823 | optimizer = torch.optim.Adam(neural_radiance_field.pytorch_parameters(dummy_data), lr=lr) 824 | 825 | # We sample 6 random cameras in a minibatch. Each camera 826 | # emits raysampler_mc.n_pts_per_image rays. 827 | batch_size = 6 828 | 829 | # Init the loss history buffers. 830 | loss_history_color, loss_history_sil = [], [] 831 | 832 | if kl_annealing_iters > 0 or zero_kl_iters > 0: 833 | kl_factor = 0. 834 | kl_annealing_rate = max_kl_factor / max(kl_annealing_iters, 1) 835 | else: 836 | kl_factor = max_kl_factor 837 | kl_annealing_rate = 0. 838 | # The main optimization loop. 839 | for iteration in range(n_iter): 840 | # In case we reached the last 75% of iterations, 841 | # decrease the learning rate of the optimizer 10-fold. 842 | if iteration == round(n_iter * 0.75): 843 | print('Decreasing LR 10-fold ...') 844 | optimizer = torch.optim.Adam( 845 | neural_radiance_field.pytorch_parameters(dummy_data), lr=lr * 0.1 846 | ) 847 | 848 | # Zero the optimizer gradient. 849 | optimizer.zero_grad() 850 | 851 | # Sample random batch indices. 852 | batch_idx = torch.randperm(len(target_cameras))[:batch_size] 853 | 854 | # Sample the minibatch of cameras. 855 | batch_cameras = FoVPerspectiveCameras( 856 | R=target_cameras.R[batch_idx], 857 | T=target_cameras.T[batch_idx], 858 | znear=target_cameras.znear[batch_idx], 859 | zfar=target_cameras.zfar[batch_idx], 860 | aspect_ratio=target_cameras.aspect_ratio[batch_idx], 861 | fov=target_cameras.fov[batch_idx], 862 | device=device, 863 | ) 864 | 865 | rendered_images_silhouettes, sampled_rays = renderer_mc( 866 | cameras=batch_cameras, 867 | volumetric_function=partial(batched_forward, net=neural_radiance_field) 868 | ) 869 | rendered_images, rendered_silhouettes = ( 870 | rendered_images_silhouettes.split([3, 1], dim=-1) 871 | ) 872 | 873 | # Compute the silhoutte error as the mean huber 874 | # loss between the predicted masks and the 875 | # sampled target silhouettes. 876 | silhouettes_at_rays = sample_images_at_mc_locs( 877 | target_silhouettes[batch_idx, ..., None], 878 | sampled_rays.xys 879 | ) 880 | sil_err = huber( 881 | rendered_silhouettes, 882 | silhouettes_at_rays, 883 | ).abs().mean() 884 | 885 | # Compute the color error as the mean huber 886 | # loss between the rendered colors and the 887 | # sampled target images. 888 | colors_at_rays = sample_images_at_mc_locs( 889 | target_images[batch_idx], 890 | sampled_rays.xys 891 | ) 892 | color_err = huber( 893 | rendered_images, 894 | colors_at_rays, 895 | ).abs().mean() 896 | 897 | # The optimization loss is a simple 898 | # sum of the color and silhouette errors. 899 | # TYXE comment: we also add a kl loss for the variational posterior scaled by the size of the data 900 | # i.e. the total number of data points times the number of values that the data-dependent part of the 901 | # objective averages over. Effectively I'm treating this as if this was something like a Bernoulli likelihood 902 | # in a VAE where the expected log likelihood is averaged over both data points and pixels 903 | beta = kl_factor / (target_images.numel() + target_silhouettes.numel()) 904 | kl_err = neural_radiance_field.cached_kl_loss 905 | loss = color_err + sil_err + beta * kl_err 906 | 907 | # Log the loss history. 908 | loss_history_color.append(float(color_err)) 909 | loss_history_sil.append(float(sil_err)) 910 | 911 | # Every 10 iterations, print the current values of the losses. 912 | if iteration % 10 == 0: 913 | print( 914 | f'Iteration {iteration:05d}:' 915 | + f' loss color = {float(color_err):1.2e}' 916 | + f' loss silhouette = {float(sil_err):1.2e}' 917 | + f' loss kl = {float(kl_err):1.2e}' 918 | + f' kl_factor = {kl_factor:1.3e}' 919 | ) 920 | 921 | # Take the optimization step. 922 | loss.backward() 923 | optimizer.step() 924 | 925 | # TYXE comment: anneal the kl rate 926 | if iteration >= zero_kl_iters: 927 | kl_factor = min(max_kl_factor, kl_factor + kl_annealing_rate) 928 | 929 | # Visualize the full renders every 100 iterations. 930 | if iteration % 1000 == 0: 931 | show_idx = torch.randperm(len(target_cameras))[:1] 932 | fig = show_full_render( 933 | neural_radiance_field, 934 | FoVPerspectiveCameras( 935 | R=target_cameras.R[show_idx], 936 | T=target_cameras.T[show_idx], 937 | znear=target_cameras.znear[show_idx], 938 | zfar=target_cameras.zfar[show_idx], 939 | aspect_ratio=target_cameras.aspect_ratio[show_idx], 940 | fov=target_cameras.fov[show_idx], 941 | device=device, 942 | ), 943 | target_images[show_idx][0], 944 | target_silhouettes[show_idx][0], 945 | loss_history_color, 946 | loss_history_sil, 947 | renderer_grid, 948 | num_forward=test_samples 949 | ) 950 | plt.savefig(f"nerf/full_render{iteration}.png") 951 | plt.close(fig) 952 | 953 | with torch.no_grad(): 954 | rotating_nerf_frames, uncertainty_frames = generate_rotating_nerf( 955 | neural_radiance_field, 956 | target_cameras, 957 | renderer_grid, 958 | device, 959 | n_frames=3 * 5, 960 | num_forward=test_samples, 961 | save_visualization=save_visualization 962 | ) 963 | 964 | for i, (img, uncertainty) in enumerate(zip( 965 | rotating_nerf_frames.clamp(0., 1.).cpu().numpy(), uncertainty_frames.cpu().numpy())): 966 | f, ax = plt.subplots(figsize=(1.625, 1.625)) 967 | f.subplots_adjust(0, 0, 1, 1) 968 | ax.imshow(img) 969 | ax.set_axis_off() 970 | f.savefig(f"nerf/final_image{i}.jpg", bbox_inches="tight", pad_inches=0) 971 | plt.close(f) 972 | 973 | f, ax = plt.subplots(figsize=(1.625, 1.625)) 974 | f.subplots_adjust(0, 0, 1, 1) 975 | ax.imshow(uncertainty, cmap="hot", vmax=0.75 ** 0.5) 976 | ax.set_axis_off() 977 | f.savefig(f"nerf/final_uncertainty{i}.jpg", bbox_inches="tight", pad_inches=0) 978 | plt.close(f) 979 | 980 | if save_state_dict is not None: 981 | if inference != "ml": 982 | raise ValueError("Saving the state dict is only available for ml inference for now.") 983 | state_dict = dict(neural_radiance_field.named_pytorch_parameters(dummy_data)) 984 | torch.save(state_dict, save_state_dict) 985 | 986 | test_cameras, test_images, test_silhouettes = generate_cow_renders( 987 | num_views=10, azimuth_low=90, azimuth_high=180) 988 | 989 | del renderer_mc 990 | del target_cameras 991 | del target_images 992 | del target_silhouettes 993 | torch.cuda.empty_cache() 994 | 995 | test_cameras = test_cameras.to(device) 996 | test_images = test_images.to(device) 997 | test_silhouettes = test_silhouettes.to(device) 998 | 999 | # TODO remove duplication from training code for test error 1000 | with torch.no_grad(): 1001 | sil_err = 0. 1002 | color_err = 0. 1003 | for i in range(len(test_cameras)): 1004 | batch_idx = [i] 1005 | 1006 | # Sample the minibatch of cameras. 1007 | batch_cameras = FoVPerspectiveCameras( 1008 | R=test_cameras.R[batch_idx], 1009 | T=test_cameras.T[batch_idx], 1010 | znear=test_cameras.znear[batch_idx], 1011 | zfar=test_cameras.zfar[batch_idx], 1012 | aspect_ratio=test_cameras.aspect_ratio[batch_idx], 1013 | fov=test_cameras.fov[batch_idx], 1014 | device=device, 1015 | ) 1016 | 1017 | img_list, sils_list, sampled_rays_list, = [], [], [] 1018 | for _ in range(test_samples): 1019 | rendered_images_silhouettes, sampled_rays = renderer_grid( 1020 | cameras=batch_cameras, 1021 | volumetric_function=partial(batched_forward, net=neural_radiance_field) 1022 | ) 1023 | imgs, sils = ( 1024 | rendered_images_silhouettes.split([3, 1], dim=-1) 1025 | ) 1026 | img_list.append(imgs) 1027 | sils_list.append(sils) 1028 | sampled_rays_list.append(sampled_rays.xys) 1029 | 1030 | assert sampled_rays_list[0].eq(torch.stack(sampled_rays_list)).all() 1031 | 1032 | rendered_images = torch.stack(img_list).mean(0) 1033 | rendered_silhouettes = torch.stack(sils_list).mean(0) 1034 | 1035 | # Compute the silhoutte error as the mean huber 1036 | # loss between the predicted masks and the 1037 | # sampled target silhouettes. 1038 | # TYXE comment: sampled_rays are always the same for renderer_grid 1039 | silhouettes_at_rays = sample_images_at_mc_locs( 1040 | test_silhouettes[batch_idx, ..., None], 1041 | sampled_rays.xys 1042 | ) 1043 | sil_err += huber( 1044 | rendered_silhouettes, 1045 | silhouettes_at_rays, 1046 | ).abs().mean().item() / len(test_cameras) 1047 | 1048 | # Compute the color error as the mean huber 1049 | # loss between the rendered colors and the 1050 | # sampled target images. 1051 | colors_at_rays = sample_images_at_mc_locs( 1052 | test_images[batch_idx], 1053 | sampled_rays.xys 1054 | ) 1055 | color_err += huber( 1056 | rendered_images, 1057 | colors_at_rays, 1058 | ).abs().mean().item() / len(test_cameras) 1059 | 1060 | print(f"Test error: sil={sil_err:1.3e}; col={color_err:1.3e}") 1061 | 1062 | 1063 | if __name__ == '__main__': 1064 | parser = argparse.ArgumentParser() 1065 | parser.add_argument("--inference", choices=["ml", "map", "mean-field"], required=True) 1066 | parser.add_argument("--n-iter", type=int, default=20000) 1067 | parser.add_argument("--save-state-dict") 1068 | parser.add_argument("--load-state-dict") 1069 | parser.add_argument("--kl-annealing-iters", type=int, default=0) 1070 | parser.add_argument("--zero-kl-iters", type=int, default=0) 1071 | parser.add_argument("--max-kl-factor", type=float, default=1.) 1072 | parser.add_argument("--init-scale", type=float, default=1e-2) 1073 | parser.add_argument("--save-visualization", action="store_true") 1074 | main(**vars(parser.parse_args())) 1075 | -------------------------------------------------------------------------------- /examples/resnet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import contextlib 3 | import functools 4 | import os 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.utils.data as data 9 | 10 | import torchvision 11 | 12 | import pyro 13 | import pyro.distributions as dist 14 | import pyro.infer.autoguide as ag 15 | 16 | 17 | import tyxe 18 | 19 | 20 | NORMALIZERS = { 21 | "cifar10": ((0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784)), 22 | "cifar100": ((0.50707516, 0.48654887, 0.44091784), (0.26733429, 0.25643846, 0.27615047)), 23 | "svhn": ((0.4376821, 0.4437697, 0.47280442), (0.19803012, 0.20101562, 0.19703614)) 24 | } 25 | 26 | 27 | def make_loaders(dataset, root, train_batch_size, test_batch_size, use_cuda): 28 | train_img_transforms = [torchvision.transforms.RandomCrop(size=32, padding=4, padding_mode='reflect'), 29 | torchvision.transforms.RandomHorizontalFlip()] 30 | test_img_transforms = [] 31 | tensor_transforms = [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(*NORMALIZERS[dataset])] 32 | 33 | dataset_fn = getattr(torchvision.datasets, dataset.upper()) 34 | train_transform = torchvision.transforms.Compose(train_img_transforms + tensor_transforms) 35 | train_data = dataset_fn(root, train=True, transform=train_transform, download=True) 36 | train_loader = data.DataLoader( 37 | train_data, train_batch_size, pin_memory=use_cuda, num_workers=2 * int(use_cuda), shuffle=True) 38 | 39 | test_transform = torchvision.transforms.Compose(test_img_transforms + tensor_transforms) 40 | test_data = dataset_fn(root, train=False, transform=test_transform, download=True) 41 | test_loader = data.DataLoader(test_data, test_batch_size) 42 | 43 | ood_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), 44 | torchvision.transforms.Normalize(*NORMALIZERS["svhn"])]) 45 | ood_data = torchvision.datasets.SVHN(root, split="test", transform=ood_transform, download=True) 46 | ood_loader = data.DataLoader(ood_data, test_batch_size) 47 | 48 | return train_loader, test_loader, ood_loader 49 | 50 | 51 | def make_net(dataset, architecture): 52 | net = getattr(torchvision.models, architecture)(pretrained=True) 53 | if dataset.startswith("cifar"): 54 | net.conv1 = nn.Conv2d(3, net.conv1.out_channels, kernel_size=3, stride=1, padding=1, bias=False) 55 | net.maxpool = nn.Identity() 56 | num_classes = 10 if dataset.endswith("10") else 100 57 | net.fc = nn.Linear(net.fc.in_features, num_classes) 58 | return net 59 | 60 | 61 | def main(dataset, architecture, inference, train_batch_size, test_batch_size, local_reparameterization, flipout, 62 | num_epochs, test_samples, max_guide_scale, rank, root, seed, output_dir, pretrained_weights, scale_only, lr, milestones, gamma): 63 | pyro.set_rng_seed(seed) 64 | use_cuda = torch.cuda.is_available() 65 | device = torch.device("cuda" if use_cuda else "cpu") 66 | 67 | train_loader, test_loader, ood_loader = make_loaders(dataset, root, train_batch_size, test_batch_size, use_cuda) 68 | net = make_net(dataset, architecture).to(device) 69 | if pretrained_weights is not None: 70 | sd = torch.load(pretrained_weights, map_location=device) 71 | net.load_state_dict(sd) 72 | 73 | observation_model = tyxe.likelihoods.Categorical(len(train_loader.sampler)) 74 | 75 | prior_kwargs = dict(expose_all=False, hide_module_types=(nn.BatchNorm2d,)) 76 | if inference == "ml": 77 | test_samples = 1 78 | guide = None 79 | prior_kwargs["hide_all"] = True 80 | elif inference == "map": 81 | test_samples = 1 82 | guide = functools.partial(ag.AutoDelta, init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(net, prefix="net")) 83 | elif inference == "mean-field": 84 | guide = functools.partial(tyxe.guides.AutoNormal, 85 | init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(net, prefix="net"), init_scale=1e-4, 86 | max_guide_scale=max_guide_scale, train_loc=not scale_only) 87 | elif inference.startswith("last-layer"): 88 | if pretrained_weights is None: 89 | raise ValueError("Asked to do last-layer inference, but no pre-trained weights were provided.") 90 | # turning parameters except for last layer in buffers to avoid training them 91 | # this might be avoidable via poutine.block 92 | for module in net.modules(): 93 | if module is not net.fc: 94 | for param_name, param in list(module.named_parameters(recurse=False)): 95 | delattr(module, param_name) 96 | module.register_buffer(param_name, param.detach().data) 97 | del prior_kwargs['hide_module_types'] 98 | prior_kwargs["expose_modules"] = [net.fc] 99 | if inference == "last-layer-mean-field": 100 | guide = functools.partial(tyxe.guides.AutoNormal, 101 | init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(net, prefix="net"), init_scale=1e-4) 102 | elif inference == "last-layer-full": 103 | guide = functools.partial(ag.AutoMultivariateNormal, 104 | init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(net, prefix="net"), init_scale=1e-4) 105 | elif inference == "last-layer-low-rank": 106 | guide = functools.partial(ag.AutoLowRankMultivariateNormal, rank=rank, 107 | init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(net, prefix="net"), init_scale=1e-4) 108 | else: 109 | raise RuntimeError("Unreachable") 110 | else: 111 | raise RuntimeError("Unreachable") 112 | 113 | prior = tyxe.priors.IIDPrior(dist.Normal(torch.zeros(1, device=device), torch.ones(1, device=device)), 114 | **prior_kwargs) 115 | bnn = tyxe.VariationalBNN(net, prior, observation_model, guide) 116 | 117 | if local_reparameterization: 118 | if flipout: 119 | raise RuntimeError("Can't use both local reparameterization and flipout, pick one.") 120 | fit_ctxt = tyxe.poutine.local_reparameterization 121 | elif flipout: 122 | fit_ctxt = tyxe.poutine.flipout 123 | else: 124 | fit_ctxt = contextlib.nullcontext 125 | 126 | if milestones is None: 127 | optim = pyro.optim.Adam({"lr": lr}) 128 | else: 129 | optimizer = torch.optim.Adam 130 | optim = pyro.optim.MultiStepLR({"optimizer": optimizer, "optim_args": {"lr": lr}, "milestones": milestones, "gamma": gamma}) 131 | 132 | def callback(b, i, avg_elbo): 133 | avg_err, avg_ll = 0., 0. 134 | b.eval() 135 | for x, y in iter(test_loader): 136 | err, ll = b.evaluate(x.to(device), y.to(device), num_predictions=test_samples) 137 | avg_err += err / len(test_loader.sampler) 138 | avg_ll += ll / len(test_loader.sampler) 139 | print(f"ELBO={avg_elbo}; test error={100 * avg_err:.2f}%; LL={avg_ll:.4f}") 140 | b.train() 141 | 142 | with fit_ctxt(): 143 | bnn.fit(train_loader, optim, num_epochs, callback=callback, device=device) 144 | 145 | if output_dir is not None: 146 | pyro.get_param_store().save(os.path.join(output_dir, "param_store.pt")) 147 | torch.save(bnn.state_dict(), os.path.join(output_dir, "state_dict.pt")) 148 | 149 | test_predictions = torch.cat([bnn.predict(x.to(device), num_predictions=test_samples) 150 | for x, _ in iter(test_loader)]) 151 | torch.save(test_predictions.detach().cpu(), os.path.join(output_dir, "test_predictions.pt")) 152 | 153 | ood_predictions = torch.cat([bnn.predict(x.to(device), num_predictions=test_samples) 154 | for x, _ in iter(ood_loader)]) 155 | torch.save(ood_predictions.detach().cpu(), os.path.join(output_dir, "ood_predictions.pt")) 156 | 157 | 158 | if __name__ == '__main__': 159 | resnets = [n for n in dir(torchvision.models) 160 | if (n.startswith("resnet") or n.startswith("wide_resnet")) and n[-1].isdigit()] 161 | 162 | parser = argparse.ArgumentParser() 163 | parser.add_argument("--dataset", default="cifar10", choices=["cifar10", "cifar100"]) 164 | parser.add_argument("--architecture", default="resnet18", choices=resnets) 165 | parser.add_argument("--inference", required=True, choices=[ 166 | "ml", "map", "mean-field", "last-layer-mean-field", "last-layer-full", "last-layer-low-rank"]) 167 | 168 | parser.add_argument("--train-batch-size", type=int, default=100) 169 | parser.add_argument("--test-batch-size", type=int, default=1000) 170 | parser.add_argument("--num-epochs", type=int, default=200) 171 | parser.add_argument("--test-samples", type=int, default=20) 172 | parser.add_argument("--local-reparameterization", action="store_true") 173 | parser.add_argument("--flipout", action="store_true") 174 | parser.add_argument("--max-guide-scale", type=float) 175 | parser.add_argument("--rank", type=int, default=10) 176 | 177 | parser.add_argument("--root", default=os.environ.get("DATASETS_PATH", "./data")) 178 | parser.add_argument("--seed", type=int, default=42) 179 | parser.add_argument("--output-dir") 180 | parser.add_argument("--pretrained-weights") 181 | parser.add_argument("--scale-only", action="store_true") 182 | 183 | parser.add_argument("--lr", type=float, default=1e-3) 184 | parser.add_argument("--milestones", type=lambda s: list(map(int, s.split(",")))) 185 | parser.add_argument("--gamma", type=float, default=0.1) 186 | 187 | main(**vars((parser.parse_args()))) 188 | -------------------------------------------------------------------------------- /examples/vcl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import functools 4 | import os 5 | 6 | from tqdm import tqdm 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.utils.data as data 11 | 12 | import torchvision.transforms as tf 13 | from torchvision.datasets import MNIST, CIFAR10, CIFAR100 14 | 15 | import pyro 16 | import pyro.distributions as dist 17 | 18 | 19 | import tyxe 20 | 21 | 22 | ROOT = os.environ.get("DATASETS_PATH", "./data") 23 | USE_CUDA = torch.cuda.is_available() 24 | DEVICE = torch.device("cuda") if USE_CUDA else torch.device("cpu") 25 | C10_MEAN = (0.49139968, 0.48215841, 0.44653091) 26 | C10_SD = (0.24703223, 0.24348513, 0.26158784) 27 | 28 | 29 | def conv_3x3(c_in, c_out): 30 | return nn.Conv2d(c_in, c_out, kernel_size=3, stride=1, padding=1) 31 | 32 | 33 | class ConvNet(nn.Sequential): 34 | 35 | def __init__(self): 36 | super().__init__() 37 | self.add_module("Conv1_1", conv_3x3(3, 32)) 38 | self.add_module("ReLU1_1", nn.ReLU(inplace=True)) 39 | self.add_module("Conv1_2", conv_3x3(32, 32)) 40 | self.add_module("ReLU1_2", nn.ReLU(inplace=True)) 41 | self.add_module("MaxPool1", nn.MaxPool2d(2, stride=2)) 42 | 43 | self.add_module("Conv2_1", conv_3x3(32, 64)) 44 | self.add_module("ReLU2_1", nn.ReLU(inplace=True)) 45 | self.add_module("Conv2_2", conv_3x3(64, 64)) 46 | self.add_module("ReLU2_2", nn.ReLU(inplace=True)) 47 | self.add_module("MaxPool2", nn.MaxPool2d(2, stride=2)) 48 | 49 | self.add_module("Flatten", nn.Flatten()) 50 | 51 | self.add_module("Linear", nn.Linear(64 * 8 * 8, 512)) 52 | self.add_module("ReLU", nn.ReLU(inplace=True)) 53 | 54 | self.add_module("Head", nn.Linear(512, 10)) 55 | 56 | 57 | class FCNet(nn.Sequential): 58 | 59 | def __init__(self): 60 | super().__init__() 61 | self.add_module("Linear", nn.Linear(784, 200)) 62 | self.add_module("ReLU", nn.ReLU(inplace=True)) 63 | self.add_module("Head", nn.Linear(200, 1)) 64 | 65 | 66 | def make_mnist_dataloaders(root, train_batch_size, test_batch_size): 67 | train_loaders = [] 68 | test_loaders = [] 69 | 70 | for train, loaders, bs in zip((True, False), (train_loaders, test_loaders), (train_batch_size, test_batch_size)): 71 | mnist = MNIST(os.path.join(root, "mnist"), train=train, download=True) 72 | x = mnist.data.flatten(1) / 255. 73 | y = mnist.targets 74 | for i in range(5): 75 | index = y.ge(i * 2) & y.lt((i + 1) * 2) 76 | loaders.append(data.DataLoader(data.TensorDataset(x[index], y[index].sub(2 * i).float().unsqueeze(-1)), 77 | bs, shuffle=True, pin_memory=USE_CUDA)) 78 | 79 | return train_loaders, test_loaders 80 | 81 | 82 | def make_cifar_dataloaders(root, train_batch_size, test_batch_size): 83 | train_loaders = [] 84 | test_loaders = [] 85 | 86 | c100_means = [] 87 | c100_sds = [] 88 | for train, loaders, bs in zip((True, False), (train_loaders, test_loaders), (train_batch_size, test_batch_size)): 89 | c10 = CIFAR10( 90 | os.path.join(root, "cifar10"), 91 | train=train, 92 | transform=tf.Compose([tf.ToTensor(), 93 | tf.Normalize(C10_MEAN, C10_SD)]) 94 | ) 95 | 96 | 97 | loaders.append(data.DataLoader(c10, bs, shuffle=train, pin_memory=USE_CUDA)) 98 | 99 | c100 = CIFAR100(os.path.join(root, "cifar100"), train=train) 100 | unnormalized_data = torch.from_numpy(c100.data).permute(0, 3, 1, 2).div(255.) # convert images to torch arrays 101 | targets = torch.tensor(c100.targets) 102 | 103 | for i in range(5): 104 | index = targets.ge(i * 10) & targets.lt((i + 1) * 10) 105 | 106 | unnormalized_data_i = unnormalized_data[index] 107 | if train: 108 | c100_means.append(unnormalized_data_i.mean((0, 2, 3), keepdims=True)) 109 | c100_sds.append(unnormalized_data_i.std((0, 2, 3), keepdims=True)) 110 | normalized_data_i = (unnormalized_data_i - c100_means[i]) / c100_sds[i] 111 | targets_i = targets[index] - i * 10 112 | 113 | dataset_i = data.TensorDataset(normalized_data_i, targets_i) 114 | loaders.append(data.DataLoader(dataset_i, bs, shuffle=train, pin_memory=USE_CUDA)) 115 | 116 | return train_loaders, test_loaders 117 | 118 | 119 | def main(root, dataset, inference, num_epochs=0): 120 | train_batch_size = 250 121 | test_batch_size = 1000 122 | 123 | if dataset == "cifar": 124 | net = ConvNet() 125 | obs = tyxe.likelihoods.Categorical(-1) 126 | train_loaders, test_loaders = make_cifar_dataloaders(root, train_batch_size, test_batch_size) 127 | num_epochs = 60 if not num_epochs else num_epochs 128 | elif dataset == "mnist": 129 | net = FCNet() 130 | obs = tyxe.likelihoods.Bernoulli(-1, event_dim=1) 131 | train_loaders, test_loaders = make_mnist_dataloaders(root, train_batch_size, test_batch_size) 132 | num_epochs = 600 if not num_epochs else num_epochs 133 | else: 134 | raise RuntimeError("Unreachable") 135 | 136 | net.to(DEVICE) 137 | if inference == "mean-field": 138 | prior = tyxe.priors.IIDPrior(dist.Normal(torch.tensor(0., device=DEVICE), torch.tensor(1., device=DEVICE)), 139 | expose_all=False, hide_modules=[net.Head]) 140 | guide = functools.partial( 141 | tyxe.guides.AutoNormal, 142 | init_scale=1e-4, 143 | init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(net, prefix="net") 144 | ) 145 | test_samples = 8 146 | elif inference == "ml": 147 | prior = tyxe.priors.IIDPrior(dist.Normal(0, 1), expose_all=False, hide_all=True) 148 | guide = None 149 | else: 150 | raise RuntimeError("Unreachable") 151 | bnn = tyxe.VariationalBNN(net, prior, obs, guide) 152 | n_tasks = len(train_loaders) 153 | test_errors = torch.ones(n_tasks, n_tasks) 154 | 155 | head_state_dicts = [] 156 | init_head_sd = copy.deepcopy(net.Head.state_dict()) 157 | for i, train_loader in enumerate(train_loaders, 1): 158 | elbos = [] 159 | net.Head.load_state_dict(init_head_sd) 160 | 161 | pbar = tqdm(total=num_epochs, unit="Epochs", postfix=f"Task {i}") 162 | 163 | def callback(_i, _ii, e): 164 | elbos.append(e / len(train_loader.sampler)) 165 | pbar.update() 166 | 167 | obs.dataset_size = len(train_loader.sampler) 168 | optim = pyro.optim.Adam({"lr": 1e-3}) 169 | with tyxe.poutine.local_reparameterization(): 170 | bnn.fit(train_loader, optim, num_epochs, device=DEVICE, callback=callback) 171 | 172 | pbar.close() 173 | 174 | head_state_dicts.append(copy.deepcopy(net.Head.state_dict())) 175 | for j, (test_loader, head_params) in enumerate(zip(test_loaders, head_state_dicts)): 176 | net.Head.load_state_dict(head_params) 177 | err = sum(bnn.evaluate(x.to(DEVICE), y.to(DEVICE), num_predictions=8)[0] for x, y in test_loader) 178 | test_errors[i-1, j] = err / len(test_loader.sampler) 179 | 180 | print("\t".join(["Error"] + [f"Task {j}" for j in range(1, i+1)])) 181 | print("\t" + "\t".join([f"{100 * e:.2f}%" for e in test_errors[i-1, :i]])) 182 | 183 | if inference == "mean-field": 184 | site_names = tyxe.util.pyro_sample_sites(bnn) 185 | bnn.update_prior(tyxe.priors.DictPrior(bnn.net_guide.get_detached_distributions(site_names))) 186 | 187 | 188 | if __name__ == '__main__': 189 | parser = argparse.ArgumentParser() 190 | parser.add_argument("--root", default=ROOT) 191 | parser.add_argument("--dataset", choices=["mnist", "cifar"], required=True) 192 | parser.add_argument("--inference", choices=["mean-field", "ml"], required=True) 193 | parser.add_argument("--num-epochs", default=0, required=False, type=int) 194 | main(**vars(parser.parse_args())) 195 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='tyxe', 5 | version='0.0.2', 6 | url='https://github.com/TyXe-BDL/TyXe', 7 | author=['Hippolyt Ritter', 'Theofanis Karaletsos'], 8 | author_email='j.ritter@cs.ucl.ac.uk', 9 | description='BNNs for pytorch using pyro.', 10 | packages=find_packages(), 11 | install_requires=[ 12 | 'torch == 1.12.0', 13 | 'torchvision == 0.13.0', 14 | 'pyro-ppl == 1.8.1' 15 | ], 16 | ) 17 | -------------------------------------------------------------------------------- /tests/test_bnn.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.data as data 6 | 7 | import pyro 8 | import pyro.distributions as dist 9 | import pyro.infer.autoguide as ag 10 | from pyro.infer.mcmc import HMC 11 | 12 | 13 | import tyxe 14 | 15 | 16 | def bayesian_regression(n, d, weight_precision, noise_precision): 17 | x = torch.randn(n, d) 18 | w = weight_precision ** -0.5 * torch.randn(d, 1) 19 | y = x @ w + noise_precision ** -0.5 * torch.randn(n, 1) 20 | 21 | posterior_precision = noise_precision * x.t().mm(x) + weight_precision * torch.eye(d) 22 | posterior_mean = torch.cholesky_solve(noise_precision * x.t().mm(y), torch.linalg.cholesky(posterior_precision)) 23 | 24 | return x, y, w, posterior_precision, posterior_mean 25 | 26 | 27 | def get_linear_bnn(n, d, wp, np, guide, variational=True): 28 | l = nn.Linear(d, 1, bias=False) 29 | prior = tyxe.priors.IIDPrior(dist.Normal(0, wp ** -0.5)) 30 | likelihood = tyxe.likelihoods.HomoskedasticGaussian(n, precision=np) 31 | if variational: 32 | return tyxe.VariationalBNN(l, prior, likelihood, guide) 33 | else: 34 | return tyxe.MCMC_BNN(l, prior, likelihood, guide) 35 | 36 | def test_diagonal_svi(): 37 | torch.manual_seed(42) 38 | n, d, wp, np = 20, 2, 1, 100 39 | x, y, w, pp, pm = bayesian_regression(n, d, wp, np) 40 | bnn = get_linear_bnn(n, d, wp, np, partial(ag.AutoDiagonalNormal, init_scale=1e-2)) 41 | 42 | loader = data.DataLoader(data.TensorDataset(x, y), n // 2, shuffle=True) 43 | 44 | optim = torch.optim.Adam 45 | sched = pyro.optim.StepLR({"optimizer": optim, "optim_args": {"lr": 1e-1}, "step_size": 100}) 46 | bnn.fit(loader, sched, int(500), callback=lambda *args: sched.step()) 47 | 48 | vm = pyro.get_param_store()["net_guide.loc"].data.squeeze() 49 | vp = pyro.get_param_store()["net_guide.scale"].data.squeeze() 50 | 51 | assert torch.allclose(vm, pm.squeeze(), atol=1e-2) 52 | assert torch.allclose(vp, pp.diagonal().sqrt().reciprocal(), atol=1e-2) 53 | 54 | 55 | def test_multivariate_svi(): 56 | torch.manual_seed(42) 57 | n, d, wp, np = 20, 2, 1, 100 58 | x, y, w, pp, pm = bayesian_regression(n, d, wp, np) # These are unchanged by the upgrade 59 | bnn = get_linear_bnn(n, d, wp, np, partial(ag.AutoMultivariateNormal, init_scale=1e-2)) 60 | loader = data.DataLoader(data.TensorDataset(x, y), n // 2, shuffle=True) 61 | 62 | optim = torch.optim.Adam 63 | sched = pyro.optim.StepLR({"optimizer": optim, "optim_args": {"lr": 1e-1}, "step_size": 500}) 64 | bnn.fit(loader, sched, 2500, num_particles=4, callback=lambda *args: sched.step()) 65 | 66 | vm = pyro.get_param_store()["net_guide.loc"].data.squeeze() 67 | vsd = pyro.get_param_store()["net_guide.scale"].data 68 | vst = pyro.get_param_store()["net_guide.scale_tril"].data 69 | 70 | vs = vst*vsd 71 | 72 | assert torch.allclose(vm, pm.squeeze(), atol=0.01) 73 | 74 | cov_prec_mm = vs.mm(vs.t()).mm(pp) 75 | 76 | assert torch.allclose(cov_prec_mm, torch.eye(d), atol=0.05) 77 | 78 | site_names = tyxe.util.pyro_sample_sites(bnn.net) 79 | assert "weight" in site_names 80 | 81 | samples = next(tyxe.util.named_pyro_samples(bnn.net)) 82 | assert "weight" in samples 83 | 84 | 85 | def test_hmc(): 86 | torch.manual_seed(42) 87 | n, d, wp, np = 20, 2, 1, 100 88 | x, y, w, pp, pm = bayesian_regression(n, d, wp, np) 89 | bnn = get_linear_bnn(n, d, wp, np, partial(HMC, step_size=1e-2, num_steps=10, target_accept_prob=0.7), 90 | variational=False) 91 | 92 | loader = data.DataLoader(data.TensorDataset(x, y), n // 2, shuffle=True) 93 | mcmc = bnn.fit(loader, num_samples=4000, warmup_steps=1000, disable_progbar=True).get_samples() 94 | w_mcmc = mcmc["net.weight"] 95 | 96 | 97 | w_mean = w_mcmc.mean(0) 98 | w_cov = w_mcmc.transpose(-2, -1).mul(w_mcmc).mean(0) - w_mean.t().mm(w_mean) 99 | 100 | assert torch.allclose(w_mean.squeeze(), pm.squeeze(), atol=1e-2) 101 | cov_prec_mm = w_cov @ pp 102 | assert torch.allclose(cov_prec_mm, torch.eye(d), atol=0.05) 103 | -------------------------------------------------------------------------------- /tests/test_guides.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from functools import partial 4 | 5 | import torch 6 | 7 | import pyro 8 | import pyro.distributions as dist 9 | from pyro.infer import TraceMeanField_ELBO 10 | from pyro.nn import PyroModule, PyroSample 11 | 12 | import tyxe 13 | 14 | 15 | @pytest.fixture(autouse=True) 16 | def setup(): 17 | pyro.clear_param_store() 18 | 19 | 20 | def test_constant_kl(): 21 | model = lambda: pyro.sample("a", dist.Normal(0, 1.)) 22 | guide = tyxe.guides.AutoNormal(model) 23 | elbo = TraceMeanField_ELBO() 24 | assert elbo.loss(model, guide) == elbo.loss(model, guide) 25 | 26 | 27 | def test_auto_normal(): 28 | model = lambda: pyro.sample("a", dist.Normal(0., 1.)) 29 | guide = tyxe.guides.AutoNormal(model) 30 | tr = pyro.poutine.trace(guide).get_trace() 31 | fn = tr.nodes["a"]["fn"] 32 | assert isinstance(fn, dist.Normal) 33 | assert fn.scale.isclose(torch.tensor(0.1)) 34 | assert fn.loc.requires_grad 35 | assert fn.scale.requires_grad 36 | 37 | 38 | def test_auto_normal_constrained(): 39 | model = lambda: pyro.sample("a", dist.Gamma(1., 1.)) 40 | guide = tyxe.guides.AutoNormal(model) 41 | tr = pyro.poutine.trace(guide).get_trace() 42 | fn = tr.nodes["a"]["fn"] 43 | assert isinstance(fn, dist.TransformedDistribution) 44 | assert isinstance(fn.base_dist, dist.Normal) 45 | assert fn.base_dist.scale.isclose(torch.tensor(0.1)) 46 | assert fn.base_dist.loc.requires_grad 47 | assert fn.base_dist.scale.requires_grad 48 | 49 | 50 | def test_auto_normal_constant_loc(): 51 | model = lambda: pyro.sample("a", dist.Normal(0., 1.)) 52 | guide = tyxe.guides.AutoNormal(model, train_loc=False) 53 | guide() 54 | assert not guide.a.loc.requires_grad 55 | 56 | 57 | def test_auto_normal_constant_scale(): 58 | model = lambda: pyro.sample("a", dist.Normal(0., 1.)) 59 | guide = tyxe.guides.AutoNormal(model, train_scale=False) 60 | guide() 61 | assert not guide.a.scale.requires_grad 62 | 63 | 64 | def test_auto_normal_init_scale(): 65 | model = lambda: pyro.sample("a", dist.Normal(0., 1.)) 66 | guide = tyxe.guides.AutoNormal(model, init_scale=1e-2) 67 | guide() 68 | assert guide.a.scale.isclose(torch.tensor(1e-2)) 69 | 70 | 71 | def test_auto_normal_max_scale(): 72 | model = lambda: pyro.sample("a", dist.Normal(0., 1.)) 73 | guide = tyxe.guides.AutoNormal(model, init_scale=1e-2, max_guide_scale=1e-3) 74 | guide() 75 | assert guide.a.scale.isclose(torch.tensor(1e-3)) 76 | 77 | 78 | def test_auto_normal_detached_distributions(): 79 | model = lambda: pyro.sample("a", dist.Normal(0., 1.)) 80 | guide = tyxe.guides.AutoNormal(model) 81 | guide() 82 | fn = guide.get_detached_distributions()["a"] 83 | assert isinstance(fn, dist.Normal) 84 | assert fn.scale.isclose(torch.tensor(0.1)) 85 | assert not fn.loc.requires_grad 86 | assert not fn.scale.requires_grad 87 | 88 | 89 | def test_constant_init(): 90 | model = lambda: pyro.sample("a", dist.Normal(torch.ones(3), 1.)) 91 | guide = tyxe.guides.AutoNormal(model, init_loc_fn=partial(tyxe.guides.init_to_constant, c=2)) 92 | guide() 93 | assert guide.a.loc.eq(2).all() 94 | 95 | 96 | def test_zero_init(): 97 | model = lambda: pyro.sample("a", dist.Normal(torch.ones(3, 2), 1.)) 98 | guide = tyxe.guides.AutoNormal(model, init_loc_fn=tyxe.guides.init_to_zero) 99 | guide() 100 | assert guide.a.loc.eq(0).all() 101 | 102 | 103 | def test_sample_init(): 104 | model = lambda: pyro.sample("a", dist.Normal(torch.ones(10000), 1.).to_event()) 105 | guide = tyxe.guides.AutoNormal( 106 | model, init_loc_fn=partial(tyxe.guides.init_to_sample, distribution=dist.Normal(0, 1))) 107 | guide() 108 | assert guide.a.loc.mean().isclose(torch.tensor(0.), atol=3e-2).item() 109 | assert guide.a.loc.std().isclose(torch.tensor(1.), atol=3e-2).item() 110 | 111 | 112 | def test_xavier_init(): 113 | model = lambda: pyro.sample("a", dist.Normal(torch.ones(50, 150), 1.).to_event()) 114 | guide = tyxe.guides.AutoNormal(model, init_loc_fn=tyxe.guides.init_to_normal_kaiming) 115 | guide() 116 | assert guide.a.loc.std().isclose(torch.tensor(0.1), atol=3e-2).item() 117 | 118 | 119 | def test_radford_init(): 120 | model = lambda: pyro.sample("a", dist.Normal(torch.ones(50, 144), 1.).to_event()) 121 | guide = tyxe.guides.AutoNormal(model, init_loc_fn=tyxe.guides.init_to_normal_kaiming) 122 | guide() 123 | assert guide.a.loc.std().isclose(torch.tensor(144 ** -0.5), atol=3e-2).item() 124 | 125 | 126 | def test_kaiming_init(): 127 | model = lambda: pyro.sample("a", dist.Normal(torch.ones(100, 100), 1.).to_event()) 128 | guide = tyxe.guides.AutoNormal(model, init_loc_fn=partial(tyxe.guides.init_to_normal_kaiming, gain=2.)) 129 | guide() 130 | assert guide.a.loc.std().isclose(torch.tensor(0.2), atol=3e-2).item() 131 | 132 | 133 | def test_pretrained_init(): 134 | model = lambda: pyro.sample("a", dist.Normal(torch.ones(5), 1.).to_event()) 135 | mean_init = torch.randn(5) 136 | guide = tyxe.guides.AutoNormal(model, init_loc_fn=tyxe.guides.PretrainedInitializer({"a": mean_init})) 137 | guide() 138 | assert guide.a.loc.eq(mean_init).all().item() 139 | 140 | 141 | def test_pretrained_from_net_init(): 142 | l = torch.nn.Linear(3, 2, bias=False) 143 | model = PyroModule[torch.nn.Linear](3, 2, bias=False) 144 | model.weight = PyroSample(dist.Normal(torch.zeros_like(model.weight), torch.ones_like(model.weight))) 145 | guide = tyxe.guides.AutoNormal(model, init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(l, prefix="")) 146 | guide(torch.randn(3)) 147 | assert guide.weight.loc.eq(l.weight).all().item() 148 | -------------------------------------------------------------------------------- /tests/test_likelihoods.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | import torch.distributions as dist 5 | 6 | import pyro 7 | 8 | 9 | from tyxe.likelihoods import * 10 | 11 | 12 | @pytest.mark.parametrize("event_dim,shape,scale", [ 13 | (0, (3, 2), 0.1), 14 | (1, (4, 3), 1.), 15 | (2, (5, 4, 3, 2), 3.) 16 | ]) 17 | def test_hom_gaussian_log_likelihood(event_dim, shape, scale): 18 | lik = HomoskedasticGaussian(event_dim=event_dim, dataset_size=-1, scale=scale) 19 | predictions = torch.randn(shape) 20 | targets = torch.randn(shape) 21 | log_probs = dist.Normal(predictions, scale).log_prob(targets) 22 | if event_dim != 0: 23 | log_probs = log_probs.flatten(-event_dim).sum(-1) 24 | assert torch.allclose(log_probs, lik.log_likelihood(predictions, targets)) 25 | 26 | 27 | @pytest.mark.parametrize("event_dim,shape", [ 28 | (0, (3, 2)), 29 | (1, (4, 3)), 30 | (2, (5, 4, 3, 2)) 31 | ]) 32 | def test_het_gaussian_log_likelihood(event_dim, shape): 33 | lik = HeteroskedasticGaussian(event_dim=event_dim, dataset_size=-1, positive_scale=True) 34 | pred_means = torch.randn(shape) 35 | pred_scales = torch.rand(shape) 36 | predictions = torch.cat((pred_means, pred_scales), dim=-1) 37 | targets = torch.randn(shape) 38 | log_probs = dist.Normal(pred_means, pred_scales).log_prob(targets) 39 | if event_dim != 0: 40 | log_probs = log_probs.flatten(-event_dim).sum(-1) 41 | assert torch.allclose(log_probs, lik.log_likelihood(predictions, targets)) 42 | 43 | 44 | @pytest.mark.parametrize("logits", [True, False]) 45 | @pytest.mark.parametrize("event_dim,shape", [ 46 | (0, (3, 2)), 47 | (1, (4, 3)), 48 | (2, (5, 4, 3, 2)) 49 | ]) 50 | def test_bernoulli_log_likelihood(event_dim, shape, logits): 51 | lik = Bernoulli(event_dim=event_dim, dataset_size=-1, logit_predictions=logits) 52 | targets = torch.randint(2, size=shape).float() 53 | if logits: 54 | predictions = torch.randn(shape) 55 | d = dist.Bernoulli(logits=predictions) 56 | else: 57 | predictions = torch.rand(shape) 58 | d = dist.Bernoulli(probs=predictions) 59 | log_probs = d.log_prob(targets) 60 | if event_dim != 0: 61 | log_probs = log_probs.flatten(-event_dim).sum(-1) 62 | assert torch.allclose(log_probs, lik.log_likelihood(predictions, targets)) 63 | 64 | 65 | @pytest.mark.parametrize("logits", [True, False]) 66 | @pytest.mark.parametrize("event_dim,shape", [ 67 | (0, (3, 2, 5)), 68 | (1, (4, 3, 10)), 69 | (2, (5, 4, 3, 2, 7)) 70 | ]) 71 | def test_categorical_log_likelihood(event_dim, shape, logits): 72 | lik = Categorical(event_dim=event_dim, dataset_size=-1, logit_predictions=logits) 73 | targets = torch.randint(shape[-1], size=shape[:-1]) 74 | if logits: 75 | predictions = torch.randn(shape) 76 | d = dist.Categorical(logits=predictions) 77 | else: 78 | predictions = torch.randn(shape).softmax(-1) 79 | d = dist.Categorical(probs=predictions) 80 | log_probs = d.log_prob(targets) 81 | if event_dim != 0: 82 | log_probs = log_probs.flatten(-event_dim).sum(-1) 83 | assert torch.allclose(log_probs, lik.log_likelihood(predictions, targets)) 84 | 85 | 86 | @pytest.mark.parametrize("event_dim,shape", [ 87 | (0, (3, 2)), 88 | (1, (4, 3)), 89 | (2, (5, 4, 3, 2)) 90 | ]) 91 | def test_hom_gaussian_error(event_dim, shape): 92 | lik = HomoskedasticGaussian(event_dim=event_dim, dataset_size=-1, scale=1) 93 | predictions = torch.randn(shape) 94 | targets = torch.randn(shape) 95 | errors = (predictions - targets) ** 2 96 | if event_dim != 0: 97 | errors = errors.flatten(-event_dim).sum(-1) 98 | assert torch.allclose(errors, lik.error(predictions, targets)) 99 | 100 | 101 | @pytest.mark.parametrize("event_dim,shape", [ 102 | (0, (3, 2)), 103 | (1, (4, 3)), 104 | (2, (5, 4, 3, 2)) 105 | ]) 106 | def test_het_gaussian_error(event_dim, shape): 107 | lik = HeteroskedasticGaussian(event_dim=event_dim, dataset_size=-1, positive_scale=True) 108 | pred_means = torch.randn(shape) 109 | predictions = torch.cat((pred_means, torch.rand(shape)), dim=-1) 110 | targets = torch.randn(shape) 111 | errors = (pred_means - targets) ** 2 112 | if event_dim != 0: 113 | errors = errors.flatten(-event_dim).sum(-1) 114 | assert torch.allclose(errors, lik.error(predictions, targets)) 115 | 116 | 117 | @pytest.mark.parametrize("logits", [True, False]) 118 | @pytest.mark.parametrize("event_dim,shape", [ 119 | (0, (3, 2)), 120 | (1, (4, 3)), 121 | (2, (5, 4, 3, 2)) 122 | ]) 123 | def test_bernoulli_error(event_dim, shape, logits): 124 | lik = Bernoulli(event_dim=event_dim, dataset_size=-1, logit_predictions=logits) 125 | targets = torch.randint(2, size=shape).bool() 126 | if logits: 127 | predictions = torch.randn(shape) 128 | hard_predictions = predictions.gt(0) 129 | else: 130 | predictions = torch.rand(shape) 131 | hard_predictions = predictions.gt(0.5) 132 | errors = hard_predictions.ne(targets).float() 133 | if event_dim != 0: 134 | errors = errors.flatten(-event_dim).sum(-1) 135 | assert torch.allclose(errors, lik.error(predictions, targets.float())) 136 | 137 | 138 | @pytest.mark.parametrize("logits", [True, False]) 139 | @pytest.mark.parametrize("event_dim,shape", [ 140 | (0, (3, 2, 5)), 141 | (1, (4, 3, 10)), 142 | (2, (5, 4, 3, 2, 7)) 143 | ]) 144 | def test_categorical_error(event_dim, shape, logits): 145 | lik = Categorical(event_dim=event_dim, dataset_size=-1, logit_predictions=logits) 146 | targets = torch.randint(shape[-1], size=shape[:-1]) 147 | if logits: 148 | predictions = torch.randn(shape) 149 | else: 150 | predictions = torch.randn(shape).softmax(-1) 151 | hard_predictions = predictions.argmax(-1) 152 | errors = hard_predictions.ne(targets).float() 153 | if event_dim != 0: 154 | errors = errors.flatten(-event_dim).sum(-1) 155 | assert torch.allclose(errors, lik.error(predictions, targets.float())) 156 | 157 | 158 | @pytest.mark.parametrize("agg_dim,scale", [(0, 0.1), (1, 1.), (2, 3.)]) 159 | def test_hom_gaussian_aggregate(agg_dim, scale): 160 | shape = (5, 4, 3) 161 | lik = HomoskedasticGaussian(event_dim=1, dataset_size=-1, scale=scale) 162 | predictions = torch.randn(shape) 163 | means, scales = lik.aggregate_predictions(predictions, agg_dim) 164 | true_means = predictions.mean(agg_dim) 165 | true_scale = predictions.var(agg_dim).add(scale ** 2).sqrt() 166 | 167 | assert torch.allclose(means, true_means) 168 | assert torch.allclose(scales, true_scale) 169 | 170 | 171 | @pytest.mark.parametrize("agg_dim", [0, 1, 2]) 172 | def test_het_gaussian_aggregate(agg_dim): 173 | shape = (5, 4, 3) 174 | lik = HeteroskedasticGaussian(event_dim=1, dataset_size=-1, positive_scale=True) 175 | pred_means = torch.randn(shape) 176 | pred_scales = torch.rand(shape) 177 | pred_precisions = pred_scales.pow(-2) 178 | predictions = torch.cat((pred_means, pred_scales), dim=-1) 179 | means, scales = lik.aggregate_predictions(predictions, agg_dim).chunk(2, dim=-1) 180 | true_means = pred_means.mul(pred_precisions).sum(agg_dim) / pred_precisions.sum(agg_dim) 181 | true_scale = pred_means.var(agg_dim).add(pred_scales.pow(2).mean(agg_dim)).sqrt() 182 | 183 | assert torch.allclose(means, true_means) 184 | assert torch.allclose(scales, true_scale) 185 | 186 | 187 | @pytest.mark.parametrize("logits", [True, False]) 188 | @pytest.mark.parametrize("agg_dim", [0, 1, 2]) 189 | def test_bernoulli_aggregate(agg_dim, logits): 190 | shape = (5, 4, 3) 191 | lik = Bernoulli(event_dim=1, dataset_size=-1, logit_predictions=logits) 192 | if logits: 193 | predictions = torch.randn(shape) 194 | avg_probs = predictions.sigmoid().mean(agg_dim) 195 | agg_predictions = avg_probs.log() - avg_probs.mul(-1).log1p() 196 | else: 197 | predictions = torch.rand(shape) 198 | agg_predictions = predictions.mean(agg_dim) 199 | 200 | assert torch.allclose(agg_predictions, lik.aggregate_predictions(predictions, agg_dim)) 201 | 202 | 203 | @pytest.mark.parametrize("logits", [True, False]) 204 | @pytest.mark.parametrize("agg_dim", [0, 1, 2]) 205 | def test_categorical_aggregate(agg_dim, logits): 206 | shape = (5, 4, 3, 10) 207 | lik = Categorical(event_dim=1, dataset_size=-1, logit_predictions=logits) 208 | if logits: 209 | predictions = torch.randn(shape) 210 | agg_predictions = predictions.softmax(-1).mean(agg_dim).log() 211 | else: 212 | predictions = torch.randn(shape).softmax(-1) 213 | agg_predictions = predictions.mean(agg_dim) 214 | assert torch.allclose(agg_predictions, lik.aggregate_predictions(predictions, agg_dim)) 215 | 216 | 217 | def test_forward_batch(): 218 | shape = (4, 3) 219 | lik = Bernoulli(event_dim=1, dataset_size=10, logit_predictions=True) 220 | predictions = torch.randn(shape) 221 | tr = pyro.poutine.trace(lik).get_trace(predictions) 222 | assert tr.nodes["data"]["scale"] == 2.5 223 | 224 | 225 | def test_forward_single(): 226 | shape = (3,) 227 | lik = Bernoulli(event_dim=1, dataset_size=10, logit_predictions=True) 228 | predictions = torch.randn(shape) 229 | tr = pyro.poutine.trace(lik).get_trace(predictions) 230 | assert tr.nodes["data"]["scale"] == 10 231 | 232 | 233 | def test_hom_gaussian_dist(): pass 234 | -------------------------------------------------------------------------------- /tests/test_poutine.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import pyro.distributions as dist 7 | from pyro.nn import PyroModule, PyroSample 8 | from pyro.nn.module import to_pyro_module_ 9 | 10 | 11 | import tyxe 12 | 13 | 14 | def as_pyro_module(module): 15 | to_pyro_module_(module, recurse=True) 16 | for m in module.modules(): 17 | for n, p in list(m.named_parameters(recurse=False)): 18 | setattr(m, n, PyroSample(dist.Normal(torch.zeros_like(p), torch.ones_like(p)).to_event())) 19 | return module 20 | 21 | 22 | @pytest.mark.parametrize("reparameterization_ctxt", [tyxe.poutine.local_reparameterization, tyxe.poutine.flipout]) 23 | def test_different_outputs(reparameterization_ctxt): 24 | l = as_pyro_module(nn.Linear(3, 2)) 25 | x = torch.randn(1, 3).repeat(128, 1) 26 | with reparameterization_ctxt(): 27 | a = l(x) 28 | 29 | # compare the pairwise outputs by broadcasting, fail if any distinct pairs are equal 30 | assert not a.unsqueeze(-2).eq(a).all(-1).tril(diagonal=-1).any().item() 31 | 32 | 33 | @pytest.mark.parametrize("reparameterization_ctxt", [tyxe.poutine.local_reparameterization, tyxe.poutine.flipout]) 34 | def test_mean_std(reparameterization_ctxt): 35 | d = 8 36 | n_samples = int(2 ** (d + 1)) 37 | repeats = 1000 38 | 39 | x = torch.randn(d) 40 | 41 | l = PyroModule[nn.Linear](x.shape[0], 2) 42 | weight_mean = torch.randn_like(l.weight) 43 | weight_sd = torch.rand_like(l.weight) 44 | l.weight = PyroSample(dist.Normal(weight_mean, weight_sd).to_event()) 45 | bias_mean = torch.randn_like(l.bias) 46 | bias_sd = torch.rand_like(l.bias) 47 | l.bias = PyroSample(dist.Normal(bias_mean, bias_sd).to_event()) 48 | 49 | m = x @ weight_mean.t() + bias_mean 50 | s = torch.sqrt(x.pow(2) @ weight_sd.t().pow(2) + bias_sd.pow(2)) 51 | 52 | x = x.unsqueeze(0).repeat(n_samples, 1) 53 | a = [] 54 | for _ in range(repeats): 55 | with reparameterization_ctxt(): 56 | a.append(l(x)) 57 | a = torch.cat(a) 58 | 59 | assert torch.allclose(m, a.mean(0), atol=1e-2) 60 | assert torch.allclose(s, a.std(0), atol=1e-1) 61 | 62 | 63 | def test_two_parameterizations_raises(): 64 | l = as_pyro_module(nn.Linear(3,2)) 65 | x = torch.randn(8, 3) 66 | with pytest.raises(ValueError): 67 | with tyxe.poutine.local_reparameterization(), tyxe.poutine.flipout(): 68 | l(x) 69 | 70 | 71 | def test_multiple_reparameterizers_compatible(): 72 | net = as_pyro_module(nn.Sequential( 73 | nn.Conv2d(1, 2, 3), 74 | nn.Flatten(), 75 | nn.Linear(2, 3) 76 | )) 77 | x = torch.randn(1, 3, 3).repeat(8, 1, 1, 1) 78 | with tyxe.poutine.local_reparameterization(reparameterizable_functions="linear"),\ 79 | tyxe.poutine.flipout(reparameterizable_functions="conv2d"): 80 | a = net(x) 81 | assert not a.unsqueeze(-2).eq(a).all(-1).tril(diagonal=-1).any().item() 82 | 83 | 84 | def test_ignores_not_given_fn(): 85 | l = as_pyro_module(nn.Linear(3, 2)) 86 | x = torch.randn(1, 3).repeat(8, 1) 87 | with tyxe.poutine.local_reparameterization(reparameterizable_functions=["conv1d", "conv2d", "conv3d"]): 88 | a = l(x) 89 | assert a[0].eq(a).all().item() 90 | 91 | 92 | def test_flipout_no_batch_dim(): 93 | l = as_pyro_module(nn.Linear(3, 2)) 94 | with tyxe.poutine.flipout(): 95 | assert l(torch.randn(3)).shape == (2,) 96 | 97 | 98 | def test_flipout_multi_batch_dim(): 99 | l = as_pyro_module(nn.Linear(3, 2)) 100 | with tyxe.poutine.flipout(): 101 | assert l(torch.randn(5, 4, 3)).shape == (5, 4, 2) 102 | -------------------------------------------------------------------------------- /tests/test_priors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import pyro.distributions as dist 5 | from pyro.nn import PyroModule 6 | 7 | 8 | import tyxe 9 | 10 | 11 | def test_iid(): 12 | l = PyroModule[nn.Linear](3, 2, bias=False) 13 | prior = tyxe.priors.IIDPrior(dist.Normal(0, 1)) 14 | prior.apply_(l) 15 | p = l._pyro_samples["weight"] 16 | assert isinstance(p, dist.Independent) 17 | assert isinstance(p.base_dist, dist.Normal) 18 | assert p.base_dist.loc.allclose(torch.tensor(0.)) 19 | assert p.base_dist.scale.allclose(torch.tensor(1.)) 20 | 21 | 22 | def test_layerwise_normal_kaiming(): 23 | l = PyroModule[nn.Linear](3, 2, bias=False) 24 | prior = tyxe.priors.LayerwiseNormalPrior(method="kaiming") 25 | prior.apply_(l) 26 | p = l._pyro_samples["weight"] 27 | assert p.base_dist.scale.allclose(torch.tensor((2 / 3.) ** 0.5)) 28 | 29 | 30 | def test_layerwise_normal_radford(): 31 | l = PyroModule[nn.Linear](3, 2, bias=False) 32 | prior = tyxe.priors.LayerwiseNormalPrior(method="radford") 33 | prior.apply_(l) 34 | p = l._pyro_samples["weight"] 35 | assert p.base_dist.scale.allclose(torch.tensor(3 ** -0.5)) 36 | 37 | 38 | def test_layerwise_normal_xavier(): 39 | l = PyroModule[nn.Linear](3, 2, bias=False) 40 | prior = tyxe.priors.LayerwiseNormalPrior(method="xavier") 41 | prior.apply_(l) 42 | p = l._pyro_samples["weight"] 43 | assert p.base_dist.scale.allclose(torch.tensor(0.8 ** 0.5)) 44 | 45 | 46 | def test_expose_all(): 47 | net = PyroModule[nn.Sequential](PyroModule[nn.Linear](4, 3), PyroModule[nn.Linear](3, 2)) 48 | tyxe.priors.IIDPrior(dist.Normal(0, 1), expose_all=True).apply_(net) 49 | assert "weight" in net[0]._pyro_samples 50 | assert "bias" in net[0]._pyro_samples 51 | assert "weight" in net[1]._pyro_samples 52 | assert "bias" in net[1]._pyro_samples 53 | 54 | 55 | def test_hide_all(): 56 | net = PyroModule[nn.Sequential](PyroModule[nn.Linear](4, 3), PyroModule[nn.Linear](3, 2)) 57 | tyxe.priors.IIDPrior(dist.Normal(0, 1), expose_all=False, hide_all=True).apply_(net) 58 | assert "weight" in net[0]._pyro_params 59 | assert "bias" in net[0]._pyro_params 60 | assert "weight" in net[1]._pyro_params 61 | assert "bias" in net[1]._pyro_params 62 | 63 | 64 | def test_expose_modules(): 65 | net = nn.Sequential(nn.Linear(4, 3), nn.Linear(3, 2)) 66 | prior = tyxe.priors.IIDPrior(dist.Normal(0, 1), expose_all=False, expose_modules=[net[0]]) 67 | tyxe.util.to_pyro_module_(net) 68 | prior.apply_(net) 69 | assert "weight" in net[0]._pyro_samples 70 | assert "bias" in net[0]._pyro_samples 71 | assert "weight" in net[1]._pyro_params 72 | assert "bias" in net[1]._pyro_params 73 | 74 | 75 | def test_hide_modules(): 76 | net = nn.Sequential(nn.Linear(4, 3), nn.Linear(3, 2)) 77 | prior = tyxe.priors.IIDPrior(dist.Normal(0, 1), expose_all=False, hide_modules=[net[0]]) 78 | tyxe.util.to_pyro_module_(net) 79 | prior.apply_(net) 80 | assert "weight" in net[0]._pyro_params 81 | assert "bias" in net[0]._pyro_params 82 | assert "weight" in net[1]._pyro_samples 83 | assert "bias" in net[1]._pyro_samples 84 | 85 | 86 | def test_expose_types(): 87 | net = nn.Sequential(nn.Conv2d(3, 8, 3), nn.Linear(3, 2)) 88 | prior = tyxe.priors.IIDPrior(dist.Normal(0, 1), expose_all=False, expose_module_types=(nn.Conv2d,)) 89 | tyxe.util.to_pyro_module_(net) 90 | prior.apply_(net) 91 | assert "weight" in net[0]._pyro_samples 92 | assert "bias" in net[0]._pyro_samples 93 | assert "weight" in net[1]._pyro_params 94 | assert "bias" in net[1]._pyro_params 95 | 96 | 97 | def test_hide_types(): 98 | net = nn.Sequential(nn.Conv2d(3, 8, 3), nn.Linear(3, 2)) 99 | prior = tyxe.priors.IIDPrior(dist.Normal(0, 1), expose_all=False, hide_module_types=(nn.Linear,)) 100 | tyxe.util.to_pyro_module_(net) 101 | prior.apply_(net) 102 | assert "weight" in net[0]._pyro_samples 103 | assert "bias" in net[0]._pyro_samples 104 | assert "weight" in net[1]._pyro_params 105 | assert "bias" in net[1]._pyro_params 106 | 107 | 108 | def test_expose_parameters(): 109 | net = nn.Sequential(nn.Linear(4, 3), nn.Linear(3, 2)) 110 | prior = tyxe.priors.IIDPrior(dist.Normal(0, 1), expose_all=False, expose_parameters=["weight"]) 111 | tyxe.util.to_pyro_module_(net) 112 | prior.apply_(net) 113 | assert "weight" in net[0]._pyro_samples 114 | assert "bias" in net[0]._pyro_params 115 | assert "weight" in net[1]._pyro_samples 116 | assert "bias" in net[1]._pyro_params 117 | 118 | 119 | def test_hide_parameters(): 120 | net = nn.Sequential(nn.Linear(4, 3), nn.Linear(3, 2)) 121 | prior = tyxe.priors.IIDPrior(dist.Normal(0, 1), expose_all=False, hide_parameters=["weight"]) 122 | tyxe.util.to_pyro_module_(net) 123 | prior.apply_(net) 124 | assert "weight" in net[0]._pyro_params 125 | assert "bias" in net[0]._pyro_samples 126 | assert "weight" in net[1]._pyro_params 127 | assert "bias" in net[1]._pyro_samples 128 | 129 | 130 | def test_expose(): 131 | net = nn.Sequential(nn.Linear(4, 3), nn.Linear(3, 2)) 132 | prior = tyxe.priors.IIDPrior(dist.Normal(0, 1), expose_all=False, expose=["0.weight", "1.weight"]) 133 | tyxe.util.to_pyro_module_(net) 134 | prior.apply_(net) 135 | assert "weight" in net[0]._pyro_samples 136 | assert "bias" in net[0]._pyro_params 137 | assert "weight" in net[1]._pyro_samples 138 | assert "bias" in net[1]._pyro_params 139 | 140 | 141 | def test_hide(): 142 | net = nn.Sequential(nn.Linear(4, 3), nn.Linear(3, 2)) 143 | prior = tyxe.priors.IIDPrior(dist.Normal(0, 1), expose_all=False, hide=["0.weight", "1.weight"]) 144 | tyxe.util.to_pyro_module_(net) 145 | prior.apply_(net) 146 | assert "weight" in net[0]._pyro_params 147 | assert "bias" in net[0]._pyro_samples 148 | assert "weight" in net[1]._pyro_params 149 | assert "bias" in net[1]._pyro_samples 150 | -------------------------------------------------------------------------------- /tyxe/__init__.py: -------------------------------------------------------------------------------- 1 | from . import guides 2 | from . import likelihoods 3 | from . import poutine 4 | from . import priors 5 | from . import util 6 | 7 | from .bnn import * 8 | -------------------------------------------------------------------------------- /tyxe/bnn.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import itertools 3 | from operator import itemgetter 4 | 5 | import torch 6 | 7 | import pyro.nn as pynn 8 | import pyro.poutine as poutine 9 | from pyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO, MCMC 10 | 11 | from . import util 12 | 13 | 14 | __all__ = ["PytorchBNN", "VariationalBNN", "MCMC_BNN"] 15 | 16 | 17 | def _empty_guide(*args, **kwargs): 18 | return {} 19 | 20 | 21 | def _as_tuple(x): 22 | if isinstance(x, (list, tuple)): 23 | return x 24 | return x, 25 | 26 | 27 | def _to(x, device): 28 | return map(lambda t: t.to(device) if device is not None else t, _as_tuple(x)) 29 | 30 | 31 | class _BNN(pynn.PyroModule): 32 | """BNN base class that takes an nn.Module, turns it into a PyroModule and applies a prior to it, i.e. replaces 33 | nn.Parameter attributes by PyroSamples according to the specification in the prior. The forward method wraps the 34 | forward pass of the net and samples weights from the prior distributions. 35 | 36 | :param nn.Module net: pytorch neural network to be turned into a BNN. 37 | :param prior tyxe.priors.Prior: prior object that specifies over which parameters we want uncertainty. 38 | :param str name: base name for the BNN PyroModule.""" 39 | 40 | def __init__(self, net, prior, name=""): 41 | super().__init__(name) 42 | self.net = net 43 | pynn.module.to_pyro_module_(self.net) 44 | self.prior = prior 45 | self.prior.apply_(self.net) 46 | 47 | def forward(self, *args, **kwargs): 48 | return self.net(*args, **kwargs) 49 | 50 | def update_prior(self, new_prior): 51 | """Uppdates the prior of the network, i.e. calls its update_ method on the net. 52 | 53 | :param tyxe.priors.Prior new_prior: Prior for replacing the previous prior, i.e. substituting the PyroSample 54 | attributes of the net.""" 55 | self.prior = new_prior 56 | self.prior.update_(self) 57 | 58 | 59 | class GuidedBNN(_BNN): 60 | """Guided BNN class that in addition to the network and prior also has a guide for doing approximate inference 61 | over the neural network weights. The guide_builder argument is called on the net after it has been transformed to 62 | a PyroModule and returns the pyro guide function that sample from the approximate posterior. 63 | 64 | :param callable guide_builder: callable that takes a probabilistic pyro function with sample statements and returns 65 | an object that helps with inference, i.e. a callable guide function that samples from an approximate posterior 66 | for variational BNNs or an MCMC kernel for MCMC-based BNNs. May be None for maximum likelihood inference if 67 | the prior leaves all parameters of the net as such.""" 68 | 69 | def __init__(self, net, prior, guide_builder=None, name=""): 70 | super().__init__(net, prior, name=name) 71 | self.net_guide = guide_builder(self.net) if guide_builder is not None else _empty_guide 72 | 73 | def guided_forward(self, *args, guide_tr=None, **kwargs): 74 | if guide_tr is None: 75 | guide_tr = poutine.trace(self.net_guide).get_trace(*args, **kwargs) 76 | return poutine.replay(self.net, trace=guide_tr)(*args, **kwargs) 77 | 78 | 79 | class PytorchBNN(GuidedBNN): 80 | """Low-level variational BNN class that can serve as a drop-in replacement for an nn.Module. 81 | 82 | :param bool closed_form_kl: whether to use TraceMeanField_ELBO or Trace_ELBO as the loss, i.e. calculate KL 83 | divergences in closed form or via a Monte Carlo approximate of the difference of log densities between 84 | variational posterior and prior.""" 85 | 86 | def __init__(self, net, prior, guide_builder=None, name="", closed_form_kl=True): 87 | super().__init__(net, prior, guide_builder=guide_builder, name=name) 88 | self.cached_output = None 89 | self.cached_kl_loss = None 90 | self._loss = TraceMeanField_ELBO() if closed_form_kl else Trace_ELBO() 91 | 92 | def named_pytorch_parameters(self, *input_data): 93 | """Equivalent of the named_parameters method of an nn.Module. Ensures that prior and guide are run once to 94 | initialize all pyro parameters. Those are then collected and returned via the trace poutine.""" 95 | model_trace = poutine.trace(self.net, param_only=True).get_trace(*input_data) 96 | guide_trace = poutine.trace(self.net_guide, param_only=True).get_trace(*input_data) 97 | for name, msg in itertools.chain(model_trace.nodes.items(), guide_trace.nodes.items()): 98 | yield name, msg["value"].unconstrained() 99 | 100 | def pytorch_parameters(self, input_data_or_fwd_fn): 101 | yield from map(itemgetter(1), self.named_pytorch_parameters(input_data_or_fwd_fn)) 102 | 103 | def cached_forward(self, *args, **kwargs): 104 | # cache the output of forward to make it effectful, so that we can access the output when running forward with 105 | # posterior rather than prior samples 106 | self.cached_output = super().forward(*args, **kwargs) 107 | return self.cached_output 108 | 109 | def forward(self, *args, **kwargs): 110 | self.cached_kl_loss = self._loss.differentiable_loss(self.cached_forward, self.net_guide, *args, **kwargs) 111 | return self.cached_output 112 | 113 | 114 | class _SupervisedBNN(GuidedBNN): 115 | """Base class for supervised BNNs that defines the interface of the predict method and implements 116 | evaluate. Agnostic to the kind of inference performed. 117 | 118 | :param tyxe.likelihoods.Likelihood likelihood: Likelihood object that implements a forward method including 119 | a pyro.sample statement for labelled data given neural network predictions and implements logic for aggregating 120 | multiple predictions and evaluating them.""" 121 | 122 | def __init__(self, net, prior, likelihood, net_guide_builder=None, name=""): 123 | super().__init__(net, prior, net_guide_builder, name=name) 124 | self.likelihood = likelihood 125 | 126 | def model(self, x, obs=None): 127 | predictions = self(*_as_tuple(x)) 128 | self.likelihood(predictions, obs) 129 | return predictions 130 | 131 | def evaluate(self, input_data, y, num_predictions=1, aggregate=True, reduction="sum"): 132 | """"Utility method for evaluation. Calculates a likelihood-dependent errors measure, e.g. squared errors or 133 | mis-classifications and 134 | 135 | :param input_data: Inputs to the neural net. Must be a tuple of more than one. 136 | :param y: observations, e.g. class labels. 137 | :param int num_predictions: number of forward passes. 138 | :param bool aggregate: whether to aggregate the outputs of the forward passes before evaluating. 139 | :param str reduction: "sum", "mean" or "none". How to process the tensor of errors. "sum" adds them up, 140 | "mean" averages them and "none" simply returns the tensor.""" 141 | predictions = self.predict(*_as_tuple(input_data), num_predictions=num_predictions, aggregate=aggregate) 142 | error = self.likelihood.error(predictions, y, reduction=reduction) 143 | ll = self.likelihood.log_likelihood(predictions, y, reduction=reduction) 144 | return error, ll 145 | 146 | def predict(self, *input_data, num_predictions=1, aggregate=True): 147 | """Makes predictions on the input data 148 | 149 | :param input_data: inputs to the neural net, e.g. torch.Tensors 150 | :param int num_predictions: number of forward passes through the net 151 | :param bool aggregate: whether to aggregate the predictions depending on the likelihood, e.g. averaging them.""" 152 | raise NotImplementedError 153 | 154 | 155 | class VariationalBNN(_SupervisedBNN): 156 | """Variational BNN class for supervised problems. Requires a likelihood that describes the data noise and an 157 | optional guide builder for it should it contain any variables that need to be inferred. Provides high-level utility 158 | method such as fit, predict and 159 | 160 | :param callable net_guide_builder: pyro.infer.autoguide.AutoCallable style class that given a pyro function 161 | constructs a variational posterior that sample the same unobserved sites from distributions with learnable 162 | parameters. 163 | :param callable likelihood_guide_builder: optional callable that constructs a guide for the likelihood if it 164 | contains any unknown variable, such as the precision/scale of a Gaussian.""" 165 | def __init__(self, net, prior, likelihood, net_guide_builder=None, likelihood_guide_builder=None, name=""): 166 | super().__init__(net, prior, likelihood, net_guide_builder, name=name) 167 | weight_sample_sites = list(util.pyro_sample_sites(self.net)) 168 | if likelihood_guide_builder is not None: 169 | self.likelihood_guide = likelihood_guide_builder(poutine.block( 170 | self.model, hide=weight_sample_sites + [self.likelihood.data_name])) 171 | else: 172 | self.likelihood_guide = _empty_guide 173 | 174 | def guide(self, x, obs=None): 175 | result = self.net_guide(*_as_tuple(x)) or {} 176 | result.update(self.likelihood_guide(*_as_tuple(x), obs) or {}) 177 | return result 178 | 179 | def fit(self, data_loader, optim, num_epochs, callback=None, num_particles=1, closed_form_kl=True, device=None): 180 | """Optimizes the variational parameters on data from data_loader using optim for num_epochs. 181 | 182 | :param Iterable data_loader: iterable over batches of data, e.g. a torch.utils.data.DataLoader. Assumes that 183 | each element consists of a length two tuple of list, with the first element either containing a single 184 | object or a list of objects, e.g. torch.Tensors, that are the inputs to the neural network. The second 185 | element is a single torch.Tensor e.g. of class labels. 186 | :param optim: pyro optimizer to be used for constructing an SVI object, e.g. pyro.optim.Adam({"lr": 1e-3}). 187 | :param int num_epochs: number of passes over data_loader. 188 | :param callable callback: optional function to invoke after every training epoch. Receives the BNN object, 189 | the epoch number and the average value of the ELBO over the epoch. May return True to terminate 190 | optimization before num_epochs, e.g. if it finds that a validation log likelihood saturates. 191 | :param int num_particles: number of MC samples for estimating the ELBO. 192 | :param bool closed_form_kl: whether to use TraceMeanField_ELBO or Trace_ELBO, i.e. calculate KL divergence 193 | between approximate posterior and prior in closed form or via a Monte Carlo estimate. 194 | :param torch.device device: optional device to send the data to. 195 | """ 196 | old_training_state = self.net.training 197 | self.net.train(True) 198 | 199 | loss = TraceMeanField_ELBO(num_particles) if closed_form_kl else Trace_ELBO(num_particles) 200 | svi = SVI(self.model, self.guide, optim, loss=loss) 201 | 202 | for i in range(num_epochs): 203 | elbo = 0. 204 | num_batch = 1 205 | for num_batch, (input_data, observation_data) in enumerate(iter(data_loader), 1): 206 | elbo += svi.step(tuple(_to(input_data, device)), tuple(_to(observation_data, device))[0]) 207 | 208 | # the callback can stop training by returning True 209 | if callback is not None and callback(self, i, elbo / num_batch): 210 | break 211 | 212 | self.net.train(old_training_state) 213 | return svi 214 | 215 | def predict(self, *input_data, num_predictions=1, aggregate=True, guide_traces=None): 216 | if guide_traces is None: 217 | guide_traces = [None] * num_predictions 218 | 219 | preds = [] 220 | with torch.autograd.no_grad(): 221 | for trace in guide_traces: 222 | preds.append(self.guided_forward(*input_data, guide_tr=trace)) 223 | predictions = torch.stack(preds) 224 | return self.likelihood.aggregate_predictions(predictions) if aggregate else predictions 225 | 226 | 227 | # TODO inherit from _SupervisedBNN to unify the class hierarchy. This will require changing the GuidedBNN baseclass to 228 | # construct the guide on top of self.model rather than self.net (model of GuidedBNN could just call the net and 229 | # SupervisedBNN adds the likelihood on top) and consequently removing the likelihood_guide_builder parameter for 230 | # the VariationalBNN class. This will however require hiding the likelihood.data site from the guide_builder in the 231 | # base class. 232 | class MCMC_BNN(_BNN): 233 | """Supervised BNN class with an interface to pyro's MCMC that is unified with the VariationalBNN class. 234 | 235 | :param callable kernel_builder: function or class that returns an object that will accepted as kernel by 236 | pyro.infer.mcmc.MCMC, e.g. pyro.infer.mcmc.HMC or NUTS. Will be called with the entire model, i.e. also 237 | infer variables in the likelihood.""" 238 | 239 | def __init__(self, net, prior, likelihood, kernel_builder, name=""): 240 | super().__init__(net, prior, name=name) 241 | self.likelihood = likelihood 242 | self.kernel = kernel_builder(self.model) 243 | self._mcmc = None 244 | 245 | def model(self, x, obs=None): 246 | predictions = self(*_as_tuple(x)) 247 | self.likelihood(predictions, obs) 248 | return predictions 249 | 250 | def fit(self, data_loader, num_samples, device=None, batch_data=False, **mcmc_kwargs): 251 | """Runs MCMC on the data from data loader using the kernel that was used to instantiate the class. 252 | 253 | :param data_loader: iterable or list of batched inputs to the net. If iterable treated like the data_loader 254 | of VariationalBNN and all network inputs are concatenated via torch.cat. Otherwise must be a tuple of 255 | a single or list of network inputs and a tensor for the targets. 256 | :param int num_samples: number of MCMC samples to draw. 257 | :param torch.device device: optional device to send the data to. 258 | :param batch_data: whether to treat data_loader as a full batch of data or an iterable over mini-batches. 259 | :param dict mcmc_kwargs: keyword arguments for initializing the pyro.infer.mcmc.MCMC object.""" 260 | if batch_data: 261 | input_data, observation_data = data_loader 262 | else: 263 | input_data_lists = defaultdict(list) 264 | observation_data_list = [] 265 | for in_data, obs_data in iter(data_loader): 266 | for i, data in enumerate(_as_tuple(in_data)): 267 | input_data_lists[i].append(data.to(device)) 268 | observation_data_list.append(obs_data.to(device)) 269 | input_data = tuple(torch.cat(input_data_lists[i]) for i in range(len(input_data_lists))) 270 | observation_data = torch.cat(observation_data_list) 271 | self._mcmc = MCMC(self.kernel, num_samples, **mcmc_kwargs) 272 | self._mcmc.run(input_data, observation_data) 273 | 274 | return self._mcmc 275 | 276 | def predict(self, *input_data, num_predictions=1, aggregate=True): 277 | if self._mcmc is None: 278 | raise RuntimeError("Call .fit to run MCMC and obtain samples from the posterior first.") 279 | 280 | preds = [] 281 | weight_samples = self._mcmc.get_samples(num_samples=num_predictions) 282 | with torch.no_grad(): 283 | for i in range(num_predictions): 284 | weights = {name: sample[i] for name, sample in weight_samples.items()} 285 | preds.append(poutine.condition(self, weights)(*input_data)) 286 | predictions = torch.stack(preds) 287 | return self.likelihood.aggregate_predictions(predictions) if aggregate else predictions 288 | -------------------------------------------------------------------------------- /tyxe/guides.py: -------------------------------------------------------------------------------- 1 | from contextlib import ExitStack 2 | import numbers 3 | 4 | import torch 5 | from torch.distributions import biject_to, transform_to 6 | 7 | import pyro 8 | import pyro.distributions as dist 9 | from pyro.distributions import constraints 10 | import pyro.nn as pynn 11 | import pyro.infer.autoguide as ag 12 | import pyro.infer.autoguide.initialization as ag_init 13 | import pyro.util as pyutil 14 | 15 | 16 | from . import util 17 | 18 | 19 | def _get_base_dist(distribution): 20 | while isinstance(distribution, dist.Independent): 21 | distribution = distribution.base_dist 22 | return distribution 23 | 24 | 25 | def init_to_constant(site, c): 26 | """Helper function to set site value to a constant value.""" 27 | site_fn = site["fn"] 28 | value = torch.full_like(site_fn.sample(), c) 29 | if hasattr(site_fn, "_validate_sample"): 30 | site_fn._validate_sample(value) 31 | return value 32 | 33 | 34 | def init_to_zero(site): 35 | """Helper function to set site value to 0.""" 36 | return init_to_constant(site, 0.) 37 | 38 | 39 | def init_to_sample(site, distribution): 40 | """Helper function to set site value to a sample from some given distribution.""" 41 | value = distribution.expand(site["fn"].event_shape).sample().detach() 42 | t = transform_to(site["fn"].support) 43 | return t(value) 44 | 45 | 46 | def init_to_normal(site, loc=0., std="xavier", gain=1.): 47 | """Helper function to set site value to a sample from a normal distribution with variance according to 48 | xavier/kaiming/radford neural network weight initialization methods.""" 49 | if isinstance(std, str): 50 | std = util.calculate_prior_std(std, site["fn"].sample(), gain=gain) 51 | return init_to_sample(site, dist.Normal(loc, std)) 52 | 53 | 54 | def init_to_normal_xavier(site): 55 | return init_to_normal(site, std="xavier") 56 | 57 | 58 | def init_to_normal_radford(site): 59 | return init_to_normal(site, std="radford") 60 | 61 | 62 | def init_to_normal_kaiming(site, gain=1.): 63 | return init_to_normal(site, std="kaiming", gain=gain) 64 | 65 | 66 | class PretrainedInitializer: 67 | """Utility class for setting the values of a site to known constants, e.g. from a trained neural network. 68 | 69 | :param dict values: dictionary of parameter values, mapping names of sites to tensors""" 70 | 71 | def __init__(self, values): 72 | self.values = values 73 | 74 | def __call__(self, site): 75 | return self.values[site["name"]] 76 | 77 | @classmethod 78 | def from_net(cls, net, prefix="net"): 79 | """Alternative init method for instantiating the class from the parameter values of an nn.Module. 80 | 81 | :param module: nn.Module to extract parameters from 82 | :param string prefix: Prefix value to pass to the modules `named_parameters` function 83 | 84 | :rtype: PretrainedInitializer 85 | """ 86 | values = {} 87 | for name, parameter in net.named_parameters(prefix): 88 | values[name] = parameter.data.clone() 89 | return cls(values) 90 | 91 | 92 | class AutoNormal(ag.AutoGuide): 93 | """Variant of pyro.infer.autoguide.AutoNormal. Samples sites from TransformedDistribution objects of normal 94 | distributions to allow for calculating KL divergences in closed form. Further makes training means or variances 95 | optional as well as allowing for variances to be capped at some upper limit. Provides a helper function for 96 | returning a subset of all site distributions with the parameters detached to be used as priors in variational 97 | continual learning. 98 | 99 | :param module: PyroModule or pyro model to perform inference in. 100 | :param callable init_loc_fn: function that sets the means of variational distributions of each site. 101 | :param float init_scale: initial standard deviation of the variational distributions. 102 | :param bool train_loc: Whether the variational means should be learnable. 103 | :param bool train_scale: Whether the variational standard deviations should be learnable. 104 | :param float max_guide_scale: Optional upper limit on the variational standard deviations.""" 105 | 106 | def __init__(self, module, init_loc_fn=ag_init.init_to_median, init_scale=1e-1, train_loc=True, train_scale=True, 107 | max_guide_scale=None): 108 | module = ag_init.InitMessenger(init_loc_fn)(module) 109 | self.init_scale = init_scale 110 | self.train_loc = train_loc 111 | self.train_scale = train_scale 112 | self.max_guide_scale = max_guide_scale 113 | super().__init__(module) 114 | 115 | def _setup_prototype(self, *args, **kwargs): 116 | super()._setup_prototype(*args, **kwargs) 117 | 118 | for name, site in self.prototype_trace.iter_stochastic_nodes(): 119 | constrained_value = site["value"] 120 | unconstrained_value = biject_to(site["fn"].support).inv(constrained_value) 121 | if self.train_loc: 122 | unconstrained_value = pynn.PyroParam(unconstrained_value) 123 | ag.guides.deep_setattr(self, name + ".loc", unconstrained_value) 124 | if isinstance(self.init_scale, numbers.Real): 125 | scale_value = torch.full_like(site["value"], self.init_scale) 126 | elif isinstance(self.init_scale, str): 127 | scale_value = torch.full_like(site["value"], util.calculate_prior_std(self.init_scale, site["value"])) 128 | else: 129 | scale_value = self.init_scale[site["name"]] 130 | scale_constraint = constraints.positive if self.max_guide_scale is None else constraints.interval(0., self.max_guide_scale) 131 | scale = pynn.PyroParam(scale_value, constraint=scale_constraint) if self.train_scale else scale_value 132 | ag.guides.deep_setattr(self, name + ".scale", scale) 133 | 134 | def get_loc(self, site_name): 135 | return pyro.util.deep_getattr(self, site_name + ".loc") 136 | 137 | def get_scale(self, site_name): 138 | return pyro.util.deep_getattr(self, site_name + ".scale") 139 | 140 | def get_detached_distributions(self, site_names=None): 141 | """Returns a dictionary mapping the site names to their variational posteriors. All parameters are detached.""" 142 | if site_names is None: 143 | site_names = list(name for name, _ in self.prototype_trace.iter_stochastic_nodes()) 144 | 145 | result = dict() 146 | for name, site in self.prototype_trace.iter_stochastic_nodes(): 147 | if name not in site_names: 148 | continue 149 | loc = self.get_loc(name).detach().clone() 150 | scale = self.get_scale(name).detach().clone() 151 | fn = dist.Normal(loc, scale).to_event(max(loc.dim(), scale.dim())) 152 | base_fn = _get_base_dist(site["fn"]) 153 | if base_fn.support is not dist.constraints.real: 154 | fn = dist.TransformedDistribution(fn, biject_to(base_fn.support)) 155 | result[name] = fn 156 | return result 157 | 158 | def forward(self, *args, **kwargs): 159 | if self.prototype_trace is None: 160 | self._setup_prototype(*args, **kwargs) 161 | 162 | plates = self._create_plates() 163 | result = {} 164 | for name, site in self.prototype_trace.iter_stochastic_nodes(): 165 | with ExitStack() as stack: 166 | for frame in site["cond_indep_stack"]: 167 | if frame.vectorized: 168 | stack.enter_context(plates[frame.name]) 169 | loc = self.get_loc(name) 170 | scale = self.get_scale(name) 171 | fn = dist.Normal(loc, scale).to_event(site["fn"].event_dim) 172 | base_fn = _get_base_dist(site["fn"]) 173 | if base_fn.support is not dist.constraints.real: 174 | fn = dist.TransformedDistribution(fn, biject_to(base_fn.support)) 175 | result[name] = pyro.sample(name, fn) 176 | return result 177 | 178 | def median(self, *args, **kwargs): 179 | return {site["name"]: biject_to(site["fn"].support)(self.get_loc(site["name"])) 180 | for site in self.prototype_trace.iter_stochastic_nodes()} 181 | -------------------------------------------------------------------------------- /tyxe/likelihoods.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.distributions.utils as dist_utils 4 | import torch.distributions as torchdist 5 | from torch.distributions import transforms 6 | 7 | import pyro 8 | import pyro.distributions as dist 9 | from pyro.nn import PyroModule, PyroSample 10 | 11 | 12 | __all__ = ["Bernoulli", "Categorical", "HomoskedasticGaussian", "HeteroskedasticGaussian"] 13 | 14 | 15 | def inverse_softplus(t): 16 | return t.expm1().log() 17 | 18 | 19 | def _reduce(tensor, reduction): 20 | if reduction == "none": 21 | return tensor 22 | elif reduction == "sum": 23 | return tensor.sum() 24 | elif reduction == "mean": 25 | return tensor.mean() 26 | else: 27 | raise ValueError("Invalid reduction: '{}'. Must be one of ('none', 'sum', 'mean').".format(reduction)) 28 | 29 | 30 | def _make_name(prefix, suffix): 31 | return ".".join([prefix, suffix]) if prefix else suffix 32 | 33 | 34 | class Likelihood(PyroModule): 35 | """Base class for BNN likelihoods. PyroModule wrapper around the most common distribution class for data noise. 36 | The forward method draws a pyro sample to be used in a model function given some predictions. log_likelihood and 37 | error are utility functions for evaluation. 38 | 39 | :param int dataset_size: Number of data points in the dataset for rescaling the log likelihood in the forward 40 | method when using mini-batches. May be None to disable rescaling. 41 | :param int event_dim: Number of dimensions of the predictive distribution to be interpreted as independent. 42 | :param str name: Base name of the PyroModule. 43 | :param str data_name: Site name of the pyro sample for the data in forward.""" 44 | 45 | def __init__(self, dataset_size, event_dim=0, name="", data_name="data"): 46 | super().__init__(name) 47 | self.dataset_size = dataset_size 48 | self.event_dim = event_dim 49 | self._data_name = data_name 50 | 51 | @property 52 | def data_name(self): 53 | return self.var_name(self._data_name) 54 | 55 | def var_name(self, name): 56 | return _make_name(self._pyro_name, name) 57 | 58 | def forward(self, predictions, obs=None): 59 | """Executes a pyro sample statement to sample from the distribution corresponding to the likelihood class 60 | given some predictions. The values of the sample can set to some optional observations obs. 61 | 62 | :param torch.Tensor predictions: tensor of predictions. 63 | :param torch.Tensor obs: optional known values for the samples.""" 64 | predictive_distribution = self.predictive_distribution(predictions) 65 | if predictive_distribution.batch_shape: 66 | dataset_size = self.dataset_size if self.dataset_size is not None else len(predictions) 67 | with pyro.plate(self.data_name+"_plate", subsample=predictions, size=dataset_size): 68 | return pyro.sample(self.data_name, predictive_distribution, obs=obs) 69 | else: 70 | dataset_size = self.dataset_size if self.dataset_size is not None else 1 71 | with pyro.poutine.scale(scale=dataset_size): 72 | return pyro.sample(self.data_name, predictive_distribution, obs=obs) 73 | 74 | def log_likelihood(self, predictions, data, aggregation_dim=None, reduction="none"): 75 | if aggregation_dim is not None: 76 | predictions = self.aggregate_predictions(predictions, aggregation_dim) 77 | log_probs = self.predictive_distribution(predictions).log_prob(data) 78 | return _reduce(log_probs, reduction) 79 | 80 | def error(self, predictions, data, aggregation_dim=None, reduction="none"): 81 | if aggregation_dim is not None: 82 | predictions = self.aggregate_predictions(predictions, aggregation_dim) 83 | errors = dist.util.sum_rightmost(self._calc_error(self._point_predictions(predictions), data), self.event_dim) 84 | return _reduce(errors, reduction) 85 | 86 | def sample(self, predictions, sample_shape=torch.Size()): 87 | return self.predictive_distribution(predictions).sample(sample_shape) 88 | 89 | def predictive_distribution(self, predictions): 90 | return self.batch_predictive_distribution(predictions).to_event(self.event_dim) 91 | 92 | def batch_predictive_distribution(self, predictions): 93 | """Returns a batched object of predictive distributions.""" 94 | raise NotImplementedError 95 | 96 | def aggregate_predictions(self, predictions, dim=0): 97 | """Aggregates multiple samples of predictions, e.g. averages for Gaussian or probabilities.""" 98 | raise NotImplementedError 99 | 100 | def _point_predictions(self, predictions): 101 | """Point predictions without noise, e.g. hard class labels for Bernoulli or Categorical.""" 102 | raise NotImplementedError 103 | 104 | def _calc_error(self, point_predictions, data): 105 | """Typical error measure, e.g. squared errors for Gaussians or number of mis-classifications for Categorical.""" 106 | raise NotImplementedError 107 | 108 | 109 | class _Discrete(Likelihood): 110 | """Discrete base class that unifies logic for Bernoulli and Categorical likelihood classes.""" 111 | 112 | def __init__(self, dataset_size, logit_predictions=True, event_dim=0, name="", data_name="data"): 113 | super().__init__(dataset_size, event_dim=event_dim, name=name, data_name=data_name) 114 | self.logit_predictions = logit_predictions 115 | 116 | def base_dist(self, probs=None, logits=None): 117 | raise NotImplementedError 118 | 119 | def batch_predictive_distribution(self, predictions): 120 | return self.base_dist(logits=predictions) if self.logit_predictions else self.base_dist(probs=predictions) 121 | 122 | def _calc_error(self, point_predictions, data): 123 | return point_predictions.ne(data).float() 124 | 125 | def aggregate_predictions(self, predictions, dim=0): 126 | probs = dist_utils.logits_to_probs(predictions, is_binary=self.is_binary) if self.logit_predictions else predictions 127 | avg_probs = probs.mean(dim) 128 | return dist_utils.probs_to_logits(avg_probs, is_binary=self.is_binary) if self.logit_predictions else avg_probs 129 | 130 | @property 131 | def is_binary(self): 132 | raise NotImplementedError 133 | 134 | 135 | class Bernoulli(_Discrete): 136 | """Bernoulli likelihood for binary observations.""" 137 | 138 | base_dist = dist.Bernoulli 139 | 140 | def _point_predictions(self, predictions): 141 | return predictions.gt(0.) if self.logit_predictions else predictions.gt(0.5) 142 | 143 | @property 144 | def is_binary(self): 145 | return True 146 | 147 | 148 | class Categorical(_Discrete): 149 | """Categorical likelihood for multi-class observations.""" 150 | 151 | base_dist = dist.Categorical 152 | 153 | def _point_predictions(self, predictions): 154 | return predictions.argmax(-1) 155 | 156 | @property 157 | def is_binary(self): 158 | return False 159 | 160 | 161 | class Gaussian(Likelihood): 162 | """Base class for Gaussian likelihoods.""" 163 | 164 | def __init__(self, dataset_size, event_dim=1, name="", data_name="data"): 165 | super().__init__(dataset_size, event_dim=event_dim, name=name, data_name=data_name) 166 | self.event_dim = event_dim 167 | 168 | def batch_predictive_distribution(self, predictions): 169 | loc, scale = self._predictive_loc_scale(predictions) 170 | return dist.Normal(loc, scale) 171 | 172 | def _point_predictions(self, predictions): 173 | return self._predictive_loc_scale(predictions)[0] 174 | 175 | def _calc_error(self, point_predictions, data): 176 | return point_predictions.sub(data).pow(2) 177 | 178 | def _predictive_loc_scale(self, predictions): 179 | raise NotImplementedError 180 | 181 | 182 | class HeteroskedasticGaussian(Gaussian): 183 | """Heteroskedastic Gaussian likelihood, i.e. Gaussian with data-dependent observation noise that is assumed to be 184 | part of the predictions. For d-dimensional observations, the predictions are expected to be 2d, with the tensor 185 | of predictions being split in the middle along the final event dim and the first half corresponding to predicted 186 | means and the second half to the standard deviations (which may be negative, in which case they are passed 187 | through a softplus function). 188 | 189 | :param bool positive_scale: Whether the predicted scales can be assumed to be positive.""" 190 | 191 | def __init__(self, dataset_size, positive_scale=False, event_dim=1, name="", data_name="data"): 192 | super().__init__(dataset_size, event_dim=event_dim, name=name, data_name=data_name) 193 | self.positive_scale = positive_scale 194 | 195 | def aggregate_predictions(self, predictions, dim=0): 196 | """Aggregates multiple predictions for the same data by averaging them according to their predicted noise. 197 | Means with lower predicted noise are given higher weight in the average. Predictive variance is the variance 198 | of the means plus the average predicted variance.""" 199 | loc, scale = self._predictive_loc_scale(predictions) 200 | precision = scale.pow(-2) 201 | total_precision = precision.sum(dim) 202 | agg_loc = loc.mul(precision).sum(dim).div(total_precision) 203 | agg_scale = precision.reciprocal().mean(dim).add(loc.var(dim)).sqrt() 204 | if not self.positive_scale: 205 | agg_scale = inverse_softplus(agg_scale) 206 | return torch.cat([agg_loc, agg_scale], -1) 207 | 208 | def _predictive_loc_scale(self, predictions): 209 | loc, pred_scale = predictions.chunk(2, dim=-1) 210 | scale = pred_scale if self.positive_scale else F.softplus(pred_scale) 211 | return loc, scale 212 | 213 | 214 | class HomoskedasticGaussian(Gaussian): 215 | """Homeskedastic Gaussian likelihood, i.e. a likelihood that assumes the noise to be data-independent. The scale 216 | or precision may be a distribution, i.e. be unknown and have a prior placed on it for it to be inferred or be a 217 | PyroParameter in order to be learnable. 218 | 219 | :param scale: tensor, parameter or prior distribution for the scale. Mutually exclusive with precision. 220 | :param precision: tensor, parameter or prior distribution for the precision. Mutually exclusive with scale.""" 221 | 222 | def __init__(self, dataset_size, scale=None, precision=None, event_dim=1, name="", data_name="data"): 223 | super().__init__(dataset_size, event_dim=event_dim, name=name, data_name=data_name) 224 | if int(scale is None) + int(precision is None) != 1: 225 | raise ValueError("Exactly one of scale and precision must be specified") 226 | elif isinstance(scale, (dist.Distribution, torchdist.Distribution)): 227 | # if the scale or precision is a distribution, that is used as the prior for a PyroSample. I'm not 228 | # completely sure if it is a good idea to allow regular pytorch distributions, since they might not have the 229 | # correct event_dim, so perhaps it's safer to check e.g. if the batch shape is empty and raise an error 230 | # otherwise 231 | precision = PyroSample(prior=dist.TransformedDistribution(scale, transforms.PowerTransform(-2.))) 232 | scale = PyroSample(prior=scale) 233 | elif isinstance(precision, (dist.Distribution, torchdist.Distribution)): 234 | scale = PyroSample(prior=dist.TransformedDistribution(precision, transforms.PowerTransform(-0.5))) 235 | precision = PyroSample(prior=precision) 236 | else: 237 | # nothing to do, precision or scale is a number/tensor/parameter 238 | pass 239 | self._scale = scale 240 | self._precision = precision 241 | 242 | @property 243 | def scale(self): 244 | if self._scale is None: 245 | return self.precision ** -0.5 246 | else: 247 | return self._scale 248 | 249 | @property 250 | def precision(self): 251 | if self._precision is None: 252 | return self.scale ** -2 253 | else: 254 | return self._precision 255 | 256 | def aggregate_predictions(self, predictions, dim=0): 257 | """Aggregates multiple predictions for the same data by averaging them. Predictive variance is the variance 258 | of the predictions plus the known variance term.""" 259 | loc = predictions.mean(dim) 260 | scale = predictions.var(dim).add(self.scale ** 2).sqrt() 261 | return loc, scale 262 | 263 | def _predictive_loc_scale(self, predictions): 264 | if isinstance(predictions, tuple): 265 | loc, scale = predictions 266 | else: 267 | loc = predictions 268 | scale = self.scale 269 | return loc, scale 270 | -------------------------------------------------------------------------------- /tyxe/poutine/__init__.py: -------------------------------------------------------------------------------- 1 | from .handlers import * 2 | from .reparameterization_messengers import * 3 | 4 | -------------------------------------------------------------------------------- /tyxe/poutine/handlers.py: -------------------------------------------------------------------------------- 1 | from .reparameterization_messengers import LocalReparameterizationMessenger, FlipoutMessenger 2 | from .selective_messengers import SelectiveMaskMessenger, SelectiveScaleMessenger 3 | 4 | 5 | __all__ = [ 6 | "local_reparameterization", 7 | "flipout", 8 | "selective_mask", 9 | "selective_scale" 10 | ] 11 | 12 | 13 | # automate the following as in pyro.poutine.handlers 14 | def local_reparameterization(fn=None, reparameterizable_functions=None): 15 | msngr = LocalReparameterizationMessenger(reparameterizable_functions=reparameterizable_functions) 16 | return msngr(fn) if fn is not None else msngr 17 | 18 | 19 | def flipout(fn=None, reparameterizable_functions=None): 20 | msngr = FlipoutMessenger(reparameterizable_functions=reparameterizable_functions) 21 | return msngr(fn) if fn is not None else msngr 22 | 23 | 24 | def selective_mask(fn=None, mask=None, **block_kwargs): 25 | msngr = SelectiveMaskMessenger(mask, **block_kwargs) 26 | return msngr(fn) if fn is not None else msngr 27 | 28 | 29 | def selective_scale(fn=None, scale=1.0, **block_kwargs): 30 | msngr = SelectiveScaleMessenger(scale, **block_kwargs) 31 | return msngr(fn) if fn is not None else msngr 32 | -------------------------------------------------------------------------------- /tyxe/poutine/reparameterization_messengers.py: -------------------------------------------------------------------------------- 1 | from functools import update_wrapper 2 | from weakref import WeakValueDictionary, WeakKeyDictionary 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | import pyro.distributions as dist 8 | from pyro.poutine.messenger import Messenger 9 | from pyro.poutine.runtime import effectful 10 | 11 | 12 | __all__ = [ 13 | "LocalReparameterizationMessenger", 14 | "FlipoutMessenger" 15 | ] 16 | 17 | 18 | def _get_base_dist(distribution): 19 | while isinstance(distribution, dist.Independent): 20 | distribution = distribution.base_dist 21 | return distribution 22 | 23 | 24 | def _is_reparameterizable(distribution): 25 | if distribution is None: 26 | # bias terms may be None, which does not prevent reparameterization 27 | return True 28 | return isinstance(_get_base_dist(distribution), (dist.Normal, dist.Delta)) 29 | 30 | 31 | def _get_loc_var(distribution): 32 | if distribution is None: 33 | return None, None 34 | if torch.is_tensor(distribution): 35 | # distribution might be a pyro param, which is equivalent to a delta distribution 36 | return distribution, torch.zeros_like(distribution) 37 | distribution = _get_base_dist(distribution) 38 | return distribution.mean, distribution.variance 39 | 40 | 41 | class _ReparameterizationMessenger(Messenger): 42 | """Base class for reparameterization of sampling sites where a transformation of a stochastic by a deterministic 43 | variable allows for analytically calculating (or approximation) the distribution of the result and sampling 44 | the result instead of the original stochastic variable. See subclasses for examples. 45 | 46 | Within the context of this messenger, functions in the REPARAMETERIZABLE_FUNCTIONS attribute will have the 47 | outputs sampled instead of the inputs to the weight and bias attributes. This can reduce gradient noise. For now, 48 | reparameterization is limited to F.linear and F.conv, which are used by the corresponding nn.Linear and nn.Conv 49 | modules in pytorch.""" 50 | 51 | # TODO check if transposed convolutions could be added as well, might be useful for Bayesian conv VAEs 52 | REPARAMETERIZABLE_FUNCTIONS = ["linear", "conv1d", "conv2d", "conv3d"] 53 | 54 | def __init__(self, reparameterizable_functions=None): 55 | super().__init__() 56 | if reparameterizable_functions is None: 57 | reparameterizable_functions = self.REPARAMETERIZABLE_FUNCTIONS 58 | elif isinstance(reparameterizable_functions, str): 59 | reparameterizable_functions = [reparameterizable_functions] 60 | elif isinstance(reparameterizable_functions, (list, tuple)): 61 | reparameterizable_functions = list(reparameterizable_functions) 62 | else: 63 | raise ValueError(f"Unrecognized type for argument 'reparameterizable_functions. Must be str, list or " 64 | f"None, but go '{reparameterizable_functions.__class__.__name__}'.") 65 | self.reparameterizable_functions = reparameterizable_functions 66 | 67 | def __enter__(self): 68 | # deps maps sampled tensors to distributon object to check if local reparameterization is possible. 69 | # I'm using a weakref dictionary here for memory efficiency -- a standard dict would create references to all 70 | # kinds of intermediate tensors, preventing them from being garbage collected. This would be a problem if the 71 | # Messenger is used as a context outside of a training loop. Ideally I would like to use a WeakKeyDictionary, 72 | # since I would expect that the samples from the distribution are much less likely to be kept around than the 73 | # distribution object itself. I'm using id(tensor) as dictionary keys in order to avoid creating references to 74 | # the samples from the distributions. However this still means that the self.deps dictionary will keep growing 75 | # if the distribution objects from the model/guide are kept around. 76 | self.deps = WeakValueDictionary() 77 | self.original_fns = [getattr(F, name) for name in self.reparameterizable_functions] 78 | self._make_reparameterizable_functions_effectful() 79 | return super().__enter__() 80 | 81 | def __exit__(self, exc_type, exc_val, exc_tb): 82 | self._reset_reparameterizable_functions() 83 | del self.deps 84 | del self.original_fns 85 | return super().__exit__(exc_type, exc_val, exc_tb) 86 | 87 | def _make_reparameterizable_functions_effectful(self): 88 | for name, fn in zip(self.reparameterizable_functions, self.original_fns): 89 | effectful_fn = update_wrapper(effectful(fn, type="reparameterizable"), fn) 90 | setattr(F, name, effectful_fn) 91 | 92 | def _reset_reparameterizable_functions(self): 93 | for name, fn in zip(self.reparameterizable_functions, self.original_fns): 94 | setattr(F, name, fn) 95 | 96 | def _pyro_post_sample(self, msg): 97 | if id(msg["value"]) not in self.deps: 98 | self.deps[id(msg["value"])] = msg["fn"] 99 | 100 | def _pyro_reparameterizable(self, msg): 101 | if msg["fn"].__name__ not in self.reparameterizable_functions: 102 | return 103 | 104 | if msg["done"]: 105 | raise ValueError(f"Trying to reparameterize a {msg['fn'].__name__} site that has already been processed. " 106 | f"Did you use multiple reparameterization messengers for the same function?") 107 | 108 | args = list(msg["args"]) 109 | kwargs = msg["kwargs"] 110 | x = kwargs.pop("input", None) or args.pop(0) 111 | # if w is in args, so must have been x, therefore w will now be the first argument in args if not in kwargs 112 | w = kwargs.pop("weight", None) or args.pop(0) 113 | # bias might be None, so check explicitly if it's in kwargs -- if it is positional, x and w 114 | # must have been positional arguments as well 115 | b = kwargs.pop("bias") if "bias" in kwargs else args.pop(0) 116 | if id(w) in self.deps: 117 | w_fn = self.deps[id(w)] 118 | b_fn = self.deps[id(b)] if b is not None else None 119 | if torch.is_tensor(x) and _is_reparameterizable(w_fn) and _is_reparameterizable(b_fn): 120 | msg["value"] = self._reparameterize(msg, x, w_fn, w, b_fn, b, *args, **kwargs) 121 | msg["done"] = True 122 | 123 | def _reparameterize(self, msg, x, w_loc, w_var, b_loc, b_var, *args, **kwargs): 124 | raise NotImplementedError 125 | 126 | 127 | class LocalReparameterizationMessenger(_ReparameterizationMessenger): 128 | """Implements local reparameterization: https://arxiv.org/abs/1506.02557""" 129 | 130 | def _reparameterize(self, msg, x, w_fn, w, b_fn, b, *args, **kwargs): 131 | w_loc, w_var = _get_loc_var(w_fn) 132 | b_loc, b_var = _get_loc_var(b_fn) 133 | loc = msg["fn"](x, w_loc, b_loc, *args, **kwargs) 134 | var = msg["fn"](x.pow(2), w_var, b_var, *args, **kwargs) 135 | # ensure positive variances to avoid NaNs when taking square root 136 | var = var + var.lt(0).float().mul(var.abs() + 1e-6).detach() 137 | scale = var.sqrt() 138 | return dist.Normal(loc, scale).rsample() 139 | 140 | 141 | def _pad_right_like(tensor1, tensor2): 142 | while tensor1.ndim < tensor2.ndim: 143 | tensor1 = tensor1.unsqueeze(-1) 144 | return tensor1 145 | 146 | 147 | def _rand_signs(*args, **kwargs): 148 | return torch.rand(*args, **kwargs).gt(0.5).float().mul(2).sub(1) 149 | 150 | 151 | class FlipoutMessenger(_ReparameterizationMessenger): 152 | """Implements flipout: https://arxiv.org/abs/1803.04386""" 153 | 154 | FUNCTION_RANKS = {"linear": 1, "conv1d": 2, "conv2d": 3, "conv3d": 4} 155 | 156 | def _reparameterize(self, msg, x, w_fn, w, b_fn, b, *args, **kwargs): 157 | fn = msg["fn"] 158 | w_loc, _ = _get_loc_var(w_fn) 159 | loc = fn(x, w_loc, None, *args, **kwargs) 160 | 161 | # x might be one dimensional for a 1-d input with a single datapoint to F.linear, F.conv always has a batch dim 162 | batch_shape = x.shape[:-self.FUNCTION_RANKS[fn.__name__]] if x.ndim > 1 else tuple() 163 | # w might be 1-d for F.linear for a 0-d output 164 | output_shape = (w_loc.shape[0],) if w_loc.ndim > 1 else tuple() 165 | input_shape = (w_loc.shape[1],) if w_loc.ndim > 1 else (w_loc.shape[0],) 166 | 167 | if not hasattr(w, "sign_input"): 168 | w.sign_input = _pad_right_like(_rand_signs(batch_shape + input_shape, device=loc.device), x) 169 | w.sign_output = _pad_right_like(_rand_signs(batch_shape + output_shape, device=loc.device), x) 170 | 171 | w_perturbation = w - w_loc 172 | perturbation = fn(x * w.sign_input, w_perturbation, None, *args, **kwargs) * w.sign_output 173 | 174 | output = loc + perturbation 175 | if b is not None: 176 | b_loc, b_var = _get_loc_var(b_fn) 177 | bias = _pad_right_like(dist.Normal(b_loc, b_var.sqrt()).rsample(batch_shape), output) 178 | output += bias 179 | return output 180 | -------------------------------------------------------------------------------- /tyxe/poutine/selective_messengers.py: -------------------------------------------------------------------------------- 1 | from pyro.poutine.mask_messenger import MaskMessenger 2 | from pyro.poutine.scale_messenger import ScaleMessenger 3 | from pyro.poutine.block_messenger import _make_default_hide_fn 4 | 5 | 6 | __all__ = [ 7 | "SelectiveMaskMessenger", 8 | "SelectiveScaleMessenger" 9 | ] 10 | 11 | 12 | class SelectiveMixin(object): 13 | 14 | def __init__(self, *args, 15 | hide_fn=None, expose_fn=None, 16 | hide_all=True, expose_all=False, 17 | hide=None, expose=None, 18 | hide_types=None, expose_types=None, 19 | **kwargs): 20 | super().__init__(*args, **kwargs) 21 | if not (hide_fn is None or expose_fn is None): 22 | raise ValueError("Only specify one of hide_fn or expose_fn") 23 | if hide_fn is not None: 24 | self.hide_fn = hide_fn 25 | elif expose_fn is not None: 26 | self.hide_fn = lambda msg: not expose_fn(msg) 27 | else: 28 | self.hide_fn = _make_default_hide_fn(hide_all, expose_all, 29 | hide, expose, 30 | hide_types, expose_types) 31 | 32 | def _process_message(self, msg): 33 | if not self.hide_fn(msg): 34 | super()._process_message(msg) 35 | 36 | 37 | class SelectiveMaskMessenger(SelectiveMixin, MaskMessenger): pass 38 | class SelectiveScaleMessenger(SelectiveMixin, ScaleMessenger): pass 39 | -------------------------------------------------------------------------------- /tyxe/priors.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import torch.nn.init as nn_init 4 | 5 | import pyro.distributions as dist 6 | from pyro.nn.module import PyroSample, PyroParam 7 | 8 | 9 | from . import util 10 | 11 | 12 | def _make_expose_fn(hide_modules, expose_modules, hide_module_types, expose_module_types, 13 | hide_parameters, expose_parameters, hide, expose): 14 | if expose_modules is None: 15 | expose_modules = [] 16 | else: 17 | expose_all = False 18 | 19 | if hide_modules is None: 20 | hide_modules = [] 21 | else: 22 | expose_all = True 23 | 24 | if expose_module_types is None: 25 | expose_module_types = tuple() 26 | else: 27 | expose_all = False 28 | 29 | if hide_module_types is None: 30 | hide_module_types = tuple() 31 | else: 32 | expose_all = True 33 | 34 | if expose_parameters is None: 35 | expose_parameters = [] 36 | else: 37 | expose_all = False 38 | 39 | if hide_parameters is None: 40 | hide_parameters = [] 41 | else: 42 | expose_all = True 43 | 44 | if expose is None: 45 | expose = [] 46 | else: 47 | expose_all = False 48 | 49 | if hide is None: 50 | hide = [] 51 | else: 52 | expose_all = True 53 | 54 | if not set(hide_modules).isdisjoint(set(expose_modules)): 55 | raise ValueError("Cannot hide and expose a module.") 56 | 57 | if not set(hide_parameters).isdisjoint(set(expose_parameters)): 58 | raise ValueError("Cannot hide and expose a parameter type.") 59 | 60 | if not set(hide).isdisjoint(set(expose)): 61 | raise ValueError("Cannot hide and expose a parameter.") 62 | 63 | def expose_fn(module, param_name): 64 | if param_name in hide: 65 | return False 66 | if param_name in expose: 67 | return True 68 | 69 | param_suffix = param_name.rsplit(".")[-1] 70 | if param_suffix in hide_parameters: 71 | return False 72 | if param_suffix in expose_parameters: 73 | return True 74 | 75 | if isinstance(module, hide_module_types): 76 | return False 77 | if isinstance(module, expose_module_types): 78 | return True 79 | 80 | if module in hide_modules: 81 | return False 82 | if module in expose_modules: 83 | return True 84 | 85 | return expose_all 86 | 87 | return expose_fn 88 | 89 | 90 | class Prior(metaclass=ABCMeta): 91 | """Base class for TyXe's BNN priors that helps with replacing nn.Parameter attributes on PyroModule objects 92 | with PyroSamples via its apply_ method or updating them via update_ and handles logic for excluding some parameters 93 | from having priors based on them via the hide/exclude arguments of the init method. Subclasses must implement 94 | a prior_dist method that returns a distribution object given a parameter name, module and nn.Parameter object.""" 95 | 96 | def __init__(self, hide_all=False, expose_all=True, 97 | hide_modules=None, expose_modules=None, 98 | hide_module_types=None, expose_module_types=None, 99 | hide_parameters=None, expose_parameters=None, 100 | hide=None, expose=None, 101 | hide_fn=None, expose_fn=None): 102 | """Hides/exposes parameter attributes from/to being replaced by PyroSamples. The options are: 103 | * all: hides/exposes all parameters. expose_all must be set to False for using any of the other options. 104 | * modules: nn.Modules object that are part of the net being passed in apply_. 105 | * module_types: tuple of classes inheriting from nn.Module, e.g. nn.Linear. 106 | * parameters: list of parameter attribute names, e.g. "weight" for hiding/exposing the weight attribute of 107 | an nn.Linear module. 108 | * hide/expose: list of full parameter names, e.g. "0.weight" for a nn.Sequential net where the first layer is a 109 | a nn.Conv or nn.Linear module that has a weight attribute. 110 | * fn: function that returns True or False given a module and param_name string.""" 111 | if hide_all: 112 | self.expose_fn = lambda module, name: False 113 | elif expose_fn is not None: 114 | self.expose_fn = expose_fn 115 | elif hide_fn is not None: 116 | self.expose_fn = lambda module, name: not hide_fn(module, name) 117 | elif expose_all: 118 | self.expose_fn = lambda module, name: True 119 | else: 120 | self.expose_fn = _make_expose_fn( 121 | hide_modules, expose_modules, hide_module_types, expose_module_types, 122 | hide_parameters, expose_parameters, hide, expose) 123 | 124 | def apply_(self, net): 125 | """"Replaces all nn.Parameter attributes on a given PyroModule net according to the hide/expose logic and 126 | the classes' prior_dist method.""" 127 | for module_name, module in net.named_modules(): 128 | for param_name, param in list(module.named_parameters(recurse=False)): 129 | full_name = module_name + "." + param_name 130 | if self.expose_fn(module, full_name): 131 | prior_dist = self.prior_dist(full_name, module, param).expand(param.shape).to_event(param.dim()) 132 | setattr(module, param_name, PyroSample(prior_dist)) 133 | else: 134 | setattr(module, param_name, PyroParam(param.data.detach())) 135 | 136 | def update_(self, net): 137 | """Replaces PyroSample attributes on a given PyroModule net according to the hide/expose logic and 138 | the classes' prior_dist method.""" 139 | for module_name, module in net.named_modules(): 140 | for site_name, site in list(util.named_pyro_samples(module, recurse=False)): 141 | full_name = module_name + "." + site_name 142 | # See change in DictPrior as an alternative 143 | # if type(self)== DictPrior: 144 | # full_name = 'net.' + full_name 145 | if self.expose_fn(module, full_name): 146 | prior_dist = self.prior_dist(full_name, module, site) 147 | setattr(module, site_name, PyroSample(prior_dist)) 148 | 149 | @abstractmethod 150 | def prior_dist(self, name, module, param): 151 | pass 152 | 153 | 154 | class IIDPrior(Prior): 155 | """Independent identically distributed prior that is the same across all sites. Intended to be used with 156 | one-dimensional distribution that can be extended to the shape of each site, e.g. dist.Normal.""" 157 | 158 | def __init__(self, distribution, *args, **kwargs): 159 | super().__init__(*args, **kwargs) 160 | self._distribution = distribution 161 | 162 | def prior_dist(self, name, module, param): 163 | return self._distribution 164 | 165 | 166 | class LayerwiseNormalPrior(Prior): 167 | """Normal prior with module-dependent variance to preserve the variance of an input passed through the layer. 168 | "radford" sets the variance to the inverse of the number of inputs, "kaiming" multiplies this with an additional 169 | gain factor depending on the nonlinearity and "xavier" to the inverse of the average of the number of inputs 170 | and outputs (the latter correspond to weight initialization methods for deterministic neural networks).""" 171 | 172 | def __init__(self, method="radford", nonlinearity="relu", *args, **kwargs): 173 | super().__init__(*args, **kwargs) 174 | if method not in ("radford", "xavier", "kaiming"): 175 | raise ValueError(f"variance must be one of ('radford', 'xavier', 'kaiming'), but is {method}") 176 | self.method = method 177 | self.nonlinearity = nonlinearity 178 | 179 | def prior_dist(self, name, module, param): 180 | module_nonl = self.nonlinearity if isinstance(self.nonlinearity, str) else self.nonlinearity.get(module) 181 | gain = nn_init.calculate_gain(module_nonl) if module_nonl is not None else 1. 182 | std = util.calculate_prior_std(self.method, param, gain) 183 | return dist.Normal(0., std) 184 | 185 | 186 | class DictPrior(Prior): 187 | """Dictionary of prior distributions mapping parameter names as in module.named_parameters() to distribution 188 | objects.""" 189 | 190 | def __init__(self, prior_dict, *args, **kwargs): 191 | super().__init__(*args, **kwargs) 192 | self.prior_dict = prior_dict 193 | 194 | def prior_dist(self, name, module, param): 195 | return self.prior_dict[name] 196 | 197 | class LambdaPrior(Prior): 198 | """Utility class to avoid implementing a prior class for a given function.""" 199 | 200 | def __init__(self, fn, *args, **kwargs): 201 | super().__init__(*args, **kwargs) 202 | self.fn = fn 203 | 204 | def prior_dist(self, name, module, param): 205 | return self.fn(name, module, param) 206 | -------------------------------------------------------------------------------- /tyxe/util.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import copy 3 | from functools import reduce 4 | from operator import mul, itemgetter 5 | from warnings import warn 6 | 7 | import torch 8 | 9 | import pyro.util 10 | import pyro.infer.autoguide.guides 11 | import pyro.nn.module as pyromodule 12 | 13 | def deep_hasattr(obj, name): 14 | warn('deep_hasattr is deprecated.', DeprecationWarning, stacklevel=2) 15 | try: 16 | pyro.util.deep_getattr(obj, name) 17 | return True 18 | except AttributeError: 19 | return False 20 | 21 | 22 | def deep_setattr(obj, key, val): 23 | warn('deep_setattr is deprecated.', DeprecationWarning, stacklevel=2) 24 | return pyro.infer.autoguide.guides.deep_setattr(obj, key, val) 25 | 26 | def deep_getattr(obj, name): 27 | warn('deep_getattr is deprecated.', DeprecationWarning, stacklevel=2) 28 | return pyro.util.deep_getattr(obj, name) 29 | 30 | 31 | def to_pyro_module_(m, name="", recurse=True): 32 | """ 33 | Same as `pyro.nn.modules.to_pyro_module_` except that it also accepts a name argument and returns the modified 34 | module following the convention in pytorch for inplace functions. 35 | """ 36 | if not isinstance(m, torch.nn.Module): 37 | raise TypeError("Expected an nn.Module instance but got a {}".format(type(m))) 38 | 39 | if isinstance(m, pyromodule.PyroModule): 40 | if recurse: 41 | for name, value in list(m._modules.items()): 42 | to_pyro_module_(value) 43 | setattr(m, name, value) 44 | return 45 | 46 | # Change m's type in-place. 47 | m.__class__ = pyromodule.PyroModule[m.__class__] 48 | m._pyro_name = name 49 | m._pyro_context = pyromodule._Context() 50 | m._pyro_params = OrderedDict() 51 | m._pyro_samples = OrderedDict() 52 | 53 | # Reregister parameters and submodules. 54 | for name, value in list(m._parameters.items()): 55 | setattr(m, name, value) 56 | for name, value in list(m._modules.items()): 57 | if recurse: 58 | to_pyro_module_(value) 59 | setattr(m, name, value) 60 | 61 | return m 62 | 63 | 64 | def to_pyro_module(m, name="", recurse=True): 65 | return to_pyro_module_(copy.deepcopy(m), name, recurse) 66 | 67 | 68 | def named_pyro_samples(pyro_module, prefix='', recurse=True): 69 | yield from pyro_module._named_members(lambda module: module._pyro_samples.items(), prefix=prefix, recurse=recurse) 70 | 71 | 72 | def pyro_sample_sites(pyro_module, prefix='', recurse=True): 73 | yield from map(itemgetter(0), named_pyro_samples(pyro_module, prefix=prefix, recurse=recurse)) 74 | 75 | 76 | def prod(iterable, initial_value=1): 77 | return reduce(mul, iterable, initial_value) 78 | 79 | 80 | def fan_in_fan_out(weight): 81 | # this holds for linear and conv layers, but check e.g. transposed conv 82 | fan_in = prod(weight.shape[1:]) 83 | fan_out = weight.shape[0] 84 | return fan_in, fan_out 85 | 86 | 87 | def calculate_prior_std(method, weight, gain=1., mode="fan_in"): 88 | fan_in, fan_out = fan_in_fan_out(weight) 89 | if method == "radford": 90 | std = fan_in ** -0.5 91 | elif method == "xavier": 92 | std = gain * (2 / (fan_in + fan_out)) ** 0.5 93 | elif method == "kaiming": 94 | fan = fan_in if mode == "fan_in" else fan_out 95 | std = gain * fan ** -0.5 96 | else: 97 | raise ValueError(f"Invalid method: '{method}'. Must be one of ('radford', 'xavier', 'kaiming'.") 98 | return torch.tensor(std, device=weight.device) 99 | --------------------------------------------------------------------------------