├── models ├── __init__.py ├── autoguide.py ├── wrn.py ├── network.py └── pyro.py ├── utils ├── __init__.py ├── metrics.py ├── test.py ├── utils.py ├── wilds_utils.py └── data_utils.py ├── baselines ├── __init__.py ├── bbb │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ ├── lenet.py │ │ └── wrn.py │ ├── README.md │ ├── train.sh │ └── train.py ├── csghmc │ ├── __init__.py │ ├── README.md │ ├── train.sh │ ├── csghmc.py │ └── train.py ├── swag │ ├── __init__.py │ ├── swag_utils.py │ └── swag.py └── vanilla │ └── models │ ├── __init__.py │ ├── lenet.py │ ├── mlp.py │ ├── wrn.py │ └── wrn_fixup.py ├── run_scripts ├── al_train.sh ├── run_temp_uq.sh ├── ll_train.sh ├── run_al_fmnist_uq.sh ├── run_cifar100_uq.sh ├── run_cifar_uq.sh └── run_fmnist_uq.sh ├── LICENSE ├── README.md ├── .gitignore ├── al_train.py ├── ll_train.py └── uq.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baselines/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baselines/bbb/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baselines/csghmc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baselines/swag/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baselines/bbb/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baselines/vanilla/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baselines/csghmc/README.md: -------------------------------------------------------------------------------- 1 | # Cyclical Stochastic Gradient HMC 2 | 3 | Based on . Reference: . 4 | -------------------------------------------------------------------------------- /baselines/bbb/README.md: -------------------------------------------------------------------------------- 1 | Install bayesian-torch first: https://github.com/IntelLabs/bayesian-torch 2 | 3 | ``` 4 | git clone https://github.com/IntelLabs/bayesian-torch 5 | cd bayesian-torch 6 | pip install . 7 | ``` 8 | -------------------------------------------------------------------------------- /baselines/bbb/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python baselines/bbb/train.py --dataset MNIST --randseed 972394 4 | python baselines/bbb/train.py --dataset MNIST --randseed 12 5 | python baselines/bbb/train.py --dataset MNIST --randseed 523 6 | python baselines/bbb/train.py --dataset MNIST --randseed 13 7 | python baselines/bbb/train.py --dataset MNIST --randseed 6 8 | 9 | python baselines/bbb/train.py --dataset CIFAR10 --randseed 972394 10 | python baselines/bbb/train.py --dataset CIFAR10 --randseed 12 11 | python baselines/bbb/train.py --dataset CIFAR10 --randseed 523 12 | python baselines/bbb/train.py --dataset CIFAR10 --randseed 13 13 | python baselines/bbb/train.py --dataset CIFAR10 --randseed 6 14 | -------------------------------------------------------------------------------- /baselines/csghmc/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python baselines/csghmc/train.py --dataset MNIST --randseed 972394 4 | python baselines/csghmc/train.py --dataset MNIST --randseed 12 5 | python baselines/csghmc/train.py --dataset MNIST --randseed 523 6 | python baselines/csghmc/train.py --dataset MNIST --randseed 13 7 | python baselines/csghmc/train.py --dataset MNIST --randseed 6 8 | 9 | python baselines/csghmc/train.py --dataset CIFAR10 --randseed 972394 10 | python baselines/csghmc/train.py --dataset CIFAR10 --randseed 12 11 | python baselines/csghmc/train.py --dataset CIFAR10 --randseed 523 12 | python baselines/csghmc/train.py --dataset CIFAR10 --randseed 13 13 | python baselines/csghmc/train.py --dataset CIFAR10 --randseed 6 14 | -------------------------------------------------------------------------------- /baselines/vanilla/models/lenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class LeNet(nn.Module): 10 | 11 | def __init__(self, num_classes=10): 12 | super().__init__() 13 | 14 | self.net = nn.Sequential( 15 | torch.nn.Conv2d(1, 6, 5), 16 | torch.nn.ReLU(), 17 | torch.nn.MaxPool2d(2), 18 | torch.nn.Conv2d(6, 16, 5), 19 | torch.nn.ReLU(), 20 | torch.nn.MaxPool2d(2), 21 | torch.nn.Flatten(), 22 | torch.nn.Linear(16 * 4 * 4, 120), 23 | torch.nn.ReLU(), 24 | torch.nn.Linear(120, 84), 25 | torch.nn.ReLU(), 26 | torch.nn.Linear(84, num_classes) 27 | ) 28 | 29 | def forward(self, x): 30 | return self.net(x) 31 | -------------------------------------------------------------------------------- /baselines/vanilla/models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | # Taken from https://github.com/team-approx-bayes/fromp/blob/main/models.py 5 | # Fully connected network, size: input_size, hidden_size, ..., output_size 6 | class MLP(nn.Module): 7 | def __init__(self, size, act='sigmoid'): 8 | super(type(self), self).__init__() 9 | self.num_layers = len(size) - 1 10 | lower_modules = [] 11 | for i in range(self.num_layers - 1): 12 | lower_modules.append(nn.Linear(size[i], size[i+1])) 13 | if act == 'relu': 14 | lower_modules.append(nn.ReLU()) 15 | elif act == 'sigmoid': 16 | lower_modules.append(nn.Sigmoid()) 17 | else: 18 | raise ValueError(f"{act} activation hasn't been implemented") 19 | self.layer_1 = nn.Sequential(*lower_modules) 20 | self.layer_2 = nn.Linear(size[-2], size[-1]) 21 | 22 | def forward(self, x): 23 | o = self.layer_1(x) 24 | o = self.layer_2(o) 25 | return o 26 | -------------------------------------------------------------------------------- /run_scripts/al_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu-2080ti # partition 3 | #SBATCH --gres=gpu:rtx2080ti:1 # type and number of gpus 4 | #SBATCH --time=72:00:00 # job will be cancelled after max. 72h 5 | #SBATCH --output=al_train_%A_%a.out 6 | #SBATCH --array=1-5 7 | 8 | # print info about current job 9 | scontrol show job $SLURM_JOB_ID 10 | 11 | declare -a chain_ids=(2 3) 12 | 13 | # MAP 14 | python al_train.py --method map --dataset fmnist --randseed $SLURM_ARRAY_TASK_ID 15 | 16 | # HMC 17 | for chain_id in "${chain_ids[@]}"; 18 | do 19 | python al_train.py --method hmc --n_burnins 100 --n_samples 200 --dataset fmnist --randseed $SLURM_ARRAY_TASK_ID --chain_id $chain_id 20 | done 21 | 22 | # Refine 23 | python al_train.py --method refine --flow_type radial --n_flows 1 --dataset fmnist --randseed $SLURM_ARRAY_TASK_ID 24 | # 25 | for n_flows in {5..100..5}; 26 | do 27 | echo $n_flows 28 | python al_train.py --method refine --flow_type radial --n_flows $n_flows --dataset fmnist --randseed $SLURM_ARRAY_TASK_ID 29 | done 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Runa Eschenhagen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /run_scripts/run_temp_uq.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu-2080ti # partition 3 | #SBATCH --gres=gpu:rtx2080ti:1 # type and number of gpus 4 | #SBATCH --time=72:00:00 # job will be cancelled after max. 72h 5 | #SBATCH --output=temp_uq_%A_%a.out 6 | #SBATCH --array=1-5 7 | 8 | # print info about current job 9 | scontrol show job $SLURM_JOB_ID 10 | 11 | declare -a fmnist_datasets=("FMNIST-OOD" "R-FMNIST") 12 | declare -a cifar_datasets=("CIFAR-10-OOD" "CIFAR-10-C" "CIFAR-100-OOD") 13 | # Set the path to your data and models directories here. 14 | data_root=/mnt/qb/hennig/data 15 | models_root=/mnt/qb/hennig/pretrained_models 16 | 17 | for dataset in "${fmnist_datasets[@]}"; 18 | do 19 | # Assuming you have activated your conda environment 20 | python uq.py --benchmark $dataset --method map --use_temperature_scaling True --model LeNet --data_root ${data_root} --models_root ${models_root} --model_seed $SLURM_ARRAY_TASK_ID 21 | done 22 | 23 | for dataset in "${cifar_datasets[@]}"; 24 | do 25 | # Assuming you have activated your conda environment 26 | python uq.py --benchmark $dataset --method map --use_temperature_scaling True --model WRN16-4 --data_root ${data_root} --models_root ${models_root} --model_seed $SLURM_ARRAY_TASK_ID 27 | done 28 | -------------------------------------------------------------------------------- /run_scripts/ll_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu-2080ti # partition 3 | #SBATCH --gres=gpu:rtx2080ti:1 # type and number of gpus 4 | #SBATCH --time=72:00:00 # job will be cancelled after max. 72h 5 | #SBATCH --output=ll_train_%A_%a.out 6 | #SBATCH --array=1-5 7 | 8 | # print info about current job 9 | scontrol show job $SLURM_JOB_ID 10 | 11 | declare -a datasets=("fmnist" "cifar10" "cifar100") 12 | declare -a chain_ids=(1 2 3) 13 | 14 | for dataset in "${datasets[@]}"; 15 | do 16 | # MAP 17 | python ll_train.py --method map --dataset $dataset --randseed $SLURM_ARRAY_TASK_ID 18 | 19 | # HMC 20 | for chain_id in "${chain_ids[@]}"; 21 | do 22 | python ll_train.py --method hmc --n_burnins 100 --n_samples 200 --dataset $dataset --randseed $SLURM_ARRAY_TASK_ID --chain_id $chain_id 23 | done 24 | 25 | # NF+N(0,I) 26 | python ll_train.py --method nf_naive --flow_type radial --n_flows 1 --dataset $dataset --randseed $SLURM_ARRAY_TASK_ID 27 | 28 | for n_flows in {5..30..5}; 29 | do 30 | echo $n_flows 31 | python ll_train.py --method nf_naive --flow_type radial --n_flows $n_flows --dataset $dataset --randseed $SLURM_ARRAY_TASK_ID 32 | done 33 | 34 | 35 | # Refine 36 | python ll_train.py --method refine --flow_type radial --n_flows 1 --dataset $dataset --randseed $SLURM_ARRAY_TASK_ID 37 | 38 | for n_flows in {5..30..5}; 39 | do 40 | echo $n_flows 41 | python ll_train.py --method refine --flow_type radial --n_flows $n_flows --dataset $dataset --randseed $SLURM_ARRAY_TASK_ID 42 | done 43 | done 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Posterior Refinement Improves Sample Efficiency in BNNs 2 | 3 | This repository contains the code to run the experiments for the paper [Posterior Refinement Improves Sample Efficiency in Bayesian Neural Networks](https://arxiv.org/abs/2205.10041) (NeurIPS 2022), using our library [laplace](https://github.com/AlexImmer/Laplace/). 4 | 5 | The files `al_train.py` and `ll_train.py` show how to refine an all-layer and last-layer Laplace posterior approximation _post hoc_. 6 | To run them, you can use the commands in the bash scripts in `run_scripts` with the corresponding name. 7 | Specifically, the practically most relevant code for last-layer refinement is in [line 165-230](https://github.com/runame/laplace-refinement/blob/main/ll_train.py#L165-L230) in `ll_train.py`. 8 | 9 | The method boils down to this very simple statement: 10 | > Fine-tune your last-layer Laplace posterior with a normalizing flow. No need for a long, complicated flow; no need for many epochs for training the flow. 11 | 12 | The file `uq.py` contains the code to run all the uncertainty quantification experiments on F-MNIST, CIFAR-10, and CIFAR-100. The commands to run the experiments are in the `run_*_uq.sh` files in the `run_scripts` folder. 13 | 14 | _Note: Depending on your setup, you might have to copy the bash files in `run_scripts` to the root of this repo to be able to run the commands unchanged._ 15 | 16 | Please cite the paper if you want to refer to the method: 17 | ```bibtex 18 | @inproceedings{kristiadi2022refinement, 19 | title={Posterior Refinement Improves Sample Efficiency in {B}ayesian Neural Networks}, 20 | author={Agustinus Kristiadi and Runa Eschenhagen and Philipp Hennig}, 21 | booktitle={{N}eur{IPS}}, 22 | year={2022} 23 | } 24 | ``` 25 | -------------------------------------------------------------------------------- /baselines/csghmc/csghmc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class CSGHMCTrainer(): 6 | 7 | def __init__(self, model, n_cycles, n_samples_per_cycle, n_epochs, initial_lr, num_batch, total_iters, data_size, weight_decay=5e-4, alpha=0.9): 8 | self.model = model 9 | self.n_cycles = n_cycles 10 | self.n_samples_per_cycle = n_samples_per_cycle 11 | self.n_epochs = n_epochs 12 | self.epoch_per_cycle = n_epochs // n_cycles 13 | self.num_batch = num_batch 14 | self.total_iters = total_iters 15 | self.data_size = data_size 16 | self.weight_decay = weight_decay 17 | self.alpha = alpha 18 | self.temperature = 1/data_size 19 | 20 | self.initial_lr = initial_lr 21 | self.lr = initial_lr 22 | 23 | def adjust_lr(self, epoch, batch_idx): 24 | rcounter = epoch * self.num_batch + batch_idx 25 | cos_inner = np.pi * (rcounter % (self.total_iters // self.n_cycles)) 26 | cos_inner /= self.total_iters // self.n_cycles 27 | cos_out = np.cos(cos_inner) + 1 28 | 29 | self.lr = 0.5 * cos_out * self.initial_lr 30 | 31 | def update_params(self, epoch): 32 | for p in self.model.parameters(): 33 | if not hasattr(p, 'buf'): 34 | p.buf = torch.zeros(p.size()).cuda() 35 | 36 | d_p = p.grad 37 | d_p.add_(p, alpha=self.weight_decay) 38 | 39 | buf_new = (1-self.alpha) * p.buf - self.lr * d_p 40 | 41 | if (epoch % self.epoch_per_cycle) + 1 > self.epoch_per_cycle - self.n_samples_per_cycle: 42 | eps = torch.randn(p.size()).cuda() 43 | buf_new += (2*self.lr * self.alpha * self.temperature / self.data_size)**.5 * eps 44 | 45 | p.data.add_(buf_new) 46 | p.buf = buf_new 47 | -------------------------------------------------------------------------------- /baselines/bbb/models/lenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from bayesian_torch.layers import (Conv2dFlipout, Conv2dReparameterization, 4 | LinearFlipout, LinearReparameterization) 5 | 6 | 7 | class LeNetBBB(nn.Module): 8 | 9 | def __init__(self, num_classes=10, var0=1, estimator='flipout'): 10 | _check_estimator(estimator) 11 | super().__init__() 12 | 13 | Conv2dVB = Conv2dReparameterization if estimator == 'reparam' else Conv2dFlipout 14 | LinearVB = LinearReparameterization if estimator == 'reparam' else LinearFlipout 15 | 16 | self.conv1 = Conv2dVB(1, 6, 5, prior_variance=var0) 17 | self.conv2 = Conv2dVB(6, 16, 5, prior_variance=var0) 18 | self.flatten = nn.Flatten() 19 | self.fc1 = LinearVB(256, 120, prior_variance=var0) 20 | self.fc2 = LinearVB(120, 84, prior_variance=var0) 21 | self.fc3 = LinearVB(84, num_classes, prior_variance=var0) 22 | 23 | def forward(self, x): 24 | x, kl_total = self.features(x) 25 | x, kl = self.fc3(x) 26 | kl_total += kl 27 | 28 | return x, kl_total 29 | 30 | 31 | def features(self, x, return_acts=False): 32 | kl_total = 0 33 | 34 | x, kl = self.conv1(x) 35 | kl_total += kl 36 | x = F.max_pool2d(F.relu(x), 2, 2) 37 | 38 | x, kl = self.conv2(x) 39 | kl_total += kl 40 | x = F.max_pool2d(F.relu(x), 2, 2) 41 | 42 | x = self.flatten(x) 43 | 44 | x, kl = self.fc1(x) 45 | kl_total += kl 46 | x = F.relu(x) 47 | 48 | x, kl = self.fc2(x) 49 | kl_total += kl 50 | x = F.relu(x) 51 | 52 | return x, kl_total 53 | 54 | 55 | def _check_estimator(estimator): 56 | assert estimator in ['reparam', 'flipout'], 'Estimator must be either "reparam" or "flipout"' 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /run_scripts/run_al_fmnist_uq.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu-v100 # partition 3 | #SBATCH --gres=gpu:v100:1 # type and number of gpus 4 | #SBATCH --time=72:00:00 # job will be cancelled after max. 72h 5 | #SBATCH --output=fmnist_al_uq_%A_%a.out 6 | #SBATCH --array=1-5 7 | 8 | # print info about current job 9 | scontrol show job $SLURM_JOB_ID 10 | 11 | declare -a datasets=("FMNIST-OOD" "R-FMNIST") 12 | declare -a prior_optim=("marglik" "CV") 13 | # Set the path to your data and models directories here. 14 | data_root=/mnt/qb/hennig/data/ 15 | models_root=/mnt/qb/hennig/pretrained_models 16 | 17 | for dataset in "${datasets[@]}"; 18 | do 19 | # Assuming you have activated your conda environment 20 | # MAP 21 | python uq.py --benchmark $dataset --method map --model FMNIST-MLP --data_root ${data_root} --models_root ${models_root} --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID --run_name $dataset/al_map_$SLURM_ARRAY_TASK_ID 22 | 23 | # HMC 24 | python uq.py --benchmark $dataset --method hmc --subset_of_weights all --prior_precision 510 --model FMNIST-MLP --data_root ${data_root} --models_root ${models_root} --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID --run_name $dataset/al_hmc_$SLURM_ARRAY_TASK_ID 25 | 26 | for prior in "${prior_optim[@]}"; 27 | do 28 | # LA-NN-MC 29 | python uq.py --benchmark $dataset --method laplace --subset_of_weights all --hessian_structure diag --optimize_prior_precision $prior --pred_type nn --link_approx mc --n_samples 20 --model FMNIST-MLP --data_root ${data_root} --models_root ${models_root} --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID --run_name $dataset/al_la_nn_mc_${prior}_$SLURM_ARRAY_TASK_ID 30 | 31 | # LA-MC 32 | python uq.py --benchmark $dataset --method laplace --subset_of_weights all --hessian_structure diag --optimize_prior_precision $prior --pred_type glm --link_approx mc --n_samples 20 --model FMNIST-MLP --data_root ${data_root} --models_root ${models_root} --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID --run_name $dataset/al_la_glm_mc_${prior}_$SLURM_ARRAY_TASK_ID 33 | 34 | # LA-Probit 35 | python uq.py --benchmark $dataset --method laplace --subset_of_weights all --hessian_structure diag --optimize_prior_precision $prior --pred_type glm --link_approx probit --model FMNIST-MLP --data_root ${data_root} --models_root ${models_root} --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID --run_name $dataset/al_la_glm_probit_${prior}_$SLURM_ARRAY_TASK_ID 36 | done 37 | 38 | echo 1 39 | python uq.py --benchmark $dataset --method refine_radial_1 --model LeNet --data_root /mnt/qb/hennig/data/ --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID --subset_of_weights all 40 | 41 | for n_flows in {5..100..5}; 42 | do 43 | echo $n_flows 44 | python uq.py --benchmark $dataset --method refine_radial_$n_flows --model LeNet --data_root /mnt/qb/hennig/data/ --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID --subset_of_weights all 45 | done 46 | done 47 | -------------------------------------------------------------------------------- /baselines/swag/swag_utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | # Taken from https://github.com/wjmaddox/swa_gaussian/blob/master/swag/utils.py 7 | 8 | def flatten(lst): 9 | tmp = [i.contiguous().view(-1, 1) for i in lst] 10 | return torch.cat(tmp).view(-1) 11 | 12 | 13 | def unflatten_like(vector, likeTensorList): 14 | # Takes a flat torch.tensor and unflattens it to a list of torch.tensors 15 | # shaped like likeTensorList 16 | outList = [] 17 | i = 0 18 | for tensor in likeTensorList: 19 | # n = module._parameters[name].numel() 20 | n = tensor.numel() 21 | outList.append(vector[:, i: i + n].view(tensor.shape)) 22 | i += n 23 | return outList 24 | 25 | 26 | def _check_bn(module, flag): 27 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 28 | flag[0] = True 29 | 30 | 31 | def check_bn(model): 32 | flag = [False] 33 | model.apply(lambda module: _check_bn(module, flag)) 34 | return flag[0] 35 | 36 | 37 | def reset_bn(module): 38 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 39 | module.running_mean = torch.zeros_like(module.running_mean) 40 | module.running_var = torch.ones_like(module.running_var) 41 | 42 | 43 | def _get_momenta(module, momenta): 44 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 45 | momenta[module] = module.momentum 46 | 47 | 48 | def _set_momenta(module, momenta): 49 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 50 | module.momentum = momenta[module] 51 | 52 | 53 | def bn_update(loader, model, verbose=False, subset=None, **kwargs): 54 | """ 55 | BatchNorm buffers update (if any). 56 | Performs 1 epochs to estimate buffers average using train dataset. 57 | :param loader: train dataset loader for buffers average estimation. 58 | :param model: model being update 59 | :return: None 60 | """ 61 | if not check_bn(model): 62 | return 63 | model.train() 64 | momenta = {} 65 | model.apply(reset_bn) 66 | model.apply(lambda module: _get_momenta(module, momenta)) 67 | n = 0 68 | num_batches = len(loader) 69 | 70 | with torch.no_grad(): 71 | if subset is not None: 72 | loader = itertools.islice(loader, int(subset * len(loader))) 73 | 74 | if verbose: 75 | loader = tqdm(loader, total=num_batches) 76 | 77 | for input, _ in loader: 78 | input = input.cuda(non_blocking=True) 79 | input_var = torch.autograd.Variable(input) 80 | b = input_var.data.size(0) 81 | 82 | momentum = b / (n + b) 83 | for module in momenta.keys(): 84 | module.momentum = momentum 85 | 86 | model(input_var, **kwargs) 87 | n += b 88 | 89 | apply_bn_update(model, momenta) 90 | return momenta 91 | 92 | 93 | def apply_bn_update(model, momenta): 94 | model.apply(lambda module: _set_momenta(module, momenta)) 95 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.distance import pdist 3 | from sklearn import metrics 4 | 5 | 6 | def mmd_rbf(X, Y): 7 | """ 8 | MMD using rbf (gaussian) kernel (i.e., k(x,y) = exp(-gamma * ||x-y||^2 / 2)). 9 | Taken from: https://github.com/jindongwang/transferlearning 10 | 11 | Arguments: 12 | X {[n_sample1, dim]} -- [X matrix] 13 | Y {[n_sample2, dim]} -- [Y matrix] 14 | 15 | Returns: 16 | [scalar] -- [MMD value] 17 | """ 18 | # Median heuristic --- use some hold-out samples 19 | all_samples = np.concatenate([X[:50], Y[:50]], 0) 20 | pdists = pdist(all_samples) 21 | sigma = np.median(pdists) 22 | gamma=1/(sigma**2) 23 | 24 | XX = metrics.pairwise.rbf_kernel(X, X, gamma) 25 | YY = metrics.pairwise.rbf_kernel(Y, Y, gamma) 26 | XY = metrics.pairwise.rbf_kernel(X, Y, gamma) 27 | 28 | return XX.mean() + YY.mean() - 2 * XY.mean() 29 | 30 | 31 | def accuracy(y_pred, y_true): 32 | try: 33 | y_pred, y_true = y_pred.detach().cpu().numpy(), y_true.cpu().numpy() 34 | finally: 35 | return np.mean(y_pred.argmax(1) == y_true).mean()*100 36 | 37 | 38 | def nll(y_pred, y_true): 39 | """ 40 | Mean Categorical negative log-likelihood. `y_pred` is a probability vector. 41 | """ 42 | try: 43 | y_pred, y_true = y_pred.detach().cpu().numpy(), y_true.cpu().numpy() 44 | finally: 45 | return metrics.log_loss(y_true, y_pred) 46 | 47 | 48 | def brier(y_pred, y_true): 49 | try: 50 | y_pred, y_true = y_pred.detach().cpu().numpy(), y_true.cpu().numpy() 51 | finally: 52 | def one_hot(targets, nb_classes): 53 | res = np.eye(nb_classes)[np.array(targets).reshape(-1)] 54 | return res.reshape(list(targets.shape)+[nb_classes]) 55 | 56 | return metrics.mean_squared_error(y_pred, one_hot(y_true, y_pred.shape[-1])) 57 | 58 | 59 | def calibration(pys, y_true, M=100): 60 | try: 61 | pys, y_true = pys.detach().cpu().numpy(), y_true.cpu().numpy() 62 | finally: 63 | # Put the confidence into M bins 64 | _, bins = np.histogram(pys, M, range=(0, 1)) 65 | 66 | labels = pys.argmax(1) 67 | confs = np.max(pys, axis=1) 68 | conf_idxs = np.digitize(confs, bins) 69 | 70 | # Accuracy and avg. confidence per bin 71 | accs_bin = [] 72 | confs_bin = [] 73 | nitems_bin = [] 74 | 75 | for i in range(M): 76 | labels_i = labels[conf_idxs == i] 77 | y_true_i = y_true[conf_idxs == i] 78 | confs_i = confs[conf_idxs == i] 79 | 80 | acc = np.nan_to_num(np.mean(labels_i == y_true_i), 0) 81 | conf = np.nan_to_num(np.mean(confs_i), 0) 82 | 83 | accs_bin.append(acc) 84 | confs_bin.append(conf) 85 | nitems_bin.append(len(labels_i)) 86 | 87 | accs_bin, confs_bin = np.array(accs_bin), np.array(confs_bin) 88 | nitems_bin = np.array(nitems_bin) 89 | 90 | ECE = np.average(np.abs(confs_bin-accs_bin), weights=nitems_bin/nitems_bin.sum()) 91 | MCE = np.max(np.abs(accs_bin - confs_bin)) 92 | 93 | # In percent 94 | ECE, MCE = ECE*100, MCE*100 95 | 96 | return ECE, MCE 97 | -------------------------------------------------------------------------------- /models/autoguide.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | from pyro import distributions as dist 5 | from pyro.infer.autoguide import AutoContinuous, AutoNormalizingFlow 6 | from pyro.infer.autoguide.initialization import init_to_feasible 7 | 8 | FLOW_TYPES = { 9 | 'planar': dist.transforms.planar, 10 | 'radial': dist.transforms.radial 11 | } 12 | 13 | 14 | class AutoNormalizingFlowCustom(AutoNormalizingFlow): 15 | """ 16 | AutoNormalizingFlow guide with a custom base distribution 17 | """ 18 | 19 | def __init__(self, model, base_dist_mean, base_dist_Cov, diag=False, flow_type='radial', flow_len=5, cuda=False): 20 | init_transform_fn = functools.partial(dist.transforms.iterated, flow_len, FLOW_TYPES[flow_type]) 21 | super().__init__(model, init_transform_fn) 22 | 23 | self.base_dist_mean = base_dist_mean 24 | self.base_dist_Cov = base_dist_Cov 25 | 26 | if diag: 27 | assert self.base_dist_Cov.shape == self.base_dist_mean.shape 28 | self.base_dist = dist.Normal(self.base_dist_mean, torch.sqrt(self.base_dist_Cov)) 29 | else: 30 | self.base_dist = dist.MultivariateNormal(self.base_dist_mean, self.base_dist_Cov) 31 | 32 | self.transform = None 33 | self._prototype_tensor = torch.tensor(0.0, device='cuda' if cuda else 'cpu') 34 | self.cuda = cuda 35 | 36 | def get_base_dist(self): 37 | return self.base_dist 38 | 39 | def get_posterior(self, *args, **kwargs): 40 | if self.transform is None: 41 | self.transform = self._init_transform_fn(self.latent_dim) 42 | 43 | if self.cuda: 44 | self.transform.to('cuda:0') 45 | 46 | # Update prototype tensor in case transform parameters 47 | # device/dtype is not the same as default tensor type. 48 | for _, p in self.named_pyro_params(): 49 | self._prototype_tensor = p 50 | break 51 | return super().get_posterior(*args, **kwargs) 52 | 53 | 54 | class AutoNormalizingFlowCuda(AutoContinuous): 55 | 56 | def __init__(self, model, init_transform_fn, cuda=True): 57 | super().__init__(model, init_loc_fn=init_to_feasible) 58 | self._init_transform_fn = init_transform_fn 59 | self.transform = None 60 | self._prototype_tensor = torch.tensor(0.0, device='cuda' if cuda else 'cpu') 61 | self.cuda = cuda 62 | 63 | def get_base_dist(self): 64 | loc = self._prototype_tensor.new_zeros(1) 65 | scale = self._prototype_tensor.new_ones(1) 66 | return dist.Normal(loc, scale).expand([self.latent_dim]).to_event(1) 67 | 68 | 69 | def get_transform(self, *args, **kwargs): 70 | return self.transform 71 | 72 | 73 | def get_posterior(self, *args, **kwargs): 74 | if self.transform is None: 75 | self.transform = self._init_transform_fn(self.latent_dim) 76 | 77 | if self.cuda: 78 | self.transform.to('cuda:0') 79 | 80 | # Update prototype tensor in case transform parameters 81 | # device/dtype is not the same as default tensor type. 82 | for _, p in self.named_pyro_params(): 83 | self._prototype_tensor = p 84 | break 85 | return super().get_posterior(*args, **kwargs) 86 | -------------------------------------------------------------------------------- /run_scripts/run_cifar100_uq.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu-v100 # partition 3 | #SBATCH --gres=gpu:v100:1 # type and number of gpus 4 | #SBATCH --time=72:00:00 # job will be cancelled after max. 72h 5 | #SBATCH --output=cifar100_uq_%A_%a.out 6 | #SBATCH --array=1-5 7 | 8 | # print info about current job 9 | scontrol show job $SLURM_JOB_ID 10 | 11 | # Set the path to your data and models directories here. 12 | data_root=/mnt/qb/hennig/data/ 13 | models_root=./pretrained_models 14 | 15 | declare -a methods=("map" "hmc") 16 | for method in "${methods[@]}"; 17 | do 18 | echo ${method} 19 | python uq.py --benchmark CIFAR-100-OOD --method ${method} --prior_precision 40 --model WRN16-4 --data_root ${data_root} --models_root ${models_root} --model_seed $SLURM_ARRAY_TASK_ID --compute_mmd 20 | done 21 | 22 | declare -a prior_optim=("marglik" "CV") 23 | for prior in "${prior_optim[@]}"; 24 | do 25 | # LA-NN-MC 26 | python uq.py --benchmark CIFAR-100-OOD --method laplace --subset_of_weights last_layer --hessian_structure diag --optimize_prior_precision $prior --pred_type nn --link_approx mc --n_samples 20 --model WRN16-4 --data_root ${data_root} --models_root ${models_root} --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID --run_name CIFAR-100-OOD/la_nn_mc_${prior}_$SLURM_ARRAY_TASK_ID 27 | 28 | # LA-MC 29 | python uq.py --benchmark CIFAR-100-OOD --method laplace --subset_of_weights last_layer --hessian_structure diag --optimize_prior_precision $prior --pred_type glm --link_approx mc --n_samples 20 --model WRN16-4 --data_root ${data_root} --models_root ${models_root} --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID --run_name CIFAR-100-OOD/la_glm_mc_${prior}_$SLURM_ARRAY_TASK_ID 30 | 31 | # LA-Probit 32 | python uq.py --benchmark CIFAR-100-OOD --method laplace --subset_of_weights last_layer --hessian_structure diag --optimize_prior_precision $prior --pred_type glm --link_approx probit --model WRN16-4 --data_root ${data_root} --models_root ${models_root} --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID --run_name CIFAR-100-OOD/la_glm_probit_${prior}_$SLURM_ARRAY_TASK_ID 33 | done 34 | 35 | declare -a nfmethods=("refine" "nf_naive") 36 | for nfmethod in "${nfmethods[@]}"; 37 | do 38 | echo 1 39 | python uq.py --benchmark CIFAR-100-OOD --method ${nfmethod}_radial_1 --prior_precision 40 --model WRN16-4 --data_root ${data_root} --models_root ${models_root} --model_seed $SLURM_ARRAY_TASK_ID --compute_mmd 40 | for n_flows in {5..30..5}; 41 | do 42 | echo $n_flows 43 | python uq.py --benchmark CIFAR-100-OOD --method ${nfmethod}_radial_${n_flows} --prior_precision 40 --model WRN16-4 --data_root ${data_root} --models_root ${models_root} --model_seed $SLURM_ARRAY_TASK_ID --compute_mmd 44 | done 45 | done 46 | 47 | # Baselines 48 | 49 | python uq.py --benchmark CIFAR-100-OOD --method ensemble --nr_components 5 --model WRN16-4 --data_root ${data_root} --models_root ${models_root} --model_seed $SLURM_ARRAY_TASK_ID 50 | 51 | python uq.py --benchmark CIFAR-100-OOD --method bbb --model WRN16-4-BBB-flipout --normalize --data_root ${data_root} --models_root ${models_root} --model_seed $SLURM_ARRAY_TASK_ID 52 | 53 | python uq.py --benchmark CIFAR-100-OOD --method csghmc --model WRN16-4-CSGHMC --normalize --data_root ${data_root} --models_root ${models_root} --model_seed $SLURM_ARRAY_TASK_ID 54 | -------------------------------------------------------------------------------- /run_scripts/run_cifar_uq.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu-2080ti # partition 3 | #SBATCH --gres=gpu:rtx2080ti:1 # type and number of gpus 4 | #SBATCH --time=72:00:00 # job will be cancelled after max. 72h 5 | #SBATCH --output=cifar10_uq_%A_%a.out 6 | #SBATCH --array=1-5 7 | 8 | # print info about current job 9 | scontrol show job $SLURM_JOB_ID 10 | 11 | declare -a datasets=("CIFAR-10-OOD" "CIFAR-10-C") 12 | declare -a nfmethods=("refine" "nf_naive") 13 | declare -a prior_optim=("marglik" "CV") 14 | # Set the path to your data and models directories here. 15 | data_root=/mnt/qb/hennig/data/ 16 | models_root=/mnt/qb/hennig/pretrained_models/ 17 | 18 | for dataset in "${datasets[@]}"; 19 | do 20 | # Assuming you have activated your conda environment 21 | 22 | # MAP 23 | python uq.py --benchmark $dataset --method map --model WRN16-4 --data_root ${data_root} --models_root ${models_root} --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID 24 | 25 | for prior in "${prior_optim[@]}"; 26 | do 27 | # LA-NN-MC 28 | python uq.py --benchmark $dataset --method laplace --subset_of_weights last_layer --hessian_structure full --optimize_prior_precision $prior --pred_type nn --link_approx mc --n_samples 20 --model WRN16-4 --data_root ${data_root} --models_root ${models_root} --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID --run_name $dataset/la_nn_mc_${prior}_$SLURM_ARRAY_TASK_ID 29 | 30 | # LA-MC 31 | python uq.py --benchmark $dataset --method laplace --subset_of_weights last_layer --hessian_structure full --optimize_prior_precision $prior --pred_type glm --link_approx mc --n_samples 20 --model WRN16-4 --data_root ${data_root} --models_root ${models_root} --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID --run_name $dataset/la_glm_mc_${prior}_$SLURM_ARRAY_TASK_ID 32 | 33 | # LA-Probit 34 | python uq.py --benchmark $dataset --method laplace --subset_of_weights last_layer --hessian_structure full --optimize_prior_precision $prior --pred_type glm --link_approx probit --model WRN16-4 --data_root ${data_root} --models_root ${models_root} --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID --run_name $dataset/la_glm_probit_${prior}_$SLURM_ARRAY_TASK_ID 35 | done 36 | # HMC 37 | python uq.py --benchmark $dataset --method hmc --prior_precision 40 --model WRN16-4 --data_root ${data_root} --models_root ${models_root} --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID 38 | 39 | # NF 40 | for nfmethod in "${nfmethods[@]}"; 41 | do 42 | echo ${nfmethod} 43 | echo 1 44 | python uq.py --benchmark $dataset --method ${nfmethod}_radial_1 --prior_precision 40 --model WRN16-4 --data_root ${data_root} --models_root ${models_root} --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID 45 | 46 | for n_flows in {5..30..5}; 47 | do 48 | echo $n_flows 49 | python uq.py --benchmark $dataset --method ${nfmethod}_radial_${n_flows} --prior_precision 40 --model WRN16-4 --data_root ${data_root} --models_root ${models_root} --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID 50 | done 51 | done 52 | 53 | # Baselines 54 | python uq.py --benchmark $dataset --method ensemble --nr_components 5 --model WRN16-4 --data_root ${data_root} --models_root ${models_root} --model_seed $SLURM_ARRAY_TASK_ID 55 | 56 | python uq.py --benchmark $dataset --method bbb --model WRN16-4-BBB-flipout --normalize --data_root ${data_root} --models_root ${models_root} --model_seed $SLURM_ARRAY_TASK_ID 57 | 58 | python uq.py --benchmark $dataset --method csghmc --model WRN16-4-CSGHMC --normalize --data_root ${data_root} --models_root ${models_root} --model_seed $SLURM_ARRAY_TASK_ID 59 | 60 | done 61 | -------------------------------------------------------------------------------- /run_scripts/run_fmnist_uq.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu-2080ti # partition 3 | #SBATCH --gres=gpu:rtx2080ti:1 # type and number of gpus 4 | #SBATCH --time=72:00:00 # job will be cancelled after max. 72h 5 | #SBATCH --output=fmnist_uq_%A_%a.out 6 | #SBATCH --array=1-5 7 | 8 | # print info about current job 9 | scontrol show job $SLURM_JOB_ID 10 | 11 | declare -a datasets=("R-FMNIST" "FMNIST-OOD") 12 | declare -a nfmethods=("refine" "nf_naive") 13 | declare -a prior_optim=("marglik" "CV") 14 | # Set the path to your data and models directories here. 15 | data_root=/mnt/qb/hennig/data/ 16 | models_root=/mnt/qb/hennig/pretrained_models 17 | 18 | for dataset in "${datasets[@]}"; 19 | do 20 | # Assuming you have activated your conda environment 21 | 22 | # MAP 23 | python uq.py --benchmark $dataset --method map --model LeNet --data_root ${data_root} --models_root ${models_root} --model_seed $SLURM_ARRAY_TASK_ID 24 | 25 | for prior in "${prior_optim[@]}"; 26 | do 27 | # LA-NN-MC 28 | python uq.py --benchmark $dataset --method laplace --subset_of_weights last_layer --hessian_structure full --optimize_prior_precision $prior --pred_type nn --link_approx mc --n_samples 20 --model LeNet --data_root ${data_root} --models_root ${models_root} --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID --run_name $dataset/la_nn_mc_${prior}_$SLURM_ARRAY_TASK_ID 29 | 30 | # LA-MC 31 | python uq.py --benchmark $dataset --method laplace --subset_of_weights last_layer --hessian_structure full --optimize_prior_precision $prior --pred_type glm --link_approx mc --n_samples 20 --model LeNet --data_root ${data_root} --models_root ${models_root} --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID --run_name $dataset/la_glm_mc_${prior}_$SLURM_ARRAY_TASK_ID 32 | 33 | # LA-Probit 34 | python uq.py --benchmark $dataset --method laplace --subset_of_weights last_layer --hessian_structure full --optimize_prior_precision $prior --pred_type glm --link_approx probit --model LeNet --data_root ${data_root} --models_root ${models_root} --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID --run_name $dataset/la_glm_probit_${prior}_$SLURM_ARRAY_TASK_ID 35 | done 36 | 37 | # HMC 38 | python uq.py --benchmark $dataset --method hmc --prior_precision 510 --model LeNet --data_root ${data_root} --models_root ${models_root} --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID 39 | 40 | # Refine & NF-N(0,I) 41 | for nfmethod in "${nfmethods[@]}"; 42 | do 43 | echo ${nfmethod} 44 | echo 1 45 | python uq.py --benchmark $dataset --method ${nfmethod}_radial_1 --prior_precision 510 --model LeNet --data_root ${data_root} --models_root ${models_root} --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID 46 | 47 | for n_flows in {5..30..5}; 48 | do 49 | echo $n_flows 50 | python uq.py --benchmark $dataset --method ${nfmethod}_radial_${n_flows} --prior_precision 510 --model LeNet --data_root ${data_root} --models_root ${models_root} --compute_mmd --model_seed $SLURM_ARRAY_TASK_ID 51 | done 52 | done 53 | 54 | # Baslines 55 | python uq.py --benchmark $dataset --method ensemble --nr_components 5 --model LeNet --data_root ${data_root} --models_root ${models_root} --model_seed $SLURM_ARRAY_TASK_ID 56 | 57 | python uq.py --benchmark $dataset --method bbb --model LeNet-BBB-flipout --normalize --data_root ${data_root} --models_root ${models_root} --model_seed $SLURM_ARRAY_TASK_ID 58 | 59 | python uq.py --benchmark $dataset --method csghmc --model LeNet-CSGHMC --normalize --data_root ${data_root} --models_root ${models_root} --model_seed $SLURM_ARRAY_TASK_ID 60 | done 61 | -------------------------------------------------------------------------------- /baselines/vanilla/models/wrn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | ########################################################################################## 8 | # Wide ResNet (for WRN16-4) 9 | ########################################################################################## 10 | # Taken from https://github.com/hendrycks/outlier-exposure/blob/master/CIFAR/models/wrn.py 11 | 12 | class BasicBlock(nn.Module): 13 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 14 | super(BasicBlock, self).__init__() 15 | self.bn1 = nn.BatchNorm2d(in_planes) 16 | self.relu1 = nn.ReLU(inplace=True) 17 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=1, bias=False) 19 | self.bn2 = nn.BatchNorm2d(out_planes) 20 | self.relu2 = nn.ReLU(inplace=True) 21 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 22 | padding=1, bias=False) 23 | self.droprate = dropRate 24 | self.equalInOut = (in_planes == out_planes) 25 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 26 | padding=0, bias=False) or None 27 | 28 | def forward(self, x): 29 | if not self.equalInOut: 30 | x = self.relu1(self.bn1(x)) 31 | else: 32 | out = self.relu1(self.bn1(x)) 33 | if self.equalInOut: 34 | out = self.relu2(self.bn2(self.conv1(out))) 35 | else: 36 | out = self.relu2(self.bn2(self.conv1(x))) 37 | if self.droprate > 0: 38 | out = F.dropout(out, p=self.droprate, training=self.training) 39 | out = self.conv2(out) 40 | if not self.equalInOut: 41 | return torch.add(self.convShortcut(x), out) 42 | else: 43 | return torch.add(x, out) 44 | 45 | 46 | class NetworkBlock(nn.Module): 47 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 48 | super(NetworkBlock, self).__init__() 49 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 50 | 51 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 52 | layers = [] 53 | for i in range(nb_layers): 54 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 55 | return nn.Sequential(*layers) 56 | 57 | def forward(self, x): 58 | return self.layer(x) 59 | 60 | 61 | class WideResNet(nn.Module): 62 | def __init__(self, depth, widen_factor, num_classes=10, dropRate=0.0): 63 | super(WideResNet, self).__init__() 64 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 65 | assert ((depth - 4) % 6 == 0) 66 | n = (depth - 4) // 6 67 | block = BasicBlock 68 | # 1st conv before any network block 69 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 70 | padding=1, bias=False) 71 | # 1st block 72 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 73 | # 2nd block 74 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 75 | # 3rd block 76 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 77 | # global average pooling and classifier 78 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.fc = nn.Linear(nChannels[3], num_classes) 81 | self.nChannels = nChannels[3] 82 | 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv2d): 85 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 86 | m.weight.data.normal_(0, math.sqrt(2. / n)) 87 | elif isinstance(m, nn.BatchNorm2d): 88 | m.weight.data.fill_(1) 89 | m.bias.data.zero_() 90 | elif isinstance(m, nn.Linear): 91 | m.bias.data.zero_() 92 | 93 | def forward(self, x): 94 | out = self.conv1(x) 95 | out = self.block1(out) 96 | out = self.block2(out) 97 | out = self.block3(out) 98 | out = self.relu(self.bn1(out)) 99 | out = F.avg_pool2d(out, 8) 100 | out = out.view(-1, self.nChannels) 101 | return self.fc(out) 102 | -------------------------------------------------------------------------------- /models/wrn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | ########################################################################################## 8 | # Wide ResNet (for WRN16-4) 9 | ########################################################################################## 10 | # Taken from https://github.com/hendrycks/outlier-exposure/blob/master/CIFAR/models/wrn.py 11 | 12 | class BasicBlock(nn.Module): 13 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 14 | super(BasicBlock, self).__init__() 15 | self.bn1 = nn.BatchNorm2d(in_planes) 16 | self.relu1 = nn.ReLU(inplace=True) 17 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=1, bias=False) 19 | self.bn2 = nn.BatchNorm2d(out_planes) 20 | self.relu2 = nn.ReLU(inplace=True) 21 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 22 | padding=1, bias=False) 23 | self.droprate = dropRate 24 | self.equalInOut = (in_planes == out_planes) 25 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 26 | padding=0, bias=False) or None 27 | 28 | def forward(self, x): 29 | if not self.equalInOut: 30 | x = self.relu1(self.bn1(x)) 31 | else: 32 | out = self.relu1(self.bn1(x)) 33 | if self.equalInOut: 34 | out = self.relu2(self.bn2(self.conv1(out))) 35 | else: 36 | out = self.relu2(self.bn2(self.conv1(x))) 37 | if self.droprate > 0: 38 | out = F.dropout(out, p=self.droprate, training=self.training) 39 | out = self.conv2(out) 40 | if not self.equalInOut: 41 | return torch.add(self.convShortcut(x), out) 42 | else: 43 | return torch.add(x, out) 44 | 45 | 46 | class NetworkBlock(nn.Module): 47 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 48 | super(NetworkBlock, self).__init__() 49 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 50 | 51 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 52 | layers = [] 53 | for i in range(nb_layers): 54 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 55 | return nn.Sequential(*layers) 56 | 57 | def forward(self, x): 58 | return self.layer(x) 59 | 60 | 61 | class WideResNet(nn.Module): 62 | def __init__(self, depth, widen_factor, num_classes=10, dropRate=0.0): 63 | super(WideResNet, self).__init__() 64 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 65 | assert ((depth - 4) % 6 == 0) 66 | n = (depth - 4) // 6 67 | block = BasicBlock 68 | # 1st conv before any network block 69 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 70 | padding=1, bias=False) 71 | # 1st block 72 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 73 | # 2nd block 74 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 75 | # 3rd block 76 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 77 | # global average pooling and classifier 78 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.fc = nn.Linear(nChannels[3], num_classes) 81 | self.nChannels = nChannels[3] 82 | 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv2d): 85 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 86 | m.weight.data.normal_(0, math.sqrt(2. / n)) 87 | elif isinstance(m, nn.BatchNorm2d): 88 | m.weight.data.fill_(1) 89 | m.bias.data.zero_() 90 | elif isinstance(m, nn.Linear): 91 | m.bias.data.zero_() 92 | 93 | def forward(self, x): 94 | return self.fc(self.forward_features(x)) 95 | 96 | def forward_features(self, x): 97 | out = self.conv1(x) 98 | out = self.block1(out) 99 | out = self.block2(out) 100 | out = self.block3(out) 101 | out = self.relu(self.bn1(out)) 102 | out = F.avg_pool2d(out, 8) 103 | return out.view(-1, self.nChannels) 104 | -------------------------------------------------------------------------------- /baselines/vanilla/models/wrn_fixup.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from asdfghjkl.operations import Bias, Scale 8 | 9 | ########################################################################################## 10 | # Wide ResNet (for WRN16-4) 11 | ########################################################################################## 12 | # Adapted from https://github.com/hendrycks/outlier-exposure/blob/master/CIFAR/models/wrn.py 13 | 14 | class FixupBasicBlock(nn.Module): 15 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 16 | super(FixupBasicBlock, self).__init__() 17 | self.bias1 = Bias() 18 | self.relu1 = nn.ReLU(inplace=True) 19 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | self.bias2 = Bias() 22 | self.relu2 = nn.ReLU(inplace=True) 23 | self.bias3 = Bias() 24 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 25 | padding=1, bias=False) 26 | self.bias4 = Bias() 27 | self.scale1 = Scale() 28 | self.droprate = dropRate 29 | self.equalInOut = (in_planes == out_planes) 30 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 31 | padding=0, bias=False) or None 32 | 33 | def forward(self, x): 34 | if not self.equalInOut: 35 | x = self.relu1(self.bias1(x)) 36 | else: 37 | out = self.relu1(self.bias1(x)) 38 | if self.equalInOut: 39 | out = self.bias3(self.relu2(self.bias2(self.conv1(out)))) 40 | else: 41 | out = self.bias3(self.relu2(self.bias2(self.conv1(x)))) 42 | if self.droprate > 0: 43 | out = F.dropout(out, p=self.droprate, training=self.training) 44 | out = self.bias4(self.scale1(self.conv2(out))) 45 | if not self.equalInOut: 46 | return torch.add(self.convShortcut(x), out) 47 | else: 48 | return torch.add(x, out) 49 | 50 | 51 | class FixupNetworkBlock(nn.Module): 52 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 53 | super(FixupNetworkBlock, self).__init__() 54 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 55 | 56 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 57 | layers = [] 58 | for i in range(nb_layers): 59 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 60 | return nn.Sequential(*layers) 61 | 62 | def forward(self, x): 63 | return self.layer(x) 64 | 65 | 66 | class FixupWideResNet(nn.Module): 67 | def __init__(self, depth, widen_factor, num_classes=10, dropRate=0.0): 68 | super(FixupWideResNet, self).__init__() 69 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 70 | assert ((depth - 4) % 6 == 0) 71 | n = (depth - 4) // 6 72 | block = FixupBasicBlock 73 | # 1st conv before any network block 74 | self.num_layers = n * 3 75 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 76 | padding=1, bias=False) 77 | self.bias1 = Bias() 78 | # 1st block 79 | self.block1 = FixupNetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 80 | # 2nd block 81 | self.block2 = FixupNetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 82 | # 3rd block 83 | self.block3 = FixupNetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 84 | # global average pooling and classifier 85 | self.bias2 = Bias() 86 | self.relu = nn.ReLU(inplace=True) 87 | self.fc = nn.Linear(nChannels[3], num_classes) 88 | self.nChannels = nChannels[3] 89 | 90 | for m in self.modules(): 91 | if isinstance(m, FixupBasicBlock): 92 | conv = m.conv1 93 | k = conv.weight.shape[0] * np.prod(conv.weight.shape[2:]) 94 | nn.init.normal_(conv.weight, 95 | mean=0, 96 | std=np.sqrt(2. / k) * self.num_layers ** (-0.5)) 97 | nn.init.constant_(m.conv2.weight, 0) 98 | if m.convShortcut is not None: 99 | cs = m.convShortcut 100 | k = cs.weight.shape[0] * np.prod(cs.weight.shape[2:]) 101 | nn.init.normal_(cs.weight, 102 | mean=0, 103 | std=np.sqrt(2. / k)) 104 | elif isinstance(m, nn.Linear): 105 | nn.init.constant_(m.weight, 0) 106 | nn.init.constant_(m.bias, 0) 107 | 108 | def forward(self, x): 109 | out = self.bias1(self.conv1(x)) 110 | out = self.block1(out) 111 | out = self.block2(out) 112 | out = self.block3(out) 113 | out = self.relu(out) 114 | out = F.avg_pool2d(out, 8) 115 | out = out.view(-1, self.nChannels) 116 | return self.fc(self.bias2(out)) 117 | 118 | 119 | if __name__ == '__main__': 120 | X = torch.randn(7, 3, 32, 32) 121 | model = FixupWideResNet(16, 4, 10, dropRate=0.3) 122 | print(model(X).shape) 123 | -------------------------------------------------------------------------------- /baselines/csghmc/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as o 3 | import sys 4 | 5 | sys.path.append(o.abspath(o.join(o.dirname(sys.modules[__name__].__file__), '../..'))) 6 | 7 | import argparse 8 | import copy 9 | import math 10 | import pickle 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | from torch import optim 16 | from tqdm import tqdm, trange 17 | 18 | from baselines.csghmc.csghmc import CSGHMCTrainer 19 | from baselines.vanilla.models.lenet import LeNet 20 | from baselines.vanilla.models.wrn import WideResNet 21 | from utils import data_utils, test 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--dataset', default='MNIST', choices=['MNIST', 'FMNIST', 'CIFAR10', 'CIFAR100']) 25 | parser.add_argument('--n_epochs', type=int, default=200) 26 | parser.add_argument('--n_cycles', type=int, default=4) 27 | parser.add_argument('--n_samples_per_cycle', type=int, default=3) 28 | parser.add_argument('--initial_lr', type=float, default=0.1) 29 | parser.add_argument('--randseed', type=int, default=123) 30 | args = parser.parse_args() 31 | 32 | np.random.seed(args.randseed) 33 | torch.manual_seed(args.randseed) 34 | torch.backends.cudnn.deterministic = True 35 | torch.backends.cudnn.benchmark = True 36 | 37 | # Just symlink your dataset folder into your home directory like so 38 | # No need to change this code---this way it's more consistent 39 | data_path = os.path.expanduser('~/Datasets') 40 | 41 | if args.dataset == 'MNIST': 42 | train_loader, val_loader, test_loader = data_utils.get_mnist_loaders(data_path) 43 | elif args.dataset == 'FMNIST': 44 | train_loader, val_loader, test_loader = data_utils.get_fmnist_loaders(data_path) 45 | elif args.dataset == 'CIFAR10': 46 | train_loader, val_loader, test_loader = data_utils.get_cifar10_loaders(data_path) 47 | else: 48 | train_loader, val_loader, test_loader = data_utils.get_cifar100_loaders(data_path) 49 | 50 | targets = torch.cat([y for x, y in test_loader], dim=0).numpy() 51 | num_classes = 100 if args.dataset == 'CIFAR100' else 10 52 | 53 | if args.dataset in ['MNIST', 'FMNIST']: 54 | model = LeNet(num_classes) 55 | arch_name = 'lenet' 56 | dir_name = f'lenet_{args.dataset.lower()}' 57 | else: 58 | model = WideResNet(16, 4, num_classes, dropRate=0) 59 | arch_name = 'wrn_16-4' 60 | dir_name = f'wrn_16-4_{args.dataset.lower()}' 61 | 62 | print(f'Num. params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}') 63 | 64 | model.cuda() 65 | model.train() 66 | 67 | batch_size = 128 68 | data_size = len(train_loader.dataset) 69 | num_batch = data_size/batch_size + 1 70 | epoch_per_cycle = args.n_epochs // args.n_cycles 71 | pbar = trange(args.n_epochs) 72 | total_iters = args.n_epochs * num_batch 73 | weight_decay = 5e-4 74 | 75 | # Timing stuff 76 | timing_start = torch.cuda.Event(enable_timing=True) 77 | timing_end = torch.cuda.Event(enable_timing=True) 78 | torch.cuda.synchronize() 79 | timing_start.record() 80 | 81 | trainer = CSGHMCTrainer( 82 | model, args.n_cycles, args.n_samples_per_cycle, args.n_epochs, args.initial_lr, 83 | num_batch, total_iters, data_size, weight_decay 84 | ) 85 | samples = [] 86 | 87 | for epoch in pbar: 88 | train_loss = 0 89 | num_batches = len(train_loader) 90 | num_data = len(train_loader.dataset) 91 | 92 | for batch_idx, (x, y) in enumerate(train_loader): 93 | trainer.model.train() 94 | trainer.model.zero_grad() 95 | 96 | x, y = x.cuda(non_blocking=True), y.long().cuda(non_blocking=True) 97 | 98 | out = trainer.model(x) 99 | loss = F.cross_entropy(out, y) 100 | loss.backward() 101 | 102 | # The meat of the CSGMCMC method 103 | lr = trainer.adjust_lr(epoch, batch_idx) 104 | trainer.update_params(epoch) 105 | 106 | train_loss = 0.9*train_loss + 0.1*loss.item() 107 | 108 | # Save the last n_samples_per_cycle iterates of a cycle 109 | if (epoch % epoch_per_cycle) + 1 > epoch_per_cycle - args.n_samples_per_cycle: 110 | samples.append(copy.deepcopy(trainer.model.state_dict())) 111 | 112 | model.eval() 113 | pred = test.predict(test_loader, model).cpu().numpy() 114 | acc_val = np.mean(np.argmax(pred, 1) == targets)*100 115 | mmc_val = pred.max(-1).mean()*100 116 | 117 | pbar.set_description( 118 | f'[Epoch: {epoch+1}; loss: {train_loss:.3f}; acc: {acc_val:.1f}; mmc: {mmc_val:.1f}]' 119 | ) 120 | 121 | # Timing stuff 122 | timing_end.record() 123 | torch.cuda.synchronize() 124 | timing = timing_start.elapsed_time(timing_end)/1000 125 | path_timing = './results/timings_train' 126 | if not os.path.exists(path_timing): os.makedirs(path_timing) 127 | np.save(f'{path_timing}/csghmc_{args.dataset.lower()}_{args.randseed}', timing) 128 | 129 | path = f'./pretrained_models/csghmc/{dir_name}' 130 | 131 | if not os.path.exists(path): 132 | os.makedirs(path) 133 | 134 | save_name = f'{path}/{arch_name}_{args.dataset.lower()}_{args.randseed}_1' 135 | torch.save(samples, save_name) 136 | 137 | ## Try loading and testing 138 | samples_state_dicts = torch.load(save_name) 139 | models = [] 140 | 141 | for state_dict in samples_state_dicts: 142 | if args.dataset in ['MNIST', 'FMNIST']: 143 | _model = LeNet(num_classes) 144 | else: 145 | _model = WideResNet(16, 4, num_classes, dropRate=0) 146 | 147 | _model.load_state_dict(state_dict) 148 | models.append(_model.cuda().eval()) 149 | 150 | print() 151 | 152 | py_in = test.predict_ensemble(test_loader, models).cpu().numpy() 153 | acc_in = np.mean(np.argmax(py_in, 1) == targets)*100 154 | print(f'Accuracy: {acc_in:.1f}') 155 | -------------------------------------------------------------------------------- /baselines/bbb/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as o 3 | import sys 4 | 5 | sys.path.append(o.abspath(o.join(o.dirname(sys.modules[__name__].__file__), '../..'))) 6 | 7 | import argparse 8 | import math 9 | import pickle 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | from torch import optim 15 | from torch.cuda import amp 16 | from tqdm import tqdm, trange 17 | 18 | from baselines.bbb.models.lenet import LeNetBBB 19 | from baselines.bbb.models.wrn import WideResNetBBB 20 | from utils import data_utils, test 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--dataset', default='MNIST', choices=['MNIST', 'FMNIST', 'CIFAR10', 'CIFAR100']) 24 | parser.add_argument('--estimator', default='flipout', choices=['reparam', 'flipout']) 25 | parser.add_argument('--var0', default=None, help='Gaussian prior variance. If None, it will be computed to emulate weight decay') 26 | parser.add_argument('--tau', type=float, default=0.1, help='Tempering parameter for the KL-term') 27 | parser.add_argument('--randseed', type=int, default=123) 28 | args = parser.parse_args() 29 | 30 | np.random.seed(args.randseed) 31 | torch.manual_seed(args.randseed) 32 | torch.backends.cudnn.deterministic = True 33 | torch.backends.cudnn.benchmark = True 34 | 35 | # Just symlink your dataset folder into your home directory like so 36 | # No need to change this code---this way it's more consistent 37 | data_path = os.path.expanduser('~/Datasets') 38 | 39 | if args.dataset == 'MNIST': 40 | train_loader, val_loader, test_loader = data_utils.get_mnist_loaders(data_path) 41 | elif args.dataset == 'FMNIST': 42 | train_loader, val_loader, test_loader = data_utils.get_fmnist_loaders(data_path) 43 | elif args.dataset == 'CIFAR10': 44 | train_loader, val_loader, test_loader = data_utils.get_cifar10_loaders(data_path) 45 | else: 46 | train_loader, val_loader, test_loader = data_utils.get_cifar100_loaders(data_path) 47 | 48 | targets = torch.cat([y for x, y in test_loader], dim=0).numpy() 49 | num_classes = 100 if args.dataset == 'CIFAR100' else 10 50 | 51 | if args.var0 is None: 52 | args.var0 = 1/(5e-4*len(train_loader.dataset)) 53 | else: 54 | args.var0 = float(args.var0) 55 | 56 | if args.dataset in ['MNIST', 'FMNIST']: 57 | model = LeNetBBB(num_classes, var0=args.var0, estimator=args.estimator) 58 | opt = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0) 59 | arch_name = 'lenet' 60 | dir_name = f'lenet_{args.dataset.lower()}' 61 | else: 62 | model = WideResNetBBB(16, 4, num_classes, var0=args.var0, estimator=args.estimator) 63 | opt = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=0, nesterov=True) 64 | arch_name = 'wrn_16-4' 65 | dir_name = f'wrn_16-4_{args.dataset.lower()}' 66 | 67 | print(f'Num. params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}') 68 | 69 | model.cuda() 70 | model.train() 71 | 72 | n_epochs = 100 73 | pbar = trange(n_epochs) 74 | ## T_max is the max iterations: n_epochs x n_batches_per_epoch 75 | scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=n_epochs*len(train_loader)) 76 | 77 | ## For automatic-mixed-precision 78 | scaler = amp.GradScaler() 79 | 80 | # Timing stuff 81 | timing_start = torch.cuda.Event(enable_timing=True) 82 | timing_end = torch.cuda.Event(enable_timing=True) 83 | torch.cuda.synchronize() 84 | timing_start.record() 85 | 86 | for epoch in pbar: 87 | train_loss = 0 88 | num_batches = len(train_loader) 89 | num_data = len(train_loader.dataset) 90 | 91 | for batch_idx, (x, y) in enumerate(train_loader): 92 | model.train() 93 | opt.zero_grad() 94 | 95 | m = len(x) # Batch size 96 | x, y = x.cuda(non_blocking=True), y.long().cuda(non_blocking=True) 97 | 98 | with amp.autocast(): 99 | out, kl = model(x) 100 | # Scaled negative-ELBO with 1 MC sample 101 | # See Graves 2011 as to why the KL-term is scaled that way and notice that we use mean instead of sum; tau is the tempering parameter 102 | loss = F.cross_entropy(out.squeeze(), y) + args.tau/num_data*kl 103 | 104 | scaler.scale(loss).backward() 105 | scaler.step(opt) 106 | scaler.update() 107 | scheduler.step() 108 | 109 | train_loss = 0.9*train_loss + 0.1*loss.item() 110 | 111 | model.eval() 112 | pred = test.predict_vb(test_loader, model, n_samples=1).cpu().numpy() 113 | acc_val = np.mean(np.argmax(pred, 1) == targets)*100 114 | mmc_val = pred.max(-1).mean()*100 115 | 116 | pbar.set_description( 117 | f'[Epoch: {epoch+1}; ELBO: {train_loss:.3f}; acc: {acc_val:.1f}; mmc: {mmc_val:.1f}]' 118 | ) 119 | 120 | # Timing stuff 121 | timing_end.record() 122 | torch.cuda.synchronize() 123 | timing = timing_start.elapsed_time(timing_end)/1000 124 | path_timing = './results/timings_train' 125 | if not os.path.exists(path_timing): os.makedirs(path_timing) 126 | np.save(f'{path_timing}/bbb-{args.estimator}_{args.dataset.lower()}_{args.randseed}', timing) 127 | 128 | path = f'./pretrained_models/bbb/{args.estimator}/{dir_name}' 129 | 130 | if not os.path.exists(path): 131 | os.makedirs(path) 132 | 133 | save_name = f'{path}/{arch_name}_{args.dataset.lower()}_{args.randseed}_1' 134 | torch.save(model.state_dict(), save_name) 135 | 136 | ## Try loading and testing 137 | model.load_state_dict(torch.load(save_name)) 138 | model.eval() 139 | 140 | print() 141 | 142 | ## In-distribution 143 | py_in = test.predict_vb(test_loader, model, n_samples=20).cpu().numpy() 144 | acc_in = np.mean(np.argmax(py_in, 1) == targets)*100 145 | print(f'Accuracy: {acc_in:.1f}') 146 | -------------------------------------------------------------------------------- /baselines/bbb/models/wrn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from bayesian_torch.layers import (Conv2dFlipout, Conv2dReparameterization, 7 | LinearFlipout, LinearReparameterization) 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | 12 | def __init__(self, in_planes, out_planes, stride, var0=1, dropRate=0.0, estimator='reparam'): 13 | _check_estimator(estimator) 14 | super().__init__() 15 | 16 | Conv2dVB = Conv2dReparameterization if estimator == 'reparam' else Conv2dFlipout 17 | 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.relu1 = nn.ReLU(inplace=True) 20 | self.conv1 = Conv2dVB( 21 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False, 22 | prior_variance=var0 23 | ) 24 | 25 | self.bn2 = nn.BatchNorm2d(out_planes) 26 | self.relu2 = nn.ReLU(inplace=True) 27 | self.conv2 = Conv2dVB( 28 | out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False, 29 | prior_variance=var0 30 | ) 31 | self.droprate = dropRate 32 | self.equalInOut = (in_planes == out_planes) 33 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) or None 34 | 35 | def forward(self, x): 36 | kl_total = 0 37 | 38 | if not self.equalInOut: 39 | x = self.relu1(self.bn1(x)) 40 | else: 41 | out = self.relu1(self.bn1(x)) 42 | 43 | if self.equalInOut: 44 | out, kl = self.conv1(out) 45 | out = self.relu2(self.bn2(out)) 46 | else: 47 | out, kl = self.conv1(x) 48 | out = self.relu2(self.bn2(out)) 49 | 50 | kl_total += kl 51 | 52 | out, kl = self.conv2(out) 53 | kl_total += kl 54 | 55 | if not self.equalInOut: 56 | out_shortcut = self.convShortcut(x) 57 | return torch.add(out_shortcut, out), kl_total 58 | else: 59 | return torch.add(x, out), kl_total 60 | 61 | 62 | class NetworkBlock(nn.Module): 63 | 64 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, var0=1, dropRate=0.0, estimator='reparam'): 65 | _check_estimator(estimator) 66 | super().__init__() 67 | self.layer = self._make_layer( 68 | block, in_planes, out_planes, nb_layers, stride, var0, dropRate 69 | ) 70 | 71 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, var0, dropRate): 72 | layers = [] 73 | 74 | for i in range(nb_layers): 75 | layers.append(block( 76 | i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, 77 | var0, dropRate 78 | )) 79 | 80 | return nn.Sequential(*layers) 81 | 82 | def forward(self, x): 83 | out = x 84 | kl_total = 0 85 | 86 | for l in self.layer: 87 | out, kl = l(out) 88 | kl_total += kl 89 | 90 | return out, kl_total 91 | 92 | 93 | class WideResNetBBB(nn.Module): 94 | 95 | def __init__(self, depth, widen_factor, num_classes, num_channel=3, var0=1, droprate=0, estimator='reparam', feature_extractor=False): 96 | _check_estimator(estimator) 97 | super().__init__() 98 | 99 | Conv2dVB = Conv2dReparameterization if estimator == 'reparam' else Conv2dFlipout 100 | LinearVB = LinearReparameterization if estimator == 'reparam' else LinearFlipout 101 | 102 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 103 | assert ((depth - 4) % 6 == 0) 104 | n = (depth - 4) // 6 105 | block = BasicBlock 106 | 107 | # 1st conv before any network block 108 | self.conv1 = Conv2dVB( 109 | num_channel, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False, 110 | prior_variance=var0 111 | ) 112 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, var0, droprate) 113 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, var0, droprate) 114 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, var0, droprate) 115 | # global average pooling and classifier 116 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 117 | self.relu = nn.ReLU(inplace=True) 118 | self.fc = LinearVB(nChannels[3], num_classes, prior_variance=var0) 119 | 120 | self.nChannels = nChannels[3] 121 | self.feature_extractor = feature_extractor 122 | 123 | for m in self.modules(): 124 | if isinstance(m, nn.Conv2d): 125 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 126 | m.weight.data.normal_(0, math.sqrt(2. / n)) 127 | elif isinstance(m, nn.BatchNorm2d): 128 | m.weight.data.fill_(1) 129 | m.bias.data.zero_() 130 | elif isinstance(m, nn.Linear): 131 | m.bias.data.zero_() 132 | 133 | def forward(self, x): 134 | out, kl_total = self.features(x) 135 | 136 | if self.feature_extractor: 137 | return out, kl_total 138 | 139 | out, kl = self.fc(out) 140 | kl_total += kl 141 | 142 | return out, kl_total 143 | 144 | 145 | def features(self, x): 146 | kl_total = 0 147 | 148 | out, kl = self.conv1(x) 149 | kl_total += kl 150 | out, kl = self.block1(out) 151 | kl_total += kl 152 | out, kl = self.block2(out) 153 | kl_total += kl 154 | out, kl = self.block3(out) 155 | kl_total += kl 156 | 157 | out = self.relu(self.bn1(out)) 158 | out = F.avg_pool2d(out, 8) 159 | out = out.view(-1, self.nChannels) 160 | 161 | return out, kl_total 162 | 163 | 164 | def _check_estimator(estimator): 165 | assert estimator in ['reparam', 'flipout'], 'Estimator must be either "reparam" or "flipout"' 166 | -------------------------------------------------------------------------------- /utils/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from tqdm import tqdm 4 | 5 | from utils import utils 6 | from utils.utils import mixture_model_pred 7 | 8 | 9 | @torch.no_grad() 10 | def test(components, test_loader, prediction_mode, pred_type='glm', n_samples=100, 11 | link_approx='probit', no_loss_acc=False, device='cpu', 12 | likelihood='classification', sigma_noise=None): 13 | 14 | temperature_scaling_model = None 15 | if prediction_mode != 'swag': # in ['map', 'laplace', 'bbb', 'csghmc', 'hmc']: 16 | model = components[0] 17 | if prediction_mode in ['map', 'bbb']: 18 | if prediction_mode == 'map' and isinstance(model, tuple): 19 | model, temperature_scaling_model = model[0], model[1] 20 | model.eval() 21 | elif prediction_mode == 'csghmc': 22 | for m in model: 23 | m.eval() 24 | elif prediction_mode == 'swag': 25 | model, swag_samples, swag_bn_params = components[0] 26 | 27 | if likelihood == 'regression' and sigma_noise is None: 28 | raise ValueError('Must provide sigma_noise for regression!') 29 | 30 | if likelihood == 'classification': 31 | loss_fn = nn.NLLLoss() 32 | elif likelihood == 'regression': 33 | loss_fn = nn.GaussianNLLLoss(full=True) 34 | else: 35 | raise ValueError(f'Invalid likelihood type {likelihood}') 36 | 37 | all_y_true = list() 38 | all_y_prob = list() 39 | all_y_var = list() 40 | if prediction_mode in ['map', 'ensemble', 'laplace', 'mola', 'swag', 'multi-swag', 'bbb', 'csghmc']: 41 | for data in tqdm(test_loader): 42 | x, y = data[0].to(device), data[1].to(device) 43 | all_y_true.append(y.cpu()) 44 | 45 | if prediction_mode in ['ensemble', 'mola', 'multi-swag']: 46 | # set uniform mixture weights 47 | K = len(components) 48 | pi = torch.ones(K, device=device) / K 49 | y_prob = mixture_model_pred( 50 | components, x, pi, 51 | prediction_mode=prediction_mode, 52 | pred_type=pred_type, 53 | link_approx=link_approx, 54 | n_samples=n_samples, 55 | likelihood=likelihood) 56 | 57 | elif prediction_mode == 'laplace': 58 | y_prob = model( 59 | x, pred_type=pred_type, link_approx=link_approx, n_samples=n_samples) 60 | 61 | elif prediction_mode == 'map': 62 | y_prob = model(x).detach() 63 | 64 | elif prediction_mode == 'bbb': 65 | y_prob = torch.stack([model(x)[0].softmax(-1) for _ in range(10)]).mean(0) 66 | 67 | elif prediction_mode == 'csghmc': 68 | y_prob = torch.stack([m(x).softmax(-1) for m in model]).mean(0) 69 | 70 | elif prediction_mode == 'swag': 71 | from baselines.swag.swag import predict_swag 72 | y_prob = predict_swag(model, x, swag_samples, swag_bn_params) 73 | 74 | if likelihood == 'regression': 75 | y_mean = y_prob if prediction_mode == 'map' else y_prob[0] 76 | y_var = torch.zeros_like(y_mean) if prediction_mode == 'map' else y_prob[1].squeeze(2) 77 | all_y_prob.append(y_mean.cpu()) 78 | all_y_var.append(y_var.cpu()) 79 | else: 80 | all_y_prob.append(y_prob.cpu()) 81 | 82 | # aggregate predictive distributions, true labels and metadata 83 | all_y_prob = torch.cat(all_y_prob, dim=0) 84 | all_y_true = torch.cat(all_y_true, dim=0) 85 | 86 | else: 87 | if likelihood == 'regression': 88 | raise ValueError(f'Prediction mode {prediction_mode} is not supported for regression.') 89 | predictive, net = model 90 | if hasattr(net, 'forward_features'): 91 | all_y_prob, all_y_true = utils.predict_pyro_ll(components, test_loader) 92 | else: 93 | if len(components) > 1: 94 | raise ValueError('Only ll(-refine) methods support multiple components.') 95 | all_y_prob, all_y_true = utils.predict_pyro(predictive, test_loader) 96 | all_y_prob, all_y_true = all_y_prob.cpu(), all_y_true.cpu() 97 | 98 | if temperature_scaling_model is not None: 99 | print('Calibrating predictions using temperature scaling...') 100 | all_y_prob = torch.from_numpy(temperature_scaling_model.predict_proba(all_y_prob.numpy())) 101 | elif prediction_mode == 'map' and likelihood == 'classification': 102 | all_y_prob = all_y_prob.softmax(dim=1) 103 | 104 | # compute some metrics: mean confidence, accuracy and negative log-likelihood 105 | metrics = {} 106 | if likelihood == 'classification': 107 | c, preds = torch.max(all_y_prob, 1) 108 | metrics['conf'] = c.mean().item() 109 | 110 | if not no_loss_acc: 111 | if likelihood == 'regression': 112 | all_y_var = torch.cat(all_y_var, dim=0) + sigma_noise**2 113 | metrics['nll'] = loss_fn(all_y_prob, all_y_true, all_y_var).item() 114 | 115 | else: 116 | all_y_var = None 117 | metrics['nll'] = loss_fn(all_y_prob.log(), all_y_true).item() 118 | metrics['acc'] = (all_y_true == preds).float().mean().item() 119 | 120 | return metrics, all_y_prob, all_y_var 121 | 122 | 123 | @torch.no_grad() 124 | def predict(dataloader, model): 125 | py = [] 126 | 127 | for x, y in dataloader: 128 | x = x.cuda() 129 | py.append(torch.softmax(model(x), -1)) 130 | 131 | return torch.cat(py, dim=0) 132 | 133 | 134 | @torch.no_grad() 135 | def predict_ensemble(dataloader, models): 136 | py = [] 137 | 138 | for x, y in dataloader: 139 | x = x.cuda() 140 | 141 | _py = 0 142 | for model in models: 143 | _py += 1/len(models) * torch.softmax(model(x), -1) 144 | py.append(_py) 145 | 146 | return torch.cat(py, dim=0) 147 | 148 | 149 | @torch.no_grad() 150 | def predict_vb(dataloader, model, n_samples=1): 151 | py = [] 152 | 153 | for x, y in dataloader: 154 | x = x.cuda() 155 | 156 | _py = 0 157 | for _ in range(n_samples): 158 | f_s, _ = model(x) # The second return is KL 159 | _py += torch.softmax(f_s, 1) 160 | _py /= n_samples 161 | 162 | py.append(_py) 163 | 164 | return torch.cat(py, dim=0) 165 | -------------------------------------------------------------------------------- /models/network.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | 8 | def UCI(n_features): 9 | return nn.Sequential( 10 | nn.Linear(n_features, 50), 11 | nn.ReLU(), 12 | nn.Linear(50, 1) 13 | ) 14 | 15 | def UCI_hetero(n_features): 16 | return nn.Sequential( 17 | nn.Linear(n_features, 50), 18 | nn.ReLU(), 19 | nn.Linear(50, 2) 20 | ) 21 | 22 | 23 | def MLP(n_features, n_hiddens=50): 24 | return nn.Sequential( 25 | nn.Linear(n_features, n_hiddens), 26 | nn.ReLU(), 27 | nn.Linear(n_hiddens, 10) 28 | ) 29 | 30 | 31 | class LeNet(nn.Module): 32 | 33 | def __init__(self): 34 | super().__init__() 35 | 36 | self.features = nn.Sequential( 37 | nn.Conv2d(1, 6, 5), 38 | nn.ReLU(), 39 | nn.MaxPool2d(2), 40 | nn.Conv2d(6, 16, 5), 41 | nn.ReLU(), 42 | nn.MaxPool2d(2), 43 | nn.Flatten(), 44 | nn.Linear(16*4*4, 120), 45 | nn.ReLU(), 46 | nn.Linear(120, 84), 47 | nn.ReLU() 48 | ) 49 | self.ll = nn.Linear(84, 10) 50 | 51 | def forward_features(self, x): 52 | return self.features(x) 53 | 54 | def forward(self, x): 55 | return self.ll(self.forward_features(x)) 56 | 57 | 58 | class GRUClf(nn.Module): 59 | 60 | def __init__(self, num_classes, vocab_size, feature_extractor=False): 61 | super().__init__() 62 | self.feature_extractor = feature_extractor 63 | self.embedding = nn.Embedding(vocab_size, 50, padding_idx=1) 64 | self.gru = nn.GRU( 65 | input_size=50, hidden_size=128, num_layers=2, 66 | bias=True, batch_first=True, bidirectional=False 67 | ) 68 | self.linear = nn.Linear(128, num_classes) 69 | 70 | def forward(self, x): 71 | hidden = self.forward_features(x) 72 | 73 | if self.feature_extractor: 74 | return hidden 75 | 76 | logits = self.linear(hidden) 77 | return logits 78 | 79 | def forward_features(self, x): 80 | embeds = self.embedding(x) 81 | hidden = self.gru(embeds)[1][1] # select h_n, and select the 2nd layer 82 | return hidden 83 | 84 | 85 | class WideResNet(nn.Module): 86 | def __init__(self, depth, widen_factor, num_classes=10, dropRate=0): 87 | super(WideResNet, self).__init__() 88 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 89 | assert ((depth - 4) % 6 == 0) 90 | n = (depth - 4) // 6 91 | block = BasicBlock 92 | # 1st conv before any network block 93 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 94 | padding=1, bias=False) 95 | # 1st block 96 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 97 | # 2nd block 98 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 99 | # 3rd block 100 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 101 | # global average pooling and classifier 102 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.fc = nn.Linear(nChannels[3], num_classes) 105 | self.nChannels = nChannels[3] 106 | 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 110 | m.weight.data.normal_(0, math.sqrt(2. / n)) 111 | elif isinstance(m, nn.BatchNorm2d): 112 | m.weight.data.fill_(1) 113 | m.bias.data.zero_() 114 | elif isinstance(m, nn.Linear): 115 | m.bias.data.zero_() 116 | 117 | def forward(self, x): 118 | return self.fc(self.forward_features(x)) 119 | 120 | def forward_features(self, x): 121 | out = self.conv1(x) 122 | out = self.block1(out) 123 | out = self.block2(out) 124 | out = self.block3(out) 125 | out = self.relu(self.bn1(out)) 126 | out = F.avg_pool2d(out, 8) 127 | return out.view(-1, self.nChannels) 128 | 129 | 130 | class BasicBlock(nn.Module): 131 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 132 | super(BasicBlock, self).__init__() 133 | self.bn1 = nn.BatchNorm2d(in_planes) 134 | self.relu1 = nn.ReLU(inplace=True) 135 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 136 | padding=1, bias=False) 137 | self.bn2 = nn.BatchNorm2d(out_planes) 138 | self.relu2 = nn.ReLU(inplace=True) 139 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 140 | padding=1, bias=False) 141 | self.droprate = dropRate 142 | self.equalInOut = (in_planes == out_planes) 143 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 144 | padding=0, bias=False) or None 145 | 146 | def forward(self, x): 147 | if not self.equalInOut: 148 | x = self.relu1(self.bn1(x)) 149 | else: 150 | out = self.relu1(self.bn1(x)) 151 | if self.equalInOut: 152 | out = self.relu2(self.bn2(self.conv1(out))) 153 | else: 154 | out = self.relu2(self.bn2(self.conv1(x))) 155 | if self.droprate > 0: 156 | out = F.dropout(out, p=self.droprate, training=self.training) 157 | out = self.conv2(out) 158 | if not self.equalInOut: 159 | return torch.add(self.convShortcut(x), out) 160 | else: 161 | return torch.add(x, out) 162 | 163 | 164 | class NetworkBlock(nn.Module): 165 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 166 | super(NetworkBlock, self).__init__() 167 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 168 | 169 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 170 | layers = [] 171 | for i in range(nb_layers): 172 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 173 | return nn.Sequential(*layers) 174 | 175 | def forward(self, x): 176 | return self.layer(x) 177 | 178 | -------------------------------------------------------------------------------- /models/pyro.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import pyro 5 | import pyro.distributions as dist 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | from utils import utils 10 | 11 | 12 | class Model: 13 | 14 | def __init__(self, n_params, n_data, prior_prec=10, cuda=False, 15 | proj_mat=None, base_dist=None, diag=False): 16 | self.uuid = np.random.randint(low=0, high=10000, size=1)[0] 17 | 18 | self.n_data = n_data 19 | self.prior_mean = torch.zeros(n_params, device='cuda' if cuda else 'cpu') 20 | self.prior_std = math.sqrt(1/prior_prec) 21 | self.cuda = cuda 22 | 23 | if proj_mat is not None: 24 | self.A = proj_mat 25 | self.k, self.d = self.A.shape 26 | 27 | self.prior_mean_proj = torch.zeros(self.k) 28 | self.prior_Cov_proj = self.prior_std**2 * self.A @ self.A.T 29 | 30 | if base_dist is not None: 31 | self.base_dist = base_dist 32 | 33 | self.base_mean_proj = self.A @ self.base_dist.mean 34 | 35 | if not diag: 36 | self.base_Cov_proj = self.A @ self.base_dist.covariance_matrix @ self.A.T 37 | else: 38 | self.base_Cov_proj = self.A * (self.base_dist.scale**2)[None, :] @ self.A.T 39 | 40 | self.base_dist_proj = dist.MultivariateNormal(self.base_mean_proj, self.base_Cov_proj) 41 | 42 | if cuda: 43 | self.prior_mean = self.prior_mean.cuda() 44 | 45 | if proj_mat is not None and base_dist is not None: 46 | self.prior_mean_proj = self.prior_mean_proj.cuda() 47 | self.prior_Cov_proj = self.prior_Cov_proj.cuda() 48 | 49 | def model(self, X, y=None): 50 | raise NotImplementedError() 51 | 52 | def model_subspace(self, X, y=None): 53 | raise NotImplementedError() 54 | 55 | 56 | class RegressionModel(Model): 57 | 58 | def __init__(self, get_net, n_data, prior_prec=10, log_noise=torch.tensor(1.), cuda=False, proj_mat=None, base_dist=None): 59 | n_params = sum(p.numel() for p in self.get_net().parameters()) 60 | super().__init__(n_params, n_data, prior_prec, cuda, proj_mat, base_dist) 61 | self.get_net = get_net 62 | self.noise = F.softplus(log_noise) 63 | 64 | def model(self, X, y=None): 65 | # Sample params from the prior 66 | theta = pyro.sample('theta', dist.Normal(self.prior_mean, self.prior_std).to_event(1)) 67 | 68 | # Put the sample into the net 69 | net = self.get_net() 70 | utils.vector_to_parameters_backpropable(theta, net) 71 | f_X = net(X).squeeze() 72 | 73 | # Likelihood 74 | if y is not None: 75 | # Training 76 | with pyro.plate('data', size=self.n_data, subsample=y.squeeze()): 77 | pyro.sample('obs', dist.Normal(f_X, self.noise), obs=y.squeeze()) 78 | else: 79 | # Testing 80 | pyro.sample('obs', dist.Normal(f_X, self.noise)) 81 | 82 | def model_subspace(self, X, y=None, full_batch=False): 83 | # Sample params from the prior on low-dim, then project it to high-dim 84 | z = pyro.sample('z', dist.MultivariateNormal(self.prior_mean_proj, self.prior_Cov_proj)) 85 | theta = self.A.T @ z 86 | 87 | # Put the sample into the net 88 | net = self.get_net() 89 | utils.vector_to_parameters_backpropable(theta, net) 90 | f_X = net(X).squeeze() 91 | 92 | # Likelihood 93 | if y is not None: 94 | # Training 95 | with pyro.plate('data', size=self.n_data, subsample=y.squeeze()): 96 | pyro.sample('obs', dist.Normal(f_X, self.noise), obs=y.squeeze()) 97 | else: 98 | # Testing 99 | pyro.sample('obs', dist.Normal(f_X, self.noise)) 100 | 101 | 102 | class ClassificationModel(Model): 103 | 104 | def __init__(self, get_net, n_data, prior_prec=10, cuda=False, proj_mat=None, base_dist=None, diag=True): 105 | self.get_net = get_net 106 | 107 | n_params = sum(p.numel() for p in self.get_net().parameters()) 108 | super().__init__(n_params, n_data, prior_prec, cuda, proj_mat, base_dist, diag) 109 | 110 | def model(self, X, y=None, full_batch=False): 111 | # Sample params from the prior 112 | theta = pyro.sample('theta', dist.Normal(self.prior_mean, self.prior_std).to_event(1)) 113 | 114 | # Put the sample into the net 115 | net = self.get_net() 116 | 117 | if self.cuda: 118 | net.cuda() 119 | 120 | utils.vector_to_parameters_backpropable(theta, net) 121 | f_X = net(X) 122 | 123 | # Likelihood 124 | if y is not None: 125 | subsample = None if full_batch else y.squeeze() 126 | 127 | with pyro.plate('data', size=self.n_data, subsample=subsample): 128 | pyro.sample('obs', dist.Categorical(logits=f_X), obs=y.squeeze()) 129 | 130 | return f_X 131 | 132 | def model_subspace(self, X, y=None, full_batch=False): 133 | # Sample params from the prior on low-dim, then project it to high-dim 134 | z = pyro.sample('z', dist.MultivariateNormal(self.prior_mean_proj, self.prior_Cov_proj)) 135 | theta = self.A.T @ z 136 | 137 | # Put the sample into the net 138 | net = self.get_net() 139 | 140 | if self.cuda: 141 | net.cuda() 142 | 143 | utils.vector_to_parameters_backpropable(theta, net) 144 | f_X = net(X) 145 | 146 | # Likelihood 147 | if y is not None: 148 | subsample = None if full_batch else y.squeeze() 149 | 150 | with pyro.plate('data', size=self.n_data, subsample=subsample): 151 | pyro.sample('obs', dist.Categorical(logits=f_X), obs=y.squeeze()) 152 | 153 | return f_X 154 | 155 | 156 | class ClassificationModelLL(Model): 157 | 158 | def __init__(self, n_data, n_features, n_classes, feature_extractor, prior_prec=10, cuda=False, proj_mat=None, base_dist=None, diag=False): 159 | n_params = n_features*n_classes + n_classes # weights and biases 160 | super().__init__(n_params, n_data, prior_prec, cuda, proj_mat, base_dist, diag) 161 | 162 | self.n_features = n_features 163 | self.n_classes = n_classes 164 | self.feature_extractor = feature_extractor 165 | 166 | def model(self, X, y=None, full_batch=False, X_is_features=False): 167 | # Sample params from the prior 168 | theta = pyro.sample('theta', dist.Normal(self.prior_mean, self.prior_std).to_event(1)) 169 | f_X = self._forward(X, theta, X_is_features) 170 | 171 | # Likelihood 172 | if y is not None: 173 | subsample = None if full_batch else y.squeeze() 174 | 175 | with pyro.plate('data', size=self.n_data, subsample=subsample): 176 | pyro.sample('obs', dist.Categorical(logits=f_X), obs=y.squeeze()) 177 | 178 | return f_X 179 | 180 | def _forward(self, X, theta, X_is_features=False): 181 | # Make it compatible with PyTorch's parameters vectorization that Laplace uses 182 | W = theta[:self.n_features*self.n_classes].reshape(self.n_classes, self.n_features) 183 | b = theta[self.n_classes] 184 | 185 | if X_is_features: 186 | phi_X = X 187 | else: 188 | with torch.no_grad(): 189 | phi_X = self.feature_extractor(X) 190 | 191 | return phi_X @ W.T + b # Transpose following nn.Linear 192 | -------------------------------------------------------------------------------- /al_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | 5 | import numpy as np 6 | import pyro 7 | import pyro.distributions as dist 8 | import torch 9 | import tqdm 10 | from laplace import Laplace 11 | from pyro.infer import MCMC, NUTS, SVI, Trace_ELBO 12 | from pyro.infer.autoguide import * 13 | from pyro.infer.autoguide import initialization as init 14 | from torch.nn import functional as F 15 | 16 | from models import autoguide, network 17 | from models import pyro as pyro_models 18 | from utils import data_utils, metrics, utils 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--method', choices=['map', 'hmc', 'refine', 'refine_sub'], default='map') 22 | parser.add_argument('--dataset', choices=['fmnist', 'mnist'], default='fmnist') 23 | parser.add_argument('--n_burnins', type=int, default=100) 24 | parser.add_argument('--n_samples', type=int, default=200) 25 | parser.add_argument('--chain_id', type=int, choices=[1, 2, 3, 4, 5], default=1) 26 | parser.add_argument('--n_flows', type=int, default=5, help='Only relevant for refine and refine_sub method') 27 | parser.add_argument('--flow_type', choices=['radial', 'planar'], default='radial', help='Only relevant for refine and refine_sub method') 28 | parser.add_argument('--subspace_dim', type=int, default=100, help='Only relevant for refine_sub method') 29 | parser.add_argument('--prior_precision', type=float, default=30) 30 | parser.add_argument('--randseed', type=int, default=1) 31 | args = parser.parse_args() 32 | 33 | 34 | DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 35 | torch.manual_seed(args.randseed) 36 | np.random.seed(args.randseed) 37 | torch.backends.cudnn.benchmark = True 38 | 39 | # Random seeds for HMC initialization 40 | chain2randseed = {1: 77, 2: 777, 3: 7777, 4: 77777, 5: 777777} 41 | 42 | data_path = './data' 43 | if not os.path.exists(data_path): 44 | os.makedirs(data_path) 45 | # For saving pretrained models 46 | save_path = f'./pretrained_models/{args.dataset}/al' 47 | if not os.path.exists(save_path): 48 | os.makedirs(save_path) 49 | 50 | # The network 51 | get_net = lambda: network.MLP(n_features=784, n_hiddens=50) 52 | 53 | # Dataset --- no data augmentation 54 | loader_fn = data_utils.get_mnist_loaders if args.dataset == 'mnist' else data_utils.get_fmnist_loaders 55 | train_loader, val_loader, test_Loader = loader_fn( 56 | data_path, train=True, batch_size=128, model_class='MLP', download=True, device=DEVICE) 57 | 58 | X_train, y_train = [], [] 59 | for x, y in train_loader: 60 | X_train.append(x); y_train.append(y) 61 | X_train, y_train = torch.cat(X_train, 0), torch.cat(y_train, 0) 62 | 63 | M, N = X_train.shape[0], math.prod(X_train.shape[1:]) 64 | print(f'[Randseed: {args.randseed}] Dataset: {args.dataset.upper()}, n_data: {M}, n_feat: {N}, n_param: {utils.count_params(get_net())}') 65 | 66 | if args.method == 'map': 67 | model = get_net().cuda() 68 | wd = args.prior_precision / M # Since we use averaged NLL 69 | opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=wd) 70 | print(f'Weight decay: {wd}') 71 | 72 | schd = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=100*len(train_loader)) 73 | 74 | pbar = tqdm.trange(100) 75 | for it in pbar: 76 | model.train() 77 | epoch_loss = 0 78 | 79 | for x, y in train_loader: 80 | x, y = x.cuda(), y.cuda() 81 | 82 | out = model(x) 83 | loss = F.cross_entropy(out, y) 84 | loss.backward() 85 | opt.step() 86 | schd.step() 87 | opt.zero_grad() 88 | epoch_loss += loss.item() 89 | 90 | model.eval() 91 | train_acc = metrics.accuracy(*utils.predict(model, train_loader)) 92 | pbar.set_description(f'[Loss: {epoch_loss:.3f}; Train acc: {train_acc:.1f}]') 93 | 94 | torch.save(model.state_dict(), f'{save_path}/{args.method}_{args.randseed}.pt') 95 | 96 | elif args.method == 'refine': 97 | net = get_net() 98 | net.load_state_dict(torch.load(f'{save_path}/map_{args.randseed}.pt')) 99 | net.cuda() 100 | net.eval() 101 | 102 | la = Laplace( 103 | net, 'classification', subset_of_weights='all', 104 | hessian_structure='diag', prior_precision=args.prior_precision 105 | ) 106 | la.fit(train_loader) 107 | la.optimize_prior_precision(method='CV', val_loader=val_loader, pred_type='nn', n_samples=10) 108 | 109 | base_dist_mean = la.mean 110 | base_dist_var = la.posterior_variance 111 | 112 | model = pyro_models.ClassificationModel( 113 | get_net, n_data=M, prior_prec=args.prior_precision, cuda=True 114 | ) 115 | guide = autoguide.AutoNormalizingFlowCustom( 116 | model.model, base_dist_mean, base_dist_var, diag=True, 117 | flow_type=args.flow_type, flow_len=args.n_flows, cuda=True 118 | ) 119 | 120 | n_epochs = 20 121 | n_iters = n_epochs * len(train_loader) 122 | schd = pyro.optim.CosineAnnealingLR({ 123 | 'optimizer': torch.optim.Adam, 'optim_args': {'lr': 1e-3, 'weight_decay': 0}, 'T_max': n_iters 124 | }) 125 | 126 | svi = SVI(model.model, guide, optim=schd, loss=Trace_ELBO()) 127 | pbar = tqdm.trange(n_epochs) 128 | 129 | for it in pbar: 130 | for x, y in train_loader: 131 | x, y = x.cuda(), y.cuda() 132 | loss = svi.step(x, y) 133 | schd.step() 134 | 135 | pbar.set_description(f'[Loss: {loss:.3f}]') 136 | 137 | state_dict = { 138 | 'base_dist_mean': base_dist_mean, 'base_dist_var': base_dist_var, 139 | 'flow_type': args.flow_type, 'flow_len': args.n_flows, 140 | 'state_dict': guide.state_dict() 141 | } 142 | torch.save(state_dict, f'{save_path}/{args.method}_{args.flow_type}_{args.n_flows}_{args.randseed}.pt') 143 | 144 | elif args.method == 'refine_sub': 145 | model = get_net() 146 | model.load_state_dict(torch.load(f'{save_path}/map_{args.randseed}.pt')) 147 | model.cuda() 148 | 149 | la = Laplace( 150 | model, 'classification', subset_of_weights='all', 151 | hessian_structure='diag', prior_precision=args.prior_precision 152 | ) 153 | la.fit(train_loader) 154 | 155 | base_dist_mean = la.mean 156 | base_dist_scale = torch.sqrt(la.posterior_variance) 157 | base_dist = dist.Normal(base_dist_mean, base_dist_scale) 158 | 159 | # Projection matrix 160 | A = torch.zeros(len(base_dist_mean), args.subspace_dim, device=DEVICE) 161 | eigval_idxs = torch.argsort(la.posterior_variance, descending=True)[:args.subspace_dim] 162 | A[eigval_idxs, range(args.subspace_dim)] = 1 163 | A = A.T 164 | 165 | model = pyro_models.ClassificationModel( 166 | get_net, n_data=M, prior_prec=args.prior_precision, cuda=True, 167 | proj_mat=A, base_dist=base_dist, diag=True 168 | ) 169 | guide = autoguide.AutoNormalizingFlowCustom( 170 | model.model_subspace, model.base_mean_proj, model.base_Cov_proj, diag=False, 171 | flow_type=args.flow_type, flow_len=args.n_flows, cuda=True 172 | ) 173 | 174 | n_epochs = 20 175 | n_iters = n_epochs * len(train_loader) 176 | schd = pyro.optim.CosineAnnealingLR({ 177 | 'optimizer': torch.optim.Adam, 'optim_args': {'lr': 1e-1, 'weight_decay': 0}, 'T_max': n_iters 178 | }) 179 | 180 | svi = SVI(model.model_subspace, guide, optim=schd, loss=Trace_ELBO()) 181 | pbar = tqdm.trange(n_epochs) 182 | 183 | for it in pbar: 184 | for x, y in train_loader: 185 | x, y = x.cuda(), y.cuda() 186 | loss = svi.step(x, y) 187 | schd.step() 188 | 189 | pbar.set_description(f'[Loss: {loss:.3f}]') 190 | 191 | state_dict = { 192 | 'proj_mat': A, 'base_dist_mean': model.base_mean_proj, 'base_dist_Cov': model.base_Cov_proj, 193 | 'flow_type': args.flow_type, 'flow_len': args.n_flows, 194 | 'state_dict': guide.state_dict() 195 | } 196 | torch.save(state_dict, f'{save_path}/{args.method}_{args.subspace_dim}_{args.flow_type}_{args.n_flows}_{args.randseed}.pt') 197 | 198 | elif args.method == 'hmc': 199 | model = pyro_models.ClassificationModel( 200 | get_net=get_net, n_data=M, prior_prec=args.prior_precision, cuda=True 201 | ) 202 | 203 | pyro.util.set_rng_seed(chain2randseed[args.chain_id]) 204 | # Initialize the HMC by sampling from the prior 205 | mcmc_kernel = NUTS(model.model, init_strategy=init.init_to_sample) 206 | mcmc = MCMC( 207 | mcmc_kernel, num_samples=args.n_samples, warmup_steps=args.n_burnins, 208 | ) 209 | 210 | mcmc.run(X_train.cuda(), y_train.cuda(), full_batch=True) 211 | 212 | # HMC samples 213 | hmc_samples = mcmc.get_samples()['theta'].flatten(1).cpu().numpy() 214 | np.save(f'{save_path}/{args.method}_{args.randseed}_{args.chain_id}.npy', hmc_samples) 215 | -------------------------------------------------------------------------------- /ll_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import functools 3 | import math 4 | import os 5 | 6 | import numpy as np 7 | import pyro 8 | import pyro.distributions as dist 9 | import torch 10 | import tqdm 11 | from laplace import Laplace 12 | from pyro.infer import MCMC, NUTS, SVI, Trace_ELBO 13 | from pyro.infer.autoguide import * 14 | from pyro.infer.autoguide import initialization as init 15 | from torch.cuda import amp 16 | from torch.nn import functional as F 17 | from torch.utils.data import DataLoader, TensorDataset 18 | 19 | from models import autoguide, network 20 | from models import pyro as pyro_models 21 | from utils import data_utils, metrics, utils 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--method', choices=['map', 'hmc', 'refine', 'nf_naive'], default='map') 25 | parser.add_argument('--dataset', choices=['fmnist', 'cifar10', 'cifar100'], default='fmnist') 26 | parser.add_argument('--n_burnins', type=int, default=100) 27 | parser.add_argument('--n_samples', type=int, default=200) 28 | parser.add_argument('--chain_id', type=int, choices=[1, 2, 3, 4, 5], default=1) 29 | parser.add_argument('--n_flows', type=int, default=5, help='Only relevant for `args.method == "refine"`') 30 | parser.add_argument('--flow_type', choices=['radial', 'planar'], default='radial', help='Only relevant for `args.method == "refine"`') 31 | parser.add_argument('--prior_precision', type=float, default=None) 32 | parser.add_argument('--randseed', type=int, default=1) 33 | args = parser.parse_args() 34 | 35 | 36 | DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 37 | torch.manual_seed(args.randseed) 38 | np.random.seed(args.randseed) 39 | torch.backends.cudnn.benchmark = True 40 | 41 | # Random seeds for HMC initialization 42 | chain2randseed = {1: 77, 2: 777, 3: 7777, 4: 77777, 5: 777777} 43 | 44 | # For saving pretrained models 45 | save_path = f'./pretrained_models/{args.dataset}/ll' 46 | if not os.path.exists(save_path): 47 | os.makedirs(save_path) 48 | 49 | # Prior precision 50 | if args.prior_precision is None: 51 | # From gridsearch 52 | args.prior_precision = 40 if 'cifar' in args.dataset else 510 53 | 54 | print(f'Prior prec: {args.prior_precision}') 55 | 56 | N_FEAT = 256 if args.dataset != 'fmnist' else 84 57 | N_CLASS = 10 if args.dataset != 'cifar100' else 100 58 | 59 | # The network 60 | get_net = lambda: network.WideResNet(16, 4, num_classes=N_CLASS) if 'cifar' in args.dataset else network.LeNet() 61 | 62 | # Dataset --- no data augmentation 63 | data_path = './data' 64 | if not os.path.exists(data_path): 65 | os.makedirs(data_path) 66 | 67 | if args.dataset == 'fmnist': 68 | train_loader, val_loader, test_loader = data_utils.get_fmnist_loaders(data_path, download=True, device=DEVICE) 69 | elif args.dataset == 'cifar10': 70 | train_loader, val_loader, test_loader = data_utils.get_cifar10_loaders( 71 | data_path, data_augmentation=(args.method == 'map'), normalize=False, download=True) 72 | else: 73 | train_loader, val_loader, test_loader = data_utils.get_cifar100_loaders( 74 | data_path, data_augmentation=(args.method == 'map'), normalize=False, download=True) 75 | 76 | X_train, y_train = [], [] 77 | for x, y in train_loader: 78 | X_train.append(x); y_train.append(y) 79 | X_train, y_train = torch.cat(X_train, 0), torch.cat(y_train, 0) 80 | 81 | M, N = X_train.shape[0], math.prod(X_train.shape[1:]) 82 | print(f'[Randseed: {args.randseed}] Dataset: {args.dataset.upper()}, n_data: {M}, n_feat: {N}, n_param: {utils.count_params(get_net())}') 83 | 84 | if args.method == 'map': 85 | model = get_net().cuda() 86 | 87 | if 'cifar' in args.dataset: 88 | opt = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4) 89 | else: 90 | opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4) 91 | 92 | schd = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=100*len(train_loader)) 93 | scaler = amp.GradScaler() 94 | 95 | pbar = tqdm.trange(100) 96 | for it in pbar: 97 | model.train() 98 | epoch_loss = 0 99 | 100 | for x, y in train_loader: 101 | x, y = x.cuda(), y.cuda() 102 | 103 | with amp.autocast(): 104 | out = model(x) 105 | loss = F.cross_entropy(out, y) 106 | 107 | scaler.scale(loss).backward() 108 | scaler.step(opt) 109 | scaler.update() 110 | schd.step() 111 | opt.zero_grad() 112 | 113 | epoch_loss += loss.item() 114 | 115 | model.eval() 116 | train_acc = metrics.accuracy(*utils.predict(model, train_loader)) 117 | pbar.set_description(f'[Loss: {epoch_loss:.3f}; Train acc: {train_acc:.1f}]') 118 | 119 | torch.save(model.state_dict(), f'{save_path}/{args.method}_{args.randseed}.pt') 120 | 121 | elif args.method == 'nf_naive': 122 | net = get_net() 123 | net.load_state_dict(torch.load(f'{save_path}/map_{args.randseed}.pt')) 124 | net.cuda() 125 | net.eval() 126 | 127 | model = pyro_models.ClassificationModelLL( 128 | n_data=M, n_classes=N_CLASS, n_features=N_FEAT, feature_extractor=net.forward_features, 129 | prior_prec=args.prior_precision, cuda=True 130 | ) 131 | 132 | nf = dist.transforms.planar if args.flow_type == 'planar' else dist.transforms.radial 133 | transform_init = functools.partial(dist.transforms.iterated, args.n_flows, nf) 134 | guide = autoguide.AutoNormalizingFlowCuda(model.model, transform_init, cuda=True) 135 | 136 | n_epochs = 20 137 | n_iters = n_epochs * len(train_loader) 138 | schd = pyro.optim.CosineAnnealingLR({ 139 | 'optimizer': torch.optim.Adam, 'optim_args': {'lr': 1e-3, 'weight_decay': 0}, 'T_max': n_iters 140 | }) 141 | 142 | svi = SVI(model.model, guide, optim=schd, loss=Trace_ELBO()) 143 | pbar = tqdm.trange(20) 144 | 145 | # Cache features for faster training 146 | features_train, y_train = utils.get_features(net, train_loader) 147 | train_loader_features = DataLoader( 148 | TensorDataset(features_train.cpu(), y_train.unsqueeze(-1).cpu()), 149 | batch_size=128, shuffle=True, num_workers=0 150 | ) 151 | 152 | for it in pbar: 153 | for x, y in train_loader_features: 154 | x, y = x.cuda(), y.cuda() 155 | loss = svi.step(x, y, X_is_features=True) 156 | schd.step() 157 | 158 | pbar.set_description(f'[Loss: {loss:.3f}]') 159 | 160 | state_dict = { 161 | 'flow_type': args.flow_type, 'flow_len': args.n_flows, 162 | 'state_dict': guide.state_dict() 163 | } 164 | torch.save(state_dict, f'{save_path}/{args.method}_{args.flow_type}_{args.n_flows}_{args.randseed}.pt') 165 | 166 | 167 | elif args.method == 'refine': 168 | net = get_net() 169 | net.load_state_dict(torch.load(f'{save_path}/map_{args.randseed}.pt')) 170 | net.cuda() 171 | net.eval() 172 | 173 | hess_str = 'diag' if args.dataset == 'cifar100' else 'full' 174 | 175 | la = Laplace( 176 | net, 'classification', subset_of_weights='last_layer', 177 | hessian_structure=hess_str, prior_precision=args.prior_precision 178 | ) 179 | la.fit(train_loader) 180 | la.optimize_prior_precision() 181 | 182 | if hess_str == 'diag': 183 | base_dist_cov = la.posterior_variance 184 | base_dist = dist.Normal(la.mean, base_dist_cov.sqrt()) 185 | diag = True 186 | else: 187 | base_dist_cov = la.posterior_covariance 188 | base_dist = dist.MultivariateNormal(la.mean, base_dist_cov) 189 | diag = False 190 | 191 | model = pyro_models.ClassificationModelLL( 192 | n_data=M, n_classes=N_CLASS, n_features=N_FEAT, feature_extractor=net.forward_features, 193 | prior_prec=args.prior_precision, cuda=True 194 | ) 195 | 196 | guide = autoguide.AutoNormalizingFlowCustom( 197 | model.model, base_dist.mean, base_dist_cov, diag=diag, 198 | flow_type=args.flow_type, flow_len=args.n_flows, cuda=True 199 | ) 200 | 201 | n_epochs = 20 202 | n_iters = n_epochs * len(train_loader) 203 | schd = pyro.optim.CosineAnnealingLR({ 204 | 'optimizer': torch.optim.Adam, 'optim_args': {'lr': 1e-3, 'weight_decay': 0}, 'T_max': n_iters 205 | }) 206 | 207 | svi = SVI(model.model, guide, optim=schd, loss=Trace_ELBO()) 208 | pbar = tqdm.trange(20) 209 | 210 | # Cache features for faster training 211 | features_train, y_train = utils.get_features(net, train_loader) 212 | train_loader_features = DataLoader( 213 | TensorDataset(features_train.cpu(), y_train.unsqueeze(-1).cpu()), 214 | batch_size=128, shuffle=True, num_workers=0 215 | ) 216 | 217 | for it in pbar: 218 | for x, y in train_loader_features: 219 | x, y = x.cuda(), y.cuda() 220 | loss = svi.step(x, y, X_is_features=True) 221 | schd.step() 222 | 223 | pbar.set_description(f'[Loss: {loss:.3f}]') 224 | 225 | cov_key = 'base_dist_var' if diag else 'base_dist_Cov' 226 | state_dict = { 227 | 'base_dist_mean': base_dist.mean, cov_key: base_dist_cov, 228 | 'flow_type': args.flow_type, 'flow_len': args.n_flows, 229 | 'state_dict': guide.state_dict() 230 | } 231 | 232 | torch.save(state_dict, f'{save_path}/{args.method}_{args.flow_type}_{args.n_flows}_{args.randseed}.pt') 233 | 234 | elif args.method == 'hmc': 235 | net = get_net() 236 | net.load_state_dict(torch.load(f'{save_path}/map_{args.randseed}.pt')) 237 | net.cuda() 238 | net.eval() 239 | 240 | model = pyro_models.ClassificationModelLL( 241 | n_data=M, n_classes=N_CLASS, n_features=N_FEAT, feature_extractor=net.forward_features, 242 | prior_prec=args.prior_precision, cuda=True 243 | ) 244 | 245 | pyro.util.set_rng_seed(chain2randseed[args.chain_id]) 246 | # Initialize the HMC by sampling from the prior 247 | mcmc_kernel = NUTS(model.model, init_strategy=init.init_to_sample) 248 | mcmc = MCMC( 249 | mcmc_kernel, num_samples=args.n_samples, warmup_steps=args.n_burnins, 250 | ) 251 | 252 | features_train, y_train = utils.get_features(net, train_loader) 253 | mcmc.run(features_train, y_train, full_batch=True, X_is_features=True) 254 | 255 | # HMC samples 256 | hmc_samples = mcmc.get_samples()['theta'].flatten(1).cpu().numpy() 257 | np.save(f'{save_path}/{args.method}_{args.randseed}_{args.chain_id}.npy', hmc_samples) 258 | -------------------------------------------------------------------------------- /baselines/swag/swag.py: -------------------------------------------------------------------------------- 1 | """ 2 | implementation of SWAG, taken/adapted from: 3 | https://github.com/wjmaddox/swa_gaussian/blob/master/swag/posteriors/swag.py 4 | """ 5 | 6 | import copy 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | from tqdm import tqdm 12 | 13 | from baselines.swag.swag_utils import (apply_bn_update, bn_update, flatten, 14 | unflatten_like) 15 | 16 | 17 | def fit_swag(model, device, train_loader, loss_func, diag_only=True, max_num_models=20, swa_c_epochs=1, swa_c_batches=None, swa_lr=0.01, momentum=0.9, wd=3e-4, mask=None, parallel=False): 18 | """ 19 | Fit SWAG model 20 | (adapted from https://github.com/wjmaddox/swa_gaussian/blob/master/experiments/train/run_swag.py) 21 | 22 | Args: 23 | diag_only: bool flag to only store diagonal of SWAG covariance matrix (Default: True) 24 | max_num_models: int for maximum number of SWAG models to save (Default: 20) 25 | swa_c_epochs: int for SWA model collection frequency/cycle length in epochs (Default: 1) 26 | swa_c_batches: int for SWA model collection frequency/cycle length in batches (Default: None) 27 | swa_lr: float for SWA learning rate; use 0.05 for CIFAR100 and 0.01 otherwise (Default: 0.01) 28 | momentum: float for SGD momentum (Default: 0.9) 29 | wd: float for weight decay (Default: 3e-4) 30 | mask: dict of subnetwork masks (Default: None) 31 | parallel: data parallel model switch (default: False) 32 | """ 33 | 34 | if swa_c_epochs is not None and swa_c_batches is not None: 35 | raise RuntimeError("One of swa_c_epochs or swa_c_batches must be None!") 36 | 37 | if parallel: 38 | print("Using Data Parallel model") 39 | model = torch.nn.DataParallel(model).cuda(device) 40 | 41 | swag_model = SWAG(copy.deepcopy(model), no_cov_mat=diag_only, max_num_models=max_num_models, mask=mask).to(device) 42 | optimizer = torch.optim.SGD(model.parameters(), lr=swa_lr, momentum=momentum, weight_decay=wd) 43 | 44 | print("Running SWAG...") 45 | model.train() 46 | if swa_c_epochs is not None: 47 | n_epochs = swa_c_epochs * max_num_models 48 | else: 49 | n_epochs = 1 + (max_num_models * swa_c_batches) // len(train_loader) 50 | for epoch in tqdm(range(int(n_epochs))): 51 | for batch_idx, (inputs, targets) in tqdm(enumerate(train_loader)): 52 | inputs = inputs.to(device, non_blocking=True) 53 | targets = targets.to(device, non_blocking=True) 54 | loss = loss_func(model(inputs), targets) 55 | 56 | optimizer.zero_grad() 57 | loss.backward() 58 | optimizer.step() 59 | 60 | if swa_c_batches is not None and (batch_idx+1) % swa_c_batches == 0: 61 | swag_model.collect_model(model) 62 | if swag_model.n_models == max_num_models: 63 | break 64 | 65 | if swa_c_epochs is not None and epoch % swa_c_epochs == 0: 66 | swag_model.collect_model(model) 67 | 68 | return swag_model 69 | 70 | 71 | def fit_swag_and_precompute_bn_params(model, device, train_loader, max_num_models, swa_lr, swa_c_epochs, swa_c_batches, parallel, n_samples, bn_update_subset): 72 | """ fit SWAG model on training data and pre-compute SWAG weight samples and corresponding BatchNorm parameters """ 73 | 74 | # fit SWAG model on training data 75 | nll_fun = torch.nn.CrossEntropyLoss(reduction='mean') 76 | swag_model = fit_swag(copy.deepcopy(model), device, train_loader, nll_fun, 77 | diag_only=False, max_num_models=max_num_models, 78 | swa_lr=swa_lr, swa_c_epochs=swa_c_epochs, 79 | swa_c_batches=swa_c_batches, parallel=parallel) 80 | swag_model.base = swag_model.base.to(device) 81 | 82 | # pre-compute SWAG weight samples and corresponding BatchNorm parameters for every component 83 | swag_samples = [swag_model.sample() for _ in range(n_samples)] 84 | swag_bn_params = [] 85 | for i, sample in enumerate(swag_samples): 86 | print(f"Computing BatchNorm statistics for SWAG sample #{i+1}...") 87 | swag_model.set_model_parameters(sample) 88 | swag_bn_params.append(bn_update(train_loader, swag_model, verbose=True, subset=bn_update_subset)) 89 | 90 | return swag_model, swag_samples, swag_bn_params 91 | 92 | 93 | def predict_swag(swag_model, x, swag_samples, swag_bn_params): 94 | """ Make predictions with SWAG on a single data batch (x, y) """ 95 | 96 | swag_model.eval() 97 | swag_model.base.eval() 98 | 99 | out = 0. 100 | for sample, bn_param in zip(swag_samples, swag_bn_params): 101 | # set sampled model weights and update BatchNorm statistics 102 | swag_model.set_model_parameters(sample) 103 | apply_bn_update(swag_model, bn_param) 104 | f_s = swag_model(x).detach() 105 | out += torch.softmax(f_s, dim=1) 106 | out /= len(swag_samples) 107 | 108 | return out 109 | 110 | 111 | class SWAG(torch.nn.Module): 112 | def __init__(self, base, no_cov_mat=True, max_num_models=0, var_clamp=1e-30, mask=None): 113 | super(SWAG, self).__init__() 114 | 115 | self.register_buffer("n_models", torch.zeros([1], dtype=torch.long)) 116 | self.params = list() 117 | 118 | self.no_cov_mat = no_cov_mat 119 | self.max_num_models = max_num_models 120 | 121 | self.var_clamp = var_clamp 122 | self.mask = mask 123 | 124 | self.base = base 125 | self.init_swag_parameters(params=self.params, no_cov_mat=self.no_cov_mat, mask=self.mask) 126 | #self.base.apply(lambda module: swag_parameters(module=module, params=self.params, no_cov_mat=self.no_cov_mat, only_nonzero=self.only_nonzero)) 127 | 128 | def forward(self, *args, **kwargs): 129 | return self.base(*args, **kwargs) 130 | 131 | def init_swag_parameters(self, params, no_cov_mat=True, mask=None): 132 | for mod_name, module in self.base.named_modules(): 133 | for name in list(module._parameters.keys()): 134 | if module._parameters[name] is None: 135 | continue 136 | 137 | name_full = f"{mod_name}.{name}".replace(".", "-") 138 | data = module._parameters[name].data 139 | module._parameters.pop(name) 140 | module.register_buffer("%s_mean" % name_full, data.new(data.size()).zero_()) 141 | module.register_buffer("%s_sq_mean" % name_full, data.new(data.size()).zero_()) 142 | 143 | if no_cov_mat is False: 144 | if mask and name_full.replace("-", ".") in mask: 145 | data = data[mask[name_full.replace("-", ".")].nonzero(as_tuple=True)] 146 | module.register_buffer("%s_cov_mat_sqrt" % name_full, data.new_empty((0, data.numel())).zero_()) 147 | 148 | params.append((module, name_full)) 149 | 150 | def get_mean_vector(self, batchnorm_layers, mask=None): 151 | mean_list = [] 152 | for module, name in self.params: 153 | name_full = name.replace("-", ".") 154 | if 'weight' in name_full and name_full not in batchnorm_layers: 155 | mean = module.__getattr__("%s_mean" % name) 156 | if mask is not None: 157 | mean = mean[mask[name_full].nonzero(as_tuple=True)] 158 | mean_list.append(mean.cpu()) 159 | return flatten(mean_list) 160 | 161 | def get_variance_vector(self, batchnorm_layers, mask=None): 162 | mean_list = [] 163 | sq_mean_list = [] 164 | 165 | for module, name in self.params: 166 | name_full = name.replace("-", ".") 167 | if 'weight' in name_full and name_full not in batchnorm_layers: 168 | mean = module.__getattr__("%s_mean" % name) 169 | sq_mean = module.__getattr__("%s_sq_mean" % name) 170 | 171 | if mask is not None: 172 | mean = mean[mask[name_full].nonzero(as_tuple=True)] 173 | sq_mean = sq_mean[mask[name_full].nonzero(as_tuple=True)] 174 | 175 | mean_list.append(mean.cpu()) 176 | sq_mean_list.append(sq_mean.cpu()) 177 | 178 | mean = flatten(mean_list) 179 | sq_mean = flatten(sq_mean_list) 180 | 181 | variances = torch.clamp(sq_mean - mean ** 2, self.var_clamp) 182 | 183 | return variances 184 | 185 | def get_covariance_matrix(self, batchnorm_layers, eps=1e-10): 186 | if self.no_cov_mat: 187 | raise RuntimeError("No covariance matrix was estimated!") 188 | 189 | cov_mat_sqrt_list = [] 190 | for module, name in self.params: 191 | name_full = name.replace("-", ".") 192 | if 'weight' in name_full and name_full not in batchnorm_layers: 193 | cov_mat_sqrt = module.__getattr__("%s_cov_mat_sqrt" % name) 194 | cov_mat_sqrt_list.append(cov_mat_sqrt.cpu()) 195 | 196 | # build low-rank covariance matrix 197 | cov_mat_sqrt = torch.cat(cov_mat_sqrt_list, dim=1) 198 | print(cov_mat_sqrt.shape) 199 | cov_mat = torch.matmul(cov_mat_sqrt.t(), cov_mat_sqrt) 200 | cov_mat /= (self.max_num_models - 1) 201 | print(cov_mat.shape) 202 | 203 | # obtain covariance matrix by adding variances (+ eps for numerical stability) to diagonal and scaling 204 | var = self.get_variance_vector(batchnorm_layers, mask=self.mask) + eps 205 | cov_mat.add_(torch.diag(var)).mul_(0.5) 206 | 207 | return cov_mat 208 | 209 | def sample(self, scale=0.5, cov=True, seed=None): 210 | if seed is not None: 211 | torch.manual_seed(seed) 212 | 213 | mean_list = [] 214 | sq_mean_list = [] 215 | if cov: 216 | cov_mat_sqrt_list = [] 217 | 218 | for (module, name) in self.params: 219 | mean_list.append(module.__getattr__("%s_mean" % name).cpu()) 220 | sq_mean_list.append(module.__getattr__("%s_sq_mean" % name).cpu()) 221 | if cov: 222 | cov_mat_sqrt_list.append(module.__getattr__("%s_cov_mat_sqrt" % name).cpu()) 223 | 224 | mean = flatten(mean_list) 225 | sq_mean = flatten(sq_mean_list) 226 | 227 | # draw diagonal variance sample 228 | var = torch.clamp(sq_mean - mean ** 2, self.var_clamp) 229 | rand_sample = var.sqrt() * torch.randn_like(var, requires_grad=False) 230 | 231 | # if covariance draw low rank sample 232 | if cov: 233 | cov_mat_sqrt = torch.cat(cov_mat_sqrt_list, dim=1) 234 | eps = cov_mat_sqrt.new_empty((cov_mat_sqrt.size(0),), requires_grad=False).normal_() 235 | cov_sample = cov_mat_sqrt.t().matmul(eps) 236 | cov_sample /= (self.max_num_models - 1) ** 0.5 237 | rand_sample += cov_sample 238 | 239 | # update sample with mean and scale 240 | sample = (mean + scale**0.5 * rand_sample).unsqueeze(0) 241 | 242 | # unflatten new sample like the mean sample 243 | samples_list = unflatten_like(sample, mean_list) 244 | self.set_model_parameters(samples_list) 245 | 246 | return samples_list 247 | 248 | def set_model_parameters(self, parameter_list): 249 | for (module, name), param in zip(self.params, parameter_list): 250 | module.__setattr__(name.split("-")[-1], param.cuda()) 251 | 252 | def collect_model(self, base_model): 253 | for (module, name), base_param in zip(self.params, base_model.parameters()): 254 | data = base_param.data 255 | 256 | mean = module.__getattr__("%s_mean" % name) 257 | sq_mean = module.__getattr__("%s_sq_mean" % name) 258 | 259 | # first moment 260 | mean = mean * self.n_models.item() / ( 261 | self.n_models.item() + 1.0 262 | ) + data / (self.n_models.item() + 1.0) 263 | 264 | # second moment 265 | sq_mean = sq_mean * self.n_models.item() / ( 266 | self.n_models.item() + 1.0 267 | ) + data ** 2 / (self.n_models.item() + 1.0) 268 | 269 | # square root of covariance matrix 270 | if self.no_cov_mat is False: 271 | cov_mat_sqrt = module.__getattr__("%s_cov_mat_sqrt" % name) 272 | 273 | # block covariance matrices, store deviation from current mean 274 | dev = (data - mean) 275 | name_full = name.replace("-", ".") 276 | if self.mask and name_full in self.mask: 277 | dev = dev[self.mask[name_full].nonzero(as_tuple=True)] 278 | cov_mat_sqrt = torch.cat((cov_mat_sqrt, dev.view(-1, 1).t()), dim=0) 279 | 280 | # remove first column if we have stored too many models 281 | if (self.n_models.item() + 1) > self.max_num_models: 282 | cov_mat_sqrt = cov_mat_sqrt[1:, :] 283 | module.__setattr__("%s_cov_mat_sqrt" % name, cov_mat_sqrt) 284 | 285 | module.__setattr__("%s_mean" % name, mean) 286 | module.__setattr__("%s_sq_mean" % name, sq_mean) 287 | self.n_models.add_(1) 288 | 289 | def load_state_dict(self, state_dict, strict=False): 290 | if not self.no_cov_mat: 291 | n_models = state_dict["n_models"].item() 292 | rank = min(n_models, self.max_num_models) 293 | for module, name in self.params: 294 | mean = module.__getattr__("%s_mean" % name) 295 | module.__setattr__( 296 | "%s_cov_mat_sqrt" % name, 297 | mean.new_empty((rank, mean.numel())).zero_(), 298 | ) 299 | super(SWAG, self).load_state_dict(state_dict, strict) 300 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | import time 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import pyro.distributions as dist 8 | import scipy.stats as st 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | import torch.nn.functional as F 12 | import torchvision.models as torch_models 13 | from laplace.curvature import AsdlEF, AsdlGGN, BackPackEF, BackPackGGN 14 | from pyro.infer import SVI, Trace_ELBO 15 | from pyro.optim import ClippedAdam 16 | from sklearn.metrics import mean_squared_error, roc_auc_score 17 | from torch.nn.utils import parameters_to_vector 18 | 19 | import utils.wilds_utils as wu 20 | from baselines.bbb.models import lenet as lenet_bbb 21 | from baselines.bbb.models import wrn as wrn_bbb 22 | from baselines.vanilla.models import lenet, mlp, wrn, wrn_fixup 23 | from models import autoguide, network 24 | 25 | 26 | def set_seed(seed): 27 | np.random.seed(seed) 28 | torch.manual_seed(seed) 29 | if torch.cuda.is_available(): 30 | torch.cuda.manual_seed(seed) 31 | cudnn.deterministic = True 32 | cudnn.benchmark = False 33 | 34 | 35 | def load_pretrained_model(args, model_idx, n_classes, device): 36 | """ Choose appropriate architecture and load pre-trained weights """ 37 | 38 | if 'WILDS' in args.benchmark: 39 | dataset = args.benchmark[6:] 40 | model = wu.load_pretrained_wilds_model(dataset, args.models_root, 41 | device, model_idx, args.model_seed) 42 | else: 43 | model = get_model(args.model, n_classes, no_dropout=args.no_dropout) 44 | if args.benchmark in ['R-MNIST', 'MNIST-OOD']: 45 | fpath = os.path.join(args.models_root, 'lenet_mnist/lenet_mnist_{}_{}') 46 | elif args.benchmark in ['R-FMNIST', 'FMNIST-OOD']: 47 | if args.method == 'csghmc': 48 | fpath = os.path.join(args.models_root, 'csghmc/lenet_fmnist/lenet_fmnist_{}_1') 49 | elif args.method == 'bbb': 50 | fpath = os.path.join(args.models_root, 'bbb/flipout/lenet_fmnist/lenet_fmnist_{}_1') 51 | else: 52 | subset_of_weights = 'll' if args.subset_of_weights == 'last_layer' else 'al' 53 | fpath = os.path.join(args.models_root, f'fmnist/{subset_of_weights}/' + 'map_{}.pt') 54 | elif args.benchmark in ['CIFAR-10-C', 'CIFAR-10-OOD']: 55 | if args.method == 'csghmc': 56 | fpath = os.path.join(args.models_root, 'csghmc/wrn_16-4_cifar10/wrn_16-4_cifar10_{}_1') 57 | elif args.method == 'bbb': 58 | fpath = os.path.join(args.models_root, 'bbb/flipout/wrn_16-4_cifar10/wrn_16-4_cifar10_{}_1') 59 | else: 60 | fpath = os.path.join(args.models_root, 'cifar10/ll/map_{}.pt') 61 | elif args.benchmark == 'CIFAR-100-OOD': 62 | if args.method == 'csghmc': 63 | fpath = os.path.join(args.models_root, 'csghmc/wrn_16-4_cifar100/wrn_16-4_cifar100_{}_1') 64 | elif args.method == 'bbb': 65 | fpath = os.path.join(args.models_root, 'bbb/flipout/wrn_16-4_cifar100/wrn_16-4_cifar100_{}_1') 66 | else: 67 | fpath = os.path.join(args.models_root, 'cifar100_new/ll/map_{}.pt') 68 | elif args.benchmark == 'ImageNet-C': 69 | fpath = os.path.join(args.models_root, 'wrn50-2_imagenet/wrn_50-2_imagenet_{}_{}') 70 | 71 | if args.method == 'csghmc': 72 | fname = fpath.format(args.model_seed) 73 | state_dicts = torch.load(fname, map_location=device) 74 | 75 | for m, state_dict in zip(model, state_dicts): 76 | m.load_state_dict(state_dict) 77 | m.to(device) 78 | 79 | if args.data_parallel: 80 | model = [torch.nn.DataParallel(m) for m in model] 81 | else: 82 | if args.model_path is not None: 83 | model.load_state_dict(torch.load(args.model_path, map_location=device), 84 | strict=False) 85 | else: 86 | model_seed = args.model_seed 87 | if args.method == 'ensemble': 88 | model_seed = args.nr_components * model_seed - model_idx 89 | fname = (fpath.format(args.model_seed, model_idx+1) if 'FMNIST' not in args.benchmark 90 | and 'CIFAR' not in args.benchmark else fpath.format(model_seed)) 91 | load_model = ( 92 | model.net 93 | if args.benchmark in ['R-MNIST', 'MNIST-OOD'] and 'BBB' not in args.model 94 | else model 95 | ) 96 | load_model.load_state_dict(torch.load(fname, map_location=device), strict=False) 97 | model.to(device) 98 | 99 | if args.data_parallel and (args.method != 'csghmc'): 100 | model = torch.nn.DataParallel(model) 101 | return model 102 | 103 | 104 | def get_model(model_class, n_classes=10, no_dropout=False): 105 | if model_class == 'MLP': 106 | model = mlp.MLP([784, 100, 100, n_classes], act='relu') 107 | elif model_class == 'FMNIST-MLP': 108 | model = network.MLP(784, n_hiddens=50) 109 | elif model_class == 'LeNet': 110 | model = network.LeNet() 111 | elif model_class == 'LeNet-BBB-reparam': 112 | model = lenet_bbb.LeNetBBB(estimator='reparam') 113 | elif model_class == 'LeNet-BBB-flipout': 114 | model = lenet_bbb.LeNetBBB(estimator='flipout') 115 | elif model_class == 'LeNet-CSGHMC': 116 | model = [lenet.LeNet() for _ in range(12)] # 12 samples in CSGHMC 117 | elif model_class == 'WRN16-4': 118 | model = wrn.WideResNet(16, 4, n_classes, dropRate=0.3) 119 | elif model_class == 'WRN16-4-fixup': 120 | model = wrn_fixup.FixupWideResNet(16, 4, n_classes, dropRate=0.0 if no_dropout else 0.2) 121 | elif model_class == 'WRN16-4-BBB-reparam': 122 | model = wrn_bbb.WideResNetBBB(16, 4, n_classes, estimator='reparam') 123 | elif model_class == 'WRN16-4-BBB-flipout': 124 | model = wrn_bbb.WideResNetBBB(16, 4, n_classes, estimator='flipout') 125 | elif model_class == 'WRN16-4-CSGHMC': 126 | model = [wrn.WideResNet(16, 4, n_classes, dropRate=0) for _ in range(12)] # 12 samples in CSGHMC 127 | elif model_class == 'WRN50-2': 128 | model = torch_models.wide_resnet50_2() 129 | else: 130 | raise ValueError('Choose LeNet, WRN16-4, or WRN50-2 as model_class.') 131 | return model 132 | 133 | 134 | def mixture_model_pred(components, x, mixture_weights, prediction_mode='mola', 135 | pred_type='glm', link_approx='probit', n_samples=100, 136 | likelihood='classification'): 137 | if prediction_mode == 'ensemble': 138 | return ensemble_pred(components, x, likelihood=likelihood) 139 | 140 | out = 0. # out will be a tensor 141 | for model, pi in zip(components, mixture_weights): 142 | if prediction_mode == 'mola': 143 | out_prob = model(x, pred_type=pred_type, n_samples=n_samples, link_approx=link_approx) 144 | elif prediction_mode == 'multi-swag': 145 | from baselines.swag.swag import predict_swag 146 | swag_model, swag_samples, swag_bn_params = model 147 | out_prob = predict_swag(swag_model, x, swag_samples, swag_bn_params) 148 | else: 149 | raise ValueError('For now only ensemble, mola, and multi-swag are supported.') 150 | out += pi * out_prob 151 | return out 152 | 153 | 154 | def ensemble_pred(components, x, likelihood='classification'): 155 | """ Make predictions for deep ensemble """ 156 | 157 | outs = [] 158 | for model in components: 159 | model.eval() 160 | out_prob = model(x).detach() 161 | if likelihood == 'classification': 162 | out_prob = out_prob.softmax(1) 163 | outs.append(out_prob) 164 | 165 | outs = torch.stack(outs, dim=0) 166 | out_mean = torch.mean(outs, dim=0) 167 | 168 | if likelihood == 'regression': 169 | out_var = torch.var(outs, dim=0).unsqueeze(2) 170 | return [out_mean, out_var] 171 | else: 172 | return out_mean 173 | 174 | 175 | def get_backend(backend, approx_type): 176 | if backend == 'kazuki': 177 | if approx_type == 'ggn': 178 | return AsdlGGN 179 | else: 180 | return AsdlEF 181 | elif backend == 'backpack': 182 | if approx_type == 'ggn': 183 | return BackPackGGN 184 | else: 185 | return BackPackEF 186 | else: 187 | raise ValueError('Choose a valid combination of backend and approx_type') 188 | 189 | 190 | def expand_prior_precision(prior_prec, model): 191 | theta = parameters_to_vector(model.parameters()) 192 | device, P = theta.device, len(theta) 193 | assert prior_prec.ndim == 1 194 | if len(prior_prec) == 1: # scalar 195 | return torch.ones(P, device=device) * prior_prec 196 | elif len(prior_prec) == P: # full diagonal 197 | return prior_prec.to(device) 198 | else: 199 | return torch.cat([delta * torch.ones_like(m).flatten() for delta, m 200 | in zip(prior_prec, model.parameters())]) 201 | 202 | 203 | def prior_prec_to_tensor(args, prior_prec, model): 204 | H = len(list(model.parameters())) 205 | theta = parameters_to_vector(model.parameters()) 206 | device, P = theta.device, len(theta) 207 | if args.prior_structure == 'scalar': 208 | log_prior_prec = torch.ones(1, device=device) 209 | elif args.prior_structure == 'layerwise': 210 | log_prior_prec = torch.ones(H, device=device) 211 | elif args.prior_structure == 'all': 212 | log_prior_prec = torch.ones(P, device=device) 213 | else: 214 | raise ValueError(f'Invalid prior structure {args.prior_structure}') 215 | return log_prior_prec * prior_prec 216 | 217 | 218 | def get_auroc(py_in, py_out): 219 | py_in, py_out = py_in.cpu().numpy(), py_out.cpu().numpy() 220 | labels = np.zeros(len(py_in)+len(py_out), dtype='int32') 221 | labels[:len(py_in)] = 1 222 | examples = np.concatenate([py_in.max(1), py_out.max(1)]) 223 | return roc_auc_score(labels, examples) 224 | 225 | 226 | def get_fpr95(py_in, py_out): 227 | py_in, py_out = py_in.cpu().numpy(), py_out.cpu().numpy() 228 | conf_in, conf_out = py_in.max(1), py_out.max(1) 229 | tpr = 95 230 | perc = np.percentile(conf_in, 100-tpr) 231 | fp = np.sum(conf_out >= perc) 232 | fpr = np.sum(conf_out >= perc) / len(conf_out) 233 | return fpr.item(), perc.item() 234 | 235 | 236 | def get_brier_score(probs, targets): 237 | targets = F.one_hot(targets, num_classes=probs.shape[1]) 238 | return torch.mean(torch.sum((probs - targets)**2, axis=1)).item() 239 | 240 | 241 | def get_calib(pys, y_true, M=100): 242 | pys, y_true = pys.cpu().numpy(), y_true.cpu().numpy() 243 | # Put the confidence into M bins 244 | _, bins = np.histogram(pys, M, range=(0, 1)) 245 | 246 | labels = pys.argmax(1) 247 | confs = np.max(pys, axis=1) 248 | conf_idxs = np.digitize(confs, bins) 249 | 250 | # Accuracy and avg. confidence per bin 251 | accs_bin = [] 252 | confs_bin = [] 253 | nitems_bin = [] 254 | 255 | for i in range(M): 256 | labels_i = labels[conf_idxs == i] 257 | y_true_i = y_true[conf_idxs == i] 258 | confs_i = confs[conf_idxs == i] 259 | 260 | acc = np.nan_to_num(np.mean(labels_i == y_true_i), 0) 261 | conf = np.nan_to_num(np.mean(confs_i), 0) 262 | 263 | accs_bin.append(acc) 264 | confs_bin.append(conf) 265 | nitems_bin.append(len(labels_i)) 266 | 267 | accs_bin, confs_bin = np.array(accs_bin), np.array(confs_bin) 268 | nitems_bin = np.array(nitems_bin) 269 | 270 | ECE = np.average(np.abs(confs_bin-accs_bin), weights=nitems_bin/nitems_bin.sum()) 271 | MCE = np.max(np.abs(accs_bin - confs_bin)) 272 | 273 | return ECE, MCE 274 | 275 | 276 | def get_calib_regression(pred_means, pred_stds, y_true, return_hist=False, M=10): 277 | ''' 278 | Kuleshov et al. ICML 2018, eq. 9 279 | * pred_means, pred_stds, y_true must be np.array's 280 | * Set return_hist to True to also return the "histogram"---useful for visualization (see paper) 281 | ''' 282 | T = len(y_true) 283 | ps = np.linspace(0, 1, M) 284 | cdf_vals = [st.norm(m, s).cdf(y_t) for m, s, y_t in zip(pred_means, pred_stds, y_true)] 285 | p_hats = np.array([len(np.where(cdf_vals <= p)[0]) / T for p in ps]) 286 | cal = T*mean_squared_error(ps, p_hats) # Sum-squared-error 287 | 288 | return (cal, ps, p_hats) if return_hist else cal 289 | 290 | 291 | def get_sharpness(pred_stds): 292 | ''' 293 | Kuleshov et al. ICML 2018, eq. 10 294 | 295 | pred_means be np.array 296 | ''' 297 | return np.mean(pred_stds**2) 298 | 299 | 300 | def timing(fun): 301 | """ 302 | Return the original output(s) and a wall-clock timing in second. 303 | """ 304 | if torch.cuda.is_available(): 305 | start = torch.cuda.Event(enable_timing=True) 306 | end = torch.cuda.Event(enable_timing=True) 307 | torch.cuda.synchronize() 308 | start.record() 309 | ret = fun() 310 | end.record() 311 | torch.cuda.synchronize() 312 | elapsed_time = start.elapsed_time(end)/1000 313 | else: 314 | start_time = time.time() 315 | ret = fun() 316 | end_time = time.time() 317 | elapsed_time = end_time - start_time 318 | return ret, elapsed_time 319 | 320 | 321 | def save_results(args, metrics): 322 | """ Save the computed metrics """ 323 | 324 | if args.run_name is None: 325 | res_str = f'_{args.subset_of_weights}_{args.hessian_structure}' if args.method in ['laplace', 'mola'] else '' 326 | temp_str = '' if args.temperature == 1.0 else f'_{args.temperature}' 327 | method_str = f'temp' if args.use_temperature_scaling and args.method == 'map' else args.method 328 | frac_str = f'_{args.data_fraction}' if args.data_fraction < 1.0 else '' 329 | layer_str = 'al_' if args.subset_of_weights != 'last_layer' and args.compute_mmd else '' 330 | result_path = f'./results/{args.benchmark}/{layer_str}{method_str}{res_str}{temp_str}_{args.model_seed}{frac_str}.npy' 331 | else: 332 | result_path = f'./results/{args.run_name}.npy' 333 | Path(result_path).parent.mkdir(parents=True, exist_ok=True) 334 | 335 | print(f"Saving results to {result_path}...") 336 | np.save(result_path, metrics) 337 | 338 | 339 | def get_prior_precision(args, device): 340 | """ Obtain the prior precision parameter from the cmd arguments """ 341 | 342 | if type(args.prior_precision) is str: # file path 343 | prior_precision = torch.load(args.prior_precision, map_location=device) 344 | elif type(args.prior_precision) is float: 345 | prior_precision = args.prior_precision 346 | else: 347 | raise ValueError('Algorithm not happy with inputs prior precision :(') 348 | 349 | return prior_precision 350 | 351 | def count_params(net): 352 | return sum(p.numel() for p in net.parameters()) 353 | 354 | 355 | def bias_trick(x): 356 | """ 357 | x is (batch_size, n_features) 358 | """ 359 | ones = torch.ones([x.shape[0], 1], device=x.device) 360 | return torch.cat([x, ones], -1) 361 | 362 | 363 | def vector_to_parameters_backpropable(vec, net): 364 | # Pointer for slicing the vector for each parameter 365 | pointer = 0 366 | for mod in net.children(): 367 | if isinstance(mod, torch.nn.Linear) or isinstance(mod, torch.nn.Conv2d): 368 | weight_size, bias_size = mod.weight.shape, mod.bias.shape 369 | weight_numel, bias_numel = mod.weight.numel(), mod.bias.numel() 370 | 371 | del mod.weight 372 | del mod.bias 373 | 374 | mod.weight = vec[pointer:pointer+weight_numel].reshape(weight_size) 375 | pointer += weight_numel 376 | 377 | mod.bias = vec[pointer:pointer+bias_numel].reshape(bias_size) 378 | pointer += bias_numel 379 | 380 | 381 | @torch.no_grad() 382 | def hetero_to_homo(model_ori, model_mean, model_std): 383 | # Copy params from model but split the "mean branch" and "std branch" 384 | for p_ori, p_mean, p_std in zip(model_ori.parameters(), model_mean.parameters(), model_std.parameters()): 385 | if p_ori.shape == p_mean.shape and p_ori.shape == p_std.shape: 386 | p_mean.data.copy_(p_ori.data) 387 | p_std.data.copy_(p_ori.data) 388 | else: 389 | # Only take weights corresponding to the "mean branch" and "std branch" resp. 390 | if len(p_ori.data.shape) == 2: # weight 391 | p_mean.data.copy_(p_ori.data[0].unsqueeze(0)) 392 | p_std.data.copy_(p_ori.data[1].unsqueeze(0)) 393 | else: # bias 394 | p_mean.data.copy_(p_ori.data[0]) 395 | p_std.data.copy_(p_ori.data[1]) 396 | 397 | return model_mean, model_std 398 | 399 | 400 | def load_nf_guide(model, state_dict, X_train, y_train, diag=False, cuda=False, method='refine'): 401 | assert method in ['refine', 'nf_naive'] 402 | 403 | if method == 'refine': 404 | var_key = 'base_dist_Cov' if not diag else 'base_dist_var' 405 | 406 | guide = autoguide.AutoNormalizingFlowCustom( 407 | model, state_dict['base_dist_mean'], state_dict[var_key], diag=diag, 408 | flow_type=state_dict['flow_type'], flow_len=state_dict['flow_len'], cuda=cuda 409 | ) 410 | else: 411 | nf = dist.transforms.planar if state_dict['flow_type'] == 'planar' else dist.transforms.radial 412 | transform_init = functools.partial(dist.transforms.iterated, state_dict['flow_len'], nf) 413 | guide = autoguide.AutoNormalizingFlowCuda(model, transform_init, cuda=cuda) 414 | 415 | # Do a single step SVI to initialize the guide 416 | svi = SVI(model, guide, optim=ClippedAdam({'lr': 1e-3}), loss=Trace_ELBO()) 417 | 418 | if cuda: 419 | X_train, y_train = X_train.cuda(), y_train.cuda() 420 | 421 | svi.step(X_train, y_train) 422 | 423 | # Load the saved params 424 | guide.load_state_dict(state_dict['state_dict']) 425 | 426 | return guide 427 | 428 | 429 | @torch.no_grad() 430 | def get_features(model, data_loader, cuda=True): 431 | res_x, res_y = [], [] 432 | 433 | for batch in data_loader: 434 | if len(batch) == 2: # Non-text data 435 | x, y = batch 436 | else: 437 | x = batch.text.t() 438 | y = batch.label - 1 439 | 440 | if cuda: 441 | x = x.cuda() 442 | y = y.cuda() 443 | 444 | res_x.append(model.forward_features(x)) 445 | res_y.append(y) 446 | 447 | return torch.cat(res_x, dim=0), torch.cat(res_y, dim=0) 448 | 449 | 450 | @torch.no_grad() 451 | def predict(model, test_loader, softmax=True, cuda=True): 452 | y_pred, y_true = [], [] 453 | 454 | for x, y in test_loader: 455 | if cuda: 456 | x, y = x.cuda(), y.cuda() 457 | 458 | out = torch.softmax(model(x), -1) if softmax else model(x) 459 | y_pred.append(out) 460 | y_true.append(y) 461 | 462 | return torch.cat(y_pred, 0), torch.cat(y_true, 0) 463 | 464 | 465 | @torch.no_grad() 466 | def predict_la(model, test_loader, pred_type='nn', link_approx='probit', n_samples=10, cuda=True, vectorize_x=True): 467 | y_pred, y_true = [], [] 468 | 469 | for x, y in test_loader: 470 | if vectorize_x: 471 | x = x.flatten(1) 472 | 473 | if cuda: 474 | x, y = x.cuda(), y.cuda() 475 | 476 | y_pred.append(model(x, pred_type=pred_type, link_approx=link_approx, n_samples=n_samples)) 477 | y_true.append(y) 478 | 479 | return torch.cat(y_pred, 0), torch.cat(y_true, 0) 480 | 481 | 482 | @torch.no_grad() 483 | def predict_pyro(predictive, test_loader, cuda=True): 484 | y_pred, y_true = [], [] 485 | 486 | for x, y in test_loader: 487 | if cuda: 488 | x = x.cuda() 489 | 490 | y_pred.append(torch.softmax(predictive(x)['_RETURN'], -1).mean(0)) 491 | y_true.append(y) 492 | 493 | return torch.cat(y_pred, 0), torch.cat(y_true, 0) 494 | 495 | 496 | @torch.no_grad() 497 | def predict_pyro_ll(components, test_loader): 498 | all_y_prob = list() 499 | for predictive, net in components: 500 | # Cache last-layer features 501 | features_test, all_y_true = get_features(net, test_loader) 502 | all_y_prob.append(torch.softmax( 503 | predictive(features_test, full_batch=True, X_is_features=True)['_RETURN'], -1)) 504 | all_y_prob = torch.cat(all_y_prob).mean(0) 505 | return all_y_prob, all_y_true 506 | -------------------------------------------------------------------------------- /utils/wilds_utils.py: -------------------------------------------------------------------------------- 1 | """ Utility methods for WILDS benchmark experiments """ 2 | 3 | import urllib.request 4 | from argparse import Namespace 5 | from pathlib import Path 6 | 7 | import torch 8 | from torch import nn 9 | 10 | try: 11 | from configs.utils import populate_defaults 12 | from examples.models.initializer import initialize_model 13 | from examples.transforms import initialize_transform 14 | from wilds import get_dataset 15 | from wilds.common.data_loaders import get_eval_loader, get_train_loader 16 | from wilds.common.grouper import CombinatorialGrouper 17 | except ModuleNotFoundError: 18 | print('WILDS library/dependencies not found -- please install following https://github.com/p-lambda/wilds.') 19 | 20 | 21 | D_OUTS = {"camelyon17": 2, "amazon": 5, "civilcomments": 2, "poverty": 1, "fmow": 62} 22 | 23 | MODEL_URL = 'https://worksheets.codalab.org/rest/bundles/%s/contents/blob/best_model.pth' 24 | 25 | N_SEEDS = {"camelyon17": 10, "amazon": 3, "civilcomments": 5, "poverty": 5, "fmow": 3} 26 | 27 | POVERTY_FOLDS = ['A', 'B', 'C', 'D', 'E'] 28 | 29 | ALGORITHMS = ['ERM', 'IRM', 'deepCORAL', 'groupDRO', 'ERM'] 30 | 31 | MODEL_UUIDS = { 32 | 'camelyon17': { 33 | 'camelyon17_erm_densenet121_seed0': '0x6029addd6f714167a4d34fb5351347c6', 34 | 'camelyon17_erm_densenet121_seed1': '0xb701f5de96064c0fa1771418da5df499', 35 | 'camelyon17_erm_densenet121_seed2': '0x2ce5ec845b07488fb3396ab1ab8e3e17', 36 | 'camelyon17_erm_densenet121_seed3': '0x70f110e8a86e4c3aa2688bc1267e6631', 37 | 'camelyon17_erm_densenet121_seed4': '0x0fe16428860749d6b94dfb1fe9ffe986', 38 | 'camelyon17_erm_densenet121_seed5': '0x0dc383dbf97a491fab9fb630c4119e3d', 39 | 'camelyon17_erm_densenet121_seed6': '0xb7884cbe61584e80bfadd160e1514570', 40 | 'camelyon17_erm_densenet121_seed7': '0x6f1aaa4697944b24af06db6a734f341e', 41 | 'camelyon17_erm_densenet121_seed8': '0x043be722cf50447d9b52d3afd5e55716', 42 | 'camelyon17_erm_densenet121_seed9': '0xc3ce3f5a89f84a84a1ef9a6a4a398109', 43 | 'camelyon17_irm_densenet121_seed0': '0xa63359a5bb1c449085f611f5940278d1', 44 | 'camelyon17_irm_densenet121_seed1': '0x71f860528a8b45b6bd0f0aa26906e6fc', 45 | 'camelyon17_irm_densenet121_seed2': '0x8184a0b3a1d54cf895ce4d36db9110d0', 46 | 'camelyon17_irm_densenet121_seed3': '0xc5fd2d287a6c4f94a424e4025cd03d3f', 47 | 'camelyon17_irm_densenet121_seed4': '0xc1e5f84c7a05476fbcc9ebe98614e110', 48 | 'camelyon17_irm_densenet121_seed5': '0x29c4a95f9ca644f481de41aa167c8830', 49 | 'camelyon17_irm_densenet121_seed6': '0x02c51a59e380417ba516a3b56688c4d3', 50 | 'camelyon17_irm_densenet121_seed7': '0x5e6bfa1e641d4ecd99de2361290209d3', 51 | 'camelyon17_irm_densenet121_seed8': '0x1a0ac11aaeeb4a9495c37b6ab06331c9', 52 | 'camelyon17_irm_densenet121_seed9': '0x0ce8a0a5c8be4da7ad47b1120554c62d', 53 | 'camelyon17_deepCORAL_densenet121_seed0': '0x7966e810326842deb2377bf5f36fb60d', 54 | 'camelyon17_deepCORAL_densenet121_seed1': '0x9d9caa8232d846c3a7ca30718e232157', 55 | 'camelyon17_deepCORAL_densenet121_seed2': '0x8b901447f8714621b1047423844ecd37', 56 | 'camelyon17_deepCORAL_densenet121_seed3': '0xa8f8a5bad2cc4514afe06b997f9fd648', 57 | 'camelyon17_deepCORAL_densenet121_seed4': '0xecb7f3748c9e4640a5a0b47b54977a24', 58 | 'camelyon17_deepCORAL_densenet121_seed5': '0xee62fda4353b42a48a127d374d0f1613', 59 | 'camelyon17_deepCORAL_densenet121_seed6': '0x98bffe597d264f06af4ca817a01c53fa', 60 | 'camelyon17_deepCORAL_densenet121_seed7': '0x621af9d733234b6db1de187425b8457e', 61 | 'camelyon17_deepCORAL_densenet121_seed8': '0x717afb8719b141a8adeeb634ecbed1a3', 62 | 'camelyon17_deepCORAL_densenet121_seed9': '0xa7731d1d205e4b51a545e75768fe7ea1', 63 | 'camelyon17_groupDRO_densenet121_seed0': '0x583b462eaef54d93ac03f50d210d0adf', 64 | 'camelyon17_groupDRO_densenet121_seed1': '0x296560e13e60464e9bbd8b637df21594', 65 | 'camelyon17_groupDRO_densenet121_seed2': '0xd13b972ca6c442d5961292518ad1e89a', 66 | 'camelyon17_groupDRO_densenet121_seed3': '0x4b031eeb625f47e09b03a801b6fe90d9', 67 | 'camelyon17_groupDRO_densenet121_seed4': '0x8ea2d8ba9f514e56ab6030ffc07c2735', 68 | 'camelyon17_groupDRO_densenet121_seed5': '0x72e318b10a9f4fdf974f775b453ccb58', 69 | 'camelyon17_groupDRO_densenet121_seed6': '0xeea6090106c9458eab1c3aa91e5db63b', 70 | 'camelyon17_groupDRO_densenet121_seed7': '0x5b62bff7317249bdb91df2137cf5f6f0', 71 | 'camelyon17_groupDRO_densenet121_seed8': '0x64852365891f499c946461e842a7b5dc', 72 | 'camelyon17_groupDRO_densenet121_seed9': '0x2659f54957e144809b6f4f5ffe6ddbfb', 73 | 'camelyon17_erm_ID_seed0': '0xa46957fa425f4168a9e6fbfa500d2d4f', 74 | 'camelyon17_erm_ID_seed1': '0x3abd16ec8af7498d9ea1ff63175b5c76', 75 | 'camelyon17_erm_ID_seed2': '0x8119b73f481a4b3c904f227a1305fb88', 76 | 'camelyon17_erm_ID_seed3': '0x480d53a5654543a39fe2ae9296c30304', 77 | 'camelyon17_erm_ID_seed4': '0xec0885b489ad4bc2bc7b8958b08af824', 78 | 'camelyon17_erm_ID_seed5': '0x8b9d4c81b59149b7a3c50a255d0d3a6b', 79 | 'camelyon17_erm_ID_seed6': '0x017048173dd74e9cb47779f3d0534024', 80 | 'camelyon17_erm_ID_seed7': '0x1234db7106d24f94982d176b14c86d1c', 81 | 'camelyon17_erm_ID_seed8': '0xedf8885660a647f49a324b6cece94f15', 82 | 'camelyon17_erm_ID_seed9': '0x90884c39ef114925bd78d6c2e7d1acc3', 83 | }, 84 | 'civilcomments': { 85 | 'civilcomments_distilbert_erm_seed0': '0x17807ae09e364ec3b2680d71ca3d9623', 86 | 'civilcomments_distilbert_erm_seed1': '0x0f6f161391c749beb1d0006238e145d0', 87 | 'civilcomments_distilbert_erm_seed2': '0xb92f899d126d4c6ba73f2730d76ca3e6', 88 | 'civilcomments_distilbert_erm_seed3': '0x090f8d901fad4bd7be5adb4f30e20271', 89 | 'civilcomments_distilbert_erm_seed4': '0x7a2e24652b8d4129bc67368864062bb4', 90 | 'civilcomments_distilbert_irm_seed0': '0x107e65f8c89642bcabe7628221dfa108', 91 | 'civilcomments_distilbert_irm_seed1': '0x6e46b06afff04441940d967126e4a353', 92 | 'civilcomments_distilbert_irm_seed2': '0x45db8e5cbec54c078c9dfb24cb907669', 93 | 'civilcomments_distilbert_irm_seed3': '0x84bb3f7240484b0abfc08f9c85abefeb', 94 | 'civilcomments_distilbert_irm_seed4': '0x7c572477e4bc4aa38679e8409a0504f9', 95 | 'civilcomments_distilbert_deepcoral_seed0': '0x272bffce865c42c5aad565c84fbaefdc', 96 | 'civilcomments_distilbert_deepcoral_seed1': '0xc04f7ffc47bc4544b552ddff0fcf2b5e', 97 | 'civilcomments_distilbert_deepcoral_seed2': '0x24faf0290c174d2e8be0048fd39de6a0', 98 | 'civilcomments_distilbert_deepcoral_seed3': '0x4681a19b29a6443a91bdc5dcc4c2047d', 99 | 'civilcomments_distilbert_deepcoral_seed4': '0x6d282a0b9f4e415bad269947f9d59710', 100 | 'civilcomments_distilbert_groupDRO_groupby-black-y_seed0': '0x3aeeb77983a444878cb75d7f642d6159', 101 | 'civilcomments_distilbert_groupDRO_groupby-black-y_seed1': '0x49a1ef33666f43998f46c5a1b5e6afc9', 102 | 'civilcomments_distilbert_groupDRO_groupby-black-y_seed2': '0xe5394cf75f4b4933b6527f51816f839c', 103 | 'civilcomments_distilbert_groupDRO_groupby-black-y_seed3': '0xc4385ed8a1e54a8c9fd6b4a1dd8130ab', 104 | 'civilcomments_distilbert_groupDRO_groupby-black-y_seed4': '0x99e2678b77f1479f88f2256af909f0cc', 105 | 'civilcomments_distilbert_erm_groupby-black-y_seed0': '0x87dbe66862a74a88a718a7c77399437e', 106 | 'civilcomments_distilbert_erm_groupby-black-y_seed1': '0x12cc6150c17d41b299d01e20cf7e9604', 107 | 'civilcomments_distilbert_erm_groupby-black-y_seed2': '0x41e52f9e8a43440fb8071a50dfda581d', 108 | 'civilcomments_distilbert_erm_groupby-black-y_seed3': '0xa90fa233330a42618ef4f2ee21238d3e', 109 | 'civilcomments_distilbert_erm_groupby-black-y_seed4': '0xb5c821486b5e43bfaa479852d4a09ac7', 110 | }, 111 | 'fmow': { 112 | 'fmow_erm_seed0': '0x63a3f824ac6745ea8e9061f736671304', 113 | 'fmow_erm_seed1': '0x2f8b1417709b4f2b8eec5ead67aa6203', 114 | 'fmow_erm_seed2': '0x8d7b4a78f9ba41b1a33c939e8280a156', 115 | 'fmow_irm_seed0': '0x86c1b425c76348f6972279c53862ead3', 116 | 'fmow_irm_seed1': '0xef80dd52c22a4fadb8f27827f2c0cc8e', 117 | 'fmow_irm_seed2': '0x5f855ccf76674e76bc9e0b17e97eccc4', 118 | 'fmow_deepcoral_seed0': '0x84da443d129a4fafa6b0485c60b2a125', 119 | 'fmow_deepcoral_seed1': '0x8f4651313f97465f8022f628daef9044', 120 | 'fmow_deepcoral_seed2': '0x1ef46a680072402b93186eaeb7bd8d55', 121 | 'fmow_groupDRO_seed0': '0xb9fcbbeaf44b4dc2b8c2870ef3b06c1e', 122 | 'fmow_groupDRO_seed1': '0xee70a2a62f5643d4b93054c42cb249d0', 123 | 'fmow_groupDRO_seed2': '0xafdf99b8f1d74685aa01adef79291fa6', 124 | 'fmow_erm_ID_seed0': '0x6836d02f9738458e95d0c320ee9282c4', 125 | 'fmow_erm_ID_seed1': '0x5b6ec2f1be7a4c76873b1db7db6887a4', 126 | 'fmow_erm_ID_seed2': '0x6f524be080bd4bee92152bec8d603444', 127 | }, 128 | 'poverty': { 129 | 'poverty_erm_foldA': '0xed9774bc15d14a31be7e57517989f8b7', 130 | 'poverty_erm_foldB': '0x30c0de563b694cc58e01d8abb48aa276', 131 | 'poverty_erm_foldC': '0xfc22dbce36be44fe80bddaed4ffb3ff4', 132 | 'poverty_erm_foldD': '0xcb986b1511e54a64bbb14f06be2e17a6', 133 | 'poverty_erm_foldE': '0xdd34b17f9b8b4ea2aa4d9f72ed8573f0', 134 | 'poverty_irm_foldA': '0xd0f659eda42f4da4a297667ae2e51b11', 135 | 'poverty_irm_foldB': '0xa22d1c64fe9244058a58ba3853106929', 136 | 'poverty_irm_foldC': '0xd0f659eda42f4da4a297667ae2e51b11', 137 | 'poverty_irm_foldD': '0xd0f659eda42f4da4a297667ae2e51b11', 138 | 'poverty_irm_foldE': '0xd0f659eda42f4da4a297667ae2e51b11', 139 | 'poverty_deepCORAL_foldA': '0x5b4458ef8b8f4bebbc75dc7f9d84b315', 140 | 'poverty_deepCORAL_foldB': '0xa48a6f8a725340919884389c6a1529d0', 141 | 'poverty_deepCORAL_foldC': '0x4159f07dc87c4640aa5d66aedc12f6c4', 142 | 'poverty_deepCORAL_foldD': '0x5e993f29628e453282c00523c59b9c11', 143 | 'poverty_deepCORAL_foldE': '0xe8374f3986f24fbda370d91593172204', 144 | 'poverty_groupDRO_foldA': '0x3f51b739d71440a6816ad4bd3522c7fc', 145 | 'poverty_groupDRO_foldB': '0x4b2c90d800a544c998397ca4a5594b16', 146 | 'poverty_groupDRO_foldC': '0x0adb692f84cc4200969d6b67b8130bb2', 147 | 'poverty_groupDRO_foldD': '0x84c5135b6212436d96eec3d4b6f09812', 148 | 'poverty_groupDRO_foldE': '0x795ec87eb5d848ec812112f0d18afc69', 149 | 'poverty_erm_ID_foldA': '0x89926224750f4d0193acb898c277433b', 150 | 'poverty_erm_ID_foldB': '0xeaea01c922bd4d8f85795759b28c6284', 151 | 'poverty_erm_ID_foldC': '0x4f28a08a0bc14b649a15ef9d4c854e2a', 152 | 'poverty_erm_ID_foldD': '0x3f53ee543ae3497c8fe25e8df245b3b7', 153 | 'poverty_erm_ID_foldE': '0xe93ce4c37ea94f60a6a2ddb96360883a', 154 | }, 155 | 'amazon': { 156 | 'amazonv2.0_erm_seed0': '0xe9fe4a12856f461193018504f8f65977', 157 | 'amazonv2.0_erm_seed1': '0xcbcb1b4c49c0486eacfb082ca22b8691', 158 | 'amazonv2.0_erm_seed2': '0xdf5298063529413eaf06654a5f83e4db', 159 | 'amazonv2.0_irm_seed0': '0x9dd41bedfca6410880f84d857303203d', 160 | 'amazonv2.0_irm_seed1': '0x16ab66a0c17e415cb9661779eac64ce2', 161 | 'amazonv2.0_irm_seed2': '0x204d0f8cf55348f4b9b89767a7b1aa21', 162 | 'amazonv2.0_deepcoral_seed0': '0x83232062e07046a999350bfa3d1ad90f', 163 | 'amazonv2.0_deepcoral_seed1': '0x12083bbc081549fd9e943b3f0505bda6', 164 | 'amazonv2.0_deepcoral_seed2': '0x2b3d9ccbbac3406cac4ec12d0370be8e', 165 | 'amazonv2.0_groupDRO_seed0': '0x55e1f00c8c084c07884459331cbc1f3d', 166 | 'amazonv2.0_groupDRO_seed1': '0x94d184bfb931478a8da909d73ed7be71', 167 | 'amazonv2.0_groupDRO_seed2': '0x049d643074314f37845f714e6b07616a', 168 | 'amazonv2.0_reweighted_seed0': '0xe8079c938aeb48afa57b4331e8560f38', 169 | 'amazonv2.0_reweighted_seed1': '0x8c52c3aea8104d39a2ec505517909430', 170 | 'amazonv2.0_reweighted_seed2': '0x68ae284c380a4f529602266ce6a8867f', 171 | }, 172 | } 173 | 174 | DATASET_SPLITS = { 175 | 'camelyon17': ['train', 'id_val', 'val', 'test'], 176 | 'civilcomments': ['train', 'val', 'test'], 177 | 'fmow': ['train', 'id_val', 'id_test', 'val', 'test'], 178 | 'poverty': ['train', 'id_val', 'id_test', 'val', 'test'], 179 | 'amazon': ['train', 'id_val', 'id_test', 'val', 'test'], 180 | } 181 | 182 | AMAZON_MODELS = [f'amazon_seed:{seed}_epoch:best_model.pth' for seed in range(N_SEEDS["amazon"])] 183 | FMOW_MODELS = [f'fmow_seed:{seed}_epoch:best_model.pth' for seed in range(N_SEEDS["fmow"])] 184 | POVERTY_MODELS = [f'poverty_fold:{fold}_epoch:best_model.pth' for fold in POVERTY_FOLDS] 185 | 186 | 187 | class ProperDataLoader: 188 | """ This class defines an iterator that wraps a PyTorch DataLoader 189 | to only return the first two of three elements of the data tuples. 190 | 191 | This is used to make the data loaders from the WILDS benchmark 192 | (which return (X, y, metadata) tuples, where metadata for example 193 | contains domain information) compatible with the uq.py script and 194 | with the laplace library (which both expect (X, y) tuples). 195 | """ 196 | def __init__(self, data_loader): 197 | self.data_loader = data_loader 198 | self.dataset = self.data_loader.dataset 199 | 200 | def __iter__(self): 201 | self.data_iter = iter(self.data_loader) 202 | return self 203 | 204 | def __next__(self): 205 | X, y, _ = next(self.data_iter) 206 | return X, y 207 | 208 | def __len__(self): 209 | return len(self.data_loader) 210 | 211 | 212 | def load_pretrained_wilds_model(dataset, model_dir, device, model_idx=0, model_seed=0): 213 | """ load pre-trained model """ 214 | 215 | # load default config and instantiate model 216 | config = get_default_config(dataset, algorithm=ALGORITHMS[model_idx]) 217 | is_featurizer = dataset in ["civilcomments", "amazon"] and ALGORITHMS[model_idx] == "deepCORAL" 218 | model = initialize_model(config, D_OUTS[dataset], is_featurizer=is_featurizer) 219 | if is_featurizer: 220 | model = nn.Sequential(*model) 221 | model = model.to(device) 222 | 223 | # define path to pre-trained model parameters 224 | model_list_idx = model_idx * N_SEEDS[dataset] + model_seed 225 | model_name = list(MODEL_UUIDS[dataset].keys())[model_list_idx] 226 | model_path = Path(model_dir) / dataset / f"{model_name}.pth" 227 | 228 | # if required, download pre-trained model parameters 229 | if not model_path.exists(): 230 | model_path.parent.mkdir(exist_ok=True) 231 | model_url = MODEL_URL % MODEL_UUIDS[dataset][model_name] 232 | 233 | # handle special naming cases 234 | if dataset == "amazon": 235 | model_url = model_url.replace("best_model.pth", AMAZON_MODELS[model_seed]) 236 | elif dataset == "fmow" and model_idx == 4: 237 | model_url = model_url.replace("best_model.pth", FMOW_MODELS[model_seed]) 238 | elif dataset == "poverty" and model_idx == 4: 239 | model_url = model_url.replace("best_model.pth", POVERTY_MODELS[model_seed]) 240 | 241 | print(f"Downloading pre-trained model parameters for {model_name} from {model_url}...") 242 | urllib.request.urlretrieve(model_url, model_path) 243 | 244 | # load pre-trained parameters into model 245 | print(f"Loading pre-trained model parameters for {model_name}...") 246 | state_dict = torch.load(model_path)["algorithm"] 247 | model_state_dict_keys = list(model.state_dict().keys()) 248 | 249 | model_state_dict = {} 250 | for m in state_dict: 251 | if dataset in ["civilcomments", "amazon"] and ALGORITHMS[model_idx] == "deepCORAL" and "featurizer" in m: 252 | continue 253 | 254 | m_new = m if m.split('.')[0] == "classifier" else '.'.join(m.split('.')[1:]) 255 | 256 | if "classifier" in m_new: 257 | if dataset == "poverty": 258 | m_new = m_new.replace("classifier", "fc") 259 | elif dataset in ["civilcomments", "amazon"] and ALGORITHMS[model_idx] == "deepCORAL": 260 | m_new = m_new.replace("classifier", "1") 261 | 262 | if m_new not in model_state_dict_keys: 263 | continue 264 | 265 | model_state_dict[m_new] = state_dict[m] 266 | 267 | model.load_state_dict(model_state_dict) 268 | model.eval() 269 | 270 | return model 271 | 272 | 273 | def get_wilds_loaders(dataset, data_dir, data_fraction=1.0, model_seed=0): 274 | """ load in-distribution datasets and return data loaders """ 275 | 276 | # load default config and the full dataset 277 | config = get_default_config(dataset, data_fraction=data_fraction) 278 | dataset_kwargs = {'fold': POVERTY_FOLDS[model_seed]} if dataset == "poverty" else {} 279 | full_dataset = get_dataset(dataset=dataset, root_dir=data_dir, **dataset_kwargs) 280 | train_grouper = CombinatorialGrouper(dataset=full_dataset, groupby_fields=config.groupby_fields) 281 | 282 | if dataset == "fmow": 283 | config.batch_size = config.batch_size // 2 284 | 285 | # get the train data loader 286 | train_transform = initialize_transform(transform_name=config.train_transform, config=config, dataset=full_dataset) 287 | train_data = full_dataset.get_subset('train', frac=config.frac, transform=train_transform) 288 | train_loader = get_train_loader(loader=config.train_loader, dataset=train_data, batch_size=config.batch_size, 289 | uniform_over_groups=config.uniform_over_groups, grouper=train_grouper, 290 | distinct_groups=config.distinct_groups, n_groups_per_batch=config.n_groups_per_batch, 291 | **config.loader_kwargs) 292 | 293 | # get the in-distribution validation data loader 294 | eval_transform = initialize_transform(transform_name=config.eval_transform, config=config, dataset=full_dataset) 295 | try: 296 | val_str = "val" if dataset == "fmow" else "id_val" 297 | val_data = full_dataset.get_subset(val_str, frac=config.frac, transform=eval_transform) 298 | val_loader = get_eval_loader(loader=config.eval_loader, dataset=val_data, batch_size=config.batch_size, grouper=train_grouper, **config.loader_kwargs) 299 | except: 300 | print(f"{dataset} dataset doesn't have an in-distribution validation split -- using train split instead!") 301 | val_loader = train_loader 302 | 303 | # get the in-distribution test data loader 304 | try: 305 | in_test_data = full_dataset.get_subset('id_test', frac=config.frac, transform=eval_transform) 306 | in_test_loader = get_eval_loader(loader=config.eval_loader, dataset=in_test_data, batch_size=config.batch_size, grouper=train_grouper, **config.loader_kwargs) 307 | except: 308 | print(f"{dataset} dataset doesn't have an in-distribution test split -- using validation split instead!") 309 | in_test_loader = val_loader 310 | 311 | # wrap data loaders for compatibility with uq.py and laplace library 312 | train_loader = ProperDataLoader(train_loader) 313 | val_loader = ProperDataLoader(val_loader) 314 | in_test_loader = ProperDataLoader(in_test_loader) 315 | 316 | return train_loader, val_loader, in_test_loader 317 | 318 | 319 | def get_wilds_ood_test_loader(dataset, data_dir, data_fraction=1.0, model_seed=0): 320 | """ load out-of-distribution test data and return data loader """ 321 | 322 | # load default config and the full dataset 323 | config = get_default_config(dataset, data_fraction=data_fraction) 324 | dataset_kwargs = {'fold': POVERTY_FOLDS[model_seed]} if dataset == "poverty" else {} 325 | full_dataset = get_dataset(dataset=dataset, root_dir=data_dir, **dataset_kwargs) 326 | train_grouper = CombinatorialGrouper(dataset=full_dataset, groupby_fields=config.groupby_fields) 327 | 328 | if dataset == "fmow": 329 | config.batch_size = config.batch_size // 2 330 | 331 | # get the OOD test data loader 332 | test_transform = initialize_transform(transform_name=config.eval_transform, config=config, dataset=full_dataset) 333 | test_data = full_dataset.get_subset('test', frac=config.frac, transform=test_transform) 334 | test_loader = get_eval_loader(loader=config.eval_loader, dataset=test_data, batch_size=config.batch_size, grouper=train_grouper, **config.loader_kwargs) 335 | 336 | # wrap data loader for compatibility with uq.py and laplace library 337 | test_loader = ProperDataLoader(test_loader) 338 | 339 | return test_loader 340 | 341 | 342 | def get_default_config(dataset, algorithm="ERM", data_fraction=1.0): 343 | config = Namespace(dataset=dataset, algorithm=algorithm, model_kwargs={}, optimizer_kwargs={}, 344 | loader_kwargs={}, dataset_kwargs={}, scheduler_kwargs={}, 345 | train_transform=None, eval_transform=None, no_group_logging=True, 346 | distinct_groups=True, frac=data_fraction, scheduler=None) 347 | return populate_defaults(config) 348 | 349 | 350 | def optimize_noise_standard_deviation(model, val_loader, device, lr=1e-1, n_epochs=10): 351 | """ optimizes the noise standard deviation of a Gaussian regression likelihood on the validation data """ 352 | 353 | # define parameter to optimize and optimizer 354 | log_sigma_noise = nn.Parameter(torch.zeros(1, device=device)) 355 | optimizer = torch.optim.Adam([log_sigma_noise], lr=lr) 356 | 357 | # define Gaussian negative log-likelihood loss; this is equivalent to 358 | # lambda y_pred, y, var: -Normal(y_pred, var.sqrt()).log_prob(y).mean(dim=0) 359 | gaussian_nll_loss = nn.GaussianNLLLoss(full=True) 360 | 361 | for e in range(n_epochs): 362 | print(f"Running epoch {e+1}/{n_epochs}...") 363 | for i, (X, y) in enumerate(val_loader): 364 | optimizer.zero_grad() 365 | y_pred = model(X.to(device)) 366 | if isinstance(y_pred, tuple): 367 | y_pred = y_pred[0] 368 | nll = gaussian_nll_loss(y_pred, y.to(device), torch.ones_like(y_pred) * log_sigma_noise.exp()**2) 369 | nll.backward() 370 | optimizer.step() 371 | sigma_noise = log_sigma_noise.exp().item() 372 | print(f"\tIter {i+1}/{len(val_loader)}: sigma_noise = {sigma_noise} (NLL: {nll.item()}).") 373 | print(f"After epoch {e+1}/{n_epochs}: sigma_noise = {sigma_noise}.\n") 374 | 375 | return sigma_noise 376 | -------------------------------------------------------------------------------- /uq.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | 5 | import numpy as np 6 | import pycalib.calibration_methods as calib 7 | import pyro 8 | import pyro.distributions as dist 9 | import torch 10 | import yaml 11 | from laplace import Laplace 12 | from pyro.infer import Predictive 13 | from pyro.infer.autoguide import AutoDiagonalNormal 14 | 15 | import utils.data_utils as du 16 | import utils.utils as util 17 | import utils.wilds_utils as wu 18 | from baselines.swag.swag import fit_swag_and_precompute_bn_params 19 | from models import network 20 | from models import pyro as pyro_models 21 | from utils import metrics as metrics_fns 22 | from utils.test import test 23 | 24 | warnings.filterwarnings('ignore') 25 | 26 | 27 | def main(args): 28 | # set device and random seed 29 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 30 | args.prior_precision = util.get_prior_precision(args, device) 31 | util.set_seed(args.seed) 32 | 33 | # load in-distribution data 34 | in_data_loaders, ids, no_loss_acc = du.get_in_distribution_data_loaders( 35 | args, device) 36 | train_loader, val_loader, in_test_loader = in_data_loaders 37 | 38 | # fit models 39 | mixture_components, samples = fit_models( 40 | args, train_loader, val_loader, device) 41 | 42 | # evaluate models 43 | metrics = evaluate_models( 44 | args, mixture_components, in_test_loader, ids, no_loss_acc, samples, device) 45 | 46 | # save results 47 | util.save_results(args, metrics) 48 | 49 | 50 | def fit_models(args, train_loader, val_loader, device): 51 | """ load pre-trained weights, fit inference methods, and tune hyperparameters """ 52 | n_classes = 100 if args.benchmark == 'CIFAR-100-OOD' else 10 53 | 54 | mixture_components = list() 55 | all_samples = list() 56 | for model_idx in range(args.nr_components): 57 | samples = None 58 | if args.method in ['map', 'ensemble', 'laplace', 'mola', 'swag', 'multi-swag', 'bbb', 'csghmc']: 59 | model = util.load_pretrained_model(args, model_idx, n_classes, device) 60 | 61 | # For saving pretrained models 62 | dataset = 'fmnist' if args.benchmark in ['R-FMNIST', 'FMNIST-OOD'] else 'cifar10' 63 | if args.benchmark == 'CIFAR-100-OOD': 64 | dataset += '0' 65 | is_ll = args.subset_of_weights == 'last_layer' and args.method not in ['map', 'ensemble', 'bbb'] 66 | layers = 'll' if is_ll else 'al' 67 | save_path = os.path.join(args.models_root, f'{dataset}/{layers}') 68 | if args.compute_mmd and model_idx == 0: 69 | CHAIN_IDS = [1, 2, 3] 70 | # Reference HMC samples for computing MMD distances 71 | hmc_samples = np.concatenate([ 72 | np.load(f'{save_path}/hmc_{args.model_seed}_{cid}.npy') 73 | for cid in CHAIN_IDS]) 74 | n_hmc_samples = len(hmc_samples) 75 | if args.nr_components > 1: 76 | if n_hmc_samples % args.nr_components != 0: 77 | raise ValueError('n_hmc_samples must be divisible by nr_components.') 78 | n_hmc_samples = int(n_hmc_samples / args.nr_components) 79 | all_samples.append(hmc_samples) 80 | 81 | if args.method == 'map': 82 | if args.compute_mmd: 83 | if is_ll: 84 | sample_model = model.fc if 'cifar' in dataset else model.ll 85 | else: 86 | sample_model = model 87 | samples = torch.nn.utils.parameters_to_vector( 88 | sample_model.parameters()).detach().repeat(n_hmc_samples, 1) 89 | if args.likelihood == 'classification' and args.use_temperature_scaling: 90 | print('Fitting temperature scaling model on validation data...') 91 | all_y_prob = [model(d[0].to(device)).detach().cpu() for d in val_loader] 92 | all_y_prob = torch.cat(all_y_prob, dim=0) 93 | all_y_true = torch.cat([d[1] for d in val_loader], dim=0) 94 | 95 | temperature_scaling_model = calib.TemperatureScaling() 96 | temperature_scaling_model.fit(all_y_prob.numpy(), all_y_true.numpy()) 97 | model = (model, temperature_scaling_model) 98 | 99 | elif args.method in ['laplace', 'mola']: 100 | if type(args.prior_precision) is str: # file path 101 | prior_precision = torch.load(args.prior_precision, map_location=device) 102 | elif type(args.prior_precision) is float: 103 | prior_precision = args.prior_precision 104 | else: 105 | raise ValueError('prior precision has to be either float or string (file path)') 106 | Backend = util.get_backend(args.backend, args.approx_type) 107 | optional_args = dict() 108 | 109 | if args.subset_of_weights == 'last_layer': 110 | optional_args['last_layer_name'] = args.last_layer_name 111 | 112 | print('Fitting Laplace approximation...') 113 | 114 | model = Laplace(model, args.likelihood, 115 | subset_of_weights=args.subset_of_weights, 116 | hessian_structure=args.hessian_structure, 117 | prior_precision=prior_precision, 118 | temperature=args.temperature, 119 | backend=Backend, **optional_args) 120 | model.fit(train_loader) 121 | 122 | if (args.optimize_prior_precision is not None) and (args.method == 'laplace'): 123 | if (type(prior_precision) is float) and (args.prior_structure != 'scalar'): 124 | n = model.n_params if args.prior_structure == 'all' else model.n_layers 125 | prior_precision = prior_precision * torch.ones(n, device=device) 126 | 127 | print('Optimizing prior precision for Laplace approximation...') 128 | 129 | verbose_prior = args.prior_structure == 'scalar' 130 | model.optimize_prior_precision( 131 | method=args.optimize_prior_precision, 132 | init_prior_prec=prior_precision, 133 | val_loader=val_loader, 134 | pred_type=args.pred_type, 135 | link_approx=args.link_approx, 136 | n_samples=args.n_samples, 137 | verbose=verbose_prior 138 | ) 139 | 140 | if args.compute_mmd: 141 | if args.hessian_structure == 'diag': 142 | samples = dist.Normal( 143 | model.mean, model.posterior_scale).sample((n_hmc_samples,)) 144 | else: 145 | samples = dist.MultivariateNormal( 146 | model.mean, scale_tril=model.posterior_scale).sample((n_hmc_samples,)) 147 | 148 | elif args.method in ['swag', 'multi-swag']: 149 | print('Fitting SWAG...') 150 | 151 | model = fit_swag_and_precompute_bn_params( 152 | model, device, train_loader, args.swag_n_snapshots, 153 | args.swag_lr, args.swag_c_epochs, args.swag_c_batches, 154 | args.data_parallel, args.n_samples, args.swag_bn_update_subset) 155 | 156 | elif args.method == 'bbb' and args.compute_mmd: 157 | posterior_mean = list() 158 | posterior_scale = list() 159 | for module in model.modules(): 160 | if hasattr(module, 'mu_weight'): 161 | posterior_mean.append(module.mu_weight) 162 | posterior_scale.append(module.rho_weight) 163 | if hasattr(module, 'mu_kernel'): 164 | posterior_mean.append(module.mu_kernel) 165 | posterior_scale.append(module.rho_kernel) 166 | if hasattr(module, 'mu_bias'): 167 | posterior_mean.append(module.mu_bias) 168 | posterior_scale.append(module.rho_bias) 169 | posterior_mean = torch.nn.utils.parameters_to_vector(posterior_mean) 170 | posterior_scale = torch.log1p(torch.exp( 171 | torch.nn.utils.parameters_to_vector(posterior_scale))) 172 | samples = dist.Normal( 173 | posterior_mean, posterior_scale).sample((n_hmc_samples,)) 174 | 175 | elif args.method not in ['ensemble', 'bbb', 'csghmc']: 176 | if args.method == 'hmc' and not args.compute_mmd: 177 | raise ValueError(f'compute_mmd needs to be set to True for {args.method}.') 178 | 179 | # The network 180 | if is_ll: 181 | get_net = (lambda: network.WideResNet(16, 4, num_classes=n_classes) 182 | if 'CIFAR' in args.benchmark else network.LeNet()) 183 | else: 184 | get_net = lambda: network.MLP(784, n_hiddens=50) 185 | 186 | X_train = torch.cat([x for x, _ in train_loader], dim=0) 187 | y_train = torch.cat([y for _, y in train_loader], dim=0) 188 | M, N = X_train.shape[0], np.prod(X_train.shape[1:]) 189 | if model_idx == 0: 190 | print(f'[Randseed: {args.model_seed}] Dataset: {args.benchmark}, ' 191 | f'n_data: {M}, n_feat: {N}, n_param: {util.count_params(get_net())}') 192 | print() 193 | 194 | net = get_net() 195 | model_seed = args.model_seed 196 | if args.nr_components > 1: 197 | model_seed = model_idx + 1 198 | net.load_state_dict(torch.load(f'{save_path}/map_{model_seed}.pt', map_location=device)) 199 | net.to(device) 200 | net.eval() 201 | 202 | cuda = torch.cuda.is_available() 203 | subset_of_weights = 'last_layer' if is_ll else 'all' 204 | hessian_structure = ( 205 | 'full' if is_ll and args.benchmark != 'CIFAR-100-OOD' 206 | else 'diag' 207 | ) 208 | N_FEAT = 256 if 'CIFAR-10' in args.benchmark else 84 209 | if args.n_samples % args.nr_components != 0: 210 | raise ValueError('n_samples must be divisible by nr_components.') 211 | num_samples = int(args.n_samples / args.nr_components) 212 | 213 | if is_ll: 214 | model = pyro_models.ClassificationModelLL( 215 | n_data=M, n_classes=n_classes, n_features=N_FEAT, 216 | feature_extractor=net.forward_features, 217 | prior_prec=args.prior_precision, cuda=cuda) 218 | else: 219 | model = pyro_models.ClassificationModel( 220 | get_net, n_data=M, prior_prec=args.prior_precision, cuda=cuda) 221 | 222 | if args.method == 'vb': 223 | pyro.get_param_store().clear() 224 | guide = AutoDiagonalNormal(model.model) 225 | guide._setup_prototype(X_train[:2].to(device), y_train[:2].to(device)) 226 | guide.load_state_dict(torch.load(f'{save_path}/{args.method}_{model_seed}.pt')) 227 | samples = guide.get_posterior().sample((n_hmc_samples,)) 228 | predictive = Predictive(model.model, guide=guide, num_samples=num_samples, return_sites=('_RETURN',)) 229 | 230 | elif 'nf_naive' in args.method: 231 | pyro.get_param_store().clear() 232 | state_dict = torch.load(f'{save_path}/{args.method}_{model_seed}.pt') 233 | 234 | guide = util.load_nf_guide(model.model, state_dict, *next(iter(train_loader)), cuda=True, method='nf_naive') 235 | samples = guide.get_posterior().sample((n_hmc_samples,)) 236 | predictive = Predictive(model.model, guide=guide, num_samples=num_samples, return_sites=('_RETURN',)) 237 | 238 | elif 'refine' in args.method and 'sub' not in args.method: 239 | pyro.get_param_store().clear() 240 | state_dict = torch.load(f'{save_path}/{args.method}_{model_seed}.pt') 241 | 242 | la = Laplace( 243 | net, 'classification', subset_of_weights=subset_of_weights, 244 | hessian_structure=hessian_structure, prior_precision=args.prior_precision) 245 | la.fit(train_loader) 246 | la.optimize_prior_precision() 247 | 248 | diag = hessian_structure == 'diag' 249 | guide = util.load_nf_guide(model.model, state_dict, *next(iter(train_loader)), diag=diag, cuda=cuda) 250 | samples = guide.get_posterior().sample((n_hmc_samples,)) 251 | predictive = Predictive(model.model, guide=guide, num_samples=num_samples, return_sites=('_RETURN',)) 252 | 253 | elif args.method == 'refine_sub': 254 | pyro.get_param_store().clear() 255 | state_dict = torch.load(f'{save_path}/refine_sub_{model_seed}.pt') 256 | 257 | la = Laplace( 258 | net, 'classification', subset_of_weights='last_layer', 259 | hessian_structure='full', prior_precision=args.prior_precision) 260 | la.fit(train_loader) 261 | la.optimize_prior_precision() 262 | base_dist = dist.MultivariateNormal(la.mean, la.posterior_covariance) 263 | 264 | model = pyro_models.ClassificationModelLL( 265 | n_data=M, n_classes=n_classes, n_features=N_FEAT, 266 | feature_extractor=net.forward_features, 267 | prior_prec=args.prior_precision, cuda=cuda, 268 | proj_mat=state_dict['proj_mat'], base_dist=base_dist) 269 | 270 | guide = util.load_nf_guide(model.model_subspace, state_dict, *next(iter(train_loader)), cuda=cuda) 271 | samples = guide.get_posterior().sample((n_hmc_samples,)) @ state_dict['proj_mat'] 272 | predictive = Predictive(model.model_subspace, guide=guide, num_samples=num_samples, return_sites=('_RETURN',)) 273 | 274 | else: # 'hmc' 275 | samples = torch.as_tensor(hmc_samples, dtype=torch.float, device=device) 276 | predictive = Predictive(model.model, {'theta': samples.to(device)}, return_sites=('_RETURN',)) 277 | 278 | model = (predictive, net) 279 | 280 | if args.likelihood == 'regression' and args.sigma_noise is None: 281 | print('Optimizing noise standard deviation on validation data...') 282 | args.sigma_noise = wu.optimize_noise_standard_deviation(model, val_loader, device) 283 | 284 | mixture_components.append(model) 285 | all_samples.append(samples) 286 | 287 | if len(all_samples) > 2: 288 | method_samples = torch.cat(all_samples[1:]) 289 | assert n_hmc_samples * args.nr_components == len(method_samples) 290 | all_samples = [all_samples[0], method_samples] 291 | 292 | return mixture_components, all_samples 293 | 294 | 295 | def evaluate_models(args, mixture_components, in_test_loader, ids, no_loss_acc, samples, device): 296 | """ evaluate the models and return relevant evaluation metrics """ 297 | 298 | metrics = [] 299 | for i, id in enumerate(ids): 300 | # load test data 301 | test_loader = in_test_loader if i == 0 else du.get_ood_test_loader( 302 | args, id) 303 | 304 | use_no_loss_acc = no_loss_acc if i > 0 else False 305 | # make model predictions and compute some metrics 306 | test_output, test_time = util.timing(lambda: test( 307 | mixture_components, test_loader, args.method, 308 | pred_type=args.pred_type, link_approx=args.link_approx, 309 | n_samples=args.n_samples, device=device, no_loss_acc=use_no_loss_acc, 310 | likelihood=args.likelihood, sigma_noise=args.sigma_noise)) 311 | some_metrics, all_y_prob, all_y_var = test_output 312 | some_metrics['test_time'] = test_time 313 | 314 | if i == 0: 315 | all_y_prob_in = all_y_prob.clone() 316 | 317 | # compute more metrics, aggregate and print them: 318 | # log likelihood, accuracy, confidence, Brier sore, ECE, MCE, AUROC, FPR95 319 | more_metrics = compute_metrics( 320 | i, id, all_y_prob, test_loader, all_y_prob_in, all_y_var, samples, args) 321 | metrics.append({**some_metrics, **more_metrics}) 322 | print(', '.join([f'{k}: {v:.4f}' for k, v in metrics[-1].items()])) 323 | 324 | return metrics 325 | 326 | 327 | def compute_metrics(i, id, all_y_prob, test_loader, all_y_prob_in, all_y_var, samples, args): 328 | """ compute evaluation metrics """ 329 | 330 | metrics = {} 331 | 332 | # compute Brier, ECE and MCE for in-distribution and distribution shift/WILDS data 333 | if i == 0 or args.benchmark in ['R-MNIST', 'R-FMNIST', 'CIFAR-10-C', 'ImageNet-C']: 334 | if args.benchmark in ['R-MNIST', 'R-FMNIST', 'CIFAR-10-C', 'ImageNet-C']: 335 | print(f'{args.benchmark} with distribution shift intensity {i}') 336 | labels = torch.cat([data[1] for data in test_loader]) 337 | metrics['brier'] = util.get_brier_score(all_y_prob, labels) 338 | metrics['ece'], metrics['mce'] = util.get_calib(all_y_prob, labels) 339 | 340 | # compute AUROC and FPR95 for OOD benchmarks 341 | if 'OOD' in args.benchmark: 342 | print(f'{args.benchmark} - dataset: {id}') 343 | if i > 0: 344 | # compute other metrics 345 | metrics['auroc'] = util.get_auroc(all_y_prob_in, all_y_prob) 346 | metrics['fpr95'], _ = util.get_fpr95(all_y_prob_in, all_y_prob) 347 | 348 | # compute regression calibration 349 | if args.benchmark == 'WILDS-poverty': 350 | print(f'{args.benchmark} with distribution shift intensity {i}') 351 | labels = torch.cat([data[1] for data in test_loader]) 352 | metrics['calib_regression'] = util.get_calib_regression( 353 | all_y_prob.numpy(), all_y_var.sqrt().numpy(), labels.numpy()) 354 | 355 | # compute MMD to HMC samples 356 | if args.compute_mmd and i == 0 and len(samples) == 2: 357 | hmc_samples, method_samples = samples[0], samples[1] 358 | if method_samples is not None: 359 | metrics['mmd_to_hmc'] = metrics_fns.mmd_rbf( 360 | hmc_samples, method_samples.cpu().numpy()) 361 | 362 | return metrics 363 | 364 | 365 | if __name__ == '__main__': 366 | parser = argparse.ArgumentParser() 367 | parser.add_argument('--benchmark', type=str, 368 | choices=['R-MNIST', 'R-FMNIST', 'CIFAR-10-C', 'ImageNet-C', 369 | 'MNIST-OOD', 'FMNIST-OOD', 'CIFAR-10-OOD', 'CIFAR-100-OOD', 370 | 'WILDS-camelyon17', 'WILDS-iwildcam', 371 | 'WILDS-civilcomments', 'WILDS-amazon', 372 | 'WILDS-fmow', 'WILDS-poverty'], 373 | default='CIFAR-10-C', help='name of benchmark') 374 | parser.add_argument('--data_root', type=str, default='./data', 375 | help='root of dataset') 376 | parser.add_argument('--download', action='store_true', 377 | help='if True, downloads the datasets needed for given benchmark') 378 | parser.add_argument('--data_fraction', type=float, default=1.0, 379 | help='fraction of data to use (only supported for WILDS)') 380 | parser.add_argument('--models_root', type=str, default='./models', 381 | help='root of pre-trained models') 382 | parser.add_argument('--model_seed', type=int, default=None, 383 | help='random seed with which model(s) were trained') 384 | parser.add_argument('--model_path', type=str) 385 | parser.add_argument('--hessians_root', type=str, default='./hessians', 386 | help='root of pre-computed Hessians') 387 | parser.add_argument('--method', type=str, default='laplace', 388 | help='name of method to use') 389 | parser.add_argument('--seed', type=int, default=1, 390 | help='random seed') 391 | parser.add_argument('--compute_mmd', action='store_true', 392 | help='Compute MMD to HMC samples.') 393 | 394 | parser.add_argument('--pred_type', type=str, 395 | choices=['nn', 'glm'], 396 | default='glm', 397 | help='type of approximation of predictive distribution') 398 | parser.add_argument('--link_approx', type=str, 399 | choices=['mc', 'probit', 'bridge'], 400 | default='probit', 401 | help='type of approximation of link function') 402 | parser.add_argument('--n_samples', type=int, default=20, 403 | help='nr. of MC samples for approximating the predictive distribution') 404 | 405 | parser.add_argument('--likelihood', type=str, choices=['classification', 'regression'], 406 | default='classification', help='likelihood for Laplace') 407 | parser.add_argument('--subset_of_weights', type=str, choices=['last_layer', 'all'], 408 | default='last_layer', help='subset of weights for Laplace') 409 | parser.add_argument('--backend', type=str, choices=['backpack', 'kazuki'], default='backpack') 410 | parser.add_argument('--approx_type', type=str, choices=['ggn', 'ef'], default='ggn') 411 | parser.add_argument('--hessian_structure', type=str, choices=['diag', 'kron', 'full'], 412 | default='kron', help='structure of the Hessian approximation') 413 | parser.add_argument('--last_layer_name', type=str, default=None, 414 | help='name of the last layer of the model') 415 | parser.add_argument('--prior_precision', type=float, default=1., 416 | help='prior precision to use for computing the covariance matrix') 417 | parser.add_argument('--optimize_prior_precision', default=None, 418 | choices=['marglik', 'CV'], 419 | help='optimize prior precision according to specified method') 420 | parser.add_argument('--prior_structure', type=str, default='scalar', 421 | choices=['scalar', 'layerwise', 'all']) 422 | parser.add_argument('--sigma_noise', type=float, default=None, 423 | help='noise standard deviation for regression (if -1, optimize it)') 424 | parser.add_argument('--temperature', type=float, default=1.0, 425 | help='temperature of the likelihood.') 426 | 427 | parser.add_argument('--swag_n_snapshots', type=int, default=40, 428 | help='number of snapshots for [Multi]SWAG') 429 | parser.add_argument('--swag_c_batches', type=int, default=None, 430 | help='number of batches between snapshots for [Multi]SWAG') 431 | parser.add_argument('--swag_c_epochs', type=int, default=1, 432 | help='number of epochs between snapshots for [Multi]SWAG') 433 | parser.add_argument('--swag_lr', type=float, default=1e-2, 434 | help='learning rate for [Multi]SWAG') 435 | parser.add_argument('--swag_bn_update_subset', type=float, default=1.0, 436 | help='fraction of train data for updating the BatchNorm statistics for [Multi]SWAG') 437 | 438 | parser.add_argument('--nr_components', type=int, default=1, 439 | help='number of mixture components to use') 440 | parser.add_argument('--mixture_weights', type=str, 441 | choices=['uniform', 'optimize'], 442 | default='uniform', 443 | help='how the mixture weights for MoLA are chosen') 444 | 445 | parser.add_argument('--model', type=str, default='WRN16-4', 446 | choices=['MLP', 'FMNIST-MLP', 'LeNet', 'WRN16-4', 'WRN16-4-fixup', 'WRN50-2', 447 | 'LeNet-BBB-reparam', 'LeNet-BBB-flipout', 'LeNet-CSGHMC', 448 | 'WRN16-4-BBB-reparam', 'WRN16-4-BBB-flipout', 'WRN16-4-CSGHMC'], 449 | help='the neural network model architecture') 450 | parser.add_argument('--no_dropout', action='store_true', help='only for WRN-fixup.') 451 | parser.add_argument('--data_parallel', action='store_true', 452 | help='if True, use torch.nn.DataParallel(model)') 453 | parser.add_argument('--batch_size', type=int, default=512, 454 | help='batch size for testing') 455 | parser.add_argument('--val_set_size', type=int, default=2000, 456 | help='size of validation set (taken from test set)') 457 | parser.add_argument('--use_temperature_scaling', default=False, 458 | help='if True, calibrate model using temperature scaling') 459 | 460 | parser.add_argument('--job_id', type=int, default=0, 461 | help='job ID, leave at 0 when running locally') 462 | parser.add_argument('--config', default=None, nargs='+', 463 | help='YAML config file path') 464 | parser.add_argument('--run_name', type=str, help='overwrite save file name') 465 | parser.add_argument('--noda', action='store_true') 466 | parser.add_argument('--normalize', action='store_true') 467 | 468 | args = parser.parse_args() 469 | args_dict = vars(args) 470 | 471 | # load config file (YAML) 472 | if args.config is not None: 473 | for path in args.config: 474 | with open(path) as f: 475 | config = yaml.full_load(f) 476 | args_dict.update(config) 477 | 478 | if args.data_parallel and (args.method in ['laplace, mola']): 479 | raise NotImplementedError( 480 | 'laplace and mola do not support DataParallel yet.') 481 | 482 | if (args.optimize_prior_precision is not None) and (args.method == 'mola'): 483 | raise NotImplementedError( 484 | 'optimizing the prior precision for MoLA is not supported yet.') 485 | 486 | if args.mixture_weights != 'uniform': 487 | raise NotImplementedError( 488 | 'Only uniform mixture weights are supported for now.') 489 | 490 | if ((args.method in ['ensemble', 'mola', 'multi-swag']) 491 | and (args.nr_components <= 1)): 492 | parser.error( 493 | 'Choose nr_components > 1 for ensemble, MoLA, or MultiSWAG.') 494 | 495 | if args.model != 'WRN16-4-fixup' and args.no_dropout: 496 | parser.error( 497 | 'No dropout option only available for Fixup.') 498 | 499 | if args.benchmark in ['R-MNIST', 'MNIST-OOD', 'R-FMNIST', 'FMNIST-OOD']: 500 | if 'LeNet' not in args.model and 'MLP' not in args.model: 501 | parser.error('Only LeNet or (FMNIST-)MLP works for (F-)MNIST.') 502 | elif args.benchmark in ['CIFAR-10-C', 'CIFAR-10-OOD']: 503 | if 'WRN16-4' not in args.model: 504 | parser.error('Only WRN16-4 works for CIFAR-10-C.') 505 | elif args.benchmark == 'ImageNet-C': 506 | if not (args.model == 'WRN50-2'): 507 | parser.error('Only WRN50-2 works for ImageNet-C.') 508 | 509 | if args.benchmark == 'WILDS-poverty': 510 | args.likelihood = 'regression' 511 | else: 512 | args.likelihood = 'classification' 513 | 514 | for key, val in args_dict.items(): 515 | print(f'{key}: {val}') 516 | print() 517 | 518 | main(args) 519 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data_utils 6 | import torchvision.transforms.functional as TF 7 | from PIL import Image 8 | from torchvision import datasets, transforms 9 | 10 | import utils.wilds_utils as wu 11 | 12 | 13 | def get_in_distribution_data_loaders(args, device): 14 | """ load in-distribution datasets and return data loaders """ 15 | 16 | if args.benchmark in ['R-MNIST', 'MNIST-OOD']: 17 | if args.benchmark == 'R-MNIST': 18 | no_loss_acc = False 19 | # here, id is the rotation angle 20 | ids = [0, 15, 30, 45, 60, 75, 90, 105, 120, 135, 150, 165, 180] 21 | else: 22 | no_loss_acc = True 23 | # here, id is the name of the dataset 24 | ids = ['MNIST', 'EMNIST', 'FMNIST', 'KMNIST'] 25 | train_loader, val_loader, in_test_loader = get_mnist_loaders( 26 | args.data_root, 27 | model_class=args.model, 28 | batch_size=args.batch_size, 29 | val_size=args.val_set_size, 30 | download=args.download, 31 | device=device) 32 | 33 | elif args.benchmark in ['R-FMNIST', 'FMNIST-OOD']: 34 | if args.benchmark == 'R-FMNIST': 35 | no_loss_acc = False 36 | # here, id is the rotation angle 37 | ids = [0, 15, 30, 45, 60, 75, 90, 105, 120, 135, 150, 165, 180] 38 | else: 39 | no_loss_acc = True 40 | # here, id is the name of the dataset 41 | ids = ['FMNIST', 'EMNIST', 'MNIST', 'KMNIST'] 42 | train_loader, val_loader, in_test_loader = get_fmnist_loaders( 43 | args.data_root, 44 | model_class=args.model, 45 | batch_size=args.batch_size, 46 | val_size=args.val_set_size, 47 | download=args.download, 48 | device=device) 49 | 50 | elif args.benchmark in ['R-FMNIST', 'FMNIST-OOD']: 51 | if args.benchmark == 'R-FMNIST': 52 | no_loss_acc = False 53 | # here, id is the rotation angle 54 | ids = [0, 15, 30, 45, 60, 75, 90, 105, 120, 135, 150, 165, 180] 55 | else: 56 | no_loss_acc = True 57 | # here, id is the name of the dataset 58 | ids = ['FMNIST', 'EMNIST', 'MNIST', 'KMNIST'] 59 | train_loader, val_loader, in_test_loader = get_fmnist_loaders( 60 | args.data_root, 61 | model_class=args.model, 62 | batch_size=args.batch_size, 63 | val_size=args.val_set_size, 64 | download=args.download, 65 | device=device) 66 | 67 | elif args.benchmark in ['CIFAR-10-C', 'CIFAR-10-OOD']: 68 | if args.benchmark == 'CIFAR-10-C': 69 | no_loss_acc = False 70 | # here, id is the corruption severity 71 | ids = [0, 1, 2, 3, 4, 5] 72 | else: 73 | no_loss_acc = True 74 | # here, id is the name of the OOD dataset 75 | ids = ['CIFAR-10', 'SVHN', 'LSUN', 'CIFAR-100'] 76 | 77 | train_loader, val_loader, in_test_loader = get_cifar10_loaders( 78 | args.data_root, 79 | batch_size=args.batch_size, 80 | train_batch_size=args.batch_size, 81 | val_size=args.val_set_size, 82 | download=args.download, 83 | normalize=args.normalize, 84 | data_augmentation=not args.noda) 85 | 86 | elif args.benchmark == 'CIFAR-100-OOD': 87 | no_loss_acc = True 88 | # here, id is the name of the OOD dataset 89 | ids = ['CIFAR-100', 'SVHN', 'LSUN', 'CIFAR-10'] 90 | 91 | train_loader, val_loader, in_test_loader = get_cifar100_loaders( 92 | args.data_root, 93 | batch_size=args.batch_size, 94 | train_batch_size=args.batch_size, 95 | val_size=args.val_set_size, 96 | download=args.download, 97 | normalize=args.normalize, 98 | data_augmentation=not args.noda) 99 | 100 | elif args.benchmark == 'ImageNet-C': 101 | no_loss_acc = False 102 | # here, id is the corruption severity 103 | ids = [0, 1, 2, 3, 4, 5] 104 | train_loader, val_loader, in_test_loader = get_imagenet_loaders( 105 | args.data_root, 106 | batch_size=args.batch_size, 107 | train_batch_size=args.batch_size, 108 | val_size=args.val_set_size) 109 | 110 | elif 'WILDS' in args.benchmark: 111 | dataset = args.benchmark[6:] 112 | no_loss_acc = False 113 | ids = [f'{dataset}-id', f'{dataset}-ood'] 114 | train_loader, val_loader, in_test_loader = wu.get_wilds_loaders( 115 | dataset, args.data_root, args.data_fraction, args.model_seed) 116 | 117 | return (train_loader, val_loader, in_test_loader), ids, no_loss_acc 118 | 119 | 120 | def get_ood_test_loader(args, id): 121 | """ load out-of-distribution test data and return data loader """ 122 | 123 | if args.benchmark == 'R-MNIST': 124 | _, test_loader = get_rotated_mnist_loaders( 125 | id, args.data_root, 126 | model_class=args.model, 127 | download=args.download) 128 | elif args.benchmark == 'R-FMNIST': 129 | _, test_loader = get_rotated_fmnist_loaders( 130 | id, args.data_root, 131 | model_class=args.model, 132 | download=args.download) 133 | elif args.benchmark == 'CIFAR-10-C': 134 | test_loader = load_corrupted_cifar10( 135 | id, data_dir=args.data_root, 136 | batch_size=args.batch_size, 137 | normalize=args.normalize, 138 | cuda=torch.cuda.is_available()) 139 | elif args.benchmark == 'ImageNet-C': 140 | test_loader = load_corrupted_imagenet( 141 | id, data_dir=args.data_root, 142 | batch_size=args.batch_size, 143 | cuda=torch.cuda.is_available()) 144 | elif args.benchmark == 'MNIST-OOD': 145 | _, test_loader = get_mnist_ood_loaders( 146 | id, data_path=args.data_root, 147 | batch_size=args.batch_size, 148 | model_class=args.model, 149 | download=args.download) 150 | elif args.benchmark == 'FMNIST-OOD': 151 | _, test_loader = get_mnist_ood_loaders( 152 | id, data_path=args.data_root, 153 | batch_size=args.batch_size, 154 | model_class=args.model, 155 | download=args.download) 156 | elif args.benchmark == 'CIFAR-10-OOD': 157 | _, test_loader = get_cifar10_ood_loaders( 158 | id, data_path=args.data_root, 159 | batch_size=args.batch_size, 160 | normalize=args.normalize, 161 | download=args.download) 162 | elif args.benchmark == 'CIFAR-100-OOD': 163 | _, test_loader = get_cifar100_ood_loaders( 164 | id, data_path=args.data_root, 165 | batch_size=args.batch_size, 166 | normalize=args.normalize, 167 | download=args.download) 168 | elif 'WILDS' in args.benchmark: 169 | dataset = args.benchmark[6:] 170 | test_loader = wu.get_wilds_ood_test_loader( 171 | dataset, args.data_root, args.data_fraction) 172 | 173 | return test_loader 174 | 175 | 176 | def val_test_split(dataset, val_size=5000, batch_size=512, num_workers=0, pin_memory=False): 177 | # Split into val and test sets 178 | test_size = len(dataset) - val_size 179 | dataset_val, dataset_test = data_utils.random_split( 180 | dataset, (val_size, test_size), generator=torch.Generator().manual_seed(42) 181 | ) 182 | val_loader = data_utils.DataLoader(dataset_val, batch_size=batch_size, shuffle=False, 183 | num_workers=num_workers, pin_memory=pin_memory) 184 | test_loader = data_utils.DataLoader(dataset_test, batch_size=batch_size, shuffle=False, 185 | num_workers=num_workers, pin_memory=pin_memory) 186 | return val_loader, test_loader 187 | 188 | 189 | def get_cifar10_loaders(data_path, batch_size=512, val_size=2000, 190 | train_batch_size=128, download=False, data_augmentation=True, 191 | normalize=True): 192 | mean = [x / 255 for x in [125.3, 123.0, 113.9]] 193 | std = [x / 255 for x in [63.0, 62.1, 66.7]] 194 | 195 | tforms = [transforms.ToTensor()] 196 | if normalize: 197 | tforms.append(transforms.Normalize(mean, std)) 198 | tforms_test = transforms.Compose(tforms) 199 | if data_augmentation: 200 | tforms_train = transforms.Compose( 201 | [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4)] 202 | + tforms 203 | ) 204 | else: 205 | tforms_train = tforms_test 206 | 207 | # Get datasets and data loaders 208 | train_set = datasets.CIFAR10(data_path, train=True, transform=tforms_train, 209 | download=download) 210 | # train_set = data_utils.Subset(train_set, range(500)) 211 | val_test_set = datasets.CIFAR10(data_path, train=False, transform=tforms_test, 212 | download=download) 213 | 214 | train_loader = data_utils.DataLoader(train_set, 215 | batch_size=train_batch_size, 216 | shuffle=True) 217 | val_loader, test_loader = val_test_split(val_test_set, 218 | batch_size=batch_size, 219 | val_size=val_size) 220 | 221 | return train_loader, val_loader, test_loader 222 | 223 | 224 | def get_cifar100_loaders(data_path, batch_size=512, val_size=2000, 225 | train_batch_size=128, download=False, data_augmentation=True, 226 | normalize=True): 227 | mean = (0.4914, 0.4822, 0.4465) 228 | std = (0.2023, 0.1994, 0.2010) 229 | 230 | tforms = [transforms.ToTensor()] 231 | if normalize: 232 | tforms.append(transforms.Normalize(mean, std)) 233 | tforms_test = transforms.Compose(tforms) 234 | if data_augmentation: 235 | tforms_train = transforms.Compose( 236 | [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4)] 237 | + tforms 238 | ) 239 | else: 240 | tforms_train = tforms_test 241 | 242 | # Get datasets and data loaders 243 | train_set = datasets.CIFAR100(data_path, train=True, transform=tforms_train, 244 | download=download) 245 | val_test_set = datasets.CIFAR100(data_path, train=False, transform=tforms_test, 246 | download=download) 247 | 248 | train_loader = data_utils.DataLoader(train_set, 249 | batch_size=train_batch_size, 250 | shuffle=True, 251 | num_workers=0) 252 | val_loader, test_loader = val_test_split(val_test_set, 253 | batch_size=batch_size, 254 | val_size=val_size) 255 | 256 | return train_loader, val_loader, test_loader 257 | 258 | 259 | def get_imagenet_loaders(data_path, batch_size=128, val_size=2000, 260 | train_batch_size=128, num_workers=0): 261 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 262 | std=[0.229, 0.224, 0.225]) 263 | 264 | tforms_test = transforms.Compose([ 265 | transforms.Resize(256), 266 | transforms.CenterCrop(224), 267 | transforms.ToTensor(), 268 | normalize, 269 | ]) 270 | tforms_train = transforms.Compose([ 271 | transforms.RandomResizedCrop(224), 272 | transforms.RandomHorizontalFlip(), 273 | transforms.ToTensor(), 274 | normalize, 275 | ]) 276 | 277 | data_path_train = os.path.join(data_path, 'ImageNet2012/train') 278 | data_path_val = os.path.join(data_path, 'ImageNet2012/val') 279 | 280 | # Get datasets and data loaders 281 | train_set = datasets.ImageFolder(data_path_train, transform=tforms_train) 282 | val_test_set = datasets.ImageFolder(data_path_val, transform=tforms_test) 283 | 284 | train_loader = data_utils.DataLoader(train_set, 285 | batch_size=train_batch_size, 286 | shuffle=True, 287 | num_workers=num_workers, 288 | pin_memory=True) 289 | val_loader, test_loader = val_test_split(val_test_set, 290 | batch_size=batch_size, 291 | val_size=val_size, 292 | num_workers=num_workers, 293 | pin_memory=True) 294 | 295 | return train_loader, val_loader, test_loader 296 | 297 | 298 | def get_mnist_loaders(data_path, batch_size=512, model_class='LeNet', 299 | train_batch_size=128, val_size=2000, download=False, device='cpu'): 300 | if 'MLP' in model_class: 301 | tforms = transforms.Compose([transforms.ToTensor(), ReshapeTransform((-1,))]) 302 | else: 303 | tforms = transforms.ToTensor() 304 | 305 | train_set = datasets.MNIST(data_path, train=True, transform=tforms, 306 | download=download) 307 | val_test_set = datasets.MNIST(data_path, train=False, transform=tforms, 308 | download=download) 309 | 310 | Xys = [train_set[i] for i in range(len(train_set))] 311 | Xs = torch.stack([e[0] for e in Xys]).to(device) 312 | ys = torch.stack([torch.tensor(e[1]) for e in Xys]).to(device) 313 | train_loader = FastTensorDataLoader(Xs, ys, batch_size=train_batch_size, shuffle=True) 314 | val_loader, test_loader = val_test_split(val_test_set, 315 | batch_size=batch_size, 316 | val_size=val_size) 317 | 318 | return train_loader, val_loader, test_loader 319 | 320 | 321 | def get_fmnist_loaders(data_path, batch_size=512, model_class='LeNet', 322 | train_batch_size=128, val_size=2000, download=False, device='cpu'): 323 | if 'MLP' in model_class: 324 | tforms = transforms.Compose([transforms.ToTensor(), ReshapeTransform((-1,))]) 325 | else: 326 | tforms = transforms.ToTensor() 327 | 328 | train_set = datasets.FashionMNIST(data_path, train=True, transform=tforms, 329 | download=download) 330 | val_test_set = datasets.FashionMNIST(data_path, train=False, transform=tforms, 331 | download=download) 332 | 333 | Xys = [train_set[i] for i in range(len(train_set))] 334 | Xs = torch.stack([e[0] for e in Xys]).to(device) 335 | ys = torch.stack([torch.tensor(e[1]) for e in Xys]).to(device) 336 | train_loader = FastTensorDataLoader(Xs, ys, batch_size=train_batch_size, shuffle=True) 337 | val_loader, test_loader = val_test_split(val_test_set, 338 | batch_size=batch_size, 339 | val_size=val_size) 340 | 341 | return train_loader, val_loader, test_loader 342 | 343 | 344 | def get_rotated_mnist_loaders(angle, data_path, model_class='LeNet', download=False): 345 | if 'MLP' in model_class: 346 | shift_tforms = transforms.Compose([RotationTransform(angle), transforms.ToTensor(), 347 | ReshapeTransform((-1,))]) 348 | else: 349 | shift_tforms = transforms.Compose([RotationTransform(angle), transforms.ToTensor()]) 350 | 351 | # Get rotated MNIST val/test sets and loaders 352 | rotated_mnist_val_test_set = datasets.MNIST(data_path, train=False, 353 | transform=shift_tforms, 354 | download=download) 355 | shift_val_loader, shift_test_loader = val_test_split(rotated_mnist_val_test_set, 356 | val_size=2000) 357 | 358 | return shift_val_loader, shift_test_loader 359 | 360 | 361 | def get_rotated_fmnist_loaders(angle, data_path, model_class='LeNet', download=False): 362 | if 'MLP' in model_class: 363 | shift_tforms = transforms.Compose([RotationTransform(angle), transforms.ToTensor(), 364 | ReshapeTransform((-1,))]) 365 | else: 366 | shift_tforms = transforms.Compose([RotationTransform(angle), transforms.ToTensor()]) 367 | 368 | # Get rotated FMNIST val/test sets and loaders 369 | rotated_fmnist_val_test_set = datasets.FashionMNIST(data_path, train=False, 370 | transform=shift_tforms, 371 | download=download) 372 | shift_val_loader, shift_test_loader = val_test_split(rotated_fmnist_val_test_set, 373 | val_size=2000) 374 | 375 | return shift_val_loader, shift_test_loader 376 | 377 | 378 | # https://discuss.pytorch.org/t/missing-reshape-in-torchvision/9452/6 379 | class ReshapeTransform: 380 | def __init__(self, new_size): 381 | self.new_size = new_size 382 | 383 | def __call__(self, img): 384 | return torch.reshape(img, self.new_size) 385 | 386 | 387 | class RotationTransform: 388 | """Rotate the given angle.""" 389 | def __init__(self, angle): 390 | self.angle = angle 391 | 392 | def __call__(self, x): 393 | return TF.rotate(x, self.angle) 394 | 395 | 396 | def uniform_noise(dataset, delta=1, size=5000, batch_size=512): 397 | if dataset in ['MNIST', 'FMNIST', 'R-MNIST']: 398 | shape = (1, 28, 28) 399 | elif dataset in ['SVHN', 'CIFAR10', 'CIFAR100', 'CIFAR-10-C']: 400 | shape = (3, 32, 32) 401 | elif dataset in ['ImageNet', 'ImageNet-C']: 402 | shape = (3, 256, 256) 403 | 404 | # data = torch.rand((100*batch_size,) + shape) 405 | data = delta * torch.rand((size,) + shape) 406 | train = data_utils.TensorDataset(data, torch.zeros_like(data)) 407 | loader = torch.utils.data.DataLoader(train, batch_size=batch_size, 408 | shuffle=False, num_workers=0) 409 | return loader 410 | 411 | 412 | class DatafeedImage(torch.utils.data.Dataset): 413 | def __init__(self, x_train, y_train, transform=None): 414 | self.x_train = x_train 415 | self.y_train = y_train 416 | self.transform = transform 417 | 418 | def __getitem__(self, index): 419 | img = self.x_train[index] 420 | img = Image.fromarray(img) 421 | if self.transform is not None: 422 | img = self.transform(img) 423 | return img, self.y_train[index] 424 | 425 | def __len__(self): 426 | return len(self.x_train) 427 | 428 | 429 | def load_corrupted_cifar10(severity, data_dir='data', batch_size=256, cuda=True, 430 | workers=0, normalize=True): 431 | """ load corrupted CIFAR10 dataset """ 432 | 433 | x_file = data_dir + '/CIFAR-10-C/CIFAR10_c%d.npy' % severity 434 | np_x = np.load(x_file) 435 | y_file = data_dir + '/CIFAR-10-C/CIFAR10_c_labels.npy' 436 | np_y = np.load(y_file).astype(np.int64) 437 | 438 | transform = [transforms.ToTensor()] 439 | if normalize: 440 | transform.append( 441 | transforms.Normalize((0.4914, 0.4822, 0.4465), 442 | (0.2470, 0.2435, 0.2616)) 443 | ) 444 | transform = transforms.Compose(transform) 445 | dataset = DatafeedImage(np_x, np_y, transform) 446 | dataset = data_utils.Subset(dataset, torch.randint(len(dataset), (10000,))) 447 | 448 | loader = torch.utils.data.DataLoader( 449 | dataset, 450 | batch_size=batch_size, shuffle=False, 451 | num_workers=workers, pin_memory=cuda) 452 | 453 | return loader 454 | 455 | 456 | def load_corrupted_imagenet(severity, data_dir='data', batch_size=128, cuda=True, workers=0): 457 | """ load corrupted ImageNet dataset """ 458 | 459 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 460 | std=[0.229, 0.224, 0.225]) 461 | transform = transforms.Compose([ 462 | transforms.Resize(256), 463 | transforms.CenterCrop(224), 464 | transforms.ToTensor(), 465 | normalize, 466 | ]) 467 | 468 | corruption_types = ['brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog', 469 | 'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur', 'impulse_noise', 470 | 'jpeg_compression', 'motion_blur', 'pixelate', 'saturate', 'shot_noise', 471 | 'snow', 'spatter', 'speckle_noise', 'zoom_blur'] 472 | 473 | dsets = list() 474 | for c in corruption_types: 475 | path = os.path.join(data_dir, 'ImageNet-C/' + c + '/' + str(severity)) 476 | dsets.append(datasets.ImageFolder(path, 477 | transform=transform)) 478 | dataset = data_utils.ConcatDataset(dsets) 479 | 480 | loader = torch.utils.data.DataLoader( 481 | dataset, 482 | batch_size=batch_size, shuffle=False, 483 | num_workers=workers, pin_memory=cuda) 484 | 485 | return loader 486 | 487 | 488 | def get_mnist_ood_loaders(ood_dataset, data_path='./data', model_class='LeNet', batch_size=512, download=False): 489 | '''Get out-of-distribution val/test sets and val/test loaders (in-distribution: MNIST/FMNIST)''' 490 | if 'MLP' in model_class: 491 | tforms = transforms.Compose([transforms.ToTensor(), ReshapeTransform((-1,))]) 492 | else: 493 | tforms = transforms.ToTensor() 494 | if ood_dataset == 'FMNIST': 495 | fmnist_val_test_set = datasets.FashionMNIST(data_path, train=False, 496 | transform=tforms, 497 | download=download) 498 | val_loader, test_loader = val_test_split(fmnist_val_test_set, 499 | batch_size=batch_size, 500 | val_size=0) 501 | elif ood_dataset == 'EMNIST': 502 | emnist_val_test_set = datasets.EMNIST(data_path, split='digits', train=False, 503 | transform=tforms, 504 | download=download) 505 | val_loader, test_loader = val_test_split(emnist_val_test_set, 506 | batch_size=batch_size, 507 | val_size=0) 508 | elif ood_dataset == 'KMNIST': 509 | kmnist_val_test_set = datasets.KMNIST(data_path, train=False, 510 | transform=tforms, 511 | download=download) 512 | val_loader, test_loader = val_test_split(kmnist_val_test_set, 513 | batch_size=batch_size, 514 | val_size=0) 515 | elif ood_dataset == 'MNIST': 516 | mnist_val_test_set = datasets.MNIST(data_path, train=False, 517 | transform=tforms, 518 | download=download) 519 | val_loader, test_loader = val_test_split(mnist_val_test_set, 520 | batch_size=batch_size, 521 | val_size=0) 522 | else: 523 | raise ValueError('Choose one out of FMNIST, EMNIST, MNIST, and KMNIST.') 524 | return val_loader, test_loader 525 | 526 | 527 | def get_cifar10_ood_loaders(ood_dataset, data_path='./data', batch_size=512, normalize=True, download=False): 528 | '''Get out-of-distribution val/test sets and val/test loaders (in-distribution: CIFAR-10)''' 529 | if ood_dataset == 'SVHN': 530 | tforms = [transforms.ToTensor()] 531 | if normalize: 532 | tforms.append(transforms.Normalize((0.4376821, 0.4437697, 0.47280442), 533 | (0.19803012, 0.20101562, 0.19703614))) 534 | svhn_tforms = transforms.Compose(tforms) 535 | svhn_val_test_set = datasets.SVHN(data_path, split='test', 536 | transform=svhn_tforms, 537 | download=download) 538 | val_loader, test_loader = val_test_split(svhn_val_test_set, 539 | batch_size=batch_size, 540 | val_size=0) 541 | elif ood_dataset == 'LSUN': 542 | lsun_tforms = transforms.Compose([transforms.Resize(size=(32, 32)), 543 | transforms.ToTensor()]) 544 | lsun_test_set = datasets.LSUN(data_path, classes=['classroom_val'], # classes='test' 545 | transform=lsun_tforms) 546 | val_loader = None 547 | test_loader = data_utils.DataLoader(lsun_test_set, batch_size=batch_size, 548 | shuffle=False) 549 | elif ood_dataset == 'CIFAR-100': 550 | tforms = [transforms.ToTensor()] 551 | if normalize: 552 | tforms.append(transforms.Normalize((0.4914, 0.4822, 0.4465), 553 | (0.2023, 0.1994, 0.2010))) 554 | cifar100_tforms = transforms.Compose(tforms) 555 | cifar100_val_test_set = datasets.CIFAR100(data_path, train=False, 556 | transform=cifar100_tforms, 557 | download=download) 558 | val_loader, test_loader = val_test_split(cifar100_val_test_set, 559 | batch_size=batch_size, 560 | val_size=0) 561 | else: 562 | raise ValueError('Choose one out of SVHN, LSUN, and CIFAR-100.') 563 | return val_loader, test_loader 564 | 565 | 566 | def get_cifar100_ood_loaders(ood_dataset, data_path='./data', batch_size=512, normalize=True, download=False): 567 | '''Get out-of-distribution val/test sets and val/test loaders (in-distribution: CIFAR-10)''' 568 | if ood_dataset == 'SVHN': 569 | tforms = [transforms.ToTensor()] 570 | if normalize: 571 | tforms.append(transforms.Normalize((0.4376821, 0.4437697, 0.47280442), 572 | (0.19803012, 0.20101562, 0.19703614))) 573 | svhn_tforms = transforms.Compose(tforms) 574 | svhn_val_test_set = datasets.SVHN(data_path, split='test', 575 | transform=svhn_tforms, 576 | download=download) 577 | val_loader, test_loader = val_test_split(svhn_val_test_set, 578 | batch_size=batch_size, 579 | val_size=0) 580 | elif ood_dataset == 'LSUN': 581 | lsun_tforms = transforms.Compose([transforms.Resize(size=(32, 32)), 582 | transforms.ToTensor()]) 583 | lsun_test_set = datasets.LSUN(data_path, classes=['classroom_val'], # classes='test' 584 | transform=lsun_tforms) 585 | val_loader = None 586 | test_loader = data_utils.DataLoader(lsun_test_set, batch_size=batch_size, 587 | shuffle=False) 588 | elif ood_dataset == 'CIFAR-10': 589 | tforms = [transforms.ToTensor()] 590 | if normalize: 591 | tforms.append(transforms.Normalize((0.4914, 0.4822, 0.4465), 592 | (0.2470, 0.2435, 0.2616))) 593 | cifar10_tforms = transforms.Compose(tforms) 594 | cifar10_val_test_set = datasets.CIFAR10(data_path, train=False, 595 | transform=cifar10_tforms, 596 | download=download) 597 | val_loader, test_loader = val_test_split(cifar10_val_test_set, 598 | batch_size=batch_size, 599 | val_size=0) 600 | else: 601 | raise ValueError('Choose one out of SVHN, LSUN, and CIFAR-100.') 602 | return val_loader, test_loader 603 | 604 | 605 | class FastTensorDataLoader: 606 | """ 607 | Source: https://github.com/hcarlens/pytorch-tabular/blob/master/fast_tensor_data_loader.py 608 | and https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6 609 | """ 610 | def __init__(self, *tensors, batch_size=32, shuffle=False): 611 | assert all(t.shape[0] == tensors[0].shape[0] for t in tensors) 612 | self.tensors = tensors 613 | self.dataset = tensors[0] 614 | 615 | self.dataset_len = self.tensors[0].shape[0] 616 | self.batch_size = batch_size 617 | self.shuffle = shuffle 618 | 619 | # Calculate # batches 620 | n_batches, remainder = divmod(self.dataset_len, self.batch_size) 621 | if remainder > 0: 622 | n_batches += 1 623 | self.n_batches = n_batches 624 | 625 | def __iter__(self): 626 | if self.shuffle: 627 | r = torch.randperm(self.dataset_len) 628 | self.tensors = [t[r] for t in self.tensors] 629 | self.i = 0 630 | return self 631 | 632 | def __next__(self): 633 | if self.i >= self.dataset_len: 634 | raise StopIteration 635 | batch = tuple(t[self.i:self.i+self.batch_size] for t in self.tensors) 636 | self.i += self.batch_size 637 | return batch 638 | 639 | def __len__(self): 640 | return self.n_batches 641 | --------------------------------------------------------------------------------