├── utils ├── __init__.py ├── __pycache__ │ ├── util.cpython-36.pyc │ └── __init__.cpython-36.pyc └── util.py ├── models ├── __init__.py ├── __pycache__ │ ├── resnet.cpython-36.pyc │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── selector.cpython-35.pyc │ ├── selector.cpython-36.pyc │ ├── wresnet.cpython-35.pyc │ └── wresnet.cpython-36.pyc ├── selector.py ├── wresnet.py ├── resnet.py └── lenet.py ├── __pycache__ ├── at.cpython-36.pyc ├── config.cpython-36.pyc └── data_loader.cpython-36.pyc ├── trigger ├── signal_cifar10_mask.npy └── best_square_trigger_cifar10.npz ├── weight ├── erasing_net │ └── WRN-16-1.tar ├── s_net │ └── WRN-16-1-S-model_best.pth.tar └── t_net │ └── WRN-16-1-T-model_best.pth.tar ├── .idea ├── vcs.xml ├── misc.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml ├── NAD.iml └── workspace.xml ├── at.py ├── results └── results.csv ├── config.py ├── train_badnet.py ├── README.md ├── main.py └── data_loader.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Data:2020/7/3 20:58 3 | # @Author:lyg 4 | -------------------------------------------------------------------------------- /__pycache__/at.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/NAD/HEAD/__pycache__/at.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/NAD/HEAD/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /trigger/signal_cifar10_mask.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/NAD/HEAD/trigger/signal_cifar10_mask.npy -------------------------------------------------------------------------------- /weight/erasing_net/WRN-16-1.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/NAD/HEAD/weight/erasing_net/WRN-16-1.tar -------------------------------------------------------------------------------- /__pycache__/data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/NAD/HEAD/__pycache__/data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/NAD/HEAD/utils/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/NAD/HEAD/models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /trigger/best_square_trigger_cifar10.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/NAD/HEAD/trigger/best_square_trigger_cifar10.npz -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/NAD/HEAD/models/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/NAD/HEAD/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/selector.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/NAD/HEAD/models/__pycache__/selector.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/selector.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/NAD/HEAD/models/__pycache__/selector.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/wresnet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/NAD/HEAD/models/__pycache__/wresnet.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/wresnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/NAD/HEAD/models/__pycache__/wresnet.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/NAD/HEAD/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /weight/s_net/WRN-16-1-S-model_best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/NAD/HEAD/weight/s_net/WRN-16-1-S-model_best.pth.tar -------------------------------------------------------------------------------- /weight/t_net/WRN-16-1-T-model_best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/NAD/HEAD/weight/t_net/WRN-16-1-T-model_best.pth.tar -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/NAD.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /at.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | ''' 9 | AT with sum of absolute values with power p 10 | code from: https://github.com/AberHu/Knowledge-Distillation-Zoo 11 | ''' 12 | class AT(nn.Module): 13 | ''' 14 | Paying More Attention to Attention: Improving the Performance of Convolutional 15 | Neural Netkworks wia Attention Transfer 16 | https://arxiv.org/pdf/1612.03928.pdf 17 | ''' 18 | def __init__(self, p): 19 | super(AT, self).__init__() 20 | self.p = p 21 | 22 | def forward(self, fm_s, fm_t): 23 | loss = F.mse_loss(self.attention_map(fm_s), self.attention_map(fm_t)) 24 | 25 | return loss 26 | 27 | def attention_map(self, fm, eps=1e-6): 28 | am = torch.pow(torch.abs(fm), self.p) 29 | am = torch.sum(am, dim=1, keepdim=True) 30 | norm = torch.norm(am, dim=(2,3), keepdim=True) 31 | am = torch.div(am, norm+eps) 32 | 33 | return am -------------------------------------------------------------------------------- /results/results.csv: -------------------------------------------------------------------------------- 1 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss 2 | 0,85.65555555555555,100.0,1.2821621365017362e-07,1.3864030798806084 3 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss 4 | 1,58.43333333333333,9.6,9.003433816697862,1.207139956580268 5 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss 6 | 2,72.25555555555556,8.4,5.509578734503852,1.2736398759418064 7 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss 8 | 3,80.8,3.311111111111111,8.18274724706014,1.2574408405092028 9 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss 10 | 4,81.41111111111111,4.2,7.806607432471381,1.2187182008955213 11 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss 12 | 5,81.74444444444444,4.322222222222222,7.909650793711345,1.189901822090149 13 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss 14 | 6,82.12222222222222,3.577777777777778,8.153920613182915,1.2188563068177964 15 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss 16 | 7,82.26666666666667,4.688888888888889,7.759241249084472,1.1880263636906943 17 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss 18 | 8,82.16666666666667,4.788888888888889,7.93371855629815,1.1945802669525147 19 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss 20 | 9,81.75555555555556,5.322222222222222,7.186521497938368,1.181448416603936 21 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss 22 | 10,82.7,4.4,8.051894421895344,1.204457118988037 23 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss 24 | 11,82.33333333333333,4.2444444444444445,8.007042150709363,1.1921954712337919 25 | -------------------------------------------------------------------------------- /models/selector.py: -------------------------------------------------------------------------------- 1 | from models.wresnet import * 2 | from models.resnet import * 3 | import os 4 | 5 | def select_model(dataset, 6 | model_name, 7 | pretrained=False, 8 | pretrained_models_path=None, 9 | n_classes=10): 10 | 11 | assert model_name in ['WRN-16-1', 'WRN-16-2', 'WRN-40-1', 'WRN-40-2', 'ResNet34', 'WRN-10-2', 'WRN-10-1'] 12 | if model_name=='WRN-16-1': 13 | model = WideResNet(depth=16, num_classes=n_classes, widen_factor=1, dropRate=0) 14 | elif model_name=='WRN-16-2': 15 | model = WideResNet(depth=16, num_classes=n_classes, widen_factor=2, dropRate=0) 16 | elif model_name=='WRN-40-1': 17 | model = WideResNet(depth=40, num_classes=n_classes, widen_factor=1, dropRate=0) 18 | elif model_name=='WRN-40-2': 19 | model = WideResNet(depth=40, num_classes=n_classes, widen_factor=2, dropRate=0) 20 | elif model_name == 'WRN-10-2': 21 | model = WideResNet(depth=10, num_classes=n_classes, widen_factor=2, dropRate=0) 22 | elif model_name == 'WRN-10-1': 23 | model = WideResNet(depth=10, num_classes=n_classes, widen_factor=1, dropRate=0) 24 | elif model_name=='ResNet34': 25 | model = resnet(depth=32, num_classes=n_classes) 26 | else: 27 | raise NotImplementedError 28 | 29 | if pretrained: 30 | model_path = os.path.join(pretrained_models_path) 31 | print('Loading Model from {}'.format(model_path)) 32 | checkpoint = torch.load(model_path, map_location='cpu') 33 | print(checkpoint.keys()) 34 | model.load_state_dict(checkpoint['state_dict']) 35 | 36 | #print("=> loaded checkpoint '{}' (epoch {}) (accuracy {})".format(model_path, checkpoint['epoch'], checkpoint['best_prec'])) 37 | print("=> loaded checkpoint '{}' (epoch {}) ".format(model_path, checkpoint['epoch'])) 38 | 39 | 40 | return model 41 | 42 | if __name__ == '__main__': 43 | 44 | import torch 45 | from torchsummary import summary 46 | import random 47 | import time 48 | 49 | random.seed(1234) # torch transforms use this seed 50 | torch.manual_seed(1234) 51 | torch.cuda.manual_seed(1234) 52 | 53 | support_x_task = torch.autograd.Variable(torch.FloatTensor(64, 3, 32, 32).uniform_(0, 1)) 54 | 55 | t0 = time.time() 56 | model = select_model('CIFAR10', model_name='WRN-16-2') 57 | output, act = model(support_x_task) 58 | print("Time taken for forward pass: {} s".format(time.time() - t0)) 59 | print("\nOUTPUT SHAPE: ", output.shape) 60 | summary(model, (3, 32, 32)) -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_arguments(): 4 | parser = argparse.ArgumentParser() 5 | 6 | # various path 7 | parser.add_argument('--checkpoint_root', type=str, default='./weight/erasing_net', help='models weight are saved here') 8 | parser.add_argument('--log_root', type=str, default='./results', help='logs are saved here') 9 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='name of image dataset') 10 | parser.add_argument('--s_model', type=str, default='./weight/s_net/WRN-16-1-S-model_best.pth.tar', help='path of student model') 11 | parser.add_argument('--t_model', type=str, default='./weight/t_net/WRN-16-1-T-model_best.pth.tar', help='path of teacher model') 12 | 13 | # training hyper parameters 14 | parser.add_argument('--print_freq', type=int, default=50, help='frequency of showing training results on console') 15 | parser.add_argument('--epochs', type=int, default=20, help='number of total epochs to run') 16 | parser.add_argument('--batch_size', type=int, default=64, help='The size of batch') 17 | parser.add_argument('--lr', type=float, default=0.1, help='initial learning rate') 18 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 19 | parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay') 20 | parser.add_argument('--num_class', type=int, default=10, help='number of classes') 21 | parser.add_argument('--ratio', type=float, default=0.05, help='ratio of training data') 22 | parser.add_argument('--beta1', type=int, default=500, help='beta of low layer') 23 | parser.add_argument('--beta2', type=int, default=1000, help='beta of middle layer') 24 | parser.add_argument('--beta3', type=int, default=1000, help='beta of high layer') 25 | parser.add_argument('--p', type=float, default=2.0, help='power for AT') 26 | parser.add_argument('--threshold_clean', type=float, default=70.0, help='threshold of save weight') 27 | parser.add_argument('--threshold_bad', type=float, default=90.0, help='threshold of save weight') 28 | parser.add_argument('--cuda', type=int, default=1) 29 | parser.add_argument('--device', type=str, default='cuda') 30 | parser.add_argument('--save', type=int, default=1) 31 | 32 | # others 33 | parser.add_argument('--seed', type=int, default=2, help='random seed') 34 | parser.add_argument('--note', type=str, default='try', help='note for this run') 35 | 36 | # net and dataset choosen 37 | parser.add_argument('--data_name', type=str, default='CIFAR10', help='name of dataset') 38 | parser.add_argument('--t_name', type=str, default='WRN-16-1', help='name of teacher') 39 | parser.add_argument('--s_name', type=str, default='WRN-16-1', help='name of student') 40 | 41 | # backdoor attacks 42 | parser.add_argument('--inject_portion', type=float, default=0.1, help='ratio of backdoor samples') 43 | parser.add_argument('--target_label', type=int, default=5, help='class of target label') 44 | parser.add_argument('--trigger_type', type=str, default='gridTrigger', help='type of backdoor trigger') 45 | parser.add_argument('--target_type', type=str, default='all2one', help='type of backdoor label') 46 | parser.add_argument('--trig_w', type=int, default=3, help='width of trigger pattern') 47 | parser.add_argument('--trig_h', type=int, default=3, help='height of trigger pattern') 48 | 49 | return parser 50 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import os 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | class AverageMeter(object): 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | 25 | def print_network(net): 26 | num_params = 0 27 | for param in net.parameters(): 28 | num_params += param.numel() 29 | print(net) 30 | print('Total number of parameters: %d' % num_params) 31 | 32 | 33 | def load_pretrained_model(model, pretrained_dict, wfc=True): 34 | model_dict = model.state_dict() 35 | # 1. filter out unnecessary keys 36 | if wfc: 37 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 38 | else: 39 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if ((k in model_dict) and ('fc' not in k))} 40 | # 2. overwrite entries in the existing state dict 41 | model_dict.update(pretrained_dict) 42 | # 3. load the new state dict 43 | model.load_state_dict(model_dict) 44 | 45 | 46 | def transform_time(s): 47 | m, s = divmod(s, 60) 48 | h, m = divmod(m, 60) 49 | return h, m, s 50 | 51 | 52 | def accuracy(output, target, topk=(1,)): 53 | """Computes the precision@k for the specified values of k""" 54 | maxk = max(topk) 55 | batch_size = target.size(0) 56 | 57 | _, pred = output.topk(maxk, 1, True, True) 58 | pred = pred.t() 59 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 60 | 61 | res = [] 62 | for k in topk: 63 | correct_k = correct[:k].view(-1).float().sum(0) 64 | res.append(correct_k.mul_(100.0 / batch_size)) 65 | return res 66 | 67 | 68 | def adjust_learning_rate(optimizer, epoch, lr): 69 | if epoch < 2: 70 | lr = lr 71 | elif epoch < 20: 72 | lr = 0.01 73 | elif epoch < 30: 74 | lr = 0.0001 75 | else: 76 | lr = 0.0001 77 | print('epoch: {} lr: {:.4f}'.format(epoch, lr)) 78 | for param_group in optimizer.param_groups: 79 | param_group['lr'] = lr 80 | 81 | 82 | def save_checkpoint(state, is_best, fdir, model_name): 83 | filepath = os.path.join(fdir, model_name + '.tar') 84 | if is_best: 85 | torch.save(state, filepath) 86 | print('[info] save best model') 87 | 88 | 89 | def save_history(cls_orig_acc, clease_trig_acc, cls_trig_loss, at_trig_loss, at_epoch_list, logs_dir): 90 | dataframe = pd.DataFrame({'epoch': at_epoch_list, 'cls_orig_acc': cls_orig_acc, 'clease_trig_acc': clease_trig_acc, 91 | 'cls_trig_loss': cls_trig_loss, 'at_trig_loss': at_trig_loss}) 92 | # 将DataFrame存储为csv,index表示是否显示行名,default=True 93 | dataframe.to_csv(logs_dir, index=False, sep=',') 94 | 95 | def plot_curve(clean_acc, bad_acc, epochs, dataset_name): 96 | N = epochs+1 97 | plt.style.use("ggplot") 98 | plt.figure() 99 | plt.plot(np.arange(0, N), clean_acc, label="Classification Accuracy", marker='D', color='blue') 100 | plt.plot(np.arange(0, N), bad_acc, label="Attack Success Rate", marker='o', color='red') 101 | plt.title(dataset_name) 102 | plt.xlabel("Epoch") 103 | plt.ylabel("Student Model Accuracy/Attack Success Rate(%)") 104 | plt.xticks(range(0, N, 1)) 105 | plt.yticks(range(0, 101, 20)) 106 | plt.legend() 107 | plt.show() -------------------------------------------------------------------------------- /models/wresnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adapted from https://github.com/xternalz/WideResNet-pytorch 3 | Modifications = return activations for use in attention transfer, 4 | as done before e.g in https://github.com/BayesWatch/pytorch-moonshine 5 | """ 6 | 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 15 | super(BasicBlock, self).__init__() 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.relu1 = nn.ReLU(inplace=True) 18 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(out_planes) 21 | self.relu2 = nn.ReLU(inplace=True) 22 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 23 | padding=1, bias=False) 24 | self.droprate = dropRate 25 | self.equalInOut = (in_planes == out_planes) 26 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 27 | padding=0, bias=False) or None 28 | def forward(self, x): 29 | if not self.equalInOut: 30 | x = self.relu1(self.bn1(x)) 31 | else: 32 | out = self.relu1(self.bn1(x)) 33 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 34 | if self.droprate > 0: 35 | out = F.dropout(out, p=self.droprate, training=self.training) 36 | out = self.conv2(out) 37 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 38 | 39 | class NetworkBlock(nn.Module): 40 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 41 | super(NetworkBlock, self).__init__() 42 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 43 | 44 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 45 | layers = [] 46 | for i in range(int(nb_layers)): 47 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 48 | return nn.Sequential(*layers) 49 | 50 | def forward(self, x): 51 | return self.layer(x) 52 | 53 | class WideResNet(nn.Module): 54 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 55 | super(WideResNet, self).__init__() 56 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 57 | assert((depth - 4) % 6 == 0) 58 | n = (depth - 4) / 6 59 | block = BasicBlock 60 | # 1st conv before any network block 61 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 62 | padding=1, bias=False) 63 | # 1st block 64 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 65 | # 2nd block 66 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 67 | # 3rd block 68 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 69 | # global average pooling and classifier 70 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.fc = nn.Linear(nChannels[3], num_classes) 73 | self.nChannels = nChannels[3] 74 | 75 | for m in self.modules(): 76 | if isinstance(m, nn.Conv2d): 77 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 78 | m.weight.data.normal_(0, math.sqrt(2. / n)) 79 | elif isinstance(m, nn.BatchNorm2d): 80 | m.weight.data.fill_(1) 81 | m.bias.data.zero_() 82 | elif isinstance(m, nn.Linear): 83 | m.bias.data.zero_() 84 | 85 | 86 | def forward(self, x): 87 | out = self.conv1(x) 88 | out = self.block1(out) 89 | activation1 = out 90 | out = self.block2(out) 91 | activation2 = out 92 | out = self.block3(out) 93 | activation3 = out 94 | out = self.relu(self.bn1(out)) 95 | out = F.avg_pool2d(out, 8) 96 | out = out.view(-1, self.nChannels) 97 | return activation1, activation2, activation3, self.fc(out) 98 | 99 | 100 | if __name__ == '__main__': 101 | import random 102 | import time 103 | # from torchsummary import summary 104 | 105 | random.seed(1234) # torch transforms use this seed 106 | torch.manual_seed(1234) 107 | torch.cuda.manual_seed(1234) 108 | 109 | x = torch.FloatTensor(64, 3, 32, 32).uniform_(0, 1) 110 | 111 | ### WideResNets 112 | # Notation: W-depth-wideningfactor 113 | model = WideResNet(depth=16, num_classes=10, widen_factor=1, dropRate=0.0) 114 | model = WideResNet(depth=16, num_classes=10, widen_factor=2, dropRate=0.0) 115 | #model = WideResNet(depth=16, num_classes=10, widen_factor=8, dropRate=0.0) 116 | #model = WideResNet(depth=16, num_classes=10, widen_factor=10, dropRate=0.0) 117 | #model = WideResNet(depth=22, num_classes=10, widen_factor=8, dropRate=0.0) 118 | #model = WideResNet(depth=34, num_classes=10, widen_factor=2, dropRate=0.0) 119 | #model = WideResNet(depth=40, num_classes=10, widen_factor=10, dropRate=0.0) 120 | model = WideResNet(depth=40, num_classes=10, widen_factor=1, dropRate=0.0) 121 | model = WideResNet(depth=40, num_classes=10, widen_factor=2, dropRate=0.0) 122 | ###model = WideResNet(depth=50, num_classes=10, widen_factor=2, dropRate=0.0) 123 | 124 | 125 | t0 = time.time() 126 | output, _, __, ___ = model(x) 127 | print("Time taken for forward pass: {} s".format(time.time() - t0)) 128 | print("\nOUTPUT SHPAE: ", output.shape) 129 | 130 | # summary(model, input_size=(3, 32, 32)) -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Data:2020/7/14 17:37 3 | # @Author:lyg 4 | 5 | from __future__ import absolute_import 6 | 7 | '''Resnet for cifar dataset. 8 | Ported form 9 | https://github.com/facebook/fb.resnet.torch 10 | and 11 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 12 | (c) YANG, Wei 13 | ''' 14 | import torch.nn as nn 15 | import math 16 | 17 | __all__ = ['resnet'] 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1): 21 | "3x3 convolution with padding" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, inplanes, planes, stride=1, downsample=None): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = conv3x3(inplanes, planes, stride) 32 | self.bn1 = nn.BatchNorm2d(planes) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.conv2 = conv3x3(planes, planes) 35 | self.bn2 = nn.BatchNorm2d(planes) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | 49 | if self.downsample is not None: 50 | residual = self.downsample(x) 51 | 52 | out += residual 53 | out = self.relu(out) 54 | 55 | return out 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | expansion = 4 60 | 61 | def __init__(self, inplanes, planes, stride=1, downsample=None): 62 | super(Bottleneck, self).__init__() 63 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 64 | self.bn1 = nn.BatchNorm2d(planes) 65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 66 | padding=1, bias=False) 67 | self.bn2 = nn.BatchNorm2d(planes) 68 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 69 | self.bn3 = nn.BatchNorm2d(planes * 4) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | residual = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | residual = self.downsample(x) 90 | 91 | out += residual 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | 97 | class ResNet(nn.Module): 98 | 99 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock'): 100 | super(ResNet, self).__init__() 101 | # Model type specifies number of layers for CIFAR-10 model 102 | if block_name.lower() == 'basicblock': 103 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 104 | n = (depth - 2) // 6 105 | block = BasicBlock 106 | elif block_name.lower() == 'bottleneck': 107 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 108 | n = (depth - 2) // 9 109 | block = Bottleneck 110 | else: 111 | raise ValueError('block_name shoule be Basicblock or Bottleneck') 112 | 113 | self.inplanes = 16 114 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 115 | bias=False) 116 | self.bn1 = nn.BatchNorm2d(16) 117 | self.relu = nn.ReLU(inplace=True) 118 | self.layer1 = self._make_layer(block, 16, n) 119 | self.layer2 = self._make_layer(block, 32, n, stride=2) 120 | self.layer3 = self._make_layer(block, 64, n, stride=2) 121 | self.avgpool = nn.AvgPool2d(8) 122 | self.fc = nn.Linear(64 * block.expansion, num_classes) 123 | 124 | for m in self.modules(): 125 | if isinstance(m, nn.Conv2d): 126 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 127 | m.weight.data.normal_(0, math.sqrt(2. / n)) 128 | elif isinstance(m, nn.BatchNorm2d): 129 | m.weight.data.fill_(1) 130 | m.bias.data.zero_() 131 | 132 | def _make_layer(self, block, planes, blocks, stride=1): 133 | downsample = None 134 | if stride != 1 or self.inplanes != planes * block.expansion: 135 | downsample = nn.Sequential( 136 | nn.Conv2d(self.inplanes, planes * block.expansion, 137 | kernel_size=1, stride=stride, bias=False), 138 | nn.BatchNorm2d(planes * block.expansion), 139 | ) 140 | 141 | layers = [] 142 | layers.append(block(self.inplanes, planes, stride, downsample)) 143 | self.inplanes = planes * block.expansion 144 | for i in range(1, blocks): 145 | layers.append(block(self.inplanes, planes)) 146 | 147 | return nn.Sequential(*layers) 148 | 149 | def forward(self, x): 150 | x = self.conv1(x) 151 | x = self.bn1(x) 152 | x = self.relu(x) # 32x32 153 | 154 | x = self.layer1(x) # 32x32 155 | activation1 = x 156 | x = self.layer2(x) # 16x16 157 | activation2 = x 158 | x = self.layer3(x) # 8x8 159 | activation3 = x 160 | 161 | x = self.avgpool(x) 162 | x = x.view(x.size(0), -1) 163 | x = self.fc(x) 164 | 165 | return activation1, activation2, activation3, x 166 | 167 | 168 | def resnet(**kwargs): 169 | """ 170 | Constructs a ResNet model. 171 | """ 172 | return ResNet(**kwargs) -------------------------------------------------------------------------------- /models/lenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class View(nn.Module): 6 | """ 7 | For convenience so we can add in in nn.Sequential 8 | instead of doing it manually in forward() 9 | """ 10 | def __init__(self, size): 11 | super(View, self).__init__() 12 | self.size = size 13 | 14 | def forward(self, tensor): 15 | return tensor.view(self.size) 16 | 17 | class LeNet5(nn.Module): 18 | """ 19 | For SVHN/CIFAR experiments 20 | """ 21 | def __init__(self, n_classes): 22 | super(LeNet5, self).__init__() 23 | self.n_classes = n_classes 24 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3) 25 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3) 26 | # self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 27 | # self.conv3 = nn.Conv2d(32, 64, kernel_size=3) 28 | # self.conv4 = nn.Conv2d(64, 64, kernel_size=3) 29 | # self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 30 | # self.conv4_drop = nn.Dropout2d(0.5) 31 | self.fc1 = nn.Linear(64*6*6, 128) 32 | self.fc2 = nn.Linear(128, n_classes) 33 | 34 | def forward(self, x): 35 | out = F.relu(F.max_pool2d(self.conv1(x), 2)) 36 | # print('out;', out.shape) 37 | out = F.relu(F.max_pool2d(self.conv2(out), 2)) 38 | activation = out 39 | #print('out;', out.shape) 40 | out = out.view(-1, 64*6*6) 41 | out = F.relu(self.fc1(out)) 42 | out = F.dropout(out, training=self.training) 43 | out = self.fc2(out) 44 | return activation, out 45 | 46 | 47 | 48 | class LeNet7_T(nn.Module): 49 | """ 50 | For SVHN/MNIST experiments 51 | """ 52 | def __init__(self, n_classes): 53 | super(LeNet7_T, self).__init__() 54 | self.n_classes = n_classes 55 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3) 56 | self.conv2 = nn.Conv2d(32, 32, kernel_size=3) 57 | # self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 58 | self.conv3 = nn.Conv2d(32, 64, kernel_size=3) 59 | self.conv4 = nn.Conv2d(64, 64, kernel_size=3) 60 | # self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 61 | # self.conv4_drop = nn.Dropout2d(0.5) 62 | self.fc1 = nn.Linear(64*4*4, 200) 63 | self.fc2 = nn.Linear(200, n_classes) 64 | 65 | def forward(self, x): 66 | out = F.relu((self.conv1(x))) 67 | # print('out;', out.shape) 68 | out = F.relu(F.max_pool2d(self.conv2(out), 2)) 69 | out = F.relu((self.conv3(out))) 70 | out = F.relu(F.max_pool2d(self.conv4(out), 2)) 71 | activation = out 72 | # print('out;', out.shape) 73 | out = out.view(-1, 64*4*4) 74 | out = F.relu(self.fc1(out)) 75 | out = F.dropout(out, training=self.training) 76 | out = self.fc2(out) 77 | return activation, out 78 | 79 | 80 | class LeNet7_S(nn.Module): 81 | """ 82 | For SVHN/MNIST experiments 83 | """ 84 | def __init__(self, n_classes): 85 | super(LeNet7_S, self).__init__() 86 | 87 | self.n_classes = n_classes 88 | self.conv1 = nn.Conv2d(1, 64, kernel_size=3) 89 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3) 90 | # self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 91 | self.conv3 = nn.Conv2d(64, 128, kernel_size=3) 92 | self.conv4 = nn.Conv2d(128, 128, kernel_size=3) 93 | # self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 94 | # self.conv4_drop = nn.Dropout2d(0.5) 95 | self.fc1 = nn.Linear(128*4*4, 256) 96 | self.fc2 = nn.Linear(256, n_classes) 97 | 98 | def forward(self, x): 99 | out = F.relu((self.conv1(x))) 100 | out = F.relu(F.max_pool2d(self.conv2(out), 2)) 101 | out = F.relu((self.conv3(out))) 102 | out = F.relu(F.max_pool2d(self.conv4(out), 2)) 103 | activation = out 104 | # print('out;', out.shape) 105 | out = out.view(-1, 128*4*4) 106 | out = F.relu(self.fc1(out)) 107 | out = F.dropout(out, training=self.training) 108 | out = self.fc2(out) 109 | return activation, out 110 | 111 | class trojan_model(nn.Module): 112 | """ 113 | For train trojan model 114 | """ 115 | def __init__(self, n_classes): 116 | super(trojan_model, self).__init__() 117 | 118 | self.n_classes = n_classes 119 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3) 120 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3) 121 | # self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 122 | self.conv3 = nn.Conv2d(64, 128, kernel_size=3) 123 | self.conv4 = nn.Conv2d(128, 128, kernel_size=3) 124 | # self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 125 | # self.conv4_drop = nn.Dropout2d(0.5) 126 | self.fc1 = nn.Linear(128*5*5, 256) 127 | self.fc2 = nn.Linear(256, n_classes) 128 | 129 | def forward(self, x): 130 | out = F.relu((self.conv1(x))) 131 | out = F.relu(F.max_pool2d(self.conv2(out), 2)) 132 | activation1 = out 133 | out = F.relu((self.conv3(out))) 134 | activation2 = out 135 | out = F.relu(F.max_pool2d(self.conv4(out), 2)) 136 | activation3 = out 137 | # print('out;', out.shape) 138 | out = out.view(-1, 128*5*5) 139 | out = F.relu(self.fc1(out)) 140 | out = F.dropout(out, training=self.training) 141 | out = self.fc2(out) 142 | return activation1, activation2, activation3, out 143 | 144 | # 145 | # if __name__ == '__main__': 146 | # import random 147 | # import sys 148 | # # from torchsummary import summary 149 | # 150 | # random.seed(1234) # torch transforms use this seed 151 | # torch.manual_seed(1234) 152 | # torch.cuda.manual_seed(1234) 153 | # 154 | # ### LENET5 155 | # x = torch.FloatTensor(64, 3, 32, 32).uniform_(0, 1) 156 | # true_labels = torch.tensor([[2.], [3], [1], [8], [4]], requires_grad=True) 157 | # model = LeNet5(n_classes=10) 158 | # output, act = model(x) 159 | # print("\nOUTPUT SHAPE: ", output.shape) 160 | # 161 | # # summary(model, input_size=(3,32,32)) 162 | 163 | -------------------------------------------------------------------------------- /train_badnet.py: -------------------------------------------------------------------------------- 1 | from models.selector import * 2 | from utils.util import * 3 | from data_loader import get_test_loader, get_backdoor_loader 4 | from config import get_arguments 5 | 6 | 7 | def train_step(opt, train_loader, nets, optimizer, criterions, epoch): 8 | cls_losses = AverageMeter() 9 | top1 = AverageMeter() 10 | top5 = AverageMeter() 11 | 12 | snet = nets['snet'] 13 | 14 | criterionCls = criterions['criterionCls'] 15 | snet.train() 16 | 17 | for idx, (img, target) in enumerate(train_loader, start=1): 18 | if opt.cuda: 19 | img = img.cuda() 20 | target = target.cuda() 21 | 22 | _, _, _, output_s = snet(img) 23 | 24 | cls_loss = criterionCls(output_s, target) 25 | 26 | prec1, prec5 = accuracy(output_s, target, topk=(1, 5)) 27 | cls_losses.update(cls_loss.item(), img.size(0)) 28 | top1.update(prec1.item(), img.size(0)) 29 | top5.update(prec5.item(), img.size(0)) 30 | 31 | optimizer.zero_grad() 32 | cls_loss.backward() 33 | optimizer.step() 34 | 35 | if idx % opt.print_freq == 0: 36 | print('Epoch[{0}]:[{1:03}/{2:03}] ' 37 | 'cls_loss:{losses.val:.4f}({losses.avg:.4f}) ' 38 | 'prec@1:{top1.val:.2f}({top1.avg:.2f}) ' 39 | 'prec@5:{top5.val:.2f}({top5.avg:.2f})'.format(epoch, idx, len(train_loader), losses=cls_losses, top1=top1, top5=top5)) 40 | 41 | 42 | def test(opt, test_clean_loader, test_bad_loader, nets, criterions, epoch): 43 | test_process = [] 44 | top1 = AverageMeter() 45 | top5 = AverageMeter() 46 | 47 | snet = nets['snet'] 48 | criterionCls = criterions['criterionCls'] 49 | snet.eval() 50 | 51 | for idx, (img, target) in enumerate(test_clean_loader, start=1): 52 | img = img.cuda() 53 | target = target.cuda() 54 | 55 | with torch.no_grad(): 56 | _, _, _, output_s = snet(img) 57 | 58 | prec1, prec5 = accuracy(output_s, target, topk=(1, 5)) 59 | top1.update(prec1.item(), img.size(0)) 60 | top5.update(prec5.item(), img.size(0)) 61 | 62 | acc_clean = [top1.avg, top5.avg] 63 | 64 | cls_losses = AverageMeter() 65 | at_losses = AverageMeter() 66 | top1 = AverageMeter() 67 | top5 = AverageMeter() 68 | 69 | for idx, (img, target) in enumerate(test_bad_loader, start=1): 70 | img = img.cuda() 71 | target = target.cuda() 72 | 73 | with torch.no_grad(): 74 | _, _, _, output_s = snet(img) 75 | cls_loss = criterionCls(output_s, target) 76 | 77 | prec1, prec5 = accuracy(output_s, target, topk=(1, 5)) 78 | cls_losses.update(cls_loss.item(), img.size(0)) 79 | top1.update(prec1.item(), img.size(0)) 80 | top5.update(prec5.item(), img.size(0)) 81 | 82 | acc_bd = [top1.avg, top5.avg, cls_losses.avg] 83 | 84 | print('[clean]Prec@1: {:.2f}'.format(acc_clean[0])) 85 | print('[bad]Prec@1: {:.2f}'.format(acc_bd[0])) 86 | 87 | # save training progress 88 | log_root = opt.log_root + '/backdoor_results.csv' 89 | test_process.append( 90 | (epoch, acc_clean[0], acc_bd[0], acc_bd[2])) 91 | df = pd.DataFrame(test_process, columns=( 92 | "epoch", "test_clean_acc", "test_bad_acc", "test_bad_cls_loss")) 93 | df.to_csv(log_root, mode='a', index=False, encoding='utf-8') 94 | 95 | return acc_clean, acc_bd 96 | 97 | 98 | def train(opt): 99 | # Load models 100 | print('----------- Network Initialization --------------') 101 | student = select_model(dataset=opt.data_name, 102 | model_name=opt.s_name, 103 | pretrained=False, 104 | pretrained_models_path=opt.s_model, 105 | n_classes=opt.num_class).to(opt.device) 106 | print('finished student model init...') 107 | 108 | nets = {'snet': student} 109 | 110 | # initialize optimizer 111 | optimizer = torch.optim.SGD(student.parameters(), 112 | lr=opt.lr, 113 | momentum=opt.momentum, 114 | weight_decay=opt.weight_decay, 115 | nesterov=True) 116 | 117 | # define loss functions 118 | if opt.cuda: 119 | criterionCls = nn.CrossEntropyLoss().cuda() 120 | else: 121 | criterionCls = nn.CrossEntropyLoss() 122 | 123 | print('----------- DATA Initialization --------------') 124 | train_loader = get_backdoor_loader(opt) 125 | test_clean_loader, test_bad_loader = get_test_loader(opt) 126 | 127 | print('----------- Train Initialization --------------') 128 | for epoch in range(1, opt.epochs): 129 | 130 | _adjust_learning_rate(optimizer, epoch, opt.lr) 131 | 132 | # train every epoch 133 | criterions = {'criterionCls': criterionCls} 134 | train_step(opt, train_loader, nets, optimizer, criterions, epoch) 135 | 136 | # evaluate on testing set 137 | print('testing the models......') 138 | acc_clean, acc_bad = test(opt, test_clean_loader, test_bad_loader, nets, criterions, epoch) 139 | 140 | # remember best precision and save checkpoint 141 | if opt.save: 142 | is_best = acc_bad[0] > opt.threshold_bad 143 | opt.threshold_bad = min(acc_bad[0], opt.threshold_bad) 144 | 145 | best_clean_acc = acc_clean[0] 146 | best_bad_acc = acc_bad[0] 147 | 148 | s_name = opt.s_name + '-S-model_best.pth' 149 | save_checkpoint({ 150 | 'epoch': epoch, 151 | 'state_dict': student.state_dict(), 152 | 'best_clean_acc': best_clean_acc, 153 | 'best_bad_acc': best_bad_acc, 154 | 'optimizer': optimizer.state_dict(), 155 | }, is_best, opt.checkpoint_root, s_name) 156 | 157 | 158 | def _adjust_learning_rate(optimizer, epoch, lr): 159 | if epoch < 21: 160 | lr = lr 161 | elif epoch < 30: 162 | lr = 0.01 * lr 163 | else: 164 | lr = 0.0009 165 | print('epoch: {} lr: {:.4f}'.format(epoch, lr)) 166 | for param_group in optimizer.param_groups: 167 | param_group['lr'] = lr 168 | 169 | def main(): 170 | # Prepare arguments 171 | opt = get_arguments().parse_args() 172 | train(opt) 173 | 174 | if (__name__ == '__main__'): 175 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Attention Distillation 2 | 3 | This is an implementation demo of the ICLR 2021 paper **[Neural Attention Distillation: Erasing Backdoor Triggers from Deep Neural Networks](https://openreview.net/pdf?id=9l0K4OM-oXE)** in PyTorch. 4 | 5 | ![Python 3.6](https://img.shields.io/badge/python-3.6-DodgerBlue.svg?style=plastic) 6 | ![Pytorch 1.10](https://img.shields.io/badge/pytorch-1.2.0-DodgerBlue.svg?style=plastic) 7 | ![CUDA 10.0](https://img.shields.io/badge/cuda-10.0-DodgerBlue.svg?style=plastic) 8 | ![License CC BY-NC](https://img.shields.io/badge/license-CC_BY--NC-DodgerBlue.svg?style=plastic) 9 | 10 | ## NAD: Quick start with pretrained model 11 | We have already uploaded the `all2one` pretrained backdoor student model(i.e. gridTrigger WRN-16-1, target label 5) and the clean teacher model(i.e. WRN-16-1) in the path of `./weight/s_net` and `./weight/t_net` respectively. 12 | 13 | For evaluating the performance of NAD, you can easily run command: 14 | 15 | ```bash 16 | $ python main.py 17 | ``` 18 | where the default parameters are shown in `config.py`. 19 | 20 | The trained model will be saved at the path `weight/erasing_net/.tar` 21 | 22 | Please carefully read the `main.py` and `configs.py`, then change the parameters for your experiment. 23 | 24 | ### Erasing Results on BadNets 25 | - The setting of data augmentation for Finetuning and NAD in this table: 26 | ``` 27 | tf_train = transforms.Compose([ 28 | transforms.RandomCrop(32, padding=4), 29 | transforms.RandomHorizontalFlip(), 30 | transforms.ToTensor() 31 | ]) 32 | ``` 33 | 34 | | Dataset | Baseline ACC | Baseline ASR | Finetuning ACC | Finetuning ASR | NAD ACC | NAD ASR | 35 | | -------- | ------------ | ------------ | ------- | ------- | ------- |------- | 36 | | CIFAR-10 | 85.65 | 100.0 | 82.32 | 18.13 | 82.12 | **3.57** | 37 | 38 | --- 39 | 40 | ## Training your own backdoored model 41 | We have provided a `DatasetBD` Class in `data_loader.py` for generating training set of different backdoor attacks. 42 | 43 | For implementing backdoor attack(e.g. GridTrigger attack), you can run the below command: 44 | 45 | ```bash 46 | $ python train_badnet.py 47 | ``` 48 | This command will train the backdoored model and print clean accuracies and attack rate. You can also select the other backdoor triggers reported in the paper. 49 | 50 | Please carefully read the `train_badnet.py` and `configs.py`, then change the parameters for your experiment. 51 | 52 | ## How to get teacher model? 53 | we obtained the teacher model by finetuning all layers of the backdoored model using 5% clean data with data augmentation techniques. In our paper, we only finetuning the backdoored model for 5~10 epochs. Please check more details of our experimental settings in section 4.1 and Appendix A; The finetuning code is easy to get by just setting all the param `beta = 0`, which means the distillation loss to be zero in the training process. 54 | 55 | ## Other source of backdoor attacks 56 | #### Attack 57 | 58 | **CL:** Clean-label backdoor attacks 59 | 60 | - [Paper](https://people.csail.mit.edu/madry/lab/cleanlabel.pdf) 61 | - [pytorch implementation](https://github.com/hkunzhe/label_consistent_attacks_pytorch) 62 | 63 | **SIG:** A New Backdoor Attack in CNNS by Training Set Corruption Without Label Poisoning 64 | 65 | - [Paper](https://ieeexplore.ieee.org/document/8802997/footnotes) 66 | 67 | ```python 68 | ## reference code 69 | def plant_sin_trigger(img, delta=20, f=6, debug=False): 70 | """ 71 | Implement paper: 72 | > Barni, M., Kallas, K., & Tondi, B. (2019). 73 | > A new Backdoor Attack in CNNs by training set corruption without label poisoning. 74 | > arXiv preprint arXiv:1902.11237 75 | superimposed sinusoidal backdoor signal with default parameters 76 | """ 77 | alpha = 0.2 78 | img = np.float32(img) 79 | pattern = np.zeros_like(img) 80 | m = pattern.shape[1] 81 | for i in range(img.shape[0]): 82 | for j in range(img.shape[1]): 83 | for k in range(img.shape[2]): 84 | pattern[i, j] = delta * np.sin(2 * np.pi * j * f / m) 85 | 86 | img = alpha * np.uint32(img) + (1 - alpha) * pattern 87 | img = np.uint8(np.clip(img, 0, 255)) 88 | 89 | # if debug: 90 | # cv2.imshow('planted image', img) 91 | # cv2.waitKey() 92 | 93 | return img 94 | ``` 95 | 96 | **Refool**: Reflection Backdoor: A Natural Backdoor Attack on Deep Neural Networks 97 | 98 | - [Paper](https://arxiv.org/abs/2007.02343) 99 | - [Code](https://github.com/DreamtaleCore/Refool) 100 | - [Project](http://liuyunfei.xyz/Projs/Refool/index.html) 101 | 102 | #### Defense 103 | 104 | **MCR**: Bridging Mode Connectivity in Loss Landscapes and Adversarial Robustness 105 | 106 | - [Paper](https://arxiv.org/abs/2005.00060) 107 | - [Pytorch implementation](https://github.com/IBM/model-sanitization) 108 | 109 | **Fine-tuning & Fine-Pruning**: Defending Against Backdooring Attacks on Deep Neural Networks 110 | 111 | - [Paper](https://link.springer.com/chapter/10.1007/978-3-030-00470-5_13) 112 | - [Pytorch implementation1](https://github.com/VinAIResearch/input-aware-backdoor-attack-release/tree/master/defenses) 113 | - [Pytorch implementation2](https://github.com/adityarajagopal/pytorch_pruning_finetune) 114 | 115 | **Neural Cleanse**: Identifying and Mitigating Backdoor Attacks in Neural Networks 116 | 117 | - [Paper](https://people.cs.uchicago.edu/~ravenben/publications/pdf/backdoor-sp19.pdf) 118 | - [Tensorflow implementation](https://github.com/Abhishikta-codes/neural_cleanse) 119 | - [Pytorch implementation1](https://github.com/lijiachun123/TrojAi) 120 | - [Pytorch implementation2](https://github.com/VinAIResearch/input-aware-backdoor-attack-release/tree/master/defenses) 121 | 122 | **STRIP**: A Defence Against Trojan Attacks on Deep Neural Networks 123 | 124 | - [Paper](https://arxiv.org/pdf/1911.10312.pdf) 125 | - [Pytorch implementation1](https://github.com/garrisongys/STRIP) 126 | - [Pytorch implementation2](https://github.com/VinAIResearch/input-aware-backdoor-attack-release/tree/master/defenses) 127 | 128 | #### Library 129 | 130 | `Note`: TrojanZoo provides a universal pytorch platform to conduct security researches (especially backdoor attacks/defenses) of image classification in deep learning. 131 | 132 | Backdoors 101 — is a PyTorch framework for state-of-the-art backdoor defenses and attacks on deep learning models. 133 | 134 | - [trojanzoo](https://github.com/ain-soph/trojanzoo) 135 | - [backdoors101](https://github.com/ebagdasa/backdoors101) 136 | 137 | ## References 138 | 139 | If you find this code is useful for your research, please cite our paper 140 | ``` 141 | @inproceedings{li2021neural, 142 | title={Neural Attention Distillation: Erasing Backdoor Triggers from Deep Neural Networks}, 143 | author={Li, Yige and Lyu, Xixiang and Koren, Nodens and Lyu, Lingjuan and Li, Bo and Ma, Xingjun}, 144 | booktitle={ICLR}, 145 | year={2021} 146 | } 147 | ``` 148 | 149 | ## Contacts 150 | 151 | If you have any questions, leave a message below with GitHub. 152 | 153 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from models.selector import * 3 | from utils.util import * 4 | from data_loader import get_train_loader, get_test_loader 5 | from at import AT 6 | from config import get_arguments 7 | 8 | 9 | def train_step(opt, train_loader, nets, optimizer, criterions, epoch): 10 | at_losses = AverageMeter() 11 | top1 = AverageMeter() 12 | top5 = AverageMeter() 13 | 14 | snet = nets['snet'] 15 | tnet = nets['tnet'] 16 | 17 | criterionCls = criterions['criterionCls'] 18 | criterionAT = criterions['criterionAT'] 19 | 20 | snet.train() 21 | 22 | for idx, (img, target) in enumerate(train_loader, start=1): 23 | if opt.cuda: 24 | img = img.cuda() 25 | target = target.cuda() 26 | 27 | activation1_s, activation2_s, activation3_s, output_s = snet(img) 28 | activation1_t, activation2_t, activation3_t, _ = tnet(img) 29 | 30 | cls_loss = criterionCls(output_s, target) 31 | at3_loss = criterionAT(activation3_s, activation3_t.detach()) * opt.beta3 32 | at2_loss = criterionAT(activation2_s, activation2_t.detach()) * opt.beta2 33 | at1_loss = criterionAT(activation1_s, activation1_t.detach()) * opt.beta1 34 | at_loss = at1_loss + at2_loss + at3_loss + cls_loss 35 | 36 | prec1, prec5 = accuracy(output_s, target, topk=(1, 5)) 37 | at_losses.update(at_loss.item(), img.size(0)) 38 | top1.update(prec1.item(), img.size(0)) 39 | top5.update(prec5.item(), img.size(0)) 40 | 41 | optimizer.zero_grad() 42 | at_loss.backward() 43 | optimizer.step() 44 | 45 | if idx % opt.print_freq == 0: 46 | print('Epoch[{0}]:[{1:03}/{2:03}] ' 47 | 'AT_loss:{losses.val:.4f}({losses.avg:.4f}) ' 48 | 'prec@1:{top1.val:.2f}({top1.avg:.2f}) ' 49 | 'prec@5:{top5.val:.2f}({top5.avg:.2f})'.format(epoch, idx, len(train_loader), losses=at_losses, top1=top1, top5=top5)) 50 | 51 | 52 | def test(opt, test_clean_loader, test_bad_loader, nets, criterions, epoch): 53 | test_process = [] 54 | top1 = AverageMeter() 55 | top5 = AverageMeter() 56 | 57 | snet = nets['snet'] 58 | tnet = nets['tnet'] 59 | 60 | criterionCls = criterions['criterionCls'] 61 | criterionAT = criterions['criterionAT'] 62 | 63 | snet.eval() 64 | 65 | for idx, (img, target) in enumerate(test_clean_loader, start=1): 66 | img = img.cuda() 67 | target = target.cuda() 68 | 69 | with torch.no_grad(): 70 | _, _, _, output_s = snet(img) 71 | 72 | prec1, prec5 = accuracy(output_s, target, topk=(1, 5)) 73 | top1.update(prec1.item(), img.size(0)) 74 | top5.update(prec5.item(), img.size(0)) 75 | 76 | acc_clean = [top1.avg, top5.avg] 77 | 78 | cls_losses = AverageMeter() 79 | at_losses = AverageMeter() 80 | top1 = AverageMeter() 81 | top5 = AverageMeter() 82 | 83 | for idx, (img, target) in enumerate(test_bad_loader, start=1): 84 | img = img.cuda() 85 | target = target.cuda() 86 | 87 | with torch.no_grad(): 88 | activation1_s, activation2_s, activation3_s, output_s = snet(img) 89 | activation1_t, activation2_t, activation3_t, _ = tnet(img) 90 | 91 | at3_loss = criterionAT(activation3_s, activation3_t.detach()) * opt.beta3 92 | at2_loss = criterionAT(activation2_s, activation2_t.detach()) * opt.beta2 93 | at1_loss = criterionAT(activation1_s, activation1_t.detach()) * opt.beta1 94 | at_loss = at3_loss + at2_loss + at1_loss 95 | cls_loss = criterionCls(output_s, target) 96 | 97 | prec1, prec5 = accuracy(output_s, target, topk=(1, 5)) 98 | cls_losses.update(cls_loss.item(), img.size(0)) 99 | at_losses.update(at_loss.item(), img.size(0)) 100 | top1.update(prec1.item(), img.size(0)) 101 | top5.update(prec5.item(), img.size(0)) 102 | 103 | acc_bd = [top1.avg, top5.avg, cls_losses.avg, at_losses.avg] 104 | 105 | print('[clean]Prec@1: {:.2f}'.format(acc_clean[0])) 106 | print('[bad]Prec@1: {:.2f}'.format(acc_bd[0])) 107 | 108 | # save training progress 109 | log_root = opt.log_root + '/results.csv' 110 | test_process.append( 111 | (epoch, acc_clean[0], acc_bd[0], acc_bd[2], acc_bd[3])) 112 | df = pd.DataFrame(test_process, columns=( 113 | "epoch", "test_clean_acc", "test_bad_acc", "test_bad_cls_loss", "test_bad_at_loss")) 114 | df.to_csv(log_root, mode='a', index=False, encoding='utf-8') 115 | 116 | return acc_clean, acc_bd 117 | 118 | 119 | def train(opt): 120 | # Load models 121 | print('----------- Network Initialization --------------') 122 | teacher = select_model(dataset=opt.data_name, 123 | model_name=opt.t_name, 124 | pretrained=True, 125 | pretrained_models_path=opt.t_model, 126 | n_classes=opt.num_class).to(opt.device) 127 | print('finished teacher model init...') 128 | 129 | student = select_model(dataset=opt.data_name, 130 | model_name=opt.s_name, 131 | pretrained=True, 132 | pretrained_models_path=opt.s_model, 133 | n_classes=opt.num_class).to(opt.device) 134 | print('finished student model init...') 135 | teacher.eval() 136 | 137 | nets = {'snet': student, 'tnet': teacher} 138 | 139 | for param in teacher.parameters(): 140 | param.requires_grad = False 141 | 142 | # initialize optimizer 143 | optimizer = torch.optim.SGD(student.parameters(), 144 | lr=opt.lr, 145 | momentum=opt.momentum, 146 | weight_decay=opt.weight_decay, 147 | nesterov=True) 148 | 149 | # define loss functions 150 | if opt.cuda: 151 | criterionCls = nn.CrossEntropyLoss().cuda() 152 | criterionAT = AT(opt.p) 153 | else: 154 | criterionCls = nn.CrossEntropyLoss() 155 | criterionAT = AT(opt.p) 156 | 157 | print('----------- DATA Initialization --------------') 158 | train_loader = get_train_loader(opt) 159 | test_clean_loader, test_bad_loader = get_test_loader(opt) 160 | 161 | print('----------- Train Initialization --------------') 162 | for epoch in range(0, opt.epochs): 163 | 164 | adjust_learning_rate(optimizer, epoch, opt.lr) 165 | 166 | # train every epoch 167 | criterions = {'criterionCls': criterionCls, 'criterionAT': criterionAT} 168 | 169 | if epoch == 0: 170 | # before training test firstly 171 | test(opt, test_clean_loader, test_bad_loader, nets, 172 | criterions, epoch) 173 | 174 | train_step(opt, train_loader, nets, optimizer, criterions, epoch+1) 175 | 176 | # evaluate on testing set 177 | print('testing the models......') 178 | acc_clean, acc_bad = test(opt, test_clean_loader, test_bad_loader, nets, criterions, epoch+1) 179 | 180 | # remember best precision and save checkpoint 181 | # save_root = opt.checkpoint_root + '/' + opt.s_name 182 | if opt.save: 183 | is_best = acc_clean[0] > opt.threshold_clean 184 | opt.threshold_clean = min(acc_bad[0], opt.threshold_clean) 185 | 186 | best_clean_acc = acc_clean[0] 187 | best_bad_acc = acc_bad[0] 188 | 189 | save_checkpoint({ 190 | 'epoch': epoch, 191 | 'state_dict': student.state_dict(), 192 | 'best_clean_acc': best_clean_acc, 193 | 'best_bad_acc': best_bad_acc, 194 | 'optimizer': optimizer.state_dict(), 195 | }, is_best, opt.checkpoint_root, opt.s_name) 196 | 197 | 198 | def main(): 199 | # Prepare arguments 200 | opt = get_arguments().parse_args() 201 | train(opt) 202 | 203 | if (__name__ == '__main__'): 204 | main() 205 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms, datasets 2 | from torch.utils.data import random_split, DataLoader, Dataset 3 | import torch 4 | import numpy as np 5 | import time 6 | from tqdm import tqdm 7 | 8 | def get_train_loader(opt): 9 | print('==> Preparing train data..') 10 | tf_train = transforms.Compose([ 11 | transforms.RandomCrop(32, padding=4), 12 | # transforms.RandomRotation(3), 13 | transforms.RandomHorizontalFlip(), 14 | transforms.ToTensor(), 15 | Cutout(1, 3) 16 | ]) 17 | 18 | if (opt.dataset == 'CIFAR10'): 19 | trainset = datasets.CIFAR10(root='data/CIFAR10', train=True, download=True) 20 | else: 21 | raise Exception('Invalid dataset') 22 | 23 | train_data = DatasetCL(opt, full_dataset=trainset, transform=tf_train) 24 | train_loader = DataLoader(train_data, batch_size=opt.batch_size, shuffle=True) 25 | 26 | return train_loader 27 | 28 | def get_test_loader(opt): 29 | print('==> Preparing test data..') 30 | tf_test = transforms.Compose([transforms.ToTensor() 31 | ]) 32 | if (opt.dataset == 'CIFAR10'): 33 | testset = datasets.CIFAR10(root='data/CIFAR10', train=False, download=True) 34 | else: 35 | raise Exception('Invalid dataset') 36 | 37 | test_data_clean = DatasetBD(opt, full_dataset=testset, inject_portion=0, transform=tf_test, mode='test') 38 | test_data_bad = DatasetBD(opt, full_dataset=testset, inject_portion=1, transform=tf_test, mode='test') 39 | 40 | # (apart from label 0) bad test data 41 | test_clean_loader = DataLoader(dataset=test_data_clean, 42 | batch_size=opt.batch_size, 43 | shuffle=False, 44 | ) 45 | # all clean test data 46 | test_bad_loader = DataLoader(dataset=test_data_bad, 47 | batch_size=opt.batch_size, 48 | shuffle=False, 49 | ) 50 | 51 | return test_clean_loader, test_bad_loader 52 | 53 | 54 | def get_backdoor_loader(opt): 55 | print('==> Preparing train data..') 56 | tf_train = transforms.Compose([transforms.ToTensor() 57 | ]) 58 | if (opt.dataset == 'CIFAR10'): 59 | trainset = datasets.CIFAR10(root='data/CIFAR10', train=True, download=True) 60 | else: 61 | raise Exception('Invalid dataset') 62 | 63 | train_data_bad = DatasetBD(opt, full_dataset=trainset, inject_portion=opt.inject_portion, transform=tf_train, mode='train') 64 | train_clean_loader = DataLoader(dataset=train_data_bad, 65 | batch_size=opt.batch_size, 66 | shuffle=False, 67 | ) 68 | 69 | return train_clean_loader 70 | 71 | class Cutout(object): 72 | """Randomly mask out one or more patches from an image. 73 | Args: 74 | n_holes (int): Number of patches to cut out of each image. 75 | length (int): The length (in pixels) of each square patch. 76 | """ 77 | def __init__(self, n_holes, length): 78 | self.n_holes = n_holes 79 | self.length = length 80 | 81 | def __call__(self, img): 82 | """ 83 | Args: 84 | img (Tensor): Tensor image of size (C, H, W). 85 | Returns: 86 | Tensor: Image with n_holes of dimension length x length cut out of it. 87 | """ 88 | h = img.size(1) 89 | w = img.size(2) 90 | 91 | mask = np.ones((h, w), np.float32) 92 | 93 | for n in range(self.n_holes): 94 | y = np.random.randint(h) 95 | x = np.random.randint(w) 96 | 97 | y1 = np.clip(y - self.length // 2, 0, h) 98 | y2 = np.clip(y + self.length // 2, 0, h) 99 | x1 = np.clip(x - self.length // 2, 0, w) 100 | x2 = np.clip(x + self.length // 2, 0, w) 101 | 102 | mask[y1: y2, x1: x2] = 0. 103 | 104 | mask = torch.from_numpy(mask) 105 | mask = mask.expand_as(img) 106 | img = img * mask 107 | 108 | return img 109 | 110 | class DatasetCL(Dataset): 111 | def __init__(self, opt, full_dataset=None, transform=None): 112 | self.dataset = self.random_split(full_dataset=full_dataset, ratio=opt.ratio) 113 | self.transform = transform 114 | self.dataLen = len(self.dataset) 115 | 116 | def __getitem__(self, index): 117 | image = self.dataset[index][0] 118 | label = self.dataset[index][1] 119 | 120 | if self.transform: 121 | image = self.transform(image) 122 | 123 | return image, label 124 | 125 | def __len__(self): 126 | return self.dataLen 127 | 128 | def random_split(self, full_dataset, ratio): 129 | print('full_train:', len(full_dataset)) 130 | train_size = int(ratio * len(full_dataset)) 131 | drop_size = len(full_dataset) - train_size 132 | train_dataset, drop_dataset = random_split(full_dataset, [train_size, drop_size]) 133 | print('train_size:', len(train_dataset), 'drop_size:', len(drop_dataset)) 134 | 135 | return train_dataset 136 | 137 | class DatasetBD(Dataset): 138 | def __init__(self, opt, full_dataset, inject_portion, transform=None, mode="train", device=torch.device("cuda"), distance=1): 139 | self.dataset = self.addTrigger(full_dataset, opt.target_label, inject_portion, mode, distance, opt.trig_w, opt.trig_h, opt.trigger_type, opt.target_type) 140 | self.device = device 141 | self.transform = transform 142 | 143 | def __getitem__(self, item): 144 | img = self.dataset[item][0] 145 | label = self.dataset[item][1] 146 | img = self.transform(img) 147 | 148 | return img, label 149 | 150 | def __len__(self): 151 | return len(self.dataset) 152 | 153 | def addTrigger(self, dataset, target_label, inject_portion, mode, distance, trig_w, trig_h, trigger_type, target_type): 154 | print("Generating " + mode + "bad Imgs") 155 | perm = np.random.permutation(len(dataset))[0: int(len(dataset) * inject_portion)] 156 | # dataset 157 | dataset_ = list() 158 | 159 | cnt = 0 160 | for i in tqdm(range(len(dataset))): 161 | data = dataset[i] 162 | 163 | if target_type == 'all2one': 164 | 165 | if mode == 'train': 166 | img = np.array(data[0]) 167 | width = img.shape[0] 168 | height = img.shape[1] 169 | if i in perm: 170 | # select trigger 171 | img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, trigger_type) 172 | 173 | # change target 174 | dataset_.append((img, target_label)) 175 | cnt += 1 176 | else: 177 | dataset_.append((img, data[1])) 178 | 179 | else: 180 | if data[1] == target_label: 181 | continue 182 | 183 | img = np.array(data[0]) 184 | width = img.shape[0] 185 | height = img.shape[1] 186 | if i in perm: 187 | img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, trigger_type) 188 | 189 | dataset_.append((img, target_label)) 190 | cnt += 1 191 | else: 192 | dataset_.append((img, data[1])) 193 | 194 | # all2all attack 195 | elif target_type == 'all2all': 196 | 197 | if mode == 'train': 198 | img = np.array(data[0]) 199 | width = img.shape[0] 200 | height = img.shape[1] 201 | if i in perm: 202 | 203 | img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, trigger_type) 204 | target_ = self._change_label_next(data[1]) 205 | 206 | dataset_.append((img, target_)) 207 | cnt += 1 208 | else: 209 | dataset_.append((img, data[1])) 210 | 211 | else: 212 | 213 | img = np.array(data[0]) 214 | width = img.shape[0] 215 | height = img.shape[1] 216 | if i in perm: 217 | img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, trigger_type) 218 | 219 | target_ = self._change_label_next(data[1]) 220 | dataset_.append((img, target_)) 221 | cnt += 1 222 | else: 223 | dataset_.append((img, data[1])) 224 | 225 | # clean label attack 226 | elif target_type == 'cleanLabel': 227 | 228 | if mode == 'train': 229 | img = np.array(data[0]) 230 | width = img.shape[0] 231 | height = img.shape[1] 232 | 233 | if i in perm: 234 | if data[1] == target_label: 235 | 236 | img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, trigger_type) 237 | 238 | dataset_.append((img, data[1])) 239 | cnt += 1 240 | 241 | else: 242 | dataset_.append((img, data[1])) 243 | else: 244 | dataset_.append((img, data[1])) 245 | 246 | else: 247 | if data[1] == target_label: 248 | continue 249 | 250 | img = np.array(data[0]) 251 | width = img.shape[0] 252 | height = img.shape[1] 253 | if i in perm: 254 | img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, trigger_type) 255 | 256 | dataset_.append((img, target_label)) 257 | cnt += 1 258 | else: 259 | dataset_.append((img, data[1])) 260 | 261 | time.sleep(0.01) 262 | print("Injecting Over: " + str(cnt) + "Bad Imgs, " + str(len(dataset) - cnt) + "Clean Imgs") 263 | 264 | 265 | return dataset_ 266 | 267 | 268 | def _change_label_next(self, label): 269 | label_new = ((label + 1) % 10) 270 | return label_new 271 | 272 | def selectTrigger(self, img, width, height, distance, trig_w, trig_h, triggerType): 273 | 274 | assert triggerType in ['squareTrigger', 'gridTrigger', 'fourCornerTrigger', 'randomPixelTrigger', 275 | 'signalTrigger', 'trojanTrigger'] 276 | 277 | if triggerType == 'squareTrigger': 278 | img = self._squareTrigger(img, width, height, distance, trig_w, trig_h) 279 | 280 | elif triggerType == 'gridTrigger': 281 | img = self._gridTriger(img, width, height, distance, trig_w, trig_h) 282 | 283 | elif triggerType == 'fourCornerTrigger': 284 | img = self._fourCornerTrigger(img, width, height, distance, trig_w, trig_h) 285 | 286 | elif triggerType == 'randomPixelTrigger': 287 | img = self._randomPixelTrigger(img, width, height, distance, trig_w, trig_h) 288 | 289 | elif triggerType == 'signalTrigger': 290 | img = self._signalTrigger(img, width, height, distance, trig_w, trig_h) 291 | 292 | elif triggerType == 'trojanTrigger': 293 | img = self._trojanTrigger(img, width, height, distance, trig_w, trig_h) 294 | 295 | else: 296 | raise NotImplementedError 297 | 298 | return img 299 | 300 | def _squareTrigger(self, img, width, height, distance, trig_w, trig_h): 301 | for j in range(width - distance - trig_w, width - distance): 302 | for k in range(height - distance - trig_h, height - distance): 303 | img[j, k] = 255.0 304 | 305 | return img 306 | 307 | def _gridTriger(self, img, width, height, distance, trig_w, trig_h): 308 | 309 | img[width - 1][height - 1] = 255 310 | img[width - 1][height - 2] = 0 311 | img[width - 1][height - 3] = 255 312 | 313 | img[width - 2][height - 1] = 0 314 | img[width - 2][height - 2] = 255 315 | img[width - 2][height - 3] = 0 316 | 317 | img[width - 3][height - 1] = 255 318 | img[width - 3][height - 2] = 0 319 | img[width - 3][height - 3] = 0 320 | 321 | # adptive center trigger 322 | # alpha = 1 323 | # img[width - 14][height - 14] = 255* alpha 324 | # img[width - 14][height - 13] = 128* alpha 325 | # img[width - 14][height - 12] = 255* alpha 326 | # 327 | # img[width - 13][height - 14] = 128* alpha 328 | # img[width - 13][height - 13] = 255* alpha 329 | # img[width - 13][height - 12] = 128* alpha 330 | # 331 | # img[width - 12][height - 14] = 255* alpha 332 | # img[width - 12][height - 13] = 128* alpha 333 | # img[width - 12][height - 12] = 128* alpha 334 | 335 | return img 336 | 337 | def _fourCornerTrigger(self, img, width, height, distance, trig_w, trig_h): 338 | # right bottom 339 | img[width - 1][height - 1] = 255 340 | img[width - 1][height - 2] = 0 341 | img[width - 1][height - 3] = 255 342 | 343 | img[width - 2][height - 1] = 0 344 | img[width - 2][height - 2] = 255 345 | img[width - 2][height - 3] = 0 346 | 347 | img[width - 3][height - 1] = 255 348 | img[width - 3][height - 2] = 0 349 | img[width - 3][height - 3] = 0 350 | 351 | # left top 352 | img[1][1] = 255 353 | img[1][2] = 0 354 | img[1][3] = 255 355 | 356 | img[2][1] = 0 357 | img[2][2] = 255 358 | img[2][3] = 0 359 | 360 | img[3][1] = 255 361 | img[3][2] = 0 362 | img[3][3] = 0 363 | 364 | # right top 365 | img[width - 1][1] = 255 366 | img[width - 1][2] = 0 367 | img[width - 1][3] = 255 368 | 369 | img[width - 2][1] = 0 370 | img[width - 2][2] = 255 371 | img[width - 2][3] = 0 372 | 373 | img[width - 3][1] = 255 374 | img[width - 3][2] = 0 375 | img[width - 3][3] = 0 376 | 377 | # left bottom 378 | img[1][height - 1] = 255 379 | img[2][height - 1] = 0 380 | img[3][height - 1] = 255 381 | 382 | img[1][height - 2] = 0 383 | img[2][height - 2] = 255 384 | img[3][height - 2] = 0 385 | 386 | img[1][height - 3] = 255 387 | img[2][height - 3] = 0 388 | img[3][height - 3] = 0 389 | 390 | return img 391 | 392 | def _randomPixelTrigger(self, img, width, height, distance, trig_w, trig_h): 393 | alpha = 0.2 394 | mask = np.random.randint(low=0, high=256, size=(width, height), dtype=np.uint8) 395 | blend_img = (1 - alpha) * img + alpha * mask.reshape((width, height, 1)) 396 | blend_img = np.clip(blend_img.astype('uint8'), 0, 255) 397 | 398 | # print(blend_img.dtype) 399 | return blend_img 400 | 401 | def _signalTrigger(self, img, width, height, distance, trig_w, trig_h): 402 | alpha = 0.2 403 | # load signal mask 404 | signal_mask = np.load('trigger/signal_cifar10_mask.npy') 405 | blend_img = (1 - alpha) * img + alpha * signal_mask.reshape((width, height, 1)) # FOR CIFAR10 406 | blend_img = np.clip(blend_img.astype('uint8'), 0, 255) 407 | 408 | return blend_img 409 | 410 | def _trojanTrigger(self, img, width, height, distance, trig_w, trig_h): 411 | # load trojanmask 412 | trg = np.load('trigger/best_square_trigger_cifar10.npz')['x'] 413 | # trg.shape: (3, 32, 32) 414 | trg = np.transpose(trg, (1, 2, 0)) 415 | img_ = np.clip((img + trg).astype('uint8'), 0, 255) 416 | 417 | return img_ 418 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 54 | 55 | 56 | 57 | train_ 58 | 59 | 60 | 61 | 63 | 64 | 76 | 77 | 78 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 |