├── .gitignore ├── LICENSE ├── README.md ├── de_vs_duq.png ├── environment.yml ├── train_deep_ensemble.py ├── train_duq_cifar.py ├── train_duq_fm.py ├── two_moons.ipynb ├── two_moons_ensemble.ipynb └── utils ├── __init__.py ├── cnn_duq.py ├── datasets.py ├── evaluate_ood.py ├── resnet_duq.py └── wide_resnet.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | runs 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Joost van Amersfoort 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deterministic Uncertainty Quantification (DUQ) 2 | 3 | This repo contains the code for [*Uncertainty Estimation Using a Single Deep Deterministic Neural Network*](https://arxiv.org/abs/2003.02037), which is accepted for publication at ICML 2020. 4 | 5 | If the code or the paper has been useful in your research, please add a citation to our work: 6 | 7 | ``` 8 | @article{van2020uncertainty, 9 | title={Uncertainty Estimation Using a Single Deep Deterministic Neural Network}, 10 | author={van Amersfoort, Joost and Smith, Lewis and Teh, Yee Whye and Gal, Yarin}, 11 | booktitle={International Conference on Machine Learning}, 12 | year={2020} 13 | } 14 | ``` 15 | 16 | ## Dependencies 17 | 18 | The code is based on PyTorch and requires a few further dependencies, listed in [environment.yml](environment.yml). 19 | The code was tested with the versions specified in the environment file, but should work with newer versions as well (except for ignite=0.4.3). 20 | If you find an incompatibility, please let me know and I'll gladly update the code for the newest version of each library. 21 | 22 | ### Datasets 23 | 24 | Most datasets will be downloaded on the fly by Torchvision. Only NotMNIST needs to be downloaded in advance in a subfolder called `data/`: 25 | 26 | ``` 27 | mkdir -p data && cd data && curl -O "http://yaroslavvb.com/upload/notMNIST/notMNIST_small.mat" 28 | ``` 29 | 30 | FastFashionMNIST is based on [this code](https://github.com/y0ast/pytorch-snippets/tree/main/fast_mnist). 31 | The default Torchvision implementation first creates a PIL image (see [here](https://github.com/pytorch/vision/blob/v0.6.1/torchvision/datasets/mnist.py#L94)) which creates a CPU bottleneck (while training on GPU). 32 | The FastFashionMNIST class provides a significant speed up. 33 | 34 | ## Running 35 | 36 | The Two Moons experiments can be replicated using the [Two Moons notebook](two_moons.ipynb). 37 | The FashionMNIST experiment is implemented in [train\_duq\_fm.py](train_duq_fm.py). 38 | For both experiments, the paper's default are hardcoded and can be changed in place. 39 | 40 | The ResNet18 based CIFAR experiments are implemented in [train\_duq\_cifar.py](train_duq_cifar.py). 41 | All command line parameter defaults are as listed in the experimental details in Appendix A of the paper. 42 | We additionally include a Wide ResNet based architecture. 43 | 44 | For example: CIFAR-10 with gradient penalty with weight 0.5 and full training set: 45 | 46 | ``` 47 | python train_duq_cifar.py --final_model --l_gradient_penalty 0.5 48 | ``` 49 | 50 | Note that ommitting `--final_model` will lead to 20\% of the training data to be used for validation, such that hyper parameter selection can be done in a responsible manner. 51 | The code also supports the Wide ResNet with `--architecture WRN`. 52 | 53 | I also include code for my implementation of Deep Ensembles. 54 | It's a very simple implementation that achieves good results (95\% accuracy in 75 epochs using 5 models). 55 | 56 | ``` 57 | python train_deep_ensemble.py --dataset CIFAR10 58 | ``` 59 | 60 | This command will train a Deep Ensemble consisting of 5 models (the default) on CIFAR10. 61 | 62 | ## Questions 63 | 64 | For questions about the code or the paper, feel free to open an issue or email me directly. 65 | My email can be found on my GitHub profile, my [website](https://joo.st) and the paper above. 66 | 67 | 68 | ![Deep Ensembles vs DUQ](de_vs_duq.png) 69 | -------------------------------------------------------------------------------- /de_vs_duq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/y0ast/deterministic-uncertainty-quantification/3a659230a1583cc9977ff65743a1c209e7af4bed/de_vs_duq.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: duq 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python>=3.8 7 | - pytorch=1.8.1 8 | - torchvision=0.9.1 9 | - cudatoolkit=10.2 10 | - ignite=0.4.4 11 | - tqdm 12 | - tensorboard 13 | - numpy 14 | - scipy 15 | - matplotlib 16 | - seaborn 17 | - scikit-learn 18 | -------------------------------------------------------------------------------- /train_deep_ensemble.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from tqdm import tqdm 9 | 10 | from utils.datasets import all_datasets 11 | from utils.cnn_duq import SoftmaxModel as CNN 12 | from torchvision.models import resnet18 13 | 14 | 15 | class ResNet(nn.Module): 16 | def __init__(self, input_size, num_classes): 17 | super().__init__() 18 | 19 | self.resnet = resnet18(pretrained=False, num_classes=num_classes) 20 | 21 | # Adapted resnet from: 22 | # https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py 23 | self.resnet.conv1 = nn.Conv2d( 24 | 3, 64, kernel_size=3, stride=1, padding=1, bias=False 25 | ) 26 | self.resnet.maxpool = nn.Identity() 27 | 28 | def forward(self, x): 29 | x = self.resnet(x) 30 | x = F.log_softmax(x, dim=1) 31 | 32 | return x 33 | 34 | 35 | def train(model, train_loader, optimizer, epoch, loss_fn): 36 | model.train() 37 | 38 | total_loss = [] 39 | 40 | for batch_idx, (data, target) in enumerate(tqdm(train_loader)): 41 | data = data.cuda() 42 | target = target.cuda() 43 | 44 | optimizer.zero_grad() 45 | 46 | prediction = model(data) 47 | loss = loss_fn(prediction, target) 48 | 49 | loss.backward() 50 | optimizer.step() 51 | 52 | total_loss.append(loss.item()) 53 | 54 | avg_loss = torch.tensor(total_loss).mean() 55 | print(f"Epoch: {epoch}:") 56 | print(f"Train Set: Average Loss: {avg_loss:.2f}") 57 | 58 | 59 | def test(models, test_loader, loss_fn): 60 | models.eval() 61 | 62 | loss = 0 63 | correct = 0 64 | 65 | for data, target in test_loader: 66 | with torch.no_grad(): 67 | data = data.cuda() 68 | target = target.cuda() 69 | 70 | losses = torch.empty(len(models), data.shape[0]) 71 | predictions = [] 72 | for i, model in enumerate(models): 73 | predictions.append(model(data)) 74 | losses[i, :] = loss_fn(predictions[i], target, reduction="sum") 75 | 76 | predictions = torch.stack(predictions) 77 | 78 | loss += torch.mean(losses) 79 | avg_prediction = predictions.exp().mean(0) 80 | 81 | # get the index of the max log-probability 82 | class_prediction = avg_prediction.max(1)[1] 83 | correct += ( 84 | class_prediction.eq(target.view_as(class_prediction)).sum().item() 85 | ) 86 | 87 | loss /= len(test_loader.dataset) 88 | 89 | percentage_correct = 100.0 * correct / len(test_loader.dataset) 90 | 91 | print( 92 | "Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)".format( 93 | loss, correct, len(test_loader.dataset), percentage_correct 94 | ) 95 | ) 96 | 97 | return loss, percentage_correct 98 | 99 | 100 | def main(): 101 | parser = argparse.ArgumentParser() 102 | parser.add_argument( 103 | "--epochs", type=int, default=75, help="number of epochs to train (default: 75)" 104 | ) 105 | parser.add_argument( 106 | "--lr", type=float, default=0.05, help="learning rate (default: 0.05)" 107 | ) 108 | parser.add_argument( 109 | "--ensemble", type=int, default=5, help="Ensemble size (default: 5)" 110 | ) 111 | parser.add_argument( 112 | "--dataset", 113 | required=True, 114 | choices=["FashionMNIST", "CIFAR10"], 115 | help="Select a dataset", 116 | ) 117 | parser.add_argument("--seed", type=int, default=1, help="random seed (default: 1)") 118 | args = parser.parse_args() 119 | print(args) 120 | 121 | torch.manual_seed(args.seed) 122 | 123 | loss_fn = F.nll_loss 124 | 125 | ds = all_datasets[args.dataset]() 126 | input_size, num_classes, train_dataset, test_dataset = ds 127 | 128 | kwargs = {"num_workers": 4, "pin_memory": True} 129 | 130 | train_loader = torch.utils.data.DataLoader( 131 | train_dataset, batch_size=128, shuffle=True, **kwargs 132 | ) 133 | test_loader = torch.utils.data.DataLoader( 134 | test_dataset, batch_size=5000, shuffle=False, **kwargs 135 | ) 136 | 137 | if args.dataset == "FashionMNIST": 138 | milestones = [10, 20] 139 | ensemble = [CNN(input_size, num_classes).cuda() for _ in range(args.ensemble)] 140 | else: 141 | # CIFAR-10 142 | milestones = [25, 50] 143 | ensemble = [ 144 | ResNet(input_size, num_classes).cuda() for _ in range(args.ensemble) 145 | ] 146 | 147 | ensemble = torch.nn.ModuleList(ensemble) 148 | 149 | optimizers = [] 150 | schedulers = [] 151 | 152 | for model in ensemble: 153 | # Need different optimisers to apply weight decay and momentum properly 154 | # when only optimising one element of the ensemble 155 | optimizers.append( 156 | torch.optim.SGD( 157 | model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4 158 | ) 159 | ) 160 | 161 | schedulers.append( 162 | torch.optim.lr_scheduler.MultiStepLR( 163 | optimizers[-1], milestones=milestones, gamma=0.1 164 | ) 165 | ) 166 | 167 | for epoch in range(1, args.epochs + 1): 168 | for i, model in enumerate(ensemble): 169 | train(model, train_loader, optimizers[i], epoch, loss_fn) 170 | schedulers[i].step() 171 | 172 | test(ensemble, test_loader, loss_fn) 173 | 174 | pathlib.Path("saved_models").mkdir(exist_ok=True) 175 | path = f"saved_models/{args.dataset}_{len(ensemble)}" 176 | torch.save(ensemble.state_dict(), path + "_ensemble.pt") 177 | 178 | 179 | if __name__ == "__main__": 180 | main() 181 | -------------------------------------------------------------------------------- /train_duq_cifar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import pathlib 4 | import random 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.utils.data 9 | from torch.utils.tensorboard.writer import SummaryWriter 10 | 11 | from torchvision.models import resnet18 12 | 13 | from ignite.engine import Events, Engine 14 | from ignite.metrics import Accuracy, Average, Loss 15 | from ignite.contrib.handlers import ProgressBar 16 | 17 | from utils.wide_resnet import WideResNet 18 | from utils.resnet_duq import ResNet_DUQ 19 | from utils.datasets import all_datasets 20 | from utils.evaluate_ood import get_cifar_svhn_ood, get_auroc_classification 21 | 22 | 23 | def main( 24 | architecture, 25 | batch_size, 26 | length_scale, 27 | centroid_size, 28 | learning_rate, 29 | l_gradient_penalty, 30 | gamma, 31 | weight_decay, 32 | final_model, 33 | output_dir, 34 | ): 35 | writer = SummaryWriter(log_dir=f"runs/{output_dir}") 36 | 37 | ds = all_datasets["CIFAR10"]() 38 | input_size, num_classes, dataset, test_dataset = ds 39 | 40 | # Split up training set 41 | idx = list(range(len(dataset))) 42 | random.shuffle(idx) 43 | 44 | if final_model: 45 | train_dataset = dataset 46 | val_dataset = test_dataset 47 | else: 48 | val_size = int(len(dataset) * 0.8) 49 | train_dataset = torch.utils.data.Subset(dataset, idx[:val_size]) 50 | val_dataset = torch.utils.data.Subset(dataset, idx[val_size:]) 51 | 52 | val_dataset.transform = ( 53 | test_dataset.transform 54 | ) # Test time preprocessing for validation 55 | 56 | if architecture == "WRN": 57 | model_output_size = 640 58 | epochs = 200 59 | milestones = [60, 120, 160] 60 | feature_extractor = WideResNet() 61 | elif architecture == "ResNet18": 62 | model_output_size = 512 63 | epochs = 100 64 | milestones = [25, 50, 75] 65 | feature_extractor = resnet18() 66 | 67 | # Adapted resnet from: 68 | # https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py 69 | feature_extractor.conv1 = torch.nn.Conv2d( 70 | 3, 64, kernel_size=3, stride=1, padding=1, bias=False 71 | ) 72 | feature_extractor.maxpool = torch.nn.Identity() 73 | feature_extractor.fc = torch.nn.Identity() 74 | 75 | if centroid_size is None: 76 | centroid_size = model_output_size 77 | 78 | model = ResNet_DUQ( 79 | feature_extractor, 80 | num_classes, 81 | centroid_size, 82 | model_output_size, 83 | length_scale, 84 | gamma, 85 | ) 86 | model = model.cuda() 87 | 88 | optimizer = torch.optim.SGD( 89 | model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay 90 | ) 91 | 92 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 93 | optimizer, milestones=milestones, gamma=0.2 94 | ) 95 | 96 | def calc_gradients_input(x, y_pred): 97 | gradients = torch.autograd.grad( 98 | outputs=y_pred, 99 | inputs=x, 100 | grad_outputs=torch.ones_like(y_pred), 101 | create_graph=True, 102 | )[0] 103 | 104 | gradients = gradients.flatten(start_dim=1) 105 | 106 | return gradients 107 | 108 | def calc_gradient_penalty(x, y_pred): 109 | gradients = calc_gradients_input(x, y_pred) 110 | 111 | # L2 norm 112 | grad_norm = gradients.norm(2, dim=1) 113 | 114 | # Two sided penalty 115 | gradient_penalty = ((grad_norm - 1) ** 2).mean() 116 | 117 | return gradient_penalty 118 | 119 | def step(engine, batch): 120 | model.train() 121 | 122 | optimizer.zero_grad() 123 | 124 | x, y = batch 125 | x, y = x.cuda(), y.cuda() 126 | 127 | x.requires_grad_(True) 128 | 129 | y_pred = model(x) 130 | 131 | y = F.one_hot(y, num_classes).float() 132 | 133 | loss = F.binary_cross_entropy(y_pred, y, reduction="mean") 134 | 135 | if l_gradient_penalty > 0: 136 | gp = calc_gradient_penalty(x, y_pred) 137 | loss += l_gradient_penalty * gp 138 | 139 | loss.backward() 140 | optimizer.step() 141 | 142 | x.requires_grad_(False) 143 | 144 | with torch.no_grad(): 145 | model.eval() 146 | model.update_embeddings(x, y) 147 | 148 | return loss.item() 149 | 150 | def eval_step(engine, batch): 151 | model.eval() 152 | 153 | x, y = batch 154 | x, y = x.cuda(), y.cuda() 155 | 156 | x.requires_grad_(True) 157 | 158 | y_pred = model(x) 159 | 160 | return {"x": x, "y": y, "y_pred": y_pred} 161 | 162 | trainer = Engine(step) 163 | evaluator = Engine(eval_step) 164 | 165 | metric = Average() 166 | metric.attach(trainer, "loss") 167 | 168 | metric = Accuracy(output_transform=lambda out: (out["y_pred"], out["y"])) 169 | metric.attach(evaluator, "accuracy") 170 | 171 | def bce_output_transform(out): 172 | return (out["y_pred"], F.one_hot(out["y"], num_classes).float()) 173 | 174 | metric = Loss(F.binary_cross_entropy, output_transform=bce_output_transform) 175 | metric.attach(evaluator, "bce") 176 | 177 | metric = Loss( 178 | calc_gradient_penalty, output_transform=lambda out: (out["x"], out["y_pred"]) 179 | ) 180 | metric.attach(evaluator, "gradient_penalty") 181 | 182 | pbar = ProgressBar(dynamic_ncols=True) 183 | pbar.attach(trainer) 184 | 185 | kwargs = {"num_workers": 4, "pin_memory": True} 186 | 187 | train_loader = torch.utils.data.DataLoader( 188 | train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs 189 | ) 190 | 191 | val_loader = torch.utils.data.DataLoader( 192 | val_dataset, batch_size=batch_size, shuffle=False, **kwargs 193 | ) 194 | 195 | test_loader = torch.utils.data.DataLoader( 196 | test_dataset, batch_size=batch_size, shuffle=False, **kwargs 197 | ) 198 | 199 | @trainer.on(Events.EPOCH_COMPLETED) 200 | def log_results(trainer): 201 | metrics = trainer.state.metrics 202 | loss = metrics["loss"] 203 | 204 | print(f"Train - Epoch: {trainer.state.epoch} Loss: {loss:.2f}") 205 | 206 | writer.add_scalar("Loss/train", loss, trainer.state.epoch) 207 | 208 | if trainer.state.epoch > (epochs - 5): 209 | accuracy, auroc = get_cifar_svhn_ood(model) 210 | print(f"Test Accuracy: {accuracy}, AUROC: {auroc}") 211 | writer.add_scalar("OoD/test_accuracy", accuracy, trainer.state.epoch) 212 | writer.add_scalar("OoD/roc_auc", auroc, trainer.state.epoch) 213 | 214 | accuracy, auroc = get_auroc_classification(val_dataset, model) 215 | print(f"AUROC - uncertainty: {auroc}") 216 | writer.add_scalar("OoD/val_accuracy", accuracy, trainer.state.epoch) 217 | writer.add_scalar("OoD/roc_auc_classification", auroc, trainer.state.epoch) 218 | 219 | evaluator.run(val_loader) 220 | metrics = evaluator.state.metrics 221 | acc = metrics["accuracy"] 222 | bce = metrics["bce"] 223 | GP = metrics["gradient_penalty"] 224 | loss = bce + l_gradient_penalty * GP 225 | 226 | print( 227 | ( 228 | f"Valid - Epoch: {trainer.state.epoch} " 229 | f"Acc: {acc:.4f} " 230 | f"Loss: {loss:.2f} " 231 | f"BCE: {bce:.2f} " 232 | f"GP: {GP:.2f} " 233 | ) 234 | ) 235 | 236 | writer.add_scalar("Loss/valid", loss, trainer.state.epoch) 237 | writer.add_scalar("BCE/valid", bce, trainer.state.epoch) 238 | writer.add_scalar("GP/valid", GP, trainer.state.epoch) 239 | writer.add_scalar("Accuracy/valid", acc, trainer.state.epoch) 240 | 241 | scheduler.step() 242 | 243 | trainer.run(train_loader, max_epochs=epochs) 244 | evaluator.run(test_loader) 245 | acc = evaluator.state.metrics["accuracy"] 246 | 247 | print(f"Test - Accuracy {acc:.4f}") 248 | 249 | torch.save(model.state_dict(), f"runs/{output_dir}/model.pt") 250 | writer.close() 251 | 252 | 253 | if __name__ == "__main__": 254 | parser = argparse.ArgumentParser() 255 | 256 | parser.add_argument( 257 | "--architecture", 258 | default="ResNet18", 259 | choices=["ResNet18", "WRN"], 260 | help="Pick an architecture (default: ResNet18)", 261 | ) 262 | 263 | parser.add_argument( 264 | "--batch_size", 265 | type=int, 266 | default=128, 267 | help="Batch size to use for training (default: 128)", 268 | ) 269 | 270 | parser.add_argument( 271 | "--centroid_size", 272 | type=int, 273 | default=None, 274 | help="Size to use for centroids (default: same as model output)", 275 | ) 276 | 277 | parser.add_argument( 278 | "--learning_rate", 279 | type=float, 280 | default=0.05, 281 | help="Learning rate (default: 0.05)", 282 | ) 283 | 284 | parser.add_argument( 285 | "--l_gradient_penalty", 286 | type=float, 287 | default=0.75, 288 | help="Weight for gradient penalty (default: 0.75)", 289 | ) 290 | 291 | parser.add_argument( 292 | "--gamma", 293 | type=float, 294 | default=0.999, 295 | help="Decay factor for exponential average (default: 0.999)", 296 | ) 297 | 298 | parser.add_argument( 299 | "--length_scale", 300 | type=float, 301 | default=0.1, 302 | help="Length scale of RBF kernel (default: 0.1)", 303 | ) 304 | 305 | parser.add_argument( 306 | "--weight_decay", type=float, default=5e-4, help="Weight decay (default: 5e-4)" 307 | ) 308 | 309 | parser.add_argument( 310 | "--output_dir", type=str, default="results", help="set output folder" 311 | ) 312 | 313 | # Below setting cannot be used for model selection, 314 | # because the validation set equals the test set. 315 | parser.add_argument( 316 | "--final_model", 317 | action="store_true", 318 | default=False, 319 | help="Use entire training set for final model", 320 | ) 321 | 322 | args = parser.parse_args() 323 | kwargs = vars(args) 324 | print("input args:\n", json.dumps(kwargs, indent=4, separators=(",", ":"))) 325 | 326 | pathlib.Path("runs/" + args.output_dir).mkdir(exist_ok=True) 327 | 328 | main(**kwargs) 329 | -------------------------------------------------------------------------------- /train_duq_fm.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | import torch 5 | import torch.utils.data 6 | from torch.nn import functional as F 7 | 8 | from ignite.engine import Events, Engine 9 | from ignite.metrics import Accuracy, Loss 10 | 11 | from ignite.contrib.handlers.tqdm_logger import ProgressBar 12 | 13 | from utils.evaluate_ood import ( 14 | get_fashionmnist_mnist_ood, 15 | get_fashionmnist_notmnist_ood, 16 | ) 17 | from utils.datasets import FastFashionMNIST, get_FashionMNIST 18 | from utils.cnn_duq import CNN_DUQ 19 | 20 | 21 | def train_model(l_gradient_penalty, length_scale, final_model): 22 | dataset = FastFashionMNIST("data/", train=True, download=True) 23 | test_dataset = FastFashionMNIST("data/", train=False, download=True) 24 | 25 | idx = list(range(60000)) 26 | random.shuffle(idx) 27 | 28 | if final_model: 29 | train_dataset = dataset 30 | val_dataset = test_dataset 31 | else: 32 | train_dataset = torch.utils.data.Subset(dataset, indices=idx[:55000]) 33 | val_dataset = torch.utils.data.Subset(dataset, indices=idx[55000:]) 34 | 35 | num_classes = 10 36 | embedding_size = 256 37 | learnable_length_scale = False 38 | gamma = 0.999 39 | 40 | model = CNN_DUQ( 41 | num_classes, 42 | embedding_size, 43 | learnable_length_scale, 44 | length_scale, 45 | gamma, 46 | ) 47 | model = model.cuda() 48 | 49 | optimizer = torch.optim.SGD( 50 | model.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4 51 | ) 52 | 53 | def output_transform_bce(output): 54 | y_pred, y, _, _ = output 55 | return y_pred, y 56 | 57 | def output_transform_acc(output): 58 | y_pred, y, _, _ = output 59 | return y_pred, torch.argmax(y, dim=1) 60 | 61 | def output_transform_gp(output): 62 | y_pred, y, x, y_pred_sum = output 63 | return x, y_pred_sum 64 | 65 | def calc_gradient_penalty(x, y_pred_sum): 66 | gradients = torch.autograd.grad( 67 | outputs=y_pred_sum, 68 | inputs=x, 69 | grad_outputs=torch.ones_like(y_pred_sum), 70 | create_graph=True, 71 | retain_graph=True, 72 | )[0] 73 | 74 | gradients = gradients.flatten(start_dim=1) 75 | 76 | # L2 norm 77 | grad_norm = gradients.norm(2, dim=1) 78 | 79 | # Two sided penalty 80 | gradient_penalty = ((grad_norm - 1) ** 2).mean() 81 | 82 | return gradient_penalty 83 | 84 | def step(engine, batch): 85 | model.train() 86 | optimizer.zero_grad() 87 | 88 | x, y = batch 89 | y = F.one_hot(y, num_classes=10).float() 90 | 91 | x, y = x.cuda(), y.cuda() 92 | 93 | x.requires_grad_(True) 94 | 95 | y_pred = model(x) 96 | 97 | loss = F.binary_cross_entropy(y_pred, y) 98 | loss += l_gradient_penalty * calc_gradient_penalty(x, y_pred.sum(1)) 99 | 100 | x.requires_grad_(False) 101 | 102 | loss.backward() 103 | optimizer.step() 104 | 105 | with torch.no_grad(): 106 | model.eval() 107 | model.update_embeddings(x, y) 108 | 109 | return loss.item() 110 | 111 | def eval_step(engine, batch): 112 | model.eval() 113 | 114 | x, y = batch 115 | y = F.one_hot(y, num_classes=10).float() 116 | 117 | x, y = x.cuda(), y.cuda() 118 | 119 | x.requires_grad_(True) 120 | 121 | y_pred = model(x) 122 | 123 | return y_pred, y, x, y_pred.sum(1) 124 | 125 | trainer = Engine(step) 126 | evaluator = Engine(eval_step) 127 | 128 | metric = Accuracy(output_transform=output_transform_acc) 129 | metric.attach(evaluator, "accuracy") 130 | 131 | metric = Loss(F.binary_cross_entropy, output_transform=output_transform_bce) 132 | metric.attach(evaluator, "bce") 133 | 134 | metric = Loss(calc_gradient_penalty, output_transform=output_transform_gp) 135 | metric.attach(evaluator, "gradient_penalty") 136 | 137 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 138 | optimizer, milestones=[10, 20], gamma=0.2 139 | ) 140 | 141 | dl_train = torch.utils.data.DataLoader( 142 | train_dataset, batch_size=128, shuffle=True, num_workers=0, drop_last=True 143 | ) 144 | 145 | dl_val = torch.utils.data.DataLoader( 146 | val_dataset, batch_size=2000, shuffle=False, num_workers=0 147 | ) 148 | 149 | dl_test = torch.utils.data.DataLoader( 150 | test_dataset, batch_size=2000, shuffle=False, num_workers=0 151 | ) 152 | 153 | pbar = ProgressBar() 154 | pbar.attach(trainer) 155 | 156 | @trainer.on(Events.EPOCH_COMPLETED) 157 | def log_results(trainer): 158 | scheduler.step() 159 | 160 | if trainer.state.epoch % 5 == 0: 161 | evaluator.run(dl_val) 162 | _, roc_auc_mnist = get_fashionmnist_mnist_ood(model) 163 | _, roc_auc_notmnist = get_fashionmnist_notmnist_ood(model) 164 | 165 | metrics = evaluator.state.metrics 166 | 167 | print( 168 | f"Validation Results - Epoch: {trainer.state.epoch} " 169 | f"Acc: {metrics['accuracy']:.4f} " 170 | f"BCE: {metrics['bce']:.2f} " 171 | f"GP: {metrics['gradient_penalty']:.6f} " 172 | f"AUROC MNIST: {roc_auc_mnist:.2f} " 173 | f"AUROC NotMNIST: {roc_auc_notmnist:.2f} " 174 | ) 175 | print(f"Sigma: {model.sigma}") 176 | 177 | trainer.run(dl_train, max_epochs=30) 178 | 179 | evaluator.run(dl_val) 180 | val_accuracy = evaluator.state.metrics["accuracy"] 181 | 182 | evaluator.run(dl_test) 183 | test_accuracy = evaluator.state.metrics["accuracy"] 184 | 185 | return model, val_accuracy, test_accuracy 186 | 187 | 188 | if __name__ == "__main__": 189 | _, _, _, fashionmnist_test_dataset = get_FashionMNIST() 190 | 191 | # Finding length scale - decided based on validation accuracy 192 | l_gradient_penalties = [0.0] 193 | length_scales = [0.05, 0.1, 0.2, 0.3, 0.5, 1.0] 194 | 195 | # Finding gradient penalty - decided based on AUROC on NotMNIST 196 | # l_gradient_penalties = [0.0, 0.05, 0.1, 0.2, 0.3, 0.5, 1.0] 197 | # length_scales = [0.1] 198 | 199 | repetition = 1 # Increase for multiple repetitions 200 | final_model = False # set true for final model to train on full train set 201 | 202 | results = {} 203 | 204 | for l_gradient_penalty in l_gradient_penalties: 205 | for length_scale in length_scales: 206 | val_accuracies = [] 207 | test_accuracies = [] 208 | roc_aucs_mnist = [] 209 | roc_aucs_notmnist = [] 210 | 211 | for _ in range(repetition): 212 | print(" ### NEW MODEL ### ") 213 | model, val_accuracy, test_accuracy = train_model( 214 | l_gradient_penalty, length_scale, final_model 215 | ) 216 | accuracy, roc_auc_mnist = get_fashionmnist_mnist_ood(model) 217 | _, roc_auc_notmnist = get_fashionmnist_notmnist_ood(model) 218 | 219 | val_accuracies.append(val_accuracy) 220 | test_accuracies.append(test_accuracy) 221 | roc_aucs_mnist.append(roc_auc_mnist) 222 | roc_aucs_notmnist.append(roc_auc_notmnist) 223 | 224 | results[f"lgp{l_gradient_penalty}_ls{length_scale}"] = [ 225 | (np.mean(val_accuracies), np.std(val_accuracies)), 226 | (np.mean(test_accuracies), np.std(test_accuracies)), 227 | (np.mean(roc_aucs_mnist), np.std(roc_aucs_mnist)), 228 | (np.mean(roc_aucs_notmnist), np.std(roc_aucs_notmnist)), 229 | ] 230 | print(results[f"lgp{l_gradient_penalty}_ls{length_scale}"]) 231 | 232 | print(results) 233 | -------------------------------------------------------------------------------- /two_moons.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.utils.data\n", 11 | "from torch import nn\n", 12 | "from torch.nn import functional as F\n", 13 | "\n", 14 | "from ignite.engine import Events, Engine\n", 15 | "from ignite.metrics import Accuracy, Loss\n", 16 | "\n", 17 | "import numpy as np\n", 18 | "import sklearn.datasets\n", 19 | "\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "import seaborn as sns\n", 22 | "\n", 23 | "sns.set()" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "class Model_bilinear(nn.Module):\n", 33 | " def __init__(self, features, num_embeddings):\n", 34 | " super().__init__()\n", 35 | " \n", 36 | " self.gamma = 0.99\n", 37 | " self.sigma = 0.3\n", 38 | " \n", 39 | " embedding_size = 10\n", 40 | " \n", 41 | " self.fc1 = nn.Linear(2, features)\n", 42 | " self.fc2 = nn.Linear(features, features)\n", 43 | " self.fc3 = nn.Linear(features, features)\n", 44 | " \n", 45 | " self.W = nn.Parameter(torch.normal(torch.zeros(embedding_size, num_embeddings, features), 1))\n", 46 | " \n", 47 | " self.register_buffer('N', torch.ones(num_embeddings) * 20)\n", 48 | " self.register_buffer('m', torch.normal(torch.zeros(embedding_size, num_embeddings), 1))\n", 49 | " \n", 50 | " self.m = self.m * self.N.unsqueeze(0)\n", 51 | "\n", 52 | " def embed(self, x):\n", 53 | " x = F.relu(self.fc1(x))\n", 54 | " x = F.relu(self.fc2(x))\n", 55 | " x = self.fc3(x)\n", 56 | " \n", 57 | " # i is batch, m is embedding_size, n is num_embeddings (classes)\n", 58 | " x = torch.einsum('ij,mnj->imn', x, self.W)\n", 59 | " \n", 60 | " return x\n", 61 | "\n", 62 | " def bilinear(self, z):\n", 63 | " embeddings = self.m / self.N.unsqueeze(0)\n", 64 | " \n", 65 | " diff = z - embeddings.unsqueeze(0) \n", 66 | " y_pred = (- diff**2).mean(1).div(2 * self.sigma**2).exp()\n", 67 | "\n", 68 | " return y_pred\n", 69 | "\n", 70 | " def forward(self, x):\n", 71 | " z = self.embed(x)\n", 72 | " y_pred = self.bilinear(z)\n", 73 | " \n", 74 | " return z, y_pred\n", 75 | "\n", 76 | " def update_embeddings(self, x, y):\n", 77 | " z = self.embed(x)\n", 78 | " \n", 79 | " # normalizing value per class, assumes y is one_hot encoded\n", 80 | " self.N = torch.max(self.gamma * self.N + (1 - self.gamma) * y.sum(0), torch.ones_like(self.N))\n", 81 | " \n", 82 | " # compute sum of embeddings on class by class basis\n", 83 | " features_sum = torch.einsum('ijk,ik->jk', z, y)\n", 84 | " \n", 85 | " self.m = self.gamma * self.m + (1 - self.gamma) * features_sum" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "np.random.seed(0)\n", 95 | "torch.manual_seed(0)\n", 96 | "\n", 97 | "l_gradient_penalty = 1.0\n", 98 | "\n", 99 | "# Moons\n", 100 | "noise = 0.1\n", 101 | "X_train, y_train = sklearn.datasets.make_moons(n_samples=1500, noise=noise)\n", 102 | "X_test, y_test = sklearn.datasets.make_moons(n_samples=200, noise=noise)\n", 103 | "\n", 104 | "num_classes = 2\n", 105 | "batch_size = 64\n", 106 | "\n", 107 | "model = Model_bilinear(20, num_classes)\n", 108 | "\n", 109 | "optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)\n", 110 | "\n", 111 | "\n", 112 | "def calc_gradient_penalty(x, y_pred):\n", 113 | " gradients = torch.autograd.grad(\n", 114 | " outputs=y_pred,\n", 115 | " inputs=x,\n", 116 | " grad_outputs=torch.ones_like(y_pred),\n", 117 | " create_graph=True,\n", 118 | " )[0]\n", 119 | "\n", 120 | "\n", 121 | " gradients = gradients.flatten(start_dim=1)\n", 122 | " \n", 123 | " # L2 norm\n", 124 | " grad_norm = gradients.norm(2, dim=1)\n", 125 | "\n", 126 | " # Two sided penalty\n", 127 | " gradient_penalty = ((grad_norm - 1) ** 2).mean()\n", 128 | " \n", 129 | " # One sided penalty - down\n", 130 | "# gradient_penalty = F.relu(grad_norm - 1).mean()\n", 131 | "\n", 132 | " return gradient_penalty\n", 133 | "\n", 134 | "\n", 135 | "def output_transform_acc(output):\n", 136 | " y_pred, y, x, z = output\n", 137 | " \n", 138 | " y = torch.argmax(y, dim=1)\n", 139 | " \n", 140 | " return y_pred, y\n", 141 | "\n", 142 | "\n", 143 | "def output_transform_bce(output):\n", 144 | " y_pred, y, x, z = output\n", 145 | "\n", 146 | " return y_pred, y\n", 147 | "\n", 148 | "\n", 149 | "def output_transform_gp(output):\n", 150 | " y_pred, y, x, z = output\n", 151 | "\n", 152 | " return x, y_pred\n", 153 | "\n", 154 | "\n", 155 | "def step(engine, batch):\n", 156 | " model.train()\n", 157 | " optimizer.zero_grad()\n", 158 | " \n", 159 | " x, y = batch\n", 160 | " x.requires_grad_(True)\n", 161 | " \n", 162 | " z, y_pred = model(x)\n", 163 | " \n", 164 | " loss1 = F.binary_cross_entropy(y_pred, y)\n", 165 | " loss2 = l_gradient_penalty * calc_gradient_penalty(x, y_pred)\n", 166 | " \n", 167 | " loss = loss1 + loss2\n", 168 | " \n", 169 | " loss.backward()\n", 170 | " optimizer.step()\n", 171 | " \n", 172 | " with torch.no_grad():\n", 173 | " model.update_embeddings(x, y)\n", 174 | " \n", 175 | " return loss.item()\n", 176 | "\n", 177 | "\n", 178 | "def eval_step(engine, batch):\n", 179 | " model.eval()\n", 180 | "\n", 181 | " x, y = batch\n", 182 | "\n", 183 | " x.requires_grad_(True)\n", 184 | "\n", 185 | " z, y_pred = model(x)\n", 186 | "\n", 187 | " return y_pred, y, x, z\n", 188 | " \n", 189 | "\n", 190 | "trainer = Engine(step)\n", 191 | "evaluator = Engine(eval_step)\n", 192 | "\n", 193 | "metric = Accuracy(output_transform=output_transform_acc)\n", 194 | "metric.attach(evaluator, \"accuracy\")\n", 195 | "\n", 196 | "metric = Loss(F.binary_cross_entropy, output_transform=output_transform_bce)\n", 197 | "metric.attach(evaluator, \"bce\")\n", 198 | "\n", 199 | "metric = Loss(calc_gradient_penalty, output_transform=output_transform_gp)\n", 200 | "metric.attach(evaluator, \"gp\")\n", 201 | "\n", 202 | "\n", 203 | "ds_train = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(), F.one_hot(torch.from_numpy(y_train)).float())\n", 204 | "dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True, drop_last=True)\n", 205 | "\n", 206 | "ds_test = torch.utils.data.TensorDataset(torch.from_numpy(X_test).float(), F.one_hot(torch.from_numpy(y_test)).float())\n", 207 | "dl_test = torch.utils.data.DataLoader(ds_test, batch_size=200, shuffle=False)\n", 208 | "\n", 209 | "@trainer.on(Events.EPOCH_COMPLETED)\n", 210 | "def log_results(trainer):\n", 211 | " evaluator.run(dl_test)\n", 212 | " metrics = evaluator.state.metrics\n", 213 | "\n", 214 | " print(\"Test Results - Epoch: {} Acc: {:.4f} BCE: {:.2f} GP {:.2f}\"\n", 215 | " .format(trainer.state.epoch, metrics['accuracy'], metrics['bce'], metrics['gp']))" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "metadata": { 222 | "scrolled": false 223 | }, 224 | "outputs": [], 225 | "source": [ 226 | "trainer.run(dl_train, max_epochs=30)" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "metadata": {}, 233 | "outputs": [], 234 | "source": [ 235 | "domain = 3\n", 236 | "x_lin = np.linspace(-domain+0.5, domain+0.5, 100)\n", 237 | "y_lin = np.linspace(-domain, domain, 100)\n", 238 | "\n", 239 | "xx, yy = np.meshgrid(x_lin, y_lin)\n", 240 | "\n", 241 | "X_grid = np.column_stack([xx.flatten(), yy.flatten()])\n", 242 | "\n", 243 | "X_vis, y_vis = sklearn.datasets.make_moons(n_samples=1000, noise=noise)\n", 244 | "mask = y_vis.astype(np.bool)\n", 245 | "\n", 246 | "with torch.no_grad():\n", 247 | " output = model(torch.from_numpy(X_grid).float())[1]\n", 248 | " confidence = output.max(1)[0].numpy()\n", 249 | "\n", 250 | "\n", 251 | "z = confidence.reshape(xx.shape)\n", 252 | "\n", 253 | "plt.figure()\n", 254 | "plt.contourf(x_lin, y_lin, z, cmap='cividis')\n", 255 | "\n", 256 | "plt.scatter(X_vis[mask,0], X_vis[mask,1])\n", 257 | "plt.scatter(X_vis[~mask,0], X_vis[~mask,1])" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [] 266 | } 267 | ], 268 | "metadata": { 269 | "kernelspec": { 270 | "display_name": "Python 3", 271 | "language": "python", 272 | "name": "python3" 273 | }, 274 | "language_info": { 275 | "codemirror_mode": { 276 | "name": "ipython", 277 | "version": 3 278 | }, 279 | "file_extension": ".py", 280 | "mimetype": "text/x-python", 281 | "name": "python", 282 | "nbconvert_exporter": "python", 283 | "pygments_lexer": "ipython3", 284 | "version": "3.7.7" 285 | } 286 | }, 287 | "nbformat": 4, 288 | "nbformat_minor": 2 289 | } 290 | -------------------------------------------------------------------------------- /two_moons_ensemble.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.utils.data\n", 11 | "from torch import nn\n", 12 | "from torch.nn import functional as F\n", 13 | "\n", 14 | "from ignite.engine import Events, Engine\n", 15 | "from ignite.metrics import Accuracy, Loss\n", 16 | "\n", 17 | "import numpy as np\n", 18 | "import sklearn.datasets\n", 19 | "\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "import seaborn as sns\n", 22 | "\n", 23 | "sns.set()\n", 24 | "\n", 25 | "torch.manual_seed(1)\n", 26 | "np.random.seed(1)" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "class Model(nn.Module):\n", 36 | " def __init__(self, features):\n", 37 | " super().__init__()\n", 38 | " \n", 39 | " self.fc1 = nn.Linear(2, features)\n", 40 | " self.fc2 = nn.Linear(features, features)\n", 41 | " self.fc3 = nn.Linear(features, features)\n", 42 | " self.fc4 = nn.Linear(features, 2)\n", 43 | "\n", 44 | " def forward(self, x):\n", 45 | " x = F.relu(self.fc1(x))\n", 46 | " x = F.relu(self.fc2(x))\n", 47 | " x = F.relu(self.fc3(x))\n", 48 | " x = self.fc4(x)\n", 49 | " \n", 50 | " return F.log_softmax(x, dim=1)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "noise = 0.1\n", 60 | "\n", 61 | "X_train, y_train = sklearn.datasets.make_moons(n_samples=1000, noise=noise)\n", 62 | "X_test, y_test = sklearn.datasets.make_moons(n_samples=200, noise=noise)\n", 63 | "\n", 64 | "num_classes = 2\n", 65 | "batch_size = 64\n", 66 | "\n", 67 | "def train_model(max_epochs):\n", 68 | " model = Model(20)\n", 69 | "\n", 70 | " optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)\n", 71 | "\n", 72 | " def step(engine, batch):\n", 73 | " model.train()\n", 74 | " optimizer.zero_grad()\n", 75 | "\n", 76 | " x, y = batch\n", 77 | "\n", 78 | " y_pred = model(x)\n", 79 | " loss = F.nll_loss(y_pred, y)\n", 80 | "\n", 81 | " loss.backward()\n", 82 | " optimizer.step()\n", 83 | "\n", 84 | " return loss.item()\n", 85 | "\n", 86 | " def eval_step(engine, batch):\n", 87 | " model.eval()\n", 88 | "\n", 89 | " x, y = batch\n", 90 | " y_pred = model(x)\n", 91 | "\n", 92 | " return y_pred, y\n", 93 | "\n", 94 | "\n", 95 | " trainer = Engine(step)\n", 96 | " evaluator = Engine(eval_step)\n", 97 | "\n", 98 | " metric = Accuracy()\n", 99 | " metric.attach(evaluator, \"accuracy\")\n", 100 | "\n", 101 | " metric = Loss(F.nll_loss)\n", 102 | " metric.attach(evaluator, \"nll\")\n", 103 | "\n", 104 | " ds_train = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(), torch.from_numpy(y_train))\n", 105 | " dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True, drop_last=True)\n", 106 | "\n", 107 | " ds_test = torch.utils.data.TensorDataset(torch.from_numpy(X_test).float(), torch.from_numpy(y_test))\n", 108 | " dl_test = torch.utils.data.DataLoader(ds_test, batch_size=200, shuffle=False)\n", 109 | "\n", 110 | " @trainer.on(Events.EPOCH_COMPLETED)\n", 111 | " def log_results(trainer):\n", 112 | " evaluator.run(dl_test)\n", 113 | " metrics = evaluator.state.metrics\n", 114 | "\n", 115 | " print(f\"Test Results - Epoch: {trainer.state.epoch} Acc: {metrics['accuracy']:.4f} NLL: {metrics['nll']:.2f}\")\n", 116 | " \n", 117 | " trainer.run(dl_train, max_epochs=max_epochs)\n", 118 | " \n", 119 | " return model" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": { 126 | "scrolled": false 127 | }, 128 | "outputs": [], 129 | "source": [ 130 | "ensemble = 5\n", 131 | "models = [train_model(50) for _ in range(ensemble)]" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "domain = 3\n", 141 | "x = np.linspace(-domain+0.5, domain+0.5, 100)\n", 142 | "y = np.linspace(-domain, domain, 100)\n", 143 | "\n", 144 | "xx, yy = np.meshgrid(x, y)\n", 145 | "\n", 146 | "X = np.column_stack([xx.flatten(), yy.flatten()])\n", 147 | "\n", 148 | "X_vis, y_vis = sklearn.datasets.make_moons(n_samples=500, noise=noise)\n", 149 | "mask = y_vis.astype(np.bool)\n", 150 | "\n", 151 | "for model in models:\n", 152 | " model.eval()\n", 153 | "\n", 154 | "with torch.no_grad():\n", 155 | " predictions = torch.stack([model(torch.from_numpy(X).float()) for model in models])\n", 156 | "\n", 157 | " mean_prediction = torch.mean(predictions.exp(), dim=0)\n", 158 | " confidence = torch.sum(mean_prediction * torch.log(mean_prediction), dim=1)\n", 159 | "\n", 160 | "z = confidence.reshape(xx.shape)\n", 161 | "\n", 162 | "plt.figure()\n", 163 | "plt.contourf(x, y, z, cmap='cividis')\n", 164 | "\n", 165 | "plt.scatter(X_vis[mask,0], X_vis[mask,1])\n", 166 | "plt.scatter(X_vis[~mask,0], X_vis[~mask,1])\n", 167 | "\n", 168 | "plt.figure()\n", 169 | "plt.contourf(x, y, z, cmap='cividis')" 170 | ] 171 | } 172 | ], 173 | "metadata": { 174 | "kernelspec": { 175 | "display_name": "Python 3", 176 | "language": "python", 177 | "name": "python3" 178 | }, 179 | "language_info": { 180 | "codemirror_mode": { 181 | "name": "ipython", 182 | "version": 3 183 | }, 184 | "file_extension": ".py", 185 | "mimetype": "text/x-python", 186 | "name": "python", 187 | "nbconvert_exporter": "python", 188 | "pygments_lexer": "ipython3", 189 | "version": "3.8.5" 190 | } 191 | }, 192 | "nbformat": 4, 193 | "nbformat_minor": 2 194 | } 195 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/y0ast/deterministic-uncertainty-quantification/3a659230a1583cc9977ff65743a1c209e7af4bed/utils/__init__.py -------------------------------------------------------------------------------- /utils/cnn_duq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Model(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | self.conv1 = nn.Conv2d(1, 64, 3, padding=1) 11 | self.bn1 = nn.BatchNorm2d(64) 12 | 13 | self.conv2 = nn.Conv2d(64, 128, 3, padding=1) 14 | self.bn2 = nn.BatchNorm2d(128) 15 | 16 | self.conv3 = nn.Conv2d(128, 128, 3) 17 | self.bn3 = nn.BatchNorm2d(128) 18 | 19 | self.fc1 = nn.Linear(2 * 2 * 128, 256) 20 | 21 | def compute_features(self, x): 22 | x = F.relu(self.bn1(self.conv1(x))) 23 | x = F.max_pool2d(x, 2, 2) 24 | 25 | x = F.relu(self.bn2(self.conv2(x))) 26 | x = F.max_pool2d(x, 2, 2) 27 | 28 | x = F.relu(self.bn3(self.conv3(x))) 29 | x = F.max_pool2d(x, 2, 2) 30 | 31 | x = x.flatten(1) 32 | 33 | x = F.relu(self.fc1(x)) 34 | 35 | return x 36 | 37 | 38 | class CNN_DUQ(Model): 39 | def __init__( 40 | self, 41 | num_classes, 42 | embedding_size, 43 | learnable_length_scale, 44 | length_scale, 45 | gamma, 46 | ): 47 | super().__init__() 48 | 49 | self.gamma = gamma 50 | 51 | self.W = nn.Parameter( 52 | torch.normal(torch.zeros(embedding_size, num_classes, 256), 0.05) 53 | ) 54 | 55 | self.register_buffer("N", torch.ones(num_classes) * 12) 56 | self.register_buffer( 57 | "m", torch.normal(torch.zeros(embedding_size, num_classes), 1) 58 | ) 59 | 60 | self.m = self.m * self.N.unsqueeze(0) 61 | 62 | if learnable_length_scale: 63 | self.sigma = nn.Parameter(torch.zeros(num_classes) + length_scale) 64 | else: 65 | self.sigma = length_scale 66 | 67 | def update_embeddings(self, x, y): 68 | z = self.last_layer(self.compute_features(x)) 69 | 70 | # normalizing value per class, assumes y is one_hot encoded 71 | self.N = self.gamma * self.N + (1 - self.gamma) * y.sum(0) 72 | 73 | # compute sum of embeddings on class by class basis 74 | features_sum = torch.einsum("ijk,ik->jk", z, y) 75 | 76 | self.m = self.gamma * self.m + (1 - self.gamma) * features_sum 77 | 78 | def last_layer(self, z): 79 | z = torch.einsum("ij,mnj->imn", z, self.W) 80 | return z 81 | 82 | def output_layer(self, z): 83 | embeddings = self.m / self.N.unsqueeze(0) 84 | 85 | diff = z - embeddings.unsqueeze(0) 86 | distances = (-(diff**2)).mean(1).div(2 * self.sigma**2).exp() 87 | 88 | return distances 89 | 90 | def forward(self, x): 91 | z = self.last_layer(self.compute_features(x)) 92 | y_pred = self.output_layer(z) 93 | 94 | return y_pred 95 | 96 | 97 | class SoftmaxModel(Model): 98 | def __init__(self, input_size, num_classes): 99 | super().__init__() 100 | 101 | self.last_layer = nn.Linear(256, num_classes) 102 | self.output_layer = nn.LogSoftmax(dim=1) 103 | 104 | def forward(self, x): 105 | z = self.last_layer(self.compute_features(x)) 106 | y_pred = F.log_softmax(z, dim=1) 107 | 108 | return y_pred 109 | -------------------------------------------------------------------------------- /utils/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | from torchvision import datasets, transforms 7 | 8 | from scipy.io import loadmat 9 | from PIL import Image 10 | 11 | 12 | def get_MNIST(root="./"): 13 | input_size = 28 14 | num_classes = 10 15 | transform = transforms.Compose( 16 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 17 | ) 18 | 19 | train_dataset = datasets.MNIST( 20 | root + "data/", train=True, download=True, transform=transform 21 | ) 22 | 23 | test_dataset = datasets.MNIST( 24 | root + "data/", train=False, download=True, transform=transform 25 | ) 26 | return input_size, num_classes, train_dataset, test_dataset 27 | 28 | 29 | def get_FashionMNIST(root="./"): 30 | input_size = 28 31 | num_classes = 10 32 | 33 | transform_list = [transforms.ToTensor(), transforms.Normalize((0.2861,), (0.3530,))] 34 | transform = transforms.Compose(transform_list) 35 | 36 | train_dataset = datasets.FashionMNIST( 37 | root + "data/", train=True, download=True, transform=transform 38 | ) 39 | test_dataset = datasets.FashionMNIST( 40 | root + "data/", train=False, download=True, transform=transform 41 | ) 42 | return input_size, num_classes, train_dataset, test_dataset 43 | 44 | 45 | def get_SVHN(root="./"): 46 | input_size = 32 47 | num_classes = 10 48 | transform = transforms.Compose( 49 | [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 50 | ) 51 | train_dataset = datasets.SVHN( 52 | root + "data/SVHN", split="train", transform=transform, download=True 53 | ) 54 | test_dataset = datasets.SVHN( 55 | root + "data/SVHN", split="test", transform=transform, download=True 56 | ) 57 | return input_size, num_classes, train_dataset, test_dataset 58 | 59 | 60 | def get_CIFAR10(root="./"): 61 | input_size = 32 62 | num_classes = 10 63 | train_transform = transforms.Compose( 64 | [ 65 | transforms.RandomCrop(32, padding=4), 66 | transforms.RandomHorizontalFlip(), 67 | transforms.ToTensor(), 68 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 69 | ] 70 | ) 71 | train_dataset = datasets.CIFAR10( 72 | root + "data/CIFAR10", train=True, transform=train_transform, download=True 73 | ) 74 | 75 | test_transform = transforms.Compose( 76 | [ 77 | transforms.ToTensor(), 78 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 79 | ] 80 | ) 81 | test_dataset = datasets.CIFAR10( 82 | root + "data/CIFAR10", train=False, transform=test_transform, download=True 83 | ) 84 | 85 | return input_size, num_classes, train_dataset, test_dataset 86 | 87 | 88 | def get_notMNIST(root="./"): 89 | input_size = 28 90 | num_classes = 10 91 | 92 | transform = transforms.Compose( 93 | [transforms.ToTensor(), transforms.Normalize((0.4254,), (0.4586,))] 94 | ) 95 | 96 | test_dataset = NotMNIST(root + "data/", transform=transform) 97 | 98 | return input_size, num_classes, None, test_dataset 99 | 100 | 101 | all_datasets = { 102 | "MNIST": get_MNIST, 103 | "notMNIST": get_notMNIST, 104 | "FashionMNIST": get_FashionMNIST, 105 | "SVHN": get_SVHN, 106 | "CIFAR10": get_CIFAR10, 107 | } 108 | 109 | 110 | class NotMNIST(Dataset): 111 | def __init__(self, root, transform=None): 112 | root = os.path.expanduser(root) 113 | 114 | self.transform = transform 115 | 116 | data_dict = loadmat(os.path.join(root, "notMNIST_small.mat")) 117 | 118 | self.data = torch.tensor( 119 | data_dict["images"].transpose(2, 0, 1), dtype=torch.uint8 120 | ).unsqueeze(1) 121 | 122 | self.targets = torch.tensor(data_dict["labels"], dtype=torch.int64) 123 | 124 | def __getitem__(self, index): 125 | img = self.data[index] 126 | target = self.targets[index] 127 | 128 | if self.transform is not None: 129 | img = Image.fromarray(img.squeeze().numpy(), mode="L") 130 | img = self.transform(img) 131 | 132 | return img, target 133 | 134 | def __len__(self): 135 | return len(self.data) 136 | 137 | 138 | class FastFashionMNIST(datasets.FashionMNIST): 139 | def __init__(self, *args, **kwargs): 140 | super().__init__(*args, **kwargs) 141 | 142 | self.data = self.data.unsqueeze(1).float().div(255) 143 | self.data = self.data.sub_(0.2861).div_(0.3530) 144 | 145 | self.data, self.targets = self.data.to("cuda"), self.targets.to("cuda") 146 | 147 | def __getitem__(self, index): 148 | """ 149 | Args: 150 | index (int): Index 151 | 152 | Returns: 153 | tuple: (image, target) where target is index of the target class. 154 | """ 155 | img, target = self.data[index], self.targets[index] 156 | 157 | return img, target 158 | -------------------------------------------------------------------------------- /utils/evaluate_ood.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.metrics import roc_auc_score 4 | 5 | from utils.datasets import ( 6 | get_CIFAR10, 7 | get_SVHN, 8 | get_FashionMNIST, 9 | get_MNIST, 10 | get_notMNIST, 11 | ) 12 | 13 | 14 | def prepare_ood_datasets(true_dataset, ood_dataset): 15 | # Preprocess OoD dataset same as true dataset 16 | ood_dataset.transform = true_dataset.transform 17 | 18 | datasets = [true_dataset, ood_dataset] 19 | 20 | anomaly_targets = torch.cat( 21 | (torch.zeros(len(true_dataset)), torch.ones(len(ood_dataset))) 22 | ) 23 | 24 | concat_datasets = torch.utils.data.ConcatDataset(datasets) 25 | 26 | dataloader = torch.utils.data.DataLoader( 27 | concat_datasets, batch_size=500, shuffle=False, num_workers=4, pin_memory=False 28 | ) 29 | 30 | return dataloader, anomaly_targets 31 | 32 | 33 | def loop_over_dataloader(model, dataloader): 34 | model.eval() 35 | 36 | with torch.no_grad(): 37 | scores = [] 38 | accuracies = [] 39 | for data, target in dataloader: 40 | data = data.cuda() 41 | target = target.cuda() 42 | 43 | output = model(data) 44 | kernel_distance, pred = output.max(1) 45 | 46 | accuracy = pred.eq(target) 47 | accuracies.append(accuracy.cpu().numpy()) 48 | 49 | scores.append(-kernel_distance.cpu().numpy()) 50 | 51 | scores = np.concatenate(scores) 52 | accuracies = np.concatenate(accuracies) 53 | 54 | return scores, accuracies 55 | 56 | 57 | def get_auroc_ood(true_dataset, ood_dataset, model): 58 | dataloader, anomaly_targets = prepare_ood_datasets(true_dataset, ood_dataset) 59 | 60 | scores, accuracies = loop_over_dataloader(model, dataloader) 61 | 62 | accuracy = np.mean(accuracies[: len(true_dataset)]) 63 | roc_auc = roc_auc_score(anomaly_targets, scores) 64 | 65 | return accuracy, roc_auc 66 | 67 | 68 | def get_auroc_classification(dataset, model): 69 | dataloader = torch.utils.data.DataLoader( 70 | dataset, batch_size=500, shuffle=False, num_workers=4, pin_memory=False 71 | ) 72 | 73 | scores, accuracies = loop_over_dataloader(model, dataloader) 74 | 75 | accuracy = np.mean(accuracies) 76 | roc_auc = roc_auc_score(1 - accuracies, scores) 77 | 78 | return accuracy, roc_auc 79 | 80 | 81 | def get_cifar_svhn_ood(model): 82 | _, _, _, cifar_test_dataset = get_CIFAR10() 83 | _, _, _, svhn_test_dataset = get_SVHN() 84 | 85 | return get_auroc_ood(cifar_test_dataset, svhn_test_dataset, model) 86 | 87 | 88 | def get_fashionmnist_mnist_ood(model): 89 | _, _, _, fashionmnist_test_dataset = get_FashionMNIST() 90 | _, _, _, mnist_test_dataset = get_MNIST() 91 | 92 | return get_auroc_ood(fashionmnist_test_dataset, mnist_test_dataset, model) 93 | 94 | 95 | def get_fashionmnist_notmnist_ood(model): 96 | _, _, _, fashionmnist_test_dataset = get_FashionMNIST() 97 | _, _, _, notmnist_test_dataset = get_notMNIST() 98 | 99 | return get_auroc_ood(fashionmnist_test_dataset, notmnist_test_dataset, model) 100 | -------------------------------------------------------------------------------- /utils/resnet_duq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ResNet_DUQ(nn.Module): 6 | def __init__( 7 | self, 8 | feature_extractor, 9 | num_classes, 10 | centroid_size, 11 | model_output_size, 12 | length_scale, 13 | gamma, 14 | ): 15 | super().__init__() 16 | 17 | self.gamma = gamma 18 | 19 | self.W = nn.Parameter( 20 | torch.zeros(centroid_size, num_classes, model_output_size) 21 | ) 22 | nn.init.kaiming_normal_(self.W, nonlinearity="relu") 23 | 24 | self.feature_extractor = feature_extractor 25 | 26 | self.register_buffer("N", torch.zeros(num_classes) + 13) 27 | self.register_buffer( 28 | "m", torch.normal(torch.zeros(centroid_size, num_classes), 0.05) 29 | ) 30 | self.m = self.m * self.N 31 | 32 | self.sigma = length_scale 33 | 34 | def rbf(self, z): 35 | z = torch.einsum("ij,mnj->imn", z, self.W) 36 | 37 | embeddings = self.m / self.N.unsqueeze(0) 38 | 39 | diff = z - embeddings.unsqueeze(0) 40 | diff = (diff ** 2).mean(1).div(2 * self.sigma ** 2).mul(-1).exp() 41 | 42 | return diff 43 | 44 | def update_embeddings(self, x, y): 45 | self.N = self.gamma * self.N + (1 - self.gamma) * y.sum(0) 46 | 47 | z = self.feature_extractor(x) 48 | 49 | z = torch.einsum("ij,mnj->imn", z, self.W) 50 | embedding_sum = torch.einsum("ijk,ik->jk", z, y) 51 | 52 | self.m = self.gamma * self.m + (1 - self.gamma) * embedding_sum 53 | 54 | def forward(self, x): 55 | z = self.feature_extractor(x) 56 | y_pred = self.rbf(z) 57 | 58 | return y_pred 59 | -------------------------------------------------------------------------------- /utils/wide_resnet.py: -------------------------------------------------------------------------------- 1 | # Obtained from: https://github.com/meliketoy/wide-resnet.pytorch 2 | # Adapted to match: 3 | # https://github.com/szagoruyko/wide-residual-networks/tree/master/pytorch 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class WideBasic(nn.Module): 10 | def __init__(self, in_c, out_c, stride, dropout_rate): 11 | super().__init__() 12 | self.bn1 = nn.BatchNorm2d(in_c) 13 | kernel = 3 14 | padding = 1 15 | self.conv1 = nn.Conv2d(in_c, out_c, kernel, stride, padding, bias=False) 16 | 17 | self.bn2 = nn.BatchNorm2d(out_c) 18 | self.conv2 = nn.Conv2d(out_c, out_c, kernel, 1, padding, bias=False) 19 | 20 | self.dropout_rate = dropout_rate 21 | if dropout_rate > 0: 22 | self.dropout = nn.Dropout(dropout_rate) 23 | 24 | if stride != 1 or in_c != out_c: 25 | self.shortcut = nn.Conv2d(in_c, out_c, 1, stride, bias=False) 26 | else: 27 | self.shortcut = nn.Identity() 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(x)) 31 | 32 | out = self.conv1(out) 33 | 34 | out = F.relu(self.bn2(out)) 35 | 36 | if self.dropout_rate > 0: 37 | out = self.dropout(out) 38 | 39 | out = self.conv2(out) 40 | out += self.shortcut(x) 41 | 42 | return out 43 | 44 | 45 | class WideResNet(nn.Module): 46 | def __init__( 47 | self, depth=28, widen_factor=10, num_classes=None, dropout_rate=0.3, 48 | ): 49 | super().__init__() 50 | 51 | assert (depth - 4) % 6 == 0, "Wide-resnet depth should be 6n+4" 52 | 53 | self.dropout_rate = dropout_rate 54 | 55 | n = (depth - 4) // 6 56 | k = widen_factor 57 | 58 | nStages = [16, 16 * k, 32 * k, 64 * k] 59 | strides = [1, 1, 2, 2] 60 | 61 | self.conv1 = nn.Conv2d(3, nStages[0], 3, strides[0], 1, bias=False) 62 | self.layer1 = self._wide_layer(nStages[0:2], n, strides[1]) 63 | self.layer2 = self._wide_layer(nStages[1:3], n, strides[2]) 64 | self.layer3 = self._wide_layer(nStages[2:4], n, strides[3]) 65 | 66 | self.bn1 = nn.BatchNorm2d(nStages[3]) 67 | 68 | self.num_classes = num_classes 69 | if num_classes is not None: 70 | self.linear = nn.Linear(nStages[3], num_classes) 71 | 72 | for m in self.modules(): 73 | if isinstance(m, nn.Conv2d): 74 | # Sergey implementation has no mode/nonlinearity 75 | # https://github.com/szagoruyko/wide-residual-networks/blob/master/pytorch/utils.py#L17 76 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 77 | elif isinstance(m, nn.BatchNorm2d): 78 | # https://github.com/szagoruyko/wide-residual-networks/blob/master/pytorch/utils.py#L25 79 | nn.init.uniform_(m.weight) 80 | nn.init.constant_(m.bias, 0) 81 | elif isinstance(m, nn.Linear): 82 | # Sergey implementation has no mode/nonlinearity 83 | # https://github.com/szagoruyko/wide-residual-networks/blob/master/pytorch/utils.py#L21 84 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 85 | nn.init.constant_(m.bias, 0) 86 | 87 | def _wide_layer(self, channels, num_blocks, stride): 88 | strides = [stride] + [1] * (num_blocks - 1) 89 | layers = [] 90 | 91 | in_c, out_c = channels 92 | 93 | for stride in strides: 94 | layers.append(WideBasic(in_c, out_c, stride, self.dropout_rate)) 95 | in_c = out_c 96 | 97 | return nn.Sequential(*layers) 98 | 99 | def forward(self, x): 100 | out = self.conv1(x) 101 | out = self.layer1(out) 102 | out = self.layer2(out) 103 | out = self.layer3(out) 104 | out = F.relu(self.bn1(out)) 105 | out = F.avg_pool2d(out, 8) 106 | out = out.flatten(1) 107 | 108 | if self.num_classes is not None: 109 | out = self.linear(out) 110 | 111 | return out 112 | --------------------------------------------------------------------------------