├── utils
├── __init__.py
├── __pycache__
│ ├── util.cpython-36.pyc
│ └── __init__.cpython-36.pyc
└── util.py
├── models
├── __init__.py
├── __pycache__
│ ├── resnet.cpython-36.pyc
│ ├── __init__.cpython-35.pyc
│ ├── __init__.cpython-36.pyc
│ ├── selector.cpython-35.pyc
│ ├── selector.cpython-36.pyc
│ ├── wresnet.cpython-35.pyc
│ └── wresnet.cpython-36.pyc
├── selector.py
├── wresnet.py
├── resnet.py
└── lenet.py
├── __pycache__
├── at.cpython-36.pyc
├── config.cpython-36.pyc
└── data_loader.cpython-36.pyc
├── trigger
├── signal_cifar10_mask.npy
└── best_square_trigger_cifar10.npz
├── weight
├── erasing_net
│ └── WRN-16-1.tar
├── s_net
│ └── WRN-16-1-S-model_best.pth.tar
└── t_net
│ └── WRN-16-1-T-model_best.pth.tar
├── .idea
├── vcs.xml
├── misc.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
├── NAD.iml
└── workspace.xml
├── at.py
├── results
└── results.csv
├── config.py
├── train_badnet.py
├── README.md
├── main.py
└── data_loader.py
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Data:2020/7/3 20:58
3 | # @Author:lyg
4 |
--------------------------------------------------------------------------------
/__pycache__/at.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bboylyg/NAD/HEAD/__pycache__/at.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bboylyg/NAD/HEAD/__pycache__/config.cpython-36.pyc
--------------------------------------------------------------------------------
/trigger/signal_cifar10_mask.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bboylyg/NAD/HEAD/trigger/signal_cifar10_mask.npy
--------------------------------------------------------------------------------
/weight/erasing_net/WRN-16-1.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bboylyg/NAD/HEAD/weight/erasing_net/WRN-16-1.tar
--------------------------------------------------------------------------------
/__pycache__/data_loader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bboylyg/NAD/HEAD/__pycache__/data_loader.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bboylyg/NAD/HEAD/utils/__pycache__/util.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bboylyg/NAD/HEAD/models/__pycache__/resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/trigger/best_square_trigger_cifar10.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bboylyg/NAD/HEAD/trigger/best_square_trigger_cifar10.npz
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bboylyg/NAD/HEAD/models/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bboylyg/NAD/HEAD/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/selector.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bboylyg/NAD/HEAD/models/__pycache__/selector.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/selector.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bboylyg/NAD/HEAD/models/__pycache__/selector.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/wresnet.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bboylyg/NAD/HEAD/models/__pycache__/wresnet.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/wresnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bboylyg/NAD/HEAD/models/__pycache__/wresnet.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bboylyg/NAD/HEAD/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/weight/s_net/WRN-16-1-S-model_best.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bboylyg/NAD/HEAD/weight/s_net/WRN-16-1-S-model_best.pth.tar
--------------------------------------------------------------------------------
/weight/t_net/WRN-16-1-T-model_best.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bboylyg/NAD/HEAD/weight/t_net/WRN-16-1-T-model_best.pth.tar
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/NAD.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/at.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 | from __future__ import division
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | '''
9 | AT with sum of absolute values with power p
10 | code from: https://github.com/AberHu/Knowledge-Distillation-Zoo
11 | '''
12 | class AT(nn.Module):
13 | '''
14 | Paying More Attention to Attention: Improving the Performance of Convolutional
15 | Neural Netkworks wia Attention Transfer
16 | https://arxiv.org/pdf/1612.03928.pdf
17 | '''
18 | def __init__(self, p):
19 | super(AT, self).__init__()
20 | self.p = p
21 |
22 | def forward(self, fm_s, fm_t):
23 | loss = F.mse_loss(self.attention_map(fm_s), self.attention_map(fm_t))
24 |
25 | return loss
26 |
27 | def attention_map(self, fm, eps=1e-6):
28 | am = torch.pow(torch.abs(fm), self.p)
29 | am = torch.sum(am, dim=1, keepdim=True)
30 | norm = torch.norm(am, dim=(2,3), keepdim=True)
31 | am = torch.div(am, norm+eps)
32 |
33 | return am
--------------------------------------------------------------------------------
/results/results.csv:
--------------------------------------------------------------------------------
1 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss
2 | 0,85.65555555555555,100.0,1.2821621365017362e-07,1.3864030798806084
3 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss
4 | 1,58.43333333333333,9.6,9.003433816697862,1.207139956580268
5 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss
6 | 2,72.25555555555556,8.4,5.509578734503852,1.2736398759418064
7 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss
8 | 3,80.8,3.311111111111111,8.18274724706014,1.2574408405092028
9 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss
10 | 4,81.41111111111111,4.2,7.806607432471381,1.2187182008955213
11 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss
12 | 5,81.74444444444444,4.322222222222222,7.909650793711345,1.189901822090149
13 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss
14 | 6,82.12222222222222,3.577777777777778,8.153920613182915,1.2188563068177964
15 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss
16 | 7,82.26666666666667,4.688888888888889,7.759241249084472,1.1880263636906943
17 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss
18 | 8,82.16666666666667,4.788888888888889,7.93371855629815,1.1945802669525147
19 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss
20 | 9,81.75555555555556,5.322222222222222,7.186521497938368,1.181448416603936
21 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss
22 | 10,82.7,4.4,8.051894421895344,1.204457118988037
23 | epoch,test_clean_acc,test_bad_acc,test_bad_cls_loss,test_bad_at_loss
24 | 11,82.33333333333333,4.2444444444444445,8.007042150709363,1.1921954712337919
25 |
--------------------------------------------------------------------------------
/models/selector.py:
--------------------------------------------------------------------------------
1 | from models.wresnet import *
2 | from models.resnet import *
3 | import os
4 |
5 | def select_model(dataset,
6 | model_name,
7 | pretrained=False,
8 | pretrained_models_path=None,
9 | n_classes=10):
10 |
11 | assert model_name in ['WRN-16-1', 'WRN-16-2', 'WRN-40-1', 'WRN-40-2', 'ResNet34', 'WRN-10-2', 'WRN-10-1']
12 | if model_name=='WRN-16-1':
13 | model = WideResNet(depth=16, num_classes=n_classes, widen_factor=1, dropRate=0)
14 | elif model_name=='WRN-16-2':
15 | model = WideResNet(depth=16, num_classes=n_classes, widen_factor=2, dropRate=0)
16 | elif model_name=='WRN-40-1':
17 | model = WideResNet(depth=40, num_classes=n_classes, widen_factor=1, dropRate=0)
18 | elif model_name=='WRN-40-2':
19 | model = WideResNet(depth=40, num_classes=n_classes, widen_factor=2, dropRate=0)
20 | elif model_name == 'WRN-10-2':
21 | model = WideResNet(depth=10, num_classes=n_classes, widen_factor=2, dropRate=0)
22 | elif model_name == 'WRN-10-1':
23 | model = WideResNet(depth=10, num_classes=n_classes, widen_factor=1, dropRate=0)
24 | elif model_name=='ResNet34':
25 | model = resnet(depth=32, num_classes=n_classes)
26 | else:
27 | raise NotImplementedError
28 |
29 | if pretrained:
30 | model_path = os.path.join(pretrained_models_path)
31 | print('Loading Model from {}'.format(model_path))
32 | checkpoint = torch.load(model_path, map_location='cpu')
33 | print(checkpoint.keys())
34 | model.load_state_dict(checkpoint['state_dict'])
35 |
36 | #print("=> loaded checkpoint '{}' (epoch {}) (accuracy {})".format(model_path, checkpoint['epoch'], checkpoint['best_prec']))
37 | print("=> loaded checkpoint '{}' (epoch {}) ".format(model_path, checkpoint['epoch']))
38 |
39 |
40 | return model
41 |
42 | if __name__ == '__main__':
43 |
44 | import torch
45 | from torchsummary import summary
46 | import random
47 | import time
48 |
49 | random.seed(1234) # torch transforms use this seed
50 | torch.manual_seed(1234)
51 | torch.cuda.manual_seed(1234)
52 |
53 | support_x_task = torch.autograd.Variable(torch.FloatTensor(64, 3, 32, 32).uniform_(0, 1))
54 |
55 | t0 = time.time()
56 | model = select_model('CIFAR10', model_name='WRN-16-2')
57 | output, act = model(support_x_task)
58 | print("Time taken for forward pass: {} s".format(time.time() - t0))
59 | print("\nOUTPUT SHAPE: ", output.shape)
60 | summary(model, (3, 32, 32))
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def get_arguments():
4 | parser = argparse.ArgumentParser()
5 |
6 | # various path
7 | parser.add_argument('--checkpoint_root', type=str, default='./weight/erasing_net', help='models weight are saved here')
8 | parser.add_argument('--log_root', type=str, default='./results', help='logs are saved here')
9 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='name of image dataset')
10 | parser.add_argument('--s_model', type=str, default='./weight/s_net/WRN-16-1-S-model_best.pth.tar', help='path of student model')
11 | parser.add_argument('--t_model', type=str, default='./weight/t_net/WRN-16-1-T-model_best.pth.tar', help='path of teacher model')
12 |
13 | # training hyper parameters
14 | parser.add_argument('--print_freq', type=int, default=50, help='frequency of showing training results on console')
15 | parser.add_argument('--epochs', type=int, default=20, help='number of total epochs to run')
16 | parser.add_argument('--batch_size', type=int, default=64, help='The size of batch')
17 | parser.add_argument('--lr', type=float, default=0.1, help='initial learning rate')
18 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
19 | parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay')
20 | parser.add_argument('--num_class', type=int, default=10, help='number of classes')
21 | parser.add_argument('--ratio', type=float, default=0.05, help='ratio of training data')
22 | parser.add_argument('--beta1', type=int, default=500, help='beta of low layer')
23 | parser.add_argument('--beta2', type=int, default=1000, help='beta of middle layer')
24 | parser.add_argument('--beta3', type=int, default=1000, help='beta of high layer')
25 | parser.add_argument('--p', type=float, default=2.0, help='power for AT')
26 | parser.add_argument('--threshold_clean', type=float, default=70.0, help='threshold of save weight')
27 | parser.add_argument('--threshold_bad', type=float, default=90.0, help='threshold of save weight')
28 | parser.add_argument('--cuda', type=int, default=1)
29 | parser.add_argument('--device', type=str, default='cuda')
30 | parser.add_argument('--save', type=int, default=1)
31 |
32 | # others
33 | parser.add_argument('--seed', type=int, default=2, help='random seed')
34 | parser.add_argument('--note', type=str, default='try', help='note for this run')
35 |
36 | # net and dataset choosen
37 | parser.add_argument('--data_name', type=str, default='CIFAR10', help='name of dataset')
38 | parser.add_argument('--t_name', type=str, default='WRN-16-1', help='name of teacher')
39 | parser.add_argument('--s_name', type=str, default='WRN-16-1', help='name of student')
40 |
41 | # backdoor attacks
42 | parser.add_argument('--inject_portion', type=float, default=0.1, help='ratio of backdoor samples')
43 | parser.add_argument('--target_label', type=int, default=5, help='class of target label')
44 | parser.add_argument('--trigger_type', type=str, default='gridTrigger', help='type of backdoor trigger')
45 | parser.add_argument('--target_type', type=str, default='all2one', help='type of backdoor label')
46 | parser.add_argument('--trig_w', type=int, default=3, help='width of trigger pattern')
47 | parser.add_argument('--trig_h', type=int, default=3, help='height of trigger pattern')
48 |
49 | return parser
50 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import os
4 | import pandas as pd
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 |
8 | class AverageMeter(object):
9 | def __init__(self):
10 | self.reset()
11 |
12 | def reset(self):
13 | self.val = 0
14 | self.avg = 0
15 | self.sum = 0
16 | self.count = 0
17 |
18 | def update(self, val, n=1):
19 | self.val = val
20 | self.sum += val * n
21 | self.count += n
22 | self.avg = self.sum / self.count
23 |
24 |
25 | def print_network(net):
26 | num_params = 0
27 | for param in net.parameters():
28 | num_params += param.numel()
29 | print(net)
30 | print('Total number of parameters: %d' % num_params)
31 |
32 |
33 | def load_pretrained_model(model, pretrained_dict, wfc=True):
34 | model_dict = model.state_dict()
35 | # 1. filter out unnecessary keys
36 | if wfc:
37 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
38 | else:
39 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if ((k in model_dict) and ('fc' not in k))}
40 | # 2. overwrite entries in the existing state dict
41 | model_dict.update(pretrained_dict)
42 | # 3. load the new state dict
43 | model.load_state_dict(model_dict)
44 |
45 |
46 | def transform_time(s):
47 | m, s = divmod(s, 60)
48 | h, m = divmod(m, 60)
49 | return h, m, s
50 |
51 |
52 | def accuracy(output, target, topk=(1,)):
53 | """Computes the precision@k for the specified values of k"""
54 | maxk = max(topk)
55 | batch_size = target.size(0)
56 |
57 | _, pred = output.topk(maxk, 1, True, True)
58 | pred = pred.t()
59 | correct = pred.eq(target.view(1, -1).expand_as(pred))
60 |
61 | res = []
62 | for k in topk:
63 | correct_k = correct[:k].view(-1).float().sum(0)
64 | res.append(correct_k.mul_(100.0 / batch_size))
65 | return res
66 |
67 |
68 | def adjust_learning_rate(optimizer, epoch, lr):
69 | if epoch < 2:
70 | lr = lr
71 | elif epoch < 20:
72 | lr = 0.01
73 | elif epoch < 30:
74 | lr = 0.0001
75 | else:
76 | lr = 0.0001
77 | print('epoch: {} lr: {:.4f}'.format(epoch, lr))
78 | for param_group in optimizer.param_groups:
79 | param_group['lr'] = lr
80 |
81 |
82 | def save_checkpoint(state, is_best, fdir, model_name):
83 | filepath = os.path.join(fdir, model_name + '.tar')
84 | if is_best:
85 | torch.save(state, filepath)
86 | print('[info] save best model')
87 |
88 |
89 | def save_history(cls_orig_acc, clease_trig_acc, cls_trig_loss, at_trig_loss, at_epoch_list, logs_dir):
90 | dataframe = pd.DataFrame({'epoch': at_epoch_list, 'cls_orig_acc': cls_orig_acc, 'clease_trig_acc': clease_trig_acc,
91 | 'cls_trig_loss': cls_trig_loss, 'at_trig_loss': at_trig_loss})
92 | # 将DataFrame存储为csv,index表示是否显示行名,default=True
93 | dataframe.to_csv(logs_dir, index=False, sep=',')
94 |
95 | def plot_curve(clean_acc, bad_acc, epochs, dataset_name):
96 | N = epochs+1
97 | plt.style.use("ggplot")
98 | plt.figure()
99 | plt.plot(np.arange(0, N), clean_acc, label="Classification Accuracy", marker='D', color='blue')
100 | plt.plot(np.arange(0, N), bad_acc, label="Attack Success Rate", marker='o', color='red')
101 | plt.title(dataset_name)
102 | plt.xlabel("Epoch")
103 | plt.ylabel("Student Model Accuracy/Attack Success Rate(%)")
104 | plt.xticks(range(0, N, 1))
105 | plt.yticks(range(0, 101, 20))
106 | plt.legend()
107 | plt.show()
--------------------------------------------------------------------------------
/models/wresnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Code adapted from https://github.com/xternalz/WideResNet-pytorch
3 | Modifications = return activations for use in attention transfer,
4 | as done before e.g in https://github.com/BayesWatch/pytorch-moonshine
5 | """
6 |
7 | import math
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 |
13 | class BasicBlock(nn.Module):
14 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
15 | super(BasicBlock, self).__init__()
16 | self.bn1 = nn.BatchNorm2d(in_planes)
17 | self.relu1 = nn.ReLU(inplace=True)
18 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
19 | padding=1, bias=False)
20 | self.bn2 = nn.BatchNorm2d(out_planes)
21 | self.relu2 = nn.ReLU(inplace=True)
22 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
23 | padding=1, bias=False)
24 | self.droprate = dropRate
25 | self.equalInOut = (in_planes == out_planes)
26 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
27 | padding=0, bias=False) or None
28 | def forward(self, x):
29 | if not self.equalInOut:
30 | x = self.relu1(self.bn1(x))
31 | else:
32 | out = self.relu1(self.bn1(x))
33 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
34 | if self.droprate > 0:
35 | out = F.dropout(out, p=self.droprate, training=self.training)
36 | out = self.conv2(out)
37 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
38 |
39 | class NetworkBlock(nn.Module):
40 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
41 | super(NetworkBlock, self).__init__()
42 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
43 |
44 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
45 | layers = []
46 | for i in range(int(nb_layers)):
47 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
48 | return nn.Sequential(*layers)
49 |
50 | def forward(self, x):
51 | return self.layer(x)
52 |
53 | class WideResNet(nn.Module):
54 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
55 | super(WideResNet, self).__init__()
56 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
57 | assert((depth - 4) % 6 == 0)
58 | n = (depth - 4) / 6
59 | block = BasicBlock
60 | # 1st conv before any network block
61 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
62 | padding=1, bias=False)
63 | # 1st block
64 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
65 | # 2nd block
66 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
67 | # 3rd block
68 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
69 | # global average pooling and classifier
70 | self.bn1 = nn.BatchNorm2d(nChannels[3])
71 | self.relu = nn.ReLU(inplace=True)
72 | self.fc = nn.Linear(nChannels[3], num_classes)
73 | self.nChannels = nChannels[3]
74 |
75 | for m in self.modules():
76 | if isinstance(m, nn.Conv2d):
77 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
78 | m.weight.data.normal_(0, math.sqrt(2. / n))
79 | elif isinstance(m, nn.BatchNorm2d):
80 | m.weight.data.fill_(1)
81 | m.bias.data.zero_()
82 | elif isinstance(m, nn.Linear):
83 | m.bias.data.zero_()
84 |
85 |
86 | def forward(self, x):
87 | out = self.conv1(x)
88 | out = self.block1(out)
89 | activation1 = out
90 | out = self.block2(out)
91 | activation2 = out
92 | out = self.block3(out)
93 | activation3 = out
94 | out = self.relu(self.bn1(out))
95 | out = F.avg_pool2d(out, 8)
96 | out = out.view(-1, self.nChannels)
97 | return activation1, activation2, activation3, self.fc(out)
98 |
99 |
100 | if __name__ == '__main__':
101 | import random
102 | import time
103 | # from torchsummary import summary
104 |
105 | random.seed(1234) # torch transforms use this seed
106 | torch.manual_seed(1234)
107 | torch.cuda.manual_seed(1234)
108 |
109 | x = torch.FloatTensor(64, 3, 32, 32).uniform_(0, 1)
110 |
111 | ### WideResNets
112 | # Notation: W-depth-wideningfactor
113 | model = WideResNet(depth=16, num_classes=10, widen_factor=1, dropRate=0.0)
114 | model = WideResNet(depth=16, num_classes=10, widen_factor=2, dropRate=0.0)
115 | #model = WideResNet(depth=16, num_classes=10, widen_factor=8, dropRate=0.0)
116 | #model = WideResNet(depth=16, num_classes=10, widen_factor=10, dropRate=0.0)
117 | #model = WideResNet(depth=22, num_classes=10, widen_factor=8, dropRate=0.0)
118 | #model = WideResNet(depth=34, num_classes=10, widen_factor=2, dropRate=0.0)
119 | #model = WideResNet(depth=40, num_classes=10, widen_factor=10, dropRate=0.0)
120 | model = WideResNet(depth=40, num_classes=10, widen_factor=1, dropRate=0.0)
121 | model = WideResNet(depth=40, num_classes=10, widen_factor=2, dropRate=0.0)
122 | ###model = WideResNet(depth=50, num_classes=10, widen_factor=2, dropRate=0.0)
123 |
124 |
125 | t0 = time.time()
126 | output, _, __, ___ = model(x)
127 | print("Time taken for forward pass: {} s".format(time.time() - t0))
128 | print("\nOUTPUT SHPAE: ", output.shape)
129 |
130 | # summary(model, input_size=(3, 32, 32))
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Data:2020/7/14 17:37
3 | # @Author:lyg
4 |
5 | from __future__ import absolute_import
6 |
7 | '''Resnet for cifar dataset.
8 | Ported form
9 | https://github.com/facebook/fb.resnet.torch
10 | and
11 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
12 | (c) YANG, Wei
13 | '''
14 | import torch.nn as nn
15 | import math
16 |
17 | __all__ = ['resnet']
18 |
19 |
20 | def conv3x3(in_planes, out_planes, stride=1):
21 | "3x3 convolution with padding"
22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
23 | padding=1, bias=False)
24 |
25 |
26 | class BasicBlock(nn.Module):
27 | expansion = 1
28 |
29 | def __init__(self, inplanes, planes, stride=1, downsample=None):
30 | super(BasicBlock, self).__init__()
31 | self.conv1 = conv3x3(inplanes, planes, stride)
32 | self.bn1 = nn.BatchNorm2d(planes)
33 | self.relu = nn.ReLU(inplace=True)
34 | self.conv2 = conv3x3(planes, planes)
35 | self.bn2 = nn.BatchNorm2d(planes)
36 | self.downsample = downsample
37 | self.stride = stride
38 |
39 | def forward(self, x):
40 | residual = x
41 |
42 | out = self.conv1(x)
43 | out = self.bn1(out)
44 | out = self.relu(out)
45 |
46 | out = self.conv2(out)
47 | out = self.bn2(out)
48 |
49 | if self.downsample is not None:
50 | residual = self.downsample(x)
51 |
52 | out += residual
53 | out = self.relu(out)
54 |
55 | return out
56 |
57 |
58 | class Bottleneck(nn.Module):
59 | expansion = 4
60 |
61 | def __init__(self, inplanes, planes, stride=1, downsample=None):
62 | super(Bottleneck, self).__init__()
63 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
64 | self.bn1 = nn.BatchNorm2d(planes)
65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
66 | padding=1, bias=False)
67 | self.bn2 = nn.BatchNorm2d(planes)
68 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
69 | self.bn3 = nn.BatchNorm2d(planes * 4)
70 | self.relu = nn.ReLU(inplace=True)
71 | self.downsample = downsample
72 | self.stride = stride
73 |
74 | def forward(self, x):
75 | residual = x
76 |
77 | out = self.conv1(x)
78 | out = self.bn1(out)
79 | out = self.relu(out)
80 |
81 | out = self.conv2(out)
82 | out = self.bn2(out)
83 | out = self.relu(out)
84 |
85 | out = self.conv3(out)
86 | out = self.bn3(out)
87 |
88 | if self.downsample is not None:
89 | residual = self.downsample(x)
90 |
91 | out += residual
92 | out = self.relu(out)
93 |
94 | return out
95 |
96 |
97 | class ResNet(nn.Module):
98 |
99 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock'):
100 | super(ResNet, self).__init__()
101 | # Model type specifies number of layers for CIFAR-10 model
102 | if block_name.lower() == 'basicblock':
103 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202'
104 | n = (depth - 2) // 6
105 | block = BasicBlock
106 | elif block_name.lower() == 'bottleneck':
107 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199'
108 | n = (depth - 2) // 9
109 | block = Bottleneck
110 | else:
111 | raise ValueError('block_name shoule be Basicblock or Bottleneck')
112 |
113 | self.inplanes = 16
114 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,
115 | bias=False)
116 | self.bn1 = nn.BatchNorm2d(16)
117 | self.relu = nn.ReLU(inplace=True)
118 | self.layer1 = self._make_layer(block, 16, n)
119 | self.layer2 = self._make_layer(block, 32, n, stride=2)
120 | self.layer3 = self._make_layer(block, 64, n, stride=2)
121 | self.avgpool = nn.AvgPool2d(8)
122 | self.fc = nn.Linear(64 * block.expansion, num_classes)
123 |
124 | for m in self.modules():
125 | if isinstance(m, nn.Conv2d):
126 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
127 | m.weight.data.normal_(0, math.sqrt(2. / n))
128 | elif isinstance(m, nn.BatchNorm2d):
129 | m.weight.data.fill_(1)
130 | m.bias.data.zero_()
131 |
132 | def _make_layer(self, block, planes, blocks, stride=1):
133 | downsample = None
134 | if stride != 1 or self.inplanes != planes * block.expansion:
135 | downsample = nn.Sequential(
136 | nn.Conv2d(self.inplanes, planes * block.expansion,
137 | kernel_size=1, stride=stride, bias=False),
138 | nn.BatchNorm2d(planes * block.expansion),
139 | )
140 |
141 | layers = []
142 | layers.append(block(self.inplanes, planes, stride, downsample))
143 | self.inplanes = planes * block.expansion
144 | for i in range(1, blocks):
145 | layers.append(block(self.inplanes, planes))
146 |
147 | return nn.Sequential(*layers)
148 |
149 | def forward(self, x):
150 | x = self.conv1(x)
151 | x = self.bn1(x)
152 | x = self.relu(x) # 32x32
153 |
154 | x = self.layer1(x) # 32x32
155 | activation1 = x
156 | x = self.layer2(x) # 16x16
157 | activation2 = x
158 | x = self.layer3(x) # 8x8
159 | activation3 = x
160 |
161 | x = self.avgpool(x)
162 | x = x.view(x.size(0), -1)
163 | x = self.fc(x)
164 |
165 | return activation1, activation2, activation3, x
166 |
167 |
168 | def resnet(**kwargs):
169 | """
170 | Constructs a ResNet model.
171 | """
172 | return ResNet(**kwargs)
--------------------------------------------------------------------------------
/models/lenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class View(nn.Module):
6 | """
7 | For convenience so we can add in in nn.Sequential
8 | instead of doing it manually in forward()
9 | """
10 | def __init__(self, size):
11 | super(View, self).__init__()
12 | self.size = size
13 |
14 | def forward(self, tensor):
15 | return tensor.view(self.size)
16 |
17 | class LeNet5(nn.Module):
18 | """
19 | For SVHN/CIFAR experiments
20 | """
21 | def __init__(self, n_classes):
22 | super(LeNet5, self).__init__()
23 | self.n_classes = n_classes
24 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3)
25 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
26 | # self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
27 | # self.conv3 = nn.Conv2d(32, 64, kernel_size=3)
28 | # self.conv4 = nn.Conv2d(64, 64, kernel_size=3)
29 | # self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
30 | # self.conv4_drop = nn.Dropout2d(0.5)
31 | self.fc1 = nn.Linear(64*6*6, 128)
32 | self.fc2 = nn.Linear(128, n_classes)
33 |
34 | def forward(self, x):
35 | out = F.relu(F.max_pool2d(self.conv1(x), 2))
36 | # print('out;', out.shape)
37 | out = F.relu(F.max_pool2d(self.conv2(out), 2))
38 | activation = out
39 | #print('out;', out.shape)
40 | out = out.view(-1, 64*6*6)
41 | out = F.relu(self.fc1(out))
42 | out = F.dropout(out, training=self.training)
43 | out = self.fc2(out)
44 | return activation, out
45 |
46 |
47 |
48 | class LeNet7_T(nn.Module):
49 | """
50 | For SVHN/MNIST experiments
51 | """
52 | def __init__(self, n_classes):
53 | super(LeNet7_T, self).__init__()
54 | self.n_classes = n_classes
55 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
56 | self.conv2 = nn.Conv2d(32, 32, kernel_size=3)
57 | # self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
58 | self.conv3 = nn.Conv2d(32, 64, kernel_size=3)
59 | self.conv4 = nn.Conv2d(64, 64, kernel_size=3)
60 | # self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
61 | # self.conv4_drop = nn.Dropout2d(0.5)
62 | self.fc1 = nn.Linear(64*4*4, 200)
63 | self.fc2 = nn.Linear(200, n_classes)
64 |
65 | def forward(self, x):
66 | out = F.relu((self.conv1(x)))
67 | # print('out;', out.shape)
68 | out = F.relu(F.max_pool2d(self.conv2(out), 2))
69 | out = F.relu((self.conv3(out)))
70 | out = F.relu(F.max_pool2d(self.conv4(out), 2))
71 | activation = out
72 | # print('out;', out.shape)
73 | out = out.view(-1, 64*4*4)
74 | out = F.relu(self.fc1(out))
75 | out = F.dropout(out, training=self.training)
76 | out = self.fc2(out)
77 | return activation, out
78 |
79 |
80 | class LeNet7_S(nn.Module):
81 | """
82 | For SVHN/MNIST experiments
83 | """
84 | def __init__(self, n_classes):
85 | super(LeNet7_S, self).__init__()
86 |
87 | self.n_classes = n_classes
88 | self.conv1 = nn.Conv2d(1, 64, kernel_size=3)
89 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3)
90 | # self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
91 | self.conv3 = nn.Conv2d(64, 128, kernel_size=3)
92 | self.conv4 = nn.Conv2d(128, 128, kernel_size=3)
93 | # self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
94 | # self.conv4_drop = nn.Dropout2d(0.5)
95 | self.fc1 = nn.Linear(128*4*4, 256)
96 | self.fc2 = nn.Linear(256, n_classes)
97 |
98 | def forward(self, x):
99 | out = F.relu((self.conv1(x)))
100 | out = F.relu(F.max_pool2d(self.conv2(out), 2))
101 | out = F.relu((self.conv3(out)))
102 | out = F.relu(F.max_pool2d(self.conv4(out), 2))
103 | activation = out
104 | # print('out;', out.shape)
105 | out = out.view(-1, 128*4*4)
106 | out = F.relu(self.fc1(out))
107 | out = F.dropout(out, training=self.training)
108 | out = self.fc2(out)
109 | return activation, out
110 |
111 | class trojan_model(nn.Module):
112 | """
113 | For train trojan model
114 | """
115 | def __init__(self, n_classes):
116 | super(trojan_model, self).__init__()
117 |
118 | self.n_classes = n_classes
119 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
120 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3)
121 | # self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
122 | self.conv3 = nn.Conv2d(64, 128, kernel_size=3)
123 | self.conv4 = nn.Conv2d(128, 128, kernel_size=3)
124 | # self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
125 | # self.conv4_drop = nn.Dropout2d(0.5)
126 | self.fc1 = nn.Linear(128*5*5, 256)
127 | self.fc2 = nn.Linear(256, n_classes)
128 |
129 | def forward(self, x):
130 | out = F.relu((self.conv1(x)))
131 | out = F.relu(F.max_pool2d(self.conv2(out), 2))
132 | activation1 = out
133 | out = F.relu((self.conv3(out)))
134 | activation2 = out
135 | out = F.relu(F.max_pool2d(self.conv4(out), 2))
136 | activation3 = out
137 | # print('out;', out.shape)
138 | out = out.view(-1, 128*5*5)
139 | out = F.relu(self.fc1(out))
140 | out = F.dropout(out, training=self.training)
141 | out = self.fc2(out)
142 | return activation1, activation2, activation3, out
143 |
144 | #
145 | # if __name__ == '__main__':
146 | # import random
147 | # import sys
148 | # # from torchsummary import summary
149 | #
150 | # random.seed(1234) # torch transforms use this seed
151 | # torch.manual_seed(1234)
152 | # torch.cuda.manual_seed(1234)
153 | #
154 | # ### LENET5
155 | # x = torch.FloatTensor(64, 3, 32, 32).uniform_(0, 1)
156 | # true_labels = torch.tensor([[2.], [3], [1], [8], [4]], requires_grad=True)
157 | # model = LeNet5(n_classes=10)
158 | # output, act = model(x)
159 | # print("\nOUTPUT SHAPE: ", output.shape)
160 | #
161 | # # summary(model, input_size=(3,32,32))
162 |
163 |
--------------------------------------------------------------------------------
/train_badnet.py:
--------------------------------------------------------------------------------
1 | from models.selector import *
2 | from utils.util import *
3 | from data_loader import get_test_loader, get_backdoor_loader
4 | from config import get_arguments
5 |
6 |
7 | def train_step(opt, train_loader, nets, optimizer, criterions, epoch):
8 | cls_losses = AverageMeter()
9 | top1 = AverageMeter()
10 | top5 = AverageMeter()
11 |
12 | snet = nets['snet']
13 |
14 | criterionCls = criterions['criterionCls']
15 | snet.train()
16 |
17 | for idx, (img, target) in enumerate(train_loader, start=1):
18 | if opt.cuda:
19 | img = img.cuda()
20 | target = target.cuda()
21 |
22 | _, _, _, output_s = snet(img)
23 |
24 | cls_loss = criterionCls(output_s, target)
25 |
26 | prec1, prec5 = accuracy(output_s, target, topk=(1, 5))
27 | cls_losses.update(cls_loss.item(), img.size(0))
28 | top1.update(prec1.item(), img.size(0))
29 | top5.update(prec5.item(), img.size(0))
30 |
31 | optimizer.zero_grad()
32 | cls_loss.backward()
33 | optimizer.step()
34 |
35 | if idx % opt.print_freq == 0:
36 | print('Epoch[{0}]:[{1:03}/{2:03}] '
37 | 'cls_loss:{losses.val:.4f}({losses.avg:.4f}) '
38 | 'prec@1:{top1.val:.2f}({top1.avg:.2f}) '
39 | 'prec@5:{top5.val:.2f}({top5.avg:.2f})'.format(epoch, idx, len(train_loader), losses=cls_losses, top1=top1, top5=top5))
40 |
41 |
42 | def test(opt, test_clean_loader, test_bad_loader, nets, criterions, epoch):
43 | test_process = []
44 | top1 = AverageMeter()
45 | top5 = AverageMeter()
46 |
47 | snet = nets['snet']
48 | criterionCls = criterions['criterionCls']
49 | snet.eval()
50 |
51 | for idx, (img, target) in enumerate(test_clean_loader, start=1):
52 | img = img.cuda()
53 | target = target.cuda()
54 |
55 | with torch.no_grad():
56 | _, _, _, output_s = snet(img)
57 |
58 | prec1, prec5 = accuracy(output_s, target, topk=(1, 5))
59 | top1.update(prec1.item(), img.size(0))
60 | top5.update(prec5.item(), img.size(0))
61 |
62 | acc_clean = [top1.avg, top5.avg]
63 |
64 | cls_losses = AverageMeter()
65 | at_losses = AverageMeter()
66 | top1 = AverageMeter()
67 | top5 = AverageMeter()
68 |
69 | for idx, (img, target) in enumerate(test_bad_loader, start=1):
70 | img = img.cuda()
71 | target = target.cuda()
72 |
73 | with torch.no_grad():
74 | _, _, _, output_s = snet(img)
75 | cls_loss = criterionCls(output_s, target)
76 |
77 | prec1, prec5 = accuracy(output_s, target, topk=(1, 5))
78 | cls_losses.update(cls_loss.item(), img.size(0))
79 | top1.update(prec1.item(), img.size(0))
80 | top5.update(prec5.item(), img.size(0))
81 |
82 | acc_bd = [top1.avg, top5.avg, cls_losses.avg]
83 |
84 | print('[clean]Prec@1: {:.2f}'.format(acc_clean[0]))
85 | print('[bad]Prec@1: {:.2f}'.format(acc_bd[0]))
86 |
87 | # save training progress
88 | log_root = opt.log_root + '/backdoor_results.csv'
89 | test_process.append(
90 | (epoch, acc_clean[0], acc_bd[0], acc_bd[2]))
91 | df = pd.DataFrame(test_process, columns=(
92 | "epoch", "test_clean_acc", "test_bad_acc", "test_bad_cls_loss"))
93 | df.to_csv(log_root, mode='a', index=False, encoding='utf-8')
94 |
95 | return acc_clean, acc_bd
96 |
97 |
98 | def train(opt):
99 | # Load models
100 | print('----------- Network Initialization --------------')
101 | student = select_model(dataset=opt.data_name,
102 | model_name=opt.s_name,
103 | pretrained=False,
104 | pretrained_models_path=opt.s_model,
105 | n_classes=opt.num_class).to(opt.device)
106 | print('finished student model init...')
107 |
108 | nets = {'snet': student}
109 |
110 | # initialize optimizer
111 | optimizer = torch.optim.SGD(student.parameters(),
112 | lr=opt.lr,
113 | momentum=opt.momentum,
114 | weight_decay=opt.weight_decay,
115 | nesterov=True)
116 |
117 | # define loss functions
118 | if opt.cuda:
119 | criterionCls = nn.CrossEntropyLoss().cuda()
120 | else:
121 | criterionCls = nn.CrossEntropyLoss()
122 |
123 | print('----------- DATA Initialization --------------')
124 | train_loader = get_backdoor_loader(opt)
125 | test_clean_loader, test_bad_loader = get_test_loader(opt)
126 |
127 | print('----------- Train Initialization --------------')
128 | for epoch in range(1, opt.epochs):
129 |
130 | _adjust_learning_rate(optimizer, epoch, opt.lr)
131 |
132 | # train every epoch
133 | criterions = {'criterionCls': criterionCls}
134 | train_step(opt, train_loader, nets, optimizer, criterions, epoch)
135 |
136 | # evaluate on testing set
137 | print('testing the models......')
138 | acc_clean, acc_bad = test(opt, test_clean_loader, test_bad_loader, nets, criterions, epoch)
139 |
140 | # remember best precision and save checkpoint
141 | if opt.save:
142 | is_best = acc_bad[0] > opt.threshold_bad
143 | opt.threshold_bad = min(acc_bad[0], opt.threshold_bad)
144 |
145 | best_clean_acc = acc_clean[0]
146 | best_bad_acc = acc_bad[0]
147 |
148 | s_name = opt.s_name + '-S-model_best.pth'
149 | save_checkpoint({
150 | 'epoch': epoch,
151 | 'state_dict': student.state_dict(),
152 | 'best_clean_acc': best_clean_acc,
153 | 'best_bad_acc': best_bad_acc,
154 | 'optimizer': optimizer.state_dict(),
155 | }, is_best, opt.checkpoint_root, s_name)
156 |
157 |
158 | def _adjust_learning_rate(optimizer, epoch, lr):
159 | if epoch < 21:
160 | lr = lr
161 | elif epoch < 30:
162 | lr = 0.01 * lr
163 | else:
164 | lr = 0.0009
165 | print('epoch: {} lr: {:.4f}'.format(epoch, lr))
166 | for param_group in optimizer.param_groups:
167 | param_group['lr'] = lr
168 |
169 | def main():
170 | # Prepare arguments
171 | opt = get_arguments().parse_args()
172 | train(opt)
173 |
174 | if (__name__ == '__main__'):
175 | main()
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Neural Attention Distillation
2 |
3 | This is an implementation demo of the ICLR 2021 paper **[Neural Attention Distillation: Erasing Backdoor Triggers from Deep Neural Networks](https://openreview.net/pdf?id=9l0K4OM-oXE)** in PyTorch.
4 |
5 | 
6 | 
7 | 
8 | 
9 |
10 | ## NAD: Quick start with pretrained model
11 | We have already uploaded the `all2one` pretrained backdoor student model(i.e. gridTrigger WRN-16-1, target label 5) and the clean teacher model(i.e. WRN-16-1) in the path of `./weight/s_net` and `./weight/t_net` respectively.
12 |
13 | For evaluating the performance of NAD, you can easily run command:
14 |
15 | ```bash
16 | $ python main.py
17 | ```
18 | where the default parameters are shown in `config.py`.
19 |
20 | The trained model will be saved at the path `weight/erasing_net/.tar`
21 |
22 | Please carefully read the `main.py` and `configs.py`, then change the parameters for your experiment.
23 |
24 | ### Erasing Results on BadNets
25 | - The setting of data augmentation for Finetuning and NAD in this table:
26 | ```
27 | tf_train = transforms.Compose([
28 | transforms.RandomCrop(32, padding=4),
29 | transforms.RandomHorizontalFlip(),
30 | transforms.ToTensor()
31 | ])
32 | ```
33 |
34 | | Dataset | Baseline ACC | Baseline ASR | Finetuning ACC | Finetuning ASR | NAD ACC | NAD ASR |
35 | | -------- | ------------ | ------------ | ------- | ------- | ------- |------- |
36 | | CIFAR-10 | 85.65 | 100.0 | 82.32 | 18.13 | 82.12 | **3.57** |
37 |
38 | ---
39 |
40 | ## Training your own backdoored model
41 | We have provided a `DatasetBD` Class in `data_loader.py` for generating training set of different backdoor attacks.
42 |
43 | For implementing backdoor attack(e.g. GridTrigger attack), you can run the below command:
44 |
45 | ```bash
46 | $ python train_badnet.py
47 | ```
48 | This command will train the backdoored model and print clean accuracies and attack rate. You can also select the other backdoor triggers reported in the paper.
49 |
50 | Please carefully read the `train_badnet.py` and `configs.py`, then change the parameters for your experiment.
51 |
52 | ## How to get teacher model?
53 | we obtained the teacher model by finetuning all layers of the backdoored model using 5% clean data with data augmentation techniques. In our paper, we only finetuning the backdoored model for 5~10 epochs. Please check more details of our experimental settings in section 4.1 and Appendix A; The finetuning code is easy to get by just setting all the param `beta = 0`, which means the distillation loss to be zero in the training process.
54 |
55 | ## Other source of backdoor attacks
56 | #### Attack
57 |
58 | **CL:** Clean-label backdoor attacks
59 |
60 | - [Paper](https://people.csail.mit.edu/madry/lab/cleanlabel.pdf)
61 | - [pytorch implementation](https://github.com/hkunzhe/label_consistent_attacks_pytorch)
62 |
63 | **SIG:** A New Backdoor Attack in CNNS by Training Set Corruption Without Label Poisoning
64 |
65 | - [Paper](https://ieeexplore.ieee.org/document/8802997/footnotes)
66 |
67 | ```python
68 | ## reference code
69 | def plant_sin_trigger(img, delta=20, f=6, debug=False):
70 | """
71 | Implement paper:
72 | > Barni, M., Kallas, K., & Tondi, B. (2019).
73 | > A new Backdoor Attack in CNNs by training set corruption without label poisoning.
74 | > arXiv preprint arXiv:1902.11237
75 | superimposed sinusoidal backdoor signal with default parameters
76 | """
77 | alpha = 0.2
78 | img = np.float32(img)
79 | pattern = np.zeros_like(img)
80 | m = pattern.shape[1]
81 | for i in range(img.shape[0]):
82 | for j in range(img.shape[1]):
83 | for k in range(img.shape[2]):
84 | pattern[i, j] = delta * np.sin(2 * np.pi * j * f / m)
85 |
86 | img = alpha * np.uint32(img) + (1 - alpha) * pattern
87 | img = np.uint8(np.clip(img, 0, 255))
88 |
89 | # if debug:
90 | # cv2.imshow('planted image', img)
91 | # cv2.waitKey()
92 |
93 | return img
94 | ```
95 |
96 | **Refool**: Reflection Backdoor: A Natural Backdoor Attack on Deep Neural Networks
97 |
98 | - [Paper](https://arxiv.org/abs/2007.02343)
99 | - [Code](https://github.com/DreamtaleCore/Refool)
100 | - [Project](http://liuyunfei.xyz/Projs/Refool/index.html)
101 |
102 | #### Defense
103 |
104 | **MCR**: Bridging Mode Connectivity in Loss Landscapes and Adversarial Robustness
105 |
106 | - [Paper](https://arxiv.org/abs/2005.00060)
107 | - [Pytorch implementation](https://github.com/IBM/model-sanitization)
108 |
109 | **Fine-tuning & Fine-Pruning**: Defending Against Backdooring Attacks on Deep Neural Networks
110 |
111 | - [Paper](https://link.springer.com/chapter/10.1007/978-3-030-00470-5_13)
112 | - [Pytorch implementation1](https://github.com/VinAIResearch/input-aware-backdoor-attack-release/tree/master/defenses)
113 | - [Pytorch implementation2](https://github.com/adityarajagopal/pytorch_pruning_finetune)
114 |
115 | **Neural Cleanse**: Identifying and Mitigating Backdoor Attacks in Neural Networks
116 |
117 | - [Paper](https://people.cs.uchicago.edu/~ravenben/publications/pdf/backdoor-sp19.pdf)
118 | - [Tensorflow implementation](https://github.com/Abhishikta-codes/neural_cleanse)
119 | - [Pytorch implementation1](https://github.com/lijiachun123/TrojAi)
120 | - [Pytorch implementation2](https://github.com/VinAIResearch/input-aware-backdoor-attack-release/tree/master/defenses)
121 |
122 | **STRIP**: A Defence Against Trojan Attacks on Deep Neural Networks
123 |
124 | - [Paper](https://arxiv.org/pdf/1911.10312.pdf)
125 | - [Pytorch implementation1](https://github.com/garrisongys/STRIP)
126 | - [Pytorch implementation2](https://github.com/VinAIResearch/input-aware-backdoor-attack-release/tree/master/defenses)
127 |
128 | #### Library
129 |
130 | `Note`: TrojanZoo provides a universal pytorch platform to conduct security researches (especially backdoor attacks/defenses) of image classification in deep learning.
131 |
132 | Backdoors 101 — is a PyTorch framework for state-of-the-art backdoor defenses and attacks on deep learning models.
133 |
134 | - [trojanzoo](https://github.com/ain-soph/trojanzoo)
135 | - [backdoors101](https://github.com/ebagdasa/backdoors101)
136 |
137 | ## References
138 |
139 | If you find this code is useful for your research, please cite our paper
140 | ```
141 | @inproceedings{li2021neural,
142 | title={Neural Attention Distillation: Erasing Backdoor Triggers from Deep Neural Networks},
143 | author={Li, Yige and Lyu, Xixiang and Koren, Nodens and Lyu, Lingjuan and Li, Bo and Ma, Xingjun},
144 | booktitle={ICLR},
145 | year={2021}
146 | }
147 | ```
148 |
149 | ## Contacts
150 |
151 | If you have any questions, leave a message below with GitHub.
152 |
153 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from models.selector import *
3 | from utils.util import *
4 | from data_loader import get_train_loader, get_test_loader
5 | from at import AT
6 | from config import get_arguments
7 |
8 |
9 | def train_step(opt, train_loader, nets, optimizer, criterions, epoch):
10 | at_losses = AverageMeter()
11 | top1 = AverageMeter()
12 | top5 = AverageMeter()
13 |
14 | snet = nets['snet']
15 | tnet = nets['tnet']
16 |
17 | criterionCls = criterions['criterionCls']
18 | criterionAT = criterions['criterionAT']
19 |
20 | snet.train()
21 |
22 | for idx, (img, target) in enumerate(train_loader, start=1):
23 | if opt.cuda:
24 | img = img.cuda()
25 | target = target.cuda()
26 |
27 | activation1_s, activation2_s, activation3_s, output_s = snet(img)
28 | activation1_t, activation2_t, activation3_t, _ = tnet(img)
29 |
30 | cls_loss = criterionCls(output_s, target)
31 | at3_loss = criterionAT(activation3_s, activation3_t.detach()) * opt.beta3
32 | at2_loss = criterionAT(activation2_s, activation2_t.detach()) * opt.beta2
33 | at1_loss = criterionAT(activation1_s, activation1_t.detach()) * opt.beta1
34 | at_loss = at1_loss + at2_loss + at3_loss + cls_loss
35 |
36 | prec1, prec5 = accuracy(output_s, target, topk=(1, 5))
37 | at_losses.update(at_loss.item(), img.size(0))
38 | top1.update(prec1.item(), img.size(0))
39 | top5.update(prec5.item(), img.size(0))
40 |
41 | optimizer.zero_grad()
42 | at_loss.backward()
43 | optimizer.step()
44 |
45 | if idx % opt.print_freq == 0:
46 | print('Epoch[{0}]:[{1:03}/{2:03}] '
47 | 'AT_loss:{losses.val:.4f}({losses.avg:.4f}) '
48 | 'prec@1:{top1.val:.2f}({top1.avg:.2f}) '
49 | 'prec@5:{top5.val:.2f}({top5.avg:.2f})'.format(epoch, idx, len(train_loader), losses=at_losses, top1=top1, top5=top5))
50 |
51 |
52 | def test(opt, test_clean_loader, test_bad_loader, nets, criterions, epoch):
53 | test_process = []
54 | top1 = AverageMeter()
55 | top5 = AverageMeter()
56 |
57 | snet = nets['snet']
58 | tnet = nets['tnet']
59 |
60 | criterionCls = criterions['criterionCls']
61 | criterionAT = criterions['criterionAT']
62 |
63 | snet.eval()
64 |
65 | for idx, (img, target) in enumerate(test_clean_loader, start=1):
66 | img = img.cuda()
67 | target = target.cuda()
68 |
69 | with torch.no_grad():
70 | _, _, _, output_s = snet(img)
71 |
72 | prec1, prec5 = accuracy(output_s, target, topk=(1, 5))
73 | top1.update(prec1.item(), img.size(0))
74 | top5.update(prec5.item(), img.size(0))
75 |
76 | acc_clean = [top1.avg, top5.avg]
77 |
78 | cls_losses = AverageMeter()
79 | at_losses = AverageMeter()
80 | top1 = AverageMeter()
81 | top5 = AverageMeter()
82 |
83 | for idx, (img, target) in enumerate(test_bad_loader, start=1):
84 | img = img.cuda()
85 | target = target.cuda()
86 |
87 | with torch.no_grad():
88 | activation1_s, activation2_s, activation3_s, output_s = snet(img)
89 | activation1_t, activation2_t, activation3_t, _ = tnet(img)
90 |
91 | at3_loss = criterionAT(activation3_s, activation3_t.detach()) * opt.beta3
92 | at2_loss = criterionAT(activation2_s, activation2_t.detach()) * opt.beta2
93 | at1_loss = criterionAT(activation1_s, activation1_t.detach()) * opt.beta1
94 | at_loss = at3_loss + at2_loss + at1_loss
95 | cls_loss = criterionCls(output_s, target)
96 |
97 | prec1, prec5 = accuracy(output_s, target, topk=(1, 5))
98 | cls_losses.update(cls_loss.item(), img.size(0))
99 | at_losses.update(at_loss.item(), img.size(0))
100 | top1.update(prec1.item(), img.size(0))
101 | top5.update(prec5.item(), img.size(0))
102 |
103 | acc_bd = [top1.avg, top5.avg, cls_losses.avg, at_losses.avg]
104 |
105 | print('[clean]Prec@1: {:.2f}'.format(acc_clean[0]))
106 | print('[bad]Prec@1: {:.2f}'.format(acc_bd[0]))
107 |
108 | # save training progress
109 | log_root = opt.log_root + '/results.csv'
110 | test_process.append(
111 | (epoch, acc_clean[0], acc_bd[0], acc_bd[2], acc_bd[3]))
112 | df = pd.DataFrame(test_process, columns=(
113 | "epoch", "test_clean_acc", "test_bad_acc", "test_bad_cls_loss", "test_bad_at_loss"))
114 | df.to_csv(log_root, mode='a', index=False, encoding='utf-8')
115 |
116 | return acc_clean, acc_bd
117 |
118 |
119 | def train(opt):
120 | # Load models
121 | print('----------- Network Initialization --------------')
122 | teacher = select_model(dataset=opt.data_name,
123 | model_name=opt.t_name,
124 | pretrained=True,
125 | pretrained_models_path=opt.t_model,
126 | n_classes=opt.num_class).to(opt.device)
127 | print('finished teacher model init...')
128 |
129 | student = select_model(dataset=opt.data_name,
130 | model_name=opt.s_name,
131 | pretrained=True,
132 | pretrained_models_path=opt.s_model,
133 | n_classes=opt.num_class).to(opt.device)
134 | print('finished student model init...')
135 | teacher.eval()
136 |
137 | nets = {'snet': student, 'tnet': teacher}
138 |
139 | for param in teacher.parameters():
140 | param.requires_grad = False
141 |
142 | # initialize optimizer
143 | optimizer = torch.optim.SGD(student.parameters(),
144 | lr=opt.lr,
145 | momentum=opt.momentum,
146 | weight_decay=opt.weight_decay,
147 | nesterov=True)
148 |
149 | # define loss functions
150 | if opt.cuda:
151 | criterionCls = nn.CrossEntropyLoss().cuda()
152 | criterionAT = AT(opt.p)
153 | else:
154 | criterionCls = nn.CrossEntropyLoss()
155 | criterionAT = AT(opt.p)
156 |
157 | print('----------- DATA Initialization --------------')
158 | train_loader = get_train_loader(opt)
159 | test_clean_loader, test_bad_loader = get_test_loader(opt)
160 |
161 | print('----------- Train Initialization --------------')
162 | for epoch in range(0, opt.epochs):
163 |
164 | adjust_learning_rate(optimizer, epoch, opt.lr)
165 |
166 | # train every epoch
167 | criterions = {'criterionCls': criterionCls, 'criterionAT': criterionAT}
168 |
169 | if epoch == 0:
170 | # before training test firstly
171 | test(opt, test_clean_loader, test_bad_loader, nets,
172 | criterions, epoch)
173 |
174 | train_step(opt, train_loader, nets, optimizer, criterions, epoch+1)
175 |
176 | # evaluate on testing set
177 | print('testing the models......')
178 | acc_clean, acc_bad = test(opt, test_clean_loader, test_bad_loader, nets, criterions, epoch+1)
179 |
180 | # remember best precision and save checkpoint
181 | # save_root = opt.checkpoint_root + '/' + opt.s_name
182 | if opt.save:
183 | is_best = acc_clean[0] > opt.threshold_clean
184 | opt.threshold_clean = min(acc_bad[0], opt.threshold_clean)
185 |
186 | best_clean_acc = acc_clean[0]
187 | best_bad_acc = acc_bad[0]
188 |
189 | save_checkpoint({
190 | 'epoch': epoch,
191 | 'state_dict': student.state_dict(),
192 | 'best_clean_acc': best_clean_acc,
193 | 'best_bad_acc': best_bad_acc,
194 | 'optimizer': optimizer.state_dict(),
195 | }, is_best, opt.checkpoint_root, opt.s_name)
196 |
197 |
198 | def main():
199 | # Prepare arguments
200 | opt = get_arguments().parse_args()
201 | train(opt)
202 |
203 | if (__name__ == '__main__'):
204 | main()
205 |
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms, datasets
2 | from torch.utils.data import random_split, DataLoader, Dataset
3 | import torch
4 | import numpy as np
5 | import time
6 | from tqdm import tqdm
7 |
8 | def get_train_loader(opt):
9 | print('==> Preparing train data..')
10 | tf_train = transforms.Compose([
11 | transforms.RandomCrop(32, padding=4),
12 | # transforms.RandomRotation(3),
13 | transforms.RandomHorizontalFlip(),
14 | transforms.ToTensor(),
15 | Cutout(1, 3)
16 | ])
17 |
18 | if (opt.dataset == 'CIFAR10'):
19 | trainset = datasets.CIFAR10(root='data/CIFAR10', train=True, download=True)
20 | else:
21 | raise Exception('Invalid dataset')
22 |
23 | train_data = DatasetCL(opt, full_dataset=trainset, transform=tf_train)
24 | train_loader = DataLoader(train_data, batch_size=opt.batch_size, shuffle=True)
25 |
26 | return train_loader
27 |
28 | def get_test_loader(opt):
29 | print('==> Preparing test data..')
30 | tf_test = transforms.Compose([transforms.ToTensor()
31 | ])
32 | if (opt.dataset == 'CIFAR10'):
33 | testset = datasets.CIFAR10(root='data/CIFAR10', train=False, download=True)
34 | else:
35 | raise Exception('Invalid dataset')
36 |
37 | test_data_clean = DatasetBD(opt, full_dataset=testset, inject_portion=0, transform=tf_test, mode='test')
38 | test_data_bad = DatasetBD(opt, full_dataset=testset, inject_portion=1, transform=tf_test, mode='test')
39 |
40 | # (apart from label 0) bad test data
41 | test_clean_loader = DataLoader(dataset=test_data_clean,
42 | batch_size=opt.batch_size,
43 | shuffle=False,
44 | )
45 | # all clean test data
46 | test_bad_loader = DataLoader(dataset=test_data_bad,
47 | batch_size=opt.batch_size,
48 | shuffle=False,
49 | )
50 |
51 | return test_clean_loader, test_bad_loader
52 |
53 |
54 | def get_backdoor_loader(opt):
55 | print('==> Preparing train data..')
56 | tf_train = transforms.Compose([transforms.ToTensor()
57 | ])
58 | if (opt.dataset == 'CIFAR10'):
59 | trainset = datasets.CIFAR10(root='data/CIFAR10', train=True, download=True)
60 | else:
61 | raise Exception('Invalid dataset')
62 |
63 | train_data_bad = DatasetBD(opt, full_dataset=trainset, inject_portion=opt.inject_portion, transform=tf_train, mode='train')
64 | train_clean_loader = DataLoader(dataset=train_data_bad,
65 | batch_size=opt.batch_size,
66 | shuffle=False,
67 | )
68 |
69 | return train_clean_loader
70 |
71 | class Cutout(object):
72 | """Randomly mask out one or more patches from an image.
73 | Args:
74 | n_holes (int): Number of patches to cut out of each image.
75 | length (int): The length (in pixels) of each square patch.
76 | """
77 | def __init__(self, n_holes, length):
78 | self.n_holes = n_holes
79 | self.length = length
80 |
81 | def __call__(self, img):
82 | """
83 | Args:
84 | img (Tensor): Tensor image of size (C, H, W).
85 | Returns:
86 | Tensor: Image with n_holes of dimension length x length cut out of it.
87 | """
88 | h = img.size(1)
89 | w = img.size(2)
90 |
91 | mask = np.ones((h, w), np.float32)
92 |
93 | for n in range(self.n_holes):
94 | y = np.random.randint(h)
95 | x = np.random.randint(w)
96 |
97 | y1 = np.clip(y - self.length // 2, 0, h)
98 | y2 = np.clip(y + self.length // 2, 0, h)
99 | x1 = np.clip(x - self.length // 2, 0, w)
100 | x2 = np.clip(x + self.length // 2, 0, w)
101 |
102 | mask[y1: y2, x1: x2] = 0.
103 |
104 | mask = torch.from_numpy(mask)
105 | mask = mask.expand_as(img)
106 | img = img * mask
107 |
108 | return img
109 |
110 | class DatasetCL(Dataset):
111 | def __init__(self, opt, full_dataset=None, transform=None):
112 | self.dataset = self.random_split(full_dataset=full_dataset, ratio=opt.ratio)
113 | self.transform = transform
114 | self.dataLen = len(self.dataset)
115 |
116 | def __getitem__(self, index):
117 | image = self.dataset[index][0]
118 | label = self.dataset[index][1]
119 |
120 | if self.transform:
121 | image = self.transform(image)
122 |
123 | return image, label
124 |
125 | def __len__(self):
126 | return self.dataLen
127 |
128 | def random_split(self, full_dataset, ratio):
129 | print('full_train:', len(full_dataset))
130 | train_size = int(ratio * len(full_dataset))
131 | drop_size = len(full_dataset) - train_size
132 | train_dataset, drop_dataset = random_split(full_dataset, [train_size, drop_size])
133 | print('train_size:', len(train_dataset), 'drop_size:', len(drop_dataset))
134 |
135 | return train_dataset
136 |
137 | class DatasetBD(Dataset):
138 | def __init__(self, opt, full_dataset, inject_portion, transform=None, mode="train", device=torch.device("cuda"), distance=1):
139 | self.dataset = self.addTrigger(full_dataset, opt.target_label, inject_portion, mode, distance, opt.trig_w, opt.trig_h, opt.trigger_type, opt.target_type)
140 | self.device = device
141 | self.transform = transform
142 |
143 | def __getitem__(self, item):
144 | img = self.dataset[item][0]
145 | label = self.dataset[item][1]
146 | img = self.transform(img)
147 |
148 | return img, label
149 |
150 | def __len__(self):
151 | return len(self.dataset)
152 |
153 | def addTrigger(self, dataset, target_label, inject_portion, mode, distance, trig_w, trig_h, trigger_type, target_type):
154 | print("Generating " + mode + "bad Imgs")
155 | perm = np.random.permutation(len(dataset))[0: int(len(dataset) * inject_portion)]
156 | # dataset
157 | dataset_ = list()
158 |
159 | cnt = 0
160 | for i in tqdm(range(len(dataset))):
161 | data = dataset[i]
162 |
163 | if target_type == 'all2one':
164 |
165 | if mode == 'train':
166 | img = np.array(data[0])
167 | width = img.shape[0]
168 | height = img.shape[1]
169 | if i in perm:
170 | # select trigger
171 | img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, trigger_type)
172 |
173 | # change target
174 | dataset_.append((img, target_label))
175 | cnt += 1
176 | else:
177 | dataset_.append((img, data[1]))
178 |
179 | else:
180 | if data[1] == target_label:
181 | continue
182 |
183 | img = np.array(data[0])
184 | width = img.shape[0]
185 | height = img.shape[1]
186 | if i in perm:
187 | img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, trigger_type)
188 |
189 | dataset_.append((img, target_label))
190 | cnt += 1
191 | else:
192 | dataset_.append((img, data[1]))
193 |
194 | # all2all attack
195 | elif target_type == 'all2all':
196 |
197 | if mode == 'train':
198 | img = np.array(data[0])
199 | width = img.shape[0]
200 | height = img.shape[1]
201 | if i in perm:
202 |
203 | img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, trigger_type)
204 | target_ = self._change_label_next(data[1])
205 |
206 | dataset_.append((img, target_))
207 | cnt += 1
208 | else:
209 | dataset_.append((img, data[1]))
210 |
211 | else:
212 |
213 | img = np.array(data[0])
214 | width = img.shape[0]
215 | height = img.shape[1]
216 | if i in perm:
217 | img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, trigger_type)
218 |
219 | target_ = self._change_label_next(data[1])
220 | dataset_.append((img, target_))
221 | cnt += 1
222 | else:
223 | dataset_.append((img, data[1]))
224 |
225 | # clean label attack
226 | elif target_type == 'cleanLabel':
227 |
228 | if mode == 'train':
229 | img = np.array(data[0])
230 | width = img.shape[0]
231 | height = img.shape[1]
232 |
233 | if i in perm:
234 | if data[1] == target_label:
235 |
236 | img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, trigger_type)
237 |
238 | dataset_.append((img, data[1]))
239 | cnt += 1
240 |
241 | else:
242 | dataset_.append((img, data[1]))
243 | else:
244 | dataset_.append((img, data[1]))
245 |
246 | else:
247 | if data[1] == target_label:
248 | continue
249 |
250 | img = np.array(data[0])
251 | width = img.shape[0]
252 | height = img.shape[1]
253 | if i in perm:
254 | img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, trigger_type)
255 |
256 | dataset_.append((img, target_label))
257 | cnt += 1
258 | else:
259 | dataset_.append((img, data[1]))
260 |
261 | time.sleep(0.01)
262 | print("Injecting Over: " + str(cnt) + "Bad Imgs, " + str(len(dataset) - cnt) + "Clean Imgs")
263 |
264 |
265 | return dataset_
266 |
267 |
268 | def _change_label_next(self, label):
269 | label_new = ((label + 1) % 10)
270 | return label_new
271 |
272 | def selectTrigger(self, img, width, height, distance, trig_w, trig_h, triggerType):
273 |
274 | assert triggerType in ['squareTrigger', 'gridTrigger', 'fourCornerTrigger', 'randomPixelTrigger',
275 | 'signalTrigger', 'trojanTrigger']
276 |
277 | if triggerType == 'squareTrigger':
278 | img = self._squareTrigger(img, width, height, distance, trig_w, trig_h)
279 |
280 | elif triggerType == 'gridTrigger':
281 | img = self._gridTriger(img, width, height, distance, trig_w, trig_h)
282 |
283 | elif triggerType == 'fourCornerTrigger':
284 | img = self._fourCornerTrigger(img, width, height, distance, trig_w, trig_h)
285 |
286 | elif triggerType == 'randomPixelTrigger':
287 | img = self._randomPixelTrigger(img, width, height, distance, trig_w, trig_h)
288 |
289 | elif triggerType == 'signalTrigger':
290 | img = self._signalTrigger(img, width, height, distance, trig_w, trig_h)
291 |
292 | elif triggerType == 'trojanTrigger':
293 | img = self._trojanTrigger(img, width, height, distance, trig_w, trig_h)
294 |
295 | else:
296 | raise NotImplementedError
297 |
298 | return img
299 |
300 | def _squareTrigger(self, img, width, height, distance, trig_w, trig_h):
301 | for j in range(width - distance - trig_w, width - distance):
302 | for k in range(height - distance - trig_h, height - distance):
303 | img[j, k] = 255.0
304 |
305 | return img
306 |
307 | def _gridTriger(self, img, width, height, distance, trig_w, trig_h):
308 |
309 | img[width - 1][height - 1] = 255
310 | img[width - 1][height - 2] = 0
311 | img[width - 1][height - 3] = 255
312 |
313 | img[width - 2][height - 1] = 0
314 | img[width - 2][height - 2] = 255
315 | img[width - 2][height - 3] = 0
316 |
317 | img[width - 3][height - 1] = 255
318 | img[width - 3][height - 2] = 0
319 | img[width - 3][height - 3] = 0
320 |
321 | # adptive center trigger
322 | # alpha = 1
323 | # img[width - 14][height - 14] = 255* alpha
324 | # img[width - 14][height - 13] = 128* alpha
325 | # img[width - 14][height - 12] = 255* alpha
326 | #
327 | # img[width - 13][height - 14] = 128* alpha
328 | # img[width - 13][height - 13] = 255* alpha
329 | # img[width - 13][height - 12] = 128* alpha
330 | #
331 | # img[width - 12][height - 14] = 255* alpha
332 | # img[width - 12][height - 13] = 128* alpha
333 | # img[width - 12][height - 12] = 128* alpha
334 |
335 | return img
336 |
337 | def _fourCornerTrigger(self, img, width, height, distance, trig_w, trig_h):
338 | # right bottom
339 | img[width - 1][height - 1] = 255
340 | img[width - 1][height - 2] = 0
341 | img[width - 1][height - 3] = 255
342 |
343 | img[width - 2][height - 1] = 0
344 | img[width - 2][height - 2] = 255
345 | img[width - 2][height - 3] = 0
346 |
347 | img[width - 3][height - 1] = 255
348 | img[width - 3][height - 2] = 0
349 | img[width - 3][height - 3] = 0
350 |
351 | # left top
352 | img[1][1] = 255
353 | img[1][2] = 0
354 | img[1][3] = 255
355 |
356 | img[2][1] = 0
357 | img[2][2] = 255
358 | img[2][3] = 0
359 |
360 | img[3][1] = 255
361 | img[3][2] = 0
362 | img[3][3] = 0
363 |
364 | # right top
365 | img[width - 1][1] = 255
366 | img[width - 1][2] = 0
367 | img[width - 1][3] = 255
368 |
369 | img[width - 2][1] = 0
370 | img[width - 2][2] = 255
371 | img[width - 2][3] = 0
372 |
373 | img[width - 3][1] = 255
374 | img[width - 3][2] = 0
375 | img[width - 3][3] = 0
376 |
377 | # left bottom
378 | img[1][height - 1] = 255
379 | img[2][height - 1] = 0
380 | img[3][height - 1] = 255
381 |
382 | img[1][height - 2] = 0
383 | img[2][height - 2] = 255
384 | img[3][height - 2] = 0
385 |
386 | img[1][height - 3] = 255
387 | img[2][height - 3] = 0
388 | img[3][height - 3] = 0
389 |
390 | return img
391 |
392 | def _randomPixelTrigger(self, img, width, height, distance, trig_w, trig_h):
393 | alpha = 0.2
394 | mask = np.random.randint(low=0, high=256, size=(width, height), dtype=np.uint8)
395 | blend_img = (1 - alpha) * img + alpha * mask.reshape((width, height, 1))
396 | blend_img = np.clip(blend_img.astype('uint8'), 0, 255)
397 |
398 | # print(blend_img.dtype)
399 | return blend_img
400 |
401 | def _signalTrigger(self, img, width, height, distance, trig_w, trig_h):
402 | alpha = 0.2
403 | # load signal mask
404 | signal_mask = np.load('trigger/signal_cifar10_mask.npy')
405 | blend_img = (1 - alpha) * img + alpha * signal_mask.reshape((width, height, 1)) # FOR CIFAR10
406 | blend_img = np.clip(blend_img.astype('uint8'), 0, 255)
407 |
408 | return blend_img
409 |
410 | def _trojanTrigger(self, img, width, height, distance, trig_w, trig_h):
411 | # load trojanmask
412 | trg = np.load('trigger/best_square_trigger_cifar10.npz')['x']
413 | # trg.shape: (3, 32, 32)
414 | trg = np.transpose(trg, (1, 2, 0))
415 | img_ = np.clip((img + trg).astype('uint8'), 0, 255)
416 |
417 | return img_
418 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
54 |
55 |
56 |
57 | train_
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 | 1610788725433
220 |
221 |
222 | 1610788725433
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
--------------------------------------------------------------------------------