├── src
├── datasets
│ ├── utils.py
│ └── eeg_epilepsy.py
├── schedulers.py
├── optimisers.py
├── losses.py
├── utils.py
├── loaders.py
├── run.py
├── transforms.py
├── metrics.py
└── models
│ └── res_net_18.py
├── data
└── eeg_epilepsy
│ └── download.py
├── LICENSE
├── README.md
└── train.py
/src/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def calculate_sample_weights(y):
6 |
7 | classes, counts = np.unique(y, return_counts=True)
8 | class_weights = dict(zip(classes, sum(counts) / counts))
9 | sample_weights = torch.DoubleTensor([class_weights[cls] for cls in y])
10 |
11 | return sample_weights
--------------------------------------------------------------------------------
/src/schedulers.py:
--------------------------------------------------------------------------------
1 | import torch.optim.lr_scheduler as scl
2 |
3 |
4 | def get_scheduler(name, optimiser, **kwargs):
5 |
6 | if name == "reduce_on_plateau":
7 | return scl.ReduceLROnPlateau(optimiser, **kwargs)
8 |
9 | elif name == "multi_step":
10 | return scl.MultiStepLR(optimiser, **kwargs)
11 |
12 | elif name is None:
13 | return None
14 |
15 | else:
16 | raise NotImplementedError("scheduler not implemented: '{}'".format(name))
--------------------------------------------------------------------------------
/data/eeg_epilepsy/download.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import requests
3 | import zipfile
4 |
5 | URL = "https://web.archive.org/web/20200318000445/http://archive.ics.uci.edu/ml/machine-learning-databases/00388/data.csv"
6 |
7 |
8 | def download(from_path, to_path):
9 |
10 | if not to_path.exists():
11 |
12 | try:
13 | r = requests.get(url=from_path)
14 |
15 | with open(to_path, "wb") as file:
16 | file.write(r.content)
17 |
18 | except:
19 | print("error downloading {}".format(str(from_path)))
20 |
21 |
22 | if __name__ == "__main__":
23 |
24 | print("downloading eeg epilepsy data")
25 | download(URL, to_path=pathlib.Path("dataset.csv"))
26 |
--------------------------------------------------------------------------------
/src/optimisers.py:
--------------------------------------------------------------------------------
1 | import torch.optim as opt
2 |
3 |
4 | def get_optimiser(name, model, **kwargs):
5 |
6 | if name == "adam":
7 |
8 | if "early_exit" in model.name:
9 |
10 | weight_decay = kwargs.pop("weight_decay")
11 | params = [{"params": [param for name, param in model.named_parameters() if "exit_block" not in name], "weight_decay": weight_decay}]
12 |
13 | for block_idx, exit_block in enumerate(model.exit_blocks):
14 | params += [{"params": exit_block.parameters(), "weight_decay": (block_idx + 1) * weight_decay}]
15 |
16 | return opt.Adam(params, **kwargs)
17 |
18 | else:
19 |
20 | return opt.Adam(model.parameters(), **kwargs)
21 |
22 | else:
23 | raise NotImplementedError("optimiser not implemented: '{}'".format(name))
--------------------------------------------------------------------------------
/src/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def get_loss(name, ensemble, **kwargs):
7 |
8 | if name == "cross_entropy":
9 |
10 | if ensemble == "early_exit":
11 | return ExitWeightedCrossEntropyLoss(**kwargs)
12 |
13 | else:
14 | return nn.CrossEntropyLoss(**kwargs)
15 |
16 | else:
17 | raise ValueError("loss not implemented: '{}'".format(name))
18 |
19 |
20 | class ExitWeightedCrossEntropyLoss:
21 |
22 | def __init__(self, alpha):
23 | self.alpha=torch.tensor(alpha)
24 |
25 | def __call__(self, logits, labels, gamma):
26 |
27 | batch_size, num_exits, _ = logits.shape
28 |
29 | loss = 0.0
30 | for ex in range(num_exits):
31 | exit_logits = logits[:, ex, :]
32 | loss += self.alpha[ex] * gamma[ex] * F.cross_entropy(exit_logits, labels)
33 |
34 | return loss
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Alex Campbell
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Early Exit Ensembles for Uncertainty Quantification
2 |
3 | [[Paper]](https://proceedings.mlr.press/v158/qendro21a/qendro21a.pdf) [[Poster]]() [[Slides]]()
4 |
5 |
6 |
7 |
8 | # Contact
9 |
10 | Alexander Campbell (ajrc4@cl.cam.ac.uk), Lorena Qendro (lq223@cl.cam.ac.uk)
11 |
12 |
13 |
14 |
15 | # Citation
16 |
17 | If you make use of this code in your work, please cite our paper:
18 |
19 |
20 | @inproceedings{early_exit_ensembles_2021,
21 | title = {Early Exit Ensembles for Uncertainty Quantification},
22 | booktitle = {Proceedings of Machine Learning for Health},
23 | publisher = {PMLR},
24 | author = {Qendro, Lorena and Campbell, Alexander and Liò, Pietro and Mascolo, Cecilia},
25 | year = {2021},
26 | pages = {179--193},
27 | }
28 |
29 |
30 |
31 | Qendro, L., Campbell, A., Liò, P., & Mascolo, C. (2021). Early Exit Ensembles for Uncertainty Quantification. In Proceedings of Machine Learning for Health (pp. 179–193). PMLR.
32 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import random
4 |
5 |
6 | def set_random_seed(seed, is_gpu=False):
7 | """
8 | Set random seeds for reproducability
9 | """
10 | max_seed_value = np.iinfo(np.uint32).max
11 | min_seed_value = np.iinfo(np.uint32).min
12 |
13 | if not (min_seed_value <= seed <= max_seed_value):
14 | raise ValueError("{} is not in bounds, numpy accepts from {} to {}".format(seed, min_seed_value, max_seed_value))
15 |
16 | torch.manual_seed(seed)
17 | np.random.seed(seed)
18 | random.seed(seed)
19 |
20 | if torch.cuda.is_available() and is_gpu:
21 | torch.cuda.manual_seed_all(seed)
22 |
23 |
24 | def get_device(is_gpu=True, gpu_number=0):
25 | """
26 | Set the backend for model training
27 | """
28 | gpu_count = torch.cuda.device_count()
29 | if gpu_count < gpu_number:
30 | raise ValueError("number of cuda devices: '{}'".format(gpu_count))
31 |
32 | else:
33 | if torch.cuda.is_available() and is_gpu:
34 | device = torch.device("cuda:{}".format(gpu_number))
35 | else:
36 | device = torch.device("cpu")
37 |
38 | return device
--------------------------------------------------------------------------------
/src/loaders.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 |
3 | import torch
4 | from torch.utils.data.sampler import WeightedRandomSampler
5 | from torch.utils.data import DataLoader
6 |
7 | from src.datasets.eeg_epilepsy import get_eeg_epilepsy
8 |
9 |
10 | def get_dataset_splits(name, data_dir, valid_prop, test_prop, seed):
11 | data_dir = pathlib.Path(data_dir)
12 |
13 | if name == "eeg_epilepsy":
14 | return get_eeg_epilepsy(data_dir, valid_prop, test_prop, seed)
15 |
16 | else:
17 | raise ValueError("dataset not implemented: '{}'".format(name))
18 |
19 |
20 | def get_dataloaders(name, data_dir, valid_prop=0.10, test_prop=0.10, batch_size=16,
21 | num_workers=0, seed=1234, device=torch.device("cpu")):
22 |
23 | datasets = get_dataset_splits(name, data_dir, valid_prop, test_prop, seed=seed)
24 |
25 | train_dataset = datasets["train"]
26 | sample_weights = train_dataset.sample_weights
27 | sampler = WeightedRandomSampler(sample_weights, len(sample_weights))
28 |
29 | pin_memory = True if device.type == "cuda" else False
30 |
31 | train = DataLoader(dataset=train_dataset,
32 | batch_size=batch_size,
33 | shuffle=False,
34 | drop_last=True,
35 | sampler=sampler,
36 | pin_memory=pin_memory,
37 | num_workers=num_workers)
38 |
39 | valid = DataLoader(dataset=datasets["valid"],
40 | batch_size=batch_size,
41 | shuffle=False,
42 | drop_last=True,
43 | pin_memory=pin_memory,
44 | num_workers=num_workers)
45 |
46 | test = DataLoader(dataset=datasets["test"],
47 | batch_size=1,
48 | shuffle=False,
49 | pin_memory=pin_memory,
50 | num_workers=num_workers)
51 |
52 | return {"train": train, "valid": valid, "test": test}
--------------------------------------------------------------------------------
/src/datasets/eeg_epilepsy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pandas as pd
4 | from torch.utils.data import Dataset
5 | from sklearn.model_selection import train_test_split
6 |
7 | from src.transforms import Compose, FlipTime, Shift, FlipPolarity, GuassianNoise
8 | from src.datasets.utils import calculate_sample_weights
9 |
10 |
11 | def get_eeg_epilepsy(data_dir, valid_prop=0.10, test_prop=0.10, seed=1234):
12 |
13 | data = pd.read_csv(data_dir / "eeg_epilepsy/dataset.csv")
14 | x, y = data.drop(columns=["Unnamed: 0", "y"]), data["y"]
15 |
16 | x, x_test, y, y_test = train_test_split(x, y, test_size=test_prop, shuffle=True, random_state=seed)
17 | x_train, x_valid, y_train, y_valid = train_test_split(x, y, test_size=valid_prop, shuffle=True, random_state=seed)
18 |
19 | train_sample_weights = calculate_sample_weights(y_train)
20 |
21 | reverse = FlipTime(p=0.5)
22 | shift = Shift(p=0.5)
23 | flip = FlipPolarity(p=0.5)
24 | noise = GuassianNoise(min_amplitude=0.01, max_amplitude=1.0, p=0.5)
25 | transforms = Compose([reverse, flip, shift, noise])
26 |
27 | datasets = {}
28 | for stage, x, y in zip(["train", "valid", "test"], [x_train, x_valid, x_test], [y_train, y_valid, y_test]):
29 |
30 | dataset = EEGEpilepsyDataset(x, y, transforms=transforms if stage=="train" else None)
31 |
32 | if stage == "train":
33 | dataset.sample_weights = train_sample_weights
34 |
35 | datasets[stage] = dataset
36 |
37 | return datasets
38 |
39 |
40 | class EEGEpilepsyDataset(Dataset):
41 |
42 | def __init__(self, data, label, transforms=None):
43 | self.data = data.values
44 | self.label = label.values - 1
45 | self.transforms = transforms
46 | self.num_classes = 5
47 |
48 | def __len__(self):
49 | return len(self.data)
50 |
51 | def __getitem__(self, idx):
52 | x, y = self.data[idx], self.label[idx]
53 |
54 | if self.transforms:
55 | x = self.transforms(x, sample_rate=None)
56 |
57 | x = torch.from_numpy(x).unsqueeze(0).float()
58 | y = torch.tensor(y).long()
59 |
60 | return x, y
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import uuid
2 | import hydra
3 | import wandb
4 | import pathlib
5 |
6 | from src.loaders import get_dataloaders
7 | from src.models import get_model
8 | from src.optimisers import get_optimiser
9 | from src.schedulers import get_scheduler
10 | from src.losses import get_loss
11 | from src.run import run
12 | from src.utils import set_random_seed, get_device, load_config, save_config, to_dict
13 |
14 |
15 | @hydra.main(config_path="config", config_name="config.yaml")
16 | def train(cfg):
17 | # set random seed for reproducibility
18 | set_random_seed(seed=cfg.train.seed,
19 | is_gpu=cfg.train.is_gpu)
20 |
21 | # get training backend
22 | device = get_device(is_gpu=cfg.train.is_gpu,
23 | gpu_number=cfg.train.gpu_number)
24 |
25 | # unique id
26 | experiment_id = cfg.experiment.id if cfg.experiment.id is not None else uuid.uuid4().hex[:8]
27 |
28 | # initialise logging
29 | if cfg.logging.wb_logging: wandb.init(project=cfg.logging.wb_project, id=experiment_id)
30 |
31 | model_name, ensemble = cfg.model.name, cfg.model.ensemble
32 | models_dir = pathlib.Path("./models") / ((model_name + "_" + ensemble) if ensemble is not None else model_name) / ("run_" + str(cfg.experiment.run)) / experiment_id
33 | models_dir.mkdir(parents=True)
34 |
35 | # initalise dataloaders
36 | dataloaders = get_dataloaders(**to_dict(cfg.data),
37 | seed=cfg.train.seed + cfg.train.run,
38 | device=device)
39 |
40 | # initialise model
41 | model = get_model(**to_dict(cfg.model)).to(device)
42 |
43 | # initialise loss
44 | loss = get_loss(ensemble=ensemble, **to_dict(cfg.loss))
45 |
46 | # initialise optimiser
47 | optimiser = get_optimiser(model=model, **to_dict(cfg.optimiser))
48 |
49 | # initialise scheduler
50 | scheduler = get_scheduler(optimiser=optimiser, **to_dict(cfg.scheduler))
51 |
52 | # train model
53 | run(model=model,
54 | train_loader=dataloaders["train"],
55 | valid_loader=dataloaders["valid"],
56 | criterion=loss,
57 | optimiser=optimiser,
58 | scheduler=scheduler,
59 | num_epochs=cfg.train.num_epochs,
60 | save_dir=models_dir,
61 | device=device,
62 | wb_logging=cfg.logging.wb_logging)
63 |
64 | # save hyperparameters
65 | save_config(cfg, models_dir)
66 |
67 |
68 | if __name__ == "__main__":
69 | train()
--------------------------------------------------------------------------------
/src/run.py:
--------------------------------------------------------------------------------
1 | import time
2 | import wandb
3 | import torch
4 | import copy
5 |
6 | from src.metrics import total_corrrect
7 |
8 |
9 | def _save_model(model, save_dir, is_checkpoint=False):
10 |
11 | if is_checkpoint:
12 | torch.save(model.state_dict(), save_dir / "best_model.pth.tar")
13 |
14 | else:
15 | torch.save(model.state_dict(), save_dir / "last_model.pth.tar")
16 |
17 |
18 | def _train_epoch(model, dataloader, criterion, optimiser, device):
19 |
20 | model.train()
21 |
22 | epoch_loss = 0
23 | epoch_acc = 0
24 |
25 | for x, y in dataloader:
26 | # transfer signal, y to device
27 | x, y = x.to(device), y.to(device).reshape(-1)
28 | # clear gradients of model parameters
29 | optimiser.zero_grad()
30 | # forward pass
31 | logits = model(x)
32 | # calculate metrics
33 | loss = criterion(logits, y)
34 | correct = total_corrrect(logits, y)
35 | # backward pass
36 | loss.backward()
37 | # update model parameters
38 | optimiser.step()
39 | # accumulate loss over batch
40 | epoch_loss += loss.item() / len(dataloader)
41 | epoch_acc += (100 * correct.item()) / len(dataloader)
42 |
43 | break
44 |
45 | return epoch_loss, epoch_acc
46 |
47 |
48 | def _valid_epoch(model, dataloader, criterion, device):
49 |
50 | model.eval()
51 |
52 | epoch_loss = 0
53 | epoch_acc = 0
54 |
55 | for x, y in dataloader:
56 | # transfer x, y to device
57 | x, y = x.to(device), y.to(device).reshape(-1)
58 | # do not calculate gradients
59 | with torch.no_grad():
60 | # forward pass
61 | logits = model(x)
62 | # calculate metrics
63 | loss = criterion(logits, y)
64 | correct = total_corrrect(logits, y)
65 | # accumulate loss over batch
66 | epoch_loss += loss.item() / len(dataloader)
67 | epoch_acc += (100 * correct.item()) / len(dataloader)
68 |
69 | break
70 |
71 | return epoch_loss, epoch_acc
72 |
73 |
74 | def run(model, train_loader, valid_loader, criterion, optimiser, scheduler, num_epochs, save_dir, device=torch.device("cpu"), wb_logging=False):
75 |
76 | if wb_logging: wandb.watch(model)
77 |
78 | train_time = 0.
79 | best_valid_acc = -1.
80 |
81 | for epoch in range(num_epochs):
82 | start_time = time.time()
83 |
84 | train_loss, train_acc = _train_epoch(model, train_loader, criterion, optimiser, device)
85 | valid_loss, valid_acc = _valid_epoch(model, valid_loader, criterion, device)
86 |
87 | is_best = valid_acc > best_valid_acc
88 | if is_best:
89 | best_valid_acc = valid_acc
90 | _save_model(model, save_dir, is_checkpoint=True)
91 |
92 | if scheduler is not None:
93 | scheduler.step()
94 |
95 | end_time = time.strftime("%H:%M:%S", time.gmtime(time.time() - start_time))
96 | to_print = "{} | epoch {:4d} of {:4d} | train loss {:06.3f} | train acc {:05.2f} | valid loss {:06.3f} | valid acc {:05.2f} | time: {} "
97 | if is_best: to_print = to_print + "| *"
98 | print(to_print.format(save_dir.stem, epoch + 1, num_epochs, train_loss, train_acc, valid_loss, valid_acc, end_time))
99 |
100 | if wb_logging: wandb.log(dict(train={"loss": train_loss, "acc": train_acc}, valid={"loss": valid_loss, "acc": valid_acc}))
101 |
102 | _save_model(model, save_dir, train_time=train_time)
103 |
--------------------------------------------------------------------------------
/src/transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 |
4 |
5 | class BaseTransform:
6 |
7 | def __init__(self, p):
8 | """
9 | p: Probability of applying transform.
10 | """
11 |
12 | assert 0 <= p <= 1
13 | self.p = p
14 |
15 | def apply(self, signals):
16 | raise NotImplementedError
17 |
18 | def __call__(self, signals):
19 |
20 | if random.random() < self.p:
21 | signals = self.apply(signals)
22 |
23 | return signals
24 |
25 |
26 | class Compose:
27 |
28 | def __init__(self, transforms, p=1):
29 | """
30 | transforms: List of transforms to apply to signals.
31 | p: Probability of applying the list of transforms.
32 | """
33 |
34 | assert 0 <= p <= 1
35 |
36 | self.transforms = transforms
37 | self.p = p
38 |
39 | def __call__(self, signals):
40 |
41 | if random.random() < self.p:
42 | for transform in self.transforms:
43 | signals = transform(signals)
44 |
45 | return signals
46 |
47 |
48 | class FlipTime(BaseTransform):
49 | """Randomly flip signals along temporal dimension"""
50 |
51 | def __init__(self, p=0.5):
52 | """
53 | p: Probability of applying transform.
54 | """
55 | super().__init__(p)
56 |
57 |
58 | def apply(self, signals):
59 |
60 | if len(signals.shape) > 1:
61 | signals = np.fliplr(signals)
62 |
63 | else:
64 | signals = np.flipud(signals)
65 |
66 | return signals
67 |
68 |
69 | class MaskTime(BaseTransform):
70 | """Randomly mask signal"""
71 |
72 | def __init__(self, min_fraction=0.0, max_fraction=0.5, p=0.5):
73 | """
74 | min_fraction: Minimum length of the mask as a fraction of the total time series length.
75 | max_fraction: Maximum length of the mask as a fraction of the total time series length.
76 | p: Probability of applying transform.
77 | """
78 |
79 | super().__init__(p)
80 |
81 | assert 0 <= min_fraction <= 1
82 | assert 0 <= max_fraction <= 1
83 | assert max_fraction >= min_fraction
84 |
85 | self.min_fraction = min_fraction
86 | self.max_fraction = max_fraction
87 |
88 |
89 | def apply(self, signals):
90 |
91 | num_samples = signals.shape[-1]
92 | length = random.randint(int(num_samples * self.min_fraction), int(num_samples * self.max_fraction))
93 | start = random.randint(0, num_samples - length)
94 |
95 | mask = np.zeros(length)
96 | masked_signals = signals.copy()
97 | masked_signals[..., start : start + length] *= mask
98 |
99 | return masked_signals
100 |
101 |
102 | class Shift(BaseTransform):
103 | """Shift the signals forwards or backwards along the temporal dimension"""
104 |
105 | def __init__(self, min_fraction=-0.5, max_fraction=0.5, rollover=True, p=0.5):
106 | """
107 | min_fraction: Fraction of total timeseries to shift.
108 | max_fraction: Fraction of total timeseries to shift.
109 | rollover: Samples that roll beyond the first or last position are re-introduced at the last or first otherwise set to zero.
110 | p: Probability of applying this transform.
111 | """
112 |
113 | super().__init__(p)
114 |
115 | assert min_fraction >= -1
116 | assert max_fraction <= 1
117 |
118 | self.min_fraction = min_fraction
119 | self.max_fraction = max_fraction
120 | self.rollover = rollover
121 |
122 | def apply(self, signals):
123 |
124 | num_samples = signals.shape[-1]
125 | num_shift = int(round(random.uniform(self.min_fraction, self.max_fraction) * num_samples))
126 | signals = np.roll(signals, num_shift, axis=-1)
127 |
128 | if not self.rollover:
129 | if num_shift > 0:
130 | signals[..., :num_shift] = 0.0
131 |
132 | elif num_shift < 0:
133 | signals[..., num_shift:] = 0.0
134 |
135 | return signals
136 |
137 |
138 | class FlipPolarity(BaseTransform):
139 | """Randomly flip sign of signal"""
140 |
141 | def __init__(self, p=0.5):
142 | """
143 | p: Probability of applying transform.
144 | """
145 | super().__init__(p)
146 |
147 | def apply(self, signals):
148 | return -signals
149 |
150 |
151 | class GuassianNoise(BaseTransform):
152 | """Add gaussian noise to the signals"""
153 |
154 | def __init__(self, min_amplitude=0.001, max_amplitude=0.015, p=0.5):
155 | """
156 | min_amplitude: minimum amplitude of noise.
157 | max_amplitude: maximum amplitude of noise.
158 | p: Probability of applying this transform.
159 | """
160 | super().__init__(p)
161 |
162 | assert min_amplitude > 0.0
163 | assert max_amplitude > 0.0
164 | assert max_amplitude >= min_amplitude
165 |
166 | self.min_amplitude = min_amplitude
167 | self.max_amplitude = max_amplitude
168 |
169 | def apply(self, signals):
170 |
171 | amplitude = random.uniform(self.min_amplitude, self.max_amplitude)
172 |
173 | noise = np.random.randn(*signals.shape).astype(np.float32)
174 | signals = signals + amplitude * noise
175 |
176 | return signals
--------------------------------------------------------------------------------
/src/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torchmetrics.functional as tm
4 | from torch.distributions import Categorical
5 |
6 |
7 | def count_parameters(model):
8 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
9 |
10 |
11 | def F1(logits, labels, ensemble_weights, average="weighted"):
12 |
13 |
14 | _, num_exits, num_classes = logits.shape
15 | scale = ensemble_weights.sum()
16 |
17 | pred_labels = logits.softmax(dim=-1).mul(ensemble_weights).sum(dim=-2).div(scale).argmax(-1)
18 |
19 | f1 = tm.f1(pred_labels, labels, num_classes=num_classes, average=average)
20 |
21 | return f1
22 |
23 |
24 | def precision(logits, labels, ensemble_weights, average="weighted"):
25 |
26 | _, num_exits, num_classes = logits.shape
27 | scale = ensemble_weights.sum()
28 |
29 | pred_labels = logits.softmax(dim=-1).mul(ensemble_weights).sum(dim=-2).div(scale).argmax(-1)
30 |
31 | pr = tm.precision(pred_labels, labels, num_classes=num_classes, average=average)
32 |
33 | return pr
34 |
35 |
36 | def recall(logits, labels, ensemble_weights, average="weighted"):
37 |
38 | _, num_exits, num_classes = logits.shape
39 | scale = ensemble_weights.sum()
40 |
41 | pred_labels = logits.softmax(dim=-1).mul(ensemble_weights).sum(dim=-2).div(scale).argmax(-1)
42 |
43 | rc = tm.recall(pred_labels, labels, num_classes=num_classes, average=average)
44 |
45 | return rc
46 |
47 |
48 | def negative_loglikelihood(logits, labels, ensemble_weights, reduction="mean"):
49 |
50 | _, num_exits, num_classes = logits.shape
51 | scale = ensemble_weights.sum()
52 |
53 | probs = logits.softmax(dim=-1).mul(ensemble_weights).sum(dim=-2).div(scale)
54 |
55 | nll = -Categorical(probs=probs).log_prob(labels)
56 |
57 | if reduction == "mean":
58 | nll = nll.mean()
59 |
60 | return nll
61 |
62 |
63 | def brier_score(logits, labels, ensemble_weights, reduction="mean"):
64 |
65 | _, num_exits, num_classes = logits.shape
66 | scale = ensemble_weights.sum()
67 |
68 | probs = logits.softmax(dim=-1).mul(ensemble_weights).sum(dim=-2).div(scale)
69 |
70 | labels_one_hot = F.one_hot(labels, num_classes=num_classes)
71 |
72 | bs = ((probs - labels_one_hot)**2).sum(dim=-1)
73 |
74 | if reduction == "mean":
75 | bs = bs.mean()
76 |
77 | return bs
78 |
79 |
80 | def predictive_entropy(logits, labels, ensemble_weights, reduction="mean"):
81 |
82 | _, num_exits, num_classes = logits.shape
83 | scale = ensemble_weights.sum()
84 |
85 | probs = logits.softmax(dim=-1).mul(ensemble_weights).sum(dim=-2).div(scale)
86 |
87 | et = Categorical(probs=probs).entropy()
88 |
89 | if reduction == "mean":
90 | et = et.mean()
91 |
92 | return et
93 |
94 |
95 | def predictive_confidence(logits, labels, ensemble_weights, reduction="mean"):
96 |
97 | num_samples, _, _ = logits.shape
98 | scale = ensemble_weights.sum()
99 |
100 | probs = logits.softmax(dim=-1).mul(ensemble_weights).sum(dim=-2).div(scale)
101 |
102 | pc = probs[torch.arange(num_samples), labels]
103 |
104 | if reduction == "mean":
105 | pc = pc.mean()
106 |
107 | return pc
108 |
109 |
110 | def expected_calibration_error(logits, labels, ensemble_weights, n_bins=15):
111 |
112 | num_samples, num_exits, num_classes = logits.shape
113 | scale = ensemble_weights.sum()
114 |
115 | pred_probs = logits.softmax(dim=-1).mul(ensemble_weights).sum(dim=-2).div(scale)
116 | pred_labels = pred_probs.argmax(-1)
117 |
118 | pred_probs = pred_probs[torch.arange(num_samples), pred_labels]
119 |
120 | correct = pred_labels.eq(labels)
121 |
122 | bin_boundaries = torch.linspace(0, 1, n_bins + 1)
123 |
124 | conf_bin = torch.zeros_like(bin_boundaries)
125 | acc_bin = torch.zeros_like(bin_boundaries)
126 | prop_bin = torch.zeros_like(bin_boundaries)
127 |
128 | for i, (bin_lower, bin_upper) in enumerate(zip(bin_boundaries[:-1], bin_boundaries[1:])):
129 |
130 | in_bin = pred_probs.gt(bin_lower.item()) * pred_probs.le(bin_upper.item())
131 | prop_in_bin = in_bin.float().mean()
132 |
133 | if prop_in_bin.item() > 0:
134 | # probability of making a correct prediction given a probability bin
135 | acc_bin[i] = correct[in_bin].float().mean()
136 | # average predicted probabily given a probability bin.
137 | conf_bin[i] = pred_probs[in_bin].mean()
138 | # probability of observing a probability bin
139 | prop_bin[i] = prop_in_bin
140 |
141 | ece = ((acc_bin - conf_bin).abs() * prop_bin).sum()
142 |
143 | return ece
144 |
145 |
146 | def calculate_metrics(model, logits, labels, ensemble_weights):
147 |
148 | metrics = dict(f1=F1(logits, labels, ensemble_weights, average="weighted").numpy(),
149 | precision=precision(logits, labels, ensemble_weights, average="weighted").numpy(),
150 | recall=recall(logits, labels, ensemble_weights, average="weighted").numpy(),
151 | negative_loglikelihood=negative_loglikelihood(logits, labels, ensemble_weights, reduction="mean").numpy(),
152 | brier_score=brier_score(logits, labels, ensemble_weights, reduction="mean").numpy(),
153 | predictive_entropy=predictive_entropy(logits, labels, ensemble_weights, reduction="mean").numpy(),
154 | expected_calibration_error=expected_calibration_error(logits, labels, ensemble_weights, n_bins=15).numpy(),
155 | params=count_parameters(model))
156 |
157 | return metrics
--------------------------------------------------------------------------------
/src/models/res_net_18.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | def get_res_net_18(ensemble, **kwargs):
8 |
9 | if ensemble is None:
10 | return ResNet18(**kwargs)
11 |
12 | elif ensemble == "early_exit":
13 | return ResNet18EarlyExit(**kwargs)
14 |
15 | elif ensemble == "mc_dropout":
16 | return ResNet18MCDrop(**kwargs)
17 |
18 | elif ensemble == "deep":
19 | return ResNet18(**kwargs)
20 |
21 | elif ensemble == "depth":
22 | return ResNet18Depth(**kwargs)
23 |
24 | else:
25 | NotImplementedError("ensemble not implemented: '{}'".format(ensemble))
26 |
27 |
28 | def init_weights(model):
29 |
30 | for module in model.modules():
31 |
32 | if isinstance(module, nn.Conv1d):
33 | nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
34 |
35 | elif isinstance(module, nn.BatchNorm1d):
36 | nn.init.constant_(module.weight, 1)
37 | nn.init.constant_(module.bias, 0)
38 |
39 |
40 | def conv3x3(in_planes, out_planes, stride=1):
41 | return nn.Conv1d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
42 |
43 |
44 | def _conv1x1(in_planes, out_planes, stride=1):
45 | return nn.Conv1d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
46 |
47 |
48 | class BasicBlock(nn.Module):
49 |
50 | def __init__(self, inplanes, planes, stride=1, downsample=None):
51 | super().__init__()
52 |
53 | self.conv1 = conv3x3(inplanes, planes, stride)
54 | self.bn1 = nn.BatchNorm1d(planes)
55 | self.relu = nn.ReLU(inplace=True)
56 | self.conv2 = conv3x3(planes, planes)
57 | self.bn2 = nn.BatchNorm1d(planes)
58 | self.downsample = downsample
59 | self.stride = stride
60 |
61 | def forward(self, x):
62 |
63 | identity = x
64 | out = self.conv1(x)
65 | out = self.bn1(out)
66 | out = self.relu(out)
67 | out = self.conv2(out)
68 | out = self.bn2(out)
69 |
70 | if self.downsample is not None:
71 | identity = self.downsample(x)
72 |
73 | out += identity
74 | out = self.relu(out)
75 |
76 | return out
77 |
78 |
79 | class ResNet18(nn.Module):
80 |
81 | name = "res_net_18"
82 |
83 | def __init__(self, out_channels, seed=None):
84 | super().__init__()
85 |
86 | self.out_channels = out_channels
87 | self.seed = seed
88 |
89 | self.hidden_sizes = [64, 128, 256, 512]
90 | self.layers = [2, 2, 2, 2]
91 | self.strides = [1, 2, 2, 2]
92 | self.inplanes = self.hidden_sizes[0]
93 |
94 | in_block = [nn.Conv1d(1, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)]
95 | in_block += [nn.BatchNorm1d(self.inplanes)]
96 | in_block += [nn.ReLU(inplace=True)]
97 | in_block += [nn.MaxPool1d(kernel_size=3, stride=2, padding=1)]
98 | self.in_block = nn.Sequential(*in_block)
99 |
100 | blocks = []
101 | for h, l, s in zip(self.hidden_sizes, self.layers, self.strides):
102 | blocks += [self._make_layer(h, l, s)]
103 | self.blocks = nn.Sequential(*blocks)
104 |
105 | out_block = [nn.AdaptiveAvgPool1d(1)]
106 | out_block += [nn.Flatten(1)]
107 | out_block += [nn.Linear(self.hidden_sizes[-1], self.out_channels)]
108 | self.out_block = nn.Sequential(*out_block)
109 |
110 | if self.seed is not None:
111 | torch.manual_seed(seed)
112 |
113 | self.apply(init_weights)
114 |
115 | def _make_layer(self, planes, blocks, stride=1):
116 |
117 | downsample = None
118 |
119 | if stride != 1 or self.inplanes != planes:
120 | downsample = nn.Sequential(_conv1x1(self.inplanes, planes, stride), nn.BatchNorm1d(planes))
121 |
122 | layers = [BasicBlock(self.inplanes, planes, stride, downsample)]
123 | self.inplanes = planes
124 |
125 | for _ in range(1, blocks):
126 | layers += [BasicBlock(self.inplanes, planes)]
127 |
128 | return nn.Sequential(*layers)
129 |
130 | def forward(self, x):
131 |
132 | x = self.in_block(x)
133 | x = self.blocks(x)
134 | x = self.out_block(x)
135 |
136 | return x
137 |
138 |
139 | class ExitBlock(nn.Module):
140 |
141 | def __init__(self, in_channels, hidden_sizes, out_channels):
142 | super().__init__()
143 |
144 | layers = [nn.AdaptiveAvgPool1d(1)]
145 | layers += [nn.Flatten(1)]
146 | layers += [nn.Linear(in_channels, hidden_sizes)]
147 | layers += [nn.ReLU()]
148 | layers += [nn.Linear(hidden_sizes, out_channels)]
149 | self.layers = nn.Sequential(*layers)
150 |
151 | def forward(self, x):
152 |
153 | return self.layers(x)
154 |
155 |
156 | class ResNet18EarlyExit(ResNet18):
157 |
158 | name = "res_net_18_early_exit"
159 |
160 | def __init__(self, *args, exit_after=-1, complexity_factor=1.2, **kwargs):
161 | self.exit_after = exit_after
162 | self.complexity_factor = complexity_factor
163 |
164 | super().__init__(*args, **kwargs)
165 |
166 | to_exit = [2, 8, 15, 24, 31, 40, 47, 56]
167 | hidden_sizes = len(self.hidden_sizes)
168 |
169 | num_hidden = len(self.hidden_sizes)
170 | exit_hidden_sizes = [int(((self.complexity_factor ** 0.5) ** (num_hidden - idx)) * self.hidden_sizes[-1]) for idx in range(num_hidden)]
171 | exit_hidden_sizes = [h for pair in zip(exit_hidden_sizes, exit_hidden_sizes) for h in pair]
172 |
173 | if self.exit_after == -1:
174 | self.exit_after = range(len(to_exit))
175 |
176 | num_exits = len(to_exit)
177 |
178 | if (len(self.exit_after) > num_exits) or not set(self.exit_after).issubset(list(range(num_exits))):
179 | raise ValueError("valid exit points: {}".format(", ".join(str(n) for n in range(num_exits))))
180 |
181 | self.exit_hidden_sizes = np.array(exit_hidden_sizes)[self.exit_after]
182 |
183 | blocks = []
184 | for idx, module in enumerate(self.blocks.modules()):
185 | if idx in to_exit:
186 | blocks += [module]
187 | self.blocks = nn.ModuleList(blocks)
188 |
189 | idx = 0
190 | exit_blocks = []
191 | for block_idx, block in enumerate(self.blocks):
192 | if block_idx in self.exit_after:
193 | in_channels = block.conv1.out_channels
194 | exit_blocks += [ExitBlock(in_channels, self.exit_hidden_sizes[idx], self.out_channels)]
195 | idx += 1
196 | self.exit_blocks = nn.ModuleList(exit_blocks)
197 |
198 | self.apply(init_weights)
199 |
200 | def forward(self, x):
201 |
202 | out = self.in_block(x)
203 |
204 | out_blocks = []
205 | for block in self.blocks:
206 | out = block(out)
207 | out_blocks += [out]
208 |
209 | out_exits = []
210 | for exit_after, exit_block in zip(self.exit_after, self.exit_blocks):
211 | out = exit_block(out_blocks[exit_after])
212 | out_exits += [out]
213 |
214 | out = self.out_block(out_blocks[-1])
215 | out = torch.stack(out_exits + [out], dim=1)
216 |
217 | return out
218 |
219 |
220 | class MCDropout(nn.Dropout):
221 |
222 | def forward(self, x):
223 | return F.dropout(x, self.p, True, self.inplace)
224 |
225 |
226 | class ResNet18MCDrop(ResNet18EarlyExit):
227 |
228 | name = "res_net_18_mc_drop"
229 |
230 | def __init__(self, *args, drop_after=-1, drop_prob=0.2, **kwargs):
231 | self.drop_after = drop_after
232 | self.drop_prob = drop_prob
233 |
234 | super().__init__(*args, exit_after=drop_after, **kwargs)
235 |
236 | self.drop_after = self.exit_after
237 |
238 | self.__delattr__("exit_after")
239 | self.__delattr__("exit_blocks")
240 |
241 | for block_idx in self.drop_after:
242 | self.blocks[block_idx].add_module("dropout", MCDropout(self.drop_prob))
243 |
244 | def forward(self, x):
245 |
246 | x = self.in_block(x)
247 | x = self.blocks(x)
248 | x = self.out_block(x)
249 |
250 | return x
251 |
252 |
253 | class ResNet18Depth(ResNet18):
254 |
255 | name = "res_net_18_depth"
256 |
257 | def __init__(self, *args, max_depth=1, **kwargs):
258 | self.max_depth = max_depth
259 |
260 | super().__init__(*args, **kwargs)
261 |
262 | num_blocks = len(self.hidden_sizes)
263 |
264 | if self.max_depth == -1:
265 | self.max_depth = len(self.hidden_sizes)
266 |
267 | elif (max_depth > num_blocks) or (max_depth < 1):
268 | raise ValueError("valid depths: {}".format(", ".join(str(n) for n in range(1, num_blocks + 1))))
269 |
270 | self.blocks = self.blocks[:self.max_depth]
271 |
272 | out_block = [nn.AdaptiveAvgPool1d(1)]
273 | out_block += [nn.Flatten(1)]
274 | out_block += [nn.Linear(self.hidden_sizes[self.max_depth - 1], self.out_channels)]
275 | self.out_block = nn.Sequential(*out_block)
--------------------------------------------------------------------------------