├── src ├── datasets │ ├── utils.py │ └── eeg_epilepsy.py ├── schedulers.py ├── optimisers.py ├── losses.py ├── utils.py ├── loaders.py ├── run.py ├── transforms.py ├── metrics.py └── models │ └── res_net_18.py ├── data └── eeg_epilepsy │ └── download.py ├── LICENSE ├── README.md └── train.py /src/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def calculate_sample_weights(y): 6 | 7 | classes, counts = np.unique(y, return_counts=True) 8 | class_weights = dict(zip(classes, sum(counts) / counts)) 9 | sample_weights = torch.DoubleTensor([class_weights[cls] for cls in y]) 10 | 11 | return sample_weights -------------------------------------------------------------------------------- /src/schedulers.py: -------------------------------------------------------------------------------- 1 | import torch.optim.lr_scheduler as scl 2 | 3 | 4 | def get_scheduler(name, optimiser, **kwargs): 5 | 6 | if name == "reduce_on_plateau": 7 | return scl.ReduceLROnPlateau(optimiser, **kwargs) 8 | 9 | elif name == "multi_step": 10 | return scl.MultiStepLR(optimiser, **kwargs) 11 | 12 | elif name is None: 13 | return None 14 | 15 | else: 16 | raise NotImplementedError("scheduler not implemented: '{}'".format(name)) -------------------------------------------------------------------------------- /data/eeg_epilepsy/download.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import requests 3 | import zipfile 4 | 5 | URL = "https://web.archive.org/web/20200318000445/http://archive.ics.uci.edu/ml/machine-learning-databases/00388/data.csv" 6 | 7 | 8 | def download(from_path, to_path): 9 | 10 | if not to_path.exists(): 11 | 12 | try: 13 | r = requests.get(url=from_path) 14 | 15 | with open(to_path, "wb") as file: 16 | file.write(r.content) 17 | 18 | except: 19 | print("error downloading {}".format(str(from_path))) 20 | 21 | 22 | if __name__ == "__main__": 23 | 24 | print("downloading eeg epilepsy data") 25 | download(URL, to_path=pathlib.Path("dataset.csv")) 26 | -------------------------------------------------------------------------------- /src/optimisers.py: -------------------------------------------------------------------------------- 1 | import torch.optim as opt 2 | 3 | 4 | def get_optimiser(name, model, **kwargs): 5 | 6 | if name == "adam": 7 | 8 | if "early_exit" in model.name: 9 | 10 | weight_decay = kwargs.pop("weight_decay") 11 | params = [{"params": [param for name, param in model.named_parameters() if "exit_block" not in name], "weight_decay": weight_decay}] 12 | 13 | for block_idx, exit_block in enumerate(model.exit_blocks): 14 | params += [{"params": exit_block.parameters(), "weight_decay": (block_idx + 1) * weight_decay}] 15 | 16 | return opt.Adam(params, **kwargs) 17 | 18 | else: 19 | 20 | return opt.Adam(model.parameters(), **kwargs) 21 | 22 | else: 23 | raise NotImplementedError("optimiser not implemented: '{}'".format(name)) -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def get_loss(name, ensemble, **kwargs): 7 | 8 | if name == "cross_entropy": 9 | 10 | if ensemble == "early_exit": 11 | return ExitWeightedCrossEntropyLoss(**kwargs) 12 | 13 | else: 14 | return nn.CrossEntropyLoss(**kwargs) 15 | 16 | else: 17 | raise ValueError("loss not implemented: '{}'".format(name)) 18 | 19 | 20 | class ExitWeightedCrossEntropyLoss: 21 | 22 | def __init__(self, alpha): 23 | self.alpha=torch.tensor(alpha) 24 | 25 | def __call__(self, logits, labels, gamma): 26 | 27 | batch_size, num_exits, _ = logits.shape 28 | 29 | loss = 0.0 30 | for ex in range(num_exits): 31 | exit_logits = logits[:, ex, :] 32 | loss += self.alpha[ex] * gamma[ex] * F.cross_entropy(exit_logits, labels) 33 | 34 | return loss -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Alex Campbell 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Early Exit Ensembles for Uncertainty Quantification 2 | 3 | [[Paper]](https://proceedings.mlr.press/v158/qendro21a/qendro21a.pdf) [[Poster]]() [[Slides]]() 4 | 5 | 6 |
7 | 8 | # Contact 9 | 10 | Alexander Campbell (ajrc4@cl.cam.ac.uk), Lorena Qendro (lq223@cl.cam.ac.uk) 11 | 12 | 13 |
14 | 15 | # Citation 16 | 17 | If you make use of this code in your work, please cite our paper: 18 | 19 | 20 | @inproceedings{early_exit_ensembles_2021, 21 | title = {Early Exit Ensembles for Uncertainty Quantification}, 22 | booktitle = {Proceedings of Machine Learning for Health}, 23 | publisher = {PMLR}, 24 | author = {Qendro, Lorena and Campbell, Alexander and Liò, Pietro and Mascolo, Cecilia}, 25 | year = {2021}, 26 | pages = {179--193}, 27 | } 28 | 29 |
30 | 31 | Qendro, L., Campbell, A., Liò, P., & Mascolo, C. (2021). Early Exit Ensembles for Uncertainty Quantification. In Proceedings of Machine Learning for Health (pp. 179–193). PMLR. 32 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | 5 | 6 | def set_random_seed(seed, is_gpu=False): 7 | """ 8 | Set random seeds for reproducability 9 | """ 10 | max_seed_value = np.iinfo(np.uint32).max 11 | min_seed_value = np.iinfo(np.uint32).min 12 | 13 | if not (min_seed_value <= seed <= max_seed_value): 14 | raise ValueError("{} is not in bounds, numpy accepts from {} to {}".format(seed, min_seed_value, max_seed_value)) 15 | 16 | torch.manual_seed(seed) 17 | np.random.seed(seed) 18 | random.seed(seed) 19 | 20 | if torch.cuda.is_available() and is_gpu: 21 | torch.cuda.manual_seed_all(seed) 22 | 23 | 24 | def get_device(is_gpu=True, gpu_number=0): 25 | """ 26 | Set the backend for model training 27 | """ 28 | gpu_count = torch.cuda.device_count() 29 | if gpu_count < gpu_number: 30 | raise ValueError("number of cuda devices: '{}'".format(gpu_count)) 31 | 32 | else: 33 | if torch.cuda.is_available() and is_gpu: 34 | device = torch.device("cuda:{}".format(gpu_number)) 35 | else: 36 | device = torch.device("cpu") 37 | 38 | return device -------------------------------------------------------------------------------- /src/loaders.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import torch 4 | from torch.utils.data.sampler import WeightedRandomSampler 5 | from torch.utils.data import DataLoader 6 | 7 | from src.datasets.eeg_epilepsy import get_eeg_epilepsy 8 | 9 | 10 | def get_dataset_splits(name, data_dir, valid_prop, test_prop, seed): 11 | data_dir = pathlib.Path(data_dir) 12 | 13 | if name == "eeg_epilepsy": 14 | return get_eeg_epilepsy(data_dir, valid_prop, test_prop, seed) 15 | 16 | else: 17 | raise ValueError("dataset not implemented: '{}'".format(name)) 18 | 19 | 20 | def get_dataloaders(name, data_dir, valid_prop=0.10, test_prop=0.10, batch_size=16, 21 | num_workers=0, seed=1234, device=torch.device("cpu")): 22 | 23 | datasets = get_dataset_splits(name, data_dir, valid_prop, test_prop, seed=seed) 24 | 25 | train_dataset = datasets["train"] 26 | sample_weights = train_dataset.sample_weights 27 | sampler = WeightedRandomSampler(sample_weights, len(sample_weights)) 28 | 29 | pin_memory = True if device.type == "cuda" else False 30 | 31 | train = DataLoader(dataset=train_dataset, 32 | batch_size=batch_size, 33 | shuffle=False, 34 | drop_last=True, 35 | sampler=sampler, 36 | pin_memory=pin_memory, 37 | num_workers=num_workers) 38 | 39 | valid = DataLoader(dataset=datasets["valid"], 40 | batch_size=batch_size, 41 | shuffle=False, 42 | drop_last=True, 43 | pin_memory=pin_memory, 44 | num_workers=num_workers) 45 | 46 | test = DataLoader(dataset=datasets["test"], 47 | batch_size=1, 48 | shuffle=False, 49 | pin_memory=pin_memory, 50 | num_workers=num_workers) 51 | 52 | return {"train": train, "valid": valid, "test": test} -------------------------------------------------------------------------------- /src/datasets/eeg_epilepsy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pandas as pd 4 | from torch.utils.data import Dataset 5 | from sklearn.model_selection import train_test_split 6 | 7 | from src.transforms import Compose, FlipTime, Shift, FlipPolarity, GuassianNoise 8 | from src.datasets.utils import calculate_sample_weights 9 | 10 | 11 | def get_eeg_epilepsy(data_dir, valid_prop=0.10, test_prop=0.10, seed=1234): 12 | 13 | data = pd.read_csv(data_dir / "eeg_epilepsy/dataset.csv") 14 | x, y = data.drop(columns=["Unnamed: 0", "y"]), data["y"] 15 | 16 | x, x_test, y, y_test = train_test_split(x, y, test_size=test_prop, shuffle=True, random_state=seed) 17 | x_train, x_valid, y_train, y_valid = train_test_split(x, y, test_size=valid_prop, shuffle=True, random_state=seed) 18 | 19 | train_sample_weights = calculate_sample_weights(y_train) 20 | 21 | reverse = FlipTime(p=0.5) 22 | shift = Shift(p=0.5) 23 | flip = FlipPolarity(p=0.5) 24 | noise = GuassianNoise(min_amplitude=0.01, max_amplitude=1.0, p=0.5) 25 | transforms = Compose([reverse, flip, shift, noise]) 26 | 27 | datasets = {} 28 | for stage, x, y in zip(["train", "valid", "test"], [x_train, x_valid, x_test], [y_train, y_valid, y_test]): 29 | 30 | dataset = EEGEpilepsyDataset(x, y, transforms=transforms if stage=="train" else None) 31 | 32 | if stage == "train": 33 | dataset.sample_weights = train_sample_weights 34 | 35 | datasets[stage] = dataset 36 | 37 | return datasets 38 | 39 | 40 | class EEGEpilepsyDataset(Dataset): 41 | 42 | def __init__(self, data, label, transforms=None): 43 | self.data = data.values 44 | self.label = label.values - 1 45 | self.transforms = transforms 46 | self.num_classes = 5 47 | 48 | def __len__(self): 49 | return len(self.data) 50 | 51 | def __getitem__(self, idx): 52 | x, y = self.data[idx], self.label[idx] 53 | 54 | if self.transforms: 55 | x = self.transforms(x, sample_rate=None) 56 | 57 | x = torch.from_numpy(x).unsqueeze(0).float() 58 | y = torch.tensor(y).long() 59 | 60 | return x, y -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import hydra 3 | import wandb 4 | import pathlib 5 | 6 | from src.loaders import get_dataloaders 7 | from src.models import get_model 8 | from src.optimisers import get_optimiser 9 | from src.schedulers import get_scheduler 10 | from src.losses import get_loss 11 | from src.run import run 12 | from src.utils import set_random_seed, get_device, load_config, save_config, to_dict 13 | 14 | 15 | @hydra.main(config_path="config", config_name="config.yaml") 16 | def train(cfg): 17 | # set random seed for reproducibility 18 | set_random_seed(seed=cfg.train.seed, 19 | is_gpu=cfg.train.is_gpu) 20 | 21 | # get training backend 22 | device = get_device(is_gpu=cfg.train.is_gpu, 23 | gpu_number=cfg.train.gpu_number) 24 | 25 | # unique id 26 | experiment_id = cfg.experiment.id if cfg.experiment.id is not None else uuid.uuid4().hex[:8] 27 | 28 | # initialise logging 29 | if cfg.logging.wb_logging: wandb.init(project=cfg.logging.wb_project, id=experiment_id) 30 | 31 | model_name, ensemble = cfg.model.name, cfg.model.ensemble 32 | models_dir = pathlib.Path("./models") / ((model_name + "_" + ensemble) if ensemble is not None else model_name) / ("run_" + str(cfg.experiment.run)) / experiment_id 33 | models_dir.mkdir(parents=True) 34 | 35 | # initalise dataloaders 36 | dataloaders = get_dataloaders(**to_dict(cfg.data), 37 | seed=cfg.train.seed + cfg.train.run, 38 | device=device) 39 | 40 | # initialise model 41 | model = get_model(**to_dict(cfg.model)).to(device) 42 | 43 | # initialise loss 44 | loss = get_loss(ensemble=ensemble, **to_dict(cfg.loss)) 45 | 46 | # initialise optimiser 47 | optimiser = get_optimiser(model=model, **to_dict(cfg.optimiser)) 48 | 49 | # initialise scheduler 50 | scheduler = get_scheduler(optimiser=optimiser, **to_dict(cfg.scheduler)) 51 | 52 | # train model 53 | run(model=model, 54 | train_loader=dataloaders["train"], 55 | valid_loader=dataloaders["valid"], 56 | criterion=loss, 57 | optimiser=optimiser, 58 | scheduler=scheduler, 59 | num_epochs=cfg.train.num_epochs, 60 | save_dir=models_dir, 61 | device=device, 62 | wb_logging=cfg.logging.wb_logging) 63 | 64 | # save hyperparameters 65 | save_config(cfg, models_dir) 66 | 67 | 68 | if __name__ == "__main__": 69 | train() -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | import time 2 | import wandb 3 | import torch 4 | import copy 5 | 6 | from src.metrics import total_corrrect 7 | 8 | 9 | def _save_model(model, save_dir, is_checkpoint=False): 10 | 11 | if is_checkpoint: 12 | torch.save(model.state_dict(), save_dir / "best_model.pth.tar") 13 | 14 | else: 15 | torch.save(model.state_dict(), save_dir / "last_model.pth.tar") 16 | 17 | 18 | def _train_epoch(model, dataloader, criterion, optimiser, device): 19 | 20 | model.train() 21 | 22 | epoch_loss = 0 23 | epoch_acc = 0 24 | 25 | for x, y in dataloader: 26 | # transfer signal, y to device 27 | x, y = x.to(device), y.to(device).reshape(-1) 28 | # clear gradients of model parameters 29 | optimiser.zero_grad() 30 | # forward pass 31 | logits = model(x) 32 | # calculate metrics 33 | loss = criterion(logits, y) 34 | correct = total_corrrect(logits, y) 35 | # backward pass 36 | loss.backward() 37 | # update model parameters 38 | optimiser.step() 39 | # accumulate loss over batch 40 | epoch_loss += loss.item() / len(dataloader) 41 | epoch_acc += (100 * correct.item()) / len(dataloader) 42 | 43 | break 44 | 45 | return epoch_loss, epoch_acc 46 | 47 | 48 | def _valid_epoch(model, dataloader, criterion, device): 49 | 50 | model.eval() 51 | 52 | epoch_loss = 0 53 | epoch_acc = 0 54 | 55 | for x, y in dataloader: 56 | # transfer x, y to device 57 | x, y = x.to(device), y.to(device).reshape(-1) 58 | # do not calculate gradients 59 | with torch.no_grad(): 60 | # forward pass 61 | logits = model(x) 62 | # calculate metrics 63 | loss = criterion(logits, y) 64 | correct = total_corrrect(logits, y) 65 | # accumulate loss over batch 66 | epoch_loss += loss.item() / len(dataloader) 67 | epoch_acc += (100 * correct.item()) / len(dataloader) 68 | 69 | break 70 | 71 | return epoch_loss, epoch_acc 72 | 73 | 74 | def run(model, train_loader, valid_loader, criterion, optimiser, scheduler, num_epochs, save_dir, device=torch.device("cpu"), wb_logging=False): 75 | 76 | if wb_logging: wandb.watch(model) 77 | 78 | train_time = 0. 79 | best_valid_acc = -1. 80 | 81 | for epoch in range(num_epochs): 82 | start_time = time.time() 83 | 84 | train_loss, train_acc = _train_epoch(model, train_loader, criterion, optimiser, device) 85 | valid_loss, valid_acc = _valid_epoch(model, valid_loader, criterion, device) 86 | 87 | is_best = valid_acc > best_valid_acc 88 | if is_best: 89 | best_valid_acc = valid_acc 90 | _save_model(model, save_dir, is_checkpoint=True) 91 | 92 | if scheduler is not None: 93 | scheduler.step() 94 | 95 | end_time = time.strftime("%H:%M:%S", time.gmtime(time.time() - start_time)) 96 | to_print = "{} | epoch {:4d} of {:4d} | train loss {:06.3f} | train acc {:05.2f} | valid loss {:06.3f} | valid acc {:05.2f} | time: {} " 97 | if is_best: to_print = to_print + "| *" 98 | print(to_print.format(save_dir.stem, epoch + 1, num_epochs, train_loss, train_acc, valid_loss, valid_acc, end_time)) 99 | 100 | if wb_logging: wandb.log(dict(train={"loss": train_loss, "acc": train_acc}, valid={"loss": valid_loss, "acc": valid_acc})) 101 | 102 | _save_model(model, save_dir, train_time=train_time) 103 | -------------------------------------------------------------------------------- /src/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | 5 | class BaseTransform: 6 | 7 | def __init__(self, p): 8 | """ 9 | p: Probability of applying transform. 10 | """ 11 | 12 | assert 0 <= p <= 1 13 | self.p = p 14 | 15 | def apply(self, signals): 16 | raise NotImplementedError 17 | 18 | def __call__(self, signals): 19 | 20 | if random.random() < self.p: 21 | signals = self.apply(signals) 22 | 23 | return signals 24 | 25 | 26 | class Compose: 27 | 28 | def __init__(self, transforms, p=1): 29 | """ 30 | transforms: List of transforms to apply to signals. 31 | p: Probability of applying the list of transforms. 32 | """ 33 | 34 | assert 0 <= p <= 1 35 | 36 | self.transforms = transforms 37 | self.p = p 38 | 39 | def __call__(self, signals): 40 | 41 | if random.random() < self.p: 42 | for transform in self.transforms: 43 | signals = transform(signals) 44 | 45 | return signals 46 | 47 | 48 | class FlipTime(BaseTransform): 49 | """Randomly flip signals along temporal dimension""" 50 | 51 | def __init__(self, p=0.5): 52 | """ 53 | p: Probability of applying transform. 54 | """ 55 | super().__init__(p) 56 | 57 | 58 | def apply(self, signals): 59 | 60 | if len(signals.shape) > 1: 61 | signals = np.fliplr(signals) 62 | 63 | else: 64 | signals = np.flipud(signals) 65 | 66 | return signals 67 | 68 | 69 | class MaskTime(BaseTransform): 70 | """Randomly mask signal""" 71 | 72 | def __init__(self, min_fraction=0.0, max_fraction=0.5, p=0.5): 73 | """ 74 | min_fraction: Minimum length of the mask as a fraction of the total time series length. 75 | max_fraction: Maximum length of the mask as a fraction of the total time series length. 76 | p: Probability of applying transform. 77 | """ 78 | 79 | super().__init__(p) 80 | 81 | assert 0 <= min_fraction <= 1 82 | assert 0 <= max_fraction <= 1 83 | assert max_fraction >= min_fraction 84 | 85 | self.min_fraction = min_fraction 86 | self.max_fraction = max_fraction 87 | 88 | 89 | def apply(self, signals): 90 | 91 | num_samples = signals.shape[-1] 92 | length = random.randint(int(num_samples * self.min_fraction), int(num_samples * self.max_fraction)) 93 | start = random.randint(0, num_samples - length) 94 | 95 | mask = np.zeros(length) 96 | masked_signals = signals.copy() 97 | masked_signals[..., start : start + length] *= mask 98 | 99 | return masked_signals 100 | 101 | 102 | class Shift(BaseTransform): 103 | """Shift the signals forwards or backwards along the temporal dimension""" 104 | 105 | def __init__(self, min_fraction=-0.5, max_fraction=0.5, rollover=True, p=0.5): 106 | """ 107 | min_fraction: Fraction of total timeseries to shift. 108 | max_fraction: Fraction of total timeseries to shift. 109 | rollover: Samples that roll beyond the first or last position are re-introduced at the last or first otherwise set to zero. 110 | p: Probability of applying this transform. 111 | """ 112 | 113 | super().__init__(p) 114 | 115 | assert min_fraction >= -1 116 | assert max_fraction <= 1 117 | 118 | self.min_fraction = min_fraction 119 | self.max_fraction = max_fraction 120 | self.rollover = rollover 121 | 122 | def apply(self, signals): 123 | 124 | num_samples = signals.shape[-1] 125 | num_shift = int(round(random.uniform(self.min_fraction, self.max_fraction) * num_samples)) 126 | signals = np.roll(signals, num_shift, axis=-1) 127 | 128 | if not self.rollover: 129 | if num_shift > 0: 130 | signals[..., :num_shift] = 0.0 131 | 132 | elif num_shift < 0: 133 | signals[..., num_shift:] = 0.0 134 | 135 | return signals 136 | 137 | 138 | class FlipPolarity(BaseTransform): 139 | """Randomly flip sign of signal""" 140 | 141 | def __init__(self, p=0.5): 142 | """ 143 | p: Probability of applying transform. 144 | """ 145 | super().__init__(p) 146 | 147 | def apply(self, signals): 148 | return -signals 149 | 150 | 151 | class GuassianNoise(BaseTransform): 152 | """Add gaussian noise to the signals""" 153 | 154 | def __init__(self, min_amplitude=0.001, max_amplitude=0.015, p=0.5): 155 | """ 156 | min_amplitude: minimum amplitude of noise. 157 | max_amplitude: maximum amplitude of noise. 158 | p: Probability of applying this transform. 159 | """ 160 | super().__init__(p) 161 | 162 | assert min_amplitude > 0.0 163 | assert max_amplitude > 0.0 164 | assert max_amplitude >= min_amplitude 165 | 166 | self.min_amplitude = min_amplitude 167 | self.max_amplitude = max_amplitude 168 | 169 | def apply(self, signals): 170 | 171 | amplitude = random.uniform(self.min_amplitude, self.max_amplitude) 172 | 173 | noise = np.random.randn(*signals.shape).astype(np.float32) 174 | signals = signals + amplitude * noise 175 | 176 | return signals -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torchmetrics.functional as tm 4 | from torch.distributions import Categorical 5 | 6 | 7 | def count_parameters(model): 8 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 9 | 10 | 11 | def F1(logits, labels, ensemble_weights, average="weighted"): 12 | 13 | 14 | _, num_exits, num_classes = logits.shape 15 | scale = ensemble_weights.sum() 16 | 17 | pred_labels = logits.softmax(dim=-1).mul(ensemble_weights).sum(dim=-2).div(scale).argmax(-1) 18 | 19 | f1 = tm.f1(pred_labels, labels, num_classes=num_classes, average=average) 20 | 21 | return f1 22 | 23 | 24 | def precision(logits, labels, ensemble_weights, average="weighted"): 25 | 26 | _, num_exits, num_classes = logits.shape 27 | scale = ensemble_weights.sum() 28 | 29 | pred_labels = logits.softmax(dim=-1).mul(ensemble_weights).sum(dim=-2).div(scale).argmax(-1) 30 | 31 | pr = tm.precision(pred_labels, labels, num_classes=num_classes, average=average) 32 | 33 | return pr 34 | 35 | 36 | def recall(logits, labels, ensemble_weights, average="weighted"): 37 | 38 | _, num_exits, num_classes = logits.shape 39 | scale = ensemble_weights.sum() 40 | 41 | pred_labels = logits.softmax(dim=-1).mul(ensemble_weights).sum(dim=-2).div(scale).argmax(-1) 42 | 43 | rc = tm.recall(pred_labels, labels, num_classes=num_classes, average=average) 44 | 45 | return rc 46 | 47 | 48 | def negative_loglikelihood(logits, labels, ensemble_weights, reduction="mean"): 49 | 50 | _, num_exits, num_classes = logits.shape 51 | scale = ensemble_weights.sum() 52 | 53 | probs = logits.softmax(dim=-1).mul(ensemble_weights).sum(dim=-2).div(scale) 54 | 55 | nll = -Categorical(probs=probs).log_prob(labels) 56 | 57 | if reduction == "mean": 58 | nll = nll.mean() 59 | 60 | return nll 61 | 62 | 63 | def brier_score(logits, labels, ensemble_weights, reduction="mean"): 64 | 65 | _, num_exits, num_classes = logits.shape 66 | scale = ensemble_weights.sum() 67 | 68 | probs = logits.softmax(dim=-1).mul(ensemble_weights).sum(dim=-2).div(scale) 69 | 70 | labels_one_hot = F.one_hot(labels, num_classes=num_classes) 71 | 72 | bs = ((probs - labels_one_hot)**2).sum(dim=-1) 73 | 74 | if reduction == "mean": 75 | bs = bs.mean() 76 | 77 | return bs 78 | 79 | 80 | def predictive_entropy(logits, labels, ensemble_weights, reduction="mean"): 81 | 82 | _, num_exits, num_classes = logits.shape 83 | scale = ensemble_weights.sum() 84 | 85 | probs = logits.softmax(dim=-1).mul(ensemble_weights).sum(dim=-2).div(scale) 86 | 87 | et = Categorical(probs=probs).entropy() 88 | 89 | if reduction == "mean": 90 | et = et.mean() 91 | 92 | return et 93 | 94 | 95 | def predictive_confidence(logits, labels, ensemble_weights, reduction="mean"): 96 | 97 | num_samples, _, _ = logits.shape 98 | scale = ensemble_weights.sum() 99 | 100 | probs = logits.softmax(dim=-1).mul(ensemble_weights).sum(dim=-2).div(scale) 101 | 102 | pc = probs[torch.arange(num_samples), labels] 103 | 104 | if reduction == "mean": 105 | pc = pc.mean() 106 | 107 | return pc 108 | 109 | 110 | def expected_calibration_error(logits, labels, ensemble_weights, n_bins=15): 111 | 112 | num_samples, num_exits, num_classes = logits.shape 113 | scale = ensemble_weights.sum() 114 | 115 | pred_probs = logits.softmax(dim=-1).mul(ensemble_weights).sum(dim=-2).div(scale) 116 | pred_labels = pred_probs.argmax(-1) 117 | 118 | pred_probs = pred_probs[torch.arange(num_samples), pred_labels] 119 | 120 | correct = pred_labels.eq(labels) 121 | 122 | bin_boundaries = torch.linspace(0, 1, n_bins + 1) 123 | 124 | conf_bin = torch.zeros_like(bin_boundaries) 125 | acc_bin = torch.zeros_like(bin_boundaries) 126 | prop_bin = torch.zeros_like(bin_boundaries) 127 | 128 | for i, (bin_lower, bin_upper) in enumerate(zip(bin_boundaries[:-1], bin_boundaries[1:])): 129 | 130 | in_bin = pred_probs.gt(bin_lower.item()) * pred_probs.le(bin_upper.item()) 131 | prop_in_bin = in_bin.float().mean() 132 | 133 | if prop_in_bin.item() > 0: 134 | # probability of making a correct prediction given a probability bin 135 | acc_bin[i] = correct[in_bin].float().mean() 136 | # average predicted probabily given a probability bin. 137 | conf_bin[i] = pred_probs[in_bin].mean() 138 | # probability of observing a probability bin 139 | prop_bin[i] = prop_in_bin 140 | 141 | ece = ((acc_bin - conf_bin).abs() * prop_bin).sum() 142 | 143 | return ece 144 | 145 | 146 | def calculate_metrics(model, logits, labels, ensemble_weights): 147 | 148 | metrics = dict(f1=F1(logits, labels, ensemble_weights, average="weighted").numpy(), 149 | precision=precision(logits, labels, ensemble_weights, average="weighted").numpy(), 150 | recall=recall(logits, labels, ensemble_weights, average="weighted").numpy(), 151 | negative_loglikelihood=negative_loglikelihood(logits, labels, ensemble_weights, reduction="mean").numpy(), 152 | brier_score=brier_score(logits, labels, ensemble_weights, reduction="mean").numpy(), 153 | predictive_entropy=predictive_entropy(logits, labels, ensemble_weights, reduction="mean").numpy(), 154 | expected_calibration_error=expected_calibration_error(logits, labels, ensemble_weights, n_bins=15).numpy(), 155 | params=count_parameters(model)) 156 | 157 | return metrics -------------------------------------------------------------------------------- /src/models/res_net_18.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def get_res_net_18(ensemble, **kwargs): 8 | 9 | if ensemble is None: 10 | return ResNet18(**kwargs) 11 | 12 | elif ensemble == "early_exit": 13 | return ResNet18EarlyExit(**kwargs) 14 | 15 | elif ensemble == "mc_dropout": 16 | return ResNet18MCDrop(**kwargs) 17 | 18 | elif ensemble == "deep": 19 | return ResNet18(**kwargs) 20 | 21 | elif ensemble == "depth": 22 | return ResNet18Depth(**kwargs) 23 | 24 | else: 25 | NotImplementedError("ensemble not implemented: '{}'".format(ensemble)) 26 | 27 | 28 | def init_weights(model): 29 | 30 | for module in model.modules(): 31 | 32 | if isinstance(module, nn.Conv1d): 33 | nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") 34 | 35 | elif isinstance(module, nn.BatchNorm1d): 36 | nn.init.constant_(module.weight, 1) 37 | nn.init.constant_(module.bias, 0) 38 | 39 | 40 | def conv3x3(in_planes, out_planes, stride=1): 41 | return nn.Conv1d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 42 | 43 | 44 | def _conv1x1(in_planes, out_planes, stride=1): 45 | return nn.Conv1d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 46 | 47 | 48 | class BasicBlock(nn.Module): 49 | 50 | def __init__(self, inplanes, planes, stride=1, downsample=None): 51 | super().__init__() 52 | 53 | self.conv1 = conv3x3(inplanes, planes, stride) 54 | self.bn1 = nn.BatchNorm1d(planes) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.conv2 = conv3x3(planes, planes) 57 | self.bn2 = nn.BatchNorm1d(planes) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | 63 | identity = x 64 | out = self.conv1(x) 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | out = self.conv2(out) 68 | out = self.bn2(out) 69 | 70 | if self.downsample is not None: 71 | identity = self.downsample(x) 72 | 73 | out += identity 74 | out = self.relu(out) 75 | 76 | return out 77 | 78 | 79 | class ResNet18(nn.Module): 80 | 81 | name = "res_net_18" 82 | 83 | def __init__(self, out_channels, seed=None): 84 | super().__init__() 85 | 86 | self.out_channels = out_channels 87 | self.seed = seed 88 | 89 | self.hidden_sizes = [64, 128, 256, 512] 90 | self.layers = [2, 2, 2, 2] 91 | self.strides = [1, 2, 2, 2] 92 | self.inplanes = self.hidden_sizes[0] 93 | 94 | in_block = [nn.Conv1d(1, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)] 95 | in_block += [nn.BatchNorm1d(self.inplanes)] 96 | in_block += [nn.ReLU(inplace=True)] 97 | in_block += [nn.MaxPool1d(kernel_size=3, stride=2, padding=1)] 98 | self.in_block = nn.Sequential(*in_block) 99 | 100 | blocks = [] 101 | for h, l, s in zip(self.hidden_sizes, self.layers, self.strides): 102 | blocks += [self._make_layer(h, l, s)] 103 | self.blocks = nn.Sequential(*blocks) 104 | 105 | out_block = [nn.AdaptiveAvgPool1d(1)] 106 | out_block += [nn.Flatten(1)] 107 | out_block += [nn.Linear(self.hidden_sizes[-1], self.out_channels)] 108 | self.out_block = nn.Sequential(*out_block) 109 | 110 | if self.seed is not None: 111 | torch.manual_seed(seed) 112 | 113 | self.apply(init_weights) 114 | 115 | def _make_layer(self, planes, blocks, stride=1): 116 | 117 | downsample = None 118 | 119 | if stride != 1 or self.inplanes != planes: 120 | downsample = nn.Sequential(_conv1x1(self.inplanes, planes, stride), nn.BatchNorm1d(planes)) 121 | 122 | layers = [BasicBlock(self.inplanes, planes, stride, downsample)] 123 | self.inplanes = planes 124 | 125 | for _ in range(1, blocks): 126 | layers += [BasicBlock(self.inplanes, planes)] 127 | 128 | return nn.Sequential(*layers) 129 | 130 | def forward(self, x): 131 | 132 | x = self.in_block(x) 133 | x = self.blocks(x) 134 | x = self.out_block(x) 135 | 136 | return x 137 | 138 | 139 | class ExitBlock(nn.Module): 140 | 141 | def __init__(self, in_channels, hidden_sizes, out_channels): 142 | super().__init__() 143 | 144 | layers = [nn.AdaptiveAvgPool1d(1)] 145 | layers += [nn.Flatten(1)] 146 | layers += [nn.Linear(in_channels, hidden_sizes)] 147 | layers += [nn.ReLU()] 148 | layers += [nn.Linear(hidden_sizes, out_channels)] 149 | self.layers = nn.Sequential(*layers) 150 | 151 | def forward(self, x): 152 | 153 | return self.layers(x) 154 | 155 | 156 | class ResNet18EarlyExit(ResNet18): 157 | 158 | name = "res_net_18_early_exit" 159 | 160 | def __init__(self, *args, exit_after=-1, complexity_factor=1.2, **kwargs): 161 | self.exit_after = exit_after 162 | self.complexity_factor = complexity_factor 163 | 164 | super().__init__(*args, **kwargs) 165 | 166 | to_exit = [2, 8, 15, 24, 31, 40, 47, 56] 167 | hidden_sizes = len(self.hidden_sizes) 168 | 169 | num_hidden = len(self.hidden_sizes) 170 | exit_hidden_sizes = [int(((self.complexity_factor ** 0.5) ** (num_hidden - idx)) * self.hidden_sizes[-1]) for idx in range(num_hidden)] 171 | exit_hidden_sizes = [h for pair in zip(exit_hidden_sizes, exit_hidden_sizes) for h in pair] 172 | 173 | if self.exit_after == -1: 174 | self.exit_after = range(len(to_exit)) 175 | 176 | num_exits = len(to_exit) 177 | 178 | if (len(self.exit_after) > num_exits) or not set(self.exit_after).issubset(list(range(num_exits))): 179 | raise ValueError("valid exit points: {}".format(", ".join(str(n) for n in range(num_exits)))) 180 | 181 | self.exit_hidden_sizes = np.array(exit_hidden_sizes)[self.exit_after] 182 | 183 | blocks = [] 184 | for idx, module in enumerate(self.blocks.modules()): 185 | if idx in to_exit: 186 | blocks += [module] 187 | self.blocks = nn.ModuleList(blocks) 188 | 189 | idx = 0 190 | exit_blocks = [] 191 | for block_idx, block in enumerate(self.blocks): 192 | if block_idx in self.exit_after: 193 | in_channels = block.conv1.out_channels 194 | exit_blocks += [ExitBlock(in_channels, self.exit_hidden_sizes[idx], self.out_channels)] 195 | idx += 1 196 | self.exit_blocks = nn.ModuleList(exit_blocks) 197 | 198 | self.apply(init_weights) 199 | 200 | def forward(self, x): 201 | 202 | out = self.in_block(x) 203 | 204 | out_blocks = [] 205 | for block in self.blocks: 206 | out = block(out) 207 | out_blocks += [out] 208 | 209 | out_exits = [] 210 | for exit_after, exit_block in zip(self.exit_after, self.exit_blocks): 211 | out = exit_block(out_blocks[exit_after]) 212 | out_exits += [out] 213 | 214 | out = self.out_block(out_blocks[-1]) 215 | out = torch.stack(out_exits + [out], dim=1) 216 | 217 | return out 218 | 219 | 220 | class MCDropout(nn.Dropout): 221 | 222 | def forward(self, x): 223 | return F.dropout(x, self.p, True, self.inplace) 224 | 225 | 226 | class ResNet18MCDrop(ResNet18EarlyExit): 227 | 228 | name = "res_net_18_mc_drop" 229 | 230 | def __init__(self, *args, drop_after=-1, drop_prob=0.2, **kwargs): 231 | self.drop_after = drop_after 232 | self.drop_prob = drop_prob 233 | 234 | super().__init__(*args, exit_after=drop_after, **kwargs) 235 | 236 | self.drop_after = self.exit_after 237 | 238 | self.__delattr__("exit_after") 239 | self.__delattr__("exit_blocks") 240 | 241 | for block_idx in self.drop_after: 242 | self.blocks[block_idx].add_module("dropout", MCDropout(self.drop_prob)) 243 | 244 | def forward(self, x): 245 | 246 | x = self.in_block(x) 247 | x = self.blocks(x) 248 | x = self.out_block(x) 249 | 250 | return x 251 | 252 | 253 | class ResNet18Depth(ResNet18): 254 | 255 | name = "res_net_18_depth" 256 | 257 | def __init__(self, *args, max_depth=1, **kwargs): 258 | self.max_depth = max_depth 259 | 260 | super().__init__(*args, **kwargs) 261 | 262 | num_blocks = len(self.hidden_sizes) 263 | 264 | if self.max_depth == -1: 265 | self.max_depth = len(self.hidden_sizes) 266 | 267 | elif (max_depth > num_blocks) or (max_depth < 1): 268 | raise ValueError("valid depths: {}".format(", ".join(str(n) for n in range(1, num_blocks + 1)))) 269 | 270 | self.blocks = self.blocks[:self.max_depth] 271 | 272 | out_block = [nn.AdaptiveAvgPool1d(1)] 273 | out_block += [nn.Flatten(1)] 274 | out_block += [nn.Linear(self.hidden_sizes[self.max_depth - 1], self.out_channels)] 275 | self.out_block = nn.Sequential(*out_block) --------------------------------------------------------------------------------