├── assets ├── arch.png └── comparison.png ├── requirement.txt ├── LICENSE ├── checkpoints └── README.md ├── utils ├── cutout.py ├── setup.py └── defense.py ├── vanilla_clean_acc.py ├── example_cmd.sh ├── misc ├── notes_on_robustness_evaluation.md ├── notes_on_mr.md ├── pc_multiple.py ├── pc_mr.py ├── pc_mr_experimental.py └── reproducibility.md ├── pc_clean_acc.py ├── pc_certification.py ├── README.md └── train_model.py /assets/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inspire-group/PatchCleanser/HEAD/assets/arch.png -------------------------------------------------------------------------------- /assets/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inspire-group/PatchCleanser/HEAD/assets/comparison.png -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | tqdm #==4.51.0 2 | torch #==1.7.0 3 | torchvision #==0.8.1 4 | joblib #==0.17.0 5 | numpy #==1.19.2 6 | timm == 0.4.12 # timm is actively evolving, so I specify the version number -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Princeton INSPIRE Research Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | ## Checkpoints 2 | Model checkpoints used in the paper can be downloaded from [link](https://drive.google.com/drive/folders/1Ewks-NgJHDlpeAaGInz_jZ6iczcYNDlN?usp=sharing). 3 | 4 | Model training should be very easy with the provided training scripts; see `example_cmds.sh` for examples. 5 | 6 | #### checkpoint name format: 7 | 8 | `{model_name}{cutout_setting}_{dataset_name}.pth` 9 | 10 | `{model_name}` options: 11 | 12 | 1. `resnetv2_50x1_bit_distilled` 13 | 2. `vit_base_patch16_224` 14 | 3. `resmlp_24_distilled_224` 15 | 16 | `{cutout_setting}` options: 17 | 18 | 1. empty 19 | 2. `_cutout2_128` (2 cutout squares of size 128px; the default setting used in the paper) 20 | 21 | `{dataset_name}` options: 22 | 23 | 1. `imagenet` 24 | 2. `imagenette` 25 | 3. `cifar` 26 | 4. `cifar100` 27 | 5. `flower102` 28 | 6. `svhn` 29 | 30 | **Note 1:** We do not have weights for ImageNet; the pretrained weights can be loaded using `timm`. 31 | 32 | **Note 2:** `'_{dataset_name}.pth'` will be automatically appended to `args.model` in the scripts `pc_certification.py`, `pc_clean_acc.py`, and `vanilla_clean_acc.py`. 33 | 34 | 35 | **Update 04/2023:** Add support for MAE. Download checkpoints its [GitHub repository](https://github.com/facebookresearch/mae) and add a name suffix `_imagenet.pth`. -------------------------------------------------------------------------------- /utils/cutout.py: -------------------------------------------------------------------------------- 1 | ############################################### 2 | # from https://github.com/uoguelph-mlrg/Cutout 3 | ############################################### 4 | 5 | import torch 6 | import numpy as np 7 | 8 | 9 | class Cutout(object): 10 | """Randomly mask out one or more patches from an image. 11 | 12 | Args: 13 | n_holes (int): Number of patches to cut out of each image. 14 | length (int): The length (in pixels) of each square patch. 15 | """ 16 | def __init__(self, n_holes, length): 17 | self.n_holes = n_holes 18 | self.length = length 19 | 20 | def __call__(self, img): 21 | """ 22 | Args: 23 | img (Tensor): Tensor image of size (C, H, W). 24 | Returns: 25 | Tensor: Image with n_holes of dimension length x length cut out of it. 26 | """ 27 | h = img.size(1) 28 | w = img.size(2) 29 | 30 | mask = np.ones((h, w), np.float32) 31 | 32 | for n in range(self.n_holes): 33 | y = np.random.randint(h) 34 | x = np.random.randint(w) 35 | 36 | y1 = np.clip(y - self.length // 2, 0, h) 37 | y2 = np.clip(y + self.length // 2, 0, h) 38 | x1 = np.clip(x - self.length // 2, 0, w) 39 | x2 = np.clip(x + self.length // 2, 0, w) 40 | 41 | mask[y1: y2, x1: x2] = 0. 42 | 43 | mask = torch.from_numpy(mask) 44 | mask = mask.expand_as(img) 45 | img = img * mask 46 | 47 | return img 48 | -------------------------------------------------------------------------------- /vanilla_clean_acc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.backends.cudnn as cudnn 3 | 4 | import numpy as np 5 | import os 6 | import argparse 7 | import time 8 | from tqdm import tqdm 9 | 10 | from utils.setup import get_model,get_data_loader 11 | 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument("--model_dir",default='checkpoints',type=str,help="directory of checkpoints") 15 | parser.add_argument('--data_dir', default='data', type=str,help="directory of data") 16 | parser.add_argument('--dataset', default='imagenette',type=str,choices=('imagenette','imagenet','cifar','cifar100','svhn','flower102'),help="dataset") 17 | parser.add_argument("--model",default='vit_base_patch16_224',type=str,help="model name") 18 | parser.add_argument("--num_img",default=-1,type=int,help="number of randomly selected images for this experiment (-1: using the all images)") 19 | 20 | args = parser.parse_args() 21 | 22 | DATASET = args.dataset 23 | MODEL_DIR=os.path.join('.',args.model_dir) 24 | DATA_DIR=os.path.join(args.data_dir,DATASET) 25 | MODEL_NAME = args.model 26 | NUM_IMG = args.num_img 27 | 28 | #get model and data loader 29 | model = get_model(MODEL_NAME,DATASET,MODEL_DIR) 30 | val_loader,NUM_IMG,_ = get_data_loader(DATASET,DATA_DIR,model,batch_size=16,num_img=NUM_IMG,train=False) 31 | 32 | device = 'cuda' 33 | model = model.to(device) 34 | model.eval() 35 | cudnn.benchmark = True 36 | 37 | accuracy_list=[] 38 | time_list=[] 39 | for data,labels in tqdm(val_loader): 40 | data,labels=data.to(device),labels.to(device) 41 | start = time.time() 42 | output_clean = model(data) 43 | end=time.time() 44 | time_list.append(end-start) 45 | acc_clean=torch.sum(torch.argmax(output_clean, dim=1) == labels).item()#cpu().detach().numpy() 46 | accuracy_list.append(acc_clean) 47 | 48 | print("Test accuracy:",np.sum(accuracy_list)/NUM_IMG) 49 | print('Per-example inference time:',np.sum(time_list)/NUM_IMG) 50 | 51 | -------------------------------------------------------------------------------- /example_cmd.sh: -------------------------------------------------------------------------------- 1 | # certified robust accuracy of patchcleanser models 2 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img -1 --num_mask 6 --patch_size 32 # a simple usage example 3 | 4 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenette --num_img -1 --num_mask 6 --patch_size 32 # experiment with a different dataset 5 | python pc_certification.py --model resnetv2_50x1_bit_distilled --dataset imagenette --num_img -1 --num_mask 6 --patch_size 32 # experiment with a different architecture 6 | python pc_certification.py --model resnetv2_50x1_bit_distilled --dataset imagenette --num_img -1 --num_mask 6 --patch_size 64 # experiment with a larger patch 7 | python pc_certification.py --model resnetv2_50x1_bit_distilled --dataset imagenette --num_img 1000 --num_mask 6 --patch_size 32 # experiment with a random subset of images 8 | 9 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenette --num_img -1 --num_mask 3 --patch_size 32 # adjust computation budget (number of masks) 10 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenette --num_img -1 --mask_stride 32 --patch_size 32 # set mask_stride instead of num_mask 11 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenette --num_img -1 --mask_stride 32 --pa 16 --pb 64 # consider a rectangle patch 12 | 13 | 14 | # clean accuracy of patchcleanser models (the same usage as pc_certification.py) 15 | python pc_clean_acc.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 6 --patch_size 32 16 | 17 | 18 | # clean accuracy of vanilla undefended models (similar usage) 19 | python vanilla_clean_acc.py --model vit_base_patch16_224 --dataset imagenet 20 | 21 | 22 | # train models (similar usage for other datasets) 23 | python train_model.py --model vit_base_patch16_224 --dataset imagenette --lr 0.001 --epoch 10 24 | python train_model.py --model vit_base_patch16_224 --dataset imagenette --lr 0.001 --epoch 10 --cutout --cutout_size 128 --n_holes 2 25 | python train_model.py --model resnetv2_50x1_bit_distilled --dataset imagenette --lr 0.01 --epoch 10 26 | python train_model.py --model resnetv2_50x1_bit_distilled --dataset imagenette --lr 0.01 --epoch 10 --cutout --cutout_size 128 --n_holes 2 27 | python train_model.py --model resmlp_24_distilled_224 --dataset imagenette --lr 0.001 --epoch 10 28 | python train_model.py --model resmlp_24_distilled_224 --dataset imagenette --lr 0.001 --epoch 10 --cutout --cutout_size 128 --n_holes 2 29 | -------------------------------------------------------------------------------- /misc/notes_on_robustness_evaluation.md: -------------------------------------------------------------------------------- 1 | ## Are you looking for "attack code" in this repository? 2 | 3 | Unfortunately, there is no attack code here. 4 | 5 | Note that PatchCleanser is a certifiably robust defense. Its robustness evaluation is done using the certification procedure (Algorithm 2), instead of any concrete attack code. The proof (Theorem 1) in the paper demonstrates the soundness of our certification algorithm: the certified robust accuracy is a lower bound on model accuracy against any attacker within the threat model. 6 | 7 | Please read [this note](https://github.com/xiangchong1/adv-patch-paper-list#empirically-robust-defenses-vs-provablycertifiably-robust-defenses) for more discussions between certifiably robust defenses and empirical defenses. 8 | 9 | ## What if I really want to evaluate the empirical robust accuracy of PatchCleanser? 10 | 11 | You will probably have to implement the attack yourself. 12 | 13 | #### Here are a few notes/suggestions. 14 | 15 | Of course, you need to use *adaptive* attacks (the attack algorithm should be targeted at PatchCleanser) to evaluate the empirical robust accuracy. The question is: what is a good adaptive attack against PatchCleanser? Here is one possible strategy: 16 | 1. Find a mask location where two-mask correctness is not satisfied, and place a patch there. -> since there is an incorrect two-mask prediction, PatchCleanser will not return a correct disagreer prediction in the second-round masking (Case II). 17 | 2. Optimize the patch content such that the one-mask majority prediction is incorrect -> then PatchCleanser will not return a correct majority prediction in Case III, or an agreed prediction via Case I. 18 | 19 | #### How hard is this attack? 20 | shouldn't be too hard. 21 | 1. The first step is just to find violated two-mask correctness. If there is no violated two-mask correctness, PatchCleanser is certifiably robust, and there is no need to empirically attack PatchCleanser 22 | 2. The second step shouldn't be too hard. When the first-round mask does not remove the patch, the malicious masked predictions usually do not change (so we have an incorrect majority prediction). Moreover, the patch content can be further optimized for the malicious majority prediction if we do observe inconsistent malicious masked predictions. 23 | 24 | #### Additional note 25 | be careful with the image pixel normalization. Some models in the repo scale pixel value with a mean `[0.5,0.5,0.5]` instead of `[0.485, 0.456, 0.406]`. The inference on clean images might only be affected slightly, but the robustness can be very different if you use a different normalization parameter. 26 | -------------------------------------------------------------------------------- /misc/notes_on_mr.md: -------------------------------------------------------------------------------- 1 | ## Notes on the Minority Reports defense 2 | 3 | In `pc_mr.py`, I discarded the voting grid design discussed in the original Minority Reports paper. See reasons below. 4 | 5 | ### MR Algorithm 6 | 7 | The original [Minority Reports paper](https://arxiv.org/abs/2004.13799) focuses on the certified robustness on low resolution images (e.g., CIFAR-10). 8 | 9 | To counter a 5x5 patch on 32x32 CIFAR images, MR 10 | 11 | 1. places 7x7 masks on the 32x32 image to generate a 25x25 prediction grid (25=32-7+1) 12 | 2. considers 3x3 regions on the 25x25 prediction grid to generate 23x23 voting grid (23=25-3+1) 13 | 3. If all voting grids vote for the correct class label, MR has certifiable robustness for attack detection 14 | 15 | ### Issue 16 | 17 | The certification depends on the argument that "wherever the sticker is placed, there will be a 3×3 grid in the prediction grid that is unaffected by the sticker." 18 | 19 | However, this argument is actually unsound for the approach described in the original paper. If we consider a 5x5 patch placed at the corner of the image, for example, the upper left coordinate of the patch is (0,0)...., there is *only one* unaffected masked prediction. This unaffected masked prediction is given a the mask whose upper left coordinate is (0,0). Any other mask will leave the adversarial *pixel* at (0,0) unmasked. We cannot certify the robustness for this corner case. 20 | 21 | ### Fix 22 | 23 | One possible fix is to start the mask location at (-2,-2) instead of (0,0)! 24 | 25 | In `pc_mr_expertimental.py`, I implemented this secure version of MR. Additionally, I add a parameters `--mr` to tune the number of predictions that participate in the voting. We will use `(mr+1)x(mr+1) ` predictions for voting. 26 | 27 | `mr=2` replicates the original MR. `mr=0` replicates the first-round masking of PatchCleanser. 28 | 29 | ### Observation 30 | 31 | My experiments on ImageNet find that, if we use the same mask stride, then 32 | 33 | 1. setting `mr` to a non-zero value only slightly affects the defense performance. 34 | 2. however, if we use the insecure masking strategy discussed in the original paper, a non-zero `mr` can significantly improves the defense performance. 35 | 4. I simplified `pc_mr_expertimental.py` to `pc_mr.py`, which always uses `mr=0`. 36 | 37 | By the way, there are other possible fixes that might improve the robustness with a non-zero `mr` , further discussions are out of the scope of this repo... 38 | 39 | PS. I did not spend much time on the Minority Reports. If you have better implementation strategies or have different observations, I am happy to discuss! 40 | 41 | -------------------------------------------------------------------------------- /pc_clean_acc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.backends.cudnn as cudnn 3 | 4 | import numpy as np 5 | import os 6 | import argparse 7 | import time 8 | from tqdm import tqdm 9 | 10 | from utils.setup import get_model,get_data_loader 11 | from utils.defense import gen_mask_set,double_masking#,challenger_masking 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--model_dir",default='checkpoints',type=str,help="directory of checkpoints") 15 | parser.add_argument('--data_dir', default='data', type=str,help="directory of data") 16 | parser.add_argument('--dataset', default='imagenette',type=str,choices=('imagenette','imagenet','cifar','cifar100','svhn','flower102'),help="dataset") 17 | parser.add_argument("--model",default='vit_base_patch16_224',type=str,help="model name") 18 | parser.add_argument("--num_img",default=-1,type=int,help="number of randomly selected images for this experiment (-1: using the all images)") 19 | parser.add_argument("--mask_stride",default=-1,type=int,help="mask stride s (square patch; conflict with num_mask)") 20 | parser.add_argument("--num_mask",default=-1,type=int,help="number of mask in one dimension (square patch; conflict with mask_stride)") 21 | parser.add_argument("--patch_size",default=32,type=int,help="size of the adversarial patch (square patch)") 22 | parser.add_argument("--pa",default=-1,type=int,help="size of the adversarial patch (first axis; for rectangle patch)") 23 | parser.add_argument("--pb",default=-1,type=int,help="size of the adversarial patch (second axis; for rectangle patch)") 24 | 25 | args = parser.parse_args() 26 | DATASET = args.dataset 27 | MODEL_DIR = os.path.join('.',args.model_dir) 28 | DATA_DIR = os.path.join(args.data_dir,DATASET) 29 | MODEL_NAME = args.model 30 | NUM_IMG = args.num_img 31 | 32 | #get model and data loader 33 | model = get_model(MODEL_NAME,DATASET,MODEL_DIR) 34 | val_loader,NUM_IMG,ds_config = get_data_loader(DATASET,DATA_DIR,model,batch_size=1,num_img=NUM_IMG,train=False) 35 | 36 | device = 'cuda' 37 | model = model.to(device) 38 | model.eval() 39 | cudnn.benchmark = True 40 | 41 | # generate the mask set 42 | mask_list,MASK_SIZE,MASK_STRIDE = gen_mask_set(args,ds_config) 43 | 44 | 45 | clean_corr = 0 46 | time_list = [] 47 | 48 | for data,labels in tqdm(val_loader): 49 | data=data.to(device) 50 | labels = labels.numpy() 51 | start = time.time() 52 | preds = double_masking(data,mask_list,model) 53 | #preds = challenger_masking(data,mask_list,model) 54 | end = time.time() 55 | time_list.append(end-start) 56 | clean_corr += np.sum(preds==labels) 57 | 58 | 59 | print("Clean accuracy with defense:",clean_corr/NUM_IMG) 60 | print('per-example infernece time',np.sum(time_list)/NUM_IMG) -------------------------------------------------------------------------------- /utils/setup.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch 3 | from timm.data import resolve_data_config 4 | from timm.data.transforms_factory import create_transform 5 | from torchvision import datasets 6 | import numpy as np 7 | import os 8 | 9 | 10 | NUM_CLASSES_DICT = {'imagenette':10,'imagenet':1000,'flower102':102,'cifar':10,'cifar100':100,'svhn':10} 11 | 12 | def get_model(model_name,dataset_name,model_dir): 13 | 14 | #build model and load weights 15 | 16 | ''' 17 | INPUT: 18 | model_name str, model name. The name should contrain one of ('resnetv2_50x1_bit_distilled', 'vit_base_patch16_224','resmlp_24_distilled_224') 19 | dataset_name str, dataset name. One of ('imagenette','imagenet','cifar','cifar100','svhn','flower102') 20 | model_dir str, the directory of model checkpoints 21 | 22 | OUTPUT: 23 | model torch.nn.Module, the PyToch model with weights loaded 24 | ''' 25 | 26 | timm_pretrained = (dataset_name == 'imagenet') and ('cutout' not in model_name) 27 | if 'resnetv2_50x1_bit_distilled' in model_name: 28 | model = timm.create_model('resnetv2_50x1_bit_distilled', pretrained=timm_pretrained) 29 | elif 'vit_base_patch16_224' in model_name: 30 | model = timm.create_model('vit_base_patch16_224', pretrained=timm_pretrained) 31 | elif 'resmlp_24_distilled_224' in model_name: 32 | model = timm.create_model('resmlp_24_distilled_224', pretrained=timm_pretrained) 33 | elif 'mae_finetuned_vit_base' in model_name: 34 | model = timm.create_model('vit_base_patch16_224', pretrained=False,global_pool='avg') 35 | timm_pretrained = False 36 | del model.pretrained_cfg['mean'] 37 | del model.pretrained_cfg['std'] 38 | model.pretrained_cfg['crop_pct'] = 224/256 39 | 40 | 41 | # modify classification head and load model weight 42 | if not timm_pretrained: 43 | model.reset_classifier(num_classes=NUM_CLASSES_DICT[dataset_name]) 44 | checkpoint_name = model_name + '_{}.pth'.format(dataset_name) 45 | checkpoint = torch.load(os.path.join(model_dir,checkpoint_name)) 46 | if 'mae' not in model_name: 47 | model.load_state_dict(checkpoint['state_dict']) 48 | else: 49 | msg = model.load_state_dict(checkpoint['model'],strict=False) 50 | print(msg) 51 | 52 | 53 | return model 54 | 55 | 56 | def get_data_loader(dataset_name,data_dir,model,batch_size=1,num_img=-1,train=False): 57 | 58 | # get the data loader (possibly only a subset of the dataset) 59 | 60 | ''' 61 | INPUT: 62 | dataset_name str, dataset name. One of ('imagenette','imagenet','cifar','cifar100','svhn','flower102') 63 | data_dir str, the directory of data 64 | model_name str, model name. The name should contrain one of ('resnetv2_50x1_bit_distilled', 'vit_base_patch16_224','resmlp_24_distilled_224') 65 | model torch.nn.Module / timm.models, the built model returned by get_model(), which has an attribute of default_cfg for data preprocessing 66 | batch_size int, batch size. default value is 1 for per-example inference time evaluation. In practice, a larger batch size is preferred 67 | num_img int, number of images to construct a random image subset. if num_img<0, we return a data loader for the entire dataset 68 | train bool, whether to return the training data split. 69 | 70 | OUTPUT: 71 | loader the PyToch data loader 72 | len(dataset) the size of dataset 73 | config data preprocessing configuration dict 74 | ''' 75 | 76 | ### !!!!!ATTN!!!!!! Do not pass a DataParallel instance as `model`. 77 | ### The DataParallel wrap makes the `default_cfg` invisible to `resolve_data_config` and might returns a incompatiable data preprocessing pipeline 78 | ### you can add the DataParallel wrapper after calling `get_data_loader` 79 | if isinstance(model,(torch.nn.DataParallel)): 80 | model = model.module 81 | 82 | 83 | # get dataset 84 | if dataset_name in ['imagenette','imagenet','flower102']: 85 | #high resolution images; use the default image preprocessing (all three models use 224x224 inputs) 86 | config = resolve_data_config({}, model=model) 87 | print(config) 88 | ds_transforms = create_transform(**config) 89 | split = 'train' if train else 'val' 90 | dataset_ = datasets.ImageFolder(os.path.join(data_dir,split),ds_transforms) 91 | elif dataset_name in ['cifar','cifar100','svhn']: 92 | #low resolution images; resize them to 224x224 without cropping 93 | config = resolve_data_config({'crop_pct':1}, model=model) 94 | ds_transforms = create_transform(**config) 95 | if dataset_name == 'cifar': 96 | dataset_ = datasets.CIFAR10(root=data_dir, train=train, download=True, transform=ds_transforms) 97 | elif dataset_name == 'cifar100': 98 | dataset_ = datasets.CIFAR100(root=data_dir, train=train, download=True, transform=ds_transforms) 99 | elif dataset_name == 'svhn': 100 | split = 'train' if train else 'test' 101 | dataset_ = datasets.SVHN(root=data_dir, split=split, download=True, transform=ds_transforms) 102 | 103 | # select a random set of test images (when args.num_img>0) 104 | np.random.seed(233333333)#random seed for selecting test images 105 | idxs=np.arange(len(dataset_)) 106 | np.random.shuffle(idxs) 107 | if num_img>0: 108 | idxs=idxs[:num_img] 109 | dataset = torch.utils.data.Subset(dataset_, idxs) 110 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,shuffle=train,num_workers=2) 111 | 112 | return loader,len(dataset),config 113 | 114 | -------------------------------------------------------------------------------- /pc_certification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.backends.cudnn as cudnn 3 | 4 | import numpy as np 5 | import os 6 | import argparse 7 | import time 8 | from tqdm import tqdm 9 | import joblib 10 | 11 | from utils.setup import get_model,get_data_loader 12 | from utils.defense import gen_mask_set,double_masking_precomputed,certify_precomputed 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--model_dir",default='checkpoints',type=str,help="directory of checkpoints") 16 | parser.add_argument('--data_dir', default='data', type=str,help="directory of data") 17 | parser.add_argument('--dataset', default='imagenette',type=str,choices=('imagenette','imagenet','cifar','cifar100','svhn','flower102'),help="dataset") 18 | parser.add_argument("--model",default='vit_base_patch16_224',type=str,help="model name") 19 | parser.add_argument("--num_img",default=-1,type=int,help="number of randomly selected images for this experiment (-1: using the all images)") 20 | parser.add_argument("--mask_stride",default=-1,type=int,help="mask stride s (square patch; conflict with num_mask)") 21 | parser.add_argument("--num_mask",default=-1,type=int,help="number of mask in one dimension (square patch; conflict with mask_stride)") 22 | parser.add_argument("--patch_size",default=32,type=int,help="size of the adversarial patch (square patch)") 23 | parser.add_argument("--pa",default=-1,type=int,help="size of the adversarial patch (first axis; for rectangle patch)") 24 | parser.add_argument("--pb",default=-1,type=int,help="size of the adversarial patch (second axis; for rectangle patch)") 25 | parser.add_argument("--dump_dir",default='dump',type=str,help='directory to dump two-mask predictions') 26 | parser.add_argument("--override",action='store_true',help='override dumped file') 27 | 28 | args = parser.parse_args() 29 | DATASET = args.dataset 30 | MODEL_DIR = os.path.join('.',args.model_dir) 31 | DATA_DIR = os.path.join(args.data_dir,DATASET) 32 | DUMP_DIR = os.path.join('.',args.dump_dir) 33 | if not os.path.exists(DUMP_DIR): 34 | os.mkdir(DUMP_DIR) 35 | 36 | MODEL_NAME = args.model 37 | NUM_IMG = args.num_img 38 | 39 | #get model and data loader 40 | model = get_model(MODEL_NAME,DATASET,MODEL_DIR) 41 | val_loader,NUM_IMG,ds_config = get_data_loader(DATASET,DATA_DIR,model,batch_size=16,num_img=NUM_IMG,train=False) 42 | 43 | device = 'cuda' 44 | model = model.to(device) 45 | model.eval() 46 | cudnn.benchmark = True 47 | 48 | # generate the mask set 49 | mask_list,MASK_SIZE,MASK_STRIDE = gen_mask_set(args,ds_config) 50 | 51 | # the computation of two-mask predictions is expensive; will dump (or resue the dumped) two-mask predictions. 52 | SUFFIX = '_two_mask_{}_{}_m{}_s{}_{}.z'.format(DATASET,MODEL_NAME,MASK_SIZE,MASK_STRIDE,NUM_IMG) 53 | if not args.override and os.path.exists(os.path.join(DUMP_DIR,'prediction_map_list'+SUFFIX)): 54 | print('loading two-mask predictions') 55 | prediction_map_list = joblib.load(os.path.join(DUMP_DIR,'prediction_map_list'+SUFFIX)) 56 | orig_prediction_list = joblib.load(os.path.join(DUMP_DIR,'orig_prediction_list_{}_{}_{}.z'.format(DATASET,MODEL_NAME,NUM_IMG))) 57 | label_list = joblib.load(os.path.join(DUMP_DIR,'label_list_{}_{}_{}.z'.format(DATASET,MODEL_NAME,NUM_IMG))) 58 | else: 59 | print('computing two-mask predictions') 60 | prediction_map_list = [] 61 | confidence_map_list = [] 62 | label_list = [] 63 | orig_prediction_list = [] 64 | for data,labels in tqdm(val_loader): 65 | data=data.to(device) 66 | labels = labels.numpy() 67 | num_img = data.shape[0] 68 | num_mask = len(mask_list) 69 | 70 | #two-mask predictions 71 | prediction_map = np.zeros([num_img,num_mask,num_mask],dtype=int) 72 | confidence_map = np.zeros([num_img,num_mask,num_mask]) 73 | for i,mask in enumerate(mask_list): 74 | for j in range(i,num_mask): 75 | mask2 = mask_list[j] 76 | masked_output = model(torch.where(torch.logical_and(mask,mask2),data,torch.tensor(0.).cuda())) 77 | masked_output = torch.nn.functional.softmax(masked_output,dim=1) 78 | masked_conf, masked_pred = masked_output.max(1) 79 | masked_conf = masked_conf.detach().cpu().numpy() 80 | confidence_map[:,i,j] = masked_conf 81 | masked_pred = masked_pred.detach().cpu().numpy() 82 | prediction_map[:,i,j] = masked_pred 83 | 84 | #vanilla predictions 85 | clean_output = model(data) 86 | clean_conf, clean_pred = clean_output.max(1) 87 | clean_pred = clean_pred.detach().cpu().numpy() 88 | orig_prediction_list.append(clean_pred) 89 | prediction_map_list.append(prediction_map) 90 | confidence_map_list.append(confidence_map) 91 | label_list.append(labels) 92 | 93 | prediction_map_list = np.concatenate(prediction_map_list) 94 | confidence_map_list = np.concatenate(confidence_map_list) 95 | orig_prediction_list = np.concatenate(orig_prediction_list) 96 | label_list = np.concatenate(label_list) 97 | 98 | joblib.dump(confidence_map_list,os.path.join(DUMP_DIR,'confidence_map_list'+SUFFIX)) 99 | joblib.dump(prediction_map_list,os.path.join(DUMP_DIR,'prediction_map_list'+SUFFIX)) 100 | joblib.dump(orig_prediction_list,os.path.join(DUMP_DIR,'orig_prediction_list_{}_{}_{}.z'.format(DATASET,MODEL_NAME,NUM_IMG))) 101 | joblib.dump(label_list,os.path.join(DUMP_DIR,'label_list_{}_{}_{}.z'.format(DATASET,MODEL_NAME,NUM_IMG))) 102 | 103 | 104 | clean_corr = 0 105 | robust = 0 106 | orig_corr = 0 107 | for i,(prediction_map,label,orig_pred) in enumerate(zip(prediction_map_list,label_list,orig_prediction_list)): 108 | prediction_map = prediction_map + prediction_map.T - np.diag(np.diag(prediction_map)) #generate a symmetric matrix from a triangle matrix 109 | robust += certify_precomputed(prediction_map,label) 110 | clean_corr += double_masking_precomputed(prediction_map) == label 111 | orig_corr += orig_pred == label 112 | 113 | print("------------------------------") 114 | print("Certified robust accuracy:",robust/NUM_IMG) 115 | print("Clean accuracy with defense:",clean_corr/NUM_IMG) 116 | print("Clean accuracy without defense:",orig_corr/NUM_IMG) 117 | 118 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PatchCleanser: Certifiably Robust Defense against Adversarial Patches for Any Image Classifier 2 | 3 | By [Chong Xiang](http://xiangchong.xyz/), [Saeed Mahloujifar](https://smahloujifar.github.io/), [Prateek Mittal](https://www.princeton.edu/~pmittal/) 4 | 5 | Code for "[PatchCleanser: Certifiably Robust Defense against Adversarial Patches for Any Image Classifier](https://arxiv.org/abs/2108.09135)" in USENIX Security Symposium 2022. 6 | 7 | Update 04/2023: (1) Check out this [35-min presentation](https://drive.google.com/file/d/1LGBKTxMcapQculKZ543SY-bAQ8O8a8ei/view?usp=sharing) and this [short survey paper](https://www.ndss-symposium.org/ndss-paper/auto-draft-379/) for a quick introduction to certifiable patch defenses. (2) Added support for MAE. 8 | 9 | Update 08/2022: added notes for "there is no attack code in this repo" [here](./misc/notes_on_robustness_evaluation.md) . 10 | 11 | Update 04/2022: We earned all badges (available, functional, reproduced) in USENIX artifact evaluation! The camera ready version is available [here](https://www.usenix.org/conference/usenixsecurity22/presentation/xiang). We released the extended technical report [here](https://arxiv.org/abs/2108.09135). 12 | 13 | Update 03/2022: released the [**leaderboard** for certifiable robust image classification against adversarial patches](https://github.com/inspire-group/patch-defense-leaderboard). 14 | 15 | Update 02/2022: made some minor changes for the model loading script and updated the [pretrained weights](https://drive.google.com/drive/folders/1Ewks-NgJHDlpeAaGInz_jZ6iczcYNDlN?usp=sharing) 16 | 17 | defense overview pipeline 18 | 19 | **Takeaways**: 20 | 21 | 1. We design a *certifiably robust defense* against adversarial patches that is *compatible with any state-of-the-art image classifier*. 22 | 2. We achieve clean accuracy that is comparable to state-of-the-art image classifier and improves certified robust accuracy by a large margin. 23 | 3. We visualize our defense performance for 1000-class ImageNet below! 24 | 25 | performance comparison 26 | 27 | #### Check out our [paper list for adversarial patch research](https://github.com/xiangchong1/adv-patch-paper-list) and [leaderboard for certifiable robust image classification](https://github.com/inspire-group/patch-defense-leaderboard) for fun! 28 | 29 | ## Requirements 30 | 31 | Experiments were done with PyTorch 1.7.0 and timm 0.4.12. The complete list of required packages are available in `requirement.txt`, and can be installed with `pip install -r requirement.txt`. The code should be compatible with newer versions of packages. Update 04/2023: tested with `torch==1.13.1` and `timm=0.6.13`; the code should work fine. 32 | 33 | ## Files 34 | 35 | ```shell 36 | ├── README.md #this file 37 | ├── requirement.txt #required package 38 | ├── example_cmd.sh #example command to run the code 39 | | 40 | ├── pc_certification.py #PatchCleanser: certify robustness via two-mask correctness 41 | ├── pc_clean_acc.py #PatchCleanser: evaluate clean accuracy and per-example inference time 42 | | 43 | ├── vanilla_clean_acc.py #undefended vanilla models: evaluate clean accuracy and per-example inference time 44 | ├── train_model.py #train undefended vanilla models for different datasets 45 | | 46 | ├── utils 47 | | ├── setup.py #utils for constructing models and data loaders 48 | | ├── defense.py #utils for PatchCleanser defenses 49 | | └── cutout.py #utils for masked model training 50 | | 51 | ├── misc 52 | | ├── reproducibility.md #detailed instructions for reproducing paper results 53 | | ├── pc_mr.py #script for minority report (Figure 9) 54 | | └── pc_multiple.py #script for multiple patch shapes and multiple patches (Table 4) 55 | | 56 | ├── data 57 | | ├── imagenet #data directory for imagenet 58 | | ├── imagenette #data directory for imagenette 59 | | ├── cifar #data directory for cifar-10 60 | | ├── cifar100 #data directory for cifar-100 61 | | ├── flower102 #data directory for flower102 62 | | └── svhn #data directory for svhn 63 | | 64 | └── checkpoints #directory for checkpoints 65 | ├── README.md #details of checkpoints 66 | └── ... #model checkpoints 67 | ``` 68 | 69 | ## Datasets 70 | 71 | - [ImageNet](https://image-net.org/download.php) (ILSVRC2012) 72 | - [ImageNette](https://github.com/fastai/imagenette) ([Full size](https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz)) 73 | - [CIFAR-10/CIFAR-100](https://www.cs.toronto.edu/~kriz/cifar.html) 74 | - [Oxford Flower-102](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/) 75 | - [SVHN](http://ufldl.stanford.edu/housenumbers/) 76 | 77 | ## Usage 78 | 79 | - See **Files** for details of each file. 80 | - Download data in **Datasets** to `data/`. 81 | - (optional) Download checkpoints from Google Drive [link](https://drive.google.com/drive/folders/1Ewks-NgJHDlpeAaGInz_jZ6iczcYNDlN?usp=sharing) and move them to `checkpoints`. 82 | - (optional) Download pre-computed two-mask predictions for ImageNet from Google Drive [link](https://drive.google.com/drive/folders/1Ewks-NgJHDlpeAaGInz_jZ6iczcYNDlN?usp=sharing) and move them to `dump`. Computing two-mask predictions for other datasets shouldn't take too long. 83 | - See [`example_cmd.sh`](example_cmd.sh) for example commands for running the code. 84 | - See [`misc/reproducibility.md`](./misc/reproducibility.md) for instructions to reproduce all results in the main body of paper. 85 | 86 | If anything is unclear, please open an issue or contact Chong Xiang (cxiang@princeton.edu). 87 | 88 | ## Citations 89 | 90 | If you find our work useful in your research, please consider citing: 91 | 92 | ```tex 93 | @inproceedings{xiang2022patchcleanser, 94 | title={PatchCleanser: Certifiably Robust Defense against Adversarial Patches for Any Image Classifier}, 95 | author={Xiang, Chong and Mahloujifar, Saeed and Mittal, Prateek}, 96 | booktitle = {31st {USENIX} Security Symposium ({USENIX} Security)}, 97 | year={2022} 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /misc/pc_multiple.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.backends.cudnn as cudnn 3 | 4 | import numpy as np 5 | import os 6 | import argparse 7 | import time 8 | from tqdm import tqdm 9 | import joblib 10 | 11 | from utils.setup import get_model,get_data_loader 12 | from utils.defense import gen_mask_set,double_masking_precomputed,certify_precomputed 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--model_dir",default='checkpoints',type=str,help="directory of checkpoints") 16 | parser.add_argument('--data_dir', default='data', type=str,help="directory of data") 17 | parser.add_argument('--dataset', default='imagenette',type=str,choices=('imagenette','imagenet','cifar','cifar100','svhn','flower102'),help="dataset") 18 | parser.add_argument("--model",default='vit_base_patch16_224',type=str,help="model name") 19 | parser.add_argument("--num_img",default=-1,type=int,help="number of randomly selected images for this experiment (-1: using the all images)") 20 | parser.add_argument("--mask_stride",default=-1,type=int,help="mask stride s (square patch; conflict with num_mask)") 21 | parser.add_argument("--num_mask",default=-1,type=int,help="number of mask in one dimension (square patch; conflict with mask_stride)") 22 | parser.add_argument("--patch_size",default=32,type=int,help="size of the adversarial patch (square patch)") 23 | parser.add_argument("--pa",default=-1,type=int,help="size of the adversarial patch (first axis; for rectangle patch)") 24 | parser.add_argument("--pb",default=-1,type=int,help="size of the adversarial patch (second axis; for rectangle patch)") 25 | parser.add_argument("--dump_dir",default='dump',type=str,help='directory to dump two-mask predictions') 26 | parser.add_argument("--override",action='store_true',help='override dumped file') 27 | parser.add_argument("--mode",choices=('shape','twopatch'),type=str,help='either analyze multiple shapes or two patches') 28 | 29 | args = parser.parse_args() 30 | DATASET = args.dataset 31 | MODEL_DIR = os.path.join('.',args.model_dir) 32 | DATA_DIR = os.path.join(args.data_dir,DATASET) 33 | DUMP_DIR = os.path.join('.',args.dump_dir) 34 | if not os.path.exists(DUMP_DIR): 35 | os.mkdir(DUMP_DIR) 36 | 37 | MODEL_NAME = args.model 38 | NUM_IMG = args.num_img 39 | 40 | #get model and data loader 41 | model = get_model(MODEL_NAME,DATASET,MODEL_DIR) 42 | val_loader,NUM_IMG,ds_config = get_data_loader(DATASET,DATA_DIR,model,batch_size=16,num_img=NUM_IMG,train=False) 43 | 44 | device = 'cuda' 45 | model = model.to(device) 46 | model.eval() 47 | cudnn.benchmark = True 48 | 49 | ##################################################################################################################################### 50 | # generate the mask set 51 | mask_list = [] 52 | if args.mode == 'shape': 53 | for pa,pb in [(5,224),(12,83),(23,38),(39,20),(84,12),(224,5)]: 54 | #for pa,pb in [(9,224),(16,101),(23,60),(32,42),(42,32),(60,23),(101,16),(224,9)]: 55 | args.pa=pa 56 | args.pb=pb 57 | tmp,MASK_SIZE,MASK_STRIDE = gen_mask_set(args,ds_config) 58 | mask_list+=tmp 59 | SUFFIX = '_two_mask_shape_set_{}_{}_p{}_s{}_{}.z'.format(DATASET,MODEL_NAME,args.patch_size,MASK_STRIDE,NUM_IMG) 60 | elif args.mode == 'twopatch': 61 | tmp,MASK_SIZE,MASK_STRIDE = gen_mask_set(args,ds_config) 62 | mask_list = [torch.logical_and(mask1,mask2) for mask1 in tmp for mask2 in tmp] 63 | SUFFIX = '_two_2mask_{}_{}_p{}_s{}_{}.z'.format(DATASET,MODEL_NAME,args.patch_size,MASK_STRIDE,NUM_IMG) 64 | else: 65 | raise NotImplementedError 66 | ##################################################################################################################################### 67 | 68 | print(len(mask_list)) 69 | # the computation of two-mask predictions is expensive; will dump (or resue the dumped) two-mask predictions. 70 | if not args.override and os.path.exists(os.path.join(DUMP_DIR,'prediction_map_list'+SUFFIX)): 71 | print('loading two-mask predictions') 72 | prediction_map_list = joblib.load(os.path.join(DUMP_DIR,'prediction_map_list'+SUFFIX)) 73 | orig_prediction_list = joblib.load(os.path.join(DUMP_DIR,'orig_prediction_list_{}_{}_{}.z'.format(DATASET,MODEL_NAME,NUM_IMG))) 74 | label_list = joblib.load(os.path.join(DUMP_DIR,'label_list_{}_{}_{}.z'.format(DATASET,MODEL_NAME,NUM_IMG))) 75 | else: 76 | print('computing two-mask predictions') 77 | prediction_map_list = [] 78 | confidence_map_list = [] 79 | label_list = [] 80 | orig_prediction_list = [] 81 | for data,labels in tqdm(val_loader): 82 | data=data.to(device) 83 | labels = labels.numpy() 84 | num_img = data.shape[0] 85 | num_mask = len(mask_list) 86 | 87 | #two-mask predictions 88 | prediction_map = np.zeros([num_img,num_mask,num_mask],dtype=int) 89 | for i,mask in enumerate(mask_list): 90 | for j in range(i,num_mask): 91 | mask2 = mask_list[j] 92 | masked_output = model(torch.where(torch.logical_and(mask,mask2),data,torch.tensor(0.).cuda())) 93 | _, masked_pred = masked_output.max(1) 94 | masked_pred = masked_pred.detach().cpu().numpy() 95 | prediction_map[:,i,j] = masked_pred 96 | 97 | #vanilla predictions 98 | clean_output = model(data) 99 | clean_conf, clean_pred = clean_output.max(1) 100 | clean_pred = clean_pred.detach().cpu().numpy() 101 | orig_prediction_list.append(clean_pred) 102 | prediction_map_list.append(prediction_map) 103 | label_list.append(labels) 104 | 105 | prediction_map_list = np.concatenate(prediction_map_list) 106 | orig_prediction_list = np.concatenate(orig_prediction_list) 107 | label_list = np.concatenate(label_list) 108 | 109 | joblib.dump(prediction_map_list,os.path.join(DUMP_DIR,'prediction_map_list'+SUFFIX)) 110 | joblib.dump(orig_prediction_list,os.path.join(DUMP_DIR,'orig_prediction_list_{}_{}_{}.z'.format(DATASET,MODEL_NAME,NUM_IMG))) 111 | joblib.dump(label_list,os.path.join(DUMP_DIR,'label_list_{}_{}_{}.z'.format(DATASET,MODEL_NAME,NUM_IMG))) 112 | 113 | 114 | clean_corr = 0 115 | robust = 0 116 | orig_corr = 0 117 | for i,(prediction_map,label,orig_pred) in enumerate(zip(prediction_map_list,label_list,orig_prediction_list)): 118 | prediction_map = prediction_map + prediction_map.T - np.diag(np.diag(prediction_map)) #generate a symmetric matrix from a triangle matrix 119 | robust += certify_precomputed(prediction_map,label) 120 | clean_corr += double_masking_precomputed(prediction_map) == label 121 | orig_corr += orig_pred == label 122 | 123 | print("------------------------------") 124 | print("Provable robust accuracy:",robust/NUM_IMG) 125 | print("Clean accuracy with defense:",clean_corr/NUM_IMG) 126 | print("Clean accuracy without defense:",orig_corr/NUM_IMG) 127 | 128 | -------------------------------------------------------------------------------- /misc/pc_mr.py: -------------------------------------------------------------------------------- 1 | # you are encouraged to read `notes_on_mr.md` first :) 2 | import torch 3 | import torch.backends.cudnn as cudnn 4 | 5 | import numpy as np 6 | import os 7 | import argparse 8 | import time 9 | from tqdm import tqdm 10 | import joblib 11 | 12 | from utils.setup import get_model,get_data_loader 13 | from utils.defense import gen_mask_set,double_masking_precomputed,certify_precomputed 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--model_dir",default='checkpoints',type=str,help="directory of checkpoints") 17 | parser.add_argument('--data_dir', default='data', type=str,help="directory of data") 18 | parser.add_argument('--dataset', default='imagenette',type=str,choices=('imagenette','imagenet','cifar','cifar100','svhn','flower102'),help="dataset") 19 | parser.add_argument("--model",default='vit_base_patch16_224',type=str,help="model name") 20 | parser.add_argument("--num_img",default=-1,type=int,help="number of randomly selected images for this experiment (-1: using the all images)") 21 | parser.add_argument("--mask_stride",default=-1,type=int,help="mask stride s (square patch; conflict with num_mask)") 22 | parser.add_argument("--num_mask",default=-1,type=int,help="number of mask in one dimension (square patch; conflict with mask_stride)") 23 | parser.add_argument("--patch_size",default=32,type=int,help="size of the adversarial patch (square patch)") 24 | parser.add_argument("--pa",default=-1,type=int,help="size of the adversarial patch (first axis; for rectangle patch)") 25 | parser.add_argument("--pb",default=-1,type=int,help="size of the adversarial patch (second axis; for rectangle patch)") 26 | parser.add_argument("--dump_dir",default='dump',type=str,help='directory to dump two-mask predictions') 27 | parser.add_argument("--override",action='store_true',help='override dumped file') 28 | 29 | args = parser.parse_args() 30 | DATASET = args.dataset 31 | MODEL_DIR = os.path.join('.',args.model_dir) 32 | DATA_DIR = os.path.join(args.data_dir,DATASET) 33 | DUMP_DIR = os.path.join('.',args.dump_dir) 34 | if not os.path.exists(DUMP_DIR): 35 | os.mkdir(DUMP_DIR) 36 | 37 | MODEL_NAME = args.model 38 | NUM_IMG = args.num_img 39 | 40 | #get model and data loader 41 | model = get_model(MODEL_NAME,DATASET,MODEL_DIR) 42 | val_loader,NUM_IMG,ds_config = get_data_loader(DATASET,DATA_DIR,model,batch_size=16,num_img=NUM_IMG,train=False) 43 | 44 | device = 'cuda' 45 | model = model.to(device) 46 | model.eval() 47 | cudnn.benchmark = True 48 | num_classes = 1000 if args.dataset == 'imagenet' else 10 49 | 50 | # generate the mask set 51 | mask_list,MASK_SIZE,MASK_STRIDE = gen_mask_set(args,ds_config) 52 | #print(len(mask_list)) 53 | #args.num_mask = int((len(mask_list))**0.5) 54 | # the computation of two-mask predictions is expensive; will dump (or resue the dumped) two-mask predictions. 55 | SUFFIX = '_one_mask_{}_{}_m{}_s{}_{}.z'.format(DATASET,MODEL_NAME,MASK_SIZE,MASK_STRIDE,NUM_IMG) 56 | SUFFIX2 = '_two_mask_{}_{}_m{}_s{}_{}.z'.format(DATASET,MODEL_NAME,MASK_SIZE,MASK_STRIDE,NUM_IMG) 57 | print(os.path.join(DUMP_DIR,'confidence_map_list'+SUFFIX2)) 58 | if not args.override and os.path.exists(os.path.join(DUMP_DIR,'prediction_map_list'+SUFFIX)): 59 | print('loading one-mask predictions') 60 | confidence_map_list = joblib.load(os.path.join(DUMP_DIR,'confidence_map_list'+SUFFIX)) 61 | prediction_map_list = joblib.load(os.path.join(DUMP_DIR,'prediction_map_list'+SUFFIX)) 62 | orig_prediction_list = joblib.load(os.path.join(DUMP_DIR,'orig_prediction_list_{}_{}_{}.z'.format(DATASET,MODEL_NAME,NUM_IMG))) 63 | label_list = joblib.load(os.path.join(DUMP_DIR,'label_list_{}_{}_{}.z'.format(DATASET,MODEL_NAME,NUM_IMG))) 64 | elif not args.override and os.path.exists(os.path.join(DUMP_DIR,'confidence_map_list'+SUFFIX2)): 65 | print('loading two-mask predictions') 66 | confidence_map_list = joblib.load(os.path.join(DUMP_DIR,'confidence_map_list'+SUFFIX2)) 67 | prediction_map_list = joblib.load(os.path.join(DUMP_DIR,'prediction_map_list'+SUFFIX2)) 68 | print('converting two-mask predictions to one-mask predictions') 69 | confidence_map_list = np.diagonal(confidence_map_list,axis1=1,axis2=2) 70 | prediction_map_list = np.diagonal(prediction_map_list,axis1=1,axis2=2) 71 | orig_prediction_list = joblib.load(os.path.join(DUMP_DIR,'orig_prediction_list_{}_{}_{}.z'.format(DATASET,MODEL_NAME,NUM_IMG))) 72 | label_list = joblib.load(os.path.join(DUMP_DIR,'label_list_{}_{}_{}.z'.format(DATASET,MODEL_NAME,NUM_IMG))) 73 | else: 74 | print('computing one-mask predictions') 75 | prediction_map_list = [] 76 | confidence_map_list = [] 77 | label_list = [] 78 | orig_prediction_list = [] 79 | for data,labels in tqdm(val_loader): 80 | data=data.to(device) 81 | labels = labels.numpy() 82 | num_img = data.shape[0] 83 | 84 | #two-mask predictions 85 | num_mask = len(mask_list) 86 | prediction_map = np.zeros([num_img,num_mask],dtype=int) 87 | confidence_map = np.zeros([num_img,num_mask]) 88 | 89 | for i,mask in enumerate(mask_list): 90 | 91 | masked_output = model(torch.where(mask,data,torch.tensor(0.).cuda())) 92 | masked_output = torch.nn.functional.softmax(masked_output,dim=1) 93 | masked_conf, masked_pred = masked_output.max(1) 94 | masked_pred = masked_pred.detach().cpu().numpy() 95 | masked_conf = masked_conf.detach().cpu().numpy() 96 | 97 | prediction_map[:,i] = masked_pred 98 | confidence_map[:,i] = masked_conf 99 | 100 | #vanilla predictions 101 | clean_output = model(data) 102 | clean_conf, clean_pred = clean_output.max(1) 103 | clean_pred = clean_pred.detach().cpu().numpy() 104 | orig_prediction_list.append(clean_pred) 105 | prediction_map_list.append(prediction_map) 106 | confidence_map_list.append(confidence_map) 107 | label_list.append(labels) 108 | 109 | prediction_map_list = np.concatenate(prediction_map_list) 110 | confidence_map_list = np.concatenate(confidence_map_list) 111 | orig_prediction_list = np.concatenate(orig_prediction_list) 112 | label_list = np.concatenate(label_list) 113 | 114 | joblib.dump(confidence_map_list,os.path.join(DUMP_DIR,'confidence_map_list'+SUFFIX)) 115 | joblib.dump(prediction_map_list,os.path.join(DUMP_DIR,'prediction_map_list'+SUFFIX)) 116 | joblib.dump(orig_prediction_list,os.path.join(DUMP_DIR,'orig_prediction_list_{}_{}_{}.z'.format(DATASET,MODEL_NAME,NUM_IMG))) 117 | joblib.dump(label_list,os.path.join(DUMP_DIR,'label_list_{}_{}_{}.z'.format(DATASET,MODEL_NAME,NUM_IMG))) 118 | 119 | 120 | def provable_detection(prediction_map,confidence_map,label,orig_pred,tau): 121 | if orig_pred != label: # clean prediction is incorrect 122 | return 0,0 # 0 for incorrect clean prediction 123 | clean = 1 124 | provable = 1 125 | if np.any(np.logical_or(prediction_map!=label,confidence_maptau)): 128 | clean = 0 129 | return provable,clean 130 | 131 | 132 | clean_list = [] 133 | robust_list = [] 134 | for tau in np.arange(0.,1.,0.05): 135 | print(tau) 136 | clean_corr = 0 137 | robust_cnt = 0 138 | vanilla_corr =0 139 | for i,(prediction_map,confidence_map,label,orig_pred) in enumerate(zip(prediction_map_list,confidence_map_list,label_list,orig_prediction_list)): 140 | #print(confidence_map) 141 | provable,clean = provable_detection(prediction_map,confidence_map,label,orig_pred,tau) 142 | robust_cnt+=provable 143 | clean_corr+=clean 144 | vanilla_corr +=orig_pred==label 145 | robust_list.append(robust_cnt/NUM_IMG) 146 | clean_list.append(clean_corr/NUM_IMG) 147 | print("Provable robust accuracy ({}):".format(tau),robust_cnt/NUM_IMG) 148 | print("Clean accuracy with defense:",clean_corr/NUM_IMG) 149 | print("Clean accuracy without defense:",vanilla_corr/NUM_IMG) 150 | print() 151 | print('clean_list=', [100*x for x in clean_list]) 152 | print('robust_list=',[100*x for x in robust_list]) -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | ####################################################################################### 2 | # Adapted from https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html 3 | ####################################################################################### 4 | 5 | from __future__ import print_function, division 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.optim import lr_scheduler 11 | import numpy as np 12 | import torchvision 13 | from torchvision import datasets, models, transforms 14 | import time 15 | import os 16 | import copy 17 | from tqdm import tqdm 18 | import random 19 | import argparse 20 | from timm.data import resolve_data_config 21 | from timm.data.transforms_factory import create_transform 22 | from utils.cutout import Cutout 23 | 24 | import timm 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--model_dir",default='checkpoints',type=str) 28 | parser.add_argument("--data_dir",default='data',type=str) 29 | parser.add_argument("--dataset",default='imagenette',type=str) 30 | parser.add_argument("--model",default='vit_base_patch16_224',type=str) 31 | parser.add_argument("--epoch",default=10,type=int) 32 | parser.add_argument("--lr",default=0.001,type=float) 33 | parser.add_argument("--cutout_size",default=128,type=int) 34 | parser.add_argument("--resume",action='store_true') 35 | parser.add_argument("--n_holes",default=2,type=int) 36 | parser.add_argument("--cutout",action='store_true') 37 | args = parser.parse_args() 38 | 39 | MODEL_DIR=os.path.join('.',args.model_dir) 40 | DATA_DIR=os.path.join(args.data_dir,args.dataset) 41 | 42 | if not os.path.exists(MODEL_DIR): 43 | os.mkdir(MODEL_DIR) 44 | 45 | n_holes = args.n_holes 46 | cutout_size = args.cutout_size 47 | if args.cutout: 48 | model_name = args.model + '_cutout{}_{}_{}.pth'.format(n_holes,cutout_size,args.dataset) 49 | else: 50 | model_name = args.model + '_{}.pth'.format(args.dataset) 51 | 52 | device = 'cuda' 53 | 54 | if 'vit_base_patch16_224' in model_name: 55 | model = timm.create_model('vit_base_patch16_224', pretrained=True) 56 | elif 'resnetv2_50x1_bit_distilled' in model_name: 57 | model = timm.create_model('resnetv2_50x1_bit_distilled', pretrained=True) 58 | elif 'resmlp_24_distilled_224' in model_name: 59 | model = timm.create_model('resmlp_24_distilled_224', pretrained=True) 60 | 61 | 62 | # get data loader 63 | if args.dataset in ['imagenette','flower102']: 64 | config = resolve_data_config({}, model=model) 65 | ds_transforms = create_transform(**config) 66 | if args.cutout: 67 | ds_transforms.transforms.append(Cutout(n_holes=n_holes, length=cutout_size)) 68 | train_dataset = datasets.ImageFolder(os.path.join(DATA_DIR,'train'),ds_transforms) 69 | val_dataset = datasets.ImageFolder(os.path.join(DATA_DIR,'val'),ds_transforms) 70 | num_classes = 10 if args.dataset=='imagenette' else 102 71 | elif args.dataset in ['cifar','cifar100','svhn']: 72 | config = resolve_data_config({'crop_pct':1}, model=model)###############################to decide 73 | ds_transforms = create_transform(**config) 74 | if args.cutout: 75 | ds_transforms.transforms.append(Cutout(n_holes=n_holes, length=cutout_size)) 76 | if args.dataset == 'cifar': 77 | train_dataset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=ds_transforms) 78 | val_dataset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=ds_transforms) 79 | num_classes = 10 80 | elif args.dataset == 'cifar100': 81 | train_dataset = datasets.CIFAR100(root=DATA_DIR, train=True, download=True, transform=ds_transforms) 82 | val_dataset = datasets.CIFAR100(root=DATA_DIR, train=False, download=True, transform=ds_transforms) 83 | num_classes = 100 84 | elif args.dataset == 'svhn': 85 | train_dataset = datasets.SVHN(root=DATA_DIR, split='train', download=True, transform=ds_transforms) 86 | val_dataset = datasets.SVHN(root=DATA_DIR, split='test', download=True, transform=ds_transforms) 87 | num_classes = 10 88 | print(ds_transforms) 89 | 90 | 91 | image_datasets = {'train':train_dataset,'val':val_dataset} 92 | dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} 93 | 94 | 95 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64,shuffle=True,num_workers=4) 96 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64,shuffle=False,num_workers=4) 97 | 98 | dataloaders={'train':train_loader,'val':val_loader} 99 | 100 | 101 | print('device:',device) 102 | 103 | def train_model(model, criterion, optimizer, scheduler, num_epochs=20 ,mask=False): 104 | 105 | since = time.time() 106 | 107 | best_model_wts = copy.deepcopy(model.state_dict()) 108 | best_acc = 0.0 109 | 110 | for epoch in tqdm(range(num_epochs)): 111 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 112 | print('-' * 10) 113 | 114 | # Each epoch has a training and validation phase 115 | for phase in ['train', 'val']: 116 | if phase == 'train': 117 | model.train() # Set model to training mode 118 | else: 119 | model.eval() # Set model to evaluate mode 120 | 121 | running_loss = 0.0 122 | running_corrects = 0 123 | 124 | # Iterate over data. 125 | for inputs, labels in dataloaders[phase]: 126 | inputs = inputs.to(device) 127 | labels = labels.to(device) 128 | 129 | # zero the parameter gradients 130 | optimizer.zero_grad() 131 | 132 | # forward 133 | # track history if only in train 134 | with torch.set_grad_enabled(phase == 'train'): 135 | outputs = model(inputs) 136 | if isinstance(outputs,tuple): 137 | outputs = (outputs[0]+outputs[1])/2 138 | #outputs = outputs[0] 139 | 140 | _, preds = torch.max(outputs, 1) 141 | loss = criterion(outputs, labels) 142 | 143 | # backward + optimize only if in training phase 144 | if phase == 'train': 145 | loss.backward() 146 | optimizer.step() 147 | 148 | # statistics 149 | running_loss += loss.item() * inputs.size(0) 150 | running_corrects += torch.sum(preds == labels.data) 151 | if phase == 'train': 152 | scheduler.step() 153 | 154 | epoch_loss = running_loss / dataset_sizes[phase] 155 | epoch_acc = running_corrects.double() / dataset_sizes[phase] 156 | 157 | print('{} Loss: {:.4f} Acc: {:.4f}'.format( 158 | phase, epoch_loss, epoch_acc)) 159 | 160 | # deep copy the model 161 | if phase == 'val':# and epoch_acc > best_acc: 162 | best_acc = epoch_acc 163 | best_model_wts = copy.deepcopy(model.state_dict()) 164 | print('saving...') 165 | torch.save({ 166 | 'epoch': epoch, 167 | 'state_dict': best_model_wts, 168 | 'optimizer_state_dict': optimizer.state_dict(), 169 | 'scheduler_state_dict':scheduler.state_dict() 170 | }, os.path.join(MODEL_DIR,model_name)) 171 | 172 | print() 173 | 174 | time_elapsed = time.time() - since 175 | print('Training complete in {:.0f}m {:.0f}s'.format( 176 | time_elapsed // 60, time_elapsed % 60)) 177 | print('Best val Acc: {:4f}'.format(best_acc)) 178 | 179 | # load best model weights 180 | model.load_state_dict(best_model_wts) 181 | return model 182 | 183 | 184 | if args.dataset!='imagenet': 185 | model.reset_classifier(num_classes=num_classes) 186 | model = torch.nn.DataParallel(model) 187 | model = model.to(device) 188 | criterion = nn.CrossEntropyLoss() 189 | 190 | 191 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9) 192 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) 193 | 194 | #https://pytorch.org/tutorials/beginner/saving_loading_models.html 195 | if args.resume: 196 | print('restoring model from checkpoint...') 197 | checkpoint = torch.load(os.path.join(MODEL_DIR,model_name)) 198 | model.load_state_dict(checkpoint['state_dict']) 199 | model = model.to(device) 200 | #https://discuss.pytorch.org/t/code-that-loads-sgd-fails-to-load-adam-state-to-gpu/61783/3 201 | optimizer_conv.load_state_dict(checkpoint['optimizer_state_dict']) 202 | exp_lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 203 | 204 | 205 | model = train_model(model, criterion, optimizer, 206 | exp_lr_scheduler, num_epochs=args.epoch) 207 | 208 | -------------------------------------------------------------------------------- /misc/pc_mr_experimental.py: -------------------------------------------------------------------------------- 1 | # you are encouraged to read `notes_on_mr.md` first :) 2 | import torch 3 | import torch.backends.cudnn as cudnn 4 | 5 | import numpy as np 6 | import os 7 | import argparse 8 | import time 9 | from tqdm import tqdm 10 | import joblib 11 | 12 | from utils.setup import get_model,get_data_loader 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--model_dir",default='checkpoints',type=str,help="directory of checkpoints") 16 | parser.add_argument('--data_dir', default='data', type=str,help="directory of data") 17 | parser.add_argument('--dataset', default='imagenette',type=str,choices=('imagenette','imagenet','cifar','cifar100','svhn','flower102'),help="dataset") 18 | parser.add_argument("--model",default='vit_base_patch16_224',type=str,help="model name") 19 | parser.add_argument("--num_img",default=-1,type=int,help="number of randomly selected images for this experiment (-1: using the all images)") 20 | parser.add_argument("--mask_stride",default=-1,type=int,help="mask stride s (square patch; conflict with num_mask)") 21 | parser.add_argument("--num_mask",default=-1,type=int,help="number of mask in one dimension (square patch; conflict with mask_stride)") 22 | parser.add_argument("--patch_size",default=32,type=int,help="size of the adversarial patch (square patch)") 23 | parser.add_argument("--pa",default=-1,type=int,help="size of the adversarial patch (first axis; for rectangle patch)") 24 | parser.add_argument("--pb",default=-1,type=int,help="size of the adversarial patch (second axis; for rectangle patch)") 25 | parser.add_argument("--dump_dir",default='dump/dump_mr',type=str,help='directory to dump two-mask predictions') 26 | parser.add_argument("--override",action='store_true',help='override dumped file') 27 | parser.add_argument("--mr",default=2,type=int,help="we will use (mr+1)x(mr+1) prediction to generate one vote (default value is 2 since the originial MR uses 3x3 grids)") 28 | 29 | args = parser.parse_args() 30 | DATASET = args.dataset 31 | MODEL_DIR = os.path.join('.',args.model_dir) 32 | DATA_DIR = os.path.join(args.data_dir,DATASET) 33 | DUMP_DIR = os.path.join('.',args.dump_dir) 34 | if not os.path.exists(DUMP_DIR): 35 | os.mkdir(DUMP_DIR) 36 | 37 | MODEL_NAME = args.model 38 | NUM_IMG = args.num_img 39 | 40 | #get model and data loader 41 | model = get_model(MODEL_NAME,DATASET,MODEL_DIR) 42 | val_loader,NUM_IMG,ds_config = get_data_loader(DATASET,DATA_DIR,model,batch_size=16,num_img=NUM_IMG,train=False) 43 | 44 | device = 'cuda' 45 | model = model.to(device) 46 | model.eval() 47 | cudnn.benchmark = True 48 | num_classes = 1000 if args.dataset == 'imagenet' else 10 49 | 50 | def gen_mask_set_mr(args,ds_config,mr=2): 51 | 52 | # generate mask set 53 | #assert args.mask_stride * args.num_mask < 0 #can only set either mask_stride or num_mask 54 | assert args.mask_stride > 0 55 | IMG_SIZE = (ds_config['input_size'][1],ds_config['input_size'][2]) 56 | 57 | if args.pa>0 and args.pb>0: #rectangle patch 58 | PATCH_SIZE = (args.pa,args.pb) 59 | else: #square patch 60 | PATCH_SIZE = (args.patch_size,args.patch_size) 61 | 62 | if args.mask_stride>0: #use specified mask stride 63 | MASK_STRIDE = (args.mask_stride,args.mask_stride) 64 | 65 | # calculate mask size 66 | 67 | MASK_SIZE = (min(PATCH_SIZE[0]+MASK_STRIDE[0]*(mr+1)-1,IMG_SIZE[0]),min(PATCH_SIZE[1]+MASK_STRIDE[1]*(mr+1)-1,IMG_SIZE[1])) 68 | 69 | mask_list = [] 70 | idx_list1 = list(range(0,IMG_SIZE[0] - MASK_SIZE[0] + 1,MASK_STRIDE[0])) 71 | if (IMG_SIZE[0] - MASK_SIZE[0])%MASK_STRIDE[0]!=0: 72 | idx_list1.append(IMG_SIZE[0] - MASK_SIZE[0]) 73 | 74 | idx_list1 = [-j*MASK_STRIDE[0] for j in range(-mr,0)]+idx_list1+[IMG_SIZE[0] - MASK_SIZE[0]+j*MASK_STRIDE[0] for j in range(1,mr+1)] 75 | 76 | idx_list2 = list(range(0,IMG_SIZE[1] - MASK_SIZE[1] + 1,MASK_STRIDE[1])) 77 | if (IMG_SIZE[1] - MASK_SIZE[1])%MASK_STRIDE[1]!=0: 78 | idx_list2.append(IMG_SIZE[1] - MASK_SIZE[1]) 79 | idx_list2 = [-j*MASK_STRIDE[1] for j in range(-mr,0)]+idx_list2+[IMG_SIZE[1] - MASK_SIZE[1]+j*MASK_STRIDE[1] for j in range(1,mr+1)] 80 | 81 | for x in idx_list1: 82 | for y in idx_list2: 83 | mask = torch.ones([1,1,IMG_SIZE[0],IMG_SIZE[1]],dtype=bool).cuda() 84 | mask[...,max(x,0):min(x+MASK_SIZE[0],IMG_SIZE[0]),max(y,0):min(y+MASK_SIZE[1],IMG_SIZE[1])] = False 85 | mask_list.append(mask) 86 | return mask_list,MASK_SIZE,MASK_STRIDE 87 | 88 | 89 | 90 | 91 | # generate the mask set 92 | mask_list,MASK_SIZE,MASK_STRIDE = gen_mask_set_mr(args,ds_config,mr=args.mr) 93 | print(len(mask_list)) 94 | args.num_mask = int((len(mask_list))**0.5) 95 | 96 | 97 | 98 | def process_minority_report(prediction_map,confidence_map,mr=2): 99 | num_img = prediction_map.shape[0] 100 | num_pred = prediction_map.shape[1] 101 | voting_grid_pred = np.zeros([num_img,num_pred-mr,num_pred-mr],dtype=int) 102 | voting_grid_conf = np.zeros([num_img,num_pred-mr,num_pred-mr]) 103 | for a in range(num_img): 104 | for i in range(num_pred-mr): 105 | for j in range(num_pred-mr): 106 | confidence_vec = np.sum(confidence_map[a,i:i+mr+1,j:j+mr+1],axis=(0,1)) 107 | if mr >0: 108 | confidence_vec -= np.min(confidence_map[a,i:i+mr+1,j:j+mr+1],axis=(0,1)) 109 | confidence_vec /= (mr+1)**2 110 | pred = np.argmax(confidence_vec) 111 | conf = confidence_vec[pred] 112 | voting_grid_pred[a,i,j]=pred 113 | voting_grid_conf[a,i,j]=conf 114 | return voting_grid_pred,voting_grid_conf 115 | 116 | # the computation of two-mask predictions is expensive; will dump (or resue the dumped) two-mask predictions. 117 | SUFFIX = '_mr_one_mask_{}_{}_m{}_s{}_mr{}_{}.z'.format(DATASET,MODEL_NAME,MASK_SIZE,MASK_STRIDE,args.mr,NUM_IMG) 118 | #SUFFIX = '_mr_one_mask_{}_{}_m{}_s{}_{}.z'.format(DATASET,MODEL_NAME,MASK_SIZE,MASK_STRIDE,NUM_IMG) 119 | if not args.override and os.path.exists(os.path.join(DUMP_DIR,'prediction_map_list'+SUFFIX)): 120 | print('loading two-mask predictions') 121 | confidence_map_list = joblib.load(os.path.join(DUMP_DIR,'confidence_map_list'+SUFFIX)) 122 | prediction_map_list = joblib.load(os.path.join(DUMP_DIR,'prediction_map_list'+SUFFIX)) 123 | orig_prediction_list = joblib.load(os.path.join(DUMP_DIR,'orig_prediction_list_{}_{}_{}.z'.format(DATASET,MODEL_NAME,NUM_IMG))) 124 | label_list = joblib.load(os.path.join(DUMP_DIR,'label_list_{}_{}_{}.z'.format(DATASET,MODEL_NAME,NUM_IMG))) 125 | else: 126 | print('computing two-mask predictions') 127 | prediction_map_list = [] 128 | confidence_map_list = [] 129 | label_list = [] 130 | orig_prediction_list = [] 131 | for data,labels in tqdm(val_loader): 132 | data=data.to(device) 133 | labels = labels.numpy() 134 | num_img = data.shape[0] 135 | 136 | 137 | #two-mask predictions 138 | prediction_map = np.zeros([num_img,args.num_mask,args.num_mask],dtype=int) 139 | confidence_map = np.zeros([num_img,args.num_mask,args.num_mask,num_classes]) 140 | 141 | for i,mask in enumerate(mask_list): 142 | 143 | masked_output = model(torch.where(mask,data,torch.tensor(0.).cuda())) 144 | masked_output = torch.nn.functional.softmax(masked_output,dim=1) 145 | _, masked_pred = masked_output.max(1) 146 | masked_pred = masked_pred.detach().cpu().numpy() 147 | masked_conf = masked_output.detach().cpu().numpy() 148 | 149 | a,b = divmod(i,args.num_mask) 150 | prediction_map[:,a,b] = masked_pred 151 | confidence_map[:,a,b,:] = masked_conf 152 | 153 | prediction_map,confidence_map = process_minority_report(prediction_map,confidence_map,mr=args.mr) 154 | 155 | #vanilla predictions 156 | clean_output = model(data) 157 | clean_conf, clean_pred = clean_output.max(1) 158 | clean_pred = clean_pred.detach().cpu().numpy() 159 | orig_prediction_list.append(clean_pred) 160 | prediction_map_list.append(prediction_map) 161 | confidence_map_list.append(confidence_map) 162 | label_list.append(labels) 163 | 164 | prediction_map_list = np.concatenate(prediction_map_list) 165 | confidence_map_list = np.concatenate(confidence_map_list) 166 | orig_prediction_list = np.concatenate(orig_prediction_list) 167 | label_list = np.concatenate(label_list) 168 | 169 | joblib.dump(confidence_map_list,os.path.join(DUMP_DIR,'confidence_map_list'+SUFFIX)) 170 | joblib.dump(prediction_map_list,os.path.join(DUMP_DIR,'prediction_map_list'+SUFFIX)) 171 | joblib.dump(orig_prediction_list,os.path.join(DUMP_DIR,'orig_prediction_list_{}_{}_{}.z'.format(DATASET,MODEL_NAME,NUM_IMG))) 172 | joblib.dump(label_list,os.path.join(DUMP_DIR,'label_list_{}_{}_{}.z'.format(DATASET,MODEL_NAME,NUM_IMG))) 173 | 174 | 175 | 176 | def provable_detection(prediction_map,confidence_map,label,orig_pred,tau): 177 | if orig_pred != label: # clean prediction is incorrect 178 | return 0,0 # 0 for incorrect clean prediction 179 | clean = 1 180 | provable = 1 181 | if np.any(np.logical_or(prediction_map!=label,confidence_maptau)): 184 | clean = 0 185 | return provable,clean 186 | 187 | 188 | clean_list = [] 189 | robust_list = [] 190 | for tau in np.arange(0.,1.,0.1): 191 | print(tau) 192 | clean_corr = 0 193 | robust_cnt = 0 194 | vanilla_corr =0 195 | for i,(prediction_map,confidence_map,label,orig_pred) in enumerate(zip(prediction_map_list,confidence_map_list,label_list,orig_prediction_list)): 196 | #print(confidence_map) 197 | provable,clean = provable_detection(prediction_map,confidence_map,label,orig_pred,tau) 198 | robust_cnt+=provable 199 | clean_corr+=clean 200 | vanilla_corr +=orig_pred==label 201 | robust_list.append(robust_cnt/NUM_IMG) 202 | clean_list.append(clean_corr/NUM_IMG) 203 | print("Provable robust accuracy ({}):".format(tau),robust_cnt/NUM_IMG) 204 | print("Clean accuracy with defense:",clean_corr/NUM_IMG) 205 | print("Clean accuracy without defense:",vanilla_corr/NUM_IMG) 206 | print() 207 | print('clean_list=', [100*x for x in clean_list]) 208 | print('robust_list=',[100*x for x in robust_list]) -------------------------------------------------------------------------------- /utils/defense.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def gen_mask_set(args,ds_config): 5 | # generate a R-covering mask set 6 | ''' 7 | INPUT: 8 | args argparse.Namespace, the set of argumements/hyperparamters for mask set generation 9 | ds_config dict, data preprocessing dict 10 | 11 | OUTPUT: 12 | mask_list list of torch.tensor, the generation R-covering mask set, the binary masks are moved to CUDA 13 | MASK_SIZE tuple (int,int), the mask size along two axes 14 | MASK_STRIDE tuple (int,int), the mask stride along two axes 15 | ''' 16 | # generate mask set 17 | assert args.mask_stride * args.num_mask < 0 #can only set either mask_stride or num_mask 18 | 19 | IMG_SIZE = (ds_config['input_size'][1],ds_config['input_size'][2]) 20 | 21 | if args.pa>0 and args.pb>0: #rectangle patch 22 | PATCH_SIZE = (args.pa,args.pb) 23 | else: #square patch 24 | PATCH_SIZE = (args.patch_size,args.patch_size) 25 | 26 | if args.mask_stride>0: #use specified mask stride 27 | MASK_STRIDE = (args.mask_stride,args.mask_stride) 28 | else: #calculate mask stride based on the computation budget 29 | MASK_STRIDE = (int(np.ceil((IMG_SIZE[0] - PATCH_SIZE[0] + 1)/args.num_mask)),int(np.ceil((IMG_SIZE[1] - PATCH_SIZE[1] + 1)/args.num_mask))) 30 | 31 | # calculate mask size 32 | MASK_SIZE = (min(PATCH_SIZE[0]+MASK_STRIDE[0]-1,IMG_SIZE[0]),min(PATCH_SIZE[1]+MASK_STRIDE[1]-1,IMG_SIZE[1])) 33 | 34 | mask_list = [] 35 | idx_list1 = list(range(0,IMG_SIZE[0] - MASK_SIZE[0] + 1,MASK_STRIDE[0])) 36 | if (IMG_SIZE[0] - MASK_SIZE[0])%MASK_STRIDE[0]!=0: 37 | idx_list1.append(IMG_SIZE[0] - MASK_SIZE[0]) 38 | idx_list2 = list(range(0,IMG_SIZE[1] - MASK_SIZE[1] + 1,MASK_STRIDE[1])) 39 | if (IMG_SIZE[1] - MASK_SIZE[1])%MASK_STRIDE[1]!=0: 40 | idx_list2.append(IMG_SIZE[1] - MASK_SIZE[1]) 41 | 42 | for x in idx_list1: 43 | for y in idx_list2: 44 | mask = torch.ones([1,1,IMG_SIZE[0],IMG_SIZE[1]],dtype=bool).cuda() 45 | mask[...,x:x+MASK_SIZE[0],y:y+MASK_SIZE[1]] = False 46 | mask_list.append(mask) 47 | return mask_list,MASK_SIZE,MASK_STRIDE 48 | 49 | 50 | def double_masking(data,mask_list,model): 51 | # perform double masking inference with the input image, the mask set, and the undefended model 52 | ''' 53 | INPUT: 54 | data torch.Tensor [B,C,W,H], a batch of data 55 | mask_list a list of torch.Tensor, R-covering mask set 56 | model torch.nn.module, the vanilla undefended model 57 | 58 | OUTPUT: 59 | output_pred numpy.ndarray, the prediction labels 60 | ''' 61 | 62 | # first-round masking 63 | num_img = len(data) 64 | num_mask = len(mask_list) 65 | pred_one_mask_batch = np.zeros([num_img,num_mask],dtype=int) 66 | # compute one-mask prediction in batch 67 | for i,mask in enumerate(mask_list): 68 | masked_output = model(torch.where(mask,data,torch.tensor(0.).cuda())) 69 | _, masked_pred = masked_output.max(1) 70 | masked_pred = masked_pred.detach().cpu().numpy() 71 | pred_one_mask_batch[:,i] = masked_pred 72 | 73 | # determine the prediction label for each image 74 | output_pred = np.zeros([num_img],dtype=int) 75 | for j in range(num_img): 76 | pred_one_mask = pred_one_mask_batch[j] 77 | pred,cnt = np.unique(pred_one_mask,return_counts=True) 78 | 79 | if len(pred)==1: # unanimous agreement in the first-round masking 80 | defense_pred = pred[0] # Case I: agreed prediction 81 | else: 82 | sorted_idx = np.argsort(cnt) 83 | # get majority prediction and disagreer prediction 84 | majority_pred = pred[sorted_idx][-1] 85 | disagreer_pred = pred[sorted_idx][:-1] 86 | 87 | # second-round masking 88 | # get index list of the disagreer mask 89 | tmp = np.zeros_like(pred_one_mask,dtype=bool) 90 | for dis in disagreer_pred: 91 | tmp = np.logical_or(tmp,pred_one_mask==dis) 92 | disagreer_pred_mask_idx = np.where(tmp)[0] 93 | 94 | for i in disagreer_pred_mask_idx: 95 | dis = pred_one_mask[i] 96 | mask = mask_list[i] 97 | flg=True 98 | for mask2 in mask_list: 99 | # evaluate two-mask predictions 100 | masked_output = model(torch.where(torch.logical_and(mask,mask2),data[j],torch.tensor(0.).cuda())) 101 | masked_conf, masked_pred = masked_output.max(1) 102 | masked_pred = masked_pred.item() 103 | if masked_pred!=dis: # disagreement in the second-round masking -> discard the disagreer 104 | flg=False 105 | break 106 | if flg: 107 | defense_pred = dis # Case II: disagreer prediction 108 | break 109 | if not flg: 110 | defense_pred = majority_pred # Case III: majority prediction 111 | output_pred[j] = defense_pred 112 | return output_pred 113 | 114 | 115 | def double_masking_precomputed(prediction_map): 116 | # perform double masking inference with the pre-computed two-mask predictions 117 | ''' 118 | INPUT: 119 | prediction_map numpy.ndarray [num_mask,num_mask], the two-mask prediction map for a single data point 120 | 121 | OUTPUT: int, the prediction label 122 | ''' 123 | # first-round masking 124 | pred_one_mask = np.diag(prediction_map) 125 | pred,cnt = np.unique(pred_one_mask,return_counts=True) 126 | 127 | if len(pred) == 1: # unanimous agreement in the first-round masking 128 | return pred[0] # Case I: agreed prediction 129 | 130 | # get majority prediction and disagreer prediction 131 | sorted_idx = np.argsort(cnt) 132 | majority_pred = pred[sorted_idx][-1] 133 | disagreer_pred = pred[sorted_idx][:-1] 134 | 135 | # second-round masking 136 | # get index list of the disagreer mask 137 | tmp = np.zeros_like(pred_one_mask,dtype=bool) 138 | for dis in disagreer_pred: 139 | tmp = np.logical_or(tmp,pred_one_mask==dis) 140 | disagreer_pred_mask_idx = np.where(tmp)[0] 141 | 142 | for i in disagreer_pred_mask_idx: 143 | dis = pred_one_mask[i] 144 | # check all two-mask predictions 145 | tmp = prediction_map[i]==dis 146 | if np.all(tmp): 147 | return dis # Case II: disagreer prediction 148 | 149 | return majority_pred # Case III: majority prediction 150 | 151 | def certify_precomputed(prediction_map,label): 152 | # certify the robustness with pre-computed two mask prediction 153 | # check for two-mask correctness 154 | return np.all(prediction_map==label) 155 | 156 | 157 | 158 | 159 | def challenger_masking(data,mask_list,model): 160 | # perform challenger masking inference (discussed in the appendix) with the input image, the mask set, and the undefended model 161 | ''' 162 | INPUT: 163 | data torch.Tensor [B,C,W,H], a batch of data 164 | mask_list a list of torch.Tensor, R-covering mask set 165 | model torch.nn.module, the vanilla undefended model 166 | 167 | OUTPUT: 168 | output_pred numpy.ndarray, the prediction labels 169 | ''' 170 | 171 | # first-round masking 172 | num_img = len(data) 173 | num_mask = len(mask_list) 174 | pred_one_mask_batch = np.zeros([num_img,num_mask],dtype=int) 175 | # compute one-mask prediction in batch 176 | for i,mask in enumerate(mask_list): 177 | masked_output = model(torch.where(mask,data,torch.tensor(0.).cuda())) 178 | _, masked_pred = masked_output.max(1) 179 | masked_pred = masked_pred.detach().cpu().numpy() 180 | pred_one_mask_batch[:,i] = masked_pred 181 | 182 | # determine the prediction label for each image 183 | output_pred = np.zeros([num_img],dtype=int) 184 | for j in range(num_img): 185 | pred_one_mask = pred_one_mask_batch[j] 186 | pred,cnt = np.unique(pred_one_mask,return_counts=True) 187 | 188 | if len(pred)==1: # unanimous agreement in the first-round masking 189 | candidate_label = pred[0] 190 | else: 191 | # second-round masking (challenger game) 192 | candidate = 0 # take the index 0 as the winner candidate 193 | candidate_label = pred_one_mask[candidate] 194 | candidate_mask = mask_list[candidate] 195 | used_flg = np.zeros([num_mask],dtype=bool) 196 | #used_flg[candidate_mask]=True 197 | while len(np.unique(pred_one_mask[~used_flg]))>1: 198 | # find a challenger 199 | for challenger in range(0,num_mask): 200 | if used_flg[challenger]: 201 | continue 202 | challenger_label = pred_one_mask[challenger] 203 | if challenger_label==candidate_label: 204 | continue 205 | break 206 | # challenger game 207 | challenger_mask = mask_list[challenger] 208 | masked_output = model(torch.where(torch.logical_and(candidate_mask,challenger_mask),data[j],torch.tensor(0.).cuda())) 209 | _, masked_pred = masked_output.max(1) 210 | masked_pred = masked_pred.item() 211 | if masked_pred == challenger_label: 212 | used_flg[candidate]=True 213 | candidate = challenger 214 | candidate_label = challenger_label 215 | candidate_mask = challenger_mask 216 | else: 217 | used_flg[challenger]=True 218 | output_pred[j] = candidate_label 219 | return output_pred 220 | 221 | 222 | def challenger_masking_precomputed(prediction_map): 223 | # perform challenger masking inference with the pre-computed two-mask predictions 224 | ''' 225 | INPUT: 226 | prediction_map numpy.ndarray [num_mask,num_mask], the two-mask prediction map for a single data point 227 | 228 | OUTPUT: int, the prediction label 229 | ''' 230 | # first-round masking 231 | pred_one_mask = np.diag(prediction_map) 232 | pred,cnt = np.unique(pred_one_mask,return_counts=True) 233 | 234 | if len(pred) == 1: # unanimous agreement in the first-round masking 235 | candidate_label = pred[0] 236 | else: 237 | # second-round masking (challenger game) 238 | candidate = 0 # take the index 0 as the winner candidate 239 | candidate_label = pred_one_mask[candidate] 240 | num_mask = len(pred_one_mask) 241 | used_flg = np.zeros([num_mask],dtype=bool) 242 | while len(np.unique(pred_one_mask[~used_flg]))>1: 243 | # find a challenger 244 | for challenger in range(0,num_mask): 245 | if used_flg[challenger]: 246 | continue 247 | challenger_label = pred_one_mask[challenger] 248 | if challenger_label==candidate_label: 249 | continue 250 | break 251 | # challenger game 252 | masked_pred = prediction_map[candidate,challenger] 253 | if masked_pred == challenger_label: 254 | used_flg[candidate]=True 255 | candidate = challenger 256 | candidate_label = challenger_label 257 | else: 258 | used_flg[challenger]=True 259 | return candidate_label 260 | -------------------------------------------------------------------------------- /misc/reproducibility.md: -------------------------------------------------------------------------------- 1 | ## Overview 2 | 3 | This document provides a detailed guide to reproduce all experimental results in the main body of our PatchCleanser paper. 4 | 5 | ## Setup 6 | 7 | #### File directory 8 | Below is an overview of relevant files for the artifact evaluation. Please organize the files within the specified structure. 9 | ```shell 10 | ├── requirement.txt #required package 11 | ├── pc_certification.py #PatchCleanser: certify robustness via two-mask correctness 12 | ├── pc_clean_acc.py #PatchCleanser: evaluate clean accuracy and per-example inference time 13 | | 14 | ├── vanilla_clean_acc.py #undefended vanilla models: evaluate clean accuracy and per-example inference time 15 | | 16 | ├── utils 17 | | ├── setup.py #utils for constructing models and data loaders 18 | | ├── defense.py #utils for PatchCleanser defenses 19 | | └── cutout.py #utils for masked model training 20 | | 21 | ├── misc 22 | | ├── pc_mr.py #script for minority report (Figure 9) 23 | | └── pc_multiple.py #script for multiple patch shapes and multiple patches (Table 4) 24 | | 25 | ├── data 26 | | ├── imagenet #data directory for imagenet 27 | | | └── val #imagenet validation set 28 | | ├── imagenette #data directory for imagenette 29 | | | └── val #imagenette validation set 30 | | └── cifar #data directory for cifar-10 31 | | 32 | └── checkpoints #directory for checkpoints 33 | ├── README.md #details of checkpoints 34 | └── ... #model checkpoints 35 | ``` 36 | #### Dependency 37 | Install [PyTorch](https://pytorch.org/get-started/locally/) with GPU support. 38 | 39 | Install other packages `pip install -r requirement.txt`. 40 | 41 | #### Datasets 42 | - [ImageNet](https://image-net.org/download.php) (ILSVRC2012) - requires manual download; also available on [Kaggle](https://www.kaggle.com/c/imagenet-object-localization-challenge/data) 43 | - [ImageNette](https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz) - requires manual download 44 | - [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) - will be downloaded automatically within our code 45 | Move manually downloaded data to the directory `data/`. 46 | 47 | #### Checkpoints 48 | 1. Download the following checkpoints from the Google Drive [link](https://drive.google.com/drive/folders/1Ewks-NgJHDlpeAaGInz_jZ6iczcYNDlN?usp=sharing). 49 | 50 | ``` 51 | resnetv2_50x1_bit_distilled_cutout2_128_imagenet.pth 52 | resnetv2_50x1_bit_distilled_cutout2_128_imagenette.pth 53 | resnetv2_50x1_bit_distilled_cutout2_128_cifar.pth 54 | resmlp_24_distilled_224_cutout2_128_imagenet.pth 55 | resmlp_24_distilled_224_cutout2_128_imagenette.pth 56 | resmlp_24_distilled_224_cutout2_128_cifar.pth 57 | vit_base_patch16_224_cutout2_128_imagenet.pth 58 | vit_base_patch16_224_cutout2_128_imagenette.pth 59 | vit_base_patch16_224_cutout2_128_cifar.pth 60 | ``` 61 | 62 | 2. Move downloaded weights to the directory `checkpoints/`. 63 | 64 | ## Experiments 65 | 66 | In this section, we list all the shell commands used for getting experimental results for every table and figure. 67 | 68 | 1. Our evaluation metrics are **clean accuracy** and **certified robust accuracy**, which will be outputted to the console. Below is the expected output after running `python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img -1 --num_mask 6 --patch_size 32`. In this example, the clean accuracy of PatchCleanser defense is 83.9%, and the certified robust accuracy of PatchCleanser defense is 62.1%. These numbers match the results reported in Table 2 (ImageNet, PC-ViT, 2%-pixel patch). 69 | 70 | ``` 71 | Certified robust accuracy: 0.6207 72 | Clean accuracy with defense: 0.8394 73 | ``` 74 | 75 | 2. We also specified the estimated runtime (with one GPU) for each experiment below. 76 | - Running experiments for the entire dataset can take a long time. Please feel free to set ``--num_img`` to a small positive integer (e.g., 1000) to run experiments on a subset of the dataset. This will give an approximated evaluation result. 77 | - When ``--num_img`` is set to a negative integer, we will use the entire dataset for experiments; when it is set to a positive integer, we will use a random subset (with ``num_img`` images) for experiments. 78 | 79 | #### Table 2 (and Figure 2): our main evaluation results 80 | 81 | The following scripts allow us to obtain results for our main evaluation in Table 2 and Figure 2. 82 | 83 | ```shell 84 | # please feel free to set --num_img to a small positive integer (e.g., 1000) to reduce runtime; the script will run experiments on a random subset of the dataset to get an approximated result. 85 | 86 | #### imagenette 87 | # resnet (each takes 1-2hrs) 88 | python pc_certification.py --model resnetv2_50x1_bit_distilled_cutout2_128 --dataset imagenette --num_img -1 --num_mask 6 --patch_size 32 # 2% pixel patch 89 | python pc_certification.py --model resnetv2_50x1_bit_distilled_cutout2_128 --dataset imagenette --num_img -1 --num_mask 6 --patch_size 39 # 3% pixel patch 90 | python pc_certification.py --model resnetv2_50x1_bit_distilled_cutout2_128 --dataset imagenette --num_img -1 --num_mask 6 --patch_size 23 # 1% pixel patch 91 | # mlp (each takes 1-2hrs) 92 | python pc_certification.py --model resmlp_24_distilled_224_cutout2_128 --dataset imagenette --num_img -1 --num_mask 6 --patch_size 32 93 | python pc_certification.py --model resmlp_24_distilled_224_cutout2_128 --dataset imagenette --num_img -1 --num_mask 6 --patch_size 39 94 | python pc_certification.py --model resmlp_24_distilled_224_cutout2_128 --dataset imagenette --num_img -1 --num_mask 6 --patch_size 23 95 | # vit (each takes 3-4hrs) 96 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenette --num_img -1 --num_mask 6 --patch_size 32 97 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenette --num_img -1 --num_mask 6 --patch_size 39 98 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenette --num_img -1 --num_mask 6 --patch_size 23 99 | 100 | #### imagenet 101 | # resnet (each takes 16-17hrs) 102 | python pc_certification.py --model resnetv2_50x1_bit_distilled_cutout2_128 --dataset imagenet --num_img -1 --num_mask 6 --patch_size 32 103 | python pc_certification.py --model resnetv2_50x1_bit_distilled_cutout2_128 --dataset imagenet --num_img -1 --num_mask 6 --patch_size 39 104 | python pc_certification.py --model resnetv2_50x1_bit_distilled_cutout2_128 --dataset imagenet --num_img -1 --num_mask 6 --patch_size 23 105 | # mlp (each takes 16-17hrs) 106 | python pc_certification.py --model resmlp_24_distilled_224_cutout2_128 --dataset imagenet --num_img -1 --num_mask 6 --patch_size 32 107 | python pc_certification.py --model resmlp_24_distilled_224_cutout2_128 --dataset imagenet --num_img -1 --num_mask 6 --patch_size 39 108 | python pc_certification.py --model resmlp_24_distilled_224_cutout2_128 --dataset imagenet --num_img -1 --num_mask 6 --patch_size 23 109 | # vit (each takes 38-40hrs) 110 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img -1 --num_mask 6 --patch_size 32 111 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img -1 --num_mask 6 --patch_size 39 112 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img -1 --num_mask 6 --patch_size 23 113 | 114 | #### cifar 115 | # resnet (each takes 4-5hrs) 116 | python pc_certification.py --model resnetv2_50x1_bit_distilled_cutout2_128 --dataset cifar --num_img -1 --num_mask 6 --patch_size 35 # 2.4% pixel patch 117 | python pc_certification.py --model resnetv2_50x1_bit_distilled_cutout2_128 --dataset cifar --num_img -1 --num_mask 6 --patch_size 14 # 0.4% pixel patch 118 | # mlp (each takes 4-5hrs) 119 | python pc_certification.py --model resmlp_24_distilled_224_cutout2_128 --dataset cifar --num_img -1 --num_mask 6 --patch_size 35 120 | python pc_certification.py --model resmlp_24_distilled_224_cutout2_128 --dataset cifar --num_img -1 --num_mask 6 --patch_size 14 121 | # vit (each takes 11-12hrs) 122 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset cifar --num_img -1 --num_mask 6 --patch_size 35 123 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset cifar --num_img -1 --num_mask 6 --patch_size 14 124 | 125 | ``` 126 | 127 | 128 | 129 | #### Table 3: vanilla undefended models 130 | 131 | The following script is used for Table 3 (clean accuracy of vanilla undefended model). 132 | 133 | ```shell 134 | # takes a few minutes... 135 | # imagenette 136 | python vanilla_clean_acc.py --model resnetv2_50x1_bit_distilled_cutout2_128 --dataset imagenette --num_img -1 137 | python vanilla_clean_acc.py --model resmlp_24_distilled_224_cutout2_128 --dataset imagenette --num_img -1 138 | python vanilla_clean_acc.py --model vit_base_patch16_224_cutout2_128 --dataset imagenette --num_img -1 139 | 140 | #imagenet 141 | python vanilla_clean_acc.py --model resnetv2_50x1_bit_distilled_cutout2_128 --dataset imagenet --num_img -1 142 | python vanilla_clean_acc.py --model resmlp_24_distilled_224_cutout2_128 --dataset imagenet --num_img -1 143 | python vanilla_clean_acc.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img -1 144 | 145 | #cifar 146 | python vanilla_clean_acc.py --model resnetv2_50x1_bit_distilled_cutout2_128 --dataset cifar --num_img -1 147 | python vanilla_clean_acc.py --model resmlp_24_distilled_224_cutout2_128 --dataset cifar --num_img -1 148 | python vanilla_clean_acc.py --model vit_base_patch16_224_cutout2_128 --dataset cifar --num_img -1 149 | 150 | ``` 151 | 152 | 153 | 154 | #### Figure 4: defense with different numbers of masks $k$ 155 | 156 | The following script is used for Figure 4 (the effect of number of masks $k$). Note that we choose 5000 random ImageNet images and evaluate for 32x32 patches (2% pixels) 157 | 158 | ```shell 159 | # each takes ~0.5hr 160 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 2 --patch_size 32 161 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 3 --patch_size 32 162 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 4 --patch_size 32 163 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 5 --patch_size 32 164 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 6 --patch_size 32 165 | ``` 166 | 167 | 168 | 169 | #### Figure 5: defense runtime 170 | 171 | The following script evaluates per-example runtime (Figure 8) 172 | 173 | ```shell 174 | # each takes a few minutes 175 | python pc_clean_acc.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 2 --patch_size 32 176 | python pc_clean_acc.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 3 --patch_size 32 177 | python pc_clean_acc.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 4 --patch_size 32 178 | python pc_clean_acc.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 5 --patch_size 32 179 | python pc_clean_acc.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 6 --patch_size 32 180 | 181 | ``` 182 | 183 | 184 | 185 | #### Figure 6, 7, 8: defense with different (estimated) patch sizes 186 | 187 | The following script is used for evaluating defense performance with different patch sizes (or estimated patch sizes), results are plotted in Figure 5, 6, 7. 188 | 189 | ```shell 190 | # each takes ~0.5hr 191 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 6 --patch_size 48 192 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 6 --patch_size 64 193 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 6 --patch_size 80 194 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 6 --patch_size 96 195 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 6 --patch_size 112 196 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 6 --patch_size 128 197 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 6 --patch_size 144 198 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 6 --patch_size 160 199 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 6 --patch_size 176 200 | 201 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 6 --patch_size 40 202 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 6 --patch_size 56 203 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 6 --patch_size 72 204 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 6 --patch_size 88 205 | python pc_certification.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img 5000 --num_mask 6 --patch_size 104 206 | ``` 207 | 208 | 209 | 210 | #### Table 4: multiple shapes and multiple patches 211 | 212 | The following script evaluates defense performance for all 1% rectangular pixel patches and two 1% square patches (Table 4) 213 | 214 | ```shell 215 | # move the script to the main directory 216 | mv misc/pc_multiple.py pc_multiple.py 217 | 218 | # multiple shapes for 1%-pixel patch 219 | # each takes 11-12 hrs 220 | python pc_multiple.py --mode shape --dataset imagenet --model vit_base_patch16_224_cutout2_128 --num_img 500 --mask_stride 32 --patch_size 23 221 | python pc_multiple.py --mode shape --dataset imagenette --model vit_base_patch16_224_cutout2_128 --num_img 500 --mask_stride 32 --patch_size 23 222 | python pc_multiple.py --mode shape --dataset cifar --model vit_base_patch16_224_cutout2_128 --num_img 500 --mask_stride 32 --patch_size 23 223 | 224 | # two 1%-pixel patches 225 | # each takes 100hrs; setting --num_mask to a smaller number (e.g., 4) can significantly reduce runtime (at the cost of performance drops) 226 | python pc_multiple.py --mode twopatch --dataset imagenet --model vit_base_patch16_224_cutout2_128 --num_img 500 --num_mask 5 --patch_size 23 227 | python pc_multiple.py --mode twopatch --dataset imagenette --model vit_base_patch16_224_cutout2_128 --num_img 500 --num_mask 5 --patch_size 23 228 | python pc_multiple.py --mode twopatch --dataset cifar --model vit_base_patch16_224_cutout2_128 --num_img 500 --num_mask 5 --patch_size 23 229 | ``` 230 | 231 | 232 | 233 | #### Figure 9: Minority Report 234 | 235 | The following script evaluates defense performance for Minority Report using our mask set (Figure 9). 236 | 237 | ```shell 238 | # move the script to the main directory 239 | mv misc/pc_mr.py pc_mr.py 240 | # each takes 3-4 hrs 241 | python pc_mr.py --model vit_base_patch16_224_cutout2_128 --dataset imagenet --num_img -1 --num_mask 6 --patch_size 32 # the number of masks is 6x6=36 242 | ``` 243 | 244 | --------------------------------------------------------------------------------