├── .gitignore ├── LICENSE ├── README.md ├── attacks ├── __init__.py ├── attack_BSS.py └── helpers.py ├── models ├── __init__.py └── resnet.py ├── results └── Res26_C10 │ └── 320_epoch.t7 └── train_BSS_distillation.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | checkpoint/ 3 | .idea/ 4 | results/ 5 | models/__pycache__ 6 | *__pycache__ 7 | paper/ 8 | module/ 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Byeongho Heo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Knowledge Distillation with Adversarial Samples Supporting Decision Boundary 2 | 3 | Official Pytorch implementation of paper: 4 | 5 | [Knowledge Distillation with Adversarial Samples Supporting Decision Boundary](https://arxiv.org/abs/1805.05532) (AAAI 2019). 6 | 7 | Sporlight and poster are available on [homepage](https://sites.google.com/view/byeongho-heo/home) 8 | 9 | ## Environment 10 | Python 3.6, Pytorch 0.4.1, Torchvision 11 | 12 | 13 | ## Knowledge distillation [(CIFAR-10)](https://www.cs.toronto.edu/~kriz/cifar.html) 14 | 15 | ```shell 16 | python train_BSS_distillation.py 17 | ``` 18 | 19 | 20 | Distillation from ResNet 26 (teacher) to ResNet 10 (student) on CIFAR-10 dataset. 21 | 22 | Pre-trained teacher network (ResNet 26) is included. 23 | 24 | 25 | ## Citation 26 | 27 | ``` 28 | @inproceedings{BSSdistill, 29 | title = {Knowledge Distillation with Adversarial Samples Supporting Decision Boundary}, 30 | author = {Byeongho Heo, Minsik Lee, Sangdoo Yun, Jin Young Choi}, 31 | booktitle = {AAAI Conference on Artificial Intelligence (AAAI)}, 32 | year = {2019} 33 | } 34 | ``` 35 | -------------------------------------------------------------------------------- /attacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .attack_BSS import AttackBSS -------------------------------------------------------------------------------- /attacks/attack_BSS.py: -------------------------------------------------------------------------------- 1 | from torch import autograd 2 | from torch.autograd.gradcheck import zero_gradients 3 | import torch.nn.functional as F 4 | from .helpers import * 5 | 6 | 7 | class AttackBSS: 8 | 9 | def __init__( 10 | self, 11 | targeted=True, max_epsilon=16, norm=float('inf'), 12 | step_alpha=None, num_steps=None, cuda=True, debug=False): 13 | 14 | self.targeted = targeted 15 | self.eps = 5.0 * max_epsilon / 255.0 16 | self.num_steps = num_steps or 10 17 | self.norm = norm 18 | if not step_alpha: 19 | if norm == float('inf'): 20 | self.step_alpha = self.eps / self.num_steps 21 | else: 22 | if norm == 1: 23 | self.step_alpha = 500.0 24 | else: 25 | self.step_alpha = 1.0 26 | else: 27 | self.step_alpha = step_alpha 28 | self.loss_fn = torch.nn.CrossEntropyLoss(size_average=False) 29 | if cuda: 30 | self.loss_fn = self.loss_fn.cuda() 31 | self.debug = debug 32 | 33 | def run(self, model, input, target, batch_idx=0): 34 | input_var = autograd.Variable(input, requires_grad=True) 35 | target_var = autograd.Variable(target) 36 | GT_var = autograd.Variable(target) 37 | eps = self.eps 38 | 39 | step = 0 40 | while step < self.num_steps: 41 | zero_gradients(input_var) 42 | output = model(input_var) 43 | 44 | if not step: 45 | GT_var.data = output.data.max(1)[1] 46 | 47 | score = output 48 | 49 | score_GT = score.gather(1, GT_var.unsqueeze(1)) 50 | score_target = score.gather(1, target_var.unsqueeze(1)) 51 | 52 | loss = (score_target - score_GT).sum() 53 | loss.backward() 54 | 55 | step_alpha = self.step_alpha * (GT_var.data == output.data.max(1)[1]).float() 56 | step_alpha = step_alpha.unsqueeze(1).unsqueeze(1).unsqueeze(1) 57 | 58 | if step_alpha.sum() == 0: 59 | break 60 | 61 | pert = ((score_GT.data - score_target.data).unsqueeze(1).unsqueeze(1)) 62 | normed_grad = step_alpha * (pert+1e-4) * input_var.grad.data / (l2_norm(input_var.grad.data)) 63 | 64 | # perturb current input image by normalized and scaled gradient 65 | overshoot = 0.0 66 | step_adv = input_var.data + (1+overshoot) * normed_grad 67 | 68 | total_adv = step_adv - input 69 | 70 | # apply total adversarial perturbation to original image and clip to valid pixel range 71 | input_adv = input + total_adv 72 | input_adv = torch.clamp(input_adv, -2.5, 2.5) 73 | input_var.data = input_adv 74 | step += 1 75 | 76 | return input_adv 77 | 78 | -------------------------------------------------------------------------------- /attacks/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import operator as op 3 | import functools as ft 4 | 5 | 6 | '''reduce_* helper functions reduce tensors on all dimensions but the first. 7 | They are intended to be used on batched tensors where dim 0 is the batch dim. 8 | ''' 9 | 10 | 11 | def reduce_sum(x, keepdim=True): 12 | # silly PyTorch, when will you get proper reducing sums/means? 13 | for a in reversed(range(1, x.dim())): 14 | x = x.sum(a, keepdim=keepdim) 15 | return x 16 | 17 | 18 | def reduce_mean(x, keepdim=True): 19 | numel = ft.reduce(op.mul, x.size()[1:]) 20 | x = reduce_sum(x, keepdim=keepdim) 21 | return x / numel 22 | 23 | 24 | def reduce_min(x, keepdim=True): 25 | for a in reversed(range(1, x.dim())): 26 | x = x.min(a, keepdim=keepdim)[0] 27 | return x 28 | 29 | 30 | def reduce_max(x, keepdim=True): 31 | for a in reversed(range(1, x.dim())): 32 | x = x.max(a, keepdim=keepdim)[0] 33 | return x 34 | 35 | 36 | def torch_arctanh(x, eps=1e-6): 37 | x *= (1. - eps) 38 | return (torch.log((1 + x) / (1 - x))) * 0.5 39 | 40 | 41 | def l2r_dist(x, y, keepdim=True, eps=1e-8): 42 | d = (x - y)**2 43 | d = reduce_sum(d, keepdim=keepdim) 44 | d += eps # to prevent infinite gradient at 0 45 | return d.sqrt() 46 | 47 | 48 | def l2_dist(x, y, keepdim=True): 49 | d = (x - y)**2 50 | return reduce_sum(d, keepdim=keepdim) 51 | 52 | 53 | def l1_dist(x, y, keepdim=True): 54 | d = torch.abs(x - y) 55 | return reduce_sum(d, keepdim=keepdim) 56 | 57 | 58 | def l2_norm(x, keepdim=True): 59 | norm = reduce_sum(x*x, keepdim=keepdim) 60 | return norm.sqrt() 61 | 62 | 63 | def l1_norm(x, keepdim=True): 64 | return reduce_sum(x.abs(), keepdim=keepdim) 65 | 66 | 67 | def rescale(x, x_min=-1., x_max=1.): 68 | return x * (x_max - x_min) + x_min 69 | 70 | 71 | def tanh_rescale(x, x_min=-1., x_max=1.): 72 | return (torch.tanh(x) + 1) * 0.5 * (x_max - x_min) + x_min 73 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 5 | 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import math 11 | 12 | from torch.autograd import Variable 13 | 14 | class ZeroPadBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(ZeroPadBlock, self).__init__() 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*planes: 26 | self.shortcut = nn.Sequential( 27 | nn.AvgPool2d(kernel_size=1, stride=stride) 28 | ) 29 | 30 | for m in self.modules(): 31 | if isinstance(m, nn.Conv2d): 32 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 33 | m.weight.data.normal_(0, math.sqrt(2. / n)) 34 | elif isinstance(m, nn.BatchNorm2d): 35 | m.weight.data.fill_(1) 36 | m.bias.data.zero_() 37 | 38 | def forward(self, x): 39 | out = F.relu(self.bn1(self.conv1(x))) 40 | out = self.bn2(self.conv2(out)) 41 | out += F.pad(self.shortcut(x), (0, 0, 0, 0, 0, out.size()[1] - x.size()[1]), 'constant', 0) 42 | out = F.relu(out) 43 | return out 44 | 45 | 46 | class ResNet(nn.Module): 47 | def __init__(self, block, num_blocks, num_classes=10): 48 | super(ResNet, self).__init__() 49 | multiplier = 1 50 | self.in_planes = multiplier*16 51 | 52 | self.conv1 = nn.Conv2d(3, multiplier*16, kernel_size=3, stride=1, padding=1, bias=False) 53 | self.bn1 = nn.BatchNorm2d(multiplier*16) 54 | self.layer1 = self._make_layer(block, multiplier*16, num_blocks[0], stride=1) 55 | self.layer2 = self._make_layer(block, multiplier*32, num_blocks[1], stride=2) 56 | self.layer3 = self._make_layer(block, multiplier*64, num_blocks[2], stride=2) 57 | self.linear = nn.Linear(multiplier*64*block.expansion, num_classes) 58 | 59 | for m in self.modules(): 60 | if isinstance(m, nn.Conv2d): 61 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 62 | m.weight.data.normal_(0, math.sqrt(2. / n)) 63 | elif isinstance(m, nn.BatchNorm2d): 64 | m.weight.data.fill_(1) 65 | m.bias.data.zero_() 66 | 67 | def _make_layer(self, block, planes, num_blocks, stride): 68 | strides = [stride] + [1]*(num_blocks-1) 69 | layers = [] 70 | for stride in strides: 71 | layers.append(block(self.in_planes, planes, stride)) 72 | self.in_planes = planes * block.expansion 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = F.relu(self.bn1(self.conv1(x))) 77 | out = self.layer1(out) 78 | out = self.layer2(out) 79 | out = self.layer3(out) 80 | out = F.avg_pool2d(out, 8) 81 | out = out.view(out.size(0), -1) 82 | out = self.linear(out) 83 | return out 84 | 85 | def BN_version_fix(net): 86 | for m in net.modules(): 87 | if isinstance(m, nn.BatchNorm2d): 88 | m.track_running_stats = True 89 | m.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 90 | 91 | return net 92 | 93 | def ResNet8(): 94 | return ResNet(ZeroPadBlock, [1,1,1]) 95 | 96 | def ResNet14(): 97 | return ResNet(ZeroPadBlock, [2,2,2]) 98 | 99 | def ResNet20(): 100 | return ResNet(ZeroPadBlock, [3,3,3]) 101 | 102 | def ResNet26(): 103 | return ResNet(ZeroPadBlock, [4,4,4]) 104 | 105 | 106 | -------------------------------------------------------------------------------- /results/Res26_C10/320_epoch.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bhheo/BSS_distillation/586b1411ed067af92ab781ac4fc58f8cf58798dc/results/Res26_C10/320_epoch.t7 -------------------------------------------------------------------------------- /train_BSS_distillation.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | from __future__ import print_function 3 | 4 | import os 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim as optim 11 | import torchvision 12 | import torchvision.transforms as transforms 13 | import torch.nn.functional as F 14 | import attacks 15 | 16 | from models import * 17 | 18 | # Parameters 19 | dataset_name = 'CIFAR-10' 20 | res_folder = 'results/BSS_distillation_80epoch_res8_C10' 21 | temperature = 3 22 | gpu_num = 0 23 | attack_size = 64 24 | max_epoch = 80 25 | 26 | if not os.path.isdir(res_folder): 27 | os.mkdir(res_folder) 28 | 29 | use_cuda = torch.cuda.is_available() 30 | 31 | # Dataset 32 | if dataset_name is 'CIFAR-10': 33 | # CIFAR-10 34 | print('==> Preparing data..') 35 | transform_train = transforms.Compose([ 36 | transforms.RandomCrop(32, padding=4), 37 | transforms.RandomHorizontalFlip(), 38 | transforms.ToTensor(), 39 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 40 | ]) 41 | 42 | transform_test = transforms.Compose([ 43 | transforms.ToTensor(), 44 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 45 | ]) 46 | 47 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 48 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=4) 49 | 50 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 51 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=4) 52 | else: 53 | raise Exception('Undefined Dataset') 54 | 55 | # Teacher network 56 | teacher = BN_version_fix(torch.load('./results/Res26_C10/320_epoch.t7', map_location=lambda storage, loc: storage.cuda(0))['net']) 57 | t_net = ResNet26() 58 | t_net.load_state_dict(teacher.state_dict()) 59 | 60 | # Student network 61 | s_net = ResNet8() 62 | 63 | 64 | if use_cuda: 65 | torch.cuda.set_device(gpu_num) 66 | t_net.cuda() 67 | s_net.cuda() 68 | cudnn.benchmark = True 69 | 70 | # Proposed adversarial attack algorithm (BSS) 71 | attack = attacks.AttackBSS(targeted=True, num_steps=10, max_epsilon=16, step_alpha=0.3, cuda=True, norm=2) 72 | 73 | criterion_MSE = nn.MSELoss(size_average=False) 74 | criterion_CE = nn.CrossEntropyLoss() 75 | 76 | # Training 77 | def train_attack_KD(t_net, s_net, ratio, ratio_attack, epoch): 78 | epoch_start_time = time.time() 79 | print('\nStage 1 Epoch: %d' % epoch) 80 | s_net.train() 81 | t_net.eval() 82 | train_loss = 0 83 | correct = 0 84 | total = 0 85 | global optimizer 86 | for batch_idx, (inputs, targets) in enumerate(trainloader): 87 | if use_cuda: 88 | inputs, targets = inputs.cuda(), targets.cuda() 89 | 90 | batch_size1 = inputs.shape[0] 91 | 92 | optimizer.zero_grad() 93 | inputs, targets = Variable(inputs), Variable(targets) 94 | 95 | out_s = s_net(inputs) 96 | 97 | # Cross-entropy loss 98 | loss = criterion_CE(out_s[0:batch_size1, :], targets) 99 | out_t = t_net(inputs) 100 | 101 | # KD loss 102 | loss += - ratio * (F.softmax(out_t/temperature, 1).detach() * F.log_softmax(out_s/temperature, 1)).sum() / batch_size1 103 | 104 | if ratio_attack > 0: 105 | 106 | condition1 = targets.data == out_t.sort(dim=1, descending=True)[1][:, 0].data 107 | condition2 = targets.data == out_s.sort(dim=1, descending=True)[1][:, 0].data 108 | 109 | attack_flag = condition1 & condition2 110 | 111 | if attack_flag.sum(): 112 | # Base sample selection 113 | attack_idx = attack_flag.nonzero().squeeze() 114 | if attack_idx.shape[0] > attack_size: 115 | diff = (F.softmax(out_t[attack_idx,:], 1).data - F.softmax(out_s[attack_idx,:], 1).data) ** 2 116 | distill_score = diff.sum(dim=1) - diff.gather(1, targets[attack_idx].data.unsqueeze(1)).squeeze() 117 | attack_idx = attack_idx[distill_score.sort(descending=True)[1][:attack_size]] 118 | 119 | # Target class sampling 120 | attack_class = out_t.sort(dim=1, descending=True)[1][:, 1][attack_idx].data 121 | class_score, class_idx = F.softmax(out_t, 1)[attack_idx, :].data.sort(dim=1, descending=True) 122 | class_score = class_score[:, 1:] 123 | class_idx = class_idx[:, 1:] 124 | 125 | rand_seed = 1 * (class_score.sum(dim=1) * torch.rand([attack_idx.shape[0]]).cuda()).unsqueeze(1) 126 | prob = class_score.cumsum(dim=1) 127 | for k in range(attack_idx.shape[0]): 128 | for c in range(prob.shape[1]): 129 | if (prob[k, c] >= rand_seed[k]).cpu().numpy(): 130 | attack_class[k] = class_idx[k, c] 131 | break 132 | 133 | # Forward and backward for adversarial samples 134 | attacked_inputs = Variable(attack.run(t_net, inputs[attack_idx, :, :, :].data, attack_class)) 135 | batch_size2 = attacked_inputs.shape[0] 136 | 137 | attack_out_t = t_net(attacked_inputs) 138 | attack_out_s = s_net(attacked_inputs) 139 | 140 | # KD loss for Boundary Supporting Samples (BSS) 141 | loss += - ratio_attack * (F.softmax(attack_out_t / temperature, 1).detach() * F.log_softmax(attack_out_s / temperature, 1)).sum() / batch_size2 142 | 143 | loss.backward() 144 | optimizer.step() 145 | 146 | train_loss += loss.data.item() 147 | _, predicted = torch.max(out_s[0:batch_size1, :].data, 1) 148 | total += targets.size(0) 149 | correct += predicted.eq(targets.data).cpu().float().sum() 150 | b_idx = batch_idx 151 | 152 | print('Train \t Time Taken: %.2f sec' % (time.time() - epoch_start_time)) 153 | print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss / (b_idx + 1), 100. * correct / total, correct, total)) 154 | 155 | def test(net, epoch, save=False): 156 | epoch_start_time = time.time() 157 | net.eval() 158 | test_loss = 0 159 | correct = 0 160 | total = 0 161 | for batch_idx, (inputs, targets) in enumerate(testloader): 162 | if use_cuda: 163 | inputs, targets = inputs.cuda(), targets.cuda() 164 | inputs, targets = Variable(inputs), Variable(targets) 165 | outputs = net(inputs) 166 | loss = criterion_CE(outputs, targets) 167 | 168 | test_loss += loss.data.item() 169 | _, predicted = torch.max(outputs.data, 1) 170 | total += targets.size(0) 171 | correct += predicted.eq(targets.data).cpu().float().sum() 172 | b_idx= batch_idx 173 | 174 | print('Test \t Time Taken: %.2f sec' % (time.time() - epoch_start_time)) 175 | print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/(b_idx+1), 100.*correct/total, correct, total)) 176 | 177 | if save: 178 | # Save checkpoint. 179 | acc = 100.*correct/total 180 | if epoch is not 0 and epoch % 80 is 0: 181 | print('Saving..') 182 | state = { 183 | 'net': net if use_cuda else net, 184 | 'acc': acc, 185 | 'epoch': epoch, 186 | } 187 | torch.save(state, './' + res_folder + '/%d_epoch.t7' % epoch) 188 | 189 | for epoch in range(1, max_epoch+1): 190 | if epoch == 1: 191 | optimizer = optim.SGD(s_net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) 192 | elif epoch == max_epoch/2: 193 | optimizer = optim.SGD(s_net.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4) 194 | elif epoch == max_epoch/4*3: 195 | optimizer = optim.SGD(s_net.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) 196 | 197 | ratio = max(3 * (1 - epoch / max_epoch), 0) + 1 198 | attack_ratio = max(2 * (1 - 4 / 3 * epoch / max_epoch), 0) + 0 199 | 200 | train_attack_KD(t_net, s_net, ratio, attack_ratio, epoch) 201 | 202 | test(s_net, epoch, save=True) 203 | 204 | state = { 205 | 'net': s_net, 206 | 'epoch': max_epoch, 207 | } 208 | torch.save(state, './' + res_folder + '/%depoch_final.t7' % (max_epoch)) 209 | --------------------------------------------------------------------------------