├── LoadingData.py ├── README.md ├── SelectiveNet.py ├── checkpoints └── README.md ├── test.py └── train.py /LoadingData.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Thu Aug 8, 2019 3 | @author: HyunsuKim6(Github), hyunsukim@kaist.ac.kr 4 | """ 5 | 6 | import torch 7 | import torchvision 8 | from torchvision import transforms 9 | 10 | 11 | def load_data(purpose, batch_size=10, num_workers=12): 12 | transform = transforms.Compose([transforms.ToTensor(), 13 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 14 | if purpose == 'train': 15 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 16 | download=True, transform=transform) 17 | dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, 18 | shuffle=True, num_workers=num_workers) 19 | 20 | elif purpose == 'test': 21 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, 22 | download=True, transform=transform) 23 | dataloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, 24 | shuffle=False, num_workers=num_workers) 25 | 26 | else: 27 | print("Incorrect input: Please enter correct purpose input") 28 | return 29 | 30 | return dataloader 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SelectiveNet-Pytorch 2 | 3 | The author of paper has uploaded code written in Keras, but I thought some people are familiar with Pytorch, so I implemented it in Pytorch. 4 | 5 | ## Requirements 6 | 7 | You will need the following to run the above: 8 | - Pytorch 9 | - Python3, Numpy, Matplotlib, tqdm 10 | 11 | Note that I run the code with Windows 10, Pytorch 0.4.1, CUDA 10.1 12 | 13 | ### Training 14 | Use `train.py` to train the network. Example usage: 15 | ```bash 16 | # Example usage 17 | python train.py 18 | ``` 19 | 20 | ### Testing 21 | Use `test.py` to test the network. Example usage: 22 | ```bash 23 | # Example usage 24 | python test.py 25 | ``` 26 | 27 | ## References 28 | 29 | - [SelectiveNet: A Deep Neural Network with an Integrated Reject Option][1] 30 | - I referred to the SelectiveNet paper. 31 | 32 | - [geifmany/selectivenet][2] 33 | - There is author's repository. 34 | 35 | [1]: https://arxiv.org/abs/1901.09192 36 | [2]: https://github.com/geifmany/selectivenet 37 | -------------------------------------------------------------------------------- /SelectiveNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Thu Aug 8, 2019 3 | @author: HyunsuKim6(Github), hyunsukim@kaist.ac.kr 4 | """ 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | class SelectiveNet(nn.Module): 10 | def __init__(self, features, num_classes=10, init_weights=True): 11 | super(SelectiveNet, self).__init__() 12 | self.features = features 13 | 14 | self.classifier = nn.Sequential( 15 | nn.Linear(512, 512), 16 | nn.ReLU(True), 17 | nn.Dropout(), 18 | nn.Linear(512, num_classes) 19 | ) 20 | 21 | self.aux_classifier = nn.Sequential( 22 | nn.Linear(512, 512), 23 | nn.ReLU(True), 24 | nn.Dropout(), 25 | nn.Linear(512, num_classes) 26 | ) 27 | 28 | self.selector = nn.Sequential( 29 | nn.Linear(512, 512), 30 | nn.ReLU(True), 31 | nn.BatchNorm1d(512), 32 | nn.Linear(512, 1), 33 | nn.Sigmoid() 34 | ) 35 | 36 | if init_weights: 37 | self._initialize_weights() 38 | 39 | def forward(self, x): 40 | x = self.features(x) 41 | x = x.view(x.size(0), -1) 42 | 43 | # classification head (f) 44 | 45 | prediction_output = self.classifier(x) 46 | 47 | # selection head (g) 48 | 49 | selection_output = self.selector(x) 50 | 51 | # auxiliary head (h) 52 | 53 | auxiliary_output = self.aux_classifier(x) 54 | 55 | return prediction_output, selection_output, auxiliary_output 56 | 57 | def _initialize_weights(self): 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 61 | if m.bias is not None: 62 | nn.init.constant_(m.bias, 0) 63 | elif isinstance(m, nn.BatchNorm2d): 64 | nn.init.constant_(m.weight, 1) 65 | nn.init.constant_(m.bias, 0) 66 | elif isinstance(m, nn.Linear): 67 | nn.init.normal_(m.weight, 0, 0.01) 68 | nn.init.constant_(m.bias, 0) 69 | 70 | 71 | def make_layers(cfg, batch_norm=False): 72 | layers = [] 73 | in_channels = 3 74 | for v in cfg: 75 | if v == 'M': 76 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 77 | else: 78 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 79 | if batch_norm: 80 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 81 | else: 82 | layers += [conv2d, nn.ReLU(inplace=True)] 83 | in_channels = v 84 | return nn.Sequential(*layers) 85 | 86 | 87 | cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] 88 | 89 | 90 | def _vgg(batch_norm, **kwargs): 91 | model = SelectiveNet(make_layers(cfg, batch_norm=batch_norm), **kwargs) 92 | return model 93 | 94 | 95 | def SelectiveNet_vgg16(**kwargs): 96 | return _vgg(False, **kwargs) 97 | 98 | 99 | def SelectiveNet_vgg16_bn(**kwargs): 100 | return _vgg(True, **kwargs) 101 | 102 | 103 | class OverAllLoss(nn.Module): 104 | def __init__(self, alpha=0.5, lambda=32, coverage=0.7): 105 | super(OverAllLoss, self).__init__() 106 | self.alpha = alpha 107 | self.lambda = lambda 108 | self.coverage = coverage 109 | 110 | def forward(self, prediction_input, selection_input, aux_input, labels): 111 | sel_log_prob = -1.0 * F.log_softmax(prediction_input, 1) * selection_input 112 | sel_risk = sel_log_prob.gather(1, labels.unsqueeze(1)) 113 | sel_risk = sel_risk.mean() 114 | 115 | aux_log_prob = -1.0 * F.log_softmax(aux_input, 1) 116 | aux_loss = aux_log_prob.gather(1, labels.unsqueeze(1)) 117 | aux_loss = aux_loss.mean() 118 | 119 | emp_coverage = selection_input.mean() 120 | 121 | return self.alpha * (sel_risk / emp_coverage + self.lambda * max(self.coverage - emp_coverage, 0) ** 2) + ( 122 | 1 - self.alpha) * aux_loss \ 123 | , sel_risk, emp_coverage 124 | 125 | 126 | -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | * Checkpoints will be saved here. 2 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Thu Aug 8, 2019 3 | @author: HyunsuKim6(Github), hyunsukim@kaist.ac.kr 4 | """ 5 | 6 | from __future__ import print_function, division 7 | 8 | from torch.optim import lr_scheduler 9 | 10 | import time 11 | from tqdm import tqdm 12 | from SelectiveNet import * 13 | from LoadingData import * 14 | 15 | 16 | def test(model, test_dataloader, batch_size, pkl_name): 17 | # device setting 18 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 19 | model.to(device) 20 | print(torch.cuda.get_device_name(0)) 21 | 22 | model.load_state_dict(torch.load(pkl_name)) 23 | model.eval() 24 | 25 | since = time.time() 26 | 27 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 28 | class_correct = list(0. for i in range(10)) 29 | class_total = list(0. for i in range(10)) 30 | total_classes_correct = 0. 31 | total_classes_total = 0. 32 | test_coverage = 0 33 | first_time_flag = 0 34 | 35 | with torch.no_grad(): 36 | for data in tqdm(test_dataloader): 37 | inputs, labels = data 38 | 39 | inputs = inputs.to(device) 40 | labels = labels.to(device) 41 | 42 | pred_outputs, sel_outputs, aux_outputs = model(inputs) 43 | _, predicted = pred_outputs.max(1) 44 | 45 | for i in range(batch_size): 46 | if sel_outputs[i].item() >= 0.5: 47 | if labels[i].item() == predicted[i].item(): 48 | class_correct[labels[i].item()] += 1 49 | class_total[labels[i].item()] += 1 50 | 51 | if first_time_flag == 0: 52 | total_sel_outputs = sel_outputs 53 | first_time_flag += 1 54 | 55 | else: 56 | total_sel_outputs = torch.cat((total_sel_outputs, sel_outputs), 0) 57 | 58 | for i in range(10): 59 | print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i])) 60 | total_classes_correct += class_correct[i] 61 | total_classes_total += class_total[i] 62 | 63 | print('Total Accuracy : %2d %%' % (100 * total_classes_correct / total_classes_total)) 64 | test_coverage = total_sel_outputs.mean() 65 | print('Test Coverage : %2d %%' % (100 * test_coverage)) 66 | 67 | time_elapsed = time.time() - since 68 | 69 | print('Testing complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 70 | 71 | return 72 | 73 | 74 | if __name__ == "__main__": 75 | num_worker = 12 76 | pkl_name = 'checkpoints/SelectiveNet.pkl' # for testing 77 | 78 | model = SelectiveNet_vgg16_bn() 79 | optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 80 | criterion = OverAllLoss() 81 | scheduler = lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.5) 82 | 83 | # testing 84 | batch_size = 100 85 | test_dataloader = load_data('test', batch_size, num_worker) 86 | test(model, test_dataloader, batch_size, pkl_name) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Thu Aug 8, 2019 3 | @author: HyunsuKim6(Github), hyunsukim@kaist.ac.kr 4 | """ 5 | 6 | from __future__ import print_function, division 7 | 8 | from torch.optim import lr_scheduler 9 | import time 10 | import copy 11 | from SelectiveNet import * 12 | from LoadingData import * 13 | 14 | 15 | def train(model, dataloader, criterion, optimizer, scheduler, pkl_name, num_epochs=30): 16 | # device setting 17 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 18 | model.to(device) 19 | print(torch.cuda.get_device_name(0)) 20 | 21 | since = time.time() 22 | 23 | best_model = copy.deepcopy(model.state_dict()) 24 | best_acc = 0.0 25 | 26 | for epoch in range(num_epochs): 27 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 28 | print('-' * 10) 29 | 30 | scheduler.step() 31 | model.train() # Set model to training mode 32 | 33 | running_loss = 0.0 34 | running_corrects = 0 35 | sel_risk = 0 36 | curr_coverage = 0 37 | 38 | # Iterate over data. 39 | for index, data in enumerate(dataloader): 40 | inputs, labels = data 41 | 42 | inputs = inputs.to(device) 43 | labels = labels.to(device) 44 | 45 | # zero the parameter gradients 46 | optimizer.zero_grad() 47 | 48 | pred_outputs, sel_outputs, aux_outputs = model(inputs) 49 | _, preds = torch.max(pred_outputs, 1) 50 | loss, sel_risk, curr_coverage = criterion(pred_outputs, sel_outputs, aux_outputs, labels) 51 | 52 | loss.backward() 53 | optimizer.step() 54 | 55 | # statistics 56 | running_loss += loss.item() * inputs.size(0) 57 | running_corrects += torch.sum(preds == labels.data) 58 | 59 | if index % 100 == 0: 60 | print('Epoch {} [{}/{} ({:.0f}%)]: Loss: {:.6f} Selective risk: {:.6f} Current coverage: {:.6f}' 61 | .format(epoch, index * len(inputs), len(dataloader.dataset), 62 | 100. * index / len(dataloader), loss, sel_risk, curr_coverage)) 63 | 64 | epoch_loss = running_loss / len(dataloader.dataset) 65 | epoch_acc = running_corrects.double() / len(dataloader.dataset) 66 | 67 | print('Epoch {} result: Loss: {:.4f} Acc: {:.4f}'.format(epoch, epoch_loss, epoch_acc)) 68 | 69 | if epoch % 50 == 0: 70 | torch.save(model.state_dict(), pkl_name) 71 | 72 | print() 73 | 74 | time_elapsed = time.time() - since 75 | 76 | print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 77 | 78 | print("A model was saved") 79 | print("Training is done") 80 | 81 | return model 82 | 83 | 84 | if __name__ == "__main__": 85 | num_worker = 12 86 | pkl_name = 'checkpoints/SelectiveNet.pkl' 87 | 88 | model = SelectiveNet_vgg16_bn() 89 | optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 90 | criterion = OverAllLoss() 91 | scheduler = lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.5) 92 | 93 | # training 94 | batch_size = 128 95 | train_dataloader = load_data('train', batch_size, num_worker) 96 | best_model = train(model, train_dataloader, criterion, optimizer, scheduler, pkl_name, num_epochs=300) --------------------------------------------------------------------------------