├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── bnn ├── __init__.py ├── bayesianize.py ├── calibration.py ├── distributions │ ├── __init__.py │ ├── kl.py │ └── matrix_normal.py └── nn │ ├── __init__.py │ ├── mixins │ ├── __init__.py │ ├── base.py │ └── variational │ │ ├── __init__.py │ │ ├── base.py │ │ ├── fcg.py │ │ ├── ffg.py │ │ ├── inducing.py │ │ └── utils.py │ ├── modules.py │ └── nets.py ├── configs ├── ensemble_u_cifar10.json ├── ensemble_u_cifar100.json ├── ffg_u_cifar10.json ├── ffg_u_cifar100.json ├── ffg_w_maxsd01.json └── ffg_w_unconstrained.json ├── environment.yml ├── requirements.txt ├── scripts └── cifar_resnet.py ├── setup.py └── tests ├── bnn ├── distributions │ ├── test_kl.py │ └── test_matrix_normal.py ├── nn │ └── mixins │ │ └── variational │ │ ├── test_fcg.py │ │ ├── test_ffg.py │ │ └── test_inducing.py └── test_bayesianize.py └── regression └── test_variational_synthetic.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bayesianize: a Bayesian neural network wrapper in pytorch 2 | 3 | Bayesianize is a lightweight Bayesian neural network (BNN) wrapper in pytorch. The overall goal is to allow for easy conversion of neural networks in existing scripts to BNNs with minimal changes to the code. 4 | 5 | Currently the wrapper supports the following uncertainty estimation methods for feed-forward neural networks and convnets: 6 | 7 | * Mean-field variational inference (MFVI): variational inference with fully factorised Gaussian (FFG) approximation. 8 | * Variational inference with full-covariance Gaussian approximation (for each layer). 9 | * Variational inference with inducing weights: each of the layer is augmented with a small matrix of inducing weights, then MFVI is performed in the inducing weight space. 10 | * Ensemble in inducing weight space: same augmentation as above, but with ensembles in the inducing weight space. 11 | 12 | ## Usage 13 | 14 | The main workhorse of our library is the `bayesianize_` function. 15 | It can be applied to a pytorch neural network and turns deterministic `nn.Linear` and `nn.Conv` layers into their bayesian counterparts. 16 | For example, to construct a Bayesian ResNet-18 that uses the variational inducing weight method, run: 17 | ``` 18 | import bnn 19 | net = torchvision.models.resnet18() 20 | bnn.bayesianize_(net, inference="inducing", inducing_rows=64, inducing_cols=64) 21 | ``` 22 | 23 | Then the converted BNN can be trained in almost identical way as one would train a deterministic net: 24 | ``` 25 | yhat = net(x_train) 26 | nll = F.cross_entropy(yhat, y_train) 27 | kl = sum(m.kl_divergence() for m in net.modules() 28 | if hasattr(m, "kl_divergence")) 29 | loss = nll + kl / dataset_size 30 | loss.backward() 31 | optim.step() 32 | ``` 33 | The main difference to training a deterministic net is the extra KL-divergence regulariser in the objective function. 34 | Note that while the call to the forward method of the net looks the same, it is no longer deterministic because the weights are sampled, so to subsequent calls will lead to different predictions. 35 | Therefore, when testing, an average of multiple predictions is needed. For example, in BNN classification: 36 | ``` 37 | net.eval() 38 | with torch.no_grad(): 39 | logits = torch.stack([net(x_test) for _ in range(num_samples)]) 40 | probs = logits.softmax(-1).mean(0) 41 | ``` 42 | 43 | Besides the inducing weight method, other variational inference approaches can be used, by setting `inference="ffg"` for MFVI or `inference="fcg"` for VI with full-covariance Gaussians. 44 | 45 | `bayesianize_` also supports using different methods or arguments for different layers, by passing in a dictionary for the `inference` argument. 46 | This way you can, for example, take a pre-trained ResNet and only perform (approximate) Bayesian inference over the weights of the final, linear layer (which you can access via the `net.fc` attribute): 47 | ``` 48 | bnn.bayesianize_(net, inference={ 49 | net.fc: {"inference": "fcg"} 50 | }) 51 | optim = torch.optim.Adam(net.fc.parameters(), 1e-3) 52 | ``` 53 | If `net` is an instance of `nn.Sequential` the network layers can also be indexed as `net[i]`, e.g. `net[-1]` for the output layer. 54 | Alternatively, it is possible use the names of layers (e.g. `"fc"` for the linear output layer of a ResNet), the names of classes (`"Linear"`) or the corresponding objects as keys for the dictionary to specify the inference arguments for individual or groups of layers. 55 | 56 | ## Installation 57 | 58 | The easiest option for installing the library is to first create a `bayesianize` conda environment from the `.yml` file we provide: 59 | ``` 60 | conda env create -f environment.yml 61 | ``` 62 | Depending on your system, you might need to add a `cudatoolkit` or `cpuonly` as the final line to the `environment.yml` to install the correct version of pytorch, e.g. add 63 | ``` 64 | - cudatoolkit=11.0 65 | ``` 66 | to install pytorch with CUDA11 support. 67 | 68 | Then you can load the environment and pip install our module from source: 69 | ``` 70 | conda activate bayesianize 71 | pip install -e . 72 | ``` 73 | 74 | Alternatively, you can copy the `bnn/` folder to your project or add `/your_path/bnn/` to your `PYTHONPATH`: 75 | ``` 76 | export PYTHONPATH=PATH_TO_INDUCING_WEIGHT_DIR:$PYTHONPATH 77 | ``` 78 | with `PATH_TO_INDUCING_WEIGHT_DIR=/your_path/` in the example case. 79 | 80 | ## Code structure 81 | 82 | The variational inference logic is mostly contained inside the `bnn.nn.mixins.variational` module. 83 | There we implement mixin classes that contain logic for sampling `.weight` and `.bias` parameters from a variational posterior and calculating its KL divergence from a prior. 84 | Those classes are mixed with pytorch's `nn.Linear` and all `nn.Conv` classes in `bnn/nn/modules.py`. 85 | Our `bayesianize_` method automatically collects classes that inherit from `bnn.nn.mixins.base.BayesianMixin` and the Linear or Conv class. 86 | So if you want to add your own variational layer classes, e.g. with a low rank or matrix normal variational posterior, you only need to make them inherit from our `BayesianMixin` class and create the corresponding linear and conv layers in `modules`. 87 | 88 | ## Example script 89 | 90 | We provide an example script for training Bayesian ResNets on CIFAR10 and CIFAR100 in `scripts/cifar_resnet.py`. 91 | The most important command line argument is the `--inference-config`. 92 | If you do not provide a value, your network will remain unchanged and the script will train using maximum likelihood. 93 | Otherwise you can pass in the path to one of the inference config files in the `configs/` directory. 94 | We provide configs for Gaussian mean-field VI with either no contraints on the variational posterior or one where the maximum standard deviation is set to 0.1. 95 | There are also of course configs for our inducing weight method, both with an ensemble and fully-factorised Gaussian VI in inducing space. 96 | Note that there are separate configs for CIFAR10 and CIFAR100 due to the different number of classes. 97 | 98 | To train a BNN with our inducing weight method, you can run the script for example as: 99 | ``` 100 | python scripts/cifar_resnet.py --inference-config=configs/ffg_u_cifar10.json \ 101 | --num-epochs=200 --ml-epochs=100 --annealing-epochs=50 --lr=1e-3 \ 102 | --milestones=100 --resnet=18 --cifar=10 --verbose --progress-bar 103 | ``` 104 | The full list of command line options for the script is: 105 | ``` 106 | --num-epochs: Total number of training epochs 107 | --train-samples: Number of MC samples to draw from the variational posterior during training for the data log likelihood 108 | --test-samples: Number of samples to average the predictive posterior during testing. 109 | --annealing-epochs: Number of training epochs over which the weight of the KL term is annealed linearly. 110 | --ml-epochs: Number of training epochs where the weight of the KL term is 0. 111 | --inference-config: Path to the inference config file 112 | --output-dir: Directory in which to store state dicts of the network and optimizer, and the final calibration plot. 113 | --verbose: Switch for printing validation accuracy and calibration at every epoch. 114 | --progress-bar: Switch for tqdm progress bar for epochs and batches. 115 | --lr: Initial learning rate. 116 | --seed: Random seed. 117 | --cifar: 10 or 100 for the corresponding CIFAR dataset. 118 | --optimizer: sgd or adam for the corresponding optimizer. 119 | --momentum: momentum if using sgd. 120 | --milestones: Comma-separated list of epochs after which to decay the learning rate by a factor of gamma. 121 | --gamma: Multiplicative decay factor for the learning rate. 122 | --resnet: Which ResNet architecture from torchvision to use (must be one of {18, 34, 50, 101, 152}). 123 | ``` 124 | 125 | 126 | ## Contributing 127 | 128 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 129 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 130 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 131 | 132 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 133 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 134 | provided by the bot. You will only need to do this once across all repos using our CLA. 135 | 136 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 137 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 138 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 139 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). 7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /bnn/__init__.py: -------------------------------------------------------------------------------- 1 | from . import calibration 2 | from . import distributions 3 | from . import nn 4 | from .bayesianize import bayesianize_ 5 | from .nn.mixins.variational.inducing import register_global_inducing_weights_ 6 | -------------------------------------------------------------------------------- /bnn/bayesianize.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import itertools 3 | from typing import Any, Dict, Union, Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | from .nn.mixins.base import BayesianMixin 10 | 11 | 12 | def _subclasses(cls: type): 13 | return set(cls.__subclasses__()).union([s for c in cls.__subclasses__() for s in _subclasses(c)]) 14 | 15 | 16 | _BAYESIANIZABLE_CLASSES = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d) 17 | # dictionary mapping nn.Module classes for which we support 'Bayesianization' to a dictionary mapping 18 | # an inference name (the base name of the mixin class) to a class inheriting from the nn.Module and the 19 | # inference Mixin. For example, nn.Linear would map to a dictionary containing `'ffg'=FFGLinear` as one 20 | # of the key-value pairs. The dictionary construction assumes that implementations of Bayesian layers 21 | # inherit from the nn.Module as the final baseclass, so that an arbitrary number of additional mixin 22 | # classes is supported 23 | _BAYESIAN_MODULES = { 24 | t_cls: { 25 | m_cls.__bases__[0].__name__.rstrip("Mixin").lower(): m_cls 26 | for m_cls in _subclasses(BayesianMixin) 27 | if len(m_cls.__bases__) >= 2 and m_cls.__bases__[-1] == t_cls} 28 | for t_cls in _BAYESIANIZABLE_CLASSES 29 | } 30 | 31 | 32 | def _deep_setattr(obj: Any, name: str, value: Any) -> None: 33 | attr_names = name.split(".") 34 | for attr_name in attr_names[:-1]: 35 | obj = getattr(obj, attr_name) 36 | setattr(obj, attr_names[-1], value) 37 | 38 | 39 | def _module_in_sd(module_name: str, sd: Dict): 40 | return any(k.startswith(module_name) for k in sd.keys()) 41 | 42 | 43 | def bayesian_from_template(layer: nn.Module, inference: str, **bayes_kwargs) -> BayesianMixin: 44 | """Takes a pytorch module and turns it into an equivalent Bayesian module depending on inference. 45 | For example, if layer is an nn.Linear instance and inference is ffg, the method constructs an 46 | FFGLinear object with the same number of in_features and out_features.""" 47 | bayes_module_class = _BAYESIAN_MODULES[layer.__class__][inference.lower()] 48 | # pulls nn.Module init arguments from the attributes of the object. bias needs to be treated separately 49 | # since init expects a bool, while it is a torch.tensor or None as an attribute 50 | init_parameters = inspect.signature(layer.__class__).parameters 51 | layer_kwargs = {k: v for k, v in vars(layer).items() if k in init_parameters} 52 | layer_kwargs["bias"] = getattr(layer, "bias", None) is not None 53 | return bayes_module_class(**layer_kwargs, **bayes_kwargs) 54 | 55 | 56 | def bayesianize_(network: nn.Module, 57 | inference: Union[str, Dict[Union[str, nn.Module, type, int], Any]], 58 | reference_state_dict: Optional[Dict[str, torch.tensor]] = None, 59 | **default_params) -> None: 60 | """Method for turning a pytorch neural network that is an instance of nn.Module into a 'Bayesian' 61 | variant, where all of the nn.Linear and nn.ConvNd layers are replaced (inplace) with Bayesian layers. 62 | Which type of Bayesian layer gets used is specified by inference. If it is a string, the same type of 63 | layer is used throughout the net, otherwise a dictionary mapping specific layer names or objects or 64 | entire classes or the module's index to a string can be used. That way it is possible to, for example, 65 | only do variational inference on the output layer and learn the parameters of the remaining layers 66 | via maximum likelihood.""" 67 | if reference_state_dict is None: 68 | reference_state_dict = {} 69 | 70 | num_modules = len(list(network.modules())) 71 | for i, (name, module) in enumerate(network.named_modules()): 72 | if isinstance(inference, str): 73 | module_inference = inference 74 | elif module in inference: 75 | module_inference = inference[module] 76 | elif name in inference: 77 | module_inference = inference[name] 78 | elif i in inference: 79 | module_inference = inference[i] 80 | elif i - num_modules in inference: 81 | module_inference = inference[i - num_modules] 82 | elif module.__class__ in inference: 83 | module_inference = inference[module.__class__] 84 | elif module.__class__.__name__ in inference: 85 | module_inference = inference[module.__class__.__name__] 86 | else: 87 | continue 88 | 89 | if isinstance(module_inference, str): 90 | module_inference = {"inference": module_inference} 91 | 92 | for k, v in default_params.items(): 93 | module_inference.setdefault(k, v) 94 | 95 | cls = module.__class__ 96 | if cls in _BAYESIAN_MODULES and module_inference["inference"] in _BAYESIAN_MODULES[cls]: 97 | bayesian_module = bayesian_from_template(module, **module_inference) 98 | _deep_setattr(network, name, bayesian_module) 99 | 100 | if _module_in_sd(name, reference_state_dict): 101 | param_dict = {param_name: reference_state_dict[name + "." + param_name] 102 | for param_name, _ in module.named_parameters()} 103 | bayesian_module.init_from_deterministic_params(param_dict) 104 | else: 105 | if _module_in_sd(name, reference_state_dict): 106 | module_sd = { 107 | attr_name: reference_state_dict[name + "." + attr_name] for attr_name, _ in 108 | itertools.chain(module.named_parameters(recurse=False), module.named_buffers(recurse=False))} 109 | # check for non-empty state dict, some modules, e.g. BasicBlock in Resnet, only contain other 110 | # modules and don't have any parameters of their own 111 | if module_sd: 112 | module.load_state_dict(module_sd) 113 | -------------------------------------------------------------------------------- /bnn/calibration.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def calibration_curve(probabilities: np.ndarray, targets: np.ndarray, num_bins: int, 5 | top_class_only: bool = True, equal_size_bins: bool = False, min_p: float = 0.0): 6 | """Calculates the calibration of a classifier (binary or multi-class). Specificially it takes 7 | predicted probability values, assigns them to a given number of bins (keeping either the width 8 | of the bins fixed or the number of predictions assigned to each bin) and then returns for each 9 | bin the mean predicted probability of the positive class occuring as well as the empirically observed 10 | frequency as per the targets. Additionally the relative size of each bin is returned. Note that 11 | all inputs are assumed to be well-specified, i.e. probabilities between 0 and 1 and, for multi-class 12 | targets, to sum to 1 across the final dimension. 13 | 14 | Using the default options top_class_only=True and equal_bin_size=False returns mean probabilities, 15 | bin_frequency and bin_weights values as used for the standard ECE forumlation, e.g. in 16 | http://openaccess.thecvf.com/content_CVPRW_2019/papers/Uncertainty%20and%20Robustness%20in%20Deep%20Visual%20Learning/Nixon_Measuring_Calibration_in_Deep_Learning_CVPRW_2019_paper.pdf 17 | top_class_only=False gives results for the Static Calibration Error, equal_size_bins=True the Adaptive Calibration 18 | Error (the paper does not specify whether to set top_class_only to True or False). Setting min_p > 0 19 | corresponds the Thresholded Adaptive Calibration Error. To calculate these calibration error, the outputs 20 | of this fucntion can directly be passed into expected_calibration_error function. 21 | 22 | Args: 23 | probabilities: Array containing probability predictions. 24 | targets: Array containing classification targets. 25 | num_bins: Number of bins for probability values. 26 | top_class_only: Whether to only use the maximum predicted probability for multi-class classification 27 | or all probabilities. 28 | equal_size_bins: Whether to have each bin an equal number of predictions assigned vs. equal width. 29 | min_p: Minimum threshold for the probabilities to count. 30 | Returns: 31 | bin_probability: Average predicted probability. NaN for empty bins. 32 | bin_frequency: Average observed true class frequency. NaN for empty bins. 33 | bin_weights: Relative size of each bin. Zero for empty bins. 34 | """ 35 | 36 | if probabilities.ndim == targets.ndim + 1: 37 | # multi-class 38 | if top_class_only: 39 | # targets are converted to per-datapoint accuracies, i.e. checking whether or not the predicted 40 | # class was observed 41 | predictions = np.cast[targets.dtype](probabilities.argmax(-1)) 42 | targets = targets == predictions 43 | probabilities = probabilities.max(-1) 44 | else: 45 | # convert the targets to one-hot encodings and flatten both those targets and the probabilities, 46 | # treating them as independent predictions for binary classification 47 | num_classes = probabilities.shape[-1] 48 | one_hot_targets = np.cast[targets.dtype](targets[..., np.newaxis] == np.arange(num_classes)) 49 | targets = one_hot_targets.reshape(*targets.shape[:-1], -1) 50 | probabilities = probabilities.reshape(*probabilities.shape[:-2], -1) 51 | 52 | elif probabilities.ndim != targets.ndim: 53 | raise ValueError("Shapes of probabilities and targets do not match. " 54 | "Must be either equal (binary classification) or probabilities " 55 | "must have exactly one dimension more (multi-class).") 56 | else: 57 | # binary predictions, no pre-processing to do 58 | pass 59 | 60 | if equal_size_bins: 61 | quantiles = np.linspace(0, 1, num_bins + 1) 62 | bin_edges = np.quantile(probabilities, quantiles) 63 | # explicitly set upper and lower edge to be 0/1 64 | bin_edges[0] = 0 65 | bin_edges[-1] = 1 66 | else: 67 | bin_edges = np.linspace(0, 1, num_bins + 1) 68 | 69 | # bin membership has to be checked with strict inequality to either the lower or upper 70 | # edge to avoid predictions exactly on a boundary to be included in multiple bins. 71 | # Therefore the exclusive boundary has to be slightly below or above the actual value 72 | # to avoid 0 or 1 predictions to not be assigned to any bin 73 | bin_edges[0] -= 1e-6 74 | lower = bin_edges[:-1] 75 | upper = bin_edges[1:] 76 | probabilities = probabilities.reshape(-1, 1) 77 | targets = targets.reshape(-1, 1) 78 | 79 | # set up masks for checking which bin probabilities fall into and whether they are above the minimum 80 | # threshold. I'm doing this by multiplication with those booleans rather than indexing in order to 81 | # allow for the code to be extensible for broadcasting 82 | bin_membership = (probabilities > lower) & (probabilities <= upper) 83 | exceeds_threshold = probabilities >= min_p 84 | 85 | bin_sizes = (bin_membership * exceeds_threshold).sum(-2) 86 | non_empty = bin_sizes > 0 87 | 88 | bin_probability = np.full(num_bins, np.nan) 89 | np.divide((probabilities * bin_membership * exceeds_threshold).sum(-2), bin_sizes, 90 | out=bin_probability, where=non_empty) 91 | 92 | bin_frequency = np.full(num_bins, np.nan) 93 | np.divide((targets * bin_membership * exceeds_threshold).sum(-2), bin_sizes, 94 | out=bin_frequency, where=non_empty) 95 | 96 | bin_weights = np.zeros(num_bins) 97 | np.divide(bin_sizes, bin_sizes.sum(), out=bin_weights, where=non_empty) 98 | 99 | return bin_probability, bin_frequency, bin_weights 100 | 101 | 102 | def expected_calibration_error(mean_probability_predicted: np.ndarray, observed_frequency: np.ndarray, 103 | bin_weights: np.ndarray): 104 | """Calculates the ECE, i.e. the average absolute difference between predicted probabilities and 105 | true observed frequencies for a classifier and its targets. Inputs are expected to be formatted 106 | as the return values from the calibration_curve method. NaNs in mean_probability_predicted and 107 | observed_frequency are ignored if the corresponding entry in bin_weights is 0.""" 108 | idx = bin_weights > 0 109 | return np.sum(np.abs(mean_probability_predicted[idx] - observed_frequency[idx]) * bin_weights[idx]) 110 | -------------------------------------------------------------------------------- /bnn/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | from .matrix_normal import * 2 | 3 | from . import kl 4 | -------------------------------------------------------------------------------- /bnn/distributions/kl.py: -------------------------------------------------------------------------------- 1 | import torch.distributions as dist 2 | 3 | from .matrix_normal import MatrixNormal, _batch_kf_mahalanobis 4 | 5 | 6 | # below are functions for calculating KL divergences with matrix normals. They are registered with pytorch's 7 | # distribution module, through the dist.register_kl decorator so that calling dist.kl_divergence will dispatch 8 | # the call to these functions. The KLs that use an efficient implementation are between two matrix normals and 9 | # between a matrix normal and a diagonal Normal, otherwise the distributions are converted to multivariate normals 10 | # which might be prohibitive in terms of memory usage 11 | @dist.register_kl(MatrixNormal, MatrixNormal) 12 | def _kl_matrixnormal_matrixnormal(p, q): 13 | if p.event_shape != q.event_shape: 14 | raise NotImplementedError("Cannot calculate Kl-divergence for matrix normals of different shape.") 15 | half_log_det_diff = q._half_log_det() - p._half_log_det() 16 | d = p.event_shape[0] * p.event_shape[1] 17 | row_trace_term = p.row_scale_tril.triangular_solve(q.row_scale_tril, upper=False)[0].pow(2).sum((-2, -1)) 18 | col_trace_term = p.col_scale_tril.triangular_solve(q.col_scale_tril, upper=False)[0].pow(2).sum((-2, -1)) 19 | trace_term = row_trace_term * col_trace_term 20 | mahalanobis_term = _batch_kf_mahalanobis(q.row_scale_tril, q.col_scale_tril, q.loc - p.loc) 21 | return half_log_det_diff + 0.5 * (trace_term + mahalanobis_term - d) 22 | 23 | 24 | @dist.register_kl(MatrixNormal, dist.Normal) 25 | def _kl_matrixnormal_normal(p, q): 26 | if p.event_shape != q.batch_shape[-2:]: 27 | raise ValueError("Cannot calculate KL-divergence if trailing batch dimensions of the factorized normal do not" 28 | "match the event shape of the matrix normal") 29 | p_halflogdet = p._half_log_det() 30 | q_halflogdet = q.scale.log().sum() 31 | diff_halflogdet = q_halflogdet - p_halflogdet 32 | 33 | d = p.event_shape[0] * p.event_shape[1] 34 | 35 | row_cov_diag = p.row_scale_tril.pow(2).sum(-1).unsqueeze(-1) 36 | col_cov_diag = p.col_scale_tril.pow(2).sum(-1).unsqueeze(-2) 37 | trace_term = (q.scale.pow(-2) * row_cov_diag * col_cov_diag).sum((-2, -1)) 38 | 39 | delta = q.loc - p.loc 40 | mahalanobis_term = delta.pow(2).div(q.scale.pow(2)).sum((-2, -1)) 41 | 42 | return diff_halflogdet + 0.5 * (trace_term - d + mahalanobis_term) 43 | 44 | 45 | @dist.register_kl(dist.Normal, MatrixNormal) 46 | def _kl_normal_matrixnormal(p, q): 47 | return dist.kl_divergence(dist.MultivariateNormal(p.loc.flatten(-2), scale_tril=p.scale.diag_embed()), 48 | q.to_multivariatenormal()) 49 | 50 | 51 | @dist.register_kl(MatrixNormal, dist.MultivariateNormal) 52 | def _kl_matrixnormal_multivariatenormal(p, q): 53 | return dist.kl_divergence(p.to_multivariatenormal(), q) 54 | 55 | 56 | @dist.register_kl(dist.MultivariateNormal, MatrixNormal) 57 | def _kl_multivariatenormal_matrixnormal(p, q): 58 | return dist.kl_divergence(p, q.to_multivariatenormal()) 59 | -------------------------------------------------------------------------------- /bnn/distributions/matrix_normal.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.distributions as dist 5 | from torch.distributions import constraints 6 | 7 | 8 | def kron(t1: torch.Tensor, t2: torch.Tensor): 9 | r"""Calculates the Kronecker product between batches of matrices, i.e. the elementwise product between every 10 | element of t1 and the entire matrix t2. Note that in contrast to the numpy implementation, the present one treats 11 | any leading dimensions before the first two as batch dimensions and calculates the batch of kronecker products. 12 | Numpy calculates the product along all dimensions, i.e. for a, b: m x n x k, whereas np.kron(a, b) returns an 13 | m^2 x n^2 x k^2 array.""" 14 | 15 | # a x b -> a x 1 x b x 1 16 | t1 = t1.unsqueeze(-1).unsqueeze(-3) 17 | # c x d -> 1 x c x 1 x d 18 | t2 = t2.unsqueeze(-2).unsqueeze(-4) 19 | # a x c x b x d 20 | t_stacked = t1 * t2 21 | # ac x bd 22 | return t_stacked.flatten(-4, -3).flatten(-2, -1) 23 | 24 | 25 | class _RealMatrix(constraints.Constraint): 26 | def check(self, value: torch.Tensor): 27 | return value.eq(value).all(-1).all(-1) 28 | 29 | 30 | real_matrix = _RealMatrix() 31 | 32 | 33 | def _batch_kf_mahalanobis(bLu: torch.Tensor, bLv: torch.Tensor, bX: torch.Tensor): 34 | r"""Calculates the squared Mahalanobis distance :math:`vec(\mathbf{x})^\top\mathbf{M}^{-1}vec(\mathbf{x})` 35 | for a Kronecker factored :math:`\mathbf{M} = \mathbf{U} \otimes \mathbf{V}` with the lower Cholesky factors 36 | :math:`\mathbf{U} = \mathbf{L}_u\mathbf{L}_u^\top` and :math:`\mathbf{V} = \mathbf{L}_v\mathbf{L}_v^\top` 37 | being provided as inputs. This implementation defines :math:`vec(\mathbf{X})` as the column vector of the 38 | concatenated rows of :math:`\mathbf{X}`, which means that the matrix-vector product for a Kronecker factored 39 | matrix can be calculated as 40 | :math:`\mathbf{U}\otimes\mathbf{V}vec(\mathbf{X}) = vec(\mathbf{U}\mathbf{X}\mathbf{V}^top`. This defintion 41 | of the vec operation is equivalent to using `.flatten` with the usual C/row-major order, whereas the definition 42 | e.g. on Wikipedia (see MatrixNormal for link) of concatenating the columns uses Fortran/column-major order.""" 43 | x1 = bX.transpose(-2, -1).triangular_solve(bLv, upper=False).solution 44 | x2 = x1.transpose(-2, -1).triangular_solve(bLu, upper=False).solution 45 | return x2.pow(2).flatten(-2).sum(-1) 46 | 47 | 48 | class MatrixNormal(dist.Distribution): 49 | r"""Basic implementation of a matrix normal distribution, which is equivalent to a multivariate normal with 50 | a Kronecker factored covariance matrix. Makes use of Kronecker product identities to support efficient sampling 51 | and calculation of log probabilities, entropy and KL divergences. 52 | 53 | See https://en.wikipedia.org/wiki/Matrix_normal_distribution for an overview. 54 | 55 | Besides the mean matrix, it expects the square roots of the row and 56 | column covariance matrices. Precision matrices are not currently supported, but could be incorporated in a 57 | similar manner to pytorch's multivariate normal distribution.""" 58 | 59 | arg_constraints = {"loc": real_matrix, 60 | "row_scale_tril": constraints.lower_cholesky, 61 | "col_scale_tril": constraints.lower_cholesky} 62 | support = constraints.real 63 | has_rsample = True 64 | 65 | def __init__(self, loc, row_scale_tril, col_scale_tril, validate_args=None): 66 | if loc.dim() < 2: 67 | raise ValueError("loc must be at least two-dimensional") 68 | rows, cols = loc.shape[-2:] 69 | if row_scale_tril.dim() < 2 or row_scale_tril.size(-1) != rows or row_scale_tril.size(-2) != rows: 70 | raise ValueError("row_scale_tril must be at-least two dimensional with the final two dimensions matching " 71 | "the penultimate dimension of loc") 72 | if col_scale_tril.dim() < 2 or col_scale_tril.size(-1) != cols or col_scale_tril.size(-2) != cols: 73 | raise ValueError("col_scale_tril must be at-least two dimensional with the final two dimensions matching " 74 | "the last dimension of loc") 75 | 76 | batch_shape, event_shape = loc.shape[:-2], loc.shape[-2:] 77 | super(MatrixNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args) 78 | 79 | self.loc = loc 80 | self.row_scale_tril = row_scale_tril 81 | self.col_scale_tril = col_scale_tril 82 | 83 | @property 84 | def num_rows(self): 85 | return self._event_shape[-2] 86 | 87 | @property 88 | def num_cols(self): 89 | return self._event_shape[-1] 90 | 91 | @property 92 | def mean(self): 93 | return self.loc 94 | 95 | @property 96 | def variance(self): 97 | return self.row_scale_tril.pow(2).sum(-1).unsqueeze(-1).mul( 98 | self.col_scale_tril.pow(2).sum(-1)).expand(self._batch_shape + self._event_shape) 99 | 100 | def _half_log_det(self): 101 | """Calculates the determinant of the covariance times 0.5""" 102 | row_half_log_det = self.row_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) 103 | col_half_log_det = self.col_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) 104 | return self.num_cols * row_half_log_det + self.num_rows * col_half_log_det 105 | 106 | def rsample(self, sample_shape=torch.Size()): 107 | shape = self._extended_shape(sample_shape) 108 | eps = torch.randn(shape, dtype=self.loc.dtype, device=self.loc.device) 109 | return self.loc + self.row_scale_tril.matmul(eps).matmul(self.col_scale_tril.transpose(-2, -1)) 110 | 111 | def log_prob(self, value): 112 | if self._validate_args: 113 | self._validate_sample(value) 114 | diff = value - self.loc 115 | M = _batch_kf_mahalanobis(self.row_scale_tril, self.col_scale_tril, diff) 116 | return -0.5 * (self.num_rows * self.num_cols * math.log(2 * math.pi) + M) - self._half_log_det() 117 | 118 | def entropy(self): 119 | H = 0.5 * self.num_rows * self.num_cols * (1.0 + math.log(2 * math.pi)) + self._half_log_det() 120 | if len(self._batch_shape) == 0: 121 | return H 122 | else: 123 | return H.expand(self._batch_shape) 124 | 125 | def to_multivariatenormal(self): 126 | """Converts this to a 'flat' multivariate normal. This is mainly for testing and to support KL divergences to 127 | multivariate normals, but typically this will be highly inefficient both in terms of memory and computation.""" 128 | return dist.MultivariateNormal(self.loc.flatten(-2), scale_tril=kron(self.row_scale_tril, self.col_scale_tril)) 129 | -------------------------------------------------------------------------------- /bnn/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | 3 | from . import mixins 4 | from . import nets 5 | -------------------------------------------------------------------------------- /bnn/nn/mixins/__init__.py: -------------------------------------------------------------------------------- 1 | from .variational import * 2 | -------------------------------------------------------------------------------- /bnn/nn/mixins/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class BayesianMixin(ABC, nn.Module): 7 | 8 | @abstractmethod 9 | def parameter_loss(self): 10 | """Calculates generic parameter-dependent loss. For a probabilistic module with some prior over the parameters, 11 | e.g. for MAP inference or MCMC sampling, this would be the negative log prior, for Variational inference the 12 | KL divergence between approximate posterior and prior.""" 13 | raise NotImplementedError 14 | 15 | @abstractmethod 16 | def init_from_deterministic_params(self, param_dict): 17 | """Initializes from the parameters of a deterministic network. For a variational module, this might mean 18 | setting the mean of the approximate posterior to those parameters, whereas a MAP/MCMC module would simply 19 | copy the parameter values.""" 20 | raise NotImplementedError 21 | -------------------------------------------------------------------------------- /bnn/nn/mixins/variational/__init__.py: -------------------------------------------------------------------------------- 1 | from .fcg import * 2 | from .ffg import * 3 | from .inducing import * 4 | -------------------------------------------------------------------------------- /bnn/nn/mixins/variational/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | 4 | from ..base import BayesianMixin 5 | 6 | 7 | class VariationalMixin(BayesianMixin): 8 | 9 | def parameter_loss(self): 10 | return self.kl_divergence() 11 | 12 | @abstractmethod 13 | def kl_divergence(self): 14 | raise NotImplementedError 15 | 16 | -------------------------------------------------------------------------------- /bnn/nn/mixins/variational/fcg.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import reduce 3 | import operator 4 | from typing import Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.distributions as dist 10 | from torch.nn.utils import parameters_to_vector 11 | 12 | 13 | from .base import VariationalMixin 14 | 15 | 16 | __all__ = ["FCGMixin"] 17 | 18 | 19 | def _prod(iterable): 20 | return reduce(operator.mul, iterable, 1) 21 | 22 | 23 | def _normal_sample(mean, scale_tril): 24 | return mean + scale_tril @ torch.randn_like(mean) 25 | 26 | 27 | class FCGMixin(VariationalMixin): 28 | """Variational module that places a multivariate Gaussian with full covariance jointly 29 | over .weight and .bias attributes. The forward pass always explicitly samples the weights.""" 30 | def __init__(self, *args, prior_mean: float = 0., prior_weight_sd: Union[float, str] = 1., 31 | prior_bias_sd: float = 1., init_sd: float = 1e-4, nonlinearity_scale: float = 1., **kwargs): 32 | super().__init__(*args, **kwargs) 33 | self.has_bias = self.bias is not None 34 | self.weight_shape = self.weight.shape 35 | self.bias_shape = self.bias.shape if self.has_bias else None 36 | 37 | current_parameters = parameters_to_vector(self.parameters()) 38 | num_params = current_parameters.numel() 39 | 40 | self.mean = nn.Parameter(current_parameters.data.detach().clone()) 41 | _init_sd = math.log(math.expm1(init_sd)) 42 | self._scale_diag = nn.Parameter(torch.full((num_params,), _init_sd)) 43 | self._scale_tril = nn.Parameter(torch.zeros(num_params, num_params)) 44 | 45 | if prior_weight_sd == "neal": 46 | input_dim = _prod(self.weight.shape[1:]) 47 | prior_weight_sd = input_dim ** -0.5 48 | 49 | prior_weight_sd *= nonlinearity_scale 50 | 51 | prior_weight_sd_tensor = torch.full((self.weight.flatten().shape), prior_weight_sd) 52 | if self.has_bias: 53 | prior_bias_sd_tensor = torch.full((self.bias.flatten().shape), prior_bias_sd) 54 | prior_sd_diag = torch.cat((prior_weight_sd_tensor, prior_bias_sd_tensor)) 55 | else: 56 | prior_sd_diag = prior_weight_sd_tensor 57 | 58 | del self._parameters["weight"] 59 | if self.has_bias: 60 | del self._parameters["bias"] 61 | self.assign_params(self.mean.data) 62 | 63 | self.register_buffer("prior_mean", torch.full((num_params,), prior_mean)) 64 | self.register_buffer("prior_scale_tril", prior_sd_diag.diag_embed()) 65 | 66 | def extra_repr(self): 67 | s = super().extra_repr() 68 | m = self.prior_mean.data[0] 69 | if torch.allclose(m, self.prior_mean): 70 | s += f", prior mean={m.item():.2f}" 71 | sd = self.prior_scale_tril[0, 0] 72 | if torch.allclose(sd, self.prior_scale_tril) and torch.allclose(self.prior_scale_tril.tril(diagonal=-1), 0): 73 | s += f", prior sd={sd.item():.2f}" 74 | return s 75 | 76 | def init_from_deterministic_params(self, param_dict): 77 | weight = param_dict["weight"] 78 | bias = param_dict.get("bias") 79 | with torch.no_grad(): 80 | mean = weight.flatten() 81 | if bias is not None: 82 | mean = torch.cat([mean, bias.flatten()]) 83 | self.mean.data.copy_(mean) 84 | 85 | @property 86 | def parameter_tensors(self): 87 | if self.has_bias: 88 | return self.weight, self.bias 89 | return (self.weight,) 90 | 91 | @property 92 | def scale_tril(self): 93 | return F.softplus(self._scale_diag).diagflat() + torch.tril(self._scale_tril, diagonal=-1) 94 | 95 | @property 96 | def parameter_distribution(self): 97 | return dist.MultivariateNormal(self.mean, scale_tril=self.scale_tril) 98 | 99 | @property 100 | def prior_distribution(self): 101 | return dist.MultivariateNormal(self.prior_mean, scale_tril=self.prior_scale_tril) 102 | 103 | def assign_params(self, parameters: torch.Tensor): 104 | if self.has_bias: 105 | num_bias_params = _prod(self.bias_shape) 106 | self.weight = parameters[:-num_bias_params].view(self.weight_shape) 107 | self.bias = parameters[-num_bias_params:].view(self.bias_shape) 108 | else: 109 | self.weight = parameters.view(self.weight_shape) 110 | 111 | def forward(self, x: torch.Tensor): 112 | parameter_sample = _normal_sample(self.mean, self.scale_tril) 113 | self.assign_params(parameter_sample) 114 | return super().forward(x) 115 | 116 | def kl_divergence(self): 117 | return dist.kl_divergence(self.parameter_distribution, self.prior_distribution) 118 | -------------------------------------------------------------------------------- /bnn/nn/mixins/variational/ffg.py: -------------------------------------------------------------------------------- 1 | import operator 2 | from functools import reduce 3 | import math 4 | from typing import Optional, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.distributions as dist 10 | 11 | 12 | from .base import VariationalMixin 13 | 14 | 15 | EPS = 1e-6 16 | 17 | 18 | __all__ = ["FFGMixin"] 19 | 20 | 21 | def _prod(iterable): 22 | return reduce(operator.mul, iterable, 1) 23 | 24 | 25 | def _normal_sample(mean, sd): 26 | return mean + torch.randn_like(mean) * sd 27 | 28 | 29 | class FFGMixin(VariationalMixin): 30 | """Variational module that places a fully factorized Gaussian over .weight and .bias attributes. 31 | In the forward pass it marginalizes over the weights and directly samples the outputs from a Gaussian 32 | when in training mode, in testing mode it simply samples the weights and performs the forward pass 33 | as usual (in either case the outputs come from the same distribution, but the gradients have lower 34 | variance for the training mode; see https://arxiv.org/abs/1506.02557). This Mixin class can be combined 35 | with linear and convolutions layers, probably also with deconvolutional ones (need to check the math first). 36 | """ 37 | 38 | def __init__(self, *args, prior_mean: float = 0., prior_weight_sd: Union[float, str] = 1., 39 | prior_bias_sd: float = 1., init_sd: float = 1e-4, max_sd: Optional[float] = None, 40 | local_reparameterization: bool = True, nonlinearity_scale: float = 1., 41 | sqrt_width_scaling: bool = False, **kwargs): 42 | super().__init__(*args, **kwargs) 43 | self.has_bias = self.bias is not None 44 | self.local_reparameterization = local_reparameterization 45 | self.max_sd = max_sd 46 | 47 | # I use a softplus to ensure that the sd is positive, so need to map it through 48 | # the inverse for initialization of the parameter 49 | _init_sd = math.log(math.expm1(init_sd)) 50 | self.weight_mean = nn.Parameter(self.weight.data.detach().clone()) 51 | self._weight_sd = nn.Parameter(torch.full_like(self.weight.data, _init_sd)) 52 | if self.has_bias: 53 | self.bias_mean = nn.Parameter(self.bias.data.detach().clone()) 54 | self._bias_sd = nn.Parameter(torch.full_like(self.bias.data, _init_sd)) 55 | else: 56 | self.register_parameter("bias_mean", None) 57 | self.register_parameter("_bias_sd", None) 58 | 59 | del self._parameters["weight"] 60 | if self.has_bias: 61 | del self._parameters["bias"] 62 | 63 | self.weight = self.weight_mean.data 64 | self.bias = self.bias_mean.data if self.has_bias else None 65 | 66 | if sqrt_width_scaling: 67 | input_dim = _prod(self.weight_mean.shape[1:]) + int(self.has_bias) 68 | prior_weight_sd /= input_dim ** 0.5 69 | prior_bias_sd /= input_dim ** 0.5 70 | 71 | prior_weight_sd *= nonlinearity_scale 72 | self.register_buffer("prior_weight_mean", torch.full_like(self.weight_mean, prior_mean)) 73 | self.register_buffer("prior_weight_sd", torch.full_like(self.weight_sd, prior_weight_sd)) 74 | 75 | prior_bias_mean = torch.full_like(self.bias_mean, prior_mean) if self.has_bias else None 76 | prior_bias_sd = torch.full_like(self.bias_sd, prior_bias_sd) if self.has_bias else None 77 | self.register_buffer("prior_bias_mean", prior_bias_mean) 78 | self.register_buffer("prior_bias_sd", prior_bias_sd) 79 | 80 | def extra_repr(self): 81 | s = super().extra_repr() 82 | m = self.prior_weight_mean.data.flatten()[0] 83 | if torch.allclose(m, self.prior_weight_mean) and (not self.has_bias or 84 | torch.allclose(m, self.prior_bias_mean)): 85 | s += f", prior mean={m.item():.2f}" 86 | sd = self.prior_weight_sd.flatten()[0] 87 | if torch.allclose(sd, self.prior_weight_sd) and (not self.has_bias or torch.allclose(sd, self.prior_bias_sd)): 88 | s += f", prior sd={sd.item():.2f}" 89 | return s 90 | 91 | def init_from_deterministic_params(self, param_dict): 92 | weight = param_dict["weight"] 93 | bias = param_dict.get("bias") 94 | with torch.no_grad(): 95 | self.weight_mean.data.copy_(weight.detach()) 96 | if bias is not None: 97 | self.bias_mean.data.copy_(bias.detach()) 98 | 99 | @property 100 | def weight_sd(self): 101 | weight_sd = F.softplus(self._weight_sd) 102 | return weight_sd if self.max_sd is None else weight_sd.clamp(0, self.max_sd) 103 | 104 | @property 105 | def bias_sd(self): 106 | if self.has_bias: 107 | bias_sd = F.softplus(self._bias_sd) 108 | return bias_sd if self.max_sd is None else bias_sd.clamp(0, self.max_sd) 109 | return None 110 | 111 | @property 112 | def weight_dist(self): 113 | return dist.Normal(self.weight_mean, self.weight_sd) 114 | 115 | @property 116 | def bias_dist(self): 117 | if self.has_bias: 118 | return dist.Normal(self.bias_mean, self.bias_sd) 119 | return None 120 | 121 | @property 122 | def prior_weight_dist(self): 123 | return dist.Normal(self.prior_weight_mean, self.prior_weight_sd) 124 | 125 | @property 126 | def prior_bias_dist(self): 127 | if self.has_bias: 128 | return dist.Normal(self.prior_bias_mean, self.prior_bias_sd) 129 | return None 130 | 131 | def kl_divergence(self): 132 | kl = dist.kl_divergence(self.weight_dist, self.prior_weight_dist).sum() 133 | if self.has_bias: 134 | kl += dist.kl_divergence(self.bias_dist, self.prior_bias_dist).sum() 135 | return kl 136 | 137 | def forward(self, x: torch.Tensor): 138 | if self.local_reparameterization: 139 | # use local reparameterization during training, i.e. sample the linear outputs 140 | # a ~ N(x^T \mu_w + \mu_b, (x^2)^T \sigma_W^2 + \sigma_b^2) 141 | self.weight = self.weight_mean.add(0) 142 | self.bias = self.bias_mean.add(0) if self.has_bias else None 143 | output_mean = super().forward(x) 144 | self.weight = self.weight_sd.pow(2) 145 | self.bias = self.bias_sd.pow(2) if self.has_bias else None 146 | output_var = super().forward(x.pow(2)) 147 | output_var += output_var.abs().mul(output_var.lt(0.).float()).detach() + EPS 148 | return _normal_sample(output_mean, output_var.sqrt()) 149 | else: 150 | # sample the weights during testing, i.e. W ~ N(\mu_W, \sigma_W^2), b ~ N(\mu_b, \sigma_b^2) 151 | # and calculate x^T W + b 152 | self.weight = _normal_sample(self.weight_mean, self.weight_sd) 153 | self.bias = _normal_sample(self.bias_mean, self.bias_sd) if self.has_bias else None 154 | return super().forward(x) 155 | -------------------------------------------------------------------------------- /bnn/nn/mixins/variational/inducing.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.distributions as dist 7 | 8 | 9 | from .base import VariationalMixin 10 | from .utils import EPS, prod, inverse_softplus, vec_to_chol 11 | from ....distributions import MatrixNormal 12 | 13 | 14 | __all__ = [ 15 | "InducingDeterministicMixin", 16 | "InducingMixin" 17 | ] 18 | 19 | 20 | def _jittered_cholesky(m): 21 | j = EPS * m.detach().diagonal().mean() * torch.eye(m.shape[-1], device=m.device) 22 | return m.add(j).cholesky() 23 | 24 | 25 | class _InducingBase(VariationalMixin): 26 | 27 | def __init__(self, *args, inducing_rows=None, inducing_cols=None, prior_sd=1., init_lamda=1e-4, learn_lamda=True, 28 | max_lamda=None, q_inducing="diagonal", whitened_u=True, max_sd_u=None, sqrt_width_scaling=False, 29 | cache_cholesky=False, ensemble_size=8, **kwargs): 30 | """Base class for inducing weight mixins. Implements all parameter-related logic as well as methods for 31 | sampling parameters and converting them to the correct shape, the forward pass and calculating KL divergences. 32 | Subclasses only need to implement the `conditional_mean_and_noise` method to calculate the mean and noise 33 | as in the reparameterization of a Gaussian for the conditional distribution on the weights. 34 | The Mixin class is expected to be inherited from in a class that also inherits from a torch.nn.Module that has 35 | .weight and .bias parameter attributes, such as Linear and Conv layers. 36 | 37 | Args: 38 | inducing_rows (int): number of inducing rows, i.e. first dimension of U 39 | inducing_cols (int): number of inducing columns , i.e. second dimension of U 40 | prior_sd (float): standard deviation of the prior on W. 41 | init_lamda (float): initial value for the lamda parameter that down-scale the conditional covariance of the 42 | inference distribution on W (typo is intentional since lambda is a keyword in python) 43 | learn_lamda (bool): whether to learn lamda or keep it constant and the initial value 44 | max_lamda (float): maximum value of lamda 45 | q_inducing (str): 'diagonal', 'matrix', 'full' or 'ensemble'. Determines the structure of the covariance 46 | that is used for the Gaussian that is the variational posterior on U if not 'ensemble', mixture of 47 | Delta distributions otherwise. 48 | whitened_u (bool): whether to whiten U, i.e. make the marginal prior a standard normal 49 | max_sd_u (float): maximum value of the scale of the variational posterior on U if q_inducing is "diagonal" 50 | sqrt_width_scaling (bool): whether or not to divide prior_sd by the square root of the number of inputs. 51 | Setting this flag to true and the prior_sd to 1 corresponds to the 'neal' prior scale. 52 | """ 53 | super().__init__(*args, **kwargs) 54 | self.has_bias = self.bias is not None 55 | self.whitened_u = whitened_u 56 | self._weight_shape = self.weight.shape 57 | 58 | # _u caches the per-layer u sample -- this will be needed for the between-layer correlations 59 | self._u = None 60 | 61 | self._caching_mode = False 62 | self.reset_cache() 63 | 64 | self._i = 0 65 | 66 | del self.weight 67 | if self.has_bias: 68 | del self.bias 69 | 70 | self._d_out = self._weight_shape[0] 71 | self._d_in = prod(self._weight_shape[1:]) + int(self.has_bias) 72 | 73 | # if no number of inducing rows or columns is given, use the square root of the corresponding dimension 74 | if inducing_rows is None: 75 | inducing_rows = int(self._d_out ** 0.5) 76 | if inducing_cols is None: 77 | inducing_cols = int(self._d_in ** 0.5) 78 | 79 | self.inducing_rows = inducing_rows 80 | self.inducing_cols = inducing_cols 81 | 82 | if sqrt_width_scaling: 83 | prior_sd /= self._d_in ** 0.5 84 | 85 | self.prior_sd = prior_sd 86 | self.max_sd_u = max_sd_u 87 | 88 | # set attributes with the size of the augmented space, i.e. total_cols is the number of columns 89 | # in weight space plus the number of inducing columns 90 | self.total_cols = self._d_in + self.inducing_cols 91 | self.total_rows = self._d_out + self.inducing_rows 92 | 93 | self.z_row = nn.Parameter(torch.randn(inducing_rows, self._d_out) * inducing_rows ** -0.5) 94 | self.z_col = nn.Parameter(torch.randn(inducing_cols, self._d_in) * inducing_cols ** -0.5) 95 | 96 | d = inducing_rows * inducing_cols 97 | if q_inducing == "full": 98 | # use a multivariate Gaussian with full covariance for inference on U, parameterized by the 99 | # Cholesky decomposition of the covariance 100 | self.inducing_mean = nn.Parameter(torch.randn(d)) 101 | self._inducing_scale_tril = nn.Parameter(torch.cat([ 102 | torch.full((d,), inverse_softplus(1e-3)), torch.zeros((d * (d - 1)) // 2) 103 | ])) 104 | self.register_parameter("_inducing_row_scale_tril", None) 105 | self.register_parameter("_inducing_col_scale_tril", None) 106 | self.register_parameter("_inducing_sd", None) 107 | elif q_inducing == "matrix": 108 | # use a matrix Gaussian with full row and column covariance for inference on U, both parameterized 109 | # by the corresponding Cholesky decomposition 110 | self.inducing_mean = nn.Parameter(torch.randn(inducing_rows, inducing_cols)) 111 | self._inducing_row_scale_tril = nn.Parameter(torch.cat([ 112 | torch.full((inducing_rows,), inverse_softplus(1e-3)), 113 | torch.zeros((inducing_rows * (inducing_rows - 1)) // 2) 114 | ])) 115 | self._inducing_col_scale_tril = nn.Parameter(torch.cat([ 116 | torch.full((inducing_cols,), inverse_softplus(1e-3)), 117 | torch.zeros((inducing_cols * (inducing_cols - 1)) // 2) 118 | ])) 119 | self.register_parameter("_inducing_scale_tril", None) 120 | self.register_parameter("_inducing_sd", None) 121 | elif q_inducing == "diagonal": 122 | self.inducing_mean = nn.Parameter(torch.randn(inducing_rows, inducing_cols)) 123 | self._inducing_sd = nn.Parameter(torch.full((inducing_rows, inducing_cols), inverse_softplus(1e-3))) 124 | self.register_parameter("_inducing_scale_tril", None) 125 | self.register_parameter("_inducing_row_scale_tril", None) 126 | self.register_parameter("_inducing_col_scale_tril", None) 127 | elif q_inducing == "ensemble": 128 | tmp = torch.randn(inducing_rows, inducing_cols) 129 | self.inducing_mean = nn.Parameter(tmp + torch.randn(ensemble_size, inducing_rows, inducing_cols) * 0.1) 130 | #self.inducing_mean = nn.Parameter(torch.stack([tmp for _ in range(ensemble_size)], dim=0)) 131 | #self.inducing_mean = nn.Parameter(torch.randn(ensemble_size, inducing_rows, inducing_cols) * 0.01) 132 | self.register_parameter("_inducing_scale_tril", None) 133 | self.register_parameter("_inducing_row_scale_tril", None) 134 | self.register_parameter("_inducing_col_scale_tril", None) 135 | self.register_parameter("_inducing_sd", None) 136 | else: 137 | raise ValueError("q_inducing must be one of 'full', 'matrix', 'diagonal', 'ensemble'.") 138 | self.q_inducing = q_inducing 139 | 140 | self.max_lamda = max_lamda 141 | if learn_lamda: 142 | # unconstrained parameterization of the lambda parameter which rescales the covariance of the 143 | # conditional Gaussian on W given U. Positivity is ensured through softplus function 144 | self._lamda = nn.Parameter(torch.tensor(inverse_softplus(init_lamda))) 145 | else: 146 | self.register_buffer("_lamda", torch.tensor(inverse_softplus(init_lamda))) 147 | 148 | self._d_row = nn.Parameter(torch.full((inducing_rows,), inverse_softplus(1e-3))) 149 | self._d_col = nn.Parameter(torch.full((inducing_cols,), inverse_softplus(1e-3))) 150 | 151 | self.weight, self.bias = self.sample_shaped_parameters() 152 | self._caching_mode = cache_cholesky 153 | 154 | def reset_cache(self): 155 | self._L_r_cached = None 156 | self._L_c_cached = None 157 | self._prior_inducing_row_scale_tril_cached = None 158 | self._prior_inducing_col_scale_tril_cached = None 159 | self._row_transform_cached = None 160 | self._col_transform_cached = None 161 | 162 | def extra_repr(self): 163 | """Helper function for displaying and printing nn.Module objects.""" 164 | s = super().extra_repr() 165 | s += f", inducing_cols={self.inducing_cols}, inducing_rows={self.inducing_rows}" 166 | return s 167 | 168 | def init_from_deterministic_params(self, param_dict): 169 | raise NotImplementedError 170 | 171 | def forward(self, x): 172 | self.weight, self.bias = self.sample_shaped_parameters() 173 | return super().forward(x) 174 | 175 | def sample_shaped_parameters(self): 176 | parameters = self.sample_parameters() 177 | if self.has_bias: 178 | w = parameters[:, :-1] 179 | b = parameters[:, -1] 180 | else: 181 | w = parameters 182 | b = None 183 | return w.view(self._weight_shape), b 184 | 185 | def sample_parameters(self): 186 | # self._u will be set from outside the forward pass if we're using global inducing weights 187 | # if that is the case, we use the sample that has been written into the instance attribute, 188 | # otherwise sample it here independently of other layers 189 | if self._u is None: 190 | self._u = self.sample_u() 191 | u = self._u 192 | self._u = None 193 | row_chol = self.prior_inducing_row_scale_tril 194 | col_chol = self.prior_inducing_col_scale_tril 195 | if self.whitened_u: 196 | u = row_chol @ u @ col_chol.t() 197 | return self.sample_conditional_parameters(u, row_chol, col_chol) 198 | 199 | def sample_u(self): 200 | if self.q_inducing == "ensemble": 201 | i = self._i 202 | self._i = (self._i + 1) % len(self.inducing_mean) 203 | u = self.inducing_mean[i] 204 | else: 205 | u = self.inducing_dist.rsample() 206 | return u.view(self.inducing_rows, self.inducing_cols) 207 | 208 | def sample_conditional_parameters(self, u, row_chol, col_chol): 209 | # mean and noise are the conditional mean and additive noise s.t. 210 | # W = mean + noise follows a Gaussian with covariance equal to the covariance of noise 211 | mean, noise = self.conditional_mean_and_noise(u, row_chol, col_chol) 212 | rescaled_noise = self.lamda * noise 213 | # self.prior_sd is the standard deviation of the prior on W set in the init method 214 | return self.prior_sd * (mean + rescaled_noise) 215 | 216 | def kl_divergence(self): 217 | # TODO potentially add log prob of ensemble u positions to kl 218 | 219 | # kl divergence is the sum of the kl divergence in U space and the expected KL in W space. 220 | # The latter can be evaluated analytically for our cases, since everything is Gaussian 221 | if self.inducing_dist is None: 222 | inducing_kl = 0. 223 | else: 224 | inducing_kl = dist.kl_divergence(self.inducing_dist, self.inducing_prior_dist).sum() 225 | conditional_kl = self.conditional_kl_divergence() 226 | return inducing_kl + conditional_kl 227 | 228 | def conditional_kl_divergence(self): 229 | """Computes KL[q(W|U) || p(W|U)]. q and p are assumed to be Gaussians that share the same mean 230 | and only differ in covariance through a rescaling by self.lambda^2.""" 231 | return self._d_in * self._d_out * (0.5 * self.lamda ** 2 - self.lamda.log() - 0.5) 232 | 233 | def compute_row_transform(self, row_chol): 234 | # Z_r^T (Z_r Z_r^T + D_r^2)^-1 235 | if self._row_transform_cached is None: 236 | row_transform = self.z_row.cholesky_solve(row_chol).t() 237 | else: 238 | row_transform = self._row_transform_cached 239 | 240 | if self.caching_mode: 241 | self._row_transform_cached = row_transform 242 | return row_transform 243 | 244 | def compute_col_transform(self, col_chol): 245 | # (Z_c Z_c^T + D_c^2)^-1 Z_c 246 | if self._col_transform_cached is None: 247 | col_transform = self.z_col.cholesky_solve(col_chol) 248 | else: 249 | col_transform = self._col_transform_cached 250 | 251 | if self.caching_mode: 252 | self._col_transform_cached = col_transform 253 | return col_transform 254 | 255 | @abstractmethod 256 | def conditional_mean_and_noise(self, u, row_chol, col_chol): 257 | pass 258 | 259 | @property 260 | def caching_mode(self): 261 | return self._caching_mode 262 | 263 | @caching_mode.setter 264 | def caching_mode(self, mode): 265 | self._caching_mode = mode 266 | if mode is False: 267 | self.reset_cache() 268 | 269 | # below are properties that transform the unconstrained pytorch parameters into constrained space 270 | @property 271 | def d_row(self): 272 | return F.softplus(self._d_row).diag_embed() 273 | 274 | @property 275 | def d_col(self): 276 | return F.softplus(self._d_col).diag_embed() 277 | 278 | @property 279 | def lamda(self): 280 | return F.softplus(self._lamda).clamp(0., self.max_lamda) if self._lamda is not None else None 281 | 282 | @property 283 | def prior_inducing_row_cov(self): 284 | return self.z_row @ self.z_row.t() + self.d_row.pow(2) 285 | 286 | @property 287 | def prior_inducing_row_scale_tril(self): 288 | if self._prior_inducing_row_scale_tril_cached is not None: 289 | prior_inducing_row_scale_tril = self._prior_inducing_row_scale_tril_cached 290 | else: 291 | prior_inducing_row_scale_tril = _jittered_cholesky(self.prior_inducing_row_cov) 292 | 293 | if self.caching_mode: 294 | self._prior_inducing_row_scale_tril_cached = prior_inducing_row_scale_tril 295 | return prior_inducing_row_scale_tril 296 | 297 | @property 298 | def prior_inducing_col_cov(self): 299 | return self.z_col @ self.z_col.t() + self.d_col.pow(2) 300 | 301 | @property 302 | def prior_inducing_col_scale_tril(self): 303 | if self._prior_inducing_col_scale_tril_cached is not None: 304 | prior_inducing_col_scale_tril = self._prior_inducing_col_scale_tril_cached 305 | else: 306 | prior_inducing_col_scale_tril = _jittered_cholesky(self.prior_inducing_col_cov) 307 | 308 | if self.caching_mode: 309 | self._prior_inducing_col_scale_tril_cached = prior_inducing_col_scale_tril 310 | return prior_inducing_col_scale_tril 311 | 312 | @property 313 | def inducing_scale_tril(self): 314 | return vec_to_chol(self._inducing_scale_tril) if self._inducing_scale_tril is not None else None 315 | 316 | @property 317 | def inducing_row_scale_tril(self): 318 | return vec_to_chol(self._inducing_row_scale_tril) if self._inducing_row_scale_tril is not None else None 319 | 320 | @property 321 | def inducing_col_scale_tril(self): 322 | return vec_to_chol(self._inducing_col_scale_tril) if self._inducing_col_scale_tril is not None else None 323 | 324 | @property 325 | def inducing_sd(self): 326 | return F.softplus(self._inducing_sd).clamp(0., self.max_sd_u) if self._inducing_sd is not None else None 327 | 328 | @property 329 | def inducing_prior_dist(self): 330 | if self.inducing_mean is None: 331 | # If the inducing_mean has been set to None, i.e. we're using global inducing weights, the marginal prior 332 | # distribution calculated below is no longer correct (as we're conditioning on some global inducing 333 | # weight), so return None. We could also make the forward hook for the global inducing weight assign 334 | # the marginal distribution on the per-layer inducing weights of each layer conditioned on the sample 335 | # for the global inducing weights to an _inducing_prior_dist attribute and return that here, however 336 | # I can't think of an actual use case for this 337 | return None 338 | 339 | loc = self.z_row.new_zeros(self.inducing_rows, self.inducing_cols) 340 | if self.whitened_u: 341 | # after whitening the prior on u it is a standard Gaussian N(0, 1), below parameterising q 342 | # in different ways for computational efficiency 343 | if self.q_inducing == "full": 344 | cov = torch.eye(self.inducing_rows * self.inducing_cols, device=self.device) 345 | return dist.MultivariateNormal(loc.flatten(), cov) 346 | elif self.q_inducing == "matrix": 347 | row_scale_tril = torch.eye(self.inducing_rows, device=self.device) 348 | col_scale_tril = torch.eye(self.inducing_cols, device=self.device) 349 | return MatrixNormal(loc, row_scale_tril, col_scale_tril) 350 | else: # self.q_inducing == "diagonal" 351 | scale = self.z_row.new_ones(self.inducing_rows, self.inducing_cols) 352 | return dist.Normal(loc, scale) 353 | else: 354 | return MatrixNormal(loc, self.prior_inducing_row_scale_tril, self.prior_inducing_col_scale_tril) 355 | 356 | @property 357 | def inducing_dist(self): 358 | # TODO potentially return mixture of deltas for q_inducing == "ensemble" 359 | if self.inducing_mean is None or self.q_inducing == "ensemble": 360 | return None 361 | 362 | if self.q_inducing == "full": 363 | return dist.MultivariateNormal(self.inducing_mean, scale_tril=self.inducing_scale_tril) 364 | elif self.q_inducing == "matrix": 365 | return MatrixNormal(self.inducing_mean, self.inducing_row_scale_tril, self.inducing_col_scale_tril) 366 | else: # self.q_inducing == "diagonal" 367 | return dist.Normal(self.inducing_mean, self.inducing_sd) 368 | 369 | @property 370 | def device(self): 371 | return self.z_row.device 372 | 373 | 374 | class InducingDeterministicMixin(_InducingBase): 375 | """Inducing model where we only sample u and deterministiclly transform it into the weights.""" 376 | 377 | def conditional_mean_and_noise(self, u, row_chol, col_chol): 378 | # Z_r^T (Z_r Z_r^T + D_r^2)^-1 379 | row_transform = self.compute_row_transform(row_chol) 380 | # (Z_c Z_c^T + D_c^2)^-1 Z_c 381 | col_transform = self.compute_col_transform(col_chol) 382 | M_w = row_transform @ u @ col_transform 383 | return M_w, torch.zeros_like(M_w) 384 | 385 | 386 | class InducingMixin(_InducingBase): 387 | """Inducing mixin which uses Matheron's rule to sample from the conditional multivariate normal on W, i.e. 388 | sample from the joint and linearly transform the samples to get the correct mean and covariance. This is necessary 389 | since the covariance of W is a difference of two Kronecker products, i.e. has no simple structure.""" 390 | 391 | def conditional_mean_and_noise(self, u, row_chol, col_chol): 392 | # Z_r^T (Z_r Z_r^T + D_r^2)^-1 -- row_chol is the Cholesky of Z_r Z_r^T + D_r^2 393 | row_transform = self.compute_row_transform(row_chol) 394 | # (Z_c Z_c^T + D_c^2)^-1 Z_c -- col_chol is the Cholesky of Z_c Z_c^T + D_c^2 395 | col_transform = self.compute_col_transform(col_chol) 396 | 397 | M_w = row_transform @ u @ col_transform 398 | 399 | # Below is the implementation of Matheron's rule where we transform a sample from the joint 400 | # p(W, U) such that the resulting M_w + noise_term follows the conditional Normal p(W|U) 401 | # 402 | # Sampling from the joint p(W, U) cannot be done directly in an efficient manner since it is 403 | # a multivariate normal with a covariance that has Khatri-Rao product structure (Kronecker factored blocks). 404 | # Instead we augment the joint variable to: 405 | # W U_c 406 | # U_r U 407 | # such that the joint variable has a matrix normal distribution 408 | # We then only compute the W and U samples. The latter can be structured into a sum of four blocks 409 | # of noise terms. The first term is a transformation of the noise needed for W, however two of the 410 | # remaining three terms are projections from the corresponding dimension (rows or columns) of W into 411 | # the smaller dimension of U by Z. Hence the noise can also be seen as coming from a lower dimensional 412 | # matrix normal distribution with covariance ZZ^T. Hence we calculate the Cholesky of Z and directly 413 | # sample in U space to reduce noise 414 | e1 = self.z_row.new_empty(self._d_out, self._d_in).normal_() 415 | e2, e3, e4 = self.z_row.new_empty(3, self.inducing_rows, self.inducing_cols).normal_() 416 | 417 | w_bar = e1 418 | 419 | if self._L_r_cached is None: 420 | # either not in caching mode or Cholesky has not been computed yet 421 | L_r = _jittered_cholesky(self.z_row.mm(self.z_row.t())) 422 | else: 423 | L_r = self._L_r_cached 424 | 425 | if self._L_c_cached is None: 426 | # either not in caching mode or Cholesky has not been computed yet 427 | L_c = _jittered_cholesky(self.z_col.mm(self.z_col.t())) 428 | else: 429 | L_c = self._L_c_cached 430 | 431 | if self.caching_mode: 432 | self._L_r_cached = L_r 433 | self._L_c_cached = L_c 434 | 435 | t1 = self.z_row @ e1 @ self.z_col.t() 436 | t2 = L_r @ e2 @ self.d_col 437 | t3 = self.d_row @ e3 @ L_c.t() 438 | t4 = self.d_row @ e4 @ self.d_col 439 | u_bar = t1 + t2 + t3 + t4 440 | 441 | noise_term = w_bar - row_transform @ u_bar @ col_transform 442 | 443 | return M_w, noise_term 444 | 445 | 446 | def _iter_inducing_modules(module): 447 | yield from filter(lambda m: isinstance(m, _InducingBase), module.modules()) 448 | 449 | 450 | def register_global_inducing_weights_(module, inducing_rows, inducing_cols, cat_dim=0, **inducing_kwargs): 451 | """Function for adding global inducing weights to an nn.Module that contains inducing modules to add correlation 452 | between their inducing weights. For this we need to concatenate/stack the per-layer inducing weights along either 453 | the rows or columns. cat_dim indicates which dimension we concatenate (0 for rows, 1 for columns) so the other 454 | dimension needs to be the same for all inducing layers of module. 455 | 456 | Args: 457 | module (nn.Module): A pytorch module that contains at least two submodules that inherit from _InducingBase. 458 | The variational parameters of those modules will be removed and sampling the inducing weights taken 459 | care of through a forward hook that samples their respective inducing weights jointly. Note that an 460 | extra nn.Module is added to module, so this will show up when iterating over the .modules(). The forward 461 | pass of this module is implemented as the identity function to be compatible e.g. with nn.Sequential, 462 | which does such an iteration in the forward pass 463 | inducing_rows (int): number of global inducing rows 464 | inducing_cols (int): number of inducing columns 465 | cat_dim (int; 0 or 1): dimension of the per-layer inducing weights to concatenate for creating the joint 466 | object. 0 for concatenating the rows (hence the number of inducing columns must be the same across all 467 | inducing layers), 1 for concatenating columns (hence number of rows must be the same) 468 | inducing_kwargs: inducing parameters for the global InduxingMixin layer 469 | 470 | Returns: 471 | The pre forward hook registered on module that samples the inducing weights. 472 | """ 473 | 474 | class _GlobalInducingModule(InducingMixin, nn.Linear): 475 | """Dummy class to use the sampling logic from the inducing mixin for sampling global inducing weights. The 476 | in_features and out_features attributes correspond to the total number inducing columns and rows in the network. 477 | Internal to the function to avoid having it registered as a subclass of VariationalMixin for bayesianize.""" 478 | 479 | def __init__(self, *args, **kwargs): 480 | kwargs["bias"] = False 481 | super().__init__(*args, **kwargs) 482 | 483 | def forward(self, input): 484 | # implement identity function so that this works inside nn.Sequential, which calls the forward 485 | # method of all module that are registered within it 486 | return input 487 | 488 | if cat_dim not in [0, 1]: 489 | raise ValueError("Must concatenate either the rows (cat_dim=0) or columns(cat_dim=1) of the inducing weights") 490 | 491 | if len(list(_iter_inducing_modules(module))) < 2: 492 | raise ValueError("'module' must contain at least two inducing layers as the point of this function is to add " 493 | "correlation between such layers.") 494 | 495 | non_cat_size = None 496 | cat_size = 0 497 | # check that all inducing layers dimensions match for concatenation and that they are whitened, and add up 498 | # the sizes along the concatenation dimension. Note that rather than iterating over all inducing modules we 499 | # could also have the function accept an iterator as an argument, e.g. to exclude input or output layers 500 | # which might have differently shaped inducing weights than the inner modules 501 | for m in _iter_inducing_modules(module): 502 | m_non_cat_size = m.inducing_rows if cat_dim == 1 else m.inducing_cols 503 | if non_cat_size is None: 504 | non_cat_size = m_non_cat_size 505 | elif m_non_cat_size != non_cat_size: 506 | raise ValueError("Size of the inducing dimension which is not concatenated must match across all layers.") 507 | 508 | if not m.whitened_u: 509 | raise ValueError("All inducing weight layers must use whitened inducing weights.") 510 | 511 | cat_size += m.inducing_rows if cat_dim == 0 else m.inducing_cols 512 | 513 | # delete the variational parameters over u by setting them to None -- this allows for the unconstrained 514 | # properties to still be None rather than raising an AttributeError 515 | m.inducing_mean = None 516 | m._inducing_row_scale_tril = None 517 | m._inducing_col_scale_tril = None 518 | m._inducing_scale_tril = None 519 | m._inducing_sd = None 520 | 521 | if cat_dim == 0: 522 | num_cols = non_cat_size 523 | num_rows = cat_size 524 | else: 525 | num_cols = cat_size 526 | num_rows = non_cat_size 527 | 528 | # set the global inducing module as an attribute on the original module so that the parameters are registered 529 | module._global_inducing_module = _GlobalInducingModule( 530 | num_cols, num_rows, inducing_rows=inducing_rows, inducing_cols=inducing_cols, **inducing_kwargs) 531 | 532 | # hook function that is called before the forward pass of module is executed. Jointly samples the inducing weights 533 | # of all layers and assigns the corresponding slice to the ._u attribute of each inducing module, such that the 534 | # per-layer weights are conditioned on those 535 | def inducing_sampling_hook(m, input): 536 | # draw the joint sample of inducing weights for all layers stacked into one big matrix. Note that 537 | # the 'weight matrix' of the global inducing module corresponds to the joint inducing weights of the layers. 538 | # The inducing weights of the module are the 'global' inducing weights 539 | inducing_weights = m._global_inducing_module.sample_parameters() 540 | 541 | offset = 0 542 | for im in _iter_inducing_modules(m): 543 | # skip the global inducing module 544 | if im is m._global_inducing_module: 545 | continue 546 | 547 | if cat_dim == 0: 548 | delta = im.inducing_rows 549 | im._u = inducing_weights[offset:offset+delta] 550 | else: 551 | delta = im.inducing_cols 552 | im._u = inducing_weights[:, offset:offset+delta] 553 | offset += delta 554 | 555 | return module.register_forward_pre_hook(inducing_sampling_hook) 556 | 557 | 558 | class InducingDeterministicConditionalCholeskyMixin(_InducingBase): 559 | """Hybrid model which has the mean of the marginalized inducing model (i.e. the mean of the conditional 560 | distribution is a function of only U, but not U_r or U_c) and the covariance is that of the fully 561 | conditional model, i.e. the matrix normal that corresponds to p(W|U_r, U_c, U).""" 562 | 563 | def conditional_mean_and_noise(self, u, row_chol, col_chol): 564 | # Z_r^T (Z_r Z_r^T + D_r^2)^-1 565 | row_transform = self.compute_row_transform(row_chol) 566 | # (Z_c Z_c^T + D_c^2)^-1 Z_c 567 | col_transform = self.compute_col_transform(col_chol) 568 | 569 | M_w = row_transform @ u @ col_transform 570 | 571 | row_cov = (1 + EPS) * torch.eye(self._d_out, device=self.device) - row_transform @ self.z_row 572 | L_r = row_cov.cholesky() 573 | col_cov = (1 + EPS) * torch.eye(self._d_in, device=self.device) - self.z_col.t() @ col_transform 574 | L_c = col_cov.cholesky() 575 | 576 | noise_term = L_r @ torch.randn_like(M_w) @ L_c.t() 577 | 578 | return M_w, noise_term 579 | 580 | 581 | class InducingMarginalizedCholeskyMixin(_InducingBase): 582 | """Marginalized inducing model that explicitly calculates the Cholesky of the multivariate normal conditional 583 | distribution over the weights. This is intractable for even moderately sized layers, the main purpose of this 584 | class is to test the sampling of the Matheron variant.""" 585 | 586 | def conditional_mean_and_noise(self, u, row_chol, col_chol): 587 | # Z_r^T (Z_r Z_r^T + D_r^2)^-1 588 | row_transform = self.compute_row_transform(row_chol) 589 | # (Z_c Z_c^T + D_c^2)^-1 Z_c 590 | col_transform = self.compute_col_transform(col_chol) 591 | 592 | M_w = row_transform @ u @ col_transform 593 | 594 | row_cov = row_transform @ self.z_row 595 | col_cov = self.z_col.t() @ col_transform 596 | weight_cov = (1 + EPS) * torch.eye(self._d_out * self._d_in, device=self.device) - kron(row_cov, col_cov) 597 | L_w = weight_cov.cholesky() 598 | noise_term = L_w.mv(torch.randn_like(M_w).flatten()).view_as(M_w) 599 | 600 | return M_w, noise_term 601 | 602 | -------------------------------------------------------------------------------- /bnn/nn/mixins/variational/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import math 3 | import operator 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | EPS = 1e-4 10 | 11 | 12 | def prod(iterable): 13 | return functools.reduce(operator.mul, iterable, 1) 14 | 15 | 16 | def inverse_softplus(x): 17 | if torch.is_tensor(x): 18 | return x.expm1().log() 19 | else: 20 | return math.log(math.expm1(x)) 21 | 22 | 23 | def vec_to_chol(x, out=None): 24 | """Transforms a batch of (d+1)*d/2-dimensional vector to a batch of d x d dimensional lower triangular matrices 25 | with positive diagonals.""" 26 | 27 | # calculate d using quadratic formula from x.shape[-1] 28 | d_float = 0.5 * (math.sqrt(1 + 8 * x.shape[-1]) - 1) 29 | d = int(d_float) 30 | if d != d_float: 31 | raise ValueError("Trailing dimension of input must be of size d+1 choose 2.") 32 | 33 | if out is None: 34 | out = x.new_zeros(x.shape[:-1] + (d, d)) 35 | # set diagonal to positive elements 36 | out[..., torch.arange(d), torch.arange(d)] = F.softplus(x[..., :d]) 37 | # set off-diagonal elements 38 | tril_row_indices, tril_col_indices = torch.tril_indices(d, d, offset=-1) 39 | out[..., tril_row_indices, tril_col_indices] = x[..., d:] 40 | return out 41 | -------------------------------------------------------------------------------- /bnn/nn/modules.py: -------------------------------------------------------------------------------- 1 | """This module creates the actual Bayesian layers classes. Currently 2 | this is done by hand, but it might be worth defining a list of modules 3 | of which Bayesian variants should exist (currently Linear and Conv1-3d) 4 | and generate the corresponding subclasses in combination with all Mixin 5 | classes, which could be found programmatically.""" 6 | 7 | import torch.nn as nn 8 | 9 | 10 | from .mixins import * 11 | 12 | 13 | __all__ = [ 14 | "FFGLinear", "FFGConv1d", "FFGConv2d", "FFGConv3d", 15 | "FCGLinear", "FCGConv1d", "FCGConv2d", "FCGConv3d", 16 | "InducingDeterministicLinear", "InducingDeterministicConv1d", 17 | "InducingDeterministicConv2d", "InducingDeterministicConv3d", 18 | "InducingLinear", "InducingConv1d", "InducingConv2d", "InducingConv3d", 19 | ] 20 | 21 | 22 | class FFGLinear(FFGMixin, nn.Linear): pass 23 | class FFGConv1d(FFGMixin, nn.Conv1d): pass 24 | class FFGConv2d(FFGMixin, nn.Conv2d): pass 25 | class FFGConv3d(FFGMixin, nn.Conv3d): pass 26 | 27 | 28 | class FCGLinear(FCGMixin, nn.Linear): pass 29 | class FCGConv1d(FCGMixin, nn.Conv1d): pass 30 | class FCGConv2d(FCGMixin, nn.Conv2d): pass 31 | class FCGConv3d(FCGMixin, nn.Conv3d): pass 32 | 33 | 34 | class InducingDeterministicLinear(InducingDeterministicMixin, nn.Linear): pass 35 | class InducingDeterministicConv1d(InducingDeterministicMixin, nn.Conv1d): pass 36 | class InducingDeterministicConv2d(InducingDeterministicMixin, nn.Conv2d): pass 37 | class InducingDeterministicConv3d(InducingDeterministicMixin, nn.Conv3d): pass 38 | 39 | 40 | class InducingLinear(InducingMixin, nn.Linear): pass 41 | class InducingConv1d(InducingMixin, nn.Conv1d): pass 42 | class InducingConv2d(InducingMixin, nn.Conv2d): pass 43 | class InducingConv3d(InducingMixin, nn.Conv3d): pass 44 | -------------------------------------------------------------------------------- /bnn/nn/nets.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import torch.nn as nn 4 | 5 | import torchvision 6 | 7 | 8 | def make_network(architecture: str, *args, **kwargs): 9 | if architecture == "fcn": 10 | return FCN(**kwargs) 11 | elif architecture == "cnn": 12 | return CNN(**kwargs) 13 | elif architecture.startswith("resnet"): 14 | net = getattr(torchvision.models, architecture)(num_classes=kwargs["out_features"]) 15 | if "kernel_size" in kwargs: 16 | kernel_size = kwargs["kernel_size"] 17 | stride = kwargs.get("stride", 1) 18 | padding = kwargs.get("padding", kernel_size // 2) 19 | in_channels = kwargs.get("in_channels", 3) 20 | bias = net.conv1.bias is not None 21 | net.conv1 = nn.Conv2d(in_channels, net.conv1.out_channels, kernel_size, stride, padding, bias=bias) 22 | if kwargs.get("remove_maxpool", False): 23 | net.maxpool = nn.Identity() 24 | return net 25 | else: 26 | raise ValueError("Unrecognized network architecture:", architecture) 27 | 28 | 29 | class FCN(nn.Sequential): 30 | """Basic fully connected network class.""" 31 | 32 | def __init__(self, sizes: List[int], nonlinearity: Union[str, type] = "ReLU", bn: bool = False, **layer_kwargs): 33 | super().__init__() 34 | nonl_class = getattr(nn, nonlinearity) if isinstance(nonlinearity, str) else nonlinearity 35 | 36 | layer_kwargs.setdefault("bias", not bn) 37 | for i, (s0, s1) in enumerate(zip(sizes[:-1], sizes[1:])): 38 | self.add_module(f"Linear{i}", nn.Linear(s0, s1, **layer_kwargs)) 39 | if bn: 40 | self.add_module(f"BN{i}", nn.BatchNorm1d(s1)) 41 | if i < len(sizes) - 2: 42 | self.add_module(f"Nonlinarity{i}", nonl_class()) 43 | 44 | 45 | class CNN(nn.Sequential): 46 | """Basic CNN class with Conv/BN/Nonl/Maxpool blocks followed by a fully connected net. Batchnorm and maxpooling 47 | are optional and the latter can also only be included after every nth block.""" 48 | 49 | def __init__(self, channels: List[int], lin_sizes: List[int], nonlinearity: Union[str, type] = "ReLU", 50 | maxpool_freq: int = 1, conv_bn: bool = False, linear_bn: bool = False, kernel_size: int = 3, 51 | **conv_kwargs): 52 | super().__init__() 53 | nonl_class = getattr(nn, nonlinearity) if isinstance(nonlinearity, str) else nonlinearity 54 | conv_kwargs.setdefault("bias", not conv_bn) 55 | for i, (c0, c1) in enumerate(zip(channels[:-1], channels[1:])): 56 | self.add_module(f"Conv{i}", nn.Conv2d(c0, c1, kernel_size, **conv_kwargs)) 57 | if conv_bn: 58 | self.add_module(f"ConvBN{i}", nn.BatchNorm2d(c1)) 59 | self.add_module(f"ConvNonlinearity{i}", nonl_class()) 60 | if maxpool_freq and (i + 1) % maxpool_freq == 0: 61 | self.add_module(f"Maxpool{i//maxpool_freq}", nn.MaxPool2d(2, 2)) 62 | self.add_module("Flatten", nn.Flatten()) 63 | 64 | self.add_module("fc", FCN(lin_sizes, nonlinearity=nonlinearity, bn=linear_bn)) 65 | -------------------------------------------------------------------------------- /configs/ensemble_u_cifar10.json: -------------------------------------------------------------------------------- 1 | { 2 | "inference": { 3 | "Conv2d": { 4 | "inference": "inducing", 5 | "inducing_rows": 128, 6 | "inducing_cols": 128 7 | }, "Linear": { 8 | "inference": "inducing", 9 | "inducing_rows": 10, 10 | "inducing_cols": 128 11 | } 12 | }, 13 | "whitened_u": true, 14 | "q_inducing": "ensemble", 15 | "ensemble_size": 5, 16 | "learn_lamda": true, 17 | "init_lamda": 0.001, 18 | "max_lamda": 0.1, 19 | "cache_cholesky": true, 20 | "prior_sd": 1.0, 21 | "sqrt_width_scaling": true 22 | } 23 | -------------------------------------------------------------------------------- /configs/ensemble_u_cifar100.json: -------------------------------------------------------------------------------- 1 | { 2 | "inference": { 3 | "Conv2d": { 4 | "inference": "inducing", 5 | "inducing_rows": 128, 6 | "inducing_cols": 128 7 | }, "Linear": { 8 | "inference": "inducing", 9 | "inducing_rows": 100, 10 | "inducing_cols": 128 11 | } 12 | }, 13 | "whitened_u": true, 14 | "q_inducing": "ensemble", 15 | "ensemble_size": 5, 16 | "learn_lamda": true, 17 | "init_lamda": 0.001, 18 | "max_lamda": 0.1, 19 | "cache_cholesky": true, 20 | "prior_sd": 1.0, 21 | "sqrt_width_scaling": true 22 | } 23 | -------------------------------------------------------------------------------- /configs/ffg_u_cifar10.json: -------------------------------------------------------------------------------- 1 | { 2 | "inference": { 3 | "Conv2d": { 4 | "inference": "inducing", 5 | "inducing_rows": 128, 6 | "inducing_cols": 128 7 | }, "Linear": { 8 | "inference": "inducing", 9 | "inducing_rows": 10, 10 | "inducing_cols": 128 11 | } 12 | }, 13 | "whitened_u": true, 14 | "q_inducing": "diagonal", 15 | "learn_lamda": true, 16 | "init_lamda": 0.001, 17 | "max_lamda": 0.03, 18 | "max_sd_u": 0.1, 19 | "cache_cholesky": true, 20 | "prior_sd": 1.0, 21 | "sqrt_width_scaling": true 22 | } 23 | -------------------------------------------------------------------------------- /configs/ffg_u_cifar100.json: -------------------------------------------------------------------------------- 1 | { 2 | "inference": { 3 | "Conv2d": { 4 | "inference": "inducing", 5 | "inducing_rows": 128, 6 | "inducing_cols": 128 7 | }, "Linear": { 8 | "inference": "inducing", 9 | "inducing_rows": 100, 10 | "inducing_cols": 128 11 | } 12 | }, 13 | "whitened_u": true, 14 | "q_inducing": "diagonal", 15 | "learn_lamda": true, 16 | "init_lamda": 0.001, 17 | "max_lamda": 0.03, 18 | "max_sd_u": 0.1, 19 | "cache_cholesky": true, 20 | "prior_sd": 1.0, 21 | "sqrt_width_scaling": true 22 | } 23 | -------------------------------------------------------------------------------- /configs/ffg_w_maxsd01.json: -------------------------------------------------------------------------------- 1 | { 2 | "inference": "ffg", 3 | "max_sd": 0.1 4 | } -------------------------------------------------------------------------------- /configs/ffg_w_unconstrained.json: -------------------------------------------------------------------------------- 1 | { 2 | "inference": "ffg" 3 | } -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: bayesianize 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python==3.8.5 7 | - numpy==1.19.2 8 | - pytorch==1.7.0 9 | - torchvision==0.8.1 10 | - scipy==1.5.2 11 | - tqdm==4.50.2 12 | - matplotlib==3.3.2 13 | - seaborn==0.11.0 14 | - pytest==6.1.1 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.19.2 2 | torch>=1.7.0 3 | torchvision>=0.8.1 4 | scipy>=1.5.2 5 | tqdm>=4.50.2 6 | matplotlib>=3.3.2 7 | seaborn>=0.11.0 8 | pytest>=6.1.1 9 | -------------------------------------------------------------------------------- /scripts/cifar_resnet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | import json 4 | import os 5 | import pickle 6 | 7 | import numpy as np 8 | 9 | from tqdm import tqdm, trange 10 | 11 | import matplotlib.pyplot as plt 12 | import seaborn as sns 13 | 14 | import torch 15 | import torch.distributions as dist 16 | import torch.utils.data as data 17 | 18 | import torchvision 19 | import torchvision.transforms as tf 20 | 21 | 22 | import bnn 23 | from bnn.calibration import calibration_curve, expected_calibration_error as ece 24 | 25 | 26 | STATS = { 27 | "CIFAR10": {"mean": (0.49139968, 0.48215841, 0.44653091), "std": (0.24703223, 0.24348513, 0.26158784)}, 28 | "CIFAR100": {"mean": (0.50707516, 0.48654887, 0.44091784), "std": (0.26733429, 0.25643846, 0.27615047)} 29 | } 30 | ROOT = os.environ.get("DATASETS_PATH", "./data") 31 | NUM_BINS = 10 32 | 33 | 34 | def reset_cache(module): 35 | if hasattr(module, "reset_cache"): 36 | module.reset_cache() 37 | 38 | 39 | def main(seed, num_epochs, inference_config, output_dir, ml_epochs, annealing_epochs, train_samples, test_samples, 40 | verbose, progress_bar, lr, cifar, optimizer, momentum, milestones, gamma, resnet): 41 | torch.manual_seed(seed) 42 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 43 | 44 | # set up data loaders 45 | dataset_name = f"CIFAR{cifar}" 46 | dataset_cls = getattr(torchvision.datasets, dataset_name) 47 | root = f"{ROOT}/{dataset_name.lower()}" 48 | print(f"Loading dataset {dataset_cls} from {root}") 49 | aug_tf = [tf.RandomCrop(32, padding=4, padding_mode="reflect"), tf.RandomHorizontalFlip()] 50 | norm_tf = [tf.ToTensor(), tf.Normalize(**STATS[dataset_name])] 51 | train_data = dataset_cls(root, train=True, transform=tf.Compose(aug_tf + norm_tf), download=True) 52 | test_data = dataset_cls(root, train=False, transform=tf.Compose(norm_tf), download=True) 53 | 54 | train_loader = data.DataLoader(train_data, batch_size=100, shuffle=True, num_workers=1, pin_memory=True) 55 | test_loader = data.DataLoader(test_data, batch_size=1000) 56 | 57 | # set up net and optimizer 58 | net = bnn.nn.nets.make_network(f"resnet{resnet}", kernel_size=3, remove_maxpool=True, out_features=cifar) 59 | if inference_config is not None: 60 | with open(inference_config) as f: 61 | cfg = json.load(f) 62 | bnn.bayesianize_(net, **cfg) 63 | 64 | if verbose: 65 | print(net) 66 | net.to(device) 67 | 68 | if optimizer == "adam": 69 | optim = torch.optim.Adam(net.parameters(), lr) 70 | elif optimizer == "sgd": 71 | optim = torch.optim.SGD(net.parameters(), lr, momentum=momentum) 72 | else: 73 | raise RuntimeError("Unknown optimizer:", optimizer) 74 | 75 | # set up dict for tracking losses and load state dicts if applicable 76 | metrics = defaultdict(list) 77 | if output_dir is not None: 78 | os.makedirs(output_dir, exist_ok=True) 79 | 80 | snapshot_sd_path = os.path.join(output_dir, "snapshot_sd.pt") 81 | snapshot_optim_path = os.path.join(output_dir, "snapshot_optim.sd") 82 | metrics_path = os.path.join(output_dir, "metrics.pkl") 83 | if os.path.isfile(snapshot_sd_path): 84 | net.load_state_dict(torch.load(snapshot_sd_path, map_location=device)) 85 | optim.load_state_dict(torch.load(snapshot_optim_path, map_location=device)) 86 | with open(metrics_path, "rb") as f: 87 | metrics = pickle.load(f) 88 | else: 89 | torch.save(net.state_dict(), os.path.join(output_dir, "initial_sd.pt")) 90 | else: 91 | snapshot_sd_path = None 92 | snapshot_optim_path = None 93 | metrics_path = None 94 | 95 | last_epoch = len(metrics["acc"]) - 1 96 | 97 | if milestones is not None: 98 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, milestones, gamma=gamma, last_epoch=last_epoch) 99 | else: 100 | scheduler = None 101 | 102 | kl_factor = 0. if ml_epochs > 0 or annealing_epochs > 0 else 1. 103 | annealing_rate = annealing_epochs ** -1 if annealing_epochs > 0 else 1. 104 | 105 | epoch_iter = trange(last_epoch + 1, num_epochs, desc="Epochs") if progress_bar else range(last_epoch + 1, num_epochs) 106 | for i in epoch_iter: 107 | net.train() 108 | net.apply(reset_cache) 109 | batch_iter = tqdm(iter(train_loader), desc="Batches") if progress_bar else iter(train_loader) 110 | for j, (x, y) in enumerate(batch_iter): 111 | x = x.to(device) 112 | y = y.to(device) 113 | 114 | optim.zero_grad() 115 | avg_nll = 0. 116 | for k in range(train_samples): 117 | yhat = net(x) 118 | nll = -dist.Categorical(logits=yhat).log_prob(y).mean() / train_samples 119 | if k == 0: 120 | kl = torch.tensor(0., device=device) 121 | for module in net.modules(): 122 | if hasattr(module, "parameter_loss"): 123 | kl = kl + module.parameter_loss().sum() 124 | metrics["kl"].append(kl.item()) 125 | loss = nll + kl * kl_factor / len(train_data) 126 | else: 127 | loss = nll 128 | 129 | avg_nll += nll.item() 130 | loss.backward(retain_graph=train_samples > 1) 131 | 132 | optim.step() 133 | 134 | net.apply(reset_cache) 135 | 136 | metrics["nll"].append(avg_nll) 137 | 138 | if scheduler is not None: 139 | scheduler.step() 140 | 141 | net.eval() 142 | with torch.no_grad(): 143 | probs, targets = map(torch.cat, zip(*( 144 | (sum(net(x.to(device)).softmax(-1) for _ in range(test_samples)).div(test_samples).to("cpu"), y) 145 | for x, y in iter(test_loader) 146 | ))) 147 | 148 | metrics["acc"].append(probs.argmax(-1).eq(targets).float().mean().item()) 149 | p, f, w = calibration_curve(probs.numpy(), targets.numpy(), NUM_BINS) 150 | metrics["ece"].append(ece(p, f, w).item()) 151 | 152 | if verbose: 153 | print(f"Epoch {i} -- Accuracy: {100 * metrics['acc'][-1]:.2f}%; ECE={100 * metrics['ece'][-1]:.2f}") 154 | 155 | # free up some variables for garbage collection to save memory for smaller GPUs 156 | del probs 157 | del targets 158 | del p 159 | del f 160 | del w 161 | torch.cuda.empty_cache() 162 | 163 | if output_dir is not None: 164 | torch.save(net.state_dict(), snapshot_sd_path) 165 | torch.save(optim.state_dict(), snapshot_optim_path) 166 | with open(metrics_path, "wb") as fn: 167 | pickle.dump(metrics, fn) 168 | 169 | if i >= ml_epochs: 170 | kl_factor = min(1., kl_factor + annealing_rate) 171 | 172 | print(f"Final test accuracy: {100 * metrics['acc'][-1]:.2f}") 173 | 174 | if output_dir is not None: 175 | bin_width = NUM_BINS ** -1 176 | bin_centers = np.linspace(bin_width / 2, 1 - bin_width / 2, NUM_BINS) 177 | 178 | plt.figure(figsize=(5, 5)) 179 | plt.plot([0, 100], [0, 100], color="black", linestyle="dashed", alpha=0.5) 180 | plt.plot(100 * p[w > 0], 100 * f[w > 0], marker="o", markersize=8) 181 | plt.bar(100 * bin_centers[w > 0], 100 * w[w > 0], width=100 * bin_width, alpha=0.5) 182 | plt.xlabel("Mean probability predicted") 183 | plt.ylabel("Empirical accuracy") 184 | plt.title(f"Calibration curve (Accuracy={100 * metrics['acc'][-1]:.2f}%; ECE={100 * metrics['ece'][-1]:.4f})") 185 | plt.savefig(os.path.join(output_dir, "calibration.png"), bbox_inches="tight") 186 | 187 | 188 | if __name__ == '__main__': 189 | def list_of_ints(s): 190 | return list(map(int, s.split(","))) 191 | 192 | parser = argparse.ArgumentParser() 193 | parser.add_argument("--num-epochs", type=int, default=200) 194 | parser.add_argument("--train-samples", type=int, default=1) 195 | parser.add_argument("--test-samples", type=int, default=8) 196 | parser.add_argument("--annealing-epochs", type=int, default=0) 197 | parser.add_argument("--ml-epochs", type=int, default=0) 198 | parser.add_argument("--inference-config") 199 | parser.add_argument("--output-dir") 200 | parser.add_argument("--verbose", action="store_true") 201 | parser.add_argument("--progress-bar", action="store_true") 202 | parser.add_argument("--lr", type=float, default=1e-3) 203 | parser.add_argument("--seed", type=int, default=42) 204 | parser.add_argument("--cifar", type=int, default=10, choices=[10, 100]) 205 | parser.add_argument("--optimizer", choices=["sgd", "adam"], default="adam") 206 | parser.add_argument("--momentum", type=float, default=0.9) 207 | parser.add_argument("--milestones", type=list_of_ints) 208 | parser.add_argument("--gamma", type=float, default=0.1) 209 | parser.add_argument("--resnet", type=int, default=18, choices=[18, 34, 50, 101, 152]) 210 | 211 | args = parser.parse_args() 212 | if args.verbose: 213 | print(vars(args)) 214 | main(**vars(args)) 215 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | with open("requirements.txt") as f: 5 | requirements = f.read().splitlines() 6 | 7 | 8 | setup( 9 | name='bayesianize', 10 | version='0.0.1', 11 | url='https://github.com/TODO', 12 | author=['Hippolyt Ritter', 'Martin Kukla', 'Cheng Zhang', 'Yingzhen Li'], 13 | description='Lightweight BNN wrapper for pytorch.', 14 | packages=find_packages(), 15 | install_requires=requirements, 16 | license="MIT" 17 | ) 18 | -------------------------------------------------------------------------------- /tests/bnn/distributions/test_kl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions as dist 3 | 4 | 5 | import models.bnn.distributions.matrix_normal as mn 6 | 7 | 8 | def _rand_psd(d, batch_shape=None): 9 | if batch_shape is None: 10 | batch_shape = tuple() 11 | elif isinstance(batch_shape, int): 12 | batch_shape = (batch_shape,) 13 | a = torch.randn(*batch_shape, d, d+1) 14 | return a @ a.transpose(-2, -1) + 1e-2 * torch.eye(d) 15 | 16 | 17 | def test_matmat_kl(): 18 | rows, cols = 3, 4 19 | loc_p = torch.randn(rows, cols) 20 | row_cov_p = _rand_psd(rows) 21 | col_cov_p = _rand_psd(cols) 22 | p = mn.MatrixNormal(loc_p, row_cov_p.cholesky(), col_cov_p.cholesky()) 23 | 24 | loc_q = torch.randn(rows, cols) 25 | row_cov_q = _rand_psd(rows) 26 | col_cov_q = _rand_psd(cols) 27 | q = mn.MatrixNormal(loc_q, row_cov_q.cholesky(), col_cov_q.cholesky()) 28 | 29 | kl1 = dist.kl_divergence(p, q) 30 | kl2 = dist.kl_divergence(p.to_multivariatenormal(), q.to_multivariatenormal()) 31 | assert torch.isclose(kl1, kl2) 32 | 33 | 34 | def test_matdiag_kl(): 35 | rows, cols = 3, 2 36 | loc_p = torch.randn(rows, cols) 37 | row_cov_p = _rand_psd(rows) 38 | col_cov_p = _rand_psd(cols) 39 | p = mn.MatrixNormal(loc_p, row_cov_p.cholesky(), col_cov_p.cholesky()) 40 | p_multi = p.to_multivariatenormal() 41 | 42 | loc_q = torch.randn(rows, cols) 43 | scale_q = torch.rand(rows, cols) 44 | q = dist.Normal(loc_q, scale_q) 45 | q_multi = dist.MultivariateNormal(loc_q.flatten(), scale_tril=scale_q.flatten().diag_embed()) 46 | 47 | kl1 = dist.kl_divergence(p, q) 48 | kl2 = dist.kl_divergence(p_multi, q_multi) 49 | assert torch.isclose(kl1, kl2) 50 | -------------------------------------------------------------------------------- /tests/bnn/distributions/test_matrix_normal.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.distributions as dist 7 | from torch.distributions.multivariate_normal import _batch_mahalanobis 8 | 9 | 10 | import models.bnn.distributions.matrix_normal as mn 11 | 12 | 13 | def _rand_psd(d, batch_shape=None): 14 | if batch_shape is None: 15 | batch_shape = tuple() 16 | elif isinstance(batch_shape, int): 17 | batch_shape = (batch_shape,) 18 | a = torch.randn(*batch_shape, d, d+1) 19 | return a @ a.transpose(-2, -1) + 1e-2 * torch.eye(d) 20 | 21 | 22 | def test_kron(): 23 | a = torch.ones(2, 2) 24 | b = torch.arange(6).reshape(3, 2).float() 25 | 26 | assert torch.allclose(mn.kron(a, b), b.repeat(2, 2)) 27 | 28 | 29 | @pytest.mark.parametrize("a_rows,a_cols,b_rows,b_cols", [ 30 | (2, 3, 4, 5), 31 | (10, 12, 3, 6), 32 | (11, 3, 8, 19) 33 | ]) 34 | def test_kron_numpy(a_rows, a_cols, b_rows, b_cols): 35 | a = torch.randn(a_rows, a_cols) 36 | b = torch.randn(b_rows, b_cols) 37 | mn_result = mn.kron(a, b) 38 | np_result = torch.from_numpy(np.kron(a.numpy(), b.numpy())) 39 | 40 | assert torch.allclose(mn_result, np_result) 41 | 42 | 43 | def test_kron_batch(): 44 | a = torch.randn(2, 2, 2) 45 | b = torch.randn(2, 3, 3) 46 | batched_kron = mn.kron(a, b) 47 | iter_kron = torch.stack([mn.kron(aa, bb) for aa, bb in zip(a, b)]) 48 | assert torch.allclose(batched_kron, iter_kron) 49 | 50 | 51 | @pytest.mark.parametrize("x_batch,u_batch,v_batch", [ 52 | (None, None, None), 53 | ((2,), 2, 2), 54 | ((2,), None, None), 55 | ((3, 2), 2, 2) 56 | ]) 57 | def test_batch_kf_mahalanobis(x_batch, u_batch, v_batch): 58 | rows, cols = 3, 4 59 | if x_batch is None: 60 | x_batch = tuple() 61 | bX = torch.randn(*x_batch, rows, cols) 62 | bLU = _rand_psd(rows, u_batch).cholesky() 63 | bLV = _rand_psd(cols, v_batch).cholesky() 64 | bL = mn.kron(bLU, bLV) 65 | 66 | kf_result = mn._batch_kf_mahalanobis(bLU, bLV, bX) 67 | t_result = _batch_mahalanobis(bL, bX.flatten(-2)) 68 | 69 | assert torch.allclose(kf_result, t_result) 70 | 71 | 72 | def test_single_sample_shape(): 73 | matnorm = mn.MatrixNormal(torch.zeros(2, 3), torch.eye(2), torch.eye(3)) 74 | x = matnorm.rsample() 75 | assert x.shape == (2, 3) 76 | 77 | 78 | def test_batch_sample_shape(): 79 | matnorm = mn.MatrixNormal(torch.zeros(5, 6), torch.eye(5), torch.eye(6)) 80 | x = matnorm.rsample((10,)) 81 | assert x.shape == (10, 5, 6) 82 | 83 | 84 | def test_to_multivariate_normal(): 85 | rows, cols = 3, 4 86 | loc = torch.randn(rows, cols) 87 | row_cov = _rand_psd(rows) 88 | col_cov = _rand_psd(cols) 89 | 90 | matnorm = mn.MatrixNormal(loc, row_cov.cholesky(), col_cov.cholesky()) 91 | multinorm = dist.MultivariateNormal(loc.flatten(), mn.kron(row_cov, col_cov)) 92 | 93 | assert torch.allclose(matnorm.to_multivariatenormal().loc, multinorm.loc) 94 | assert torch.allclose(matnorm.to_multivariatenormal().scale_tril, multinorm.scale_tril, atol=1e-4) 95 | 96 | 97 | def test_batch_to_multivariate_normal(): 98 | batch, rows, cols = 2, 3, 4 99 | loc = torch.randn(batch, rows, cols) 100 | row_cov = _rand_psd(rows, batch) 101 | col_cov = _rand_psd(cols, batch) 102 | 103 | matnorm = mn.MatrixNormal(loc, row_cov.cholesky(), col_cov.cholesky()) 104 | multinorm = dist.MultivariateNormal(loc.flatten(-2), mn.kron(row_cov, col_cov)) 105 | 106 | assert torch.allclose(matnorm.to_multivariatenormal().loc, multinorm.loc) 107 | assert torch.allclose(matnorm.to_multivariatenormal().scale_tril, multinorm.scale_tril, atol=1e-4) 108 | 109 | 110 | def test_mean(): 111 | rows, cols = 3, 4 112 | loc = torch.randn(3, 4) 113 | row_cov = _rand_psd(rows) 114 | col_cov = _rand_psd(cols) 115 | 116 | matnorm = mn.MatrixNormal(loc, row_cov.cholesky(), col_cov.cholesky()) 117 | multinorm = matnorm.to_multivariatenormal() 118 | 119 | assert torch.allclose(matnorm.mean.flatten(-2), multinorm.mean) 120 | 121 | 122 | def test_variance(): 123 | rows, cols = 3, 4 124 | loc = torch.randn(3, 4) 125 | row_cov = _rand_psd(rows) 126 | col_cov = _rand_psd(cols) 127 | 128 | matnorm = mn.MatrixNormal(loc, row_cov.cholesky(), col_cov.cholesky()) 129 | multinorm = matnorm.to_multivariatenormal() 130 | 131 | assert torch.allclose(matnorm.variance.flatten(-2), multinorm.variance) 132 | 133 | 134 | def test_entropy(): 135 | rows, cols = 3, 4 136 | loc = torch.randn(3, 4) 137 | row_cov = _rand_psd(rows) 138 | col_cov = _rand_psd(cols) 139 | 140 | matnorm = mn.MatrixNormal(loc, row_cov.cholesky(), col_cov.cholesky()) 141 | multinorm = matnorm.to_multivariatenormal() 142 | 143 | assert torch.isclose(matnorm.entropy(), multinorm.entropy()) 144 | 145 | 146 | def test_log_prob(): 147 | rows, cols = 3, 4 148 | loc = torch.randn(rows, cols) 149 | row_cov = _rand_psd(rows) 150 | col_cov = _rand_psd(cols) 151 | 152 | matnorm = mn.MatrixNormal(loc, row_cov.cholesky(), col_cov.cholesky()) 153 | # using row major order for vec/flatten operation as implemented in pytorch means that the covariance of the 154 | # multivariate normal equivalent to the matrix normal has covariance U \otimes V rather than V \otimes U 155 | # as on Wikipedia (where theye use column major order) 156 | multinorm = matnorm.to_multivariatenormal() 157 | 158 | x = matnorm.sample() 159 | 160 | assert torch.isclose(matnorm.log_prob(x), multinorm.log_prob(x.flatten())) 161 | 162 | 163 | def test_sample_cov(): 164 | torch.manual_seed(42) 165 | rows, cols = 3, 4 166 | sample_size = int(1e6) 167 | 168 | loc = torch.randn(rows, cols) 169 | row_cov = _rand_psd(rows) 170 | col_cov = _rand_psd(cols) 171 | 172 | matnorm = mn.MatrixNormal(loc, row_cov.cholesky(), col_cov.cholesky()) 173 | true_cov = mn.kron(row_cov, col_cov) 174 | multinorm = matnorm.to_multivariatenormal() 175 | 176 | matnorm_sample = matnorm.sample((sample_size,)).flatten(-2) 177 | multinorm_sample = multinorm.sample((sample_size,)) 178 | 179 | matnorm_cov = torch.from_numpy(np.cov(matnorm_sample.t().numpy())).float() 180 | multinorm_cov = torch.from_numpy(np.cov(multinorm_sample.t().numpy())).float() 181 | 182 | # we're testing if the empirical covariance of samples from the matrix normal is similarly close to the true 183 | # covariance as samples from the equivalent multivariate normal. The threshold is relatively high at 0.01, but 184 | # increasing the sample size -- which slows down the test noticeably -- would enable a smaller threshold, 185 | # e.g. 0.001 for a 10 times larger sample 186 | assert torch.isclose(matnorm_cov.sub(true_cov).abs().mean(), multinorm_cov.sub(true_cov).abs().mean(), atol=1e-2) 187 | 188 | 189 | def test_1d_loc_raises(): 190 | with pytest.raises(ValueError): 191 | mn.MatrixNormal(torch.randn(3), _rand_psd(3), _rand_psd(3)) 192 | 193 | 194 | def test_1d_row_cov_raises(): 195 | with pytest.raises(ValueError): 196 | mn.MatrixNormal(torch.randn(3, 4), torch.rand(3), _rand_psd(4)) 197 | 198 | 199 | def test_1d_col_cov_raises(): 200 | with pytest.raises(ValueError): 201 | mn.MatrixNormal(torch.randn(3, 4), _rand_psd(3), torch.rand(4)) 202 | 203 | 204 | def test_row_cov_mismatch_raises(): 205 | with pytest.raises(ValueError): 206 | mn.MatrixNormal(torch.randn(3, 4), _rand_psd(2), _rand_psd(4)) 207 | 208 | 209 | def test_col_cov_mismatch_raises(): 210 | with pytest.raises(ValueError): 211 | mn.MatrixNormal(torch.randn(3, 4), _rand_psd(3), _rand_psd(2)) 212 | -------------------------------------------------------------------------------- /tests/bnn/nn/mixins/variational/test_fcg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | import models.bnn as bnn 5 | 6 | 7 | def test_init_from_deterministic_params(): 8 | layer = bnn.nn.FCGLinear(5, 3) 9 | weight = torch.randn(3, 5) 10 | bias = torch.randn(3) 11 | layer.init_from_deterministic_params({"weight": weight, "bias": bias}) 12 | assert torch.allclose(torch.cat([weight.flatten(), bias]), layer.mean) 13 | 14 | 15 | def test_init_from_deterministic_params_no_bias(): 16 | layer = bnn.nn.FCGLinear(5, 3, bias=False) 17 | weight = torch.randn(3, 5) 18 | layer.init_from_deterministic_params({"weight": weight}) 19 | assert torch.allclose(weight.flatten(), layer.mean) 20 | 21 | 22 | def test_sampling(): 23 | """Tests that the ffg layer samples from the correct distribution.""" 24 | torch.manual_seed(24) 25 | 26 | layer = bnn.nn.FCGLinear(3, 1, bias=False, init_sd=0.1) 27 | x = torch.randn(1, 3) 28 | 29 | # for w ~ N(\mu, \Sigma), x^T w ~ N(x^T \mu, x^T \Sigma x) 30 | mu = x.mv(layer.mean).squeeze() 31 | sd = x.mm(layer.scale_tril).pow(2).sum().sqrt() 32 | 33 | a = torch.stack([layer(x).squeeze() for _ in range(1000)]) 34 | assert torch.isclose(mu, a.mean(), atol=1e-2) 35 | assert torch.isclose(sd, a.std(), atol=1e-2) 36 | -------------------------------------------------------------------------------- /tests/bnn/nn/mixins/variational/test_ffg.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | 5 | 6 | import models.bnn as bnn 7 | 8 | 9 | @pytest.mark.parametrize("local_reparam", [True, False]) 10 | def test_sampling(local_reparam): 11 | """Tests that the ffg layer samples from the correct distribution.""" 12 | torch.manual_seed(24) 13 | 14 | layer = bnn.nn.FFGLinear(2, 3, bias=False, init_sd=0.1, local_reparameterization=local_reparam) 15 | x = torch.randn(1, 2) 16 | 17 | mu = x.mm(layer.weight_mean.t()) 18 | sd = x.pow(2).mm(layer.weight_sd.pow(2).t()).sqrt() 19 | 20 | a = torch.stack([layer(x) for _ in range(1000)]) 21 | assert torch.allclose(mu, a.mean(0), atol=1e-2) 22 | assert torch.allclose(sd, a.std(0), atol=1e-2) 23 | 24 | 25 | def test_init_from_deterministic_params(): 26 | layer = bnn.nn.FFGLinear(5, 3) 27 | weight = torch.randn(3, 5) 28 | bias = torch.randn(3) 29 | layer.init_from_deterministic_params({"weight": weight, "bias": bias}) 30 | assert torch.allclose(weight, layer.weight_mean) 31 | assert torch.allclose(bias, layer.bias_mean) 32 | 33 | 34 | def test_init_from_deterministic_params_no_bias(): 35 | layer = bnn.nn.FFGLinear(5, 3, bias=False) 36 | weight = torch.randn(3, 5) 37 | layer.init_from_deterministic_params({"weight": weight}) 38 | assert torch.allclose(weight, layer.weight_mean) 39 | -------------------------------------------------------------------------------- /tests/bnn/nn/mixins/variational/test_inducing.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import itertools 3 | 4 | import pytest 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | from models.bnn import bayesianize_ 11 | from models.bnn.nn import InducingLinear, InducingConv2d, InducingDeterministicLinear, InducingDeterministicConv2d 12 | 13 | 14 | @pytest.mark.parametrize("q_inducing,whitened,max_lamda,max_sd_u,bias,layer_type,sqrt_width_scaling", itertools.product( 15 | ("diagonal", "matrix", "full"), 16 | (False, True), 17 | (None, 0.3), 18 | (None, 0.3), 19 | (False, True), 20 | ("linear", "conv"), 21 | (False, True) 22 | )) 23 | def test_forward_shape(q_inducing, whitened, max_lamda, max_sd_u, bias, layer_type, sqrt_width_scaling): 24 | inducing_rows = 4 25 | inducing_cols = 2 26 | batch_size = 5 27 | inducing_kwargs = dict( 28 | inducing_rows=inducing_rows, inducing_cols=inducing_cols, q_inducing=q_inducing, whitened_u=whitened, 29 | max_lamda=max_lamda, max_sd_u=max_sd_u, init_lamda=1, bias=bias, sqrt_width_scaling=sqrt_width_scaling) 30 | 31 | if layer_type == "linear": 32 | in_features = 3 33 | out_features = 5 34 | layer = InducingLinear(in_features, out_features, **inducing_kwargs) 35 | 36 | x = torch.randn(batch_size, in_features) 37 | expected_shape = (batch_size, out_features) 38 | elif layer_type == "conv": 39 | in_channels = 3 40 | out_channels = 6 41 | kernel_size = 3 42 | padding = 1 43 | layer = InducingConv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, **inducing_kwargs) 44 | 45 | h, w = 7, 7 46 | x = torch.randn(batch_size, in_channels, h, w) 47 | expected_shape = (batch_size, out_channels, h, w) 48 | else: 49 | raise ValueError(f"Invalid layer_type: {layer_type}") 50 | 51 | assert layer(x).shape == expected_shape 52 | 53 | 54 | @pytest.mark.parametrize("inference", ["inducing", "inducingdeterministic"]) 55 | def test_bayesianize_compatible(inference): 56 | net = nn.Sequential(nn.Conv2d(3, 8, 3), nn.Conv2d(8, 8, 3), nn.Linear(32, 16), nn.Linear(16, 8)) 57 | bnn = copy.deepcopy(net) 58 | bayesianize_(bnn, inference) 59 | 60 | for m, bm in zip(net.modules(), bnn.modules()): 61 | if m is net: 62 | continue 63 | 64 | if inference == "inducing": 65 | if isinstance(m, nn.Linear): 66 | assert isinstance(bm, InducingLinear) 67 | elif isinstance(m, nn.Conv2d): 68 | assert isinstance(bm, InducingConv2d) 69 | else: # unreachable 70 | assert False 71 | else: 72 | if isinstance(m, nn.Linear): 73 | assert isinstance(bm, InducingDeterministicLinear) 74 | elif isinstance(m, nn.Conv2d): 75 | assert isinstance(bm, InducingDeterministicConv2d) 76 | else: # unreachable 77 | assert False 78 | -------------------------------------------------------------------------------- /tests/bnn/test_bayesianize.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import pytest 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision.models import resnet18 8 | 9 | 10 | import models.bnn as bnn 11 | from models.bnn.bayesianize import bayesianize_, bayesian_from_template 12 | 13 | 14 | def assert_layers_equal(l1, l2): 15 | if isinstance(l1, nn.Linear): 16 | assert_linear_equal(l1, l2) 17 | elif isinstance(l1, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 18 | assert_conv_equal(l1, l2) 19 | else: 20 | raise ValueError("Unrecognized torch layer class:", l1.__class__.__name__) 21 | 22 | 23 | def assert_linear_equal(l1, l2): 24 | assert l1.in_features == l2.in_features 25 | assert l1.out_features == l2.out_features 26 | assert l1.weight.shape == l2.weight.shape 27 | assert (l1.bias is not None) == (l2.bias is not None) 28 | if l1.bias is not None: 29 | assert l1.bias.shape == l2.bias.shape 30 | 31 | 32 | def assert_conv_equal(l1, l2): 33 | assert all(getattr(l1, a) == getattr(l2, a) for a in 34 | ["in_channels", "out_channels", "kernel_size", "stride", "padding", "dilation", "groups"]) 35 | assert l1.weight.shape == l2.weight.shape 36 | assert (l1.bias is not None) == (l2.bias is not None) 37 | if l1.bias is not None: 38 | assert l1.bias.shape == l2.bias.shape 39 | 40 | 41 | @pytest.mark.parametrize("torch_layer,inference,target_class", [ 42 | (nn.Linear(5, 3, True), "ffg", bnn.nn.FFGLinear), 43 | (nn.Linear(5, 3, False), "ffg", bnn.nn.FFGLinear), 44 | (nn.Linear(5, 3, True), "fcg", bnn.nn.FCGLinear), 45 | (nn.Linear(5, 3, False), "fcg", bnn.nn.FCGLinear), 46 | (nn.Conv2d(16, 32, 3, bias=True), "ffg", bnn.nn.FFGConv2d), 47 | (nn.Conv2d(16, 32, 3, bias=False), "ffg", bnn.nn.FFGConv2d), 48 | (nn.Conv2d(16, 32, 3, stride=2, padding=1), "ffg", bnn.nn.FFGConv2d), 49 | (nn.Conv2d(16, 32, 3, bias=True), "fcg", bnn.nn.FCGConv2d), 50 | (nn.Conv2d(16, 32, 3, bias=False), "fcg", bnn.nn.FCGConv2d) 51 | ]) 52 | def test_template(torch_layer, inference, target_class): 53 | bayesian_layer = bayesian_from_template(torch_layer, inference) 54 | assert isinstance(bayesian_layer, target_class) 55 | assert_layers_equal(torch_layer, bayesian_layer) 56 | 57 | 58 | def test_template_raises_on_unkown_inference(): 59 | layer = nn.Linear(3, 2) 60 | with pytest.raises(KeyError): 61 | bayesian_from_template(layer, "abcdefghi") 62 | 63 | 64 | @pytest.mark.parametrize("inference,target_class", [ 65 | ("ffg", bnn.nn.FFGLinear), 66 | ("fcg", bnn.nn.FCGLinear) 67 | ]) 68 | def test_bayesianize_all(inference, target_class): 69 | net = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 2)) 70 | bnet = copy.deepcopy(net) 71 | bayesianize_(bnet, inference) 72 | 73 | assert isinstance(bnet[0], target_class) 74 | assert_layers_equal(net[0], bnet[0]) 75 | assert isinstance(bnet[2], target_class) 76 | assert_layers_equal(net[2], bnet[2]) 77 | assert isinstance(bnet[1], nn.ReLU) 78 | 79 | 80 | def test_bayesianize_resnet(): 81 | net = bnn.nn.nets.make_network("resnet18", kernel_size=3, out_features=10) 82 | bnet = copy.deepcopy(net) 83 | bayesianize_(bnet, "ffg") 84 | assert len(list(net.modules())) == len(list(bnet.modules())) 85 | for module, bmodule in zip(net.modules(), bnet.modules()): 86 | if isinstance(module, nn.Linear): 87 | assert isinstance(bmodule, bnn.nn.FFGLinear) 88 | assert_linear_equal(module, bmodule) 89 | elif isinstance(module, nn.Conv2d): 90 | assert isinstance(bmodule, bnn.nn.FFGConv2d) 91 | assert_conv_equal(module, bmodule) 92 | elif not list(module.modules()): 93 | # check for "elementary" modules like batchnorm, nonlinearities etc that the 94 | # class hasn't been changed. Checking for equality would be better, but isn't 95 | # really supported by pytorch, e.g. l == copy.deepcopy(l) will return False 96 | # for a BatchNorm layer 97 | assert module.__class__ == bmodule.__class__ 98 | else: 99 | # skip modules that collect other modules 100 | pass 101 | 102 | 103 | @pytest.mark.parametrize("inference", ["ffg", "fcg"]) 104 | def test_output_shapes(inference): 105 | net = nn.Sequential(nn.Linear(5, 4), nn.ReLU(), nn.Linear(4, 3)) 106 | bnet = copy.deepcopy(net) 107 | bayesianize_(bnet, inference) 108 | 109 | x = torch.randn(1, 5) 110 | assert net(x).shape == bnet(x).shape 111 | 112 | 113 | @pytest.mark.parametrize("inference", ["ffg", "fcg"]) 114 | def test_incorrect_input_error(inference): 115 | bnet = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 2)) 116 | bayesianize_(bnet, inference) 117 | x = torch.randn(1, 5) 118 | with pytest.raises(RuntimeError): 119 | bnet(x) 120 | 121 | 122 | def test_bayesianize_last_layer(): 123 | bnet = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 2)) 124 | bayesianize_(bnet, inference={"2": "ffg"}) 125 | # explicitly comparing the class for the first layer, since an FFGLinear object is an instance of nn.Linear 126 | assert bnet[0].__class__ == nn.Linear 127 | assert isinstance(bnet[2], bnn.nn.FFGLinear) 128 | 129 | 130 | def test_bayesianize_last_layer_index(): 131 | bnet = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 2)) 132 | bayesianize_(bnet, inference={-1: "fcg"}) 133 | assert bnet[0].__class__ == nn.Linear 134 | assert isinstance(bnet[2], bnn.nn.FCGLinear) 135 | 136 | 137 | def test_bayesianize_class_name(): 138 | bnet = bnn.nn.nets.CNN(channels=[3, 3], lin_sizes=[9, 10, 2], maxpool_freq=0) 139 | bayesianize_(bnet, inference={"Conv2d": "ffg", "Linear": "fcg"}) 140 | assert isinstance(bnet[0], bnn.nn.FFGConv2d) 141 | assert isinstance(bnet[3][0], bnn.nn.FCGLinear) 142 | assert isinstance(bnet[3][2], bnn.nn.FCGLinear) 143 | 144 | 145 | def test_module_name_priority_over_class_name(): 146 | bnet = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 2)) 147 | bayesianize_(bnet, inference={"0": "fcg", "Linear": "ffg"}) 148 | assert isinstance(bnet[0], bnn.nn.FCGLinear) 149 | assert isinstance(bnet[2], bnn.nn.FFGLinear) 150 | 151 | 152 | def test_initialize_ffg(): 153 | ref_net = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 2)) 154 | bnn = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 2)) 155 | bayesianize_(bnn, reference_state_dict=ref_net.state_dict(), inference="ffg") 156 | for m, bm in zip(ref_net.modules(), bnn.modules()): 157 | if isinstance(m, nn.Linear): 158 | assert torch.allclose(m.weight, bm.weight_mean) 159 | assert torch.allclose(m.bias, bm.bias_mean) 160 | 161 | 162 | def test_initialize_fcg(): 163 | ref_net = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 2)) 164 | bnn = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 2)) 165 | bayesianize_(bnn, reference_state_dict=ref_net.state_dict(), inference="fcg") 166 | for m, bm in zip(ref_net.modules(), bnn.modules()): 167 | if isinstance(m, nn.Linear): 168 | assert torch.allclose(torch.cat([m.weight.flatten(), m.bias]), bm.mean) 169 | -------------------------------------------------------------------------------- /tests/regression/test_variational_synthetic.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | import torch.distributions as dist 5 | 6 | 7 | import models.bnn as bnn 8 | 9 | 10 | def bayesian_regression(n, d, data_precision, prior_precision): 11 | x = torch.randn(n, d) 12 | w = torch.randn(d, 1) * prior_precision ** -0.5 13 | y = x.mm(w) + torch.randn(n, 1) * data_precision ** -0.5 14 | posterior_precision = data_precision * x.t().mm(x) + prior_precision * torch.eye(d) 15 | posterior_mean = data_precision * torch.solve(x.t().mm(y), posterior_precision).solution 16 | 17 | return x, y, posterior_mean, posterior_precision 18 | 19 | 20 | def fit_(module, x, y, data_precision, lr, epochs, num_decays): 21 | optim = torch.optim.Adam(module.parameters(), lr) 22 | scheduler = torch.optim.lr_scheduler.StepLR(optim, epochs // num_decays, 0.1) 23 | for _ in range(epochs): 24 | optim.zero_grad() 25 | yhat = module(x) 26 | nll = -dist.Normal(yhat, data_precision ** -0.5).log_prob(y).sum() 27 | kl = module.kl_divergence() 28 | loss = nll + kl 29 | loss.backward() 30 | optim.step() 31 | scheduler.step() 32 | 33 | @pytest.mark.parametrize("n,d,data_precision,prior_precision,local_reparameterization", [ 34 | (10, 3, 100., 1., True), 35 | (10, 3, 100., 1., False), 36 | (100, 6, 10, 5., True) 37 | ]) 38 | def test_ffg(n, d, data_precision, prior_precision, local_reparameterization): 39 | torch.manual_seed(42) 40 | x, y, mu, lamda = bayesian_regression(n, d, data_precision, prior_precision) 41 | 42 | l = bnn.nn.FFGLinear(d, 1, bias=False, prior_weight_sd=prior_precision ** -0.5) 43 | if not local_reparameterization: 44 | l.eval() 45 | fit_(l, x, y, data_precision, 1e-1, 4000, 4) 46 | 47 | assert torch.allclose(mu.squeeze(), l.weight_mean.squeeze(), atol=1e-2) 48 | assert torch.allclose(lamda.diagonal() ** -0.5, l.weight_sd.squeeze(), atol=1e-2) 49 | assert dist.kl_divergence(dist.Normal(l.weight_mean.squeeze(), l.weight_sd.squeeze()), 50 | dist.Normal(mu.squeeze(), lamda.diag() ** -0.5)).sum().item() < 1e-2 51 | 52 | 53 | @pytest.mark.parametrize("n,d,data_precision,prior_precision", [ 54 | (10, 3, 100., 1.), 55 | (25, 4, 250., 100.) 56 | ]) 57 | def test_fcg(n, d, data_precision, prior_precision): 58 | torch.manual_seed(42) 59 | x, y, mu, lamda = bayesian_regression(n, d, data_precision, prior_precision) 60 | 61 | l = bnn.nn.FCGLinear(d, 1, bias=False, prior_weight_sd=prior_precision ** -0.5) 62 | fit_(l, x, y, data_precision, 0.1, 4000, 4) 63 | 64 | assert torch.allclose(mu.squeeze(), l.mean, atol=1e-2) 65 | assert torch.allclose(lamda.inverse(), l.scale_tril.mm(l.scale_tril.t()), atol=1e-2) 66 | assert dist.kl_divergence(l.parameter_distribution, 67 | dist.MultivariateNormal(mu.squeeze(), precision_matrix=lamda)).item() < 1e-2 68 | --------------------------------------------------------------------------------