├── .gitignore ├── LICENSE ├── README.md ├── eval.py ├── img ├── algorithm_sparse_rs.png ├── frames_adversarial_examples.png ├── illustrations_figure1.png ├── l0_adversarial_examples_targeted.png ├── l0_adversarial_examples_untargeted.png ├── patches_adversarial_examples.png ├── table_frames.png ├── table_l0_bb_wb.png ├── table_patches.png └── universal_patches_frames.png ├── rs_attacks.py ├── utils.py └── vis_images.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.pyc 3 | 4 | .idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Francesco Croce, Maksym Andriushchenko 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparse-RS: a versatile framework for query-efficient sparse black-box adversarial attacks 2 | **Francesco Croce, Maksym Andriushchenko, Naman D. Singh, Nicolas Flammarion, Matthias Hein** 3 | 4 | **University of Tübingen and EPFL** 5 | 6 | **Paper:** [https://arxiv.org/abs/2006.12834](https://arxiv.org/abs/2006.12834) 7 | 8 | **AAAI 2022** 9 | 10 | 11 | ## Abstract 12 | Sparse adversarial perturbations received much less attention in the literature compared to L2- and Linf-attacks. 13 | However, it is equally important to accurately assess the robustness of a model against sparse perturbations. Motivated by this goal, 14 | we propose a versatile framework based on random search, **Sparse-RS**, for score-based sparse targeted and untargeted attacks in 15 | the black-box setting. **Sparse-RS** does not rely on substitute models and achieves state-of-the-art success rate and query efficiency 16 | for multiple sparse attack models: L0-bounded perturbations, adversarial patches, and adversarial frames. Unlike existing methods, the 17 | L0-version of untargeted **Sparse-RS** achieves almost 100% success rate on ImageNet by perturbing *only* 0.1% of the total 18 | number of pixels, outperforming all existing white-box attacks including L0-PGD. Moreover, our untargeted **Sparse-RS** achieves very 19 | high success rates even for the challenging settings of 20x20 adversarial patches and 2-pixel wide adversarial frames for 224x224 20 | images. Finally, we show that **Sparse-RS** can be applied for universal adversarial patches where it significantly outperforms transfer-based approaches. 21 |

22 | 23 | 24 | ## About the paper 25 | Our proposed **Sparse-RS** framework is based on random search. Its main advantages are its simplicity and its wide applicability 26 | to multiple threat models: 27 |

28 | 29 | We illustrate the versatility of the **Sparse-RS** framework by generating various sparse perturbations: L0-bounded, adversarial patches, and adversarial frames: 30 |

31 |

32 |

33 | 34 | **Sparse-RS** also can successfully generate black-box **universal attacks** in sparse threat models without requiring a surrogate model: 35 |

36 | 37 | In all these threat models, **Sparse-RS** improves over the existing approaches: 38 |

39 |

40 | 41 | Moreover, for L0-perturbations **Sparse-RS** can even outperform existing **white-box** methods such as L0 PGD. 42 |

43 | 44 | 45 | 46 | ## Code of Sparse-RS 47 | The code is tested under Python 3.8.5 and PyTorch 1.8.0. It automatically downloads the pretrained models (either VGG-16-BN or ResNet-50) and requires access to ImageNet validation set. 48 | 49 | The following are examples of how to run the attacks in the different threat models. 50 | 51 | ### L0-bounded (pixel and feature space) 52 | In this case `k` represents the number of *pixels* to modify. For untargeted attacks 53 | ``` 54 | CUDA_VISIBLE_DEVICES=0 python eval.py --norm=L0 \ 55 | --model=[pt_vgg | pt_resnet] --n_queries=10000 --alpha_init=0.3 \ 56 | --data_path=/path/to/validation/set --k=150 --n_ex=500 57 | ``` 58 | and for targeted attacks please use `--targeted --n_queries=100000 --alpha_init=0.1`. The target class is randomly chosen for each point. 59 | 60 | To use an attack in the *feature* space please add `--use_feature_space` (in this case `k` indicates the number of features to modify). 61 | 62 | As additional options the flag `--constant_schedule` uses a constant schedule for `alpha` instead of the piecewise constant decreasing one, while with `--seed=N` it is possible to set a custom random seed. 63 | 64 | ### Image-specific patches and frames 65 | For untargeted image- and location-specific patches of size 20x20 (with `k=400`) 66 | ``` 67 | CUDA_VISIBLE_DEVICES=0 python eval.py --norm=patches \ 68 | --model=[pt_vgg | pt_resnet] --n_queries=10000 --alpha_init=0.4 \ 69 | --data_path=/path/to/validation/set --k=400 --n_ex=100 70 | ``` 71 | 72 | For targeted patches (size 40x40) please use `--targeted --n_queries=50000 --alpha_init=0.1 --k=1600`. The target class is randomly chosen for each point. 73 | 74 | For untargeted image-specific frames of width 2 pixels (with `k=2`) 75 | ``` 76 | CUDA_VISIBLE_DEVICES=0 python eval.py --norm=frames \ 77 | --model=[pt_vgg | pt_resnet] --n_queries=10000 --alpha_init=0.5 \ 78 | --data_path=/path/to/validation/set --k=2 --n_ex=100 79 | ``` 80 | 81 | For targeted frames (width of 3 pixels) please use `--targeted --n_queries=50000 --alpha_init=0.5 --k=3`. The target class is randomly chosen for each point. 82 | 83 | ### Universal patches and frames 84 | For targeted universal patches of size 50x50 (with `k=2500`) 85 | ``` 86 | CUDA_VISIBLE_DEVICES=0 python eval.py \ 87 | --norm=patches_universal --model=[pt_vgg | pt_resnet] \ 88 | --n_queries=100000 --alpha_init=0.3 \ 89 | --data_path=/path/to/validation/set --k=2500 \ 90 | --n_ex=30 --targeted --target_class=530 91 | ``` 92 | 93 | and for targeted universal frames of width 6 pixels (`k=6`) 94 | ``` 95 | CUDA_VISIBLE_DEVICES=0 python eval.py \ 96 | --norm=frames_universal --model=[pt_vgg | pt_resnet] \ 97 | --n_queries=100000 --alpha_init=1.667 \ 98 | --data_path=/path/to/validation/set --k=6 \ 99 | --n_ex=30 --targeted --target_class=530 100 | ``` 101 | The argument `--target_class` specifies the number corresponding to the target label. To generate universal attacks we use batches of 30 images resampled every 10000 queries. 102 | 103 | ## Visualizing resulting images 104 | We provide a script `vis_images.py` to visualize the images produced by the attacks. To use it please run 105 | 106 | ```python vis_images --path_data=/path/to/saved/results``` 107 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | 6 | import torchvision.datasets as datasets 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from torchvision import models as torch_models 10 | 11 | import sys 12 | import time 13 | from datetime import datetime 14 | 15 | from utils import SingleChannelModel 16 | 17 | model_class_dict = {'pt_vgg': torch_models.vgg16_bn, 18 | 'pt_resnet': torch_models.resnet50, 19 | } 20 | 21 | class PretrainedModel(): 22 | def __init__(self, modelname): 23 | model_pt = model_class_dict[modelname](pretrained=True) 24 | #model.eval() 25 | self.model = nn.DataParallel(model_pt.cuda()) 26 | self.model.eval() 27 | self.mu = torch.Tensor([0.485, 0.456, 0.406]).float().view(1, 3, 1, 1).cuda() 28 | self.sigma = torch.Tensor([0.229, 0.224, 0.225]).float().view(1, 3, 1, 1).cuda() 29 | 30 | def predict(self, x): 31 | out = (x - self.mu) / self.sigma 32 | return self.model(out) 33 | 34 | def forward(self, x): 35 | out = (x - self.mu) / self.sigma 36 | return self.model(out) 37 | 38 | def __call__(self, x): 39 | return self.predict(x) 40 | 41 | def random_target_classes(y_pred, n_classes): 42 | y = torch.zeros_like(y_pred) 43 | for counter in range(y_pred.shape[0]): 44 | l = list(range(n_classes)) 45 | l.remove(y_pred[counter]) 46 | t = torch.randint(0, len(l), size=[1]) 47 | y[counter] = l[t] + 0 48 | 49 | return y.long() 50 | 51 | if __name__ == '__main__': 52 | 53 | parser = argparse.ArgumentParser() 54 | 55 | parser.add_argument('--dataset', type=str, default='ImageNet') 56 | parser.add_argument('--data_path', type=str) 57 | parser.add_argument('--norm', type=str, default='L0') 58 | parser.add_argument('--k', default=150., type=float) 59 | parser.add_argument('--n_restarts', type=int, default=1) 60 | parser.add_argument('--loss', type=str, default='margin') 61 | parser.add_argument('--model', default='pt_vgg', type=str) 62 | parser.add_argument('--n_ex', type=int, default=1000) 63 | parser.add_argument('--attack', type=str, default='rs_attack') 64 | parser.add_argument('--n_queries', type=int, default=1000) 65 | parser.add_argument('--targeted', action='store_true') 66 | parser.add_argument('--target_class', type=int) 67 | parser.add_argument('--seed', type=int, default=0) 68 | parser.add_argument('--constant_schedule', action='store_true') 69 | parser.add_argument('--save_dir', type=str, default='./results') 70 | parser.add_argument('--use_feature_space', action='store_true') 71 | 72 | # Sparse-RS parameter 73 | parser.add_argument('--alpha_init', type=float, default=.3) 74 | parser.add_argument('--resample_period_univ', type=int) 75 | parser.add_argument('--loc_update_period', type=int) 76 | 77 | args = parser.parse_args() 78 | 79 | if args.data_path is None: 80 | args.data_path = "/scratch/datasets/imagenet/val" 81 | 82 | args.eps = args.k + 0 83 | args.bs = args.n_ex + 0 84 | args.p_init = args.alpha_init + 0. 85 | args.resample_loc = args.resample_period_univ 86 | args.update_loc_period = args.loc_update_period 87 | 88 | if args.dataset == 'ImageNet': 89 | # load pretrained model 90 | model = PretrainedModel(args.model) 91 | assert not model.model.training 92 | print(model.model.training) 93 | 94 | # load data 95 | IMAGENET_SL = 224 96 | IMAGENET_PATH = args.data_path 97 | imagenet = datasets.ImageFolder(IMAGENET_PATH, 98 | transforms.Compose([ 99 | transforms.Resize(IMAGENET_SL), 100 | transforms.CenterCrop(IMAGENET_SL), 101 | transforms.ToTensor() 102 | ])) 103 | torch.manual_seed(0) 104 | 105 | test_loader = data.DataLoader(imagenet, batch_size=args.bs, shuffle=True, num_workers=0) 106 | 107 | testiter = iter(test_loader) 108 | x_test, y_test = next(testiter) 109 | 110 | if args.attack in ['rs_attack']: 111 | # run Sparse-RS attacks 112 | logsdir = '{}/logs_{}_{}'.format(args.save_dir, args.attack, args.norm) 113 | savedir = '{}/{}_{}'.format(args.save_dir, args.attack, args.norm) 114 | if not os.path.exists(savedir): 115 | os.makedirs(savedir) 116 | if not os.path.exists(logsdir): 117 | os.makedirs(logsdir) 118 | 119 | if args.targeted or 'universal' in args.norm: 120 | args.loss = 'ce' 121 | data_loader = testiter if 'universal' in args.norm else None 122 | if args.use_feature_space: 123 | # reshape images to single color channel to perturb them individually 124 | assert args.norm == 'L0' 125 | bs, c, h, w = x_test.shape 126 | x_test = x_test.view(bs, 1, h, w * c) 127 | model = SingleChannelModel(model) 128 | str_space = 'feature space' 129 | else: 130 | str_space = 'pixel space' 131 | 132 | param_run = '{}_{}_{}_1_{}_nqueries_{:.0f}_pinit_{:.2f}_loss_{}_eps_{:.0f}_targeted_{}_targetclass_{}_seed_{:.0f}'.format( 133 | args.attack, args.norm, args.model, args.n_ex, args.n_queries, args.p_init, 134 | args.loss, args.eps, args.targeted, args.target_class, args.seed) 135 | if args.constant_schedule: 136 | param_run += '_constantpinit' 137 | if args.use_feature_space: 138 | param_run += '_featurespace' 139 | 140 | from rs_attacks import RSAttack 141 | adversary = RSAttack(model, norm=args.norm, eps=int(args.eps), verbose=True, n_queries=args.n_queries, 142 | p_init=args.p_init, log_path='{}/log_run_{}_{}.txt'.format(logsdir, str(datetime.now())[:-7], param_run), 143 | loss=args.loss, targeted=args.targeted, seed=args.seed, constant_schedule=args.constant_schedule, 144 | data_loader=data_loader, resample_loc=args.resample_loc) 145 | 146 | # set target classes 147 | if args.targeted and 'universal' in args.norm: 148 | if args.target_class is None: 149 | y_test = torch.ones_like(y_test) * torch.randint(1000, size=[1]).to(y_test.device) 150 | else: 151 | y_test = torch.ones_like(y_test) * args.target_class 152 | print('target labels', y_test) 153 | 154 | elif args.targeted: 155 | y_test = random_target_classes(y_test, 1000) 156 | print('target labels', y_test) 157 | 158 | bs = min(args.bs, 500) 159 | assert args.n_ex % args.bs == 0 160 | adv_complete = x_test.clone() 161 | qr_complete = torch.zeros([x_test.shape[0]]).cpu() 162 | pred = torch.zeros([0]).float().cpu() 163 | with torch.no_grad(): 164 | # find points originally correctly classified 165 | for counter in range(x_test.shape[0] // bs): 166 | x_curr = x_test[counter * bs:(counter + 1) * bs].cuda() 167 | y_curr = y_test[counter * bs:(counter + 1) * bs].cuda() 168 | output = model(x_curr) 169 | if not args.targeted: 170 | pred = torch.cat((pred, (output.max(1)[1] == y_curr).float().cpu()), dim=0) 171 | else: 172 | pred = torch.cat((pred, (output.max(1)[1] != y_curr).float().cpu()), dim=0) 173 | 174 | adversary.logger.log('clean accuracy {:.2%}'.format(pred.mean())) 175 | 176 | n_batches = pred.sum() // bs + 1 if pred.sum() % bs != 0 else pred.sum() // bs 177 | n_batches = n_batches.long().item() 178 | ind_to_fool = (pred == 1).nonzero().squeeze() 179 | 180 | # run the attack 181 | pred_adv = pred.clone() 182 | for counter in range(n_batches): 183 | x_curr = x_test[ind_to_fool[counter * bs:(counter + 1) * bs]].cuda() 184 | y_curr = y_test[ind_to_fool[counter * bs:(counter + 1) * bs]].cuda() 185 | qr_curr, adv = adversary.perturb(x_curr, y_curr) 186 | 187 | output = model(adv.cuda()) 188 | if not args.targeted: 189 | acc_curr = (output.max(1)[1] == y_curr).float().cpu() 190 | else: 191 | acc_curr = (output.max(1)[1] != y_curr).float().cpu() 192 | pred_adv[ind_to_fool[counter * bs:(counter + 1) * bs]] = acc_curr.clone() 193 | adv_complete[ind_to_fool[counter * bs:(counter + 1) * bs]] = adv.cpu().clone() 194 | qr_complete[ind_to_fool[counter * bs:(counter + 1) * bs]] = qr_curr.cpu().clone() 195 | 196 | print('batch {}/{} - {:.0f} of {} successfully perturbed'.format( 197 | counter + 1, n_batches, x_curr.shape[0] - acc_curr.sum(), x_curr.shape[0])) 198 | 199 | adversary.logger.log('robust accuracy {:.2%}'.format(pred_adv.float().mean())) 200 | 201 | # check robust accuracy and other statistics 202 | acc = 0. 203 | for counter in range(x_test.shape[0] // bs): 204 | x_curr = adv_complete[counter * bs:(counter + 1) * bs].cuda() 205 | y_curr = y_test[counter * bs:(counter + 1) * bs].cuda() 206 | output = model(x_curr) 207 | if not args.targeted: 208 | acc += (output.max(1)[1] == y_curr).float().sum().item() 209 | else: 210 | acc += (output.max(1)[1] != y_curr).float().sum().item() 211 | 212 | adversary.logger.log('robust accuracy {:.2%}'.format(acc / args.n_ex)) 213 | 214 | res = (adv_complete - x_test != 0.).max(dim=1)[0].sum(dim=(1, 2)) 215 | adversary.logger.log('max L0 perturbation ({}) {:.0f} - nan in img {} - max img {:.5f} - min img {:.5f}'.format( 216 | str_space, res.max(), (adv_complete != adv_complete).sum(), adv_complete.max(), adv_complete.min())) 217 | 218 | ind_corrcl = pred == 1. 219 | ind_succ = (pred_adv == 0.) * (pred == 1.) 220 | 221 | str_stats = 'success rate={:.0f}/{:.0f} ({:.2%}) \n'.format( 222 | pred.sum() - pred_adv.sum(), pred.sum(), (pred.sum() - pred_adv.sum()).float() / pred.sum()) +\ 223 | '[successful points] avg # queries {:.1f} - med # queries {:.1f}\n'.format( 224 | qr_complete[ind_succ].float().mean(), torch.median(qr_complete[ind_succ].float())) 225 | qr_complete[~ind_succ] = args.n_queries + 0 226 | str_stats += '[correctly classified points] avg # queries {:.1f} - med # queries {:.1f}\n'.format( 227 | qr_complete[ind_corrcl].float().mean(), torch.median(qr_complete[ind_corrcl].float())) 228 | adversary.logger.log(str_stats) 229 | 230 | # save results depending on the threat model 231 | if args.norm in ['L0', 'patches', 'frames']: 232 | if args.use_feature_space: 233 | # reshape perturbed images to original rgb format 234 | bs, _, h, w = adv_complete.shape 235 | adv_complete = adv_complete.view(bs, 3, h, w // 3) 236 | torch.save({'adv': adv_complete, 'qr': qr_complete}, 237 | '{}/{}.pth'.format(savedir, param_run)) 238 | 239 | elif args.norm in ['patches_universal']: 240 | # extract and save patch 241 | ind = (res > 0).nonzero().squeeze()[0] 242 | ind_patch = (((adv_complete[ind] - x_test[ind]).abs() > 0).max(0)[0] > 0).nonzero().squeeze() 243 | t = [ind_patch[:, 0].min().item(), ind_patch[:, 0].max().item(), ind_patch[:, 1].min().item(), ind_patch[:, 1].max().item()] 244 | loc = torch.tensor([t[0], t[2]]) 245 | s = t[1] - t[0] + 1 246 | patch = adv_complete[ind, :, loc[0]:loc[0] + s, loc[1]:loc[1] + s].unsqueeze(0) 247 | 248 | torch.save({'adv': adv_complete, 'patch': patch}, 249 | '{}/{}.pth'.format(savedir, param_run)) 250 | 251 | elif args.norm in ['frames_universal']: 252 | # extract and save frame and indeces of the perturbed pixels 253 | # to easily apply the frame to new images 254 | ind_img = (res > 0).nonzero().squeeze()[0] 255 | mask = torch.zeros(x_test.shape[-2:]) 256 | s = int(args.eps) 257 | mask[:s] = 1. 258 | mask[-s:] = 1. 259 | mask[:, :s] = 1. 260 | mask[:, -s:] = 1. 261 | ind = (mask == 1.).nonzero().squeeze() 262 | frame = adv_complete[ind_img, :, ind[:, 0], ind[:, 1]] 263 | 264 | torch.save({'adv': adv_complete, 'frame': frame, 'ind': ind}, 265 | '{}/{}.pth'.format(savedir, param_run)) 266 | 267 | -------------------------------------------------------------------------------- /img/algorithm_sparse_rs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fra31/sparse-rs/21d875969a1455e4d5b26dcf32c843e6262d1f9c/img/algorithm_sparse_rs.png -------------------------------------------------------------------------------- /img/frames_adversarial_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fra31/sparse-rs/21d875969a1455e4d5b26dcf32c843e6262d1f9c/img/frames_adversarial_examples.png -------------------------------------------------------------------------------- /img/illustrations_figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fra31/sparse-rs/21d875969a1455e4d5b26dcf32c843e6262d1f9c/img/illustrations_figure1.png -------------------------------------------------------------------------------- /img/l0_adversarial_examples_targeted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fra31/sparse-rs/21d875969a1455e4d5b26dcf32c843e6262d1f9c/img/l0_adversarial_examples_targeted.png -------------------------------------------------------------------------------- /img/l0_adversarial_examples_untargeted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fra31/sparse-rs/21d875969a1455e4d5b26dcf32c843e6262d1f9c/img/l0_adversarial_examples_untargeted.png -------------------------------------------------------------------------------- /img/patches_adversarial_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fra31/sparse-rs/21d875969a1455e4d5b26dcf32c843e6262d1f9c/img/patches_adversarial_examples.png -------------------------------------------------------------------------------- /img/table_frames.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fra31/sparse-rs/21d875969a1455e4d5b26dcf32c843e6262d1f9c/img/table_frames.png -------------------------------------------------------------------------------- /img/table_l0_bb_wb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fra31/sparse-rs/21d875969a1455e4d5b26dcf32c843e6262d1f9c/img/table_l0_bb_wb.png -------------------------------------------------------------------------------- /img/table_patches.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fra31/sparse-rs/21d875969a1455e4d5b26dcf32c843e6262d1f9c/img/table_patches.png -------------------------------------------------------------------------------- /img/universal_patches_frames.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fra31/sparse-rs/21d875969a1455e4d5b26dcf32c843e6262d1f9c/img/universal_patches_frames.png -------------------------------------------------------------------------------- /rs_attacks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import torch 14 | import time 15 | import math 16 | import torch.nn.functional as F 17 | 18 | import numpy as np 19 | import copy 20 | import sys 21 | from utils import Logger 22 | import os 23 | 24 | 25 | class RSAttack(): 26 | """ 27 | Sparse-RS attacks 28 | 29 | :param predict: forward pass function 30 | :param norm: type of the attack 31 | :param n_restarts: number of random restarts 32 | :param n_queries: max number of queries (each restart) 33 | :param eps: bound on the sparsity of perturbations 34 | :param seed: random seed for the starting point 35 | :param alpha_init: parameter to control alphai 36 | :param loss: loss function optimized ('margin', 'ce' supported) 37 | :param resc_schedule adapt schedule of alphai to n_queries 38 | :param device specify device to use 39 | :param log_path path to save logfile.txt 40 | :param constant_schedule use constant alphai 41 | :param targeted perform targeted attacks 42 | :param init_patches initialization for patches 43 | :param resample_loc period in queries of resampling images and 44 | locations for universal attacks 45 | :param data_loader loader to get new images for resampling 46 | :param update_loc_period period in queries of updates of the location 47 | for image-specific patches 48 | """ 49 | 50 | def __init__( 51 | self, 52 | predict, 53 | norm='L0', 54 | n_queries=5000, 55 | eps=None, 56 | p_init=.8, 57 | n_restarts=1, 58 | seed=0, 59 | verbose=True, 60 | targeted=False, 61 | loss='margin', 62 | resc_schedule=True, 63 | device=None, 64 | log_path=None, 65 | constant_schedule=False, 66 | init_patches='random_squares', 67 | resample_loc=None, 68 | data_loader=None, 69 | update_loc_period=None): 70 | """ 71 | Sparse-RS implementation in PyTorch 72 | """ 73 | 74 | self.predict = predict 75 | self.norm = norm 76 | self.n_queries = n_queries 77 | self.eps = eps 78 | self.p_init = p_init 79 | self.n_restarts = n_restarts 80 | self.seed = seed 81 | self.verbose = verbose 82 | self.targeted = targeted 83 | self.loss = loss 84 | self.rescale_schedule = resc_schedule 85 | self.device = device 86 | self.logger = Logger(log_path) 87 | self.constant_schedule = constant_schedule 88 | self.init_patches = init_patches 89 | self.resample_loc = n_queries // 10 if resample_loc is None else resample_loc 90 | self.data_loader = data_loader 91 | self.update_loc_period = update_loc_period if not update_loc_period is None else 4 if not targeted else 10 92 | 93 | 94 | def margin_and_loss(self, x, y): 95 | """ 96 | :param y: correct labels if untargeted else target labels 97 | """ 98 | 99 | logits = self.predict(x) 100 | xent = F.cross_entropy(logits, y, reduction='none') 101 | u = torch.arange(x.shape[0]) 102 | y_corr = logits[u, y].clone() 103 | logits[u, y] = -float('inf') 104 | y_others = logits.max(dim=-1)[0] 105 | 106 | if not self.targeted: 107 | if self.loss == 'ce': 108 | return y_corr - y_others, -1. * xent 109 | elif self.loss == 'margin': 110 | return y_corr - y_others, y_corr - y_others 111 | else: 112 | return y_others - y_corr, xent 113 | 114 | def init_hyperparam(self, x): 115 | assert self.norm in ['L0', 'patches', 'frames', 116 | 'patches_universal', 'frames_universal'] 117 | assert not self.eps is None 118 | assert self.loss in ['ce', 'margin'] 119 | 120 | if self.device is None: 121 | self.device = x.device 122 | self.orig_dim = list(x.shape[1:]) 123 | self.ndims = len(self.orig_dim) 124 | if self.seed is None: 125 | self.seed = time.time() 126 | if self.targeted: 127 | self.loss = 'ce' 128 | 129 | def random_target_classes(self, y_pred, n_classes): 130 | y = torch.zeros_like(y_pred) 131 | for counter in range(y_pred.shape[0]): 132 | l = list(range(n_classes)) 133 | l.remove(y_pred[counter]) 134 | t = self.random_int(0, len(l)) 135 | y[counter] = l[t] 136 | 137 | return y.long().to(self.device) 138 | 139 | def check_shape(self, x): 140 | return x if len(x.shape) == (self.ndims + 1) else x.unsqueeze(0) 141 | 142 | def random_choice(self, shape): 143 | t = 2 * torch.rand(shape).to(self.device) - 1 144 | return torch.sign(t) 145 | 146 | def random_int(self, low=0, high=1, shape=[1]): 147 | t = low + (high - low) * torch.rand(shape).to(self.device) 148 | return t.long() 149 | 150 | def normalize(self, x): 151 | if self.norm == 'Linf': 152 | t = x.abs().view(x.shape[0], -1).max(1)[0] 153 | return x / (t.view(-1, *([1] * self.ndims)) + 1e-12) 154 | 155 | elif self.norm == 'L2': 156 | t = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt() 157 | return x / (t.view(-1, *([1] * self.ndims)) + 1e-12) 158 | 159 | def lp_norm(self, x): 160 | if self.norm == 'L2': 161 | t = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt() 162 | return t.view(-1, *([1] * self.ndims)) 163 | 164 | def p_selection(self, it): 165 | """ schedule to decrease the parameter p """ 166 | 167 | if self.rescale_schedule: 168 | it = int(it / self.n_queries * 10000) 169 | 170 | if 'patches' in self.norm: 171 | if 10 < it <= 50: 172 | p = self.p_init / 2 173 | elif 50 < it <= 200: 174 | p = self.p_init / 4 175 | elif 200 < it <= 500: 176 | p = self.p_init / 8 177 | elif 500 < it <= 1000: 178 | p = self.p_init / 16 179 | elif 1000 < it <= 2000: 180 | p = self.p_init / 32 181 | elif 2000 < it <= 4000: 182 | p = self.p_init / 64 183 | elif 4000 < it <= 6000: 184 | p = self.p_init / 128 185 | elif 6000 < it <= 8000: 186 | p = self.p_init / 256 187 | elif 8000 < it: 188 | p = self.p_init / 512 189 | else: 190 | p = self.p_init 191 | 192 | elif 'frames' in self.norm: 193 | if not 'universal' in self.norm : 194 | tot_qr = 10000 if self.rescale_schedule else self.n_queries 195 | p = max((float(tot_qr - it) / tot_qr - .5) * self.p_init * self.eps ** 2, 0.) 196 | return 3. * math.ceil(p) 197 | 198 | else: 199 | assert self.rescale_schedule 200 | its = [200, 600, 1200, 1800, 2500, 10000, 100000] 201 | resc_factors = [1., .8, .6, .4, .2, .1, 0.] 202 | c = 0 203 | while it >= its[c]: 204 | c += 1 205 | return resc_factors[c] * self.p_init 206 | 207 | elif 'L0' in self.norm: 208 | if 0 < it <= 50: 209 | p = self.p_init / 2 210 | elif 50 < it <= 200: 211 | p = self.p_init / 4 212 | elif 200 < it <= 500: 213 | p = self.p_init / 5 214 | elif 500 < it <= 1000: 215 | p = self.p_init / 6 216 | elif 1000 < it <= 2000: 217 | p = self.p_init / 8 218 | elif 2000 < it <= 4000: 219 | p = self.p_init / 10 220 | elif 4000 < it <= 6000: 221 | p = self.p_init / 12 222 | elif 6000 < it <= 8000: 223 | p = self.p_init / 15 224 | elif 8000 < it: 225 | p = self.p_init / 20 226 | else: 227 | p = self.p_init 228 | 229 | if self.constant_schedule: 230 | p = self.p_init / 2 231 | 232 | return p 233 | 234 | def sh_selection(self, it): 235 | """ schedule to decrease the parameter p """ 236 | 237 | t = max((float(self.n_queries - it) / self.n_queries - .0) ** 1., 0) * .75 238 | 239 | return t 240 | 241 | def get_init_patch(self, c, s, n_iter=1000): 242 | if self.init_patches == 'stripes': 243 | patch_univ = torch.zeros([1, c, s, s]).to(self.device) + self.random_choice( 244 | [1, c, 1, s]).clamp(0., 1.) 245 | elif self.init_patches == 'uniform': 246 | patch_univ = torch.zeros([1, c, s, s]).to(self.device) + self.random_choice( 247 | [1, c, 1, 1]).clamp(0., 1.) 248 | elif self.init_patches == 'random': 249 | patch_univ = self.random_choice([1, c, s, s]).clamp(0., 1.) 250 | elif self.init_patches == 'random_squares': 251 | patch_univ = torch.zeros([1, c, s, s]).to(self.device) 252 | for _ in range(n_iter): 253 | size_init = torch.randint(low=1, high=math.ceil(s ** .5), size=[1]).item() 254 | loc_init = torch.randint(s - size_init + 1, size=[2]) 255 | patch_univ[0, :, loc_init[0]:loc_init[0] + size_init, loc_init[1]:loc_init[1] + size_init] = 0. 256 | patch_univ[0, :, loc_init[0]:loc_init[0] + size_init, loc_init[1]:loc_init[1] + size_init 257 | ] += self.random_choice([c, 1, 1]).clamp(0., 1.) 258 | elif self.init_patches == 'sh': 259 | patch_univ = torch.ones([1, c, s, s]).to(self.device) 260 | 261 | return patch_univ.clamp(0., 1.) 262 | 263 | def attack_single_run(self, x, y): 264 | with torch.no_grad(): 265 | adv = x.clone() 266 | c, h, w = x.shape[1:] 267 | n_features = c * h * w 268 | n_ex_total = x.shape[0] 269 | 270 | if self.norm == 'L0': 271 | eps = self.eps 272 | 273 | x_best = x.clone() 274 | n_pixels = h * w 275 | b_all, be_all = torch.zeros([x.shape[0], eps]).long(), torch.zeros([x.shape[0], n_pixels - eps]).long() 276 | for img in range(x.shape[0]): 277 | ind_all = torch.randperm(n_pixels) 278 | ind_p = ind_all[:eps] 279 | ind_np = ind_all[eps:] 280 | x_best[img, :, ind_p // w, ind_p % w] = self.random_choice([c, eps]).clamp(0., 1.) 281 | b_all[img] = ind_p.clone() 282 | be_all[img] = ind_np.clone() 283 | 284 | margin_min, loss_min = self.margin_and_loss(x_best, y) 285 | n_queries = torch.ones(x.shape[0]).to(self.device) 286 | 287 | for it in range(1, self.n_queries): 288 | # check points still to fool 289 | idx_to_fool = (margin_min > 0.).nonzero().squeeze() 290 | x_curr = self.check_shape(x[idx_to_fool]) 291 | x_best_curr = self.check_shape(x_best[idx_to_fool]) 292 | y_curr = y[idx_to_fool] 293 | margin_min_curr = margin_min[idx_to_fool] 294 | loss_min_curr = loss_min[idx_to_fool] 295 | b_curr, be_curr = b_all[idx_to_fool], be_all[idx_to_fool] 296 | if len(y_curr.shape) == 0: 297 | y_curr.unsqueeze_(0) 298 | margin_min_curr.unsqueeze_(0) 299 | loss_min_curr.unsqueeze_(0) 300 | b_curr.unsqueeze_(0) 301 | be_curr.unsqueeze_(0) 302 | idx_to_fool.unsqueeze_(0) 303 | 304 | # build new candidate 305 | x_new = x_best_curr.clone() 306 | eps_it = max(int(self.p_selection(it) * eps), 1) 307 | ind_p = torch.randperm(eps)[:eps_it] 308 | ind_np = torch.randperm(n_pixels - eps)[:eps_it] 309 | 310 | for img in range(x_new.shape[0]): 311 | p_set = b_curr[img, ind_p] 312 | np_set = be_curr[img, ind_np] 313 | x_new[img, :, p_set // w, p_set % w] = x_curr[img, :, p_set // w, p_set % w].clone() 314 | if eps_it > 1: 315 | x_new[img, :, np_set // w, np_set % w] = self.random_choice([c, eps_it]).clamp(0., 1.) 316 | else: 317 | # if update is 1x1 make sure the sampled color is different from the current one 318 | old_clr = x_new[img, :, np_set // w, np_set % w].clone() 319 | assert old_clr.shape == (c, 1), print(old_clr) 320 | new_clr = old_clr.clone() 321 | while (new_clr == old_clr).all().item(): 322 | new_clr = self.random_choice([c, 1]).clone().clamp(0., 1.) 323 | x_new[img, :, np_set // w, np_set % w] = new_clr.clone() 324 | 325 | # compute loss of the new candidates 326 | margin, loss = self.margin_and_loss(x_new, y_curr) 327 | n_queries[idx_to_fool] += 1 328 | 329 | # update best solution 330 | idx_improved = (loss < loss_min_curr).float() 331 | idx_to_update = (idx_improved > 0.).nonzero().squeeze() 332 | loss_min[idx_to_fool[idx_to_update]] = loss[idx_to_update] 333 | 334 | idx_miscl = (margin < -1e-6).float() 335 | idx_improved = torch.max(idx_improved, idx_miscl) 336 | nimpr = idx_improved.sum().item() 337 | if nimpr > 0.: 338 | idx_improved = (idx_improved.view(-1) > 0).nonzero().squeeze() 339 | margin_min[idx_to_fool[idx_improved]] = margin[idx_improved].clone() 340 | x_best[idx_to_fool[idx_improved]] = x_new[idx_improved].clone() 341 | t = b_curr[idx_improved].clone() 342 | te = be_curr[idx_improved].clone() 343 | 344 | if nimpr > 1: 345 | t[:, ind_p] = be_curr[idx_improved][:, ind_np] + 0 346 | te[:, ind_np] = b_curr[idx_improved][:, ind_p] + 0 347 | else: 348 | t[ind_p] = be_curr[idx_improved][ind_np] + 0 349 | te[ind_np] = b_curr[idx_improved][ind_p] + 0 350 | 351 | b_all[idx_to_fool[idx_improved]] = t.clone() 352 | be_all[idx_to_fool[idx_improved]] = te.clone() 353 | 354 | # log results current iteration 355 | ind_succ = (margin_min <= 0.).nonzero().squeeze() 356 | if self.verbose and ind_succ.numel() != 0: 357 | self.logger.log(' '.join(['{}'.format(it + 1), 358 | '- success rate={}/{} ({:.2%})'.format( 359 | ind_succ.numel(), n_ex_total, 360 | float(ind_succ.numel()) / n_ex_total), 361 | '- avg # queries={:.1f}'.format( 362 | n_queries[ind_succ].mean().item()), 363 | '- med # queries={:.1f}'.format( 364 | n_queries[ind_succ].median().item()), 365 | '- loss={:.3f}'.format(loss_min.mean()), 366 | '- max pert={:.0f}'.format(((x_new - x_curr).abs() > 0 367 | ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()), 368 | '- epsit={:.0f}'.format(eps_it), 369 | ])) 370 | 371 | if ind_succ.numel() == n_ex_total: 372 | break 373 | 374 | elif self.norm == 'patches': 375 | ''' assumes square images and patches ''' 376 | 377 | s = int(math.ceil(self.eps ** .5)) 378 | x_best = x.clone() 379 | x_new = x.clone() 380 | loc = torch.randint(h - s, size=[x.shape[0], 2]) 381 | patches_coll = torch.zeros([x.shape[0], c, s, s]).to(self.device) 382 | assert abs(self.update_loc_period) > 1 383 | loc_t = abs(self.update_loc_period) 384 | 385 | # set when to start single channel updates 386 | it_start_cu = None 387 | for it in range(0, self.n_queries): 388 | s_it = int(max(self.p_selection(it) ** .5 * s, 1)) 389 | if s_it == 1: 390 | break 391 | it_start_cu = it + (self.n_queries - it) // 2 392 | if self.verbose: 393 | self.logger.log('starting single channel updates at query {}'.format( 394 | it_start_cu)) 395 | 396 | # initialize patches 397 | if self.verbose: 398 | self.logger.log('using {} initialization'.format(self.init_patches)) 399 | for counter in range(x.shape[0]): 400 | patches_coll[counter] += self.get_init_patch(c, s).squeeze().clamp(0., 1.) 401 | x_new[counter, :, loc[counter, 0]:loc[counter, 0] + s, 402 | loc[counter, 1]:loc[counter, 1] + s] = patches_coll[counter].clone() 403 | 404 | margin_min, loss_min = self.margin_and_loss(x_new, y) 405 | n_queries = torch.ones(x.shape[0]).to(self.device) 406 | 407 | for it in range(1, self.n_queries): 408 | # check points still to fool 409 | idx_to_fool = (margin_min > -1e-6).nonzero().squeeze() 410 | x_curr = self.check_shape(x[idx_to_fool]) 411 | patches_curr = self.check_shape(patches_coll[idx_to_fool]) 412 | y_curr = y[idx_to_fool] 413 | margin_min_curr = margin_min[idx_to_fool] 414 | loss_min_curr = loss_min[idx_to_fool] 415 | loc_curr = loc[idx_to_fool] 416 | if len(y_curr.shape) == 0: 417 | y_curr.unsqueeze_(0) 418 | margin_min_curr.unsqueeze_(0) 419 | loss_min_curr.unsqueeze_(0) 420 | 421 | loc_curr.unsqueeze_(0) 422 | idx_to_fool.unsqueeze_(0) 423 | 424 | # sample update 425 | s_it = int(max(self.p_selection(it) ** .5 * s, 1)) 426 | p_it = torch.randint(s - s_it + 1, size=[2]) 427 | sh_it = int(max(self.sh_selection(it) * h, 0)) 428 | patches_new = patches_curr.clone() 429 | x_new = x_curr.clone() 430 | loc_new = loc_curr.clone() 431 | update_loc = int((it % loc_t == 0) and (sh_it > 0)) 432 | update_patch = 1. - update_loc 433 | if self.update_loc_period < 0 and sh_it > 0: 434 | update_loc = 1. - update_loc 435 | update_patch = 1. - update_patch 436 | for counter in range(x_curr.shape[0]): 437 | if update_patch == 1.: 438 | # update patch 439 | if it < it_start_cu: 440 | if s_it > 1: 441 | patches_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] += self.random_choice([c, 1, 1]) 442 | else: 443 | # make sure to sample a different color 444 | old_clr = patches_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it].clone() 445 | new_clr = old_clr.clone() 446 | while (new_clr == old_clr).all().item(): 447 | new_clr = self.random_choice([c, 1, 1]).clone().clamp(0., 1.) 448 | patches_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = new_clr.clone() 449 | else: 450 | assert s_it == 1 451 | assert it >= it_start_cu 452 | # single channel updates 453 | new_ch = self.random_int(low=0, high=3, shape=[1]) 454 | patches_new[counter, new_ch, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = 1. - patches_new[ 455 | counter, new_ch, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] 456 | 457 | patches_new[counter].clamp_(0., 1.) 458 | if update_loc == 1: 459 | # update location 460 | loc_new[counter] += (torch.randint(low=-sh_it, high=sh_it + 1, size=[2])) 461 | loc_new[counter].clamp_(0, h - s) 462 | 463 | x_new[counter, :, loc_new[counter, 0]:loc_new[counter, 0] + s, 464 | loc_new[counter, 1]:loc_new[counter, 1] + s] = patches_new[counter].clone() 465 | 466 | # check loss of new candidate 467 | margin, loss = self.margin_and_loss(x_new, y_curr) 468 | n_queries[idx_to_fool]+= 1 469 | 470 | # update best solution 471 | idx_improved = (loss < loss_min_curr).float() 472 | idx_to_update = (idx_improved > 0.).nonzero().squeeze() 473 | loss_min[idx_to_fool[idx_to_update]] = loss[idx_to_update] 474 | 475 | idx_miscl = (margin < -1e-6).float() 476 | idx_improved = torch.max(idx_improved, idx_miscl) 477 | nimpr = idx_improved.sum().item() 478 | if nimpr > 0.: 479 | idx_improved = (idx_improved.view(-1) > 0).nonzero().squeeze() 480 | margin_min[idx_to_fool[idx_improved]] = margin[idx_improved].clone() 481 | patches_coll[idx_to_fool[idx_improved]] = patches_new[idx_improved].clone() 482 | loc[idx_to_fool[idx_improved]] = loc_new[idx_improved].clone() 483 | 484 | # log results current iteration 485 | ind_succ = (margin_min <= 0.).nonzero().squeeze() 486 | if self.verbose and ind_succ.numel() != 0: 487 | self.logger.log(' '.join(['{}'.format(it + 1), 488 | '- success rate={}/{} ({:.2%})'.format( 489 | ind_succ.numel(), n_ex_total, 490 | float(ind_succ.numel()) / n_ex_total), 491 | '- avg # queries={:.1f}'.format( 492 | n_queries[ind_succ].mean().item()), 493 | '- med # queries={:.1f}'.format( 494 | n_queries[ind_succ].median().item()), 495 | '- loss={:.3f}'.format(loss_min.mean()), 496 | '- max pert={:.0f}'.format(((x_new - x_curr).abs() > 0 497 | ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()), 498 | #'- sit={:.0f} - sh={:.0f}'.format(s_it, sh_it), 499 | '{}'.format(' - loc' if update_loc == 1. else ''), 500 | ])) 501 | 502 | if ind_succ.numel() == n_ex_total: 503 | break 504 | 505 | # apply patches 506 | for counter in range(x.shape[0]): 507 | x_best[counter, :, loc[counter, 0]:loc[counter, 0] + s, 508 | loc[counter, 1]:loc[counter, 1] + s] = patches_coll[counter].clone() 509 | 510 | elif self.norm == 'patches_universal': 511 | ''' assumes square images and patches ''' 512 | 513 | s = int(math.ceil(self.eps ** .5)) 514 | x_best = x.clone() 515 | self.n_imgs = x.shape[0] 516 | x_new = x.clone() 517 | loc = torch.randint(h - s + 1, size=[x.shape[0], 2]) 518 | 519 | # set when to start single channel updates 520 | it_start_cu = None 521 | for it in range(0, self.n_queries): 522 | s_it = int(max(self.p_selection(it) ** .5 * s, 1)) 523 | if s_it == 1: 524 | break 525 | it_start_cu = it + (self.n_queries - it) // 2 526 | if self.verbose: 527 | self.logger.log('starting single channel updates at query {}'.format( 528 | it_start_cu)) 529 | 530 | # initialize patch 531 | if self.verbose: 532 | self.logger.log('using {} initialization'.format(self.init_patches)) 533 | patch_univ = self.get_init_patch(c, s) 534 | it_init = 0 535 | 536 | loss_batch = float(1e10) 537 | n_succs = 0 538 | n_iter = self.n_queries 539 | 540 | # init update batch 541 | assert not self.data_loader is None 542 | assert not self.resample_loc is None 543 | assert self.targeted 544 | new_train_imgs = [] 545 | n_newimgs = self.n_imgs + 0 546 | n_imgsneeded = math.ceil(self.n_queries / self.resample_loc) * n_newimgs 547 | tot_imgs = 0 548 | if self.verbose: 549 | self.logger.log('imgs updated={}, imgs needed={}'.format( 550 | n_newimgs, n_imgsneeded)) 551 | while tot_imgs < min(100000, n_imgsneeded): 552 | x_toupdatetrain, _ = next(self.data_loader) 553 | new_train_imgs.append(x_toupdatetrain) 554 | tot_imgs += x_toupdatetrain.shape[0] 555 | newimgstoadd = torch.cat(new_train_imgs, axis=0) 556 | counter_resamplingimgs = 0 557 | 558 | for it in range(it_init, n_iter): 559 | # sample size and location of the update 560 | s_it = int(max(self.p_selection(it) ** .5 * s, 1)) 561 | p_it = torch.randint(s - s_it + 1, size=[2]) 562 | 563 | patch_new = patch_univ.clone() 564 | 565 | if s_it > 1: 566 | patch_new[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] += self.random_choice([c, 1, 1]) 567 | else: 568 | old_clr = patch_new[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it].clone() 569 | new_clr = old_clr.clone() 570 | if it < it_start_cu: 571 | while (new_clr == old_clr).all().item(): 572 | new_clr = self.random_choice(new_clr).clone().clamp(0., 1.) 573 | else: 574 | # single channel update 575 | new_ch = self.random_int(low=0, high=3, shape=[1]) 576 | new_clr[new_ch] = 1. - new_clr[new_ch] 577 | 578 | patch_new[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = new_clr.clone() 579 | 580 | patch_new.clamp_(0., 1.) 581 | 582 | # compute loss for new candidate 583 | x_new = x.clone() 584 | 585 | for counter in range(x.shape[0]): 586 | loc_new = loc[counter] 587 | x_new[counter, :, loc_new[0]:loc_new[0] + s, loc_new[1]:loc_new[1] + s] = 0. 588 | x_new[counter, :, loc_new[0]:loc_new[0] + s, loc_new[1]:loc_new[1] + s] += patch_new[0] 589 | 590 | margin_run, loss_run = self.margin_and_loss(x_new, y) 591 | if self.loss == 'ce': 592 | loss_run += x_new.shape[0] 593 | loss_new = loss_run.sum() 594 | n_succs_new = (margin_run < -1e-6).sum().item() 595 | 596 | # accept candidate if loss improves 597 | if loss_new < loss_batch: 598 | is_accepted = True 599 | loss_batch = loss_new + 0. 600 | patch_univ = patch_new.clone() 601 | n_succs = n_succs_new + 0 602 | else: 603 | is_accepted = False 604 | 605 | # sample new locations and images 606 | if (it + 1) % self.resample_loc == 0: 607 | newimgstoadd_it = newimgstoadd[counter_resamplingimgs * n_newimgs:( 608 | counter_resamplingimgs + 1) * n_newimgs].clone().cuda() 609 | new_batch = [x[n_newimgs:].clone(), newimgstoadd_it.clone()] 610 | x = torch.cat(new_batch, dim=0) 611 | assert x.shape[0] == self.n_imgs 612 | 613 | loc = torch.randint(h - s + 1, size=[self.n_imgs, 2]) 614 | assert loc.shape == (self.n_imgs, 2) 615 | 616 | loss_batch = loss_batch * 0. + 1e6 617 | counter_resamplingimgs += 1 618 | 619 | # logging current iteration 620 | if self.verbose: 621 | self.logger.log(' '.join(['{}'.format(it + 1), 622 | '- success rate={}/{} ({:.2%})'.format( 623 | n_succs, n_ex_total, 624 | float(n_succs) / n_ex_total), 625 | '- loss={:.3f}'.format(loss_batch), 626 | '- max pert={:.0f}'.format(((x_new - x).abs() > 0 627 | ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()), 628 | ])) 629 | 630 | # apply patches on the initial images 631 | for counter in range(x_best.shape[0]): 632 | loc_new = loc[counter] 633 | x_best[counter, :, loc_new[0]:loc_new[0] + s, loc_new[1]:loc_new[1] + s] = 0. 634 | x_best[counter, :, loc_new[0]:loc_new[0] + s, loc_new[1]:loc_new[1] + s] += patch_univ[0] 635 | 636 | elif self.norm == 'frames': 637 | # set width and indices of frames 638 | mask = torch.zeros(x.shape[-2:]) 639 | s = self.eps + 0 640 | mask[:s] = 1. 641 | mask[-s:] = 1. 642 | mask[:, :s] = 1. 643 | mask[:, -s:] = 1. 644 | ind = (mask == 1.).nonzero().squeeze() 645 | eps = ind.shape[0] 646 | x_best = x.clone() 647 | x_new = x.clone() 648 | mask = mask.view(1, 1, h, w).to(self.device) 649 | mask_frame = torch.ones([1, c, h, w], device=x.device) * mask 650 | # 651 | 652 | # set when starting single channel updates 653 | it_start_cu = None 654 | for it in range(0, self.n_queries): 655 | s_it = int(max(self.p_selection(it), 1)) 656 | if s_it == 1: 657 | break 658 | it_start_cu = it + (self.n_queries - it) // 2 659 | #it_start_cu = 10000 660 | if self.verbose: 661 | self.logger.log('starting single channel updates at query {}'.format( 662 | it_start_cu)) 663 | 664 | # initialize frames 665 | x_best[:, :, ind[:, 0], ind[:, 1]] = self.random_choice( 666 | [x.shape[0], c, eps]).clamp(0., 1.) 667 | 668 | margin_min, loss_min = self.margin_and_loss(x_best, y) 669 | n_queries = torch.ones(x.shape[0]).to(self.device) 670 | 671 | for it in range(1, self.n_queries): 672 | # check points still to fool 673 | idx_to_fool = (margin_min > -1e-6).nonzero().squeeze() 674 | x_curr = self.check_shape(x[idx_to_fool]) 675 | x_best_curr = self.check_shape(x_best[idx_to_fool]) 676 | y_curr = y[idx_to_fool] 677 | margin_min_curr = margin_min[idx_to_fool] 678 | loss_min_curr = loss_min[idx_to_fool] 679 | 680 | if len(y_curr.shape) == 0: 681 | y_curr.unsqueeze_(0) 682 | margin_min_curr.unsqueeze_(0) 683 | loss_min_curr.unsqueeze_(0) 684 | idx_to_fool.unsqueeze_(0) 685 | 686 | # sample update 687 | s_it = max(int(self.p_selection(it)), 1) 688 | ind_it = torch.randperm(eps)[0] 689 | 690 | x_new = x_best_curr.clone() 691 | if s_it > 1: 692 | dir_h = self.random_choice([1]).long().cpu() 693 | dir_w = self.random_choice([1]).long().cpu() 694 | new_clr = self.random_choice([c, 1]).clamp(0., 1.) 695 | 696 | for counter in range(x_curr.shape[0]): 697 | if s_it > 1: 698 | for counter_h in range(s_it): 699 | for counter_w in range(s_it): 700 | x_new[counter, :, (ind[ind_it, 0] + dir_h * counter_h).clamp(0, h - 1), 701 | (ind[ind_it, 1] + dir_w * counter_w).clamp(0, w - 1)] = new_clr.clone() 702 | else: 703 | p_it = ind[ind_it].clone() 704 | old_clr = x_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it].clone() 705 | new_clr = old_clr.clone() 706 | if it < it_start_cu: 707 | while (new_clr == old_clr).all().item(): 708 | new_clr = self.random_choice([c, 1, 1]).clone().clamp(0., 1.) 709 | else: 710 | # single channel update 711 | new_ch = self.random_int(low=0, high=3, shape=[1]) 712 | new_clr[new_ch] = 1. - new_clr[new_ch] 713 | x_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = new_clr.clone() 714 | 715 | x_new.clamp_(0., 1.) 716 | x_new = (x_new - x_curr) * mask_frame + x_curr 717 | 718 | # check loss of new candidate 719 | margin, loss = self.margin_and_loss(x_new, y_curr) 720 | n_queries[idx_to_fool]+= 1 721 | 722 | # update best solution 723 | idx_improved = (loss < loss_min_curr).float() 724 | idx_to_update = (idx_improved > 0.).nonzero().squeeze() 725 | loss_min[idx_to_fool[idx_to_update]] = loss[idx_to_update] 726 | 727 | idx_miscl = (margin < -1e-6).float() 728 | idx_improved = torch.max(idx_improved, idx_miscl) 729 | nimpr = idx_improved.sum().item() 730 | if nimpr > 0.: 731 | idx_improved = (idx_improved.view(-1) > 0).nonzero().squeeze() 732 | margin_min[idx_to_fool[idx_improved]] = margin[idx_improved].clone() 733 | x_best[idx_to_fool[idx_improved]] = x_new[idx_improved].clone() 734 | 735 | # log results current iteration 736 | ind_succ = (margin_min <= 0.).nonzero().squeeze() 737 | if self.verbose and ind_succ.numel() != 0: 738 | self.logger.log(' '.join(['{}'.format(it + 1), 739 | '- success rate={}/{} ({:.2%})'.format( 740 | ind_succ.numel(), n_ex_total, 741 | float(ind_succ.numel()) / n_ex_total), 742 | '- avg # queries={:.1f}'.format( 743 | n_queries[ind_succ].mean().item()), 744 | '- med # queries={:.1f}'.format( 745 | n_queries[ind_succ].median().item()), 746 | '- loss={:.3f}'.format(loss_min.mean()), 747 | '- max pert={:.0f}'.format(((x_new - x_curr).abs() > 0 748 | ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()), 749 | #'- min pert={:.0f}'.format(((x_new - x_curr).abs() > 0 750 | #).max(1)[0].view(x_new.shape[0], -1).sum(-1).min()), 751 | #'- sit={:.0f} - indit={}'.format(s_it, ind_it.item()), 752 | ])) 753 | 754 | if ind_succ.numel() == n_ex_total: 755 | break 756 | 757 | 758 | elif self.norm == 'frames_universal': 759 | # set width and indices of frames 760 | mask = torch.zeros(x.shape[-2:]) 761 | s = self.eps + 0 762 | mask[:s] = 1. 763 | mask[-s:] = 1. 764 | mask[:, :s] = 1. 765 | mask[:, -s:] = 1. 766 | ind = (mask == 1.).nonzero().squeeze() 767 | eps = ind.shape[0] 768 | x_best = x.clone() 769 | x_new = x.clone() 770 | mask = mask.view(1, 1, h, w).to(self.device) 771 | mask_frame = torch.ones([1, c, h, w], device=x.device) * mask 772 | frame_univ = self.random_choice([1, c, eps]).clamp(0., 1.) 773 | 774 | # set when to start single channel updates 775 | it_start_cu = None 776 | for it in range(0, self.n_queries): 777 | s_it = int(max(self.p_selection(it) * s, 1)) 778 | if s_it == 1: 779 | break 780 | it_start_cu = it + (self.n_queries - it) // 2 781 | if self.verbose: 782 | self.logger.log('starting single channel updates at query {}'.format( 783 | it_start_cu)) 784 | 785 | self.n_imgs = x.shape[0] 786 | loss_batch = float(1e10) 787 | n_queries = torch.ones_like(y).float() 788 | 789 | # init update batch 790 | assert not self.data_loader is None 791 | assert not self.resample_loc is None 792 | assert self.targeted 793 | new_train_imgs = [] 794 | n_newimgs = self.n_imgs + 0 795 | n_imgsneeded = math.ceil(self.n_queries / self.resample_loc) * n_newimgs 796 | tot_imgs = 0 797 | if self.verbose: 798 | self.logger.log('imgs updated={}, imgs needed={}'.format( 799 | n_newimgs, n_imgsneeded)) 800 | while tot_imgs < min(100000, n_imgsneeded): 801 | x_toupdatetrain, _ = next(self.data_loader) 802 | new_train_imgs.append(x_toupdatetrain) 803 | tot_imgs += x_toupdatetrain.shape[0] 804 | newimgstoadd = torch.cat(new_train_imgs, axis=0) 805 | counter_resamplingimgs = 0 806 | 807 | for it in range(self.n_queries): 808 | # sample update 809 | s_it = max(int(self.p_selection(it) * self.eps), 1) 810 | ind_it = torch.randperm(eps)[0] 811 | 812 | mask_frame[:, :, ind[:, 0], ind[:, 1]] = 0 813 | mask_frame[:, :, ind[:, 0], ind[:, 1]] += frame_univ 814 | 815 | if s_it > 1: 816 | dir_h = self.random_choice([1]).long().cpu() 817 | dir_w = self.random_choice([1]).long().cpu() 818 | new_clr = self.random_choice([c, 1]).clamp(0., 1.) 819 | 820 | for counter_h in range(s_it): 821 | for counter_w in range(s_it): 822 | mask_frame[0, :, (ind[ind_it, 0] + dir_h * counter_h).clamp(0, h - 1), 823 | (ind[ind_it, 1] + dir_w * counter_w).clamp(0, w - 1)] = new_clr.clone() 824 | else: 825 | p_it = ind[ind_it] 826 | old_clr = mask_frame[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it].clone() 827 | new_clr = old_clr.clone() 828 | if it < it_start_cu: 829 | while (new_clr == old_clr).all().item(): 830 | new_clr = self.random_choice([c, 1, 1]).clone().clamp(0., 1.) 831 | else: 832 | # single channel update 833 | new_ch = self.random_int(low=0, high=3, shape=[1]) 834 | new_clr[new_ch] = 1. - new_clr[new_ch] 835 | mask_frame[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = new_clr.clone() 836 | 837 | frame_new = mask_frame[:, :, ind[:, 0], ind[:, 1]].clone() 838 | frame_new.clamp_(0., 1.) 839 | if len(frame_new.shape) == 2: 840 | frame_new.unsqueeze_(0) 841 | 842 | x_new[:, :, ind[:, 0], ind[:, 1]] = 0. 843 | x_new[:, :, ind[:, 0], ind[:, 1]] += frame_new 844 | 845 | margin_run, loss_run = self.margin_and_loss(x_new, y) 846 | if self.loss == 'ce': 847 | loss_run += x_new.shape[0] 848 | loss_new = loss_run.sum() 849 | n_succs_new = (margin_run < -1e-6).sum().item() 850 | 851 | # accept candidate if loss improves 852 | if loss_new < loss_batch: 853 | #is_accepted = True 854 | loss_batch = loss_new + 0. 855 | frame_univ = frame_new.clone() 856 | n_succs = n_succs_new + 0 857 | 858 | # sample new images 859 | if (it + 1) % self.resample_loc == 0: 860 | newimgstoadd_it = newimgstoadd[counter_resamplingimgs * n_newimgs:( 861 | counter_resamplingimgs + 1) * n_newimgs].clone().cuda() 862 | new_batch = [x[n_newimgs:].clone(), newimgstoadd_it.clone()] 863 | x = torch.cat(new_batch, dim=0) 864 | assert x.shape[0] == self.n_imgs 865 | 866 | loss_batch = loss_batch * 0. + 1e6 867 | x_new = x.clone() 868 | counter_resamplingimgs += 1 869 | 870 | # loggin current iteration 871 | if self.verbose: 872 | self.logger.log(' '.join(['{}'.format(it + 1), 873 | '- success rate={}/{} ({:.2%})'.format( 874 | n_succs, n_ex_total, 875 | float(n_succs) / n_ex_total), 876 | '- loss={:.3f}'.format(loss_batch), 877 | '- max pert={:.0f}'.format(((x_new - x).abs() > 0 878 | ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()), 879 | ])) 880 | 881 | # apply frame on initial images 882 | x_best[:, :, ind[:, 0], ind[:, 1]] = 0. 883 | x_best[:, :, ind[:, 0], ind[:, 1]] += frame_univ 884 | 885 | return n_queries, x_best 886 | 887 | def perturb(self, x, y=None): 888 | """ 889 | :param x: clean images 890 | :param y: untargeted attack -> clean labels, 891 | if None we use the predicted labels 892 | targeted attack -> target labels, if None random classes, 893 | different from the predicted ones, are sampled 894 | """ 895 | 896 | self.init_hyperparam(x) 897 | 898 | adv = x.clone() 899 | qr = torch.zeros([x.shape[0]]).to(self.device) 900 | if y is None: 901 | if not self.targeted: 902 | with torch.no_grad(): 903 | output = self.predict(x) 904 | y_pred = output.max(1)[1] 905 | y = y_pred.detach().clone().long().to(self.device) 906 | else: 907 | with torch.no_grad(): 908 | output = self.predict(x) 909 | n_classes = output.shape[-1] 910 | y_pred = output.max(1)[1] 911 | y = self.random_target_classes(y_pred, n_classes) 912 | else: 913 | y = y.detach().clone().long().to(self.device) 914 | 915 | if not self.targeted: 916 | acc = self.predict(x).max(1)[1] == y 917 | else: 918 | acc = self.predict(x).max(1)[1] != y 919 | 920 | startt = time.time() 921 | 922 | torch.random.manual_seed(self.seed) 923 | torch.cuda.random.manual_seed(self.seed) 924 | np.random.seed(self.seed) 925 | 926 | for counter in range(self.n_restarts): 927 | ind_to_fool = acc.nonzero().squeeze() 928 | if len(ind_to_fool.shape) == 0: 929 | ind_to_fool = ind_to_fool.unsqueeze(0) 930 | if ind_to_fool.numel() != 0: 931 | x_to_fool = x[ind_to_fool].clone() 932 | y_to_fool = y[ind_to_fool].clone() 933 | 934 | qr_curr, adv_curr = self.attack_single_run(x_to_fool, y_to_fool) 935 | 936 | output_curr = self.predict(adv_curr) 937 | if not self.targeted: 938 | acc_curr = output_curr.max(1)[1] == y_to_fool 939 | else: 940 | acc_curr = output_curr.max(1)[1] != y_to_fool 941 | ind_curr = (acc_curr == 0).nonzero().squeeze() 942 | 943 | acc[ind_to_fool[ind_curr]] = 0 944 | adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone() 945 | qr[ind_to_fool[ind_curr]] = qr_curr[ind_curr].clone() 946 | if self.verbose: 947 | print('restart {} - robust accuracy: {:.2%}'.format( 948 | counter, acc.float().mean()), 949 | '- cum. time: {:.1f} s'.format( 950 | time.time() - startt)) 951 | 952 | return qr, adv 953 | 954 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Logger(): 6 | def __init__(self, log_path): 7 | self.log_path = log_path 8 | 9 | def log(self, str_to_log): 10 | print(str_to_log) 11 | if not self.log_path is None: 12 | with open(self.log_path, 'a') as f: 13 | f.write(str_to_log + '\n') 14 | f.flush() 15 | 16 | 17 | class SingleChannelModel(): 18 | """ reshapes images to rgb before classification 19 | i.e. [N, 1, H, W x 3] -> [N, 3, H, W] 20 | """ 21 | def __init__(self, model): 22 | if isinstance(model, nn.Module): 23 | assert not model.training 24 | self.model = model 25 | 26 | def __call__(self, x): 27 | return self.model(x.view(x.shape[0], 3, x.shape[2], x.shape[3] // 3)) 28 | 29 | -------------------------------------------------------------------------------- /vis_images.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import os 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--path_data', type=str) 9 | 10 | args = parser.parse_args() 11 | 12 | if args.path_data is None: 13 | path_data = './results/ImageNet/sparse-rs_L0_pt_vgg_1_50_nqueries_1000_alphainit_0.30_loss_margin_eps_150_targeted_False_seed_0.pth' 14 | else: 15 | path_data = args.path_data 16 | 17 | data = torch.load(path_data) 18 | if 'qr' in list(data.keys()): 19 | imgs, qr = data['adv'].cpu(), data['qr'].cpu() 20 | else: 21 | imgs, qr = data['adv'].cpu(), torch.arange(1, data['adv'].shape[0]) 22 | 23 | nqueries = 100000 24 | ind = ((qr > 0) * (qr < nqueries)).nonzero().squeeze() 25 | imgs_to_show = imgs[ind].permute(0, 2, 3, 1).cpu().numpy() 26 | if imgs_to_show.shape[-1] == 1: 27 | imgs_to_show = np.tile(imgs_to_show, (1, 1, 1, 3)) 28 | qr_to_show = qr[ind] 29 | 30 | qr_to_show, ind = qr_to_show.sort(descending=True) 31 | imgs_to_show = imgs_to_show[ind] 32 | 33 | w = 10 34 | h = 10 35 | fig = plt.figure(figsize=(20, 12)) 36 | 37 | columns = 10 38 | rows = 5 39 | 40 | # ax enables access to manipulate each of subplots 41 | ax = [] 42 | 43 | init_pos = 0 44 | if 'patch' in list(data.keys()): 45 | ax.append( fig.add_subplot(rows, columns, 1) ) 46 | ax[-1].set_title('patch') 47 | ax[-1].get_xaxis().set_ticks([]) 48 | ax[-1].get_yaxis().set_ticks([]) 49 | ax[-1].axis('off') 50 | print(data['patch'].shape) 51 | s = int(float(path_data.split('eps_')[1].split('_')[0]) ** .5) 52 | patch = data['patch'].squeeze().view(-1, s, s) 53 | plt.imshow(patch.permute(1, 2, 0).cpu().numpy(), interpolation='none') 54 | init_pos = 1 55 | 56 | for i in range(init_pos, columns*rows): 57 | if i < imgs_to_show.shape[0]: 58 | ax.append( fig.add_subplot(rows, columns, i+1) ) 59 | ax[-1].set_title('qr = {:.0f}'.format(qr_to_show[i].item())) 60 | ax[-1].get_xaxis().set_ticks([]) 61 | ax[-1].get_yaxis().set_ticks([]) 62 | ax[-1].axis('off') 63 | plt.imshow(imgs_to_show[i], interpolation='none') 64 | 65 | if not os.path.exists('./results/plots/'): 66 | os.makedirs('./results/plots/') 67 | 68 | #plt.show() 69 | plt.savefig('./results/plots/pl_{}.pdf'.format(path_data.split('/')[-1][:-4])) 70 | --------------------------------------------------------------------------------