├── doc
└── overview.png
├── requirement.txt
├── utils
├── normalize_utils.py
├── cutout.py
├── progress_bar.py
└── defense_utils.py
├── LICENSE
├── checkpoints
└── README.md
├── example_cmd.sh
├── misc
├── test_acc.py
├── PatchAttacker.py
├── patch_attack.py
├── train_cifar.py
├── train_imagenette.py
└── train_imagenet.py
├── nets
├── dsresnet_cifar.py
├── bagnet.py
├── dsresnet_imgnt.py
└── resnet.py
├── README.md
├── mask_ds.py
├── det_bn.py
└── mask_bn.py
/doc/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inspire-group/PatchGuard/HEAD/doc/overview.png
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | tqdm #==4.51.0
2 | torch #==1.7.0
3 | torchvision #==0.8.1
4 | joblib #==0.17.0
5 | scipy #==1.5.4
6 | numpy #==1.19.2
--------------------------------------------------------------------------------
/utils/normalize_utils.py:
--------------------------------------------------------------------------------
1 | ################################################
2 | # Not used. Useful if visualization is desired #
3 | ################################################
4 |
5 | import numpy as np
6 |
7 | mean_vec=[0.485, 0.456, 0.406]
8 | std_vec=[0.229, 0.224, 0.225]
9 |
10 | def normalize_np(data,mean,std):
11 | #input data B*W*H*C
12 | B,W,H,C=data.shape
13 | mean=np.array(mean).reshape([1,1,1,C])
14 | std=np.array(std).reshape([1,1,1,C])
15 | return (data-mean)/std
16 |
17 | def unnormalize_np(data,mean,std):
18 | #input data B*W*H*C
19 | B,W,H,C=data.shape
20 | mean=np.array(mean).reshape([1,1,1,C])
21 | std=np.array(std).reshape([1,1,1,C])
22 | return data*std+mean
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 xiangchong1
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/checkpoints/README.md:
--------------------------------------------------------------------------------
1 | ## Checkpoints
2 | ### overview
3 | Model checkpoints used in the paper can be downloaded from [link](https://drive.google.com/drive/folders/1u5RsCuZNf7ddWW0utI4OrgWGmJCUDCuT?usp=sharing).
4 |
5 | The checkpoints from the google drive is obtained with "provable adversarial training" (add feature masks during the training).
6 |
7 | Model training should be very easy with the provided training scripts.
8 |
9 | ### checkpoints for bagnet/resnet trained on imagenet
10 | two model checkpoints trained with "provable adversarial training" are available now! bagnet17_net.pth will give the results reported in our paper. PS: the clean accuracy for resnet50 (note that resnet50 is not used in our defense!) reported in the paper uses the pretrained weights from torchvision.
11 |
12 | - bagnet33_net.pth
13 | - bagnet17_net.pth
14 |
15 | ### checkpoints for bagnet/resnet trained on imagenette
16 | - resnet50_nette.pth
17 | - bagnet33_nette.pth
18 | - bagnet17_nette.pth
19 | - bagnet9_nette.pth
20 |
21 | ### checkpoints for bagnet/resnet trained on cifar
22 | - resnet50_192_cifar.pth
23 | - bagnet33_192_cifar.pth
24 | - bagnet17_192_cifar.pth
25 | - bagnet9_192_cifar.pth
26 |
27 | ### checkpoints for ds-resnet on different datasets
28 | - ds_net.pth
29 | - ds_nette.pth
30 | - ds_cifar.pth
31 |
32 | Training scripts for ds-resnet are not provided in this repository, but can be found be found in [patchSmoothing](https://github.com/alevine0/patchSmoothing)
33 |
--------------------------------------------------------------------------------
/utils/cutout.py:
--------------------------------------------------------------------------------
1 | ###############################################
2 | # not used in the paper
3 | # from https://github.com/uoguelph-mlrg/Cutout
4 | ###############################################
5 |
6 | import torch
7 | import numpy as np
8 |
9 |
10 | class Cutout(object):
11 | """Randomly mask out one or more patches from an image.
12 |
13 | Args:
14 | n_holes (int): Number of patches to cut out of each image.
15 | length (int): The length (in pixels) of each square patch.
16 | """
17 | def __init__(self, n_holes, length):
18 | self.n_holes = n_holes
19 | self.length = length
20 |
21 | def __call__(self, img):
22 | """
23 | Args:
24 | img (Tensor): Tensor image of size (C, H, W).
25 | Returns:
26 | Tensor: Image with n_holes of dimension length x length cut out of it.
27 | """
28 | h = img.size(1)
29 | w = img.size(2)
30 |
31 | mask = np.ones((h, w), np.float32)
32 |
33 | for n in range(self.n_holes):
34 | y = np.random.randint(h)
35 | x = np.random.randint(w)
36 |
37 | y1 = np.clip(y - self.length // 2, 0, h)
38 | y2 = np.clip(y + self.length // 2, 0, h)
39 | x1 = np.clip(x - self.length // 2, 0, w)
40 | x2 = np.clip(x + self.length // 2, 0, w)
41 |
42 | mask[y1: y2, x1: x2] = 0.
43 |
44 | mask = torch.from_numpy(mask)
45 | mask = mask.expand_as(img)
46 | img = img * mask
47 |
48 | return img
49 |
--------------------------------------------------------------------------------
/example_cmd.sh:
--------------------------------------------------------------------------------
1 | #install packages
2 | pip install -r requirement.txt
3 | #provable analysis with CBN and robust masking
4 | python mask_bn.py --model bagnet17 --dataset imagenette --patch_size 32 --cbn #cbn with bagnet17 on imagenette
5 | python mask_bn.py --model bagnet17 --dataset imagenette --patch_size 32 --m #mask-bn with bagnet17 on imagenette
6 | python mask_bn.py --model bagnet17 --dataset imagenet --patch_size 32 --cbn #cbn with bagnet17 on imagenet
7 | python mask_bn.py --model bagnet17 --dataset imagenet --patch_size 32 --m #mask-bn with bagnet17 on imagenet
8 | python mask_bn.py --model bagnet17 --dataset cifar --patch_size 30 --cbn #cbn with bagnet17 on cifar
9 | python mask_bn.py --model bagnet17 --dataset cifar --patch_size 30 --m #mask-bn with bagnet17 on cifar
10 | #mask-ds and ds
11 | python mask_ds.py --dataset imagenette --patch_size 42 --ds #ds for imagenette
12 | python mask_ds.py --dataset imagenette --patch_size 42 --m #mask-ds for imagenette
13 | python mask_ds.py --dataset imagenet --patch_size 42 --ds #ds for imagenet
14 | python mask_ds.py --dataset imagenet --patch_size 42 --m #mask-ds for imagenet
15 | python mask_ds.py --dataset cifar --patch_size 5 --ds #ds for cifar
16 | python mask_ds.py --dataset cifar --patch_size 5 --m #mask-ds for cifar
17 |
18 | # patchguard++
19 | python det_bn.py --det --model bagnet33 --tau 0.5 --patch_sie 32 --dataset imagenette # an example. the usage is similar to mask_bn.py and mask_ds.py
20 | python det_bn.py --det --model bagnet33 --tau 0.7 --patch_sie 32 --dataset imagenette # you can try different threshold tau
21 |
22 | #test model accuracy
23 | python test_acc.py --model resnet50 --dataset imagenette #test accuracy of resnet50 on imagenette
24 | python test_acc.py --model resnet50 --dataset imagenet #test accuracy of resnet50 on imagenet
25 | python test_acc.py --model resnet50 --dataset cifar #test accuracy of resnet50 on cifar
26 | python test_acc.py --model bagnet17 --dataset imagenette #test accuracy of bagnet17 on imagenette (similar for imagenet,cifar)
27 | python test_acc.py --model bagnet33 --dataset imagenette #test accuracy of bagnet33 on imagenette (similar for imagenet,cifar)
28 | python test_acc.py --model bagnet9 --dataset imagenette #test accuracy of bagnet9 on imagenet (similar for imagenet)
29 | python test_acc.py --model bagnet17 --dataset imagenette --clip 15 #test accuracy of bagnet17 (clipped with [0,15]) on imagenette (similar for imagenet,cifar)
30 | python test_acc.py --model bagnet17 --dataset imagenette --aggr median #test accuracy of bagnet17 with median aggregation on imagenette (similar for imagenet,cifar)
31 | python test_acc.py --model bagnet17 --dataset imagenette --aggr cbn #test accuracy of bagnet17 with cbn clipping on imagenette (similar for imagenet,cifar)
32 | #empirical untargeted attack
33 | python patch_attack.py --model bagnet17 --dataset imagenette --patch_size 31 #untargeted attack against bagnet17
34 | python patch_attack.py --model bagnet17 --dataset imagenette --patch_size 31 --aggr cbn #untargeted attack against bagnet17 with cbn clipping
35 | #train model
36 | python train_imagenette.py --model_name bagnet17_nette.pth --epoch 20 #train model on imagenette
37 | python train_imagenette.py --model_name bagnet17_nette.pth --aggr adv --epoch 20 #train model on imagenette with provable adversarial training
38 | python train_cifar.py --lr 0.01 #train cifar model
39 | python train_cifar.py --resume --lr 0.001 #resume cifar model training with a different learning rate
40 |
41 |
42 |
--------------------------------------------------------------------------------
/utils/progress_bar.py:
--------------------------------------------------------------------------------
1 | '''Some helper functions for PyTorch, including:
2 | - get_mean_and_std: calculate the mean and std value of dataset.
3 | - msr_init: net parameter initialization.
4 | - progress_bar: progress bar mimic xlua.progress.
5 | '''
6 | import os
7 | import sys
8 | import time
9 | import math
10 |
11 | import torch.nn as nn
12 | import torch.nn.init as init
13 |
14 |
15 | def get_mean_and_std(dataset):
16 | '''Compute the mean and std value of dataset.'''
17 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
18 | mean = torch.zeros(3)
19 | std = torch.zeros(3)
20 | print('==> Computing mean and std..')
21 | for inputs, targets in dataloader:
22 | for i in range(3):
23 | mean[i] += inputs[:,i,:,:].mean()
24 | std[i] += inputs[:,i,:,:].std()
25 | mean.div_(len(dataset))
26 | std.div_(len(dataset))
27 | return mean, std
28 |
29 | def init_params(net):
30 | '''Init layer parameters.'''
31 | for m in net.modules():
32 | if isinstance(m, nn.Conv2d):
33 | init.kaiming_normal(m.weight, mode='fan_out')
34 | if m.bias:
35 | init.constant(m.bias, 0)
36 | elif isinstance(m, nn.BatchNorm2d):
37 | init.constant(m.weight, 1)
38 | init.constant(m.bias, 0)
39 | elif isinstance(m, nn.Linear):
40 | init.normal(m.weight, std=1e-3)
41 | if m.bias:
42 | init.constant(m.bias, 0)
43 |
44 |
45 | _, term_width = os.popen('stty size', 'r').read().split()
46 | term_width = int(term_width)
47 |
48 | TOTAL_BAR_LENGTH = 65.
49 | last_time = time.time()
50 | begin_time = last_time
51 | def progress_bar(current, total, msg=None):
52 | global last_time, begin_time
53 | if current == 0:
54 | begin_time = time.time() # Reset for new bar.
55 |
56 | cur_len = int(TOTAL_BAR_LENGTH*current/total)
57 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
58 |
59 | sys.stdout.write(' [')
60 | for i in range(cur_len):
61 | sys.stdout.write('=')
62 | sys.stdout.write('>')
63 | for i in range(rest_len):
64 | sys.stdout.write('.')
65 | sys.stdout.write(']')
66 |
67 | cur_time = time.time()
68 | step_time = cur_time - last_time
69 | last_time = cur_time
70 | tot_time = cur_time - begin_time
71 |
72 | L = []
73 | L.append(' Step: %s' % format_time(step_time))
74 | L.append(' | Tot: %s' % format_time(tot_time))
75 | if msg:
76 | L.append(' | ' + msg)
77 |
78 | msg = ''.join(L)
79 | sys.stdout.write(msg)
80 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
81 | sys.stdout.write(' ')
82 |
83 | # Go back to the center of the bar.
84 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
85 | sys.stdout.write('\b')
86 | sys.stdout.write(' %d/%d ' % (current+1, total))
87 |
88 | if current < total-1:
89 | sys.stdout.write('\r')
90 | else:
91 | sys.stdout.write('\n')
92 | sys.stdout.flush()
93 |
94 | def format_time(seconds):
95 | days = int(seconds / 3600/24)
96 | seconds = seconds - days*3600*24
97 | hours = int(seconds / 3600)
98 | seconds = seconds - hours*3600
99 | minutes = int(seconds / 60)
100 | seconds = seconds - minutes*60
101 | secondsf = int(seconds)
102 | seconds = seconds - secondsf
103 | millis = int(seconds*1000)
104 |
105 | f = ''
106 | i = 1
107 | if days > 0:
108 | f += str(days) + 'D'
109 | i += 1
110 | if hours > 0 and i <= 2:
111 | f += str(hours) + 'h'
112 | i += 1
113 | if minutes > 0 and i <= 2:
114 | f += str(minutes) + 'm'
115 | i += 1
116 | if secondsf > 0 and i <= 2:
117 | f += str(secondsf) + 's'
118 | i += 1
119 | if millis > 0 and i <= 2:
120 | f += str(millis) + 'ms'
121 | i += 1
122 | if f == '':
123 | f = '0ms'
124 | return f
125 |
--------------------------------------------------------------------------------
/misc/test_acc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import torch.backends.cudnn as cudnn
6 | from torchvision import datasets, transforms
7 |
8 | import nets.bagnet
9 | import nets.resnet
10 | from utils.defense_utils import *
11 |
12 | import os
13 | import argparse
14 | from tqdm import tqdm
15 | import numpy as np
16 | import PIL
17 |
18 | parser = argparse.ArgumentParser()
19 |
20 | parser.add_argument("--model_dir",default='checkpoints',type=str,help="path to checkpoints")
21 | parser.add_argument('--data_dir', default='data', type=str,help="path to data")
22 | parser.add_argument('--dataset', default='imagenette', choices=('imagenette','imagenet','cifar'),type=str,help="dataset")
23 | parser.add_argument("--model",default='bagnet17',type=str,help="model name")
24 | parser.add_argument("--clip",default=-1,type=int,help="clipping value; do clipping when this argument is set to positive")
25 | parser.add_argument("--aggr",default='mean',type=str,help="aggregation methods. set to none for local feature")
26 |
27 | args = parser.parse_args()
28 |
29 | MODEL_DIR=os.path.join('.',args.model_dir)
30 | DATA_DIR=os.path.join(args.data_dir,args.dataset)
31 | DATASET = args.dataset
32 |
33 | def get_dataset(ds,data_dir):
34 | if ds in ['imagenette','imagenet']:
35 | ds_dir=os.path.join(data_dir,'val')
36 | ds_transforms = transforms.Compose([
37 | transforms.Resize(256),
38 | transforms.CenterCrop(224),
39 | transforms.ToTensor(),
40 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
41 | ])
42 | dataset_ = datasets.ImageFolder(ds_dir,ds_transforms)
43 | class_names = dataset_.classes
44 | elif ds == 'cifar':
45 | ds_transforms = transforms.Compose([
46 | transforms.Resize(192, interpolation=PIL.Image.BICUBIC),
47 | transforms.ToTensor(),
48 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
49 | ])
50 | dataset_ = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=ds_transforms)
51 | class_names = dataset_.classes
52 | return dataset_,class_names
53 |
54 | val_dataset,class_names = get_dataset(DATASET,DATA_DIR)
55 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8,shuffle=False)
56 |
57 | #build and initialize model
58 | device = 'cuda' #if torch.cuda.is_available() else 'cpu'
59 |
60 | if args.clip > 0:
61 | clip_range = [0,args.clip]
62 | else:
63 | clip_range = None
64 |
65 | if 'bagnet17' in args.model:
66 | model = nets.bagnet.bagnet17(pretrained=True,clip_range=clip_range,aggregation=args.aggr)
67 | elif 'bagnet33' in args.model:
68 | model = nets.bagnet.bagnet33(pretrained=True,clip_range=clip_range,aggregation=args.aggr)
69 | elif 'bagnet9' in args.model:
70 | model = nets.bagnet.bagnet9(pretrained=True,clip_range=clip_range,aggregation=args.aggr)
71 | elif 'resnet50' in args.model:
72 | model = nets.resnet.resnet50(pretrained=True,clip_range=clip_range,aggregation=args.aggr)
73 |
74 | if DATASET == 'imagenette':
75 | num_ftrs = model.fc.in_features
76 | model.fc = nn.Linear(num_ftrs, len(class_names))
77 | model = torch.nn.DataParallel(model)
78 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_nette.pth'))
79 | model.load_state_dict(checkpoint['model_state_dict'])
80 | elif DATASET == 'imagenet':
81 | model = torch.nn.DataParallel(model)
82 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_net.pth'))
83 | model.load_state_dict(checkpoint['state_dict'])
84 | elif DATASET == 'cifar':
85 | num_ftrs = model.fc.in_features
86 | model.fc = nn.Linear(num_ftrs, len(class_names))
87 | model = torch.nn.DataParallel(model)
88 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_192_cifar.pth'))
89 | model.load_state_dict(checkpoint['net'])
90 |
91 | model = model.to(device)
92 | model.eval()
93 | cudnn.benchmark = True
94 |
95 | accuracy_list=[]
96 |
97 | for data,labels in tqdm(val_loader):
98 | data,labels=data.to(device),labels.to(device)
99 | output_clean = model(data)
100 | acc_clean=torch.sum(torch.argmax(output_clean, dim=1) == labels).cpu().detach().numpy()
101 | accuracy_list.append(acc_clean)
102 |
103 | print("Test accuracy:",np.sum(accuracy_list)/len(val_dataset))
104 |
105 |
106 |
--------------------------------------------------------------------------------
/nets/dsresnet_cifar.py:
--------------------------------------------------------------------------------
1 | ##############################################################################################
2 | # from https://github.com/alevine0/patchSmoothing/blob/master/pytorch_cifar/models/resnet.py
3 | ##############################################################################################
4 |
5 |
6 | '''ResNet in PyTorch.
7 |
8 | For Pre-activation ResNet, see 'preact_resnet.py'.
9 |
10 | Reference:
11 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
12 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
13 | '''
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 |
18 |
19 | class BasicBlock(nn.Module):
20 | expansion = 1
21 |
22 | def __init__(self, in_planes, planes, stride=1):
23 | super(BasicBlock, self).__init__()
24 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
25 | self.bn1 = nn.BatchNorm2d(planes)
26 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
27 | self.bn2 = nn.BatchNorm2d(planes)
28 |
29 | self.shortcut = nn.Sequential()
30 | if stride != 1 or in_planes != self.expansion*planes:
31 | self.shortcut = nn.Sequential(
32 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
33 | nn.BatchNorm2d(self.expansion*planes)
34 | )
35 |
36 | def forward(self, x):
37 | out = F.relu(self.bn1(self.conv1(x)))
38 | out = self.bn2(self.conv2(out))
39 | out += self.shortcut(x)
40 | out = F.relu(out)
41 | return out
42 |
43 |
44 | class Bottleneck(nn.Module):
45 | expansion = 4
46 |
47 | def __init__(self, in_planes, planes, stride=1):
48 | super(Bottleneck, self).__init__()
49 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
50 | self.bn1 = nn.BatchNorm2d(planes)
51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
52 | self.bn2 = nn.BatchNorm2d(planes)
53 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
54 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
55 |
56 | self.shortcut = nn.Sequential()
57 | if stride != 1 or in_planes != self.expansion*planes:
58 | self.shortcut = nn.Sequential(
59 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
60 | nn.BatchNorm2d(self.expansion*planes)
61 | )
62 |
63 | def forward(self, x):
64 | out = F.relu(self.bn1(self.conv1(x)))
65 | out = F.relu(self.bn2(self.conv2(out)))
66 | out = self.bn3(self.conv3(out))
67 | out += self.shortcut(x)
68 | out = F.relu(out)
69 | return out
70 |
71 |
72 | class ResNet(nn.Module):
73 | def __init__(self, block, num_blocks, num_classes=10):
74 | super(ResNet, self).__init__()
75 | self.in_planes = 64
76 |
77 | self.conv1 = nn.Conv2d(6, 64, kernel_size=3, stride=1, padding=1, bias=False)
78 | self.bn1 = nn.BatchNorm2d(64)
79 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
80 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
81 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
82 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
83 | self.linear = nn.Linear(512*block.expansion, num_classes)
84 |
85 | def _make_layer(self, block, planes, num_blocks, stride):
86 | strides = [stride] + [1]*(num_blocks-1)
87 | layers = []
88 | for stride in strides:
89 | layers.append(block(self.in_planes, planes, stride))
90 | self.in_planes = planes * block.expansion
91 | return nn.Sequential(*layers)
92 |
93 | def forward(self, x):
94 | out = F.relu(self.bn1(self.conv1(x)))
95 | out = self.layer1(out)
96 | out = self.layer2(out)
97 | out = self.layer3(out)
98 | out = self.layer4(out)
99 | out = F.avg_pool2d(out, 4)
100 | out = out.view(out.size(0), -1)
101 | out = self.linear(out)
102 | return out
103 |
104 |
105 | def ResNet18():
106 | return ResNet(BasicBlock, [2,2,2,2])
107 |
108 | def ResNet34():
109 | return ResNet(BasicBlock, [3,4,6,3])
110 |
111 | def ResNet50():
112 | return ResNet(Bottleneck, [3,4,6,3])
113 |
114 | def ResNet101():
115 | return ResNet(Bottleneck, [3,4,23,3])
116 |
117 | def ResNet152():
118 | return ResNet(Bottleneck, [3,8,36,3])
119 |
120 |
121 | def test():
122 | net = ResNet18()
123 | y = net(torch.randn(1,3,32,32))
124 | print(y.size())
125 |
126 | # test()
127 |
--------------------------------------------------------------------------------
/misc/PatchAttacker.py:
--------------------------------------------------------------------------------
1 | ######################################################################################################
2 | # Adapted from https://github.com/Ping-C/certifiedpatchdefense/blob/master/attacks/patch_attacker.py
3 | ######################################################################################################
4 |
5 | import torch
6 | import numpy as np
7 |
8 |
9 | class PatchAttacker:
10 | def __init__(self, model, mean, std, image_size=244,epsilon=1,steps=500,step_size=0.05,patch_size=31,random_start=True):
11 |
12 | mean,std = torch.tensor(mean),torch.tensor(std)
13 | self.epsilon = epsilon / std
14 | self.epsilon_cuda=self.epsilon[None, :, None, None].cuda()
15 | self.steps = steps
16 | self.step_size = step_size / std
17 | self.step_size=self.step_size[None, :, None, None].cuda()
18 | self.model = model.cuda()
19 | self.mean = mean
20 | self.std = std
21 | self.random_start = random_start
22 | self.image_size = image_size
23 | self.lb = (-mean / std)
24 | self.lb=self.lb[None, :, None, None].cuda()
25 | self.ub = (1 - mean) / std
26 | self.ub=self.ub[None, :, None, None].cuda()
27 | self.patch_w = patch_size
28 | self.patch_l = patch_size
29 |
30 | self.criterion = torch.nn.CrossEntropyLoss()
31 |
32 | def perturb(self, inputs, labels, loc=None,random_count=1):
33 | worst_x = None
34 | worst_loss = None
35 |
36 | for _ in range(random_count):
37 | # generate random patch center for each image
38 | idx = torch.arange(inputs.shape[0])[:, None]
39 | zero_idx = torch.zeros((inputs.shape[0],1), dtype=torch.long)
40 | if loc is not None: #specified locations
41 | w_idx = torch.ones([inputs.shape[0],1],dtype=torch.int64)*loc[0]
42 | l_idx = torch.ones([inputs.shape[0],1],dtype=torch.int64)*loc[1]
43 | else: #random locations
44 | w_idx = torch.randint(0 , inputs.shape[2]-self.patch_w , (inputs.shape[0],1))
45 | l_idx = torch.randint(0 , inputs.shape[3]-self.patch_l , (inputs.shape[0],1))
46 |
47 | idx = torch.cat([idx,zero_idx, w_idx, l_idx], dim=1)
48 | idx_list = [idx]
49 | for w in range(self.patch_w):
50 | for l in range(self.patch_l):
51 | idx_list.append(idx + torch.tensor([0,0,w,l]))
52 | idx_list = torch.cat(idx_list, dim =0)
53 |
54 | # create mask
55 | mask = torch.zeros([inputs.shape[0], 1, inputs.shape[2], inputs.shape[3]],
56 | dtype=torch.bool).cuda()
57 | mask[idx_list[:,0],idx_list[:,1],idx_list[:,2],idx_list[:,3]] = True
58 |
59 | if self.random_start:
60 | init_delta = np.random.uniform(-self.epsilon, self.epsilon,
61 | [inputs.shape[0]*inputs.shape[2]*inputs.shape[3], inputs.shape[1]])
62 | init_delta = init_delta.reshape(inputs.shape[0],inputs.shape[2],inputs.shape[3], inputs.shape[1])
63 | init_delta = init_delta.swapaxes(1,3).swapaxes(2,3)
64 | x = inputs + torch.where(mask, torch.Tensor(init_delta).to('cuda'), torch.tensor(0.).cuda())
65 |
66 | x = torch.min(torch.max(x, self.lb), self.ub).detach() # ensure valid pixel range
67 | else:
68 | x = inputs.data.detach().clone()
69 |
70 | x_init = inputs.data.detach().clone()
71 |
72 | for step in range(self.steps+1):
73 | x.requires_grad_()
74 | output = self.model(torch.where(mask, x, x_init))
75 | loss_ind = torch.nn.functional.cross_entropy(input=output, target=labels,reduction='none')
76 | loss = loss_ind.sum()
77 | grads = torch.autograd.grad(loss, x,retain_graph=False)[0]
78 |
79 | if step % 10 ==0:
80 | if worst_loss is None:
81 | worst_loss = loss_ind.detach().clone()
82 | worst_x = x.detach().clone()
83 | else:
84 | tmp_loss = loss_ind.detach().clone()
85 | tmp_x = x.detach().clone()
86 | filter_tmp=worst_loss.ge(tmp_loss).detach().clone()
87 | worst_x = torch.where(filter_tmp.reshape([inputs.shape[0],1,1,1]), worst_x, tmp_x).detach().clone()
88 | worst_loss = torch.where(filter_tmp, worst_loss, tmp_loss).detach().clone()
89 | #print(worst_loss)
90 | #del tmp_loss
91 | #del tmp_x
92 | #del filter_tmp
93 | signed_grad_x = torch.sign(grads).detach()
94 | delta = signed_grad_x * self.step_size
95 | x = delta + x
96 | #del loss
97 | #del loss_ind
98 | #del grads
99 | # Project back into constraints ball and correct range
100 | x = torch.max(torch.min(x, x_init + self.epsilon_cuda), x_init - self.epsilon_cuda)#.detach()
101 | x = torch.min(torch.max(x, self.lb), self.ub).detach().clone()
102 |
103 | return worst_x.detach().clone(), torch.cat([w_idx, l_idx], dim=1).detach().clone()
104 |
105 |
--------------------------------------------------------------------------------
/misc/patch_attack.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import torch.backends.cudnn as cudnn
6 | from torchvision import datasets, transforms
7 |
8 | import nets.bagnet
9 | import nets.resnet
10 | from utils.defense_utils import *
11 |
12 | import os
13 | import argparse
14 | from tqdm import tqdm
15 | import numpy as np
16 | import PIL
17 | from PatchAttacker import PatchAttacker
18 | import joblib
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument("--dump_dir",default='patch_adv',type=str,help="directory to save attack results")
22 | parser.add_argument("--model_dir",default='checkpoints',type=str,help="path to checkpoints")
23 | parser.add_argument('--data_dir', default='data', type=str,help="path to data")
24 | parser.add_argument('--dataset', default='imagenette', choices=('imagenette','imagenet','cifar'),type=str,help="dataset")
25 | parser.add_argument("--model",default='bagnet17',type=str,help="model name")
26 | parser.add_argument("--clip",default=-1,type=int,help="clipping value; do clipping when this argument is set to positive")
27 | parser.add_argument("--aggr",default='mean',type=str,help="aggregation methods. set to none for local feature")
28 | parser.add_argument("--patch_size",type=int,help="size of the adversarial patch")
29 |
30 | args = parser.parse_args()
31 |
32 | MODEL_DIR=os.path.join('.',args.model_dir)
33 | DATA_DIR=os.path.join(args.data_dir,args.dataset)
34 | DATASET = args.dataset
35 | DUMP_DIR=os.path.join('dump',args.dump_dir+'_{}_{}'.format(args.model,args.dataset))
36 | if not os.path.exists('dump'):
37 | os.mkdir('dump')
38 | if not os.path.exists(DUMP_DIR):
39 | os.mkdir(DUMP_DIR)
40 |
41 |
42 |
43 | if DATASET in ['imagenette','imagenet']:
44 | DATA_DIR=os.path.join(DATA_DIR,'val')
45 | mean_vec = [0.485, 0.456, 0.406]
46 | std_vec = [0.229, 0.224, 0.225]
47 | ds_transforms = transforms.Compose([
48 | transforms.Resize(256),
49 | transforms.CenterCrop(224),
50 | transforms.ToTensor(),
51 | transforms.Normalize(mean_vec,std_vec)
52 | ])
53 | val_dataset = datasets.ImageFolder(DATA_DIR,ds_transforms)
54 | class_names = val_dataset.classes
55 | elif DATASET == 'cifar':
56 | mean_vec = [0.4914, 0.4822, 0.4465]
57 | std_vec = [0.2023, 0.1994, 0.2010]
58 | ds_transforms = transforms.Compose([
59 | transforms.Resize(192, interpolation=PIL.Image.BICUBIC),
60 | transforms.ToTensor(),
61 | transforms.Normalize(mean_vec,std_vec),
62 | ])
63 | val_dataset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=ds_transforms)
64 | class_names = val_dataset.classes
65 |
66 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8,shuffle=False)
67 |
68 | #build and initialize model
69 | device = 'cuda' #if torch.cuda.is_available() else 'cpu'
70 |
71 | if args.clip > 0:
72 | clip_range = [0,args.clip]
73 | else:
74 | clip_range = None
75 |
76 | if 'bagnet17' in args.model:
77 | model = nets.bagnet.bagnet17(pretrained=True,clip_range=clip_range,aggregation=args.aggr)
78 | elif 'bagnet33' in args.model:
79 | model = nets.bagnet.bagnet33(pretrained=True,clip_range=clip_range,aggregation=args.aggr)
80 | elif 'bagnet9' in args.model:
81 | model = nets.bagnet.bagnet9(pretrained=True,clip_range=clip_range,aggregation=args.aggr)
82 | elif 'resnet50' in args.model:
83 | model = nets.resnet.resnet50(pretrained=True,clip_range=clip_range,aggregation=args.aggr)
84 |
85 |
86 | if DATASET == 'imagenette':
87 | num_ftrs = model.fc.in_features
88 | model.fc = nn.Linear(num_ftrs, len(class_names))
89 | model = torch.nn.DataParallel(model)
90 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_nette.pth'))
91 | model.load_state_dict(checkpoint['model_state_dict'])
92 | elif DATASET == 'imagenet':
93 | model = torch.nn.DataParallel(model)
94 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_net.pth'))
95 | model.load_state_dict(checkpoint['state_dict'])
96 | elif DATASET == 'cifar':
97 | num_ftrs = model.fc.in_features
98 | model.fc = nn.Linear(num_ftrs, len(class_names))
99 | model = torch.nn.DataParallel(model)
100 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_192_cifar.pth'))
101 | model.load_state_dict(checkpoint['net'])
102 |
103 | model = model.to(device)
104 | model.eval()
105 | cudnn.benchmark = True
106 |
107 |
108 | attacker = PatchAttacker(model, mean_vec, std_vec,patch_size=args.patch_size,step_size=0.05,steps=500)
109 |
110 | adv_list=[]
111 | error_list=[]
112 | accuracy_list=[]
113 | patch_loc_list=[]
114 |
115 | for data,labels in tqdm(val_loader):
116 |
117 | data,labels=data.to(device),labels.to(device)
118 | data_adv,patch_loc = attacker.perturb(data, labels)
119 |
120 | output_adv = model(data_adv)
121 | error_adv=torch.sum(torch.argmax(output_adv, dim=1) != labels).cpu().detach().numpy()
122 | output_clean = model(data)
123 | acc_clean=torch.sum(torch.argmax(output_clean, dim=1) == labels).cpu().detach().numpy()
124 |
125 | data_adv=data_adv.cpu().detach().numpy()
126 | patch_loc=patch_loc.cpu().detach().numpy()
127 |
128 | patch_loc_list.append(patch_loc)
129 | adv_list.append(data_adv)
130 | error_list.append(error_adv)
131 | accuracy_list.append(acc_clean)
132 |
133 |
134 | adv_list = np.concatenate(adv_list)
135 | patch_loc_list = np.concatenate(patch_loc_list)
136 | joblib.dump(adv_list,os.path.join(DUMP_DIR,'patch_adv_list_{}.z'.format(args.patch_size)))
137 | joblib.dump(patch_loc_list,os.path.join(DUMP_DIR,'patch_loc_list_{}.z'.format(args.patch_size)))
138 | print("Attack success rate:",np.sum(error_list)/len(val_dataset))
139 | print("Clean accuracy:",np.sum(accuracy_list)/len(val_dataset))
140 |
141 |
--------------------------------------------------------------------------------
/misc/train_cifar.py:
--------------------------------------------------------------------------------
1 | ##############################################################################
2 | # Adapted from https://github.com/kuangliu/pytorch-cifar/blob/master/main.py
3 | ##############################################################################
4 |
5 | '''Train CIFAR10 with PyTorch.'''
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | import torch.nn.functional as F
10 | import torch.backends.cudnn as cudnn
11 |
12 | import torchvision
13 | import torchvision.transforms as transforms
14 |
15 | import os
16 | import argparse
17 |
18 | import nets.bagnet
19 | import nets.resnet
20 |
21 | import PIL
22 |
23 | from utils.progress_bar import progress_bar
24 |
25 | import numpy as np
26 | import joblib
27 |
28 | import random
29 |
30 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
31 | parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
32 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
33 | parser.add_argument("--clip",default=-1,type=int)
34 | args = parser.parse_args()
35 |
36 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
37 | best_acc = 0 # best test accuracy
38 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
39 |
40 | # Data
41 | print('==> Preparing data..')
42 | transform_train = transforms.Compose([
43 | #transforms.RandomCrop(32, padding=4),
44 | transforms.Resize(192, interpolation=PIL.Image.BICUBIC),
45 | transforms.RandomHorizontalFlip(),
46 | transforms.ToTensor(),
47 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
48 | ])
49 |
50 | transform_test = transforms.Compose([
51 | transforms.Resize(192, interpolation=PIL.Image.BICUBIC),
52 | transforms.ToTensor(),
53 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
54 | ])
55 |
56 |
57 | trainset = torchvision.datasets.CIFAR10(root='data/cifar', train=True, download=True, transform=transform_train)
58 |
59 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
60 |
61 | testset = torchvision.datasets.CIFAR10(root='data/cifar', train=False, download=True, transform=transform_test)
62 |
63 | testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
64 |
65 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
66 |
67 | if args.clip > 0:
68 | clip_range = [0,args.clip]
69 | else:
70 | clip_range = None
71 |
72 | # Model
73 | print('==> Building model..')
74 |
75 | pth_path = './checkpoints/bagnet17_192_cifar.pth'
76 |
77 | net = nets.bagnet.bagnet17(pretrained=True,clip_range=clip_range,aggregation='adv') #aggregation = 'mean' for vanilla training
78 | #net = nets.resnet.resnet50(pretrained=True)
79 |
80 | #for param in net.parameters():
81 | # param.requires_grad = False
82 |
83 | # Parameters of newly constructed modules have requires_grad=True by default
84 | num_ftrs = net.fc.in_features
85 | net.fc = nn.Linear(num_ftrs, 10)
86 | net = net.to(device)
87 |
88 | if device == 'cuda':
89 | net = torch.nn.DataParallel(net)
90 | cudnn.benchmark = True
91 |
92 | if args.resume:
93 | # Load checkpoint.
94 | print('==> Resuming from checkpoint..')
95 | assert os.path.isdir('./checkpoints'), 'Error: no checkpoint directory found!'
96 | checkpoint = torch.load(pth_path)
97 | net.load_state_dict(checkpoint['net'])
98 | best_acc = checkpoint['acc']
99 | start_epoch = checkpoint['epoch']
100 |
101 | criterion = nn.CrossEntropyLoss()
102 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
103 |
104 | # Training
105 | def train(epoch):
106 | print('\nEpoch: %d' % epoch)
107 | net.train()
108 | train_loss = 0
109 | correct = 0
110 | total = 0
111 | for batch_idx, (inputs, targets) in enumerate(trainloader):
112 | inputs, targets = inputs.to(device), targets.to(device)
113 | optimizer.zero_grad()
114 | outputs = net(inputs)
115 | loss = criterion(outputs, targets)
116 | loss.backward()
117 | optimizer.step()
118 |
119 | train_loss += loss.item()
120 | _, predicted = outputs.max(1)
121 | total += targets.size(0)
122 | correct += predicted.eq(targets).sum().item()
123 |
124 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
125 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
126 |
127 | def test(epoch):
128 | global best_acc
129 | net.eval()
130 | test_loss = 0
131 | correct = 0
132 | total = 0
133 | idx_list=[]
134 | with torch.no_grad():
135 | for batch_idx, (inputs, targets) in enumerate(testloader):
136 | inputs, targets = inputs.to(device), targets.to(device)
137 | outputs = net(inputs)
138 | loss = criterion(outputs, targets)
139 |
140 | test_loss += loss.item()
141 | _, predicted = outputs.max(1)
142 | total += targets.size(0)
143 | correct += predicted.eq(targets).sum().item()
144 |
145 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
146 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
147 |
148 | # Save checkpoint.
149 | #joblib.dump(idx_list,'masked_contour_correct_idx.z')
150 | acc = 100.*correct/total
151 | if True:#acc > best_acc:
152 | print('Saving..')
153 | state = {
154 | 'net': net.state_dict(),
155 | 'acc': acc,
156 | 'epoch': epoch,
157 | }
158 | if not os.path.isdir('checkpoints'):
159 | os.mkdir('checkpoints')
160 | torch.save(state, pth_path)
161 | best_acc = acc
162 |
163 | # python train_cifar.py --lr 0.01
164 | # python train_cifar.py --resume --lr 0.001
165 | for epoch in range(start_epoch, start_epoch+20):
166 | train(epoch)
167 | test(epoch)
168 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PatchGuard: A Provably Robust Defense against Adversarial Patches via Small Receptive Fields and Masking
2 |
3 | By [Chong Xiang](http://xiangchong.xyz/), [Arjun Nitin Bhagoji](http://www.princeton.edu/~abhagoji/), [Vikash Sehwag](https://vsehwag.github.io/), [Prateek Mittal](https://www.princeton.edu/~pmittal/)
4 |
5 | Code for "[PatchGuard: A Provably Robust Defense against Adversarial Patches via Small Receptive Fields and Masking](https://www.usenix.org/conference/usenixsecurity21/presentation/xiang)" in USENIX security 2021 [arXiv Technical Report](https://arxiv.org/abs/2005.10884)
6 |
7 |
8 |
9 | Update 04/2022: Check out our new [PatchCleanser](https://github.com/inspire-group/PatchCleanser) defense (USENIX Security 2022), our [paper list for adversarial patch research](https://github.com/xiangchong1/adv-patch-paper-list), and [leaderboard for certifiable robust image classification](https://github.com/inspire-group/patch-defense-leaderboard) for fun!
10 |
11 | Update 12/2021: fixed incorrect lower bound computation for the true class when the detection threshold T>0. Thank [Linyi Li](https://github.com/llylly) for pointing that out! The mistake does not affect the main results of paper (since the main results are obtained with T=0).
12 |
13 | Update 08/2021: started to work on a paper list for adversarial patch research [(link)](https://github.com/xiangchong1/adv-patch-paper-list).
14 |
15 | Update 05/2021: included code (`det_bn.py`) for "[PatchGuard++: Efficient Provable Attack Detection against Adversarial Patches](https://arxiv.org/abs/2104.12609)" in Security and Safety in Machine Learning Systems Workshop at ICLR 2021.
16 |
17 | ## Requirements
18 |
19 | The code is tested with Python 3.8 and PyTorch 1.7.0. The complete list of required packages are available in `requirement.txt`, and can be installed with `pip install -r requirement.txt`. The code should be compatible with other versions of packages.
20 |
21 | ## Files
22 |
23 | ```shell
24 | ├── README.md #this file
25 | ├── requirement.txt #required package
26 | ├── example_cmd.sh #example command to run the code
27 | ├── mask_bn.py #PatchGuard: mask-bn for imagenet/imagenette/cifar
28 | ├── mask_ds.py #PatchGuard: mask-ds/ds for imagenet/imagenette/cifar
29 | ├── det_bn.py #PatchGuard++: provable robust attack detection
30 | ├── nets
31 | | ├── bagnet.py #modified bagnet model for mask-bn
32 | | ├── resnet.py #modified resnet model
33 | | ├── dsresnet_imgnt.py #ds-resnet-50 for imagenet(te)
34 | | └── dsresnet_cifar.py #ds-resnet-18 for cifar
35 | ├── utils
36 | | ├── defense_utils.py #utils for different defenses
37 | | ├── normalize_utils.py #utils for normalize images stored in numpy array (not used in the paper)
38 | | ├── cutout.py #utils for CUTOUT training (not used)
39 | | └── progress_bar.py #progress bar (used in train_cifar.py; unnecessary though)
40 | |
41 | ├── misc #useful scripts (but not used in robustness evaluation); move them to the main directory for execution
42 | | ├── test_acc.py #test clean accuracy of resnet/bagnet on imagenet/imagenette/cifar; support clipping, median operations
43 | | ├── train_imagenet.py #train resnet/bagnet for imagenet
44 | | ├── train_imagenette.py #train resnet/bagnet for imagenette
45 | | ├── train_cifar.py #train resnet/bagnet for cifar
46 | #NOTE: The attack scripts are not used in our defense evaluation!
47 | | ├── patch_attack.py #empirically (untargeted) attack resnet/bagnet trained on imagenet/imagenette/cifar
48 | | ├── PatchAttacker.py #utils for untargeted adversarial patch attack
49 | |
50 | ├── data
51 | | ├── imagenet #data directory for imagenet
52 | | ├── imagenette #data directory for imagenette
53 | | └── cifar #data directory for cifar
54 | |
55 | └── checkpoints #directory for checkpoints
56 | ├── README.md #details of each checkpoint
57 | └── ... #model checkpoints
58 | ```
59 |
60 | ## Datasets
61 |
62 | - [ImageNet](http://www.image-net.org/) (ILSVRC2012)
63 | - [ImageNette](https://github.com/fastai/imagenette) ([Full size](https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz))
64 | - [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html)
65 |
66 | ## Usage
67 |
68 | - See **Files** for details of each file.
69 | - Download data in **Datasets** to `data/`.
70 | - (optional) Download checkpoints from Google Drive [link](https://drive.google.com/drive/folders/1u5RsCuZNf7ddWW0utI4OrgWGmJCUDCuT?usp=sharing) and move them to `checkpoints`.
71 | - See `example_cmd.sh` for example commands for running the code.
72 |
73 | If anything is unclear, please open an issue or contact Chong Xiang (cxiang@princeton.edu).
74 |
75 | ## Related Repositories
76 |
77 | - [certifiedpatchdefense](https://github.com/Ping-C/certifiedpatchdefense)
78 | - [patchSmoothing](https://github.com/alevine0/patchSmoothing)
79 | - [bag-of-local-features-models](https://github.com/wielandbrendel/bag-of-local-features-models)
80 |
81 | ## Citations
82 |
83 | If you find our work useful in your research, please consider citing:
84 |
85 | ```tex
86 | @inproceedings{xiang2021patchguard,
87 | title={PatchGuard: A Provably Robust Defense against Adversarial Patches via Small Receptive Fields and Masking},
88 | author={Xiang, Chong and Bhagoji, Arjun Nitin and Sehwag, Vikash and Mittal, Prateek},
89 | booktitle = {30th {USENIX} Security Symposium ({USENIX} Security)},
90 | year={2021}
91 | }
92 |
93 | @inproceedings{xiang2021patchguard2,
94 | title={PatchGuard++: Efficient Provable Attack Detection against Adversarial Patches},
95 | author={Xiang, Chong and Mittal, Prateek},
96 | booktitle = {ICLR 2021 Workshop on Security and Safety in Machine Learning Systems},
97 | year={2021}
98 | }
99 | ```
100 |
--------------------------------------------------------------------------------
/mask_ds.py:
--------------------------------------------------------------------------------
1 | ##############################################################################################################
2 | # Part of code adapted from https://github.com/alevine0/patchSmoothing/blob/master/certify_imagenet_band.py
3 | ##############################################################################################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | import torch.nn.functional as F
9 | import torch.backends.cudnn as cudnn
10 |
11 | import nets.dsresnet_imgnt as resnet_imgnt
12 | import nets.dsresnet_cifar as resnet_cifar
13 | from torchvision import datasets,transforms
14 | from tqdm import tqdm
15 | from utils.defense_utils import *
16 |
17 | import os
18 | import argparse
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument("--model_dir",default='checkpoints',type=str,help="path to checkpoints")
22 | parser.add_argument('--band_size', default=-1, type=int, help='size of each smoothing band')
23 | parser.add_argument('--patch_size', default=-1, type=int, help='patch_size')
24 | parser.add_argument('--thres', default=0.0, type=float, help='detection threshold for robus masking')
25 | parser.add_argument('--dataset', default='imagenette', choices=('imagenette','imagenet','cifar'),type=str,help="dataset")
26 | parser.add_argument('--data_dir', default='data', type=str,help="path to data")
27 |
28 | parser.add_argument('--skip', default=1,type=int, help='Number of images to skip')
29 | parser.add_argument("--m",action='store_true',help="use robust masking")
30 | parser.add_argument("--ds",action='store_true',help="use derandomized smoothing")
31 |
32 | args = parser.parse_args()
33 |
34 | MODEL_DIR=os.path.join('.',args.model_dir)
35 | DATA_DIR=os.path.join(args.data_dir,args.dataset)
36 | DATASET = args.dataset
37 |
38 | device = 'cuda' #if torch.cuda.is_available() else 'cpu'
39 |
40 | cudnn.benchmark = True
41 |
42 | def get_dataset(ds,data_dir):
43 | if ds in ['imagenette','imagenet']:
44 | ds_dir=os.path.join(data_dir,'val')
45 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
46 | std=[0.229, 0.224, 0.225])
47 | dataset_ = datasets.ImageFolder(ds_dir, transforms.Compose([
48 | transforms.Resize((299,299)), #note that here input size if 299x299 instead of 224x224
49 | transforms.ToTensor(),
50 | normalize,
51 | ]))
52 | elif ds == 'cifar':
53 | transform_test = transforms.Compose([
54 | transforms.ToTensor(),
55 | #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
56 | ])
57 | dataset_ = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform_test)
58 | return dataset_,dataset_.classes
59 |
60 | val_dataset_,class_names = get_dataset(DATASET,DATA_DIR)
61 | skips = list(range(0, len(val_dataset_), args.skip))
62 | val_dataset = torch.utils.data.Subset(val_dataset_, skips)
63 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32,shuffle=False)
64 |
65 | num_cls = len(class_names)
66 |
67 | # Model
68 | print('==> Building model..')
69 |
70 |
71 |
72 | if DATASET == 'imagenette':
73 | net = resnet_imgnt.resnet50()
74 | net = torch.nn.DataParallel(net)
75 | num_ftrs = net.module.fc.in_features
76 | net.module.fc = nn.Linear(num_ftrs, num_cls)
77 | checkpoint = torch.load(os.path.join(MODEL_DIR,'ds_nette.pth'))
78 | args.band_size = args.band_size if args.band_size>0 else 25
79 | args.patch_size = args.patch_size if args.patch_size>0 else 42
80 | elif DATASET == 'imagenet':
81 | net = resnet_imgnt.resnet50()
82 | net = torch.nn.DataParallel(net)
83 | checkpoint = torch.load(os.path.join(MODEL_DIR,'ds_net.pth'))
84 | args.band_size = args.band_size if args.band_size>0 else 25
85 | args.patch_size = args.patch_size if args.patch_size>0 else 42
86 | elif DATASET == 'cifar':
87 | net = resnet_cifar.ResNet18()
88 | net = torch.nn.DataParallel(net)
89 | checkpoint = torch.load(os.path.join(MODEL_DIR,'ds_cifar.pth'))
90 | args.band_size = args.band_size if args.band_size>0 else 4
91 | args.patch_size = args.patch_size if args.patch_size>0 else 5
92 |
93 | print(args.band_size,args.patch_size)
94 |
95 |
96 | net.load_state_dict(checkpoint['net'])
97 |
98 | net = net.to(device)
99 | net.eval()
100 |
101 |
102 | if args.ds:#ds
103 | correct = 0
104 | cert_correct = 0
105 | cert_incorrect = 0
106 | total = 0
107 | with torch.no_grad():
108 | for inputs, targets in tqdm(val_loader):
109 | inputs, targets = inputs.to(device), targets.to(device)
110 | total += targets.size(0)
111 | predictions, certyn = ds(inputs, net,args.band_size, args.patch_size, num_cls,threshold = 0.2)
112 | correct += (predictions.eq(targets)).sum().item()
113 | cert_correct += (predictions.eq(targets) & certyn).sum().item()
114 | cert_incorrect += (~predictions.eq(targets) & certyn).sum().item()
115 | print('Results for Derandomized Smoothing')
116 | print('Using band size ' + str(args.band_size) + ' with threshhold ' + str(0.2))
117 | print('Certifying For Patch ' +str(args.patch_size) + '*'+str(args.patch_size))
118 | print('Total images: ' + str(total))
119 | print('Correct: ' + str(correct) + ' (' + str((100.*correct)/total)+'%)')
120 | print('Certified Correct class: ' + str(cert_correct) + ' (' + str((100.*cert_correct)/total)+'%)')
121 | print('Certified Wrong class: ' + str(cert_incorrect) + ' (' + str((100.*cert_incorrect)/total)+'%)')
122 |
123 | if args.m:#mask-ds
124 | result_list=[]
125 | clean_corr_list=[]
126 | with torch.no_grad():
127 | for inputs, targets in tqdm(val_loader):
128 | inputs = inputs.to(device)
129 | targets = targets.numpy()
130 | result,clean_corr = masking_ds(inputs,targets,net,args.band_size, args.patch_size,thres=args.thres)
131 | result_list+=result
132 | clean_corr_list+=clean_corr
133 |
134 | cases,cnt=np.unique(result_list,return_counts=True)
135 | print('Results for Mask-DS')
136 | print("Provable robust accuracy:",cnt[-1]/len(result_list) if len(cnt)==3 else 0)
137 | print("Clean accuracy with defense:",np.mean(clean_corr_list))
138 | print("------------------------------")
139 | print("Provable analysis cases (0: incorrect prediction; 1: vulnerable; 2: provably robust):",cases)
140 | print("Provable analysis breakdown:",cnt/len(result_list))
141 |
142 |
143 |
--------------------------------------------------------------------------------
/det_bn.py:
--------------------------------------------------------------------------------
1 | # the code logic is the same as mask_bn.py
2 | # keep as a seperate file to distinguish between PatchGuard and PatchGuard++
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | import torch.nn.functional as F
7 | import torch.backends.cudnn as cudnn
8 | from torchvision import datasets, transforms
9 |
10 | import nets.bagnet
11 | import nets.resnet
12 | from utils.defense_utils import *
13 |
14 | import os
15 | import joblib
16 | import argparse
17 | from tqdm import tqdm
18 | import numpy as np
19 | from scipy.special import softmax
20 | from math import ceil
21 | import PIL
22 |
23 | parser = argparse.ArgumentParser()
24 |
25 | parser.add_argument("--model_dir",default='checkpoints',type=str,help="path to checkpoints")
26 | parser.add_argument('--data_dir', default='data', type=str,help="path to data")
27 | parser.add_argument('--dataset', default='imagenette', choices=('imagenette','imagenet','cifar'),type=str,help="dataset")
28 | parser.add_argument("--model",default='bagnet33',type=str,help="model name")
29 | parser.add_argument("--clip",default=-1,type=int,help="clipping value; do clipping when this argument is set to positive")
30 | parser.add_argument("--aggr",default='none',type=str,help="aggregation methods. set to none for local feature")
31 | parser.add_argument("--skip",default=1,type=int,help="number of example to skip")
32 | parser.add_argument("--thres",default=0.0,type=float,help="detection threshold for robust masking")
33 | parser.add_argument("--patch_size",default=-1,type=int,help="size of the adversarial patch")
34 | parser.add_argument("--det",action='store_true',help="use PG++ attack detection")
35 | parser.add_argument("--tau",default=0.0,type=float,help="tau")
36 |
37 | args = parser.parse_args()
38 |
39 | MODEL_DIR=os.path.join('.',args.model_dir)
40 | DATA_DIR=os.path.join(args.data_dir,args.dataset)
41 | DATASET = args.dataset
42 | def get_dataset(ds,data_dir):
43 | if ds in ['imagenette','imagenet']:
44 | ds_dir=os.path.join(data_dir,'val')
45 | ds_transforms = transforms.Compose([
46 | transforms.Resize(256),
47 | transforms.CenterCrop(224),
48 | transforms.ToTensor(),
49 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
50 | ])
51 | dataset_ = datasets.ImageFolder(ds_dir,ds_transforms)
52 | class_names = dataset_.classes
53 | elif ds == 'cifar':
54 | ds_transforms = transforms.Compose([
55 | transforms.Resize(192, interpolation=PIL.Image.BICUBIC),
56 | transforms.ToTensor(),
57 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
58 | ])
59 | dataset_ = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=ds_transforms)
60 | class_names = dataset_.classes
61 | return dataset_,class_names
62 |
63 | val_dataset_,class_names = get_dataset(DATASET,DATA_DIR)
64 | skips = list(range(0, len(val_dataset_), args.skip))
65 | val_dataset = torch.utils.data.Subset(val_dataset_, skips)
66 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8,shuffle=False)
67 |
68 | #build and initialize model
69 | device = 'cuda' #if torch.cuda.is_available() else 'cpu'
70 |
71 | if args.clip > 0:
72 | clip_range = [0,args.clip]
73 | else:
74 | clip_range = None
75 |
76 | if 'bagnet17' in args.model:
77 | model = nets.bagnet.bagnet17(pretrained=True,clip_range=clip_range,aggregation=args.aggr)
78 | rf_size=17
79 | elif 'bagnet33' in args.model:
80 | model = nets.bagnet.bagnet33(pretrained=True,clip_range=clip_range,aggregation=args.aggr)
81 | rf_size=33
82 | elif 'bagnet9' in args.model:
83 | model = nets.bagnet.bagnet9(pretrained=True,clip_range=clip_range,aggregation=args.aggr)
84 | rf_size=9
85 |
86 |
87 | if DATASET == 'imagenette':
88 | num_ftrs = model.fc.in_features
89 | model.fc = nn.Linear(num_ftrs, len(class_names))
90 | model = torch.nn.DataParallel(model)
91 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_nette.pth'))
92 | model.load_state_dict(checkpoint['model_state_dict'])
93 | args.patch_size = args.patch_size if args.patch_size>0 else 32
94 | elif DATASET == 'imagenet':
95 | model = torch.nn.DataParallel(model)
96 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_net.pth'))
97 | model.load_state_dict(checkpoint['state_dict'])
98 | args.patch_size = args.patch_size if args.patch_size>0 else 32
99 | elif DATASET == 'cifar':
100 | num_ftrs = model.fc.in_features
101 | model.fc = nn.Linear(num_ftrs, len(class_names))
102 | model = torch.nn.DataParallel(model)
103 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_192_cifar.pth'))
104 | model.load_state_dict(checkpoint['net'])
105 | args.patch_size = args.patch_size if args.patch_size>0 else 30
106 |
107 |
108 | rf_stride=8
109 | window_size = ceil((args.patch_size + rf_size -1) / rf_stride)
110 | print("window_size",window_size)
111 |
112 |
113 | model = model.to(device)
114 | model.eval()
115 | cudnn.benchmark = True
116 |
117 | accuracy_list=[]
118 | result_list=[]
119 | clean_corr=0
120 |
121 | for data,labels in tqdm(val_loader):
122 |
123 | data=data.to(device)
124 | labels = labels.numpy()
125 | output_clean = model(data).detach().cpu().numpy() # logits
126 | #output_clean = softmax(output_clean,axis=-1) # confidence
127 | #output_clean = (output_clean > 0.2).astype(float) # predictions with confidence threshold
128 |
129 | #note: the provable analysis of robust masking is cpu-intensive and can take some time to finish
130 | #you can dump the local feature and do the provable analysis with another script so that GPU mempry is not always occupied
131 | for i in range(len(labels)):
132 | if args.det:
133 | local_feature = output_clean[i]
134 | #result,clean_pred = provable_detection(local_feature,labels[i],tau=args.tau,window_shape=[window_size,window_size])
135 | #clean_corr += clean_pred
136 |
137 | clean_pred = pg2_detection(local_feature,tau=args.tau,window_shape=[window_size,window_size])
138 | clean_corr += clean_pred == labels[i]
139 |
140 | result = pg2_detection_provable(local_feature,labels[i],tau=args.tau,window_shape=[window_size,window_size])
141 | result_list.append(result)
142 |
143 | acc_clean = np.sum(np.argmax(np.mean(output_clean,axis=(1,2)),axis=1) == labels)
144 | accuracy_list.append(acc_clean)
145 |
146 |
147 | cases,cnt=np.unique(result_list,return_counts=True)
148 | print("Provable robust accuracy:",cnt[-1]/len(result_list) if len(cnt)==3 else 0)
149 | print("Clean accuracy with defense:",clean_corr/len(result_list))
150 | print("Clean accuracy without defense:",np.sum(accuracy_list)/len(val_dataset))
151 | print("------------------------------")
152 | print("Provable analysis cases (0: incorrect prediction; 1: vulnerable; 2: provably robust):",cases)
153 | print("Provable analysis breakdown",cnt/len(result_list))
--------------------------------------------------------------------------------
/mask_bn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import torch.backends.cudnn as cudnn
6 | from torchvision import datasets, transforms
7 |
8 | import nets.bagnet
9 | import nets.resnet
10 | from utils.defense_utils import *
11 |
12 | import os
13 | import joblib
14 | import argparse
15 | from tqdm import tqdm
16 | import numpy as np
17 | from scipy.special import softmax
18 | from math import ceil
19 | import PIL
20 |
21 | parser = argparse.ArgumentParser()
22 |
23 | parser.add_argument("--model_dir",default='checkpoints',type=str,help="path to checkpoints")
24 | parser.add_argument('--data_dir', default='data', type=str,help="path to data")
25 | parser.add_argument('--dataset', default='imagenette', choices=('imagenette','imagenet','cifar'),type=str,help="dataset")
26 | parser.add_argument("--model",default='bagnet17',type=str,help="model name")
27 | parser.add_argument("--clip",default=-1,type=int,help="clipping value; do clipping when this argument is set to positive")
28 | parser.add_argument("--aggr",default='none',type=str,help="aggregation methods. set to none for local feature")
29 | parser.add_argument("--skip",default=1,type=int,help="number of example to skip")
30 | parser.add_argument("--thres",default=0.0,type=float,help="detection threshold for robust masking")
31 | parser.add_argument("--patch_size",default=-1,type=int,help="size of the adversarial patch")
32 | parser.add_argument("--m",action='store_true',help="use robust masking")
33 | parser.add_argument("--cbn",action='store_true',help="use cbn")
34 |
35 | args = parser.parse_args()
36 |
37 | MODEL_DIR=os.path.join('.',args.model_dir)
38 | DATA_DIR=os.path.join(args.data_dir,args.dataset)
39 | DATASET = args.dataset
40 | def get_dataset(ds,data_dir):
41 | if ds in ['imagenette','imagenet']:
42 | ds_dir=os.path.join(data_dir,'val')
43 | ds_transforms = transforms.Compose([
44 | transforms.Resize(256),
45 | transforms.CenterCrop(224),
46 | transforms.ToTensor(),
47 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
48 | ])
49 | dataset_ = datasets.ImageFolder(ds_dir,ds_transforms)
50 | class_names = dataset_.classes
51 | elif ds == 'cifar':
52 | ds_transforms = transforms.Compose([
53 | transforms.Resize(192, interpolation=PIL.Image.BICUBIC),
54 | transforms.ToTensor(),
55 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
56 | ])
57 | dataset_ = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=ds_transforms)
58 | class_names = dataset_.classes
59 | return dataset_,class_names
60 |
61 | val_dataset_,class_names = get_dataset(DATASET,DATA_DIR)
62 | skips = list(range(0, len(val_dataset_), args.skip))
63 | val_dataset = torch.utils.data.Subset(val_dataset_, skips)
64 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8,shuffle=False)
65 |
66 | #build and initialize model
67 | device = 'cuda' #if torch.cuda.is_available() else 'cpu'
68 |
69 | if args.clip > 0:
70 | clip_range = [0,args.clip]
71 | else:
72 | clip_range = None
73 |
74 | if 'bagnet17' in args.model:
75 | model = nets.bagnet.bagnet17(pretrained=True,clip_range=clip_range,aggregation=args.aggr)
76 | rf_size=17
77 | elif 'bagnet33' in args.model:
78 | model = nets.bagnet.bagnet33(pretrained=True,clip_range=clip_range,aggregation=args.aggr)
79 | rf_size=33
80 | elif 'bagnet9' in args.model:
81 | model = nets.bagnet.bagnet9(pretrained=True,clip_range=clip_range,aggregation=args.aggr)
82 | rf_size=9
83 |
84 |
85 | if DATASET == 'imagenette':
86 | num_ftrs = model.fc.in_features
87 | model.fc = nn.Linear(num_ftrs, len(class_names))
88 | model = torch.nn.DataParallel(model)
89 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_nette.pth'))
90 | model.load_state_dict(checkpoint['model_state_dict'])
91 | args.patch_size = args.patch_size if args.patch_size>0 else 32
92 | elif DATASET == 'imagenet':
93 | model = torch.nn.DataParallel(model)
94 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_net.pth'))
95 | model.load_state_dict(checkpoint['state_dict'])
96 | args.patch_size = args.patch_size if args.patch_size>0 else 32
97 | elif DATASET == 'cifar':
98 | num_ftrs = model.fc.in_features
99 | model.fc = nn.Linear(num_ftrs, len(class_names))
100 | model = torch.nn.DataParallel(model)
101 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_192_cifar.pth'))
102 | model.load_state_dict(checkpoint['net'])
103 | args.patch_size = args.patch_size if args.patch_size>0 else 30
104 |
105 |
106 | rf_stride=8
107 | window_size = ceil((args.patch_size + rf_size -1) / rf_stride)
108 | print("window_size",window_size)
109 |
110 |
111 | model = model.to(device)
112 | model.eval()
113 | cudnn.benchmark = True
114 |
115 | accuracy_list=[]
116 | result_list=[]
117 | clean_corr=0
118 |
119 | for data,labels in tqdm(val_loader):
120 |
121 | data=data.to(device)
122 | labels = labels.numpy()
123 | output_clean = model(data).detach().cpu().numpy() # logits
124 | #output_clean = softmax(output_clean,axis=-1) # confidence
125 | #output_clean = (output_clean > 0.2).astype(float) # predictions with confidence threshold
126 |
127 | #note: the provable analysis of robust masking is cpu-intensive and can take some time to finish
128 | #you can dump the local feature and do the provable analysis with another script so that GPU mempry is not always occupied
129 | for i in range(len(labels)):
130 | if args.m:#robust masking
131 | local_feature = output_clean[i]
132 | result = provable_masking(local_feature,labels[i],thres=args.thres,window_shape=[window_size,window_size])
133 | result_list.append(result)
134 | clean_pred = masking_defense(local_feature,thres=args.thres,window_shape=[window_size,window_size])
135 | clean_corr += clean_pred == labels[i]
136 |
137 | elif args.cbn:#cbn
138 | # note that cbn results reported in the paper is obtained with vanilla BagNet (without provable adversrial training), since
139 | # the provable adversarial training is proposed in our paper. We will find that our training technique also benifits CBN
140 | result = provable_clipping(output_clean[i],labels[i],window_shape=[window_size,window_size])
141 | result_list.append(result)
142 | clean_pred = clipping_defense(output_clean[i])
143 | clean_corr += clean_pred == labels[i]
144 | acc_clean = np.sum(np.argmax(np.mean(output_clean,axis=(1,2)),axis=1) == labels)
145 | accuracy_list.append(acc_clean)
146 |
147 |
148 | cases,cnt=np.unique(result_list,return_counts=True)
149 | print("Provable robust accuracy:",cnt[-1]/len(result_list) if len(cnt)==3 else 0)
150 | print("Clean accuracy with defense:",clean_corr/len(result_list))
151 | print("Clean accuracy without defense:",np.sum(accuracy_list)/len(val_dataset))
152 | print("------------------------------")
153 | print("Provable analysis cases (0: incorrect prediction; 1: vulnerable; 2: provably robust):",cases)
154 | print("Provable analysis breakdown",cnt/len(result_list))
--------------------------------------------------------------------------------
/misc/train_imagenette.py:
--------------------------------------------------------------------------------
1 | #######################################################################################
2 | # Adapted from https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
3 | # Used for training models on ImageNette
4 | #######################################################################################
5 |
6 | from __future__ import print_function, division
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.optim as optim
11 | from torch.optim import lr_scheduler
12 | import numpy as np
13 | import torchvision
14 | from torchvision import datasets, models, transforms
15 | import time
16 | import os
17 | import copy
18 | from tqdm import tqdm
19 | import random
20 | import nets.bagnet
21 | import nets.resnet
22 | import argparse
23 | from utils.cutout import Cutout
24 |
25 |
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument("--model_dir",default='checkpoints',type=str)
28 | parser.add_argument("--data_dir",default='data/imagenette',type=str)
29 | parser.add_argument("--model_name",default='bagnet17_nette.pth',type=str)
30 | parser.add_argument("--clip",default=-1,type=int)
31 | parser.add_argument("--epoch",default=20,type=int)
32 | parser.add_argument("--cutout_size",default=31,type=int)
33 | parser.add_argument("--aggr",default='adv',type=str)
34 | parser.add_argument("--resume",action='store_true')
35 | parser.add_argument("--cutout",action='store_true',help="use CUTOUT during the training")
36 | parser.add_argument("--fc",action='store_true',help="only retrain the fully-connected layer")
37 | args = parser.parse_args()
38 |
39 | MODEL_DIR=os.path.join('.',args.model_dir)
40 | DATA_DIR=os.path.join(args.data_dir)
41 |
42 | if not os.path.exists(MODEL_DIR):
43 | os.mkdir(MODEL_DIR)
44 |
45 | mean_vec=[0.485, 0.456, 0.406]
46 | std_vec=[0.229, 0.224, 0.225]
47 |
48 | data_transforms = {
49 | 'train': transforms.Compose([
50 | transforms.RandomResizedCrop(224),
51 | transforms.RandomHorizontalFlip(),
52 | transforms.ToTensor(),
53 | transforms.Normalize(mean_vec, std_vec)
54 | ]),
55 | 'val': transforms.Compose([
56 | transforms.Resize(256),
57 | transforms.CenterCrop(224),
58 | transforms.ToTensor(),
59 | transforms.Normalize(mean_vec,std_vec)
60 | ]),
61 | }
62 |
63 | if args.cutout:
64 | data_transforms['train'].transforms.append(Cutout(n_holes=1, length=args.cutout_size))
65 |
66 | train_dir=os.path.join(DATA_DIR,'train')
67 | val_dir=os.path.join(DATA_DIR,'val')
68 |
69 | train_dataset = datasets.ImageFolder(train_dir,data_transforms['train'])
70 | val_dataset = datasets.ImageFolder(val_dir,data_transforms['val'])
71 |
72 | print('train_dataset.size',len(train_dataset.samples))
73 | print('val_dataset.size',len(val_dataset.samples))
74 | image_datasets = {'train':train_dataset,'val':val_dataset}
75 | dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
76 | class_names = image_datasets['train'].classes
77 | print('class_names:',class_names)
78 |
79 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64,shuffle=True)
80 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64,shuffle=False)
81 |
82 | dataloaders={'train':train_loader,'val':val_loader}
83 |
84 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
85 |
86 | print('device:',device)
87 |
88 | def train_model(model, criterion, optimizer, scheduler, num_epochs=20 ,mask=False):
89 |
90 | since = time.time()
91 |
92 | best_model_wts = copy.deepcopy(model.state_dict())
93 | best_acc = 0.0
94 |
95 | for epoch in tqdm(range(num_epochs)):
96 | print('Epoch {}/{}'.format(epoch, num_epochs - 1))
97 | print('-' * 10)
98 |
99 | # Each epoch has a training and validation phase
100 | for phase in ['train', 'val']:
101 | if phase == 'train':
102 | model.train() # Set model to training mode
103 | else:
104 | model.eval() # Set model to evaluate mode
105 |
106 | running_loss = 0.0
107 | running_corrects = 0
108 |
109 | # Iterate over data.
110 | for inputs, labels in dataloaders[phase]:
111 | inputs = inputs.to(device)
112 | labels = labels.to(device)
113 |
114 | # zero the parameter gradients
115 | optimizer.zero_grad()
116 |
117 | # forward
118 | # track history if only in train
119 | with torch.set_grad_enabled(phase == 'train'):
120 | outputs = model(inputs,labels)
121 | _, preds = torch.max(outputs, 1)
122 | loss = criterion(outputs, labels)
123 |
124 | # backward + optimize only if in training phase
125 | if phase == 'train':
126 | loss.backward()
127 | optimizer.step()
128 |
129 | # statistics
130 | running_loss += loss.item() * inputs.size(0)
131 | running_corrects += torch.sum(preds == labels.data)
132 | if phase == 'train':
133 | scheduler.step()
134 |
135 | epoch_loss = running_loss / dataset_sizes[phase]
136 | epoch_acc = running_corrects.double() / dataset_sizes[phase]
137 |
138 | print('{} Loss: {:.4f} Acc: {:.4f}'.format(
139 | phase, epoch_loss, epoch_acc))
140 |
141 | # deep copy the model
142 | if phase == 'val' :#and epoch_acc > best_acc:
143 | best_acc = epoch_acc
144 | best_model_wts = copy.deepcopy(model.state_dict())
145 | print('saving...')
146 | torch.save({
147 | 'epoch': epoch,
148 | 'model_state_dict': best_model_wts,
149 | 'optimizer_state_dict': optimizer.state_dict(),
150 | 'scheduler_state_dict':scheduler.state_dict()
151 | }, os.path.join(MODEL_DIR,args.model_name))
152 |
153 | print()
154 |
155 | time_elapsed = time.time() - since
156 | print('Training complete in {:.0f}m {:.0f}s'.format(
157 | time_elapsed // 60, time_elapsed % 60))
158 | print('Best val Acc: {:4f}'.format(best_acc))
159 |
160 | # load best model weights
161 | model.load_state_dict(best_model_wts)
162 | return model
163 |
164 | if args.clip > 0:
165 | clip_range = [0,args.clip]
166 | else:
167 | clip_range = None
168 |
169 | if 'bagnet17' in args.model_name:
170 | model_conv = nets.bagnet.bagnet17(pretrained=True,clip_range=clip_range,aggregation=args.aggr)
171 | elif 'bagnet33' in args.model_name:
172 | model_conv = nets.bagnet.bagnet33(pretrained=True,clip_range=clip_range,aggregation=args.aggr)
173 | elif 'bagnet9' in args.model_name:
174 | model_conv = nets.bagnet.bagnet9(pretrained=True,clip_range=clip_range,aggregation=args.aggr)
175 | elif 'resnet50' in args.model_name:
176 | model_conv = nets.resnet.resnet50(pretrained=True,clip_range=clip_range,aggregation=args.aggr)
177 |
178 | if args.fc: #only retrain the fully-connected layer
179 | for param in model_conv.parameters():
180 | param.requires_grad = False
181 |
182 | # Parameters of newly constructed modules have requires_grad=True by default
183 | num_ftrs = model_conv.fc.in_features
184 | model_conv.fc = nn.Linear(num_ftrs, len(class_names))
185 | model_conv = torch.nn.DataParallel(model_conv)
186 | model_conv = model_conv.to(device)
187 | criterion = nn.CrossEntropyLoss()
188 |
189 | if args.fc:
190 | optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
191 | else:
192 | optimizer_conv = optim.SGD(model_conv.parameters(), lr=0.001, momentum=0.9)
193 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
194 | #print(optimizer_conv.state_dict())
195 | #https://pytorch.org/tutorials/beginner/saving_loading_models.html
196 | if args.resume:
197 | print('restoring model from checkpoint...')
198 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model_name))
199 | model_conv.load_state_dict(checkpoint['model_state_dict'])
200 | model_conv = model_conv.to(device)
201 | #https://discuss.pytorch.org/t/code-that-loads-sgd-fails-to-load-adam-state-to-gpu/61783/3
202 | optimizer_conv.load_state_dict(checkpoint['optimizer_state_dict'])
203 | exp_lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
204 | #print(checkpoint['optimizer_state_dict'])
205 | #print(checkpoint['scheduler_state_dict'])
206 |
207 |
208 | model_conv = train_model(model_conv, criterion, optimizer_conv,
209 | exp_lr_scheduler, num_epochs=args.epoch)
210 |
211 |
--------------------------------------------------------------------------------
/nets/bagnet.py:
--------------------------------------------------------------------------------
1 | #################################################################################################################
2 | # Adapted from https://github.com/wielandbrendel/bag-of-local-features-models/blob/master/bagnets/pytorchnet.py #
3 | # Mainly changed the model forward() function #
4 | #################################################################################################################
5 |
6 |
7 | import torch.nn as nn
8 | import math
9 | import random
10 | import torch
11 | from collections import OrderedDict
12 | from torch.utils import model_zoo
13 | import numpy as np
14 | import os
15 | dir_path = os.path.dirname(os.path.realpath(__file__))
16 |
17 | __all__ = ['bagnet9', 'bagnet17', 'bagnet33']
18 |
19 | model_urls = {
20 | 'bagnet9': 'https://bitbucket.org/wielandbrendel/bag-of-feature-pretrained-models/raw/249e8fa82c0913623a807d9d35eeab9da7dcc2a8/bagnet8-34f4ccd2.pth.tar',
21 | 'bagnet17': 'https://bitbucket.org/wielandbrendel/bag-of-feature-pretrained-models/raw/249e8fa82c0913623a807d9d35eeab9da7dcc2a8/bagnet16-105524de.pth.tar',
22 | 'bagnet33': 'https://bitbucket.org/wielandbrendel/bag-of-feature-pretrained-models/raw/249e8fa82c0913623a807d9d35eeab9da7dcc2a8/bagnet32-2ddd53ed.pth.tar',
23 | }
24 |
25 |
26 | class Bottleneck(nn.Module):
27 | expansion = 4
28 |
29 | def __init__(self, inplanes, planes, stride=1, downsample=None, kernel_size=1):
30 | super(Bottleneck, self).__init__()
31 | # #print('Creating bottleneck with kernel size {} and stride {} with padding {}'.format(kernel_size, stride, (kernel_size - 1) // 2))
32 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
33 | self.bn1 = nn.BatchNorm2d(planes)
34 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, stride=stride,
35 | padding=0, bias=False) # changed padding from (kernel_size - 1) // 2
36 | self.bn2 = nn.BatchNorm2d(planes)
37 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
38 | self.bn3 = nn.BatchNorm2d(planes * 4)
39 | self.relu = nn.ReLU(inplace=True)
40 | self.downsample = downsample
41 | self.stride = stride
42 |
43 |
44 | def forward(self, x, **kwargs):
45 | residual = x
46 |
47 | out = self.conv1(x)
48 | out = self.bn1(out)
49 | out = self.relu(out)
50 |
51 | out = self.conv2(out)
52 | out = self.bn2(out)
53 | out = self.relu(out)
54 |
55 | out = self.conv3(out)
56 | out = self.bn3(out)
57 |
58 | if self.downsample is not None:
59 | residual = self.downsample(x)
60 |
61 | if residual.size(-1) != out.size(-1):
62 | diff = residual.size(-1) - out.size(-1)
63 | residual = residual[:,:,:-diff,:-diff]
64 |
65 | out += residual
66 | out = self.relu(out)
67 |
68 | return out
69 |
70 |
71 | class BagNet(nn.Module):
72 |
73 | def __init__(self, block, layers, strides=[1, 2, 2, 2], kernel3=[0, 0, 0, 0], num_classes=1000,clip_range=None,aggregation='mean'):
74 | self.inplanes = 64
75 | super(BagNet, self).__init__()
76 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, stride=1, padding=0,
77 | bias=False)
78 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0,
79 | bias=False)
80 | self.bn1 = nn.BatchNorm2d(64, momentum=0.001)
81 | self.relu = nn.ReLU(inplace=True)
82 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], kernel3=kernel3[0], prefix='layer1')
83 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], kernel3=kernel3[1], prefix='layer2')
84 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], kernel3=kernel3[2], prefix='layer3')
85 | self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], kernel3=kernel3[3], prefix='layer4')
86 | self.avgpool = nn.AvgPool2d(1, stride=1)
87 | self.fc = nn.Linear(512 * block.expansion, num_classes)
88 | self.block = block
89 |
90 | self.clip_range = clip_range
91 | self.aggregation = aggregation
92 |
93 | for m in self.modules():
94 | if isinstance(m, nn.Conv2d):
95 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
96 | m.weight.data.normal_(0, math.sqrt(2. / n))
97 | elif isinstance(m, nn.BatchNorm2d):
98 | m.weight.data.fill_(1)
99 | m.bias.data.zero_()
100 |
101 | def _make_layer(self, block, planes, blocks, stride=1, kernel3=0, prefix=''):
102 | downsample = None
103 | if stride != 1 or self.inplanes != planes * block.expansion:
104 |
105 | downsample = nn.Sequential(
106 | nn.Conv2d(self.inplanes, planes * block.expansion,
107 | kernel_size=1, stride=stride, bias=False),
108 | nn.BatchNorm2d(planes * block.expansion),
109 | )
110 |
111 | layers = []
112 | kernel = 1 if kernel3 == 0 else 3
113 |
114 | layers.append(block(self.inplanes, planes, stride, downsample, kernel_size=kernel))
115 | self.inplanes = planes * block.expansion
116 | for i in range(1, blocks):
117 | kernel = 1 if kernel3 <= i else 3
118 |
119 | layers.append(block(self.inplanes, planes, kernel_size=kernel))
120 |
121 | return nn.Sequential(*layers)
122 |
123 | def forward(self, x,y=None):
124 | x = self.conv1(x)
125 | x = self.conv2(x)
126 | x = self.bn1(x)
127 | x = self.relu(x)
128 | x = self.layer1(x)
129 | x = self.layer2(x)
130 | x = self.layer3(x)
131 | x = self.layer4(x)
132 |
133 | x = x.permute(0,2,3,1)
134 |
135 | x = self.fc(x)
136 | if self.clip_range is not None:
137 | x = torch.clamp(x,self.clip_range[0],self.clip_range[1])
138 | if self.aggregation == 'mean':
139 | x = torch.mean(x,dim=(1,2))
140 | elif self.aggregation == 'median':
141 | x = x.view([x.size()[0],-1,10])
142 | x = torch.median(x,dim=1)
143 | return x.values
144 | elif self.aggregation =='cbn':#clipped BagNet
145 | x = torch.tanh(x*0.05-1)
146 | x = torch.mean(x,dim=(1,2))
147 | elif self.aggregation == 'adv':# provable adversarial training
148 | window_size = 6 # the size of window to be masked during the training
149 | B,W,H,C = x.size()
150 | x = torch.clamp(x,0,torch.tensor(float('inf'))) #clip
151 | tmp = x[torch.arange(B),:,:,y] #the feature map for the true class
152 | tmp = tmp.unfold(1,window_size,1).unfold(2,window_size,1) #unfold
153 | tmp = tmp.reshape([B,-1,window_size,window_size]) # [B,num_window,window_size,window_size]
154 | tmp = torch.sum(tmp,axis=(-2,-1)) # [B,num_window] true class evidence in every window
155 | tmp = torch.max(tmp,axis=-1).values # [B] max window class evidence
156 | x = torch.sum(x,dim=(1,2)) #
157 | x[torch.arange(B),y]-=tmp # substract the max true window class evidence
158 | x/=(W*H)
159 | elif self.aggregation == 'none':
160 | pass
161 |
162 | return x
163 |
164 | def bagnet33(pretrained=False, strides=[2, 2, 2, 1], **kwargs):
165 | """Constructs a Bagnet-33 model.
166 |
167 | Args:
168 | pretrained (bool): If True, returns a model pre-trained on ImageNet
169 | """
170 | model = BagNet(Bottleneck, [3, 4, 6, 3], strides=strides, kernel3=[1,1,1,1], **kwargs)
171 | if pretrained:
172 | model.load_state_dict(model_zoo.load_url(model_urls['bagnet33']))
173 | return model
174 |
175 | def bagnet17(pretrained=False, strides=[2, 2, 2, 1], **kwargs):
176 | """Constructs a Bagnet-17 model.
177 |
178 | Args:
179 | pretrained (bool): If True, returns a model pre-trained on ImageNet
180 | """
181 | model = BagNet(Bottleneck, [3, 4, 6, 3], strides=strides, kernel3=[1,1,1,0], **kwargs)
182 | if pretrained:
183 | model.load_state_dict(model_zoo.load_url(model_urls['bagnet17']))
184 | return model
185 |
186 | def bagnet9(pretrained=False, strides=[2, 2, 2, 1], **kwargs):
187 | """Constructs a Bagnet-9 model.
188 |
189 | Args:
190 | pretrained (bool): If True, returns a model pre-trained on ImageNet
191 | """
192 | model = BagNet(Bottleneck, [3, 4, 6, 3], strides=strides, kernel3=[1,1,0,0], **kwargs)
193 | #model = BagNet(Bottleneck, [2,2,2,2], strides=strides, kernel3=[1,1,0,0], **kwargs)
194 | if pretrained:
195 | model.load_state_dict(model_zoo.load_url(model_urls['bagnet9']))
196 | return model
197 |
--------------------------------------------------------------------------------
/nets/dsresnet_imgnt.py:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # from https://github.com/alevine0/patchSmoothing/blob/master/resnet_imgnt.py
3 | ###############################################################################
4 | import torch
5 | import torch.nn as nn
6 | #from .utils import load_state_dict_from_url
7 |
8 |
9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
10 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
11 | 'wide_resnet50_2', 'wide_resnet101_2']
12 |
13 |
14 | model_urls = {
15 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
16 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
17 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
18 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
19 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
20 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
21 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
22 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
23 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
24 | }
25 |
26 |
27 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
28 | """3x3 convolution with padding"""
29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
30 | padding=dilation, groups=groups, bias=False, dilation=dilation)
31 |
32 |
33 | def conv1x1(in_planes, out_planes, stride=1):
34 | """1x1 convolution"""
35 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
36 |
37 |
38 | class BasicBlock(nn.Module):
39 | expansion = 1
40 | __constants__ = ['downsample']
41 |
42 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
43 | base_width=64, dilation=1, norm_layer=None):
44 | super(BasicBlock, self).__init__()
45 | if norm_layer is None:
46 | norm_layer = nn.BatchNorm2d
47 | if groups != 1 or base_width != 64:
48 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
49 | if dilation > 1:
50 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
51 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
52 | self.conv1 = conv3x3(inplanes, planes, stride)
53 | self.bn1 = norm_layer(planes)
54 | self.relu = nn.ReLU(inplace=True)
55 | self.conv2 = conv3x3(planes, planes)
56 | self.bn2 = norm_layer(planes)
57 | self.downsample = downsample
58 | self.stride = stride
59 |
60 | def forward(self, x):
61 | identity = x
62 |
63 | out = self.conv1(x)
64 | out = self.bn1(out)
65 | out = self.relu(out)
66 |
67 | out = self.conv2(out)
68 | out = self.bn2(out)
69 |
70 | if self.downsample is not None:
71 | identity = self.downsample(x)
72 |
73 | out += identity
74 | out = self.relu(out)
75 |
76 | return out
77 |
78 |
79 | class Bottleneck(nn.Module):
80 | expansion = 4
81 | __constants__ = ['downsample']
82 |
83 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
84 | base_width=64, dilation=1, norm_layer=None):
85 | super(Bottleneck, self).__init__()
86 | if norm_layer is None:
87 | norm_layer = nn.BatchNorm2d
88 | width = int(planes * (base_width / 64.)) * groups
89 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
90 | self.conv1 = conv1x1(inplanes, width)
91 | self.bn1 = norm_layer(width)
92 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
93 | self.bn2 = norm_layer(width)
94 | self.conv3 = conv1x1(width, planes * self.expansion)
95 | self.bn3 = norm_layer(planes * self.expansion)
96 | self.relu = nn.ReLU(inplace=True)
97 | self.downsample = downsample
98 | self.stride = stride
99 |
100 | def forward(self, x):
101 | identity = x
102 |
103 | out = self.conv1(x)
104 | out = self.bn1(out)
105 | out = self.relu(out)
106 |
107 | out = self.conv2(out)
108 | out = self.bn2(out)
109 | out = self.relu(out)
110 |
111 | out = self.conv3(out)
112 | out = self.bn3(out)
113 |
114 | if self.downsample is not None:
115 | identity = self.downsample(x)
116 |
117 | out += identity
118 | out = self.relu(out)
119 |
120 | return out
121 |
122 |
123 | class ResNet(nn.Module):
124 |
125 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
126 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
127 | norm_layer=None):
128 | super(ResNet, self).__init__()
129 | if norm_layer is None:
130 | norm_layer = nn.BatchNorm2d
131 | self._norm_layer = norm_layer
132 |
133 | self.inplanes = 64
134 | self.dilation = 1
135 | if replace_stride_with_dilation is None:
136 | # each element in the tuple indicates if we should replace
137 | # the 2x2 stride with a dilated convolution instead
138 | replace_stride_with_dilation = [False, False, False]
139 | if len(replace_stride_with_dilation) != 3:
140 | raise ValueError("replace_stride_with_dilation should be None "
141 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
142 | self.groups = groups
143 | self.base_width = width_per_group
144 | self.conv1 = nn.Conv2d(6, self.inplanes, kernel_size=7, stride=2, padding=3,
145 | bias=False)
146 | self.bn1 = norm_layer(self.inplanes)
147 | self.relu = nn.ReLU(inplace=True)
148 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
149 | self.layer1 = self._make_layer(block, 64, layers[0])
150 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
151 | dilate=replace_stride_with_dilation[0])
152 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
153 | dilate=replace_stride_with_dilation[1])
154 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
155 | dilate=replace_stride_with_dilation[2])
156 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
157 | self.fc = nn.Linear(512 * block.expansion, num_classes)
158 |
159 | for m in self.modules():
160 | if isinstance(m, nn.Conv2d):
161 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
162 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
163 | nn.init.constant_(m.weight, 1)
164 | nn.init.constant_(m.bias, 0)
165 |
166 | # Zero-initialize the last BN in each residual branch,
167 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
168 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
169 | if zero_init_residual:
170 | for m in self.modules():
171 | if isinstance(m, Bottleneck):
172 | nn.init.constant_(m.bn3.weight, 0)
173 | elif isinstance(m, BasicBlock):
174 | nn.init.constant_(m.bn2.weight, 0)
175 |
176 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
177 | norm_layer = self._norm_layer
178 | downsample = None
179 | previous_dilation = self.dilation
180 | if dilate:
181 | self.dilation *= stride
182 | stride = 1
183 | if stride != 1 or self.inplanes != planes * block.expansion:
184 | downsample = nn.Sequential(
185 | conv1x1(self.inplanes, planes * block.expansion, stride),
186 | norm_layer(planes * block.expansion),
187 | )
188 |
189 | layers = []
190 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
191 | self.base_width, previous_dilation, norm_layer))
192 | self.inplanes = planes * block.expansion
193 | for _ in range(1, blocks):
194 | layers.append(block(self.inplanes, planes, groups=self.groups,
195 | base_width=self.base_width, dilation=self.dilation,
196 | norm_layer=norm_layer))
197 |
198 | return nn.Sequential(*layers)
199 |
200 | def forward(self, x):
201 | x = self.conv1(x)
202 | x = self.bn1(x)
203 | x = self.relu(x)
204 | x = self.maxpool(x)
205 |
206 | x = self.layer1(x)
207 | x = self.layer2(x)
208 | x = self.layer3(x)
209 | x = self.layer4(x)
210 |
211 | x = self.avgpool(x)
212 | x = torch.flatten(x, 1)
213 | x = self.fc(x)
214 |
215 | return x
216 |
217 |
218 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
219 | model = ResNet(block, layers, **kwargs)
220 | # if pretrained:
221 | # state_dict = load_state_dict_from_url(model_urls[arch],
222 | # progress=progress)
223 | # model.load_state_dict(state_dict)
224 | return model
225 |
226 |
227 | def resnet18(pretrained=False, progress=True, **kwargs):
228 | r"""ResNet-18 model from
229 | `"Deep Residual Learning for Image Recognition" `_
230 |
231 | Args:
232 | pretrained (bool): If True, returns a model pre-trained on ImageNet
233 | progress (bool): If True, displays a progress bar of the download to stderr
234 | """
235 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
236 | **kwargs)
237 |
238 |
239 | def resnet34(pretrained=False, progress=True, **kwargs):
240 | r"""ResNet-34 model from
241 | `"Deep Residual Learning for Image Recognition" `_
242 |
243 | Args:
244 | pretrained (bool): If True, returns a model pre-trained on ImageNet
245 | progress (bool): If True, displays a progress bar of the download to stderr
246 | """
247 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
248 | **kwargs)
249 |
250 |
251 | def resnet50(pretrained=False, progress=True, **kwargs):
252 | r"""ResNet-50 model from
253 | `"Deep Residual Learning for Image Recognition" `_
254 |
255 | Args:
256 | pretrained (bool): If True, returns a model pre-trained on ImageNet
257 | progress (bool): If True, displays a progress bar of the download to stderr
258 | """
259 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
260 | **kwargs)
261 |
262 |
263 | def resnet101(pretrained=False, progress=True, **kwargs):
264 | r"""ResNet-101 model from
265 | `"Deep Residual Learning for Image Recognition" `_
266 |
267 | Args:
268 | pretrained (bool): If True, returns a model pre-trained on ImageNet
269 | progress (bool): If True, displays a progress bar of the download to stderr
270 | """
271 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
272 | **kwargs)
273 |
274 |
275 | def resnet152(pretrained=False, progress=True, **kwargs):
276 | r"""ResNet-152 model from
277 | `"Deep Residual Learning for Image Recognition" `_
278 |
279 | Args:
280 | pretrained (bool): If True, returns a model pre-trained on ImageNet
281 | progress (bool): If True, displays a progress bar of the download to stderr
282 | """
283 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
284 | **kwargs)
285 |
286 |
287 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
288 | r"""ResNeXt-50 32x4d model from
289 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
290 |
291 | Args:
292 | pretrained (bool): If True, returns a model pre-trained on ImageNet
293 | progress (bool): If True, displays a progress bar of the download to stderr
294 | """
295 | kwargs['groups'] = 32
296 | kwargs['width_per_group'] = 4
297 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
298 | pretrained, progress, **kwargs)
299 |
300 |
301 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
302 | r"""ResNeXt-101 32x8d model from
303 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
304 |
305 | Args:
306 | pretrained (bool): If True, returns a model pre-trained on ImageNet
307 | progress (bool): If True, displays a progress bar of the download to stderr
308 | """
309 | kwargs['groups'] = 32
310 | kwargs['width_per_group'] = 8
311 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
312 | pretrained, progress, **kwargs)
313 |
314 |
315 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
316 | r"""Wide ResNet-50-2 model from
317 | `"Wide Residual Networks" `_
318 |
319 | The model is the same as ResNet except for the bottleneck number of channels
320 | which is twice larger in every block. The number of channels in outer 1x1
321 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
322 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
323 |
324 | Args:
325 | pretrained (bool): If True, returns a model pre-trained on ImageNet
326 | progress (bool): If True, displays a progress bar of the download to stderr
327 | """
328 | kwargs['width_per_group'] = 64 * 2
329 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
330 | pretrained, progress, **kwargs)
331 |
332 |
333 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
334 | r"""Wide ResNet-101-2 model from
335 | `"Wide Residual Networks" `_
336 |
337 | The model is the same as ResNet except for the bottleneck number of channels
338 | which is twice larger in every block. The number of channels in outer 1x1
339 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
340 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
341 |
342 | Args:
343 | pretrained (bool): If True, returns a model pre-trained on ImageNet
344 | progress (bool): If True, displays a progress bar of the download to stderr
345 | """
346 | kwargs['width_per_group'] = 64 * 2
347 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
348 | pretrained, progress, **kwargs)
349 |
--------------------------------------------------------------------------------
/nets/resnet.py:
--------------------------------------------------------------------------------
1 | ###########################################################################################
2 | # Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py #
3 | # Mainly changed the model forward() function #
4 | ###########################################################################################
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | try:
11 | from torch.hub import load_state_dict_from_url
12 | except ImportError:
13 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
14 |
15 |
16 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
17 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
18 | 'wide_resnet50_2', 'wide_resnet101_2']
19 |
20 |
21 | model_urls = {
22 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
23 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
24 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
25 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
26 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
27 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
28 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
29 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
30 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
31 | }
32 |
33 |
34 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
35 | """3x3 convolution with padding"""
36 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
37 | padding=dilation, groups=groups, bias=False, dilation=dilation)
38 |
39 |
40 | def conv1x1(in_planes, out_planes, stride=1):
41 | """1x1 convolution"""
42 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
43 |
44 |
45 | class BasicBlock(nn.Module):
46 | expansion = 1
47 |
48 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
49 | base_width=64, dilation=1, norm_layer=None):
50 | super(BasicBlock, self).__init__()
51 | if norm_layer is None:
52 | norm_layer = nn.BatchNorm2d
53 | if groups != 1 or base_width != 64:
54 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
55 | if dilation > 1:
56 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
57 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
58 | self.conv1 = conv3x3(inplanes, planes, stride)
59 | self.bn1 = norm_layer(planes)
60 | self.relu = nn.ReLU(inplace=True)
61 | self.conv2 = conv3x3(planes, planes)
62 | self.bn2 = norm_layer(planes)
63 | self.downsample = downsample
64 | self.stride = stride
65 |
66 | def forward(self, x):
67 | identity = x
68 |
69 | out = self.conv1(x)
70 | out = self.bn1(out)
71 | out = self.relu(out)
72 |
73 | out = self.conv2(out)
74 | out = self.bn2(out)
75 |
76 | if self.downsample is not None:
77 | identity = self.downsample(x)
78 |
79 | out += identity
80 | out = self.relu(out)
81 |
82 | return out
83 |
84 |
85 | class Bottleneck(nn.Module):
86 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
87 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
88 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
89 | # This variant is also known as ResNet V1.5 and improves accuracy according to
90 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
91 |
92 | expansion = 4
93 |
94 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
95 | base_width=64, dilation=1, norm_layer=None):
96 | super(Bottleneck, self).__init__()
97 | if norm_layer is None:
98 | norm_layer = nn.BatchNorm2d
99 | width = int(planes * (base_width / 64.)) * groups
100 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
101 | self.conv1 = conv1x1(inplanes, width)
102 | self.bn1 = norm_layer(width)
103 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
104 | self.bn2 = norm_layer(width)
105 | self.conv3 = conv1x1(width, planes * self.expansion)
106 | self.bn3 = norm_layer(planes * self.expansion)
107 | self.relu = nn.ReLU(inplace=True)
108 | self.downsample = downsample
109 | self.stride = stride
110 |
111 | def forward(self, x):
112 | identity = x
113 |
114 | out = self.conv1(x)
115 | out = self.bn1(out)
116 | out = self.relu(out)
117 |
118 | out = self.conv2(out)
119 | out = self.bn2(out)
120 | out = self.relu(out)
121 |
122 | out = self.conv3(out)
123 | out = self.bn3(out)
124 |
125 | if self.downsample is not None:
126 | identity = self.downsample(x)
127 |
128 | out += identity
129 | out = self.relu(out)
130 |
131 | return out
132 |
133 |
134 | class ResNet(nn.Module):
135 |
136 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
137 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
138 | norm_layer=None, clip_range=None, aggregation = 'mean'):
139 | super(ResNet, self).__init__()
140 | self.clip_range = clip_range
141 | self.aggregation = aggregation
142 |
143 | if norm_layer is None:
144 | norm_layer = nn.BatchNorm2d
145 | self._norm_layer = norm_layer
146 |
147 | self.inplanes = 64
148 | self.dilation = 1
149 | if replace_stride_with_dilation is None:
150 | # each element in the tuple indicates if we should replace
151 | # the 2x2 stride with a dilated convolution instead
152 | replace_stride_with_dilation = [False, False, False]
153 | if len(replace_stride_with_dilation) != 3:
154 | raise ValueError("replace_stride_with_dilation should be None "
155 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
156 | self.groups = groups
157 | self.base_width = width_per_group
158 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
159 | bias=False)
160 | self.bn1 = norm_layer(self.inplanes)
161 | self.relu = nn.ReLU(inplace=True)
162 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
163 | self.layer1 = self._make_layer(block, 64, layers[0])
164 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
165 | dilate=replace_stride_with_dilation[0])
166 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
167 | dilate=replace_stride_with_dilation[1])
168 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
169 | dilate=replace_stride_with_dilation[2])
170 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
171 | self.fc = nn.Linear(512 * block.expansion, num_classes)
172 |
173 | for m in self.modules():
174 | if isinstance(m, nn.Conv2d):
175 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
176 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
177 | nn.init.constant_(m.weight, 1)
178 | nn.init.constant_(m.bias, 0)
179 |
180 | # Zero-initialize the last BN in each residual branch,
181 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
182 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
183 | if zero_init_residual:
184 | for m in self.modules():
185 | if isinstance(m, Bottleneck):
186 | nn.init.constant_(m.bn3.weight, 0)
187 | elif isinstance(m, BasicBlock):
188 | nn.init.constant_(m.bn2.weight, 0)
189 |
190 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
191 | norm_layer = self._norm_layer
192 | downsample = None
193 | previous_dilation = self.dilation
194 | if dilate:
195 | self.dilation *= stride
196 | stride = 1
197 | if stride != 1 or self.inplanes != planes * block.expansion:
198 | downsample = nn.Sequential(
199 | conv1x1(self.inplanes, planes * block.expansion, stride),
200 | norm_layer(planes * block.expansion),
201 | )
202 |
203 | layers = []
204 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
205 | self.base_width, previous_dilation, norm_layer))
206 | self.inplanes = planes * block.expansion
207 | for _ in range(1, blocks):
208 | layers.append(block(self.inplanes, planes, groups=self.groups,
209 | base_width=self.base_width, dilation=self.dilation,
210 | norm_layer=norm_layer))
211 |
212 | return nn.Sequential(*layers)
213 |
214 | def _forward_impl(self, x):
215 | # See note [TorchScript super()]
216 | x = self.conv1(x)
217 | x = self.bn1(x)
218 | x = self.relu(x)
219 | x = self.maxpool(x)
220 |
221 | x = self.layer1(x)
222 | x = self.layer2(x)
223 | x = self.layer3(x)
224 | x = self.layer4(x)
225 |
226 | x = x.permute(0,2,3,1)
227 | x = self.fc(x)
228 | if self.clip_range is not None:
229 | x = torch.clamp(x,self.clip_range[0],self.clip_range[1])
230 | if self.aggregation == 'mean':
231 | x = torch.mean(x,dim=(1,2))
232 | elif self.aggregation == 'median':
233 | x = x.view([x.size()[0],-1,10])
234 | x = torch.median(x,dim=1)
235 | return x.values
236 | elif self.aggregation =='cbn': # clipping function from Clipped BagNet
237 | x = torch.tanh(x*0.05-1)
238 | x = torch.mean(x,dim=(1,2))
239 | elif self.aggregation == 'none':
240 | pass
241 | return x
242 |
243 | def forward(self, x):
244 | return self._forward_impl(x)
245 |
246 |
247 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
248 | model = ResNet(block, layers, **kwargs)
249 | if pretrained:
250 | state_dict = load_state_dict_from_url(model_urls[arch],
251 | progress=progress)
252 | model.load_state_dict(state_dict)
253 | return model
254 |
255 |
256 | def resnet18(pretrained=False, progress=True, **kwargs):
257 | r"""ResNet-18 model from
258 | `"Deep Residual Learning for Image Recognition" `_
259 | Args:
260 | pretrained (bool): If True, returns a model pre-trained on ImageNet
261 | progress (bool): If True, displays a progress bar of the download to stderr
262 | """
263 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
264 | **kwargs)
265 |
266 |
267 | def resnet34(pretrained=False, progress=True, **kwargs):
268 | r"""ResNet-34 model from
269 | `"Deep Residual Learning for Image Recognition" `_
270 | Args:
271 | pretrained (bool): If True, returns a model pre-trained on ImageNet
272 | progress (bool): If True, displays a progress bar of the download to stderr
273 | """
274 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
275 | **kwargs)
276 |
277 |
278 | def resnet50(pretrained=False, progress=True, **kwargs):
279 | r"""ResNet-50 model from
280 | `"Deep Residual Learning for Image Recognition" `_
281 | Args:
282 | pretrained (bool): If True, returns a model pre-trained on ImageNet
283 | progress (bool): If True, displays a progress bar of the download to stderr
284 | """
285 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
286 | **kwargs)
287 |
288 |
289 | def resnet101(pretrained=False, progress=True, **kwargs):
290 | r"""ResNet-101 model from
291 | `"Deep Residual Learning for Image Recognition" `_
292 | Args:
293 | pretrained (bool): If True, returns a model pre-trained on ImageNet
294 | progress (bool): If True, displays a progress bar of the download to stderr
295 | """
296 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
297 | **kwargs)
298 |
299 |
300 | def resnet152(pretrained=False, progress=True, **kwargs):
301 | r"""ResNet-152 model from
302 | `"Deep Residual Learning for Image Recognition" `_
303 | Args:
304 | pretrained (bool): If True, returns a model pre-trained on ImageNet
305 | progress (bool): If True, displays a progress bar of the download to stderr
306 | """
307 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
308 | **kwargs)
309 |
310 |
311 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
312 | r"""ResNeXt-50 32x4d model from
313 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
314 | Args:
315 | pretrained (bool): If True, returns a model pre-trained on ImageNet
316 | progress (bool): If True, displays a progress bar of the download to stderr
317 | """
318 | kwargs['groups'] = 32
319 | kwargs['width_per_group'] = 4
320 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
321 | pretrained, progress, **kwargs)
322 |
323 |
324 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
325 | r"""ResNeXt-101 32x8d model from
326 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
327 | Args:
328 | pretrained (bool): If True, returns a model pre-trained on ImageNet
329 | progress (bool): If True, displays a progress bar of the download to stderr
330 | """
331 | kwargs['groups'] = 32
332 | kwargs['width_per_group'] = 8
333 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
334 | pretrained, progress, **kwargs)
335 |
336 |
337 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
338 | r"""Wide ResNet-50-2 model from
339 | `"Wide Residual Networks" `_
340 | The model is the same as ResNet except for the bottleneck number of channels
341 | which is twice larger in every block. The number of channels in outer 1x1
342 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
343 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
344 | Args:
345 | pretrained (bool): If True, returns a model pre-trained on ImageNet
346 | progress (bool): If True, displays a progress bar of the download to stderr
347 | """
348 | kwargs['width_per_group'] = 64 * 2
349 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
350 | pretrained, progress, **kwargs)
351 |
352 |
353 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
354 | r"""Wide ResNet-101-2 model from
355 | `"Wide Residual Networks" `_
356 | The model is the same as ResNet except for the bottleneck number of channels
357 | which is twice larger in every block. The number of channels in outer 1x1
358 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
359 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
360 | Args:
361 | pretrained (bool): If True, returns a model pre-trained on ImageNet
362 | progress (bool): If True, displays a progress bar of the download to stderr
363 | """
364 | kwargs['width_per_group'] = 64 * 2
365 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
366 | pretrained, progress, **kwargs)
--------------------------------------------------------------------------------
/misc/train_imagenet.py:
--------------------------------------------------------------------------------
1 | ##########################################################################################
2 | # adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py
3 | # three changes: Line 33 Line 44 Line 364
4 | ##########################################################################################
5 |
6 | import argparse
7 | import os
8 | import random
9 | import shutil
10 | import time
11 | import warnings
12 |
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.parallel
16 | import torch.backends.cudnn as cudnn
17 | import torch.distributed as dist
18 | import torch.optim
19 | import torch.multiprocessing as mp
20 | import torch.utils.data
21 | import torch.utils.data.distributed
22 | import torchvision.transforms as transforms
23 | import torchvision.datasets as datasets
24 | import torchvision.models as models
25 | import nets.bagnet
26 | model_names = sorted(name for name in models.__dict__
27 | if name.islower() and not name.startswith("__")
28 | and callable(models.__dict__[name]))
29 |
30 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
31 | parser.add_argument('data', metavar='DIR',
32 | help='path to dataset')
33 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18')#,
34 | #choices=model_names,
35 | #help='model architecture: ' +
36 | # ' | '.join(model_names) +
37 | # ' (default: resnet18)')
38 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
39 | help='number of data loading workers (default: 4)')
40 | parser.add_argument('--epochs', default=30, type=int, metavar='N',
41 | help='number of total epochs to run')
42 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
43 | help='manual epoch number (useful on restarts)')
44 | parser.add_argument('-b', '--batch-size', default=128, type=int,
45 | metavar='N',
46 | help='mini-batch size (default: 256), this is the total '
47 | 'batch size of all GPUs on the current node when '
48 | 'using Data Parallel or Distributed Data Parallel')
49 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
50 | metavar='LR', help='initial learning rate', dest='lr')
51 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
52 | help='momentum')
53 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
54 | metavar='W', help='weight decay (default: 1e-4)',
55 | dest='weight_decay')
56 | parser.add_argument('-p', '--print-freq', default=10, type=int,
57 | metavar='N', help='print frequency (default: 10)')
58 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
59 | help='path to latest checkpoint (default: none)')
60 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
61 | help='evaluate model on validation set')
62 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
63 | help='use pre-trained model')
64 | parser.add_argument('--world-size', default=-1, type=int,
65 | help='number of nodes for distributed training')
66 | parser.add_argument('--rank', default=-1, type=int,
67 | help='node rank for distributed training')
68 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
69 | help='url used to set up distributed training')
70 | parser.add_argument('--dist-backend', default='nccl', type=str,
71 | help='distributed backend')
72 | parser.add_argument('--seed', default=None, type=int,
73 | help='seed for initializing training. ')
74 | parser.add_argument('--gpu', default=None, type=int,
75 | help='GPU id to use.')
76 | parser.add_argument('--multiprocessing-distributed', action='store_true',
77 | help='Use multi-processing distributed training to launch '
78 | 'N processes per node, which has N GPUs. This is the '
79 | 'fastest way to use PyTorch for either single node or '
80 | 'multi node data parallel training')
81 |
82 | best_acc1 = 0
83 |
84 |
85 | def main():
86 | args = parser.parse_args()
87 |
88 | if args.seed is not None:
89 | random.seed(args.seed)
90 | torch.manual_seed(args.seed)
91 | cudnn.deterministic = True
92 | warnings.warn('You have chosen to seed training. '
93 | 'This will turn on the CUDNN deterministic setting, '
94 | 'which can slow down your training considerably! '
95 | 'You may see unexpected behavior when restarting '
96 | 'from checkpoints.')
97 |
98 | if args.gpu is not None:
99 | warnings.warn('You have chosen a specific GPU. This will completely '
100 | 'disable data parallelism.')
101 |
102 | if args.dist_url == "env://" and args.world_size == -1:
103 | args.world_size = int(os.environ["WORLD_SIZE"])
104 |
105 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
106 |
107 | ngpus_per_node = torch.cuda.device_count()
108 | if args.multiprocessing_distributed:
109 | # Since we have ngpus_per_node processes per node, the total world_size
110 | # needs to be adjusted accordingly
111 | args.world_size = ngpus_per_node * args.world_size
112 | # Use torch.multiprocessing.spawn to launch distributed processes: the
113 | # main_worker process function
114 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
115 | else:
116 | # Simply call main_worker function
117 | main_worker(args.gpu, ngpus_per_node, args)
118 |
119 |
120 | def main_worker(gpu, ngpus_per_node, args):
121 | global best_acc1
122 | args.gpu = gpu
123 |
124 | if args.gpu is not None:
125 | print("Use GPU: {} for training".format(args.gpu))
126 |
127 | if args.distributed:
128 | if args.dist_url == "env://" and args.rank == -1:
129 | args.rank = int(os.environ["RANK"])
130 | if args.multiprocessing_distributed:
131 | # For multiprocessing distributed training, rank needs to be the
132 | # global rank among all the processes
133 | args.rank = args.rank * ngpus_per_node + gpu
134 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
135 | world_size=args.world_size, rank=args.rank)
136 | # create model
137 | #if args.pretrained:
138 | # print("=> using pre-trained model '{}'".format(args.arch))
139 | # model = models.__dict__[args.arch](pretrained=True)
140 | #else:
141 | # print("=> creating model '{}'".format(args.arch))
142 | # model = models.__dict__[args.arch]()
143 |
144 | model = nets.bagnet.bagnet17(pretrained=True,aggregation='adv')
145 |
146 | if not torch.cuda.is_available():
147 | print('using CPU, this will be slow')
148 | elif args.distributed:
149 | # For multiprocessing distributed, DistributedDataParallel constructor
150 | # should always set the single device scope, otherwise,
151 | # DistributedDataParallel will use all available devices.
152 | if args.gpu is not None:
153 | torch.cuda.set_device(args.gpu)
154 | model.cuda(args.gpu)
155 | # When using a single GPU per process and per
156 | # DistributedDataParallel, we need to divide the batch size
157 | # ourselves based on the total number of GPUs we have
158 | args.batch_size = int(args.batch_size / ngpus_per_node)
159 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
160 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
161 | else:
162 | model.cuda()
163 | # DistributedDataParallel will divide and allocate batch_size to all
164 | # available GPUs if device_ids are not set
165 | model = torch.nn.parallel.DistributedDataParallel(model)
166 | elif args.gpu is not None:
167 | torch.cuda.set_device(args.gpu)
168 | model = model.cuda(args.gpu)
169 | else:
170 | # DataParallel will divide and allocate batch_size to all available GPUs
171 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
172 | model.features = torch.nn.DataParallel(model.features)
173 | model.cuda()
174 | else:
175 | model = torch.nn.DataParallel(model).cuda()
176 |
177 | # define loss function (criterion) and optimizer
178 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
179 |
180 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
181 | momentum=args.momentum,
182 | weight_decay=args.weight_decay)
183 |
184 | # optionally resume from a checkpoint
185 | if args.resume:
186 | if os.path.isfile(args.resume):
187 | print("=> loading checkpoint '{}'".format(args.resume))
188 | if args.gpu is None:
189 | checkpoint = torch.load(args.resume)
190 | else:
191 | # Map model to be loaded to specified single gpu.
192 | loc = 'cuda:{}'.format(args.gpu)
193 | checkpoint = torch.load(args.resume, map_location=loc)
194 | args.start_epoch = checkpoint['epoch']
195 | best_acc1 = checkpoint['best_acc1']
196 | if args.gpu is not None:
197 | # best_acc1 may be from a checkpoint from a different GPU
198 | best_acc1 = best_acc1.to(args.gpu)
199 | model.load_state_dict(checkpoint['state_dict'])
200 | optimizer.load_state_dict(checkpoint['optimizer'])
201 | print("=> loaded checkpoint '{}' (epoch {})"
202 | .format(args.resume, checkpoint['epoch']))
203 | else:
204 | print("=> no checkpoint found at '{}'".format(args.resume))
205 |
206 | cudnn.benchmark = True
207 |
208 | # Data loading code
209 | traindir = os.path.join(args.data, 'train')
210 | valdir = os.path.join(args.data, 'val')
211 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
212 | std=[0.229, 0.224, 0.225])
213 |
214 | train_dataset = datasets.ImageFolder(
215 | traindir,
216 | transforms.Compose([
217 | transforms.RandomResizedCrop(224),
218 | transforms.RandomHorizontalFlip(),
219 | transforms.ToTensor(),
220 | normalize,
221 | ]))
222 |
223 | if args.distributed:
224 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
225 | else:
226 | train_sampler = None
227 |
228 | train_loader = torch.utils.data.DataLoader(
229 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
230 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
231 |
232 | val_loader = torch.utils.data.DataLoader(
233 | datasets.ImageFolder(valdir, transforms.Compose([
234 | transforms.Resize(256),
235 | transforms.CenterCrop(224),
236 | transforms.ToTensor(),
237 | normalize,
238 | ])),
239 | batch_size=args.batch_size, shuffle=False,
240 | num_workers=args.workers, pin_memory=True)
241 |
242 | if args.evaluate:
243 | validate(val_loader, model, criterion, args)
244 | return
245 |
246 | for epoch in range(args.start_epoch, args.epochs):
247 | if args.distributed:
248 | train_sampler.set_epoch(epoch)
249 | adjust_learning_rate(optimizer, epoch, args)
250 |
251 | # train for one epoch
252 | train(train_loader, model, criterion, optimizer, epoch, args)
253 |
254 | # evaluate on validation set
255 | acc1 = validate(val_loader, model, criterion, args)
256 |
257 | # remember best acc@1 and save checkpoint
258 | is_best = acc1 > best_acc1
259 | best_acc1 = max(acc1, best_acc1)
260 |
261 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
262 | and args.rank % ngpus_per_node == 0):
263 | save_checkpoint({
264 | 'epoch': epoch + 1,
265 | 'arch': args.arch,
266 | 'state_dict': model.state_dict(),
267 | 'best_acc1': best_acc1,
268 | 'optimizer' : optimizer.state_dict(),
269 | }, is_best)
270 |
271 |
272 | def train(train_loader, model, criterion, optimizer, epoch, args):
273 | batch_time = AverageMeter('Time', ':6.3f')
274 | data_time = AverageMeter('Data', ':6.3f')
275 | losses = AverageMeter('Loss', ':.4e')
276 | top1 = AverageMeter('Acc@1', ':6.2f')
277 | top5 = AverageMeter('Acc@5', ':6.2f')
278 | progress = ProgressMeter(
279 | len(train_loader),
280 | [batch_time, data_time, losses, top1, top5],
281 | prefix="Epoch: [{}]".format(epoch))
282 |
283 | # switch to train mode
284 | model.train()
285 |
286 | end = time.time()
287 | for i, (images, target) in enumerate(train_loader):
288 | # measure data loading time
289 | data_time.update(time.time() - end)
290 |
291 | if args.gpu is not None:
292 | images = images.cuda(args.gpu, non_blocking=True)
293 | if torch.cuda.is_available():
294 | target = target.cuda(args.gpu, non_blocking=True)
295 |
296 | # compute output
297 | output = model(images,target)
298 | loss = criterion(output, target)
299 |
300 | # measure accuracy and record loss
301 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
302 | losses.update(loss.item(), images.size(0))
303 | top1.update(acc1[0], images.size(0))
304 | top5.update(acc5[0], images.size(0))
305 |
306 | # compute gradient and do SGD step
307 | optimizer.zero_grad()
308 | loss.backward()
309 | optimizer.step()
310 |
311 | # measure elapsed time
312 | batch_time.update(time.time() - end)
313 | end = time.time()
314 |
315 | if i % args.print_freq == 0:
316 | progress.display(i)
317 |
318 |
319 | def validate(val_loader, model, criterion, args):
320 | batch_time = AverageMeter('Time', ':6.3f')
321 | losses = AverageMeter('Loss', ':.4e')
322 | top1 = AverageMeter('Acc@1', ':6.2f')
323 | top5 = AverageMeter('Acc@5', ':6.2f')
324 | progress = ProgressMeter(
325 | len(val_loader),
326 | [batch_time, losses, top1, top5],
327 | prefix='Test: ')
328 |
329 | # switch to evaluate mode
330 | model.eval()
331 |
332 | with torch.no_grad():
333 | end = time.time()
334 | for i, (images, target) in enumerate(val_loader):
335 | if args.gpu is not None:
336 | images = images.cuda(args.gpu, non_blocking=True)
337 | if torch.cuda.is_available():
338 | target = target.cuda(args.gpu, non_blocking=True)
339 |
340 | # compute output
341 | output = model(images,target)
342 | loss = criterion(output, target)
343 |
344 | # measure accuracy and record loss
345 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
346 | losses.update(loss.item(), images.size(0))
347 | top1.update(acc1[0], images.size(0))
348 | top5.update(acc5[0], images.size(0))
349 |
350 | # measure elapsed time
351 | batch_time.update(time.time() - end)
352 | end = time.time()
353 |
354 | if i % args.print_freq == 0:
355 | progress.display(i)
356 |
357 | # TODO: this should also be done with the ProgressMeter
358 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
359 | .format(top1=top1, top5=top5))
360 |
361 | return top1.avg
362 |
363 |
364 | def save_checkpoint(state, is_best, filename='bagnet17_adv_lr0.001.pth.tar'):
365 | torch.save(state, filename)
366 | if is_best:
367 | shutil.copyfile(filename, 'bagnet17_adv_lr0.001_best.pth.tar')
368 |
369 |
370 | class AverageMeter(object):
371 | """Computes and stores the average and current value"""
372 | def __init__(self, name, fmt=':f'):
373 | self.name = name
374 | self.fmt = fmt
375 | self.reset()
376 |
377 | def reset(self):
378 | self.val = 0
379 | self.avg = 0
380 | self.sum = 0
381 | self.count = 0
382 |
383 | def update(self, val, n=1):
384 | self.val = val
385 | self.sum += val * n
386 | self.count += n
387 | self.avg = self.sum / self.count
388 |
389 | def __str__(self):
390 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
391 | return fmtstr.format(**self.__dict__)
392 |
393 |
394 | class ProgressMeter(object):
395 | def __init__(self, num_batches, meters, prefix=""):
396 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
397 | self.meters = meters
398 | self.prefix = prefix
399 |
400 | def display(self, batch):
401 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
402 | entries += [str(meter) for meter in self.meters]
403 | print('\t'.join(entries))
404 |
405 | def _get_batch_fmtstr(self, num_batches):
406 | num_digits = len(str(num_batches // 1))
407 | fmt = '{:' + str(num_digits) + 'd}'
408 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
409 |
410 |
411 | def adjust_learning_rate(optimizer, epoch, args):
412 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
413 | lr = args.lr * (0.1 ** (epoch // 30))
414 | for param_group in optimizer.param_groups:
415 | param_group['lr'] = lr
416 |
417 |
418 | def accuracy(output, target, topk=(1,)):
419 | """Computes the accuracy over the k top predictions for the specified values of k"""
420 | with torch.no_grad():
421 | maxk = max(topk)
422 | batch_size = target.size(0)
423 |
424 | _, pred = output.topk(maxk, 1, True, True)
425 | pred = pred.t()
426 | correct = pred.eq(target.view(1, -1).expand_as(pred))
427 |
428 | res = []
429 | for k in topk:
430 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
431 | res.append(correct_k.mul_(100.0 / batch_size))
432 | return res
433 |
434 |
435 | if __name__ == '__main__':
436 | main()
--------------------------------------------------------------------------------
/utils/defense_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from scipy.special import softmax
4 |
5 | # robust masking defense (Algorithm 1 in the paper)
6 | def masking_defense(local_feature,clipping=-1,thres=0.,window_shape=[6,6],ds=False):
7 | '''
8 | local_feature numpy.ndarray, feature tensor in the shape of [feature_size_x,feature_size_y,num_cls]
9 | clipping int/float, the positive clipping value ($c_h$ in the paper). If clipping < 0, treat clipping as np.inf
10 | thres float in [0,1], detection threshold. ($T$ in the paper)
11 | window_shape list [int,int], the shape of sliding window
12 | ds boolean, whether is for mask-ds
13 |
14 | Return int, robust prediction
15 | '''
16 |
17 | feature_size_x,feature_size_y,num_cls = local_feature.shape
18 | window_size_x,window_size_y = window_shape
19 | num_window_x = feature_size_x - window_size_x + 1 if not ds else feature_size_x
20 | num_window_y = feature_size_y - window_size_y + 1 if not ds else feature_size_y
21 |
22 | # clipping
23 | if clipping >0:
24 | local_feature = np.clip(local_feature,0,clipping)
25 | else:
26 | local_feature = np.clip(local_feature,0,np.inf)
27 |
28 |
29 | global_feature = np.sum(local_feature,axis=(0,1))
30 |
31 | # the sum of class evidence within each window
32 | in_window_sum_tensor=np.zeros([num_window_x,num_window_y,num_cls])
33 | for x in range(0,num_window_x):
34 | for y in range(0,num_window_y):
35 | if ds and x + window_size_x > feature_size_x: #only happens when ds is True
36 | in_window_sum_tensor[x,y,:] = np.sum(local_feature[x:,y:y+window_size_y,:],axis=(0,1)) + np.sum(local_feature[:x+window_size_x-feature_size_x,y:y+window_size_y,:],axis=(0,1))
37 | else: # normal case
38 | in_window_sum_tensor[x,y,:] = np.sum(local_feature[x:x+window_size_x,y:y+window_size_y,:],axis=(0,1))
39 |
40 |
41 | # calculate clipped and masked class evidence for each class
42 | for c in range(num_cls):
43 | max_window_sum = np.max(in_window_sum_tensor[:,:,c])
44 | if global_feature[c] > 0 and max_window_sum / global_feature[c] > thres:
45 | global_feature[c]-=max_window_sum
46 |
47 | pred_list = np.argsort(global_feature,kind='stable')#"stable" is necessary when the feature type is prediction
48 | return pred_list[-1]
49 |
50 |
51 | # provable analysis of robust masking defense (Algorithm 2 in the paper)
52 | def provable_masking(local_feature,label,clipping=-1,thres=0.,window_shape=[6,6],ds=False):
53 | '''
54 | local_feature numpy.ndarray, feature tensor in the shape of [feature_size_x,feature_size_y,num_cls]
55 | label int, true label
56 | clipping int/float, the positive clipping value ($c_h$ in the paper). If clipping < 0, treat clipping as np.inf
57 | thres float in [0,1], detection threshold. ($T$ in the paper)
58 | window_shape list [int,int], the shape of sliding window
59 | ds boolean, whether is for mask-ds
60 |
61 | Return int, provable analysis results (0: incorrect clean prediction; 1: possible attack found; 2: certified robustness )
62 | '''
63 |
64 | feature_size_x,feature_size_y,num_cls = local_feature.shape
65 | window_size_x,window_size_y = window_shape
66 | num_window_x = feature_size_x - window_size_x + 1 if not ds else feature_size_x
67 | num_window_y = feature_size_y - window_size_y + 1 if not ds else feature_size_y
68 |
69 | if clipping > 0:
70 | local_feature = np.clip(local_feature,0,clipping)
71 | else:
72 | local_feature = np.clip(local_feature,0,np.inf)
73 |
74 | global_feature = np.sum(local_feature,axis=(0,1))
75 |
76 | pred_list = np.argsort(global_feature,kind='stable')
77 | global_pred = pred_list[-1]
78 |
79 | if global_pred != label: # clean prediction is incorrect
80 | return 0
81 |
82 | local_feature_pred = local_feature[:,:,global_pred]
83 |
84 | # the sum of class evidence within each window
85 | in_window_sum_tensor = np.zeros([num_window_x,num_window_y,num_cls])
86 |
87 | for x in range(0,num_window_x):
88 | for y in range(0,num_window_y):
89 | if ds and x+window_size_x>feature_size_x: #only happens when ds is True
90 | in_window_sum_tensor[x,y,:] = np.sum(local_feature[x:,y:y+window_size_y,:],axis=(0,1)) + np.sum(local_feature[:x+window_size_x-feature_size_x,y:y+window_size_y,:],axis=(0,1))
91 | else:
92 | in_window_sum_tensor[x,y,:] = np.sum(local_feature[x:x+window_size_x,y:y+window_size_y,:],axis=(0,1))
93 |
94 |
95 | idx = np.ones([num_cls],dtype=bool)
96 | idx[global_pred]=False
97 | for x in range(0,num_window_x):
98 | for y in range(0,num_window_y):
99 |
100 | # determine the upper bound of wrong class evidence
101 | global_feature_masked = global_feature - in_window_sum_tensor[x,y,:] # $t$ in the proof of Lemma 1
102 | global_feature_masked[idx]/=(1 - thres) # $t/(1-T)$, the upper bound of wrong class evidence
103 |
104 | # determine the lower bound of true class evidence
105 | local_feature_pred_masked = local_feature_pred.copy()
106 | if ds and x+window_size_x>feature_size_x:
107 | local_feature_pred_masked[x:,y:y+window_size_y]=0
108 | local_feature_pred_masked[:x+window_size_x-feature_size_x,y:y+window_size_y]=0
109 | else:
110 | local_feature_pred_masked[x:x+window_size_x,y:y+window_size_y]=0 # operation $u\odot(1-w)$
111 |
112 | in_window_sum_pred_masked = in_window_sum_tensor[:,:,global_pred].copy()
113 | overlap_window_max_sum = 0
114 | # only need to recalculate the windows the are partially masked
115 | for xx in range(max(0,x - window_size_x + 1),min(x + window_size_x,num_window_x)):
116 | for yy in range(max(0,y - window_size_y + 1),min(y + window_size_y,num_window_y)):
117 | if ds and xx+window_size_x>feature_size_x:
118 | in_window_sum_pred_masked[xx,yy]=local_feature_pred_masked[xx:,yy:yy+window_size_y].sum()+local_feature_pred_masked[:xx+window_size_x-feature_size_x,yy:yy+window_size_y].sum()
119 | overlap_window_max_sum = in_window_sum_pred_masked[xx,yy] if overlap_window_max_sum thres:
126 | global_feature_masked[global_pred]-=max_window_sum_pred
127 | else:
128 | global_feature_masked[global_pred]-=overlap_window_max_sum
129 |
130 |
131 | # determine if an attack is possible
132 | if np.argsort(global_feature_masked,kind='stable')[-1]!=label:
133 | return 1
134 |
135 | return 2 #provable robustness
136 |
137 |
138 |
139 |
140 | # De-randomized Smoothing
141 | # Adapted from https://github.com/alevine0/patchSmoothing/blob/master/utils_band.py
142 | def ds(inpt,net,block_size, size_to_certify, num_classes, threshold=0.2):
143 | '''
144 | inpt torch.tensor, the input images in CWH format
145 | net torch.nn.module, the based model whose input is small pixel bands
146 | block_size int, the width of pixel bands
147 | size_to_certify int, the patch size to be certified
148 | num_classes int, number of classes
149 | threshold float, the threshold for prediction, see their original paper for details
150 |
151 | Return [torch.tensor,torch.tensor], the clean prediction, certificate
152 | '''
153 |
154 | predictions = torch.zeros(inpt.size(0), num_classes).type(torch.int).cuda()
155 | batch = inpt.permute(0,2,3,1) #color channel last
156 | for pos in range(batch.shape[2]):
157 | out_c1 = torch.zeros(batch.shape).cuda()
158 | out_c2 = torch.zeros(batch.shape).cuda()
159 | if (pos+block_size > batch.shape[2]):
160 | out_c1[:,:,pos:] = batch[:,:,pos:]
161 | out_c2[:,:,pos:] = 1. - batch[:,:,pos:]
162 |
163 | out_c1[:,:,:pos+block_size-batch.shape[2]] = batch[:,:,:pos+block_size-batch.shape[2]]
164 | out_c2[:,:,:pos+block_size-batch.shape[2]] = 1. - batch[:,:,:pos+block_size-batch.shape[2]]
165 | else:
166 | out_c1[:,:,pos:pos+block_size] = batch[:,:,pos:pos+block_size]
167 | out_c2[:,:,pos:pos+block_size] = 1. - batch[:,:,pos:pos+block_size]
168 |
169 | out_c1 = out_c1.permute(0,3,1,2)
170 | out_c2 = out_c2.permute(0,3,1,2)
171 | out = torch.cat((out_c1,out_c2), 1)
172 | softmx = torch.nn.functional.softmax(net(out),dim=1)
173 | predictions += (softmx >= threshold).type(torch.int).cuda()
174 |
175 | predinctionsnp = predictions.cpu().numpy()
176 | idxsort = np.argsort(-predinctionsnp,axis=1,kind='stable')
177 | valsort = -np.sort(-predinctionsnp,axis=1,kind='stable')
178 | val = valsort[:,0]
179 | idx = idxsort[:,0]
180 | valsecond = valsort[:,1]
181 | idxsecond = idxsort[:,1]
182 | num_affected_classifications=(size_to_certify + block_size -1)
183 | cert = torch.tensor(((val - valsecond >2*num_affected_classifications) | ((val - valsecond ==2*num_affected_classifications)&(idx < idxsecond)))).cuda()
184 | return torch.tensor(idx).cuda(), cert
185 |
186 |
187 | # mask-ds
188 | def masking_ds(inpt,labels,net,block_size,size_to_certify,thres=0.0):
189 | '''
190 | inpt torch.tensor, the input images in CWH format
191 | labels numpy.ndarray, the list of label
192 | net torch.nn.module, the based model whose input is small pixel bands
193 | block_size int, the width of pixel bands
194 | size_to_certify int, the patch size to be certified
195 | thres float, the detection theshold ($T$). Note it is not `threshold` in ds()
196 |
197 | Return: [list,list], a list of provable analysis results and a list of clean prediction correctneses
198 | '''
199 | logits_list=[]
200 | cnf_list=[]
201 | pred_list=[]
202 | batch = inpt.permute(0,2,3,1) #color channel last
203 | for pos in range(batch.shape[2]):
204 | out_c1 = torch.zeros(batch.shape).cuda()
205 | out_c2 = torch.zeros(batch.shape).cuda()
206 | if (pos+block_size > batch.shape[2]):
207 | out_c1[:,:,pos:] = batch[:,:,pos:]
208 | out_c2[:,:,pos:] = 1. - batch[:,:,pos:]
209 |
210 | out_c1[:,:,:pos+block_size-batch.shape[2]] = batch[:,:,:pos+block_size-batch.shape[2]]
211 | out_c2[:,:,:pos+block_size-batch.shape[2]] = 1. - batch[:,:,:pos+block_size-batch.shape[2]]
212 | else:
213 | out_c1[:,:,pos:pos+block_size] = batch[:,:,pos:pos+block_size]
214 | out_c2[:,:,pos:pos+block_size] = 1. - batch[:,:,pos:pos+block_size]
215 |
216 | out_c1 = out_c1.permute(0,3,1,2)
217 | out_c2 = out_c2.permute(0,3,1,2)
218 | out = torch.cat((out_c1,out_c2), 1)
219 | logits_tmp = net(out).detach().cpu().numpy()
220 | cnf_tmp = softmax(logits_tmp,axis=-1)
221 | pred_tmp = (cnf_tmp > 0.2).astype(float)
222 | logits_list.append(logits_tmp)
223 | cnf_list.append(cnf_tmp)
224 | pred_list.append(pred_tmp)
225 |
226 | #output_list = np.stack(logits_list,axis=1)
227 | output_list = np.stack(cnf_list,axis=1)
228 | #output_list = np.stack(pred_list,axis=1)
229 |
230 | B,W,C=output_list.shape
231 | result_list=[]
232 | clean_corr_list=[]
233 | window_size = block_size + size_to_certify -1
234 |
235 | for i in range(len(labels)):
236 | local_feature = output_list[i].reshape([W,1,C])
237 | result=provable_masking(local_feature,labels[i],window_shape=[window_size,1],thres=thres,ds=True)
238 | clean_pred=masking_defense(local_feature,window_shape=[window_size,1],thres=thres,ds=True)
239 | result_list.append(result)
240 | clean_corr_list.append(clean_pred == labels[i])
241 |
242 | return result_list,clean_corr_list
243 |
244 | ##################################################################################################################################
245 |
246 | # a extended version of provable_masking()
247 | def provable_masking_large_mask(local_feature,label,clipping=-1,thres=0.,window_shape=[6,6],mask_shape=None):
248 | '''
249 | local_feature numpy.ndarray, feature tensor in the shape of [feature_size_x,feature_size_y,num_cls]
250 | label int, true label
251 | clipping int/float, the positive clipping value ($c_h$ in the paper). If clipping < 0, treat clipping as np.inf
252 | thres float in [0,1], detection threshold. ($T$ in the paper)
253 | window_shape list [int,int], the shape of malicious window
254 | mask_shape list [int,int], the shape of mask window. If set to None, take the same value of window_shape
255 |
256 | Return int, provable analysis results (0: incorrect clean prediction; 1: possible attack found; 2: certified robustness )
257 | '''
258 | feature_size_x,feature_size_y,num_cls = local_feature.shape
259 |
260 | patch_size_x,patch_size_y = window_shape
261 | num_patch_x = feature_size_x - patch_size_x + 1
262 | num_patch_y = feature_size_y - patch_size_y + 1
263 |
264 | if mask_shape is None:
265 | mask_shape = window_shape
266 | mask_size_x,mask_size_y = mask_shape
267 | num_mask_x = feature_size_x - mask_size_x + 1
268 | num_mask_y = feature_size_y - mask_size_y + 1
269 |
270 | if clipping > 0:
271 | local_feature = np.clip(local_feature,0,clipping)
272 | else:
273 | local_feature = np.clip(local_feature,0,np.inf)
274 |
275 | global_feature = np.sum(local_feature,axis=(0,1))
276 |
277 | pred_list = np.argsort(global_feature,kind='stable')
278 | global_pred = pred_list[-1]
279 |
280 | if global_pred != label: #clean prediction is incorrect
281 | return 0
282 |
283 | # the sum of class evidence within mask window
284 | in_mask_sum_tensor = np.zeros([num_mask_x,num_mask_y,num_cls])
285 | for x in range(0,num_mask_x):
286 | for y in range(0,num_mask_y):
287 | in_mask_sum_tensor[x,y] = np.sum(local_feature[x:x+mask_size_x,y:y+mask_size_y,:],axis=(0,1))
288 |
289 |
290 | # the sum of class evidence within each possible malicious window
291 | in_patch_sum_tensor = np.zeros([num_patch_x,num_patch_y,num_cls])
292 | for x in range(0,num_patch_x):
293 | for y in range(0,num_patch_y):
294 | in_patch_sum_tensor[x,y,:] = np.sum(local_feature[x:x+patch_size_x,y:y+patch_size_y,:],axis=(0,1))
295 |
296 | #out_patch_sum_tensor = global_feature.reshape([1,1,num_cls]) - in_patch_sum_tensor
297 |
298 | idx = np.ones([num_cls],dtype=bool)
299 | idx[global_pred]=False
300 |
301 | for x in range(0,num_patch_x):
302 | for y in range(0,num_patch_y):
303 |
304 | # determine the upper bound of wrong class evidence
305 | cover_patch_mask_sum_tensor = in_mask_sum_tensor[max(0,x + patch_size_x - mask_size_x):min(x+1,num_mask_x),max(0,y + patch_size_y - mask_size_y):min(y+1,num_mask_y)]
306 | max_cover_patch_mask_sum = np.max(cover_patch_mask_sum_tensor,axis=(0,1))
307 | global_feature_patched = global_feature - max_cover_patch_mask_sum # $t-k$ in the proof of Lemma 2
308 | global_feature_patched[idx]/=(1 - thres) # $(t-k)/(1-T)$ in the proof of Lemma 2
309 | overlap_window_max_sum = 0
310 | # determine the lower bound of true class evidence
311 | local_feature_pred_masked = local_feature[:,:,global_pred].copy()
312 | local_feature_pred_masked[x:x+patch_size_x,y:y+patch_size_y]=0
313 | in_mask_sum_pred_masked = in_mask_sum_tensor[:,:,global_pred].copy()
314 | # only need to recalculate the windows the are partially masked
315 | for xx in range(max(0,x - mask_size_x + 1),min(x + patch_size_x,num_mask_x)):
316 | for yy in range(max(0,y - mask_size_y + 1),min(y + patch_size_y,num_mask_y)):
317 | in_mask_sum_pred_masked[xx,yy]=local_feature_pred_masked[xx:xx+mask_size_x,yy:yy+mask_size_y].sum()
318 | overlap_window_max_sum = in_window_sum_pred_masked[xx,yy] if overlap_window_max_sum thres:
323 | global_feature_patched[global_pred]-=max_mask_sum_pred
324 | else:
325 | global_feature_masked[global_pred]-=overlap_window_max_sum
326 |
327 | # determine if an attack is possible
328 | if np.argsort(global_feature_patched,kind='stable')[-1]!=label:
329 | return 1
330 | return 2 #provable robustness
331 |
332 |
333 | # clipping based defense
334 | def clipping_defense(local_feature,clipping=-1):
335 | '''
336 | local_feature numpy.ndarray, feature tensor in the shape of [feature_size_x,feature_size_y,num_cls]
337 | clipping int/float, clipping value. If clipping < 0, use cbn clipping
338 |
339 | Return int, provable analysis results (0: incorrect clean prediction; 1: possible attack found; 2: certified robustness )
340 | '''
341 | if clipping > 0:
342 | local_feature = np.clip(local_feature,0,clipping) #clipped with [0,clipping]
343 | else:
344 | local_feature = np.tanh(local_feature*0.05-1) # clipped with tanh (CBN)
345 | global_feature = np.mean(local_feature,axis=(0,1))
346 | global_pred = np.argmax(global_feature)
347 |
348 | return global_pred
349 |
350 | # provable analysis for clipping based defense
351 | def provable_clipping(local_feature,label,clipping=-1,window_shape=[6,6]):
352 |
353 | '''
354 | local_feature numpy.ndarray, feature tensor in the shape of [feature_size_x,feature_size_y,num_cls]
355 | label int, true label
356 | clipping int/float, clipping value. If clipping < 0, use cbn clipping
357 |
358 | window_shape list [int,int], the shape of sliding window
359 |
360 | Return int, provable analysis results (0: incorrect clean prediction; 1: possible attack found; 2: certified robustness )
361 | '''
362 | feature_size_x,feature_size_y,num_cls = local_feature.shape
363 | window_size_x,window_size_y = window_shape
364 | num_window_x = feature_size_x - window_size_x + 1
365 | num_window_y = feature_size_y - window_size_y + 1
366 |
367 | if clipping > 0:
368 | local_feature = np.clip(local_feature,0,clipping) #clipped with [0,clipping]
369 | max_increase = window_size_x * window_size_y * clipping
370 | else:
371 | local_feature = np.tanh(local_feature*0.05-1) # clipped with tanh (CBN)
372 | max_increase = window_size_x * window_size_y * 2
373 |
374 | local_pred = np.argmax(local_feature,axis=-1)
375 | global_feature = np.mean(local_feature,axis=(0,1))
376 | pred_list = np.argsort(global_feature)
377 | global_pred = pred_list[-1]
378 | if global_pred != label: #clean prediction is incorrect
379 | return 0
380 | local_feature_pred = local_feature[:,:,global_pred]
381 |
382 |
383 | target_cls = pred_list[-2] #second prediction
384 |
385 | local_feature_target = local_feature[:,:,target_cls]
386 | diff_feature = local_feature_pred - local_feature_target
387 |
388 | for x in range(0,num_window_x):
389 | for y in range(0,num_window_y):
390 | diff_feature_masked = diff_feature.copy()
391 | diff_feature_masked[x:x+window_size_x,y:y+window_size_y]=0
392 | diff = diff_feature_masked.sum()
393 | if diff < max_increase:
394 | return 1
395 | return 2 # provable robustness
396 |
397 |
398 |
399 | ##################################################################################################################################
400 |
401 |
402 | # for PatchGuard++
403 |
404 |
405 |
406 | def pg2_detection(local_feature,tau,window_shape=[6,6]):
407 | '''
408 | local_feature numpy.ndarray, feature tensor in the shape of [feature_size_x,feature_size_y,num_cls]
409 | tau float in [0,1], detection threshold. $\tau$ in the paper
410 | window_shape list [int,int], the shape of sliding window
411 |
412 | Return int, class label or -1 for alert
413 | '''
414 | feature_size_x,feature_size_y,num_cls = local_feature.shape
415 | window_size_x,window_size_y = window_shape
416 | num_window_x = feature_size_x - window_size_x + 1
417 | num_window_y = feature_size_y - window_size_y + 1
418 |
419 | global_feature = np.mean(local_feature,axis=(0,1))
420 | pred_list = np.argsort(global_feature,kind='stable')
421 | global_pred = pred_list[-1]
422 |
423 | in_window_sum_tensor=np.zeros([num_window_x,num_window_y,num_cls])
424 | for x in range(0,num_window_x):
425 | for y in range(0,num_window_y):
426 | in_window_sum_tensor[x,y,:] = np.sum(local_feature[x:x+window_size_x,y:y+window_size_y,:],axis=(0,1))
427 | in_window_sum_tensor = in_window_sum_tensor/(feature_size_x*feature_size_y)
428 |
429 | for x in range(0,num_window_x):
430 | for y in range(0,num_window_y):
431 | global_feature_masked = global_feature - in_window_sum_tensor[x,y]
432 | global_feature_masked = softmax(global_feature_masked)
433 | masked_pred = np.argmax(global_feature_masked)
434 | masked_conf = np.max(global_feature_masked)
435 | if masked_pred != global_pred and masked_conf>tau:
436 | return -1
437 | return global_pred
438 |
439 |
440 |
441 |
442 | def pg2_detection_provable(local_feature,label,tau,window_shape=[6,6]):
443 | '''
444 | local_feature numpy.ndarray, feature tensor in the shape of [feature_size_x,feature_size_y,num_cls]
445 | label int, the ground-truth class label
446 | tau float in [0,1], detection threshold. $\tau$ in the paper
447 | window_shape list [int,int], the shape of sliding window
448 |
449 | Return int, provable analysis results (0: incorrect clean prediction; 1: possible attack found; 2: certified robustness )
450 | '''
451 | feature_size_x,feature_size_y,num_cls = local_feature.shape
452 | window_size_x,window_size_y = window_shape
453 | num_window_x = feature_size_x - window_size_x + 1
454 | num_window_y = feature_size_y - window_size_y + 1
455 |
456 | global_feature = np.mean(local_feature,axis=(0,1))
457 | pred_list = np.argsort(global_feature,kind='stable')
458 | global_pred = pred_list[-1]
459 |
460 | in_window_sum_tensor=np.zeros([num_window_x,num_window_y,num_cls])
461 | for x in range(0,num_window_x):
462 | for y in range(0,num_window_y):
463 | in_window_sum_tensor[x,y,:] = np.sum(local_feature[x:x+window_size_x,y:y+window_size_y,:],axis=(0,1))
464 | in_window_sum_tensor = in_window_sum_tensor/(feature_size_x*feature_size_y)
465 |
466 | if global_pred != label: # clean prediction is incorrect
467 | return 0
468 |
469 | for x in range(0,num_window_x):
470 | for y in range(0,num_window_y):
471 | global_feature_masked = global_feature - in_window_sum_tensor[x,y]
472 | global_feature_masked = softmax(global_feature_masked)
473 | masked_pred = np.argmax(global_feature_masked)
474 | masked_conf = np.max(global_feature_masked)
475 | if masked_pred != label or masked_conftau:
524 | clean = 0
525 | if provable == 1 and clean ==0:
526 | return provable,clean
527 | return provable,clean
528 |
529 | """
--------------------------------------------------------------------------------