├── doc └── overview.png ├── requirement.txt ├── utils ├── normalize_utils.py ├── cutout.py ├── progress_bar.py └── defense_utils.py ├── LICENSE ├── checkpoints └── README.md ├── example_cmd.sh ├── misc ├── test_acc.py ├── PatchAttacker.py ├── patch_attack.py ├── train_cifar.py ├── train_imagenette.py └── train_imagenet.py ├── nets ├── dsresnet_cifar.py ├── bagnet.py ├── dsresnet_imgnt.py └── resnet.py ├── README.md ├── mask_ds.py ├── det_bn.py └── mask_bn.py /doc/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inspire-group/PatchGuard/HEAD/doc/overview.png -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | tqdm #==4.51.0 2 | torch #==1.7.0 3 | torchvision #==0.8.1 4 | joblib #==0.17.0 5 | scipy #==1.5.4 6 | numpy #==1.19.2 -------------------------------------------------------------------------------- /utils/normalize_utils.py: -------------------------------------------------------------------------------- 1 | ################################################ 2 | # Not used. Useful if visualization is desired # 3 | ################################################ 4 | 5 | import numpy as np 6 | 7 | mean_vec=[0.485, 0.456, 0.406] 8 | std_vec=[0.229, 0.224, 0.225] 9 | 10 | def normalize_np(data,mean,std): 11 | #input data B*W*H*C 12 | B,W,H,C=data.shape 13 | mean=np.array(mean).reshape([1,1,1,C]) 14 | std=np.array(std).reshape([1,1,1,C]) 15 | return (data-mean)/std 16 | 17 | def unnormalize_np(data,mean,std): 18 | #input data B*W*H*C 19 | B,W,H,C=data.shape 20 | mean=np.array(mean).reshape([1,1,1,C]) 21 | std=np.array(std).reshape([1,1,1,C]) 22 | return data*std+mean -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 xiangchong1 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | ## Checkpoints 2 | ### overview 3 | Model checkpoints used in the paper can be downloaded from [link](https://drive.google.com/drive/folders/1u5RsCuZNf7ddWW0utI4OrgWGmJCUDCuT?usp=sharing). 4 | 5 | The checkpoints from the google drive is obtained with "provable adversarial training" (add feature masks during the training). 6 | 7 | Model training should be very easy with the provided training scripts. 8 | 9 | ### checkpoints for bagnet/resnet trained on imagenet 10 | two model checkpoints trained with "provable adversarial training" are available now! bagnet17_net.pth will give the results reported in our paper. PS: the clean accuracy for resnet50 (note that resnet50 is not used in our defense!) reported in the paper uses the pretrained weights from torchvision. 11 | 12 | - bagnet33_net.pth 13 | - bagnet17_net.pth 14 | 15 | ### checkpoints for bagnet/resnet trained on imagenette 16 | - resnet50_nette.pth 17 | - bagnet33_nette.pth 18 | - bagnet17_nette.pth 19 | - bagnet9_nette.pth 20 | 21 | ### checkpoints for bagnet/resnet trained on cifar 22 | - resnet50_192_cifar.pth 23 | - bagnet33_192_cifar.pth 24 | - bagnet17_192_cifar.pth 25 | - bagnet9_192_cifar.pth 26 | 27 | ### checkpoints for ds-resnet on different datasets 28 | - ds_net.pth 29 | - ds_nette.pth 30 | - ds_cifar.pth 31 | 32 | Training scripts for ds-resnet are not provided in this repository, but can be found be found in [patchSmoothing](https://github.com/alevine0/patchSmoothing) 33 | -------------------------------------------------------------------------------- /utils/cutout.py: -------------------------------------------------------------------------------- 1 | ############################################### 2 | # not used in the paper 3 | # from https://github.com/uoguelph-mlrg/Cutout 4 | ############################################### 5 | 6 | import torch 7 | import numpy as np 8 | 9 | 10 | class Cutout(object): 11 | """Randomly mask out one or more patches from an image. 12 | 13 | Args: 14 | n_holes (int): Number of patches to cut out of each image. 15 | length (int): The length (in pixels) of each square patch. 16 | """ 17 | def __init__(self, n_holes, length): 18 | self.n_holes = n_holes 19 | self.length = length 20 | 21 | def __call__(self, img): 22 | """ 23 | Args: 24 | img (Tensor): Tensor image of size (C, H, W). 25 | Returns: 26 | Tensor: Image with n_holes of dimension length x length cut out of it. 27 | """ 28 | h = img.size(1) 29 | w = img.size(2) 30 | 31 | mask = np.ones((h, w), np.float32) 32 | 33 | for n in range(self.n_holes): 34 | y = np.random.randint(h) 35 | x = np.random.randint(w) 36 | 37 | y1 = np.clip(y - self.length // 2, 0, h) 38 | y2 = np.clip(y + self.length // 2, 0, h) 39 | x1 = np.clip(x - self.length // 2, 0, w) 40 | x2 = np.clip(x + self.length // 2, 0, w) 41 | 42 | mask[y1: y2, x1: x2] = 0. 43 | 44 | mask = torch.from_numpy(mask) 45 | mask = mask.expand_as(img) 46 | img = img * mask 47 | 48 | return img 49 | -------------------------------------------------------------------------------- /example_cmd.sh: -------------------------------------------------------------------------------- 1 | #install packages 2 | pip install -r requirement.txt 3 | #provable analysis with CBN and robust masking 4 | python mask_bn.py --model bagnet17 --dataset imagenette --patch_size 32 --cbn #cbn with bagnet17 on imagenette 5 | python mask_bn.py --model bagnet17 --dataset imagenette --patch_size 32 --m #mask-bn with bagnet17 on imagenette 6 | python mask_bn.py --model bagnet17 --dataset imagenet --patch_size 32 --cbn #cbn with bagnet17 on imagenet 7 | python mask_bn.py --model bagnet17 --dataset imagenet --patch_size 32 --m #mask-bn with bagnet17 on imagenet 8 | python mask_bn.py --model bagnet17 --dataset cifar --patch_size 30 --cbn #cbn with bagnet17 on cifar 9 | python mask_bn.py --model bagnet17 --dataset cifar --patch_size 30 --m #mask-bn with bagnet17 on cifar 10 | #mask-ds and ds 11 | python mask_ds.py --dataset imagenette --patch_size 42 --ds #ds for imagenette 12 | python mask_ds.py --dataset imagenette --patch_size 42 --m #mask-ds for imagenette 13 | python mask_ds.py --dataset imagenet --patch_size 42 --ds #ds for imagenet 14 | python mask_ds.py --dataset imagenet --patch_size 42 --m #mask-ds for imagenet 15 | python mask_ds.py --dataset cifar --patch_size 5 --ds #ds for cifar 16 | python mask_ds.py --dataset cifar --patch_size 5 --m #mask-ds for cifar 17 | 18 | # patchguard++ 19 | python det_bn.py --det --model bagnet33 --tau 0.5 --patch_sie 32 --dataset imagenette # an example. the usage is similar to mask_bn.py and mask_ds.py 20 | python det_bn.py --det --model bagnet33 --tau 0.7 --patch_sie 32 --dataset imagenette # you can try different threshold tau 21 | 22 | #test model accuracy 23 | python test_acc.py --model resnet50 --dataset imagenette #test accuracy of resnet50 on imagenette 24 | python test_acc.py --model resnet50 --dataset imagenet #test accuracy of resnet50 on imagenet 25 | python test_acc.py --model resnet50 --dataset cifar #test accuracy of resnet50 on cifar 26 | python test_acc.py --model bagnet17 --dataset imagenette #test accuracy of bagnet17 on imagenette (similar for imagenet,cifar) 27 | python test_acc.py --model bagnet33 --dataset imagenette #test accuracy of bagnet33 on imagenette (similar for imagenet,cifar) 28 | python test_acc.py --model bagnet9 --dataset imagenette #test accuracy of bagnet9 on imagenet (similar for imagenet) 29 | python test_acc.py --model bagnet17 --dataset imagenette --clip 15 #test accuracy of bagnet17 (clipped with [0,15]) on imagenette (similar for imagenet,cifar) 30 | python test_acc.py --model bagnet17 --dataset imagenette --aggr median #test accuracy of bagnet17 with median aggregation on imagenette (similar for imagenet,cifar) 31 | python test_acc.py --model bagnet17 --dataset imagenette --aggr cbn #test accuracy of bagnet17 with cbn clipping on imagenette (similar for imagenet,cifar) 32 | #empirical untargeted attack 33 | python patch_attack.py --model bagnet17 --dataset imagenette --patch_size 31 #untargeted attack against bagnet17 34 | python patch_attack.py --model bagnet17 --dataset imagenette --patch_size 31 --aggr cbn #untargeted attack against bagnet17 with cbn clipping 35 | #train model 36 | python train_imagenette.py --model_name bagnet17_nette.pth --epoch 20 #train model on imagenette 37 | python train_imagenette.py --model_name bagnet17_nette.pth --aggr adv --epoch 20 #train model on imagenette with provable adversarial training 38 | python train_cifar.py --lr 0.01 #train cifar model 39 | python train_cifar.py --resume --lr 0.001 #resume cifar model training with a different learning rate 40 | 41 | 42 | -------------------------------------------------------------------------------- /utils/progress_bar.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | 14 | 15 | def get_mean_and_std(dataset): 16 | '''Compute the mean and std value of dataset.''' 17 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 18 | mean = torch.zeros(3) 19 | std = torch.zeros(3) 20 | print('==> Computing mean and std..') 21 | for inputs, targets in dataloader: 22 | for i in range(3): 23 | mean[i] += inputs[:,i,:,:].mean() 24 | std[i] += inputs[:,i,:,:].std() 25 | mean.div_(len(dataset)) 26 | std.div_(len(dataset)) 27 | return mean, std 28 | 29 | def init_params(net): 30 | '''Init layer parameters.''' 31 | for m in net.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | init.kaiming_normal(m.weight, mode='fan_out') 34 | if m.bias: 35 | init.constant(m.bias, 0) 36 | elif isinstance(m, nn.BatchNorm2d): 37 | init.constant(m.weight, 1) 38 | init.constant(m.bias, 0) 39 | elif isinstance(m, nn.Linear): 40 | init.normal(m.weight, std=1e-3) 41 | if m.bias: 42 | init.constant(m.bias, 0) 43 | 44 | 45 | _, term_width = os.popen('stty size', 'r').read().split() 46 | term_width = int(term_width) 47 | 48 | TOTAL_BAR_LENGTH = 65. 49 | last_time = time.time() 50 | begin_time = last_time 51 | def progress_bar(current, total, msg=None): 52 | global last_time, begin_time 53 | if current == 0: 54 | begin_time = time.time() # Reset for new bar. 55 | 56 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 57 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 58 | 59 | sys.stdout.write(' [') 60 | for i in range(cur_len): 61 | sys.stdout.write('=') 62 | sys.stdout.write('>') 63 | for i in range(rest_len): 64 | sys.stdout.write('.') 65 | sys.stdout.write(']') 66 | 67 | cur_time = time.time() 68 | step_time = cur_time - last_time 69 | last_time = cur_time 70 | tot_time = cur_time - begin_time 71 | 72 | L = [] 73 | L.append(' Step: %s' % format_time(step_time)) 74 | L.append(' | Tot: %s' % format_time(tot_time)) 75 | if msg: 76 | L.append(' | ' + msg) 77 | 78 | msg = ''.join(L) 79 | sys.stdout.write(msg) 80 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 81 | sys.stdout.write(' ') 82 | 83 | # Go back to the center of the bar. 84 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 85 | sys.stdout.write('\b') 86 | sys.stdout.write(' %d/%d ' % (current+1, total)) 87 | 88 | if current < total-1: 89 | sys.stdout.write('\r') 90 | else: 91 | sys.stdout.write('\n') 92 | sys.stdout.flush() 93 | 94 | def format_time(seconds): 95 | days = int(seconds / 3600/24) 96 | seconds = seconds - days*3600*24 97 | hours = int(seconds / 3600) 98 | seconds = seconds - hours*3600 99 | minutes = int(seconds / 60) 100 | seconds = seconds - minutes*60 101 | secondsf = int(seconds) 102 | seconds = seconds - secondsf 103 | millis = int(seconds*1000) 104 | 105 | f = '' 106 | i = 1 107 | if days > 0: 108 | f += str(days) + 'D' 109 | i += 1 110 | if hours > 0 and i <= 2: 111 | f += str(hours) + 'h' 112 | i += 1 113 | if minutes > 0 and i <= 2: 114 | f += str(minutes) + 'm' 115 | i += 1 116 | if secondsf > 0 and i <= 2: 117 | f += str(secondsf) + 's' 118 | i += 1 119 | if millis > 0 and i <= 2: 120 | f += str(millis) + 'ms' 121 | i += 1 122 | if f == '': 123 | f = '0ms' 124 | return f 125 | -------------------------------------------------------------------------------- /misc/test_acc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | from torchvision import datasets, transforms 7 | 8 | import nets.bagnet 9 | import nets.resnet 10 | from utils.defense_utils import * 11 | 12 | import os 13 | import argparse 14 | from tqdm import tqdm 15 | import numpy as np 16 | import PIL 17 | 18 | parser = argparse.ArgumentParser() 19 | 20 | parser.add_argument("--model_dir",default='checkpoints',type=str,help="path to checkpoints") 21 | parser.add_argument('--data_dir', default='data', type=str,help="path to data") 22 | parser.add_argument('--dataset', default='imagenette', choices=('imagenette','imagenet','cifar'),type=str,help="dataset") 23 | parser.add_argument("--model",default='bagnet17',type=str,help="model name") 24 | parser.add_argument("--clip",default=-1,type=int,help="clipping value; do clipping when this argument is set to positive") 25 | parser.add_argument("--aggr",default='mean',type=str,help="aggregation methods. set to none for local feature") 26 | 27 | args = parser.parse_args() 28 | 29 | MODEL_DIR=os.path.join('.',args.model_dir) 30 | DATA_DIR=os.path.join(args.data_dir,args.dataset) 31 | DATASET = args.dataset 32 | 33 | def get_dataset(ds,data_dir): 34 | if ds in ['imagenette','imagenet']: 35 | ds_dir=os.path.join(data_dir,'val') 36 | ds_transforms = transforms.Compose([ 37 | transforms.Resize(256), 38 | transforms.CenterCrop(224), 39 | transforms.ToTensor(), 40 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 41 | ]) 42 | dataset_ = datasets.ImageFolder(ds_dir,ds_transforms) 43 | class_names = dataset_.classes 44 | elif ds == 'cifar': 45 | ds_transforms = transforms.Compose([ 46 | transforms.Resize(192, interpolation=PIL.Image.BICUBIC), 47 | transforms.ToTensor(), 48 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 49 | ]) 50 | dataset_ = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=ds_transforms) 51 | class_names = dataset_.classes 52 | return dataset_,class_names 53 | 54 | val_dataset,class_names = get_dataset(DATASET,DATA_DIR) 55 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8,shuffle=False) 56 | 57 | #build and initialize model 58 | device = 'cuda' #if torch.cuda.is_available() else 'cpu' 59 | 60 | if args.clip > 0: 61 | clip_range = [0,args.clip] 62 | else: 63 | clip_range = None 64 | 65 | if 'bagnet17' in args.model: 66 | model = nets.bagnet.bagnet17(pretrained=True,clip_range=clip_range,aggregation=args.aggr) 67 | elif 'bagnet33' in args.model: 68 | model = nets.bagnet.bagnet33(pretrained=True,clip_range=clip_range,aggregation=args.aggr) 69 | elif 'bagnet9' in args.model: 70 | model = nets.bagnet.bagnet9(pretrained=True,clip_range=clip_range,aggregation=args.aggr) 71 | elif 'resnet50' in args.model: 72 | model = nets.resnet.resnet50(pretrained=True,clip_range=clip_range,aggregation=args.aggr) 73 | 74 | if DATASET == 'imagenette': 75 | num_ftrs = model.fc.in_features 76 | model.fc = nn.Linear(num_ftrs, len(class_names)) 77 | model = torch.nn.DataParallel(model) 78 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_nette.pth')) 79 | model.load_state_dict(checkpoint['model_state_dict']) 80 | elif DATASET == 'imagenet': 81 | model = torch.nn.DataParallel(model) 82 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_net.pth')) 83 | model.load_state_dict(checkpoint['state_dict']) 84 | elif DATASET == 'cifar': 85 | num_ftrs = model.fc.in_features 86 | model.fc = nn.Linear(num_ftrs, len(class_names)) 87 | model = torch.nn.DataParallel(model) 88 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_192_cifar.pth')) 89 | model.load_state_dict(checkpoint['net']) 90 | 91 | model = model.to(device) 92 | model.eval() 93 | cudnn.benchmark = True 94 | 95 | accuracy_list=[] 96 | 97 | for data,labels in tqdm(val_loader): 98 | data,labels=data.to(device),labels.to(device) 99 | output_clean = model(data) 100 | acc_clean=torch.sum(torch.argmax(output_clean, dim=1) == labels).cpu().detach().numpy() 101 | accuracy_list.append(acc_clean) 102 | 103 | print("Test accuracy:",np.sum(accuracy_list)/len(val_dataset)) 104 | 105 | 106 | -------------------------------------------------------------------------------- /nets/dsresnet_cifar.py: -------------------------------------------------------------------------------- 1 | ############################################################################################## 2 | # from https://github.com/alevine0/patchSmoothing/blob/master/pytorch_cifar/models/resnet.py 3 | ############################################################################################## 4 | 5 | 6 | '''ResNet in PyTorch. 7 | 8 | For Pre-activation ResNet, see 'preact_resnet.py'. 9 | 10 | Reference: 11 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 12 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 13 | ''' 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, in_planes, planes, stride=1): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | 29 | self.shortcut = nn.Sequential() 30 | if stride != 1 or in_planes != self.expansion*planes: 31 | self.shortcut = nn.Sequential( 32 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(self.expansion*planes) 34 | ) 35 | 36 | def forward(self, x): 37 | out = F.relu(self.bn1(self.conv1(x))) 38 | out = self.bn2(self.conv2(out)) 39 | out += self.shortcut(x) 40 | out = F.relu(out) 41 | return out 42 | 43 | 44 | class Bottleneck(nn.Module): 45 | expansion = 4 46 | 47 | def __init__(self, in_planes, planes, stride=1): 48 | super(Bottleneck, self).__init__() 49 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 50 | self.bn1 = nn.BatchNorm2d(planes) 51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 54 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 55 | 56 | self.shortcut = nn.Sequential() 57 | if stride != 1 or in_planes != self.expansion*planes: 58 | self.shortcut = nn.Sequential( 59 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 60 | nn.BatchNorm2d(self.expansion*planes) 61 | ) 62 | 63 | def forward(self, x): 64 | out = F.relu(self.bn1(self.conv1(x))) 65 | out = F.relu(self.bn2(self.conv2(out))) 66 | out = self.bn3(self.conv3(out)) 67 | out += self.shortcut(x) 68 | out = F.relu(out) 69 | return out 70 | 71 | 72 | class ResNet(nn.Module): 73 | def __init__(self, block, num_blocks, num_classes=10): 74 | super(ResNet, self).__init__() 75 | self.in_planes = 64 76 | 77 | self.conv1 = nn.Conv2d(6, 64, kernel_size=3, stride=1, padding=1, bias=False) 78 | self.bn1 = nn.BatchNorm2d(64) 79 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 80 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 81 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 82 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 83 | self.linear = nn.Linear(512*block.expansion, num_classes) 84 | 85 | def _make_layer(self, block, planes, num_blocks, stride): 86 | strides = [stride] + [1]*(num_blocks-1) 87 | layers = [] 88 | for stride in strides: 89 | layers.append(block(self.in_planes, planes, stride)) 90 | self.in_planes = planes * block.expansion 91 | return nn.Sequential(*layers) 92 | 93 | def forward(self, x): 94 | out = F.relu(self.bn1(self.conv1(x))) 95 | out = self.layer1(out) 96 | out = self.layer2(out) 97 | out = self.layer3(out) 98 | out = self.layer4(out) 99 | out = F.avg_pool2d(out, 4) 100 | out = out.view(out.size(0), -1) 101 | out = self.linear(out) 102 | return out 103 | 104 | 105 | def ResNet18(): 106 | return ResNet(BasicBlock, [2,2,2,2]) 107 | 108 | def ResNet34(): 109 | return ResNet(BasicBlock, [3,4,6,3]) 110 | 111 | def ResNet50(): 112 | return ResNet(Bottleneck, [3,4,6,3]) 113 | 114 | def ResNet101(): 115 | return ResNet(Bottleneck, [3,4,23,3]) 116 | 117 | def ResNet152(): 118 | return ResNet(Bottleneck, [3,8,36,3]) 119 | 120 | 121 | def test(): 122 | net = ResNet18() 123 | y = net(torch.randn(1,3,32,32)) 124 | print(y.size()) 125 | 126 | # test() 127 | -------------------------------------------------------------------------------- /misc/PatchAttacker.py: -------------------------------------------------------------------------------- 1 | ###################################################################################################### 2 | # Adapted from https://github.com/Ping-C/certifiedpatchdefense/blob/master/attacks/patch_attacker.py 3 | ###################################################################################################### 4 | 5 | import torch 6 | import numpy as np 7 | 8 | 9 | class PatchAttacker: 10 | def __init__(self, model, mean, std, image_size=244,epsilon=1,steps=500,step_size=0.05,patch_size=31,random_start=True): 11 | 12 | mean,std = torch.tensor(mean),torch.tensor(std) 13 | self.epsilon = epsilon / std 14 | self.epsilon_cuda=self.epsilon[None, :, None, None].cuda() 15 | self.steps = steps 16 | self.step_size = step_size / std 17 | self.step_size=self.step_size[None, :, None, None].cuda() 18 | self.model = model.cuda() 19 | self.mean = mean 20 | self.std = std 21 | self.random_start = random_start 22 | self.image_size = image_size 23 | self.lb = (-mean / std) 24 | self.lb=self.lb[None, :, None, None].cuda() 25 | self.ub = (1 - mean) / std 26 | self.ub=self.ub[None, :, None, None].cuda() 27 | self.patch_w = patch_size 28 | self.patch_l = patch_size 29 | 30 | self.criterion = torch.nn.CrossEntropyLoss() 31 | 32 | def perturb(self, inputs, labels, loc=None,random_count=1): 33 | worst_x = None 34 | worst_loss = None 35 | 36 | for _ in range(random_count): 37 | # generate random patch center for each image 38 | idx = torch.arange(inputs.shape[0])[:, None] 39 | zero_idx = torch.zeros((inputs.shape[0],1), dtype=torch.long) 40 | if loc is not None: #specified locations 41 | w_idx = torch.ones([inputs.shape[0],1],dtype=torch.int64)*loc[0] 42 | l_idx = torch.ones([inputs.shape[0],1],dtype=torch.int64)*loc[1] 43 | else: #random locations 44 | w_idx = torch.randint(0 , inputs.shape[2]-self.patch_w , (inputs.shape[0],1)) 45 | l_idx = torch.randint(0 , inputs.shape[3]-self.patch_l , (inputs.shape[0],1)) 46 | 47 | idx = torch.cat([idx,zero_idx, w_idx, l_idx], dim=1) 48 | idx_list = [idx] 49 | for w in range(self.patch_w): 50 | for l in range(self.patch_l): 51 | idx_list.append(idx + torch.tensor([0,0,w,l])) 52 | idx_list = torch.cat(idx_list, dim =0) 53 | 54 | # create mask 55 | mask = torch.zeros([inputs.shape[0], 1, inputs.shape[2], inputs.shape[3]], 56 | dtype=torch.bool).cuda() 57 | mask[idx_list[:,0],idx_list[:,1],idx_list[:,2],idx_list[:,3]] = True 58 | 59 | if self.random_start: 60 | init_delta = np.random.uniform(-self.epsilon, self.epsilon, 61 | [inputs.shape[0]*inputs.shape[2]*inputs.shape[3], inputs.shape[1]]) 62 | init_delta = init_delta.reshape(inputs.shape[0],inputs.shape[2],inputs.shape[3], inputs.shape[1]) 63 | init_delta = init_delta.swapaxes(1,3).swapaxes(2,3) 64 | x = inputs + torch.where(mask, torch.Tensor(init_delta).to('cuda'), torch.tensor(0.).cuda()) 65 | 66 | x = torch.min(torch.max(x, self.lb), self.ub).detach() # ensure valid pixel range 67 | else: 68 | x = inputs.data.detach().clone() 69 | 70 | x_init = inputs.data.detach().clone() 71 | 72 | for step in range(self.steps+1): 73 | x.requires_grad_() 74 | output = self.model(torch.where(mask, x, x_init)) 75 | loss_ind = torch.nn.functional.cross_entropy(input=output, target=labels,reduction='none') 76 | loss = loss_ind.sum() 77 | grads = torch.autograd.grad(loss, x,retain_graph=False)[0] 78 | 79 | if step % 10 ==0: 80 | if worst_loss is None: 81 | worst_loss = loss_ind.detach().clone() 82 | worst_x = x.detach().clone() 83 | else: 84 | tmp_loss = loss_ind.detach().clone() 85 | tmp_x = x.detach().clone() 86 | filter_tmp=worst_loss.ge(tmp_loss).detach().clone() 87 | worst_x = torch.where(filter_tmp.reshape([inputs.shape[0],1,1,1]), worst_x, tmp_x).detach().clone() 88 | worst_loss = torch.where(filter_tmp, worst_loss, tmp_loss).detach().clone() 89 | #print(worst_loss) 90 | #del tmp_loss 91 | #del tmp_x 92 | #del filter_tmp 93 | signed_grad_x = torch.sign(grads).detach() 94 | delta = signed_grad_x * self.step_size 95 | x = delta + x 96 | #del loss 97 | #del loss_ind 98 | #del grads 99 | # Project back into constraints ball and correct range 100 | x = torch.max(torch.min(x, x_init + self.epsilon_cuda), x_init - self.epsilon_cuda)#.detach() 101 | x = torch.min(torch.max(x, self.lb), self.ub).detach().clone() 102 | 103 | return worst_x.detach().clone(), torch.cat([w_idx, l_idx], dim=1).detach().clone() 104 | 105 | -------------------------------------------------------------------------------- /misc/patch_attack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | from torchvision import datasets, transforms 7 | 8 | import nets.bagnet 9 | import nets.resnet 10 | from utils.defense_utils import * 11 | 12 | import os 13 | import argparse 14 | from tqdm import tqdm 15 | import numpy as np 16 | import PIL 17 | from PatchAttacker import PatchAttacker 18 | import joblib 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--dump_dir",default='patch_adv',type=str,help="directory to save attack results") 22 | parser.add_argument("--model_dir",default='checkpoints',type=str,help="path to checkpoints") 23 | parser.add_argument('--data_dir', default='data', type=str,help="path to data") 24 | parser.add_argument('--dataset', default='imagenette', choices=('imagenette','imagenet','cifar'),type=str,help="dataset") 25 | parser.add_argument("--model",default='bagnet17',type=str,help="model name") 26 | parser.add_argument("--clip",default=-1,type=int,help="clipping value; do clipping when this argument is set to positive") 27 | parser.add_argument("--aggr",default='mean',type=str,help="aggregation methods. set to none for local feature") 28 | parser.add_argument("--patch_size",type=int,help="size of the adversarial patch") 29 | 30 | args = parser.parse_args() 31 | 32 | MODEL_DIR=os.path.join('.',args.model_dir) 33 | DATA_DIR=os.path.join(args.data_dir,args.dataset) 34 | DATASET = args.dataset 35 | DUMP_DIR=os.path.join('dump',args.dump_dir+'_{}_{}'.format(args.model,args.dataset)) 36 | if not os.path.exists('dump'): 37 | os.mkdir('dump') 38 | if not os.path.exists(DUMP_DIR): 39 | os.mkdir(DUMP_DIR) 40 | 41 | 42 | 43 | if DATASET in ['imagenette','imagenet']: 44 | DATA_DIR=os.path.join(DATA_DIR,'val') 45 | mean_vec = [0.485, 0.456, 0.406] 46 | std_vec = [0.229, 0.224, 0.225] 47 | ds_transforms = transforms.Compose([ 48 | transforms.Resize(256), 49 | transforms.CenterCrop(224), 50 | transforms.ToTensor(), 51 | transforms.Normalize(mean_vec,std_vec) 52 | ]) 53 | val_dataset = datasets.ImageFolder(DATA_DIR,ds_transforms) 54 | class_names = val_dataset.classes 55 | elif DATASET == 'cifar': 56 | mean_vec = [0.4914, 0.4822, 0.4465] 57 | std_vec = [0.2023, 0.1994, 0.2010] 58 | ds_transforms = transforms.Compose([ 59 | transforms.Resize(192, interpolation=PIL.Image.BICUBIC), 60 | transforms.ToTensor(), 61 | transforms.Normalize(mean_vec,std_vec), 62 | ]) 63 | val_dataset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=ds_transforms) 64 | class_names = val_dataset.classes 65 | 66 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8,shuffle=False) 67 | 68 | #build and initialize model 69 | device = 'cuda' #if torch.cuda.is_available() else 'cpu' 70 | 71 | if args.clip > 0: 72 | clip_range = [0,args.clip] 73 | else: 74 | clip_range = None 75 | 76 | if 'bagnet17' in args.model: 77 | model = nets.bagnet.bagnet17(pretrained=True,clip_range=clip_range,aggregation=args.aggr) 78 | elif 'bagnet33' in args.model: 79 | model = nets.bagnet.bagnet33(pretrained=True,clip_range=clip_range,aggregation=args.aggr) 80 | elif 'bagnet9' in args.model: 81 | model = nets.bagnet.bagnet9(pretrained=True,clip_range=clip_range,aggregation=args.aggr) 82 | elif 'resnet50' in args.model: 83 | model = nets.resnet.resnet50(pretrained=True,clip_range=clip_range,aggregation=args.aggr) 84 | 85 | 86 | if DATASET == 'imagenette': 87 | num_ftrs = model.fc.in_features 88 | model.fc = nn.Linear(num_ftrs, len(class_names)) 89 | model = torch.nn.DataParallel(model) 90 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_nette.pth')) 91 | model.load_state_dict(checkpoint['model_state_dict']) 92 | elif DATASET == 'imagenet': 93 | model = torch.nn.DataParallel(model) 94 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_net.pth')) 95 | model.load_state_dict(checkpoint['state_dict']) 96 | elif DATASET == 'cifar': 97 | num_ftrs = model.fc.in_features 98 | model.fc = nn.Linear(num_ftrs, len(class_names)) 99 | model = torch.nn.DataParallel(model) 100 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_192_cifar.pth')) 101 | model.load_state_dict(checkpoint['net']) 102 | 103 | model = model.to(device) 104 | model.eval() 105 | cudnn.benchmark = True 106 | 107 | 108 | attacker = PatchAttacker(model, mean_vec, std_vec,patch_size=args.patch_size,step_size=0.05,steps=500) 109 | 110 | adv_list=[] 111 | error_list=[] 112 | accuracy_list=[] 113 | patch_loc_list=[] 114 | 115 | for data,labels in tqdm(val_loader): 116 | 117 | data,labels=data.to(device),labels.to(device) 118 | data_adv,patch_loc = attacker.perturb(data, labels) 119 | 120 | output_adv = model(data_adv) 121 | error_adv=torch.sum(torch.argmax(output_adv, dim=1) != labels).cpu().detach().numpy() 122 | output_clean = model(data) 123 | acc_clean=torch.sum(torch.argmax(output_clean, dim=1) == labels).cpu().detach().numpy() 124 | 125 | data_adv=data_adv.cpu().detach().numpy() 126 | patch_loc=patch_loc.cpu().detach().numpy() 127 | 128 | patch_loc_list.append(patch_loc) 129 | adv_list.append(data_adv) 130 | error_list.append(error_adv) 131 | accuracy_list.append(acc_clean) 132 | 133 | 134 | adv_list = np.concatenate(adv_list) 135 | patch_loc_list = np.concatenate(patch_loc_list) 136 | joblib.dump(adv_list,os.path.join(DUMP_DIR,'patch_adv_list_{}.z'.format(args.patch_size))) 137 | joblib.dump(patch_loc_list,os.path.join(DUMP_DIR,'patch_loc_list_{}.z'.format(args.patch_size))) 138 | print("Attack success rate:",np.sum(error_list)/len(val_dataset)) 139 | print("Clean accuracy:",np.sum(accuracy_list)/len(val_dataset)) 140 | 141 | -------------------------------------------------------------------------------- /misc/train_cifar.py: -------------------------------------------------------------------------------- 1 | ############################################################################## 2 | # Adapted from https://github.com/kuangliu/pytorch-cifar/blob/master/main.py 3 | ############################################################################## 4 | 5 | '''Train CIFAR10 with PyTorch.''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | import torch.backends.cudnn as cudnn 11 | 12 | import torchvision 13 | import torchvision.transforms as transforms 14 | 15 | import os 16 | import argparse 17 | 18 | import nets.bagnet 19 | import nets.resnet 20 | 21 | import PIL 22 | 23 | from utils.progress_bar import progress_bar 24 | 25 | import numpy as np 26 | import joblib 27 | 28 | import random 29 | 30 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 31 | parser.add_argument('--lr', default=0.01, type=float, help='learning rate') 32 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 33 | parser.add_argument("--clip",default=-1,type=int) 34 | args = parser.parse_args() 35 | 36 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 37 | best_acc = 0 # best test accuracy 38 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 39 | 40 | # Data 41 | print('==> Preparing data..') 42 | transform_train = transforms.Compose([ 43 | #transforms.RandomCrop(32, padding=4), 44 | transforms.Resize(192, interpolation=PIL.Image.BICUBIC), 45 | transforms.RandomHorizontalFlip(), 46 | transforms.ToTensor(), 47 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 48 | ]) 49 | 50 | transform_test = transforms.Compose([ 51 | transforms.Resize(192, interpolation=PIL.Image.BICUBIC), 52 | transforms.ToTensor(), 53 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 54 | ]) 55 | 56 | 57 | trainset = torchvision.datasets.CIFAR10(root='data/cifar', train=True, download=True, transform=transform_train) 58 | 59 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2) 60 | 61 | testset = torchvision.datasets.CIFAR10(root='data/cifar', train=False, download=True, transform=transform_test) 62 | 63 | testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2) 64 | 65 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 66 | 67 | if args.clip > 0: 68 | clip_range = [0,args.clip] 69 | else: 70 | clip_range = None 71 | 72 | # Model 73 | print('==> Building model..') 74 | 75 | pth_path = './checkpoints/bagnet17_192_cifar.pth' 76 | 77 | net = nets.bagnet.bagnet17(pretrained=True,clip_range=clip_range,aggregation='adv') #aggregation = 'mean' for vanilla training 78 | #net = nets.resnet.resnet50(pretrained=True) 79 | 80 | #for param in net.parameters(): 81 | # param.requires_grad = False 82 | 83 | # Parameters of newly constructed modules have requires_grad=True by default 84 | num_ftrs = net.fc.in_features 85 | net.fc = nn.Linear(num_ftrs, 10) 86 | net = net.to(device) 87 | 88 | if device == 'cuda': 89 | net = torch.nn.DataParallel(net) 90 | cudnn.benchmark = True 91 | 92 | if args.resume: 93 | # Load checkpoint. 94 | print('==> Resuming from checkpoint..') 95 | assert os.path.isdir('./checkpoints'), 'Error: no checkpoint directory found!' 96 | checkpoint = torch.load(pth_path) 97 | net.load_state_dict(checkpoint['net']) 98 | best_acc = checkpoint['acc'] 99 | start_epoch = checkpoint['epoch'] 100 | 101 | criterion = nn.CrossEntropyLoss() 102 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 103 | 104 | # Training 105 | def train(epoch): 106 | print('\nEpoch: %d' % epoch) 107 | net.train() 108 | train_loss = 0 109 | correct = 0 110 | total = 0 111 | for batch_idx, (inputs, targets) in enumerate(trainloader): 112 | inputs, targets = inputs.to(device), targets.to(device) 113 | optimizer.zero_grad() 114 | outputs = net(inputs) 115 | loss = criterion(outputs, targets) 116 | loss.backward() 117 | optimizer.step() 118 | 119 | train_loss += loss.item() 120 | _, predicted = outputs.max(1) 121 | total += targets.size(0) 122 | correct += predicted.eq(targets).sum().item() 123 | 124 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 125 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 126 | 127 | def test(epoch): 128 | global best_acc 129 | net.eval() 130 | test_loss = 0 131 | correct = 0 132 | total = 0 133 | idx_list=[] 134 | with torch.no_grad(): 135 | for batch_idx, (inputs, targets) in enumerate(testloader): 136 | inputs, targets = inputs.to(device), targets.to(device) 137 | outputs = net(inputs) 138 | loss = criterion(outputs, targets) 139 | 140 | test_loss += loss.item() 141 | _, predicted = outputs.max(1) 142 | total += targets.size(0) 143 | correct += predicted.eq(targets).sum().item() 144 | 145 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 146 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 147 | 148 | # Save checkpoint. 149 | #joblib.dump(idx_list,'masked_contour_correct_idx.z') 150 | acc = 100.*correct/total 151 | if True:#acc > best_acc: 152 | print('Saving..') 153 | state = { 154 | 'net': net.state_dict(), 155 | 'acc': acc, 156 | 'epoch': epoch, 157 | } 158 | if not os.path.isdir('checkpoints'): 159 | os.mkdir('checkpoints') 160 | torch.save(state, pth_path) 161 | best_acc = acc 162 | 163 | # python train_cifar.py --lr 0.01 164 | # python train_cifar.py --resume --lr 0.001 165 | for epoch in range(start_epoch, start_epoch+20): 166 | train(epoch) 167 | test(epoch) 168 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PatchGuard: A Provably Robust Defense against Adversarial Patches via Small Receptive Fields and Masking 2 | 3 | By [Chong Xiang](http://xiangchong.xyz/), [Arjun Nitin Bhagoji](http://www.princeton.edu/~abhagoji/), [Vikash Sehwag](https://vsehwag.github.io/), [Prateek Mittal](https://www.princeton.edu/~pmittal/) 4 | 5 | Code for "[PatchGuard: A Provably Robust Defense against Adversarial Patches via Small Receptive Fields and Masking](https://www.usenix.org/conference/usenixsecurity21/presentation/xiang)" in USENIX security 2021 [arXiv Technical Report](https://arxiv.org/abs/2005.10884) 6 | 7 | defense overview pipeline 8 | 9 | Update 04/2022: Check out our new [PatchCleanser](https://github.com/inspire-group/PatchCleanser) defense (USENIX Security 2022), our [paper list for adversarial patch research](https://github.com/xiangchong1/adv-patch-paper-list), and [leaderboard for certifiable robust image classification](https://github.com/inspire-group/patch-defense-leaderboard) for fun! 10 | 11 | Update 12/2021: fixed incorrect lower bound computation for the true class when the detection threshold T>0. Thank [Linyi Li](https://github.com/llylly) for pointing that out! The mistake does not affect the main results of paper (since the main results are obtained with T=0). 12 | 13 | Update 08/2021: started to work on a paper list for adversarial patch research [(link)](https://github.com/xiangchong1/adv-patch-paper-list). 14 | 15 | Update 05/2021: included code (`det_bn.py`) for "[PatchGuard++: Efficient Provable Attack Detection against Adversarial Patches](https://arxiv.org/abs/2104.12609)" in Security and Safety in Machine Learning Systems Workshop at ICLR 2021. 16 | 17 | ## Requirements 18 | 19 | The code is tested with Python 3.8 and PyTorch 1.7.0. The complete list of required packages are available in `requirement.txt`, and can be installed with `pip install -r requirement.txt`. The code should be compatible with other versions of packages. 20 | 21 | ## Files 22 | 23 | ```shell 24 | ├── README.md #this file 25 | ├── requirement.txt #required package 26 | ├── example_cmd.sh #example command to run the code 27 | ├── mask_bn.py #PatchGuard: mask-bn for imagenet/imagenette/cifar 28 | ├── mask_ds.py #PatchGuard: mask-ds/ds for imagenet/imagenette/cifar 29 | ├── det_bn.py #PatchGuard++: provable robust attack detection 30 | ├── nets 31 | | ├── bagnet.py #modified bagnet model for mask-bn 32 | | ├── resnet.py #modified resnet model 33 | | ├── dsresnet_imgnt.py #ds-resnet-50 for imagenet(te) 34 | | └── dsresnet_cifar.py #ds-resnet-18 for cifar 35 | ├── utils 36 | | ├── defense_utils.py #utils for different defenses 37 | | ├── normalize_utils.py #utils for normalize images stored in numpy array (not used in the paper) 38 | | ├── cutout.py #utils for CUTOUT training (not used) 39 | | └── progress_bar.py #progress bar (used in train_cifar.py; unnecessary though) 40 | | 41 | ├── misc #useful scripts (but not used in robustness evaluation); move them to the main directory for execution 42 | | ├── test_acc.py #test clean accuracy of resnet/bagnet on imagenet/imagenette/cifar; support clipping, median operations 43 | | ├── train_imagenet.py #train resnet/bagnet for imagenet 44 | | ├── train_imagenette.py #train resnet/bagnet for imagenette 45 | | ├── train_cifar.py #train resnet/bagnet for cifar 46 | #NOTE: The attack scripts are not used in our defense evaluation! 47 | | ├── patch_attack.py #empirically (untargeted) attack resnet/bagnet trained on imagenet/imagenette/cifar 48 | | ├── PatchAttacker.py #utils for untargeted adversarial patch attack 49 | | 50 | ├── data 51 | | ├── imagenet #data directory for imagenet 52 | | ├── imagenette #data directory for imagenette 53 | | └── cifar #data directory for cifar 54 | | 55 | └── checkpoints #directory for checkpoints 56 | ├── README.md #details of each checkpoint 57 | └── ... #model checkpoints 58 | ``` 59 | 60 | ## Datasets 61 | 62 | - [ImageNet](http://www.image-net.org/) (ILSVRC2012) 63 | - [ImageNette](https://github.com/fastai/imagenette) ([Full size](https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz)) 64 | - [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) 65 | 66 | ## Usage 67 | 68 | - See **Files** for details of each file. 69 | - Download data in **Datasets** to `data/`. 70 | - (optional) Download checkpoints from Google Drive [link](https://drive.google.com/drive/folders/1u5RsCuZNf7ddWW0utI4OrgWGmJCUDCuT?usp=sharing) and move them to `checkpoints`. 71 | - See `example_cmd.sh` for example commands for running the code. 72 | 73 | If anything is unclear, please open an issue or contact Chong Xiang (cxiang@princeton.edu). 74 | 75 | ## Related Repositories 76 | 77 | - [certifiedpatchdefense](https://github.com/Ping-C/certifiedpatchdefense) 78 | - [patchSmoothing](https://github.com/alevine0/patchSmoothing) 79 | - [bag-of-local-features-models](https://github.com/wielandbrendel/bag-of-local-features-models) 80 | 81 | ## Citations 82 | 83 | If you find our work useful in your research, please consider citing: 84 | 85 | ```tex 86 | @inproceedings{xiang2021patchguard, 87 | title={PatchGuard: A Provably Robust Defense against Adversarial Patches via Small Receptive Fields and Masking}, 88 | author={Xiang, Chong and Bhagoji, Arjun Nitin and Sehwag, Vikash and Mittal, Prateek}, 89 | booktitle = {30th {USENIX} Security Symposium ({USENIX} Security)}, 90 | year={2021} 91 | } 92 | 93 | @inproceedings{xiang2021patchguard2, 94 | title={PatchGuard++: Efficient Provable Attack Detection against Adversarial Patches}, 95 | author={Xiang, Chong and Mittal, Prateek}, 96 | booktitle = {ICLR 2021 Workshop on Security and Safety in Machine Learning Systems}, 97 | year={2021} 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /mask_ds.py: -------------------------------------------------------------------------------- 1 | ############################################################################################################## 2 | # Part of code adapted from https://github.com/alevine0/patchSmoothing/blob/master/certify_imagenet_band.py 3 | ############################################################################################################## 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | import torch.backends.cudnn as cudnn 10 | 11 | import nets.dsresnet_imgnt as resnet_imgnt 12 | import nets.dsresnet_cifar as resnet_cifar 13 | from torchvision import datasets,transforms 14 | from tqdm import tqdm 15 | from utils.defense_utils import * 16 | 17 | import os 18 | import argparse 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--model_dir",default='checkpoints',type=str,help="path to checkpoints") 22 | parser.add_argument('--band_size', default=-1, type=int, help='size of each smoothing band') 23 | parser.add_argument('--patch_size', default=-1, type=int, help='patch_size') 24 | parser.add_argument('--thres', default=0.0, type=float, help='detection threshold for robus masking') 25 | parser.add_argument('--dataset', default='imagenette', choices=('imagenette','imagenet','cifar'),type=str,help="dataset") 26 | parser.add_argument('--data_dir', default='data', type=str,help="path to data") 27 | 28 | parser.add_argument('--skip', default=1,type=int, help='Number of images to skip') 29 | parser.add_argument("--m",action='store_true',help="use robust masking") 30 | parser.add_argument("--ds",action='store_true',help="use derandomized smoothing") 31 | 32 | args = parser.parse_args() 33 | 34 | MODEL_DIR=os.path.join('.',args.model_dir) 35 | DATA_DIR=os.path.join(args.data_dir,args.dataset) 36 | DATASET = args.dataset 37 | 38 | device = 'cuda' #if torch.cuda.is_available() else 'cpu' 39 | 40 | cudnn.benchmark = True 41 | 42 | def get_dataset(ds,data_dir): 43 | if ds in ['imagenette','imagenet']: 44 | ds_dir=os.path.join(data_dir,'val') 45 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 46 | std=[0.229, 0.224, 0.225]) 47 | dataset_ = datasets.ImageFolder(ds_dir, transforms.Compose([ 48 | transforms.Resize((299,299)), #note that here input size if 299x299 instead of 224x224 49 | transforms.ToTensor(), 50 | normalize, 51 | ])) 52 | elif ds == 'cifar': 53 | transform_test = transforms.Compose([ 54 | transforms.ToTensor(), 55 | #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 56 | ]) 57 | dataset_ = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform_test) 58 | return dataset_,dataset_.classes 59 | 60 | val_dataset_,class_names = get_dataset(DATASET,DATA_DIR) 61 | skips = list(range(0, len(val_dataset_), args.skip)) 62 | val_dataset = torch.utils.data.Subset(val_dataset_, skips) 63 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32,shuffle=False) 64 | 65 | num_cls = len(class_names) 66 | 67 | # Model 68 | print('==> Building model..') 69 | 70 | 71 | 72 | if DATASET == 'imagenette': 73 | net = resnet_imgnt.resnet50() 74 | net = torch.nn.DataParallel(net) 75 | num_ftrs = net.module.fc.in_features 76 | net.module.fc = nn.Linear(num_ftrs, num_cls) 77 | checkpoint = torch.load(os.path.join(MODEL_DIR,'ds_nette.pth')) 78 | args.band_size = args.band_size if args.band_size>0 else 25 79 | args.patch_size = args.patch_size if args.patch_size>0 else 42 80 | elif DATASET == 'imagenet': 81 | net = resnet_imgnt.resnet50() 82 | net = torch.nn.DataParallel(net) 83 | checkpoint = torch.load(os.path.join(MODEL_DIR,'ds_net.pth')) 84 | args.band_size = args.band_size if args.band_size>0 else 25 85 | args.patch_size = args.patch_size if args.patch_size>0 else 42 86 | elif DATASET == 'cifar': 87 | net = resnet_cifar.ResNet18() 88 | net = torch.nn.DataParallel(net) 89 | checkpoint = torch.load(os.path.join(MODEL_DIR,'ds_cifar.pth')) 90 | args.band_size = args.band_size if args.band_size>0 else 4 91 | args.patch_size = args.patch_size if args.patch_size>0 else 5 92 | 93 | print(args.band_size,args.patch_size) 94 | 95 | 96 | net.load_state_dict(checkpoint['net']) 97 | 98 | net = net.to(device) 99 | net.eval() 100 | 101 | 102 | if args.ds:#ds 103 | correct = 0 104 | cert_correct = 0 105 | cert_incorrect = 0 106 | total = 0 107 | with torch.no_grad(): 108 | for inputs, targets in tqdm(val_loader): 109 | inputs, targets = inputs.to(device), targets.to(device) 110 | total += targets.size(0) 111 | predictions, certyn = ds(inputs, net,args.band_size, args.patch_size, num_cls,threshold = 0.2) 112 | correct += (predictions.eq(targets)).sum().item() 113 | cert_correct += (predictions.eq(targets) & certyn).sum().item() 114 | cert_incorrect += (~predictions.eq(targets) & certyn).sum().item() 115 | print('Results for Derandomized Smoothing') 116 | print('Using band size ' + str(args.band_size) + ' with threshhold ' + str(0.2)) 117 | print('Certifying For Patch ' +str(args.patch_size) + '*'+str(args.patch_size)) 118 | print('Total images: ' + str(total)) 119 | print('Correct: ' + str(correct) + ' (' + str((100.*correct)/total)+'%)') 120 | print('Certified Correct class: ' + str(cert_correct) + ' (' + str((100.*cert_correct)/total)+'%)') 121 | print('Certified Wrong class: ' + str(cert_incorrect) + ' (' + str((100.*cert_incorrect)/total)+'%)') 122 | 123 | if args.m:#mask-ds 124 | result_list=[] 125 | clean_corr_list=[] 126 | with torch.no_grad(): 127 | for inputs, targets in tqdm(val_loader): 128 | inputs = inputs.to(device) 129 | targets = targets.numpy() 130 | result,clean_corr = masking_ds(inputs,targets,net,args.band_size, args.patch_size,thres=args.thres) 131 | result_list+=result 132 | clean_corr_list+=clean_corr 133 | 134 | cases,cnt=np.unique(result_list,return_counts=True) 135 | print('Results for Mask-DS') 136 | print("Provable robust accuracy:",cnt[-1]/len(result_list) if len(cnt)==3 else 0) 137 | print("Clean accuracy with defense:",np.mean(clean_corr_list)) 138 | print("------------------------------") 139 | print("Provable analysis cases (0: incorrect prediction; 1: vulnerable; 2: provably robust):",cases) 140 | print("Provable analysis breakdown:",cnt/len(result_list)) 141 | 142 | 143 | -------------------------------------------------------------------------------- /det_bn.py: -------------------------------------------------------------------------------- 1 | # the code logic is the same as mask_bn.py 2 | # keep as a seperate file to distinguish between PatchGuard and PatchGuard++ 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | import torch.backends.cudnn as cudnn 8 | from torchvision import datasets, transforms 9 | 10 | import nets.bagnet 11 | import nets.resnet 12 | from utils.defense_utils import * 13 | 14 | import os 15 | import joblib 16 | import argparse 17 | from tqdm import tqdm 18 | import numpy as np 19 | from scipy.special import softmax 20 | from math import ceil 21 | import PIL 22 | 23 | parser = argparse.ArgumentParser() 24 | 25 | parser.add_argument("--model_dir",default='checkpoints',type=str,help="path to checkpoints") 26 | parser.add_argument('--data_dir', default='data', type=str,help="path to data") 27 | parser.add_argument('--dataset', default='imagenette', choices=('imagenette','imagenet','cifar'),type=str,help="dataset") 28 | parser.add_argument("--model",default='bagnet33',type=str,help="model name") 29 | parser.add_argument("--clip",default=-1,type=int,help="clipping value; do clipping when this argument is set to positive") 30 | parser.add_argument("--aggr",default='none',type=str,help="aggregation methods. set to none for local feature") 31 | parser.add_argument("--skip",default=1,type=int,help="number of example to skip") 32 | parser.add_argument("--thres",default=0.0,type=float,help="detection threshold for robust masking") 33 | parser.add_argument("--patch_size",default=-1,type=int,help="size of the adversarial patch") 34 | parser.add_argument("--det",action='store_true',help="use PG++ attack detection") 35 | parser.add_argument("--tau",default=0.0,type=float,help="tau") 36 | 37 | args = parser.parse_args() 38 | 39 | MODEL_DIR=os.path.join('.',args.model_dir) 40 | DATA_DIR=os.path.join(args.data_dir,args.dataset) 41 | DATASET = args.dataset 42 | def get_dataset(ds,data_dir): 43 | if ds in ['imagenette','imagenet']: 44 | ds_dir=os.path.join(data_dir,'val') 45 | ds_transforms = transforms.Compose([ 46 | transforms.Resize(256), 47 | transforms.CenterCrop(224), 48 | transforms.ToTensor(), 49 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 50 | ]) 51 | dataset_ = datasets.ImageFolder(ds_dir,ds_transforms) 52 | class_names = dataset_.classes 53 | elif ds == 'cifar': 54 | ds_transforms = transforms.Compose([ 55 | transforms.Resize(192, interpolation=PIL.Image.BICUBIC), 56 | transforms.ToTensor(), 57 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 58 | ]) 59 | dataset_ = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=ds_transforms) 60 | class_names = dataset_.classes 61 | return dataset_,class_names 62 | 63 | val_dataset_,class_names = get_dataset(DATASET,DATA_DIR) 64 | skips = list(range(0, len(val_dataset_), args.skip)) 65 | val_dataset = torch.utils.data.Subset(val_dataset_, skips) 66 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8,shuffle=False) 67 | 68 | #build and initialize model 69 | device = 'cuda' #if torch.cuda.is_available() else 'cpu' 70 | 71 | if args.clip > 0: 72 | clip_range = [0,args.clip] 73 | else: 74 | clip_range = None 75 | 76 | if 'bagnet17' in args.model: 77 | model = nets.bagnet.bagnet17(pretrained=True,clip_range=clip_range,aggregation=args.aggr) 78 | rf_size=17 79 | elif 'bagnet33' in args.model: 80 | model = nets.bagnet.bagnet33(pretrained=True,clip_range=clip_range,aggregation=args.aggr) 81 | rf_size=33 82 | elif 'bagnet9' in args.model: 83 | model = nets.bagnet.bagnet9(pretrained=True,clip_range=clip_range,aggregation=args.aggr) 84 | rf_size=9 85 | 86 | 87 | if DATASET == 'imagenette': 88 | num_ftrs = model.fc.in_features 89 | model.fc = nn.Linear(num_ftrs, len(class_names)) 90 | model = torch.nn.DataParallel(model) 91 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_nette.pth')) 92 | model.load_state_dict(checkpoint['model_state_dict']) 93 | args.patch_size = args.patch_size if args.patch_size>0 else 32 94 | elif DATASET == 'imagenet': 95 | model = torch.nn.DataParallel(model) 96 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_net.pth')) 97 | model.load_state_dict(checkpoint['state_dict']) 98 | args.patch_size = args.patch_size if args.patch_size>0 else 32 99 | elif DATASET == 'cifar': 100 | num_ftrs = model.fc.in_features 101 | model.fc = nn.Linear(num_ftrs, len(class_names)) 102 | model = torch.nn.DataParallel(model) 103 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_192_cifar.pth')) 104 | model.load_state_dict(checkpoint['net']) 105 | args.patch_size = args.patch_size if args.patch_size>0 else 30 106 | 107 | 108 | rf_stride=8 109 | window_size = ceil((args.patch_size + rf_size -1) / rf_stride) 110 | print("window_size",window_size) 111 | 112 | 113 | model = model.to(device) 114 | model.eval() 115 | cudnn.benchmark = True 116 | 117 | accuracy_list=[] 118 | result_list=[] 119 | clean_corr=0 120 | 121 | for data,labels in tqdm(val_loader): 122 | 123 | data=data.to(device) 124 | labels = labels.numpy() 125 | output_clean = model(data).detach().cpu().numpy() # logits 126 | #output_clean = softmax(output_clean,axis=-1) # confidence 127 | #output_clean = (output_clean > 0.2).astype(float) # predictions with confidence threshold 128 | 129 | #note: the provable analysis of robust masking is cpu-intensive and can take some time to finish 130 | #you can dump the local feature and do the provable analysis with another script so that GPU mempry is not always occupied 131 | for i in range(len(labels)): 132 | if args.det: 133 | local_feature = output_clean[i] 134 | #result,clean_pred = provable_detection(local_feature,labels[i],tau=args.tau,window_shape=[window_size,window_size]) 135 | #clean_corr += clean_pred 136 | 137 | clean_pred = pg2_detection(local_feature,tau=args.tau,window_shape=[window_size,window_size]) 138 | clean_corr += clean_pred == labels[i] 139 | 140 | result = pg2_detection_provable(local_feature,labels[i],tau=args.tau,window_shape=[window_size,window_size]) 141 | result_list.append(result) 142 | 143 | acc_clean = np.sum(np.argmax(np.mean(output_clean,axis=(1,2)),axis=1) == labels) 144 | accuracy_list.append(acc_clean) 145 | 146 | 147 | cases,cnt=np.unique(result_list,return_counts=True) 148 | print("Provable robust accuracy:",cnt[-1]/len(result_list) if len(cnt)==3 else 0) 149 | print("Clean accuracy with defense:",clean_corr/len(result_list)) 150 | print("Clean accuracy without defense:",np.sum(accuracy_list)/len(val_dataset)) 151 | print("------------------------------") 152 | print("Provable analysis cases (0: incorrect prediction; 1: vulnerable; 2: provably robust):",cases) 153 | print("Provable analysis breakdown",cnt/len(result_list)) -------------------------------------------------------------------------------- /mask_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | from torchvision import datasets, transforms 7 | 8 | import nets.bagnet 9 | import nets.resnet 10 | from utils.defense_utils import * 11 | 12 | import os 13 | import joblib 14 | import argparse 15 | from tqdm import tqdm 16 | import numpy as np 17 | from scipy.special import softmax 18 | from math import ceil 19 | import PIL 20 | 21 | parser = argparse.ArgumentParser() 22 | 23 | parser.add_argument("--model_dir",default='checkpoints',type=str,help="path to checkpoints") 24 | parser.add_argument('--data_dir', default='data', type=str,help="path to data") 25 | parser.add_argument('--dataset', default='imagenette', choices=('imagenette','imagenet','cifar'),type=str,help="dataset") 26 | parser.add_argument("--model",default='bagnet17',type=str,help="model name") 27 | parser.add_argument("--clip",default=-1,type=int,help="clipping value; do clipping when this argument is set to positive") 28 | parser.add_argument("--aggr",default='none',type=str,help="aggregation methods. set to none for local feature") 29 | parser.add_argument("--skip",default=1,type=int,help="number of example to skip") 30 | parser.add_argument("--thres",default=0.0,type=float,help="detection threshold for robust masking") 31 | parser.add_argument("--patch_size",default=-1,type=int,help="size of the adversarial patch") 32 | parser.add_argument("--m",action='store_true',help="use robust masking") 33 | parser.add_argument("--cbn",action='store_true',help="use cbn") 34 | 35 | args = parser.parse_args() 36 | 37 | MODEL_DIR=os.path.join('.',args.model_dir) 38 | DATA_DIR=os.path.join(args.data_dir,args.dataset) 39 | DATASET = args.dataset 40 | def get_dataset(ds,data_dir): 41 | if ds in ['imagenette','imagenet']: 42 | ds_dir=os.path.join(data_dir,'val') 43 | ds_transforms = transforms.Compose([ 44 | transforms.Resize(256), 45 | transforms.CenterCrop(224), 46 | transforms.ToTensor(), 47 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 48 | ]) 49 | dataset_ = datasets.ImageFolder(ds_dir,ds_transforms) 50 | class_names = dataset_.classes 51 | elif ds == 'cifar': 52 | ds_transforms = transforms.Compose([ 53 | transforms.Resize(192, interpolation=PIL.Image.BICUBIC), 54 | transforms.ToTensor(), 55 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 56 | ]) 57 | dataset_ = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=ds_transforms) 58 | class_names = dataset_.classes 59 | return dataset_,class_names 60 | 61 | val_dataset_,class_names = get_dataset(DATASET,DATA_DIR) 62 | skips = list(range(0, len(val_dataset_), args.skip)) 63 | val_dataset = torch.utils.data.Subset(val_dataset_, skips) 64 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8,shuffle=False) 65 | 66 | #build and initialize model 67 | device = 'cuda' #if torch.cuda.is_available() else 'cpu' 68 | 69 | if args.clip > 0: 70 | clip_range = [0,args.clip] 71 | else: 72 | clip_range = None 73 | 74 | if 'bagnet17' in args.model: 75 | model = nets.bagnet.bagnet17(pretrained=True,clip_range=clip_range,aggregation=args.aggr) 76 | rf_size=17 77 | elif 'bagnet33' in args.model: 78 | model = nets.bagnet.bagnet33(pretrained=True,clip_range=clip_range,aggregation=args.aggr) 79 | rf_size=33 80 | elif 'bagnet9' in args.model: 81 | model = nets.bagnet.bagnet9(pretrained=True,clip_range=clip_range,aggregation=args.aggr) 82 | rf_size=9 83 | 84 | 85 | if DATASET == 'imagenette': 86 | num_ftrs = model.fc.in_features 87 | model.fc = nn.Linear(num_ftrs, len(class_names)) 88 | model = torch.nn.DataParallel(model) 89 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_nette.pth')) 90 | model.load_state_dict(checkpoint['model_state_dict']) 91 | args.patch_size = args.patch_size if args.patch_size>0 else 32 92 | elif DATASET == 'imagenet': 93 | model = torch.nn.DataParallel(model) 94 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_net.pth')) 95 | model.load_state_dict(checkpoint['state_dict']) 96 | args.patch_size = args.patch_size if args.patch_size>0 else 32 97 | elif DATASET == 'cifar': 98 | num_ftrs = model.fc.in_features 99 | model.fc = nn.Linear(num_ftrs, len(class_names)) 100 | model = torch.nn.DataParallel(model) 101 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model+'_192_cifar.pth')) 102 | model.load_state_dict(checkpoint['net']) 103 | args.patch_size = args.patch_size if args.patch_size>0 else 30 104 | 105 | 106 | rf_stride=8 107 | window_size = ceil((args.patch_size + rf_size -1) / rf_stride) 108 | print("window_size",window_size) 109 | 110 | 111 | model = model.to(device) 112 | model.eval() 113 | cudnn.benchmark = True 114 | 115 | accuracy_list=[] 116 | result_list=[] 117 | clean_corr=0 118 | 119 | for data,labels in tqdm(val_loader): 120 | 121 | data=data.to(device) 122 | labels = labels.numpy() 123 | output_clean = model(data).detach().cpu().numpy() # logits 124 | #output_clean = softmax(output_clean,axis=-1) # confidence 125 | #output_clean = (output_clean > 0.2).astype(float) # predictions with confidence threshold 126 | 127 | #note: the provable analysis of robust masking is cpu-intensive and can take some time to finish 128 | #you can dump the local feature and do the provable analysis with another script so that GPU mempry is not always occupied 129 | for i in range(len(labels)): 130 | if args.m:#robust masking 131 | local_feature = output_clean[i] 132 | result = provable_masking(local_feature,labels[i],thres=args.thres,window_shape=[window_size,window_size]) 133 | result_list.append(result) 134 | clean_pred = masking_defense(local_feature,thres=args.thres,window_shape=[window_size,window_size]) 135 | clean_corr += clean_pred == labels[i] 136 | 137 | elif args.cbn:#cbn 138 | # note that cbn results reported in the paper is obtained with vanilla BagNet (without provable adversrial training), since 139 | # the provable adversarial training is proposed in our paper. We will find that our training technique also benifits CBN 140 | result = provable_clipping(output_clean[i],labels[i],window_shape=[window_size,window_size]) 141 | result_list.append(result) 142 | clean_pred = clipping_defense(output_clean[i]) 143 | clean_corr += clean_pred == labels[i] 144 | acc_clean = np.sum(np.argmax(np.mean(output_clean,axis=(1,2)),axis=1) == labels) 145 | accuracy_list.append(acc_clean) 146 | 147 | 148 | cases,cnt=np.unique(result_list,return_counts=True) 149 | print("Provable robust accuracy:",cnt[-1]/len(result_list) if len(cnt)==3 else 0) 150 | print("Clean accuracy with defense:",clean_corr/len(result_list)) 151 | print("Clean accuracy without defense:",np.sum(accuracy_list)/len(val_dataset)) 152 | print("------------------------------") 153 | print("Provable analysis cases (0: incorrect prediction; 1: vulnerable; 2: provably robust):",cases) 154 | print("Provable analysis breakdown",cnt/len(result_list)) -------------------------------------------------------------------------------- /misc/train_imagenette.py: -------------------------------------------------------------------------------- 1 | ####################################################################################### 2 | # Adapted from https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html 3 | # Used for training models on ImageNette 4 | ####################################################################################### 5 | 6 | from __future__ import print_function, division 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from torch.optim import lr_scheduler 12 | import numpy as np 13 | import torchvision 14 | from torchvision import datasets, models, transforms 15 | import time 16 | import os 17 | import copy 18 | from tqdm import tqdm 19 | import random 20 | import nets.bagnet 21 | import nets.resnet 22 | import argparse 23 | from utils.cutout import Cutout 24 | 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--model_dir",default='checkpoints',type=str) 28 | parser.add_argument("--data_dir",default='data/imagenette',type=str) 29 | parser.add_argument("--model_name",default='bagnet17_nette.pth',type=str) 30 | parser.add_argument("--clip",default=-1,type=int) 31 | parser.add_argument("--epoch",default=20,type=int) 32 | parser.add_argument("--cutout_size",default=31,type=int) 33 | parser.add_argument("--aggr",default='adv',type=str) 34 | parser.add_argument("--resume",action='store_true') 35 | parser.add_argument("--cutout",action='store_true',help="use CUTOUT during the training") 36 | parser.add_argument("--fc",action='store_true',help="only retrain the fully-connected layer") 37 | args = parser.parse_args() 38 | 39 | MODEL_DIR=os.path.join('.',args.model_dir) 40 | DATA_DIR=os.path.join(args.data_dir) 41 | 42 | if not os.path.exists(MODEL_DIR): 43 | os.mkdir(MODEL_DIR) 44 | 45 | mean_vec=[0.485, 0.456, 0.406] 46 | std_vec=[0.229, 0.224, 0.225] 47 | 48 | data_transforms = { 49 | 'train': transforms.Compose([ 50 | transforms.RandomResizedCrop(224), 51 | transforms.RandomHorizontalFlip(), 52 | transforms.ToTensor(), 53 | transforms.Normalize(mean_vec, std_vec) 54 | ]), 55 | 'val': transforms.Compose([ 56 | transforms.Resize(256), 57 | transforms.CenterCrop(224), 58 | transforms.ToTensor(), 59 | transforms.Normalize(mean_vec,std_vec) 60 | ]), 61 | } 62 | 63 | if args.cutout: 64 | data_transforms['train'].transforms.append(Cutout(n_holes=1, length=args.cutout_size)) 65 | 66 | train_dir=os.path.join(DATA_DIR,'train') 67 | val_dir=os.path.join(DATA_DIR,'val') 68 | 69 | train_dataset = datasets.ImageFolder(train_dir,data_transforms['train']) 70 | val_dataset = datasets.ImageFolder(val_dir,data_transforms['val']) 71 | 72 | print('train_dataset.size',len(train_dataset.samples)) 73 | print('val_dataset.size',len(val_dataset.samples)) 74 | image_datasets = {'train':train_dataset,'val':val_dataset} 75 | dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} 76 | class_names = image_datasets['train'].classes 77 | print('class_names:',class_names) 78 | 79 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64,shuffle=True) 80 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64,shuffle=False) 81 | 82 | dataloaders={'train':train_loader,'val':val_loader} 83 | 84 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 85 | 86 | print('device:',device) 87 | 88 | def train_model(model, criterion, optimizer, scheduler, num_epochs=20 ,mask=False): 89 | 90 | since = time.time() 91 | 92 | best_model_wts = copy.deepcopy(model.state_dict()) 93 | best_acc = 0.0 94 | 95 | for epoch in tqdm(range(num_epochs)): 96 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 97 | print('-' * 10) 98 | 99 | # Each epoch has a training and validation phase 100 | for phase in ['train', 'val']: 101 | if phase == 'train': 102 | model.train() # Set model to training mode 103 | else: 104 | model.eval() # Set model to evaluate mode 105 | 106 | running_loss = 0.0 107 | running_corrects = 0 108 | 109 | # Iterate over data. 110 | for inputs, labels in dataloaders[phase]: 111 | inputs = inputs.to(device) 112 | labels = labels.to(device) 113 | 114 | # zero the parameter gradients 115 | optimizer.zero_grad() 116 | 117 | # forward 118 | # track history if only in train 119 | with torch.set_grad_enabled(phase == 'train'): 120 | outputs = model(inputs,labels) 121 | _, preds = torch.max(outputs, 1) 122 | loss = criterion(outputs, labels) 123 | 124 | # backward + optimize only if in training phase 125 | if phase == 'train': 126 | loss.backward() 127 | optimizer.step() 128 | 129 | # statistics 130 | running_loss += loss.item() * inputs.size(0) 131 | running_corrects += torch.sum(preds == labels.data) 132 | if phase == 'train': 133 | scheduler.step() 134 | 135 | epoch_loss = running_loss / dataset_sizes[phase] 136 | epoch_acc = running_corrects.double() / dataset_sizes[phase] 137 | 138 | print('{} Loss: {:.4f} Acc: {:.4f}'.format( 139 | phase, epoch_loss, epoch_acc)) 140 | 141 | # deep copy the model 142 | if phase == 'val' :#and epoch_acc > best_acc: 143 | best_acc = epoch_acc 144 | best_model_wts = copy.deepcopy(model.state_dict()) 145 | print('saving...') 146 | torch.save({ 147 | 'epoch': epoch, 148 | 'model_state_dict': best_model_wts, 149 | 'optimizer_state_dict': optimizer.state_dict(), 150 | 'scheduler_state_dict':scheduler.state_dict() 151 | }, os.path.join(MODEL_DIR,args.model_name)) 152 | 153 | print() 154 | 155 | time_elapsed = time.time() - since 156 | print('Training complete in {:.0f}m {:.0f}s'.format( 157 | time_elapsed // 60, time_elapsed % 60)) 158 | print('Best val Acc: {:4f}'.format(best_acc)) 159 | 160 | # load best model weights 161 | model.load_state_dict(best_model_wts) 162 | return model 163 | 164 | if args.clip > 0: 165 | clip_range = [0,args.clip] 166 | else: 167 | clip_range = None 168 | 169 | if 'bagnet17' in args.model_name: 170 | model_conv = nets.bagnet.bagnet17(pretrained=True,clip_range=clip_range,aggregation=args.aggr) 171 | elif 'bagnet33' in args.model_name: 172 | model_conv = nets.bagnet.bagnet33(pretrained=True,clip_range=clip_range,aggregation=args.aggr) 173 | elif 'bagnet9' in args.model_name: 174 | model_conv = nets.bagnet.bagnet9(pretrained=True,clip_range=clip_range,aggregation=args.aggr) 175 | elif 'resnet50' in args.model_name: 176 | model_conv = nets.resnet.resnet50(pretrained=True,clip_range=clip_range,aggregation=args.aggr) 177 | 178 | if args.fc: #only retrain the fully-connected layer 179 | for param in model_conv.parameters(): 180 | param.requires_grad = False 181 | 182 | # Parameters of newly constructed modules have requires_grad=True by default 183 | num_ftrs = model_conv.fc.in_features 184 | model_conv.fc = nn.Linear(num_ftrs, len(class_names)) 185 | model_conv = torch.nn.DataParallel(model_conv) 186 | model_conv = model_conv.to(device) 187 | criterion = nn.CrossEntropyLoss() 188 | 189 | if args.fc: 190 | optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9) 191 | else: 192 | optimizer_conv = optim.SGD(model_conv.parameters(), lr=0.001, momentum=0.9) 193 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1) 194 | #print(optimizer_conv.state_dict()) 195 | #https://pytorch.org/tutorials/beginner/saving_loading_models.html 196 | if args.resume: 197 | print('restoring model from checkpoint...') 198 | checkpoint = torch.load(os.path.join(MODEL_DIR,args.model_name)) 199 | model_conv.load_state_dict(checkpoint['model_state_dict']) 200 | model_conv = model_conv.to(device) 201 | #https://discuss.pytorch.org/t/code-that-loads-sgd-fails-to-load-adam-state-to-gpu/61783/3 202 | optimizer_conv.load_state_dict(checkpoint['optimizer_state_dict']) 203 | exp_lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 204 | #print(checkpoint['optimizer_state_dict']) 205 | #print(checkpoint['scheduler_state_dict']) 206 | 207 | 208 | model_conv = train_model(model_conv, criterion, optimizer_conv, 209 | exp_lr_scheduler, num_epochs=args.epoch) 210 | 211 | -------------------------------------------------------------------------------- /nets/bagnet.py: -------------------------------------------------------------------------------- 1 | ################################################################################################################# 2 | # Adapted from https://github.com/wielandbrendel/bag-of-local-features-models/blob/master/bagnets/pytorchnet.py # 3 | # Mainly changed the model forward() function # 4 | ################################################################################################################# 5 | 6 | 7 | import torch.nn as nn 8 | import math 9 | import random 10 | import torch 11 | from collections import OrderedDict 12 | from torch.utils import model_zoo 13 | import numpy as np 14 | import os 15 | dir_path = os.path.dirname(os.path.realpath(__file__)) 16 | 17 | __all__ = ['bagnet9', 'bagnet17', 'bagnet33'] 18 | 19 | model_urls = { 20 | 'bagnet9': 'https://bitbucket.org/wielandbrendel/bag-of-feature-pretrained-models/raw/249e8fa82c0913623a807d9d35eeab9da7dcc2a8/bagnet8-34f4ccd2.pth.tar', 21 | 'bagnet17': 'https://bitbucket.org/wielandbrendel/bag-of-feature-pretrained-models/raw/249e8fa82c0913623a807d9d35eeab9da7dcc2a8/bagnet16-105524de.pth.tar', 22 | 'bagnet33': 'https://bitbucket.org/wielandbrendel/bag-of-feature-pretrained-models/raw/249e8fa82c0913623a807d9d35eeab9da7dcc2a8/bagnet32-2ddd53ed.pth.tar', 23 | } 24 | 25 | 26 | class Bottleneck(nn.Module): 27 | expansion = 4 28 | 29 | def __init__(self, inplanes, planes, stride=1, downsample=None, kernel_size=1): 30 | super(Bottleneck, self).__init__() 31 | # #print('Creating bottleneck with kernel size {} and stride {} with padding {}'.format(kernel_size, stride, (kernel_size - 1) // 2)) 32 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(planes) 34 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, stride=stride, 35 | padding=0, bias=False) # changed padding from (kernel_size - 1) // 2 36 | self.bn2 = nn.BatchNorm2d(planes) 37 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 38 | self.bn3 = nn.BatchNorm2d(planes * 4) 39 | self.relu = nn.ReLU(inplace=True) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | 44 | def forward(self, x, **kwargs): 45 | residual = x 46 | 47 | out = self.conv1(x) 48 | out = self.bn1(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv2(out) 52 | out = self.bn2(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv3(out) 56 | out = self.bn3(out) 57 | 58 | if self.downsample is not None: 59 | residual = self.downsample(x) 60 | 61 | if residual.size(-1) != out.size(-1): 62 | diff = residual.size(-1) - out.size(-1) 63 | residual = residual[:,:,:-diff,:-diff] 64 | 65 | out += residual 66 | out = self.relu(out) 67 | 68 | return out 69 | 70 | 71 | class BagNet(nn.Module): 72 | 73 | def __init__(self, block, layers, strides=[1, 2, 2, 2], kernel3=[0, 0, 0, 0], num_classes=1000,clip_range=None,aggregation='mean'): 74 | self.inplanes = 64 75 | super(BagNet, self).__init__() 76 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, stride=1, padding=0, 77 | bias=False) 78 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0, 79 | bias=False) 80 | self.bn1 = nn.BatchNorm2d(64, momentum=0.001) 81 | self.relu = nn.ReLU(inplace=True) 82 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], kernel3=kernel3[0], prefix='layer1') 83 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], kernel3=kernel3[1], prefix='layer2') 84 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], kernel3=kernel3[2], prefix='layer3') 85 | self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], kernel3=kernel3[3], prefix='layer4') 86 | self.avgpool = nn.AvgPool2d(1, stride=1) 87 | self.fc = nn.Linear(512 * block.expansion, num_classes) 88 | self.block = block 89 | 90 | self.clip_range = clip_range 91 | self.aggregation = aggregation 92 | 93 | for m in self.modules(): 94 | if isinstance(m, nn.Conv2d): 95 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 96 | m.weight.data.normal_(0, math.sqrt(2. / n)) 97 | elif isinstance(m, nn.BatchNorm2d): 98 | m.weight.data.fill_(1) 99 | m.bias.data.zero_() 100 | 101 | def _make_layer(self, block, planes, blocks, stride=1, kernel3=0, prefix=''): 102 | downsample = None 103 | if stride != 1 or self.inplanes != planes * block.expansion: 104 | 105 | downsample = nn.Sequential( 106 | nn.Conv2d(self.inplanes, planes * block.expansion, 107 | kernel_size=1, stride=stride, bias=False), 108 | nn.BatchNorm2d(planes * block.expansion), 109 | ) 110 | 111 | layers = [] 112 | kernel = 1 if kernel3 == 0 else 3 113 | 114 | layers.append(block(self.inplanes, planes, stride, downsample, kernel_size=kernel)) 115 | self.inplanes = planes * block.expansion 116 | for i in range(1, blocks): 117 | kernel = 1 if kernel3 <= i else 3 118 | 119 | layers.append(block(self.inplanes, planes, kernel_size=kernel)) 120 | 121 | return nn.Sequential(*layers) 122 | 123 | def forward(self, x,y=None): 124 | x = self.conv1(x) 125 | x = self.conv2(x) 126 | x = self.bn1(x) 127 | x = self.relu(x) 128 | x = self.layer1(x) 129 | x = self.layer2(x) 130 | x = self.layer3(x) 131 | x = self.layer4(x) 132 | 133 | x = x.permute(0,2,3,1) 134 | 135 | x = self.fc(x) 136 | if self.clip_range is not None: 137 | x = torch.clamp(x,self.clip_range[0],self.clip_range[1]) 138 | if self.aggregation == 'mean': 139 | x = torch.mean(x,dim=(1,2)) 140 | elif self.aggregation == 'median': 141 | x = x.view([x.size()[0],-1,10]) 142 | x = torch.median(x,dim=1) 143 | return x.values 144 | elif self.aggregation =='cbn':#clipped BagNet 145 | x = torch.tanh(x*0.05-1) 146 | x = torch.mean(x,dim=(1,2)) 147 | elif self.aggregation == 'adv':# provable adversarial training 148 | window_size = 6 # the size of window to be masked during the training 149 | B,W,H,C = x.size() 150 | x = torch.clamp(x,0,torch.tensor(float('inf'))) #clip 151 | tmp = x[torch.arange(B),:,:,y] #the feature map for the true class 152 | tmp = tmp.unfold(1,window_size,1).unfold(2,window_size,1) #unfold 153 | tmp = tmp.reshape([B,-1,window_size,window_size]) # [B,num_window,window_size,window_size] 154 | tmp = torch.sum(tmp,axis=(-2,-1)) # [B,num_window] true class evidence in every window 155 | tmp = torch.max(tmp,axis=-1).values # [B] max window class evidence 156 | x = torch.sum(x,dim=(1,2)) # 157 | x[torch.arange(B),y]-=tmp # substract the max true window class evidence 158 | x/=(W*H) 159 | elif self.aggregation == 'none': 160 | pass 161 | 162 | return x 163 | 164 | def bagnet33(pretrained=False, strides=[2, 2, 2, 1], **kwargs): 165 | """Constructs a Bagnet-33 model. 166 | 167 | Args: 168 | pretrained (bool): If True, returns a model pre-trained on ImageNet 169 | """ 170 | model = BagNet(Bottleneck, [3, 4, 6, 3], strides=strides, kernel3=[1,1,1,1], **kwargs) 171 | if pretrained: 172 | model.load_state_dict(model_zoo.load_url(model_urls['bagnet33'])) 173 | return model 174 | 175 | def bagnet17(pretrained=False, strides=[2, 2, 2, 1], **kwargs): 176 | """Constructs a Bagnet-17 model. 177 | 178 | Args: 179 | pretrained (bool): If True, returns a model pre-trained on ImageNet 180 | """ 181 | model = BagNet(Bottleneck, [3, 4, 6, 3], strides=strides, kernel3=[1,1,1,0], **kwargs) 182 | if pretrained: 183 | model.load_state_dict(model_zoo.load_url(model_urls['bagnet17'])) 184 | return model 185 | 186 | def bagnet9(pretrained=False, strides=[2, 2, 2, 1], **kwargs): 187 | """Constructs a Bagnet-9 model. 188 | 189 | Args: 190 | pretrained (bool): If True, returns a model pre-trained on ImageNet 191 | """ 192 | model = BagNet(Bottleneck, [3, 4, 6, 3], strides=strides, kernel3=[1,1,0,0], **kwargs) 193 | #model = BagNet(Bottleneck, [2,2,2,2], strides=strides, kernel3=[1,1,0,0], **kwargs) 194 | if pretrained: 195 | model.load_state_dict(model_zoo.load_url(model_urls['bagnet9'])) 196 | return model 197 | -------------------------------------------------------------------------------- /nets/dsresnet_imgnt.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # from https://github.com/alevine0/patchSmoothing/blob/master/resnet_imgnt.py 3 | ############################################################################### 4 | import torch 5 | import torch.nn as nn 6 | #from .utils import load_state_dict_from_url 7 | 8 | 9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 11 | 'wide_resnet50_2', 'wide_resnet101_2'] 12 | 13 | 14 | model_urls = { 15 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 16 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 17 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 18 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 19 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 20 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 21 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 22 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 23 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 24 | } 25 | 26 | 27 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 28 | """3x3 convolution with padding""" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 30 | padding=dilation, groups=groups, bias=False, dilation=dilation) 31 | 32 | 33 | def conv1x1(in_planes, out_planes, stride=1): 34 | """1x1 convolution""" 35 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 36 | 37 | 38 | class BasicBlock(nn.Module): 39 | expansion = 1 40 | __constants__ = ['downsample'] 41 | 42 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 43 | base_width=64, dilation=1, norm_layer=None): 44 | super(BasicBlock, self).__init__() 45 | if norm_layer is None: 46 | norm_layer = nn.BatchNorm2d 47 | if groups != 1 or base_width != 64: 48 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 49 | if dilation > 1: 50 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 51 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 52 | self.conv1 = conv3x3(inplanes, planes, stride) 53 | self.bn1 = norm_layer(planes) 54 | self.relu = nn.ReLU(inplace=True) 55 | self.conv2 = conv3x3(planes, planes) 56 | self.bn2 = norm_layer(planes) 57 | self.downsample = downsample 58 | self.stride = stride 59 | 60 | def forward(self, x): 61 | identity = x 62 | 63 | out = self.conv1(x) 64 | out = self.bn1(out) 65 | out = self.relu(out) 66 | 67 | out = self.conv2(out) 68 | out = self.bn2(out) 69 | 70 | if self.downsample is not None: 71 | identity = self.downsample(x) 72 | 73 | out += identity 74 | out = self.relu(out) 75 | 76 | return out 77 | 78 | 79 | class Bottleneck(nn.Module): 80 | expansion = 4 81 | __constants__ = ['downsample'] 82 | 83 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 84 | base_width=64, dilation=1, norm_layer=None): 85 | super(Bottleneck, self).__init__() 86 | if norm_layer is None: 87 | norm_layer = nn.BatchNorm2d 88 | width = int(planes * (base_width / 64.)) * groups 89 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 90 | self.conv1 = conv1x1(inplanes, width) 91 | self.bn1 = norm_layer(width) 92 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 93 | self.bn2 = norm_layer(width) 94 | self.conv3 = conv1x1(width, planes * self.expansion) 95 | self.bn3 = norm_layer(planes * self.expansion) 96 | self.relu = nn.ReLU(inplace=True) 97 | self.downsample = downsample 98 | self.stride = stride 99 | 100 | def forward(self, x): 101 | identity = x 102 | 103 | out = self.conv1(x) 104 | out = self.bn1(out) 105 | out = self.relu(out) 106 | 107 | out = self.conv2(out) 108 | out = self.bn2(out) 109 | out = self.relu(out) 110 | 111 | out = self.conv3(out) 112 | out = self.bn3(out) 113 | 114 | if self.downsample is not None: 115 | identity = self.downsample(x) 116 | 117 | out += identity 118 | out = self.relu(out) 119 | 120 | return out 121 | 122 | 123 | class ResNet(nn.Module): 124 | 125 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 126 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 127 | norm_layer=None): 128 | super(ResNet, self).__init__() 129 | if norm_layer is None: 130 | norm_layer = nn.BatchNorm2d 131 | self._norm_layer = norm_layer 132 | 133 | self.inplanes = 64 134 | self.dilation = 1 135 | if replace_stride_with_dilation is None: 136 | # each element in the tuple indicates if we should replace 137 | # the 2x2 stride with a dilated convolution instead 138 | replace_stride_with_dilation = [False, False, False] 139 | if len(replace_stride_with_dilation) != 3: 140 | raise ValueError("replace_stride_with_dilation should be None " 141 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 142 | self.groups = groups 143 | self.base_width = width_per_group 144 | self.conv1 = nn.Conv2d(6, self.inplanes, kernel_size=7, stride=2, padding=3, 145 | bias=False) 146 | self.bn1 = norm_layer(self.inplanes) 147 | self.relu = nn.ReLU(inplace=True) 148 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 149 | self.layer1 = self._make_layer(block, 64, layers[0]) 150 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 151 | dilate=replace_stride_with_dilation[0]) 152 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 153 | dilate=replace_stride_with_dilation[1]) 154 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 155 | dilate=replace_stride_with_dilation[2]) 156 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 157 | self.fc = nn.Linear(512 * block.expansion, num_classes) 158 | 159 | for m in self.modules(): 160 | if isinstance(m, nn.Conv2d): 161 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 162 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 163 | nn.init.constant_(m.weight, 1) 164 | nn.init.constant_(m.bias, 0) 165 | 166 | # Zero-initialize the last BN in each residual branch, 167 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 168 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 169 | if zero_init_residual: 170 | for m in self.modules(): 171 | if isinstance(m, Bottleneck): 172 | nn.init.constant_(m.bn3.weight, 0) 173 | elif isinstance(m, BasicBlock): 174 | nn.init.constant_(m.bn2.weight, 0) 175 | 176 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 177 | norm_layer = self._norm_layer 178 | downsample = None 179 | previous_dilation = self.dilation 180 | if dilate: 181 | self.dilation *= stride 182 | stride = 1 183 | if stride != 1 or self.inplanes != planes * block.expansion: 184 | downsample = nn.Sequential( 185 | conv1x1(self.inplanes, planes * block.expansion, stride), 186 | norm_layer(planes * block.expansion), 187 | ) 188 | 189 | layers = [] 190 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 191 | self.base_width, previous_dilation, norm_layer)) 192 | self.inplanes = planes * block.expansion 193 | for _ in range(1, blocks): 194 | layers.append(block(self.inplanes, planes, groups=self.groups, 195 | base_width=self.base_width, dilation=self.dilation, 196 | norm_layer=norm_layer)) 197 | 198 | return nn.Sequential(*layers) 199 | 200 | def forward(self, x): 201 | x = self.conv1(x) 202 | x = self.bn1(x) 203 | x = self.relu(x) 204 | x = self.maxpool(x) 205 | 206 | x = self.layer1(x) 207 | x = self.layer2(x) 208 | x = self.layer3(x) 209 | x = self.layer4(x) 210 | 211 | x = self.avgpool(x) 212 | x = torch.flatten(x, 1) 213 | x = self.fc(x) 214 | 215 | return x 216 | 217 | 218 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 219 | model = ResNet(block, layers, **kwargs) 220 | # if pretrained: 221 | # state_dict = load_state_dict_from_url(model_urls[arch], 222 | # progress=progress) 223 | # model.load_state_dict(state_dict) 224 | return model 225 | 226 | 227 | def resnet18(pretrained=False, progress=True, **kwargs): 228 | r"""ResNet-18 model from 229 | `"Deep Residual Learning for Image Recognition" `_ 230 | 231 | Args: 232 | pretrained (bool): If True, returns a model pre-trained on ImageNet 233 | progress (bool): If True, displays a progress bar of the download to stderr 234 | """ 235 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 236 | **kwargs) 237 | 238 | 239 | def resnet34(pretrained=False, progress=True, **kwargs): 240 | r"""ResNet-34 model from 241 | `"Deep Residual Learning for Image Recognition" `_ 242 | 243 | Args: 244 | pretrained (bool): If True, returns a model pre-trained on ImageNet 245 | progress (bool): If True, displays a progress bar of the download to stderr 246 | """ 247 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 248 | **kwargs) 249 | 250 | 251 | def resnet50(pretrained=False, progress=True, **kwargs): 252 | r"""ResNet-50 model from 253 | `"Deep Residual Learning for Image Recognition" `_ 254 | 255 | Args: 256 | pretrained (bool): If True, returns a model pre-trained on ImageNet 257 | progress (bool): If True, displays a progress bar of the download to stderr 258 | """ 259 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 260 | **kwargs) 261 | 262 | 263 | def resnet101(pretrained=False, progress=True, **kwargs): 264 | r"""ResNet-101 model from 265 | `"Deep Residual Learning for Image Recognition" `_ 266 | 267 | Args: 268 | pretrained (bool): If True, returns a model pre-trained on ImageNet 269 | progress (bool): If True, displays a progress bar of the download to stderr 270 | """ 271 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 272 | **kwargs) 273 | 274 | 275 | def resnet152(pretrained=False, progress=True, **kwargs): 276 | r"""ResNet-152 model from 277 | `"Deep Residual Learning for Image Recognition" `_ 278 | 279 | Args: 280 | pretrained (bool): If True, returns a model pre-trained on ImageNet 281 | progress (bool): If True, displays a progress bar of the download to stderr 282 | """ 283 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 284 | **kwargs) 285 | 286 | 287 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 288 | r"""ResNeXt-50 32x4d model from 289 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 290 | 291 | Args: 292 | pretrained (bool): If True, returns a model pre-trained on ImageNet 293 | progress (bool): If True, displays a progress bar of the download to stderr 294 | """ 295 | kwargs['groups'] = 32 296 | kwargs['width_per_group'] = 4 297 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 298 | pretrained, progress, **kwargs) 299 | 300 | 301 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 302 | r"""ResNeXt-101 32x8d model from 303 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 304 | 305 | Args: 306 | pretrained (bool): If True, returns a model pre-trained on ImageNet 307 | progress (bool): If True, displays a progress bar of the download to stderr 308 | """ 309 | kwargs['groups'] = 32 310 | kwargs['width_per_group'] = 8 311 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 312 | pretrained, progress, **kwargs) 313 | 314 | 315 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 316 | r"""Wide ResNet-50-2 model from 317 | `"Wide Residual Networks" `_ 318 | 319 | The model is the same as ResNet except for the bottleneck number of channels 320 | which is twice larger in every block. The number of channels in outer 1x1 321 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 322 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 323 | 324 | Args: 325 | pretrained (bool): If True, returns a model pre-trained on ImageNet 326 | progress (bool): If True, displays a progress bar of the download to stderr 327 | """ 328 | kwargs['width_per_group'] = 64 * 2 329 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 330 | pretrained, progress, **kwargs) 331 | 332 | 333 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 334 | r"""Wide ResNet-101-2 model from 335 | `"Wide Residual Networks" `_ 336 | 337 | The model is the same as ResNet except for the bottleneck number of channels 338 | which is twice larger in every block. The number of channels in outer 1x1 339 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 340 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 341 | 342 | Args: 343 | pretrained (bool): If True, returns a model pre-trained on ImageNet 344 | progress (bool): If True, displays a progress bar of the download to stderr 345 | """ 346 | kwargs['width_per_group'] = 64 * 2 347 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 348 | pretrained, progress, **kwargs) 349 | -------------------------------------------------------------------------------- /nets/resnet.py: -------------------------------------------------------------------------------- 1 | ########################################################################################### 2 | # Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py # 3 | # Mainly changed the model forward() function # 4 | ########################################################################################### 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | try: 11 | from torch.hub import load_state_dict_from_url 12 | except ImportError: 13 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 14 | 15 | 16 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 17 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 18 | 'wide_resnet50_2', 'wide_resnet101_2'] 19 | 20 | 21 | model_urls = { 22 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 23 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 24 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 25 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 26 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 27 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 28 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 29 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 30 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 31 | } 32 | 33 | 34 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 35 | """3x3 convolution with padding""" 36 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 37 | padding=dilation, groups=groups, bias=False, dilation=dilation) 38 | 39 | 40 | def conv1x1(in_planes, out_planes, stride=1): 41 | """1x1 convolution""" 42 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 43 | 44 | 45 | class BasicBlock(nn.Module): 46 | expansion = 1 47 | 48 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 49 | base_width=64, dilation=1, norm_layer=None): 50 | super(BasicBlock, self).__init__() 51 | if norm_layer is None: 52 | norm_layer = nn.BatchNorm2d 53 | if groups != 1 or base_width != 64: 54 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 55 | if dilation > 1: 56 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 57 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 58 | self.conv1 = conv3x3(inplanes, planes, stride) 59 | self.bn1 = norm_layer(planes) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.conv2 = conv3x3(planes, planes) 62 | self.bn2 = norm_layer(planes) 63 | self.downsample = downsample 64 | self.stride = stride 65 | 66 | def forward(self, x): 67 | identity = x 68 | 69 | out = self.conv1(x) 70 | out = self.bn1(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv2(out) 74 | out = self.bn2(out) 75 | 76 | if self.downsample is not None: 77 | identity = self.downsample(x) 78 | 79 | out += identity 80 | out = self.relu(out) 81 | 82 | return out 83 | 84 | 85 | class Bottleneck(nn.Module): 86 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 87 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 88 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 89 | # This variant is also known as ResNet V1.5 and improves accuracy according to 90 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 91 | 92 | expansion = 4 93 | 94 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 95 | base_width=64, dilation=1, norm_layer=None): 96 | super(Bottleneck, self).__init__() 97 | if norm_layer is None: 98 | norm_layer = nn.BatchNorm2d 99 | width = int(planes * (base_width / 64.)) * groups 100 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 101 | self.conv1 = conv1x1(inplanes, width) 102 | self.bn1 = norm_layer(width) 103 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 104 | self.bn2 = norm_layer(width) 105 | self.conv3 = conv1x1(width, planes * self.expansion) 106 | self.bn3 = norm_layer(planes * self.expansion) 107 | self.relu = nn.ReLU(inplace=True) 108 | self.downsample = downsample 109 | self.stride = stride 110 | 111 | def forward(self, x): 112 | identity = x 113 | 114 | out = self.conv1(x) 115 | out = self.bn1(out) 116 | out = self.relu(out) 117 | 118 | out = self.conv2(out) 119 | out = self.bn2(out) 120 | out = self.relu(out) 121 | 122 | out = self.conv3(out) 123 | out = self.bn3(out) 124 | 125 | if self.downsample is not None: 126 | identity = self.downsample(x) 127 | 128 | out += identity 129 | out = self.relu(out) 130 | 131 | return out 132 | 133 | 134 | class ResNet(nn.Module): 135 | 136 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 137 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 138 | norm_layer=None, clip_range=None, aggregation = 'mean'): 139 | super(ResNet, self).__init__() 140 | self.clip_range = clip_range 141 | self.aggregation = aggregation 142 | 143 | if norm_layer is None: 144 | norm_layer = nn.BatchNorm2d 145 | self._norm_layer = norm_layer 146 | 147 | self.inplanes = 64 148 | self.dilation = 1 149 | if replace_stride_with_dilation is None: 150 | # each element in the tuple indicates if we should replace 151 | # the 2x2 stride with a dilated convolution instead 152 | replace_stride_with_dilation = [False, False, False] 153 | if len(replace_stride_with_dilation) != 3: 154 | raise ValueError("replace_stride_with_dilation should be None " 155 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 156 | self.groups = groups 157 | self.base_width = width_per_group 158 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 159 | bias=False) 160 | self.bn1 = norm_layer(self.inplanes) 161 | self.relu = nn.ReLU(inplace=True) 162 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 163 | self.layer1 = self._make_layer(block, 64, layers[0]) 164 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 165 | dilate=replace_stride_with_dilation[0]) 166 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 167 | dilate=replace_stride_with_dilation[1]) 168 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 169 | dilate=replace_stride_with_dilation[2]) 170 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 171 | self.fc = nn.Linear(512 * block.expansion, num_classes) 172 | 173 | for m in self.modules(): 174 | if isinstance(m, nn.Conv2d): 175 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 176 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 177 | nn.init.constant_(m.weight, 1) 178 | nn.init.constant_(m.bias, 0) 179 | 180 | # Zero-initialize the last BN in each residual branch, 181 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 182 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 183 | if zero_init_residual: 184 | for m in self.modules(): 185 | if isinstance(m, Bottleneck): 186 | nn.init.constant_(m.bn3.weight, 0) 187 | elif isinstance(m, BasicBlock): 188 | nn.init.constant_(m.bn2.weight, 0) 189 | 190 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 191 | norm_layer = self._norm_layer 192 | downsample = None 193 | previous_dilation = self.dilation 194 | if dilate: 195 | self.dilation *= stride 196 | stride = 1 197 | if stride != 1 or self.inplanes != planes * block.expansion: 198 | downsample = nn.Sequential( 199 | conv1x1(self.inplanes, planes * block.expansion, stride), 200 | norm_layer(planes * block.expansion), 201 | ) 202 | 203 | layers = [] 204 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 205 | self.base_width, previous_dilation, norm_layer)) 206 | self.inplanes = planes * block.expansion 207 | for _ in range(1, blocks): 208 | layers.append(block(self.inplanes, planes, groups=self.groups, 209 | base_width=self.base_width, dilation=self.dilation, 210 | norm_layer=norm_layer)) 211 | 212 | return nn.Sequential(*layers) 213 | 214 | def _forward_impl(self, x): 215 | # See note [TorchScript super()] 216 | x = self.conv1(x) 217 | x = self.bn1(x) 218 | x = self.relu(x) 219 | x = self.maxpool(x) 220 | 221 | x = self.layer1(x) 222 | x = self.layer2(x) 223 | x = self.layer3(x) 224 | x = self.layer4(x) 225 | 226 | x = x.permute(0,2,3,1) 227 | x = self.fc(x) 228 | if self.clip_range is not None: 229 | x = torch.clamp(x,self.clip_range[0],self.clip_range[1]) 230 | if self.aggregation == 'mean': 231 | x = torch.mean(x,dim=(1,2)) 232 | elif self.aggregation == 'median': 233 | x = x.view([x.size()[0],-1,10]) 234 | x = torch.median(x,dim=1) 235 | return x.values 236 | elif self.aggregation =='cbn': # clipping function from Clipped BagNet 237 | x = torch.tanh(x*0.05-1) 238 | x = torch.mean(x,dim=(1,2)) 239 | elif self.aggregation == 'none': 240 | pass 241 | return x 242 | 243 | def forward(self, x): 244 | return self._forward_impl(x) 245 | 246 | 247 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 248 | model = ResNet(block, layers, **kwargs) 249 | if pretrained: 250 | state_dict = load_state_dict_from_url(model_urls[arch], 251 | progress=progress) 252 | model.load_state_dict(state_dict) 253 | return model 254 | 255 | 256 | def resnet18(pretrained=False, progress=True, **kwargs): 257 | r"""ResNet-18 model from 258 | `"Deep Residual Learning for Image Recognition" `_ 259 | Args: 260 | pretrained (bool): If True, returns a model pre-trained on ImageNet 261 | progress (bool): If True, displays a progress bar of the download to stderr 262 | """ 263 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 264 | **kwargs) 265 | 266 | 267 | def resnet34(pretrained=False, progress=True, **kwargs): 268 | r"""ResNet-34 model from 269 | `"Deep Residual Learning for Image Recognition" `_ 270 | Args: 271 | pretrained (bool): If True, returns a model pre-trained on ImageNet 272 | progress (bool): If True, displays a progress bar of the download to stderr 273 | """ 274 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 275 | **kwargs) 276 | 277 | 278 | def resnet50(pretrained=False, progress=True, **kwargs): 279 | r"""ResNet-50 model from 280 | `"Deep Residual Learning for Image Recognition" `_ 281 | Args: 282 | pretrained (bool): If True, returns a model pre-trained on ImageNet 283 | progress (bool): If True, displays a progress bar of the download to stderr 284 | """ 285 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 286 | **kwargs) 287 | 288 | 289 | def resnet101(pretrained=False, progress=True, **kwargs): 290 | r"""ResNet-101 model from 291 | `"Deep Residual Learning for Image Recognition" `_ 292 | Args: 293 | pretrained (bool): If True, returns a model pre-trained on ImageNet 294 | progress (bool): If True, displays a progress bar of the download to stderr 295 | """ 296 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 297 | **kwargs) 298 | 299 | 300 | def resnet152(pretrained=False, progress=True, **kwargs): 301 | r"""ResNet-152 model from 302 | `"Deep Residual Learning for Image Recognition" `_ 303 | Args: 304 | pretrained (bool): If True, returns a model pre-trained on ImageNet 305 | progress (bool): If True, displays a progress bar of the download to stderr 306 | """ 307 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 308 | **kwargs) 309 | 310 | 311 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 312 | r"""ResNeXt-50 32x4d model from 313 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 314 | Args: 315 | pretrained (bool): If True, returns a model pre-trained on ImageNet 316 | progress (bool): If True, displays a progress bar of the download to stderr 317 | """ 318 | kwargs['groups'] = 32 319 | kwargs['width_per_group'] = 4 320 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 321 | pretrained, progress, **kwargs) 322 | 323 | 324 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 325 | r"""ResNeXt-101 32x8d model from 326 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 327 | Args: 328 | pretrained (bool): If True, returns a model pre-trained on ImageNet 329 | progress (bool): If True, displays a progress bar of the download to stderr 330 | """ 331 | kwargs['groups'] = 32 332 | kwargs['width_per_group'] = 8 333 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 334 | pretrained, progress, **kwargs) 335 | 336 | 337 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 338 | r"""Wide ResNet-50-2 model from 339 | `"Wide Residual Networks" `_ 340 | The model is the same as ResNet except for the bottleneck number of channels 341 | which is twice larger in every block. The number of channels in outer 1x1 342 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 343 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 344 | Args: 345 | pretrained (bool): If True, returns a model pre-trained on ImageNet 346 | progress (bool): If True, displays a progress bar of the download to stderr 347 | """ 348 | kwargs['width_per_group'] = 64 * 2 349 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 350 | pretrained, progress, **kwargs) 351 | 352 | 353 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 354 | r"""Wide ResNet-101-2 model from 355 | `"Wide Residual Networks" `_ 356 | The model is the same as ResNet except for the bottleneck number of channels 357 | which is twice larger in every block. The number of channels in outer 1x1 358 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 359 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 360 | Args: 361 | pretrained (bool): If True, returns a model pre-trained on ImageNet 362 | progress (bool): If True, displays a progress bar of the download to stderr 363 | """ 364 | kwargs['width_per_group'] = 64 * 2 365 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 366 | pretrained, progress, **kwargs) -------------------------------------------------------------------------------- /misc/train_imagenet.py: -------------------------------------------------------------------------------- 1 | ########################################################################################## 2 | # adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py 3 | # three changes: Line 33 Line 44 Line 364 4 | ########################################################################################## 5 | 6 | import argparse 7 | import os 8 | import random 9 | import shutil 10 | import time 11 | import warnings 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.parallel 16 | import torch.backends.cudnn as cudnn 17 | import torch.distributed as dist 18 | import torch.optim 19 | import torch.multiprocessing as mp 20 | import torch.utils.data 21 | import torch.utils.data.distributed 22 | import torchvision.transforms as transforms 23 | import torchvision.datasets as datasets 24 | import torchvision.models as models 25 | import nets.bagnet 26 | model_names = sorted(name for name in models.__dict__ 27 | if name.islower() and not name.startswith("__") 28 | and callable(models.__dict__[name])) 29 | 30 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 31 | parser.add_argument('data', metavar='DIR', 32 | help='path to dataset') 33 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18')#, 34 | #choices=model_names, 35 | #help='model architecture: ' + 36 | # ' | '.join(model_names) + 37 | # ' (default: resnet18)') 38 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 39 | help='number of data loading workers (default: 4)') 40 | parser.add_argument('--epochs', default=30, type=int, metavar='N', 41 | help='number of total epochs to run') 42 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 43 | help='manual epoch number (useful on restarts)') 44 | parser.add_argument('-b', '--batch-size', default=128, type=int, 45 | metavar='N', 46 | help='mini-batch size (default: 256), this is the total ' 47 | 'batch size of all GPUs on the current node when ' 48 | 'using Data Parallel or Distributed Data Parallel') 49 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 50 | metavar='LR', help='initial learning rate', dest='lr') 51 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 52 | help='momentum') 53 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 54 | metavar='W', help='weight decay (default: 1e-4)', 55 | dest='weight_decay') 56 | parser.add_argument('-p', '--print-freq', default=10, type=int, 57 | metavar='N', help='print frequency (default: 10)') 58 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 59 | help='path to latest checkpoint (default: none)') 60 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 61 | help='evaluate model on validation set') 62 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 63 | help='use pre-trained model') 64 | parser.add_argument('--world-size', default=-1, type=int, 65 | help='number of nodes for distributed training') 66 | parser.add_argument('--rank', default=-1, type=int, 67 | help='node rank for distributed training') 68 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 69 | help='url used to set up distributed training') 70 | parser.add_argument('--dist-backend', default='nccl', type=str, 71 | help='distributed backend') 72 | parser.add_argument('--seed', default=None, type=int, 73 | help='seed for initializing training. ') 74 | parser.add_argument('--gpu', default=None, type=int, 75 | help='GPU id to use.') 76 | parser.add_argument('--multiprocessing-distributed', action='store_true', 77 | help='Use multi-processing distributed training to launch ' 78 | 'N processes per node, which has N GPUs. This is the ' 79 | 'fastest way to use PyTorch for either single node or ' 80 | 'multi node data parallel training') 81 | 82 | best_acc1 = 0 83 | 84 | 85 | def main(): 86 | args = parser.parse_args() 87 | 88 | if args.seed is not None: 89 | random.seed(args.seed) 90 | torch.manual_seed(args.seed) 91 | cudnn.deterministic = True 92 | warnings.warn('You have chosen to seed training. ' 93 | 'This will turn on the CUDNN deterministic setting, ' 94 | 'which can slow down your training considerably! ' 95 | 'You may see unexpected behavior when restarting ' 96 | 'from checkpoints.') 97 | 98 | if args.gpu is not None: 99 | warnings.warn('You have chosen a specific GPU. This will completely ' 100 | 'disable data parallelism.') 101 | 102 | if args.dist_url == "env://" and args.world_size == -1: 103 | args.world_size = int(os.environ["WORLD_SIZE"]) 104 | 105 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 106 | 107 | ngpus_per_node = torch.cuda.device_count() 108 | if args.multiprocessing_distributed: 109 | # Since we have ngpus_per_node processes per node, the total world_size 110 | # needs to be adjusted accordingly 111 | args.world_size = ngpus_per_node * args.world_size 112 | # Use torch.multiprocessing.spawn to launch distributed processes: the 113 | # main_worker process function 114 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 115 | else: 116 | # Simply call main_worker function 117 | main_worker(args.gpu, ngpus_per_node, args) 118 | 119 | 120 | def main_worker(gpu, ngpus_per_node, args): 121 | global best_acc1 122 | args.gpu = gpu 123 | 124 | if args.gpu is not None: 125 | print("Use GPU: {} for training".format(args.gpu)) 126 | 127 | if args.distributed: 128 | if args.dist_url == "env://" and args.rank == -1: 129 | args.rank = int(os.environ["RANK"]) 130 | if args.multiprocessing_distributed: 131 | # For multiprocessing distributed training, rank needs to be the 132 | # global rank among all the processes 133 | args.rank = args.rank * ngpus_per_node + gpu 134 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 135 | world_size=args.world_size, rank=args.rank) 136 | # create model 137 | #if args.pretrained: 138 | # print("=> using pre-trained model '{}'".format(args.arch)) 139 | # model = models.__dict__[args.arch](pretrained=True) 140 | #else: 141 | # print("=> creating model '{}'".format(args.arch)) 142 | # model = models.__dict__[args.arch]() 143 | 144 | model = nets.bagnet.bagnet17(pretrained=True,aggregation='adv') 145 | 146 | if not torch.cuda.is_available(): 147 | print('using CPU, this will be slow') 148 | elif args.distributed: 149 | # For multiprocessing distributed, DistributedDataParallel constructor 150 | # should always set the single device scope, otherwise, 151 | # DistributedDataParallel will use all available devices. 152 | if args.gpu is not None: 153 | torch.cuda.set_device(args.gpu) 154 | model.cuda(args.gpu) 155 | # When using a single GPU per process and per 156 | # DistributedDataParallel, we need to divide the batch size 157 | # ourselves based on the total number of GPUs we have 158 | args.batch_size = int(args.batch_size / ngpus_per_node) 159 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 160 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 161 | else: 162 | model.cuda() 163 | # DistributedDataParallel will divide and allocate batch_size to all 164 | # available GPUs if device_ids are not set 165 | model = torch.nn.parallel.DistributedDataParallel(model) 166 | elif args.gpu is not None: 167 | torch.cuda.set_device(args.gpu) 168 | model = model.cuda(args.gpu) 169 | else: 170 | # DataParallel will divide and allocate batch_size to all available GPUs 171 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 172 | model.features = torch.nn.DataParallel(model.features) 173 | model.cuda() 174 | else: 175 | model = torch.nn.DataParallel(model).cuda() 176 | 177 | # define loss function (criterion) and optimizer 178 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 179 | 180 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 181 | momentum=args.momentum, 182 | weight_decay=args.weight_decay) 183 | 184 | # optionally resume from a checkpoint 185 | if args.resume: 186 | if os.path.isfile(args.resume): 187 | print("=> loading checkpoint '{}'".format(args.resume)) 188 | if args.gpu is None: 189 | checkpoint = torch.load(args.resume) 190 | else: 191 | # Map model to be loaded to specified single gpu. 192 | loc = 'cuda:{}'.format(args.gpu) 193 | checkpoint = torch.load(args.resume, map_location=loc) 194 | args.start_epoch = checkpoint['epoch'] 195 | best_acc1 = checkpoint['best_acc1'] 196 | if args.gpu is not None: 197 | # best_acc1 may be from a checkpoint from a different GPU 198 | best_acc1 = best_acc1.to(args.gpu) 199 | model.load_state_dict(checkpoint['state_dict']) 200 | optimizer.load_state_dict(checkpoint['optimizer']) 201 | print("=> loaded checkpoint '{}' (epoch {})" 202 | .format(args.resume, checkpoint['epoch'])) 203 | else: 204 | print("=> no checkpoint found at '{}'".format(args.resume)) 205 | 206 | cudnn.benchmark = True 207 | 208 | # Data loading code 209 | traindir = os.path.join(args.data, 'train') 210 | valdir = os.path.join(args.data, 'val') 211 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 212 | std=[0.229, 0.224, 0.225]) 213 | 214 | train_dataset = datasets.ImageFolder( 215 | traindir, 216 | transforms.Compose([ 217 | transforms.RandomResizedCrop(224), 218 | transforms.RandomHorizontalFlip(), 219 | transforms.ToTensor(), 220 | normalize, 221 | ])) 222 | 223 | if args.distributed: 224 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 225 | else: 226 | train_sampler = None 227 | 228 | train_loader = torch.utils.data.DataLoader( 229 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 230 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 231 | 232 | val_loader = torch.utils.data.DataLoader( 233 | datasets.ImageFolder(valdir, transforms.Compose([ 234 | transforms.Resize(256), 235 | transforms.CenterCrop(224), 236 | transforms.ToTensor(), 237 | normalize, 238 | ])), 239 | batch_size=args.batch_size, shuffle=False, 240 | num_workers=args.workers, pin_memory=True) 241 | 242 | if args.evaluate: 243 | validate(val_loader, model, criterion, args) 244 | return 245 | 246 | for epoch in range(args.start_epoch, args.epochs): 247 | if args.distributed: 248 | train_sampler.set_epoch(epoch) 249 | adjust_learning_rate(optimizer, epoch, args) 250 | 251 | # train for one epoch 252 | train(train_loader, model, criterion, optimizer, epoch, args) 253 | 254 | # evaluate on validation set 255 | acc1 = validate(val_loader, model, criterion, args) 256 | 257 | # remember best acc@1 and save checkpoint 258 | is_best = acc1 > best_acc1 259 | best_acc1 = max(acc1, best_acc1) 260 | 261 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 262 | and args.rank % ngpus_per_node == 0): 263 | save_checkpoint({ 264 | 'epoch': epoch + 1, 265 | 'arch': args.arch, 266 | 'state_dict': model.state_dict(), 267 | 'best_acc1': best_acc1, 268 | 'optimizer' : optimizer.state_dict(), 269 | }, is_best) 270 | 271 | 272 | def train(train_loader, model, criterion, optimizer, epoch, args): 273 | batch_time = AverageMeter('Time', ':6.3f') 274 | data_time = AverageMeter('Data', ':6.3f') 275 | losses = AverageMeter('Loss', ':.4e') 276 | top1 = AverageMeter('Acc@1', ':6.2f') 277 | top5 = AverageMeter('Acc@5', ':6.2f') 278 | progress = ProgressMeter( 279 | len(train_loader), 280 | [batch_time, data_time, losses, top1, top5], 281 | prefix="Epoch: [{}]".format(epoch)) 282 | 283 | # switch to train mode 284 | model.train() 285 | 286 | end = time.time() 287 | for i, (images, target) in enumerate(train_loader): 288 | # measure data loading time 289 | data_time.update(time.time() - end) 290 | 291 | if args.gpu is not None: 292 | images = images.cuda(args.gpu, non_blocking=True) 293 | if torch.cuda.is_available(): 294 | target = target.cuda(args.gpu, non_blocking=True) 295 | 296 | # compute output 297 | output = model(images,target) 298 | loss = criterion(output, target) 299 | 300 | # measure accuracy and record loss 301 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 302 | losses.update(loss.item(), images.size(0)) 303 | top1.update(acc1[0], images.size(0)) 304 | top5.update(acc5[0], images.size(0)) 305 | 306 | # compute gradient and do SGD step 307 | optimizer.zero_grad() 308 | loss.backward() 309 | optimizer.step() 310 | 311 | # measure elapsed time 312 | batch_time.update(time.time() - end) 313 | end = time.time() 314 | 315 | if i % args.print_freq == 0: 316 | progress.display(i) 317 | 318 | 319 | def validate(val_loader, model, criterion, args): 320 | batch_time = AverageMeter('Time', ':6.3f') 321 | losses = AverageMeter('Loss', ':.4e') 322 | top1 = AverageMeter('Acc@1', ':6.2f') 323 | top5 = AverageMeter('Acc@5', ':6.2f') 324 | progress = ProgressMeter( 325 | len(val_loader), 326 | [batch_time, losses, top1, top5], 327 | prefix='Test: ') 328 | 329 | # switch to evaluate mode 330 | model.eval() 331 | 332 | with torch.no_grad(): 333 | end = time.time() 334 | for i, (images, target) in enumerate(val_loader): 335 | if args.gpu is not None: 336 | images = images.cuda(args.gpu, non_blocking=True) 337 | if torch.cuda.is_available(): 338 | target = target.cuda(args.gpu, non_blocking=True) 339 | 340 | # compute output 341 | output = model(images,target) 342 | loss = criterion(output, target) 343 | 344 | # measure accuracy and record loss 345 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 346 | losses.update(loss.item(), images.size(0)) 347 | top1.update(acc1[0], images.size(0)) 348 | top5.update(acc5[0], images.size(0)) 349 | 350 | # measure elapsed time 351 | batch_time.update(time.time() - end) 352 | end = time.time() 353 | 354 | if i % args.print_freq == 0: 355 | progress.display(i) 356 | 357 | # TODO: this should also be done with the ProgressMeter 358 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 359 | .format(top1=top1, top5=top5)) 360 | 361 | return top1.avg 362 | 363 | 364 | def save_checkpoint(state, is_best, filename='bagnet17_adv_lr0.001.pth.tar'): 365 | torch.save(state, filename) 366 | if is_best: 367 | shutil.copyfile(filename, 'bagnet17_adv_lr0.001_best.pth.tar') 368 | 369 | 370 | class AverageMeter(object): 371 | """Computes and stores the average and current value""" 372 | def __init__(self, name, fmt=':f'): 373 | self.name = name 374 | self.fmt = fmt 375 | self.reset() 376 | 377 | def reset(self): 378 | self.val = 0 379 | self.avg = 0 380 | self.sum = 0 381 | self.count = 0 382 | 383 | def update(self, val, n=1): 384 | self.val = val 385 | self.sum += val * n 386 | self.count += n 387 | self.avg = self.sum / self.count 388 | 389 | def __str__(self): 390 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 391 | return fmtstr.format(**self.__dict__) 392 | 393 | 394 | class ProgressMeter(object): 395 | def __init__(self, num_batches, meters, prefix=""): 396 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 397 | self.meters = meters 398 | self.prefix = prefix 399 | 400 | def display(self, batch): 401 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 402 | entries += [str(meter) for meter in self.meters] 403 | print('\t'.join(entries)) 404 | 405 | def _get_batch_fmtstr(self, num_batches): 406 | num_digits = len(str(num_batches // 1)) 407 | fmt = '{:' + str(num_digits) + 'd}' 408 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 409 | 410 | 411 | def adjust_learning_rate(optimizer, epoch, args): 412 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 413 | lr = args.lr * (0.1 ** (epoch // 30)) 414 | for param_group in optimizer.param_groups: 415 | param_group['lr'] = lr 416 | 417 | 418 | def accuracy(output, target, topk=(1,)): 419 | """Computes the accuracy over the k top predictions for the specified values of k""" 420 | with torch.no_grad(): 421 | maxk = max(topk) 422 | batch_size = target.size(0) 423 | 424 | _, pred = output.topk(maxk, 1, True, True) 425 | pred = pred.t() 426 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 427 | 428 | res = [] 429 | for k in topk: 430 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 431 | res.append(correct_k.mul_(100.0 / batch_size)) 432 | return res 433 | 434 | 435 | if __name__ == '__main__': 436 | main() -------------------------------------------------------------------------------- /utils/defense_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from scipy.special import softmax 4 | 5 | # robust masking defense (Algorithm 1 in the paper) 6 | def masking_defense(local_feature,clipping=-1,thres=0.,window_shape=[6,6],ds=False): 7 | ''' 8 | local_feature numpy.ndarray, feature tensor in the shape of [feature_size_x,feature_size_y,num_cls] 9 | clipping int/float, the positive clipping value ($c_h$ in the paper). If clipping < 0, treat clipping as np.inf 10 | thres float in [0,1], detection threshold. ($T$ in the paper) 11 | window_shape list [int,int], the shape of sliding window 12 | ds boolean, whether is for mask-ds 13 | 14 | Return int, robust prediction 15 | ''' 16 | 17 | feature_size_x,feature_size_y,num_cls = local_feature.shape 18 | window_size_x,window_size_y = window_shape 19 | num_window_x = feature_size_x - window_size_x + 1 if not ds else feature_size_x 20 | num_window_y = feature_size_y - window_size_y + 1 if not ds else feature_size_y 21 | 22 | # clipping 23 | if clipping >0: 24 | local_feature = np.clip(local_feature,0,clipping) 25 | else: 26 | local_feature = np.clip(local_feature,0,np.inf) 27 | 28 | 29 | global_feature = np.sum(local_feature,axis=(0,1)) 30 | 31 | # the sum of class evidence within each window 32 | in_window_sum_tensor=np.zeros([num_window_x,num_window_y,num_cls]) 33 | for x in range(0,num_window_x): 34 | for y in range(0,num_window_y): 35 | if ds and x + window_size_x > feature_size_x: #only happens when ds is True 36 | in_window_sum_tensor[x,y,:] = np.sum(local_feature[x:,y:y+window_size_y,:],axis=(0,1)) + np.sum(local_feature[:x+window_size_x-feature_size_x,y:y+window_size_y,:],axis=(0,1)) 37 | else: # normal case 38 | in_window_sum_tensor[x,y,:] = np.sum(local_feature[x:x+window_size_x,y:y+window_size_y,:],axis=(0,1)) 39 | 40 | 41 | # calculate clipped and masked class evidence for each class 42 | for c in range(num_cls): 43 | max_window_sum = np.max(in_window_sum_tensor[:,:,c]) 44 | if global_feature[c] > 0 and max_window_sum / global_feature[c] > thres: 45 | global_feature[c]-=max_window_sum 46 | 47 | pred_list = np.argsort(global_feature,kind='stable')#"stable" is necessary when the feature type is prediction 48 | return pred_list[-1] 49 | 50 | 51 | # provable analysis of robust masking defense (Algorithm 2 in the paper) 52 | def provable_masking(local_feature,label,clipping=-1,thres=0.,window_shape=[6,6],ds=False): 53 | ''' 54 | local_feature numpy.ndarray, feature tensor in the shape of [feature_size_x,feature_size_y,num_cls] 55 | label int, true label 56 | clipping int/float, the positive clipping value ($c_h$ in the paper). If clipping < 0, treat clipping as np.inf 57 | thres float in [0,1], detection threshold. ($T$ in the paper) 58 | window_shape list [int,int], the shape of sliding window 59 | ds boolean, whether is for mask-ds 60 | 61 | Return int, provable analysis results (0: incorrect clean prediction; 1: possible attack found; 2: certified robustness ) 62 | ''' 63 | 64 | feature_size_x,feature_size_y,num_cls = local_feature.shape 65 | window_size_x,window_size_y = window_shape 66 | num_window_x = feature_size_x - window_size_x + 1 if not ds else feature_size_x 67 | num_window_y = feature_size_y - window_size_y + 1 if not ds else feature_size_y 68 | 69 | if clipping > 0: 70 | local_feature = np.clip(local_feature,0,clipping) 71 | else: 72 | local_feature = np.clip(local_feature,0,np.inf) 73 | 74 | global_feature = np.sum(local_feature,axis=(0,1)) 75 | 76 | pred_list = np.argsort(global_feature,kind='stable') 77 | global_pred = pred_list[-1] 78 | 79 | if global_pred != label: # clean prediction is incorrect 80 | return 0 81 | 82 | local_feature_pred = local_feature[:,:,global_pred] 83 | 84 | # the sum of class evidence within each window 85 | in_window_sum_tensor = np.zeros([num_window_x,num_window_y,num_cls]) 86 | 87 | for x in range(0,num_window_x): 88 | for y in range(0,num_window_y): 89 | if ds and x+window_size_x>feature_size_x: #only happens when ds is True 90 | in_window_sum_tensor[x,y,:] = np.sum(local_feature[x:,y:y+window_size_y,:],axis=(0,1)) + np.sum(local_feature[:x+window_size_x-feature_size_x,y:y+window_size_y,:],axis=(0,1)) 91 | else: 92 | in_window_sum_tensor[x,y,:] = np.sum(local_feature[x:x+window_size_x,y:y+window_size_y,:],axis=(0,1)) 93 | 94 | 95 | idx = np.ones([num_cls],dtype=bool) 96 | idx[global_pred]=False 97 | for x in range(0,num_window_x): 98 | for y in range(0,num_window_y): 99 | 100 | # determine the upper bound of wrong class evidence 101 | global_feature_masked = global_feature - in_window_sum_tensor[x,y,:] # $t$ in the proof of Lemma 1 102 | global_feature_masked[idx]/=(1 - thres) # $t/(1-T)$, the upper bound of wrong class evidence 103 | 104 | # determine the lower bound of true class evidence 105 | local_feature_pred_masked = local_feature_pred.copy() 106 | if ds and x+window_size_x>feature_size_x: 107 | local_feature_pred_masked[x:,y:y+window_size_y]=0 108 | local_feature_pred_masked[:x+window_size_x-feature_size_x,y:y+window_size_y]=0 109 | else: 110 | local_feature_pred_masked[x:x+window_size_x,y:y+window_size_y]=0 # operation $u\odot(1-w)$ 111 | 112 | in_window_sum_pred_masked = in_window_sum_tensor[:,:,global_pred].copy() 113 | overlap_window_max_sum = 0 114 | # only need to recalculate the windows the are partially masked 115 | for xx in range(max(0,x - window_size_x + 1),min(x + window_size_x,num_window_x)): 116 | for yy in range(max(0,y - window_size_y + 1),min(y + window_size_y,num_window_y)): 117 | if ds and xx+window_size_x>feature_size_x: 118 | in_window_sum_pred_masked[xx,yy]=local_feature_pred_masked[xx:,yy:yy+window_size_y].sum()+local_feature_pred_masked[:xx+window_size_x-feature_size_x,yy:yy+window_size_y].sum() 119 | overlap_window_max_sum = in_window_sum_pred_masked[xx,yy] if overlap_window_max_sum thres: 126 | global_feature_masked[global_pred]-=max_window_sum_pred 127 | else: 128 | global_feature_masked[global_pred]-=overlap_window_max_sum 129 | 130 | 131 | # determine if an attack is possible 132 | if np.argsort(global_feature_masked,kind='stable')[-1]!=label: 133 | return 1 134 | 135 | return 2 #provable robustness 136 | 137 | 138 | 139 | 140 | # De-randomized Smoothing 141 | # Adapted from https://github.com/alevine0/patchSmoothing/blob/master/utils_band.py 142 | def ds(inpt,net,block_size, size_to_certify, num_classes, threshold=0.2): 143 | ''' 144 | inpt torch.tensor, the input images in CWH format 145 | net torch.nn.module, the based model whose input is small pixel bands 146 | block_size int, the width of pixel bands 147 | size_to_certify int, the patch size to be certified 148 | num_classes int, number of classes 149 | threshold float, the threshold for prediction, see their original paper for details 150 | 151 | Return [torch.tensor,torch.tensor], the clean prediction, certificate 152 | ''' 153 | 154 | predictions = torch.zeros(inpt.size(0), num_classes).type(torch.int).cuda() 155 | batch = inpt.permute(0,2,3,1) #color channel last 156 | for pos in range(batch.shape[2]): 157 | out_c1 = torch.zeros(batch.shape).cuda() 158 | out_c2 = torch.zeros(batch.shape).cuda() 159 | if (pos+block_size > batch.shape[2]): 160 | out_c1[:,:,pos:] = batch[:,:,pos:] 161 | out_c2[:,:,pos:] = 1. - batch[:,:,pos:] 162 | 163 | out_c1[:,:,:pos+block_size-batch.shape[2]] = batch[:,:,:pos+block_size-batch.shape[2]] 164 | out_c2[:,:,:pos+block_size-batch.shape[2]] = 1. - batch[:,:,:pos+block_size-batch.shape[2]] 165 | else: 166 | out_c1[:,:,pos:pos+block_size] = batch[:,:,pos:pos+block_size] 167 | out_c2[:,:,pos:pos+block_size] = 1. - batch[:,:,pos:pos+block_size] 168 | 169 | out_c1 = out_c1.permute(0,3,1,2) 170 | out_c2 = out_c2.permute(0,3,1,2) 171 | out = torch.cat((out_c1,out_c2), 1) 172 | softmx = torch.nn.functional.softmax(net(out),dim=1) 173 | predictions += (softmx >= threshold).type(torch.int).cuda() 174 | 175 | predinctionsnp = predictions.cpu().numpy() 176 | idxsort = np.argsort(-predinctionsnp,axis=1,kind='stable') 177 | valsort = -np.sort(-predinctionsnp,axis=1,kind='stable') 178 | val = valsort[:,0] 179 | idx = idxsort[:,0] 180 | valsecond = valsort[:,1] 181 | idxsecond = idxsort[:,1] 182 | num_affected_classifications=(size_to_certify + block_size -1) 183 | cert = torch.tensor(((val - valsecond >2*num_affected_classifications) | ((val - valsecond ==2*num_affected_classifications)&(idx < idxsecond)))).cuda() 184 | return torch.tensor(idx).cuda(), cert 185 | 186 | 187 | # mask-ds 188 | def masking_ds(inpt,labels,net,block_size,size_to_certify,thres=0.0): 189 | ''' 190 | inpt torch.tensor, the input images in CWH format 191 | labels numpy.ndarray, the list of label 192 | net torch.nn.module, the based model whose input is small pixel bands 193 | block_size int, the width of pixel bands 194 | size_to_certify int, the patch size to be certified 195 | thres float, the detection theshold ($T$). Note it is not `threshold` in ds() 196 | 197 | Return: [list,list], a list of provable analysis results and a list of clean prediction correctneses 198 | ''' 199 | logits_list=[] 200 | cnf_list=[] 201 | pred_list=[] 202 | batch = inpt.permute(0,2,3,1) #color channel last 203 | for pos in range(batch.shape[2]): 204 | out_c1 = torch.zeros(batch.shape).cuda() 205 | out_c2 = torch.zeros(batch.shape).cuda() 206 | if (pos+block_size > batch.shape[2]): 207 | out_c1[:,:,pos:] = batch[:,:,pos:] 208 | out_c2[:,:,pos:] = 1. - batch[:,:,pos:] 209 | 210 | out_c1[:,:,:pos+block_size-batch.shape[2]] = batch[:,:,:pos+block_size-batch.shape[2]] 211 | out_c2[:,:,:pos+block_size-batch.shape[2]] = 1. - batch[:,:,:pos+block_size-batch.shape[2]] 212 | else: 213 | out_c1[:,:,pos:pos+block_size] = batch[:,:,pos:pos+block_size] 214 | out_c2[:,:,pos:pos+block_size] = 1. - batch[:,:,pos:pos+block_size] 215 | 216 | out_c1 = out_c1.permute(0,3,1,2) 217 | out_c2 = out_c2.permute(0,3,1,2) 218 | out = torch.cat((out_c1,out_c2), 1) 219 | logits_tmp = net(out).detach().cpu().numpy() 220 | cnf_tmp = softmax(logits_tmp,axis=-1) 221 | pred_tmp = (cnf_tmp > 0.2).astype(float) 222 | logits_list.append(logits_tmp) 223 | cnf_list.append(cnf_tmp) 224 | pred_list.append(pred_tmp) 225 | 226 | #output_list = np.stack(logits_list,axis=1) 227 | output_list = np.stack(cnf_list,axis=1) 228 | #output_list = np.stack(pred_list,axis=1) 229 | 230 | B,W,C=output_list.shape 231 | result_list=[] 232 | clean_corr_list=[] 233 | window_size = block_size + size_to_certify -1 234 | 235 | for i in range(len(labels)): 236 | local_feature = output_list[i].reshape([W,1,C]) 237 | result=provable_masking(local_feature,labels[i],window_shape=[window_size,1],thres=thres,ds=True) 238 | clean_pred=masking_defense(local_feature,window_shape=[window_size,1],thres=thres,ds=True) 239 | result_list.append(result) 240 | clean_corr_list.append(clean_pred == labels[i]) 241 | 242 | return result_list,clean_corr_list 243 | 244 | ################################################################################################################################## 245 | 246 | # a extended version of provable_masking() 247 | def provable_masking_large_mask(local_feature,label,clipping=-1,thres=0.,window_shape=[6,6],mask_shape=None): 248 | ''' 249 | local_feature numpy.ndarray, feature tensor in the shape of [feature_size_x,feature_size_y,num_cls] 250 | label int, true label 251 | clipping int/float, the positive clipping value ($c_h$ in the paper). If clipping < 0, treat clipping as np.inf 252 | thres float in [0,1], detection threshold. ($T$ in the paper) 253 | window_shape list [int,int], the shape of malicious window 254 | mask_shape list [int,int], the shape of mask window. If set to None, take the same value of window_shape 255 | 256 | Return int, provable analysis results (0: incorrect clean prediction; 1: possible attack found; 2: certified robustness ) 257 | ''' 258 | feature_size_x,feature_size_y,num_cls = local_feature.shape 259 | 260 | patch_size_x,patch_size_y = window_shape 261 | num_patch_x = feature_size_x - patch_size_x + 1 262 | num_patch_y = feature_size_y - patch_size_y + 1 263 | 264 | if mask_shape is None: 265 | mask_shape = window_shape 266 | mask_size_x,mask_size_y = mask_shape 267 | num_mask_x = feature_size_x - mask_size_x + 1 268 | num_mask_y = feature_size_y - mask_size_y + 1 269 | 270 | if clipping > 0: 271 | local_feature = np.clip(local_feature,0,clipping) 272 | else: 273 | local_feature = np.clip(local_feature,0,np.inf) 274 | 275 | global_feature = np.sum(local_feature,axis=(0,1)) 276 | 277 | pred_list = np.argsort(global_feature,kind='stable') 278 | global_pred = pred_list[-1] 279 | 280 | if global_pred != label: #clean prediction is incorrect 281 | return 0 282 | 283 | # the sum of class evidence within mask window 284 | in_mask_sum_tensor = np.zeros([num_mask_x,num_mask_y,num_cls]) 285 | for x in range(0,num_mask_x): 286 | for y in range(0,num_mask_y): 287 | in_mask_sum_tensor[x,y] = np.sum(local_feature[x:x+mask_size_x,y:y+mask_size_y,:],axis=(0,1)) 288 | 289 | 290 | # the sum of class evidence within each possible malicious window 291 | in_patch_sum_tensor = np.zeros([num_patch_x,num_patch_y,num_cls]) 292 | for x in range(0,num_patch_x): 293 | for y in range(0,num_patch_y): 294 | in_patch_sum_tensor[x,y,:] = np.sum(local_feature[x:x+patch_size_x,y:y+patch_size_y,:],axis=(0,1)) 295 | 296 | #out_patch_sum_tensor = global_feature.reshape([1,1,num_cls]) - in_patch_sum_tensor 297 | 298 | idx = np.ones([num_cls],dtype=bool) 299 | idx[global_pred]=False 300 | 301 | for x in range(0,num_patch_x): 302 | for y in range(0,num_patch_y): 303 | 304 | # determine the upper bound of wrong class evidence 305 | cover_patch_mask_sum_tensor = in_mask_sum_tensor[max(0,x + patch_size_x - mask_size_x):min(x+1,num_mask_x),max(0,y + patch_size_y - mask_size_y):min(y+1,num_mask_y)] 306 | max_cover_patch_mask_sum = np.max(cover_patch_mask_sum_tensor,axis=(0,1)) 307 | global_feature_patched = global_feature - max_cover_patch_mask_sum # $t-k$ in the proof of Lemma 2 308 | global_feature_patched[idx]/=(1 - thres) # $(t-k)/(1-T)$ in the proof of Lemma 2 309 | overlap_window_max_sum = 0 310 | # determine the lower bound of true class evidence 311 | local_feature_pred_masked = local_feature[:,:,global_pred].copy() 312 | local_feature_pred_masked[x:x+patch_size_x,y:y+patch_size_y]=0 313 | in_mask_sum_pred_masked = in_mask_sum_tensor[:,:,global_pred].copy() 314 | # only need to recalculate the windows the are partially masked 315 | for xx in range(max(0,x - mask_size_x + 1),min(x + patch_size_x,num_mask_x)): 316 | for yy in range(max(0,y - mask_size_y + 1),min(y + patch_size_y,num_mask_y)): 317 | in_mask_sum_pred_masked[xx,yy]=local_feature_pred_masked[xx:xx+mask_size_x,yy:yy+mask_size_y].sum() 318 | overlap_window_max_sum = in_window_sum_pred_masked[xx,yy] if overlap_window_max_sum thres: 323 | global_feature_patched[global_pred]-=max_mask_sum_pred 324 | else: 325 | global_feature_masked[global_pred]-=overlap_window_max_sum 326 | 327 | # determine if an attack is possible 328 | if np.argsort(global_feature_patched,kind='stable')[-1]!=label: 329 | return 1 330 | return 2 #provable robustness 331 | 332 | 333 | # clipping based defense 334 | def clipping_defense(local_feature,clipping=-1): 335 | ''' 336 | local_feature numpy.ndarray, feature tensor in the shape of [feature_size_x,feature_size_y,num_cls] 337 | clipping int/float, clipping value. If clipping < 0, use cbn clipping 338 | 339 | Return int, provable analysis results (0: incorrect clean prediction; 1: possible attack found; 2: certified robustness ) 340 | ''' 341 | if clipping > 0: 342 | local_feature = np.clip(local_feature,0,clipping) #clipped with [0,clipping] 343 | else: 344 | local_feature = np.tanh(local_feature*0.05-1) # clipped with tanh (CBN) 345 | global_feature = np.mean(local_feature,axis=(0,1)) 346 | global_pred = np.argmax(global_feature) 347 | 348 | return global_pred 349 | 350 | # provable analysis for clipping based defense 351 | def provable_clipping(local_feature,label,clipping=-1,window_shape=[6,6]): 352 | 353 | ''' 354 | local_feature numpy.ndarray, feature tensor in the shape of [feature_size_x,feature_size_y,num_cls] 355 | label int, true label 356 | clipping int/float, clipping value. If clipping < 0, use cbn clipping 357 | 358 | window_shape list [int,int], the shape of sliding window 359 | 360 | Return int, provable analysis results (0: incorrect clean prediction; 1: possible attack found; 2: certified robustness ) 361 | ''' 362 | feature_size_x,feature_size_y,num_cls = local_feature.shape 363 | window_size_x,window_size_y = window_shape 364 | num_window_x = feature_size_x - window_size_x + 1 365 | num_window_y = feature_size_y - window_size_y + 1 366 | 367 | if clipping > 0: 368 | local_feature = np.clip(local_feature,0,clipping) #clipped with [0,clipping] 369 | max_increase = window_size_x * window_size_y * clipping 370 | else: 371 | local_feature = np.tanh(local_feature*0.05-1) # clipped with tanh (CBN) 372 | max_increase = window_size_x * window_size_y * 2 373 | 374 | local_pred = np.argmax(local_feature,axis=-1) 375 | global_feature = np.mean(local_feature,axis=(0,1)) 376 | pred_list = np.argsort(global_feature) 377 | global_pred = pred_list[-1] 378 | if global_pred != label: #clean prediction is incorrect 379 | return 0 380 | local_feature_pred = local_feature[:,:,global_pred] 381 | 382 | 383 | target_cls = pred_list[-2] #second prediction 384 | 385 | local_feature_target = local_feature[:,:,target_cls] 386 | diff_feature = local_feature_pred - local_feature_target 387 | 388 | for x in range(0,num_window_x): 389 | for y in range(0,num_window_y): 390 | diff_feature_masked = diff_feature.copy() 391 | diff_feature_masked[x:x+window_size_x,y:y+window_size_y]=0 392 | diff = diff_feature_masked.sum() 393 | if diff < max_increase: 394 | return 1 395 | return 2 # provable robustness 396 | 397 | 398 | 399 | ################################################################################################################################## 400 | 401 | 402 | # for PatchGuard++ 403 | 404 | 405 | 406 | def pg2_detection(local_feature,tau,window_shape=[6,6]): 407 | ''' 408 | local_feature numpy.ndarray, feature tensor in the shape of [feature_size_x,feature_size_y,num_cls] 409 | tau float in [0,1], detection threshold. $\tau$ in the paper 410 | window_shape list [int,int], the shape of sliding window 411 | 412 | Return int, class label or -1 for alert 413 | ''' 414 | feature_size_x,feature_size_y,num_cls = local_feature.shape 415 | window_size_x,window_size_y = window_shape 416 | num_window_x = feature_size_x - window_size_x + 1 417 | num_window_y = feature_size_y - window_size_y + 1 418 | 419 | global_feature = np.mean(local_feature,axis=(0,1)) 420 | pred_list = np.argsort(global_feature,kind='stable') 421 | global_pred = pred_list[-1] 422 | 423 | in_window_sum_tensor=np.zeros([num_window_x,num_window_y,num_cls]) 424 | for x in range(0,num_window_x): 425 | for y in range(0,num_window_y): 426 | in_window_sum_tensor[x,y,:] = np.sum(local_feature[x:x+window_size_x,y:y+window_size_y,:],axis=(0,1)) 427 | in_window_sum_tensor = in_window_sum_tensor/(feature_size_x*feature_size_y) 428 | 429 | for x in range(0,num_window_x): 430 | for y in range(0,num_window_y): 431 | global_feature_masked = global_feature - in_window_sum_tensor[x,y] 432 | global_feature_masked = softmax(global_feature_masked) 433 | masked_pred = np.argmax(global_feature_masked) 434 | masked_conf = np.max(global_feature_masked) 435 | if masked_pred != global_pred and masked_conf>tau: 436 | return -1 437 | return global_pred 438 | 439 | 440 | 441 | 442 | def pg2_detection_provable(local_feature,label,tau,window_shape=[6,6]): 443 | ''' 444 | local_feature numpy.ndarray, feature tensor in the shape of [feature_size_x,feature_size_y,num_cls] 445 | label int, the ground-truth class label 446 | tau float in [0,1], detection threshold. $\tau$ in the paper 447 | window_shape list [int,int], the shape of sliding window 448 | 449 | Return int, provable analysis results (0: incorrect clean prediction; 1: possible attack found; 2: certified robustness ) 450 | ''' 451 | feature_size_x,feature_size_y,num_cls = local_feature.shape 452 | window_size_x,window_size_y = window_shape 453 | num_window_x = feature_size_x - window_size_x + 1 454 | num_window_y = feature_size_y - window_size_y + 1 455 | 456 | global_feature = np.mean(local_feature,axis=(0,1)) 457 | pred_list = np.argsort(global_feature,kind='stable') 458 | global_pred = pred_list[-1] 459 | 460 | in_window_sum_tensor=np.zeros([num_window_x,num_window_y,num_cls]) 461 | for x in range(0,num_window_x): 462 | for y in range(0,num_window_y): 463 | in_window_sum_tensor[x,y,:] = np.sum(local_feature[x:x+window_size_x,y:y+window_size_y,:],axis=(0,1)) 464 | in_window_sum_tensor = in_window_sum_tensor/(feature_size_x*feature_size_y) 465 | 466 | if global_pred != label: # clean prediction is incorrect 467 | return 0 468 | 469 | for x in range(0,num_window_x): 470 | for y in range(0,num_window_y): 471 | global_feature_masked = global_feature - in_window_sum_tensor[x,y] 472 | global_feature_masked = softmax(global_feature_masked) 473 | masked_pred = np.argmax(global_feature_masked) 474 | masked_conf = np.max(global_feature_masked) 475 | if masked_pred != label or masked_conftau: 524 | clean = 0 525 | if provable == 1 and clean ==0: 526 | return provable,clean 527 | return provable,clean 528 | 529 | """ --------------------------------------------------------------------------------