├── .gitignore
├── LICENSE
├── README.md
├── eval.py
├── img
├── algorithm_sparse_rs.png
├── frames_adversarial_examples.png
├── illustrations_figure1.png
├── l0_adversarial_examples_targeted.png
├── l0_adversarial_examples_untargeted.png
├── patches_adversarial_examples.png
├── table_frames.png
├── table_l0_bb_wb.png
├── table_patches.png
└── universal_patches_frames.png
├── rs_attacks.py
├── utils.py
└── vis_images.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | *.pyc
3 |
4 | .idea/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Francesco Croce, Maksym Andriushchenko
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Sparse-RS: a versatile framework for query-efficient sparse black-box adversarial attacks
2 | **Francesco Croce, Maksym Andriushchenko, Naman D. Singh, Nicolas Flammarion, Matthias Hein**
3 |
4 | **University of Tübingen and EPFL**
5 |
6 | **Paper:** [https://arxiv.org/abs/2006.12834](https://arxiv.org/abs/2006.12834)
7 |
8 | **AAAI 2022**
9 |
10 |
11 | ## Abstract
12 | Sparse adversarial perturbations received much less attention in the literature compared to L2- and Linf-attacks.
13 | However, it is equally important to accurately assess the robustness of a model against sparse perturbations. Motivated by this goal,
14 | we propose a versatile framework based on random search, **Sparse-RS**, for score-based sparse targeted and untargeted attacks in
15 | the black-box setting. **Sparse-RS** does not rely on substitute models and achieves state-of-the-art success rate and query efficiency
16 | for multiple sparse attack models: L0-bounded perturbations, adversarial patches, and adversarial frames. Unlike existing methods, the
17 | L0-version of untargeted **Sparse-RS** achieves almost 100% success rate on ImageNet by perturbing *only* 0.1% of the total
18 | number of pixels, outperforming all existing white-box attacks including L0-PGD. Moreover, our untargeted **Sparse-RS** achieves very
19 | high success rates even for the challenging settings of 20x20 adversarial patches and 2-pixel wide adversarial frames for 224x224
20 | images. Finally, we show that **Sparse-RS** can be applied for universal adversarial patches where it significantly outperforms transfer-based approaches.
21 |

22 |
23 |
24 | ## About the paper
25 | Our proposed **Sparse-RS** framework is based on random search. Its main advantages are its simplicity and its wide applicability
26 | to multiple threat models:
27 | 
28 |
29 | We illustrate the versatility of the **Sparse-RS** framework by generating various sparse perturbations: L0-bounded, adversarial patches, and adversarial frames:
30 | 

31 | 
32 | 
33 |
34 | **Sparse-RS** also can successfully generate black-box **universal attacks** in sparse threat models without requiring a surrogate model:
35 | 
36 |
37 | In all these threat models, **Sparse-RS** improves over the existing approaches:
38 | 
39 | 
40 |
41 | Moreover, for L0-perturbations **Sparse-RS** can even outperform existing **white-box** methods such as L0 PGD.
42 | 
43 |
44 |
45 |
46 | ## Code of Sparse-RS
47 | The code is tested under Python 3.8.5 and PyTorch 1.8.0. It automatically downloads the pretrained models (either VGG-16-BN or ResNet-50) and requires access to ImageNet validation set.
48 |
49 | The following are examples of how to run the attacks in the different threat models.
50 |
51 | ### L0-bounded (pixel and feature space)
52 | In this case `k` represents the number of *pixels* to modify. For untargeted attacks
53 | ```
54 | CUDA_VISIBLE_DEVICES=0 python eval.py --norm=L0 \
55 | --model=[pt_vgg | pt_resnet] --n_queries=10000 --alpha_init=0.3 \
56 | --data_path=/path/to/validation/set --k=150 --n_ex=500
57 | ```
58 | and for targeted attacks please use `--targeted --n_queries=100000 --alpha_init=0.1`. The target class is randomly chosen for each point.
59 |
60 | To use an attack in the *feature* space please add `--use_feature_space` (in this case `k` indicates the number of features to modify).
61 |
62 | As additional options the flag `--constant_schedule` uses a constant schedule for `alpha` instead of the piecewise constant decreasing one, while with `--seed=N` it is possible to set a custom random seed.
63 |
64 | ### Image-specific patches and frames
65 | For untargeted image- and location-specific patches of size 20x20 (with `k=400`)
66 | ```
67 | CUDA_VISIBLE_DEVICES=0 python eval.py --norm=patches \
68 | --model=[pt_vgg | pt_resnet] --n_queries=10000 --alpha_init=0.4 \
69 | --data_path=/path/to/validation/set --k=400 --n_ex=100
70 | ```
71 |
72 | For targeted patches (size 40x40) please use `--targeted --n_queries=50000 --alpha_init=0.1 --k=1600`. The target class is randomly chosen for each point.
73 |
74 | For untargeted image-specific frames of width 2 pixels (with `k=2`)
75 | ```
76 | CUDA_VISIBLE_DEVICES=0 python eval.py --norm=frames \
77 | --model=[pt_vgg | pt_resnet] --n_queries=10000 --alpha_init=0.5 \
78 | --data_path=/path/to/validation/set --k=2 --n_ex=100
79 | ```
80 |
81 | For targeted frames (width of 3 pixels) please use `--targeted --n_queries=50000 --alpha_init=0.5 --k=3`. The target class is randomly chosen for each point.
82 |
83 | ### Universal patches and frames
84 | For targeted universal patches of size 50x50 (with `k=2500`)
85 | ```
86 | CUDA_VISIBLE_DEVICES=0 python eval.py \
87 | --norm=patches_universal --model=[pt_vgg | pt_resnet] \
88 | --n_queries=100000 --alpha_init=0.3 \
89 | --data_path=/path/to/validation/set --k=2500 \
90 | --n_ex=30 --targeted --target_class=530
91 | ```
92 |
93 | and for targeted universal frames of width 6 pixels (`k=6`)
94 | ```
95 | CUDA_VISIBLE_DEVICES=0 python eval.py \
96 | --norm=frames_universal --model=[pt_vgg | pt_resnet] \
97 | --n_queries=100000 --alpha_init=1.667 \
98 | --data_path=/path/to/validation/set --k=6 \
99 | --n_ex=30 --targeted --target_class=530
100 | ```
101 | The argument `--target_class` specifies the number corresponding to the target label. To generate universal attacks we use batches of 30 images resampled every 10000 queries.
102 |
103 | ## Visualizing resulting images
104 | We provide a script `vis_images.py` to visualize the images produced by the attacks. To use it please run
105 |
106 | ```python vis_images --path_data=/path/to/saved/results```
107 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | import torch.nn as nn
5 |
6 | import torchvision.datasets as datasets
7 | import torch.utils.data as data
8 | import torchvision.transforms as transforms
9 | from torchvision import models as torch_models
10 |
11 | import sys
12 | import time
13 | from datetime import datetime
14 |
15 | from utils import SingleChannelModel
16 |
17 | model_class_dict = {'pt_vgg': torch_models.vgg16_bn,
18 | 'pt_resnet': torch_models.resnet50,
19 | }
20 |
21 | class PretrainedModel():
22 | def __init__(self, modelname):
23 | model_pt = model_class_dict[modelname](pretrained=True)
24 | #model.eval()
25 | self.model = nn.DataParallel(model_pt.cuda())
26 | self.model.eval()
27 | self.mu = torch.Tensor([0.485, 0.456, 0.406]).float().view(1, 3, 1, 1).cuda()
28 | self.sigma = torch.Tensor([0.229, 0.224, 0.225]).float().view(1, 3, 1, 1).cuda()
29 |
30 | def predict(self, x):
31 | out = (x - self.mu) / self.sigma
32 | return self.model(out)
33 |
34 | def forward(self, x):
35 | out = (x - self.mu) / self.sigma
36 | return self.model(out)
37 |
38 | def __call__(self, x):
39 | return self.predict(x)
40 |
41 | def random_target_classes(y_pred, n_classes):
42 | y = torch.zeros_like(y_pred)
43 | for counter in range(y_pred.shape[0]):
44 | l = list(range(n_classes))
45 | l.remove(y_pred[counter])
46 | t = torch.randint(0, len(l), size=[1])
47 | y[counter] = l[t] + 0
48 |
49 | return y.long()
50 |
51 | if __name__ == '__main__':
52 |
53 | parser = argparse.ArgumentParser()
54 |
55 | parser.add_argument('--dataset', type=str, default='ImageNet')
56 | parser.add_argument('--data_path', type=str)
57 | parser.add_argument('--norm', type=str, default='L0')
58 | parser.add_argument('--k', default=150., type=float)
59 | parser.add_argument('--n_restarts', type=int, default=1)
60 | parser.add_argument('--loss', type=str, default='margin')
61 | parser.add_argument('--model', default='pt_vgg', type=str)
62 | parser.add_argument('--n_ex', type=int, default=1000)
63 | parser.add_argument('--attack', type=str, default='rs_attack')
64 | parser.add_argument('--n_queries', type=int, default=1000)
65 | parser.add_argument('--targeted', action='store_true')
66 | parser.add_argument('--target_class', type=int)
67 | parser.add_argument('--seed', type=int, default=0)
68 | parser.add_argument('--constant_schedule', action='store_true')
69 | parser.add_argument('--save_dir', type=str, default='./results')
70 | parser.add_argument('--use_feature_space', action='store_true')
71 |
72 | # Sparse-RS parameter
73 | parser.add_argument('--alpha_init', type=float, default=.3)
74 | parser.add_argument('--resample_period_univ', type=int)
75 | parser.add_argument('--loc_update_period', type=int)
76 |
77 | args = parser.parse_args()
78 |
79 | if args.data_path is None:
80 | args.data_path = "/scratch/datasets/imagenet/val"
81 |
82 | args.eps = args.k + 0
83 | args.bs = args.n_ex + 0
84 | args.p_init = args.alpha_init + 0.
85 | args.resample_loc = args.resample_period_univ
86 | args.update_loc_period = args.loc_update_period
87 |
88 | if args.dataset == 'ImageNet':
89 | # load pretrained model
90 | model = PretrainedModel(args.model)
91 | assert not model.model.training
92 | print(model.model.training)
93 |
94 | # load data
95 | IMAGENET_SL = 224
96 | IMAGENET_PATH = args.data_path
97 | imagenet = datasets.ImageFolder(IMAGENET_PATH,
98 | transforms.Compose([
99 | transforms.Resize(IMAGENET_SL),
100 | transforms.CenterCrop(IMAGENET_SL),
101 | transforms.ToTensor()
102 | ]))
103 | torch.manual_seed(0)
104 |
105 | test_loader = data.DataLoader(imagenet, batch_size=args.bs, shuffle=True, num_workers=0)
106 |
107 | testiter = iter(test_loader)
108 | x_test, y_test = next(testiter)
109 |
110 | if args.attack in ['rs_attack']:
111 | # run Sparse-RS attacks
112 | logsdir = '{}/logs_{}_{}'.format(args.save_dir, args.attack, args.norm)
113 | savedir = '{}/{}_{}'.format(args.save_dir, args.attack, args.norm)
114 | if not os.path.exists(savedir):
115 | os.makedirs(savedir)
116 | if not os.path.exists(logsdir):
117 | os.makedirs(logsdir)
118 |
119 | if args.targeted or 'universal' in args.norm:
120 | args.loss = 'ce'
121 | data_loader = testiter if 'universal' in args.norm else None
122 | if args.use_feature_space:
123 | # reshape images to single color channel to perturb them individually
124 | assert args.norm == 'L0'
125 | bs, c, h, w = x_test.shape
126 | x_test = x_test.view(bs, 1, h, w * c)
127 | model = SingleChannelModel(model)
128 | str_space = 'feature space'
129 | else:
130 | str_space = 'pixel space'
131 |
132 | param_run = '{}_{}_{}_1_{}_nqueries_{:.0f}_pinit_{:.2f}_loss_{}_eps_{:.0f}_targeted_{}_targetclass_{}_seed_{:.0f}'.format(
133 | args.attack, args.norm, args.model, args.n_ex, args.n_queries, args.p_init,
134 | args.loss, args.eps, args.targeted, args.target_class, args.seed)
135 | if args.constant_schedule:
136 | param_run += '_constantpinit'
137 | if args.use_feature_space:
138 | param_run += '_featurespace'
139 |
140 | from rs_attacks import RSAttack
141 | adversary = RSAttack(model, norm=args.norm, eps=int(args.eps), verbose=True, n_queries=args.n_queries,
142 | p_init=args.p_init, log_path='{}/log_run_{}_{}.txt'.format(logsdir, str(datetime.now())[:-7], param_run),
143 | loss=args.loss, targeted=args.targeted, seed=args.seed, constant_schedule=args.constant_schedule,
144 | data_loader=data_loader, resample_loc=args.resample_loc)
145 |
146 | # set target classes
147 | if args.targeted and 'universal' in args.norm:
148 | if args.target_class is None:
149 | y_test = torch.ones_like(y_test) * torch.randint(1000, size=[1]).to(y_test.device)
150 | else:
151 | y_test = torch.ones_like(y_test) * args.target_class
152 | print('target labels', y_test)
153 |
154 | elif args.targeted:
155 | y_test = random_target_classes(y_test, 1000)
156 | print('target labels', y_test)
157 |
158 | bs = min(args.bs, 500)
159 | assert args.n_ex % args.bs == 0
160 | adv_complete = x_test.clone()
161 | qr_complete = torch.zeros([x_test.shape[0]]).cpu()
162 | pred = torch.zeros([0]).float().cpu()
163 | with torch.no_grad():
164 | # find points originally correctly classified
165 | for counter in range(x_test.shape[0] // bs):
166 | x_curr = x_test[counter * bs:(counter + 1) * bs].cuda()
167 | y_curr = y_test[counter * bs:(counter + 1) * bs].cuda()
168 | output = model(x_curr)
169 | if not args.targeted:
170 | pred = torch.cat((pred, (output.max(1)[1] == y_curr).float().cpu()), dim=0)
171 | else:
172 | pred = torch.cat((pred, (output.max(1)[1] != y_curr).float().cpu()), dim=0)
173 |
174 | adversary.logger.log('clean accuracy {:.2%}'.format(pred.mean()))
175 |
176 | n_batches = pred.sum() // bs + 1 if pred.sum() % bs != 0 else pred.sum() // bs
177 | n_batches = n_batches.long().item()
178 | ind_to_fool = (pred == 1).nonzero().squeeze()
179 |
180 | # run the attack
181 | pred_adv = pred.clone()
182 | for counter in range(n_batches):
183 | x_curr = x_test[ind_to_fool[counter * bs:(counter + 1) * bs]].cuda()
184 | y_curr = y_test[ind_to_fool[counter * bs:(counter + 1) * bs]].cuda()
185 | qr_curr, adv = adversary.perturb(x_curr, y_curr)
186 |
187 | output = model(adv.cuda())
188 | if not args.targeted:
189 | acc_curr = (output.max(1)[1] == y_curr).float().cpu()
190 | else:
191 | acc_curr = (output.max(1)[1] != y_curr).float().cpu()
192 | pred_adv[ind_to_fool[counter * bs:(counter + 1) * bs]] = acc_curr.clone()
193 | adv_complete[ind_to_fool[counter * bs:(counter + 1) * bs]] = adv.cpu().clone()
194 | qr_complete[ind_to_fool[counter * bs:(counter + 1) * bs]] = qr_curr.cpu().clone()
195 |
196 | print('batch {}/{} - {:.0f} of {} successfully perturbed'.format(
197 | counter + 1, n_batches, x_curr.shape[0] - acc_curr.sum(), x_curr.shape[0]))
198 |
199 | adversary.logger.log('robust accuracy {:.2%}'.format(pred_adv.float().mean()))
200 |
201 | # check robust accuracy and other statistics
202 | acc = 0.
203 | for counter in range(x_test.shape[0] // bs):
204 | x_curr = adv_complete[counter * bs:(counter + 1) * bs].cuda()
205 | y_curr = y_test[counter * bs:(counter + 1) * bs].cuda()
206 | output = model(x_curr)
207 | if not args.targeted:
208 | acc += (output.max(1)[1] == y_curr).float().sum().item()
209 | else:
210 | acc += (output.max(1)[1] != y_curr).float().sum().item()
211 |
212 | adversary.logger.log('robust accuracy {:.2%}'.format(acc / args.n_ex))
213 |
214 | res = (adv_complete - x_test != 0.).max(dim=1)[0].sum(dim=(1, 2))
215 | adversary.logger.log('max L0 perturbation ({}) {:.0f} - nan in img {} - max img {:.5f} - min img {:.5f}'.format(
216 | str_space, res.max(), (adv_complete != adv_complete).sum(), adv_complete.max(), adv_complete.min()))
217 |
218 | ind_corrcl = pred == 1.
219 | ind_succ = (pred_adv == 0.) * (pred == 1.)
220 |
221 | str_stats = 'success rate={:.0f}/{:.0f} ({:.2%}) \n'.format(
222 | pred.sum() - pred_adv.sum(), pred.sum(), (pred.sum() - pred_adv.sum()).float() / pred.sum()) +\
223 | '[successful points] avg # queries {:.1f} - med # queries {:.1f}\n'.format(
224 | qr_complete[ind_succ].float().mean(), torch.median(qr_complete[ind_succ].float()))
225 | qr_complete[~ind_succ] = args.n_queries + 0
226 | str_stats += '[correctly classified points] avg # queries {:.1f} - med # queries {:.1f}\n'.format(
227 | qr_complete[ind_corrcl].float().mean(), torch.median(qr_complete[ind_corrcl].float()))
228 | adversary.logger.log(str_stats)
229 |
230 | # save results depending on the threat model
231 | if args.norm in ['L0', 'patches', 'frames']:
232 | if args.use_feature_space:
233 | # reshape perturbed images to original rgb format
234 | bs, _, h, w = adv_complete.shape
235 | adv_complete = adv_complete.view(bs, 3, h, w // 3)
236 | torch.save({'adv': adv_complete, 'qr': qr_complete},
237 | '{}/{}.pth'.format(savedir, param_run))
238 |
239 | elif args.norm in ['patches_universal']:
240 | # extract and save patch
241 | ind = (res > 0).nonzero().squeeze()[0]
242 | ind_patch = (((adv_complete[ind] - x_test[ind]).abs() > 0).max(0)[0] > 0).nonzero().squeeze()
243 | t = [ind_patch[:, 0].min().item(), ind_patch[:, 0].max().item(), ind_patch[:, 1].min().item(), ind_patch[:, 1].max().item()]
244 | loc = torch.tensor([t[0], t[2]])
245 | s = t[1] - t[0] + 1
246 | patch = adv_complete[ind, :, loc[0]:loc[0] + s, loc[1]:loc[1] + s].unsqueeze(0)
247 |
248 | torch.save({'adv': adv_complete, 'patch': patch},
249 | '{}/{}.pth'.format(savedir, param_run))
250 |
251 | elif args.norm in ['frames_universal']:
252 | # extract and save frame and indeces of the perturbed pixels
253 | # to easily apply the frame to new images
254 | ind_img = (res > 0).nonzero().squeeze()[0]
255 | mask = torch.zeros(x_test.shape[-2:])
256 | s = int(args.eps)
257 | mask[:s] = 1.
258 | mask[-s:] = 1.
259 | mask[:, :s] = 1.
260 | mask[:, -s:] = 1.
261 | ind = (mask == 1.).nonzero().squeeze()
262 | frame = adv_complete[ind_img, :, ind[:, 0], ind[:, 1]]
263 |
264 | torch.save({'adv': adv_complete, 'frame': frame, 'ind': ind},
265 | '{}/{}.pth'.format(savedir, param_run))
266 |
267 |
--------------------------------------------------------------------------------
/img/algorithm_sparse_rs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fra31/sparse-rs/21d875969a1455e4d5b26dcf32c843e6262d1f9c/img/algorithm_sparse_rs.png
--------------------------------------------------------------------------------
/img/frames_adversarial_examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fra31/sparse-rs/21d875969a1455e4d5b26dcf32c843e6262d1f9c/img/frames_adversarial_examples.png
--------------------------------------------------------------------------------
/img/illustrations_figure1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fra31/sparse-rs/21d875969a1455e4d5b26dcf32c843e6262d1f9c/img/illustrations_figure1.png
--------------------------------------------------------------------------------
/img/l0_adversarial_examples_targeted.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fra31/sparse-rs/21d875969a1455e4d5b26dcf32c843e6262d1f9c/img/l0_adversarial_examples_targeted.png
--------------------------------------------------------------------------------
/img/l0_adversarial_examples_untargeted.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fra31/sparse-rs/21d875969a1455e4d5b26dcf32c843e6262d1f9c/img/l0_adversarial_examples_untargeted.png
--------------------------------------------------------------------------------
/img/patches_adversarial_examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fra31/sparse-rs/21d875969a1455e4d5b26dcf32c843e6262d1f9c/img/patches_adversarial_examples.png
--------------------------------------------------------------------------------
/img/table_frames.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fra31/sparse-rs/21d875969a1455e4d5b26dcf32c843e6262d1f9c/img/table_frames.png
--------------------------------------------------------------------------------
/img/table_l0_bb_wb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fra31/sparse-rs/21d875969a1455e4d5b26dcf32c843e6262d1f9c/img/table_l0_bb_wb.png
--------------------------------------------------------------------------------
/img/table_patches.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fra31/sparse-rs/21d875969a1455e4d5b26dcf32c843e6262d1f9c/img/table_patches.png
--------------------------------------------------------------------------------
/img/universal_patches_frames.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fra31/sparse-rs/21d875969a1455e4d5b26dcf32c843e6262d1f9c/img/universal_patches_frames.png
--------------------------------------------------------------------------------
/rs_attacks.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-present
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 | from __future__ import unicode_literals
12 |
13 | import torch
14 | import time
15 | import math
16 | import torch.nn.functional as F
17 |
18 | import numpy as np
19 | import copy
20 | import sys
21 | from utils import Logger
22 | import os
23 |
24 |
25 | class RSAttack():
26 | """
27 | Sparse-RS attacks
28 |
29 | :param predict: forward pass function
30 | :param norm: type of the attack
31 | :param n_restarts: number of random restarts
32 | :param n_queries: max number of queries (each restart)
33 | :param eps: bound on the sparsity of perturbations
34 | :param seed: random seed for the starting point
35 | :param alpha_init: parameter to control alphai
36 | :param loss: loss function optimized ('margin', 'ce' supported)
37 | :param resc_schedule adapt schedule of alphai to n_queries
38 | :param device specify device to use
39 | :param log_path path to save logfile.txt
40 | :param constant_schedule use constant alphai
41 | :param targeted perform targeted attacks
42 | :param init_patches initialization for patches
43 | :param resample_loc period in queries of resampling images and
44 | locations for universal attacks
45 | :param data_loader loader to get new images for resampling
46 | :param update_loc_period period in queries of updates of the location
47 | for image-specific patches
48 | """
49 |
50 | def __init__(
51 | self,
52 | predict,
53 | norm='L0',
54 | n_queries=5000,
55 | eps=None,
56 | p_init=.8,
57 | n_restarts=1,
58 | seed=0,
59 | verbose=True,
60 | targeted=False,
61 | loss='margin',
62 | resc_schedule=True,
63 | device=None,
64 | log_path=None,
65 | constant_schedule=False,
66 | init_patches='random_squares',
67 | resample_loc=None,
68 | data_loader=None,
69 | update_loc_period=None):
70 | """
71 | Sparse-RS implementation in PyTorch
72 | """
73 |
74 | self.predict = predict
75 | self.norm = norm
76 | self.n_queries = n_queries
77 | self.eps = eps
78 | self.p_init = p_init
79 | self.n_restarts = n_restarts
80 | self.seed = seed
81 | self.verbose = verbose
82 | self.targeted = targeted
83 | self.loss = loss
84 | self.rescale_schedule = resc_schedule
85 | self.device = device
86 | self.logger = Logger(log_path)
87 | self.constant_schedule = constant_schedule
88 | self.init_patches = init_patches
89 | self.resample_loc = n_queries // 10 if resample_loc is None else resample_loc
90 | self.data_loader = data_loader
91 | self.update_loc_period = update_loc_period if not update_loc_period is None else 4 if not targeted else 10
92 |
93 |
94 | def margin_and_loss(self, x, y):
95 | """
96 | :param y: correct labels if untargeted else target labels
97 | """
98 |
99 | logits = self.predict(x)
100 | xent = F.cross_entropy(logits, y, reduction='none')
101 | u = torch.arange(x.shape[0])
102 | y_corr = logits[u, y].clone()
103 | logits[u, y] = -float('inf')
104 | y_others = logits.max(dim=-1)[0]
105 |
106 | if not self.targeted:
107 | if self.loss == 'ce':
108 | return y_corr - y_others, -1. * xent
109 | elif self.loss == 'margin':
110 | return y_corr - y_others, y_corr - y_others
111 | else:
112 | return y_others - y_corr, xent
113 |
114 | def init_hyperparam(self, x):
115 | assert self.norm in ['L0', 'patches', 'frames',
116 | 'patches_universal', 'frames_universal']
117 | assert not self.eps is None
118 | assert self.loss in ['ce', 'margin']
119 |
120 | if self.device is None:
121 | self.device = x.device
122 | self.orig_dim = list(x.shape[1:])
123 | self.ndims = len(self.orig_dim)
124 | if self.seed is None:
125 | self.seed = time.time()
126 | if self.targeted:
127 | self.loss = 'ce'
128 |
129 | def random_target_classes(self, y_pred, n_classes):
130 | y = torch.zeros_like(y_pred)
131 | for counter in range(y_pred.shape[0]):
132 | l = list(range(n_classes))
133 | l.remove(y_pred[counter])
134 | t = self.random_int(0, len(l))
135 | y[counter] = l[t]
136 |
137 | return y.long().to(self.device)
138 |
139 | def check_shape(self, x):
140 | return x if len(x.shape) == (self.ndims + 1) else x.unsqueeze(0)
141 |
142 | def random_choice(self, shape):
143 | t = 2 * torch.rand(shape).to(self.device) - 1
144 | return torch.sign(t)
145 |
146 | def random_int(self, low=0, high=1, shape=[1]):
147 | t = low + (high - low) * torch.rand(shape).to(self.device)
148 | return t.long()
149 |
150 | def normalize(self, x):
151 | if self.norm == 'Linf':
152 | t = x.abs().view(x.shape[0], -1).max(1)[0]
153 | return x / (t.view(-1, *([1] * self.ndims)) + 1e-12)
154 |
155 | elif self.norm == 'L2':
156 | t = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt()
157 | return x / (t.view(-1, *([1] * self.ndims)) + 1e-12)
158 |
159 | def lp_norm(self, x):
160 | if self.norm == 'L2':
161 | t = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt()
162 | return t.view(-1, *([1] * self.ndims))
163 |
164 | def p_selection(self, it):
165 | """ schedule to decrease the parameter p """
166 |
167 | if self.rescale_schedule:
168 | it = int(it / self.n_queries * 10000)
169 |
170 | if 'patches' in self.norm:
171 | if 10 < it <= 50:
172 | p = self.p_init / 2
173 | elif 50 < it <= 200:
174 | p = self.p_init / 4
175 | elif 200 < it <= 500:
176 | p = self.p_init / 8
177 | elif 500 < it <= 1000:
178 | p = self.p_init / 16
179 | elif 1000 < it <= 2000:
180 | p = self.p_init / 32
181 | elif 2000 < it <= 4000:
182 | p = self.p_init / 64
183 | elif 4000 < it <= 6000:
184 | p = self.p_init / 128
185 | elif 6000 < it <= 8000:
186 | p = self.p_init / 256
187 | elif 8000 < it:
188 | p = self.p_init / 512
189 | else:
190 | p = self.p_init
191 |
192 | elif 'frames' in self.norm:
193 | if not 'universal' in self.norm :
194 | tot_qr = 10000 if self.rescale_schedule else self.n_queries
195 | p = max((float(tot_qr - it) / tot_qr - .5) * self.p_init * self.eps ** 2, 0.)
196 | return 3. * math.ceil(p)
197 |
198 | else:
199 | assert self.rescale_schedule
200 | its = [200, 600, 1200, 1800, 2500, 10000, 100000]
201 | resc_factors = [1., .8, .6, .4, .2, .1, 0.]
202 | c = 0
203 | while it >= its[c]:
204 | c += 1
205 | return resc_factors[c] * self.p_init
206 |
207 | elif 'L0' in self.norm:
208 | if 0 < it <= 50:
209 | p = self.p_init / 2
210 | elif 50 < it <= 200:
211 | p = self.p_init / 4
212 | elif 200 < it <= 500:
213 | p = self.p_init / 5
214 | elif 500 < it <= 1000:
215 | p = self.p_init / 6
216 | elif 1000 < it <= 2000:
217 | p = self.p_init / 8
218 | elif 2000 < it <= 4000:
219 | p = self.p_init / 10
220 | elif 4000 < it <= 6000:
221 | p = self.p_init / 12
222 | elif 6000 < it <= 8000:
223 | p = self.p_init / 15
224 | elif 8000 < it:
225 | p = self.p_init / 20
226 | else:
227 | p = self.p_init
228 |
229 | if self.constant_schedule:
230 | p = self.p_init / 2
231 |
232 | return p
233 |
234 | def sh_selection(self, it):
235 | """ schedule to decrease the parameter p """
236 |
237 | t = max((float(self.n_queries - it) / self.n_queries - .0) ** 1., 0) * .75
238 |
239 | return t
240 |
241 | def get_init_patch(self, c, s, n_iter=1000):
242 | if self.init_patches == 'stripes':
243 | patch_univ = torch.zeros([1, c, s, s]).to(self.device) + self.random_choice(
244 | [1, c, 1, s]).clamp(0., 1.)
245 | elif self.init_patches == 'uniform':
246 | patch_univ = torch.zeros([1, c, s, s]).to(self.device) + self.random_choice(
247 | [1, c, 1, 1]).clamp(0., 1.)
248 | elif self.init_patches == 'random':
249 | patch_univ = self.random_choice([1, c, s, s]).clamp(0., 1.)
250 | elif self.init_patches == 'random_squares':
251 | patch_univ = torch.zeros([1, c, s, s]).to(self.device)
252 | for _ in range(n_iter):
253 | size_init = torch.randint(low=1, high=math.ceil(s ** .5), size=[1]).item()
254 | loc_init = torch.randint(s - size_init + 1, size=[2])
255 | patch_univ[0, :, loc_init[0]:loc_init[0] + size_init, loc_init[1]:loc_init[1] + size_init] = 0.
256 | patch_univ[0, :, loc_init[0]:loc_init[0] + size_init, loc_init[1]:loc_init[1] + size_init
257 | ] += self.random_choice([c, 1, 1]).clamp(0., 1.)
258 | elif self.init_patches == 'sh':
259 | patch_univ = torch.ones([1, c, s, s]).to(self.device)
260 |
261 | return patch_univ.clamp(0., 1.)
262 |
263 | def attack_single_run(self, x, y):
264 | with torch.no_grad():
265 | adv = x.clone()
266 | c, h, w = x.shape[1:]
267 | n_features = c * h * w
268 | n_ex_total = x.shape[0]
269 |
270 | if self.norm == 'L0':
271 | eps = self.eps
272 |
273 | x_best = x.clone()
274 | n_pixels = h * w
275 | b_all, be_all = torch.zeros([x.shape[0], eps]).long(), torch.zeros([x.shape[0], n_pixels - eps]).long()
276 | for img in range(x.shape[0]):
277 | ind_all = torch.randperm(n_pixels)
278 | ind_p = ind_all[:eps]
279 | ind_np = ind_all[eps:]
280 | x_best[img, :, ind_p // w, ind_p % w] = self.random_choice([c, eps]).clamp(0., 1.)
281 | b_all[img] = ind_p.clone()
282 | be_all[img] = ind_np.clone()
283 |
284 | margin_min, loss_min = self.margin_and_loss(x_best, y)
285 | n_queries = torch.ones(x.shape[0]).to(self.device)
286 |
287 | for it in range(1, self.n_queries):
288 | # check points still to fool
289 | idx_to_fool = (margin_min > 0.).nonzero().squeeze()
290 | x_curr = self.check_shape(x[idx_to_fool])
291 | x_best_curr = self.check_shape(x_best[idx_to_fool])
292 | y_curr = y[idx_to_fool]
293 | margin_min_curr = margin_min[idx_to_fool]
294 | loss_min_curr = loss_min[idx_to_fool]
295 | b_curr, be_curr = b_all[idx_to_fool], be_all[idx_to_fool]
296 | if len(y_curr.shape) == 0:
297 | y_curr.unsqueeze_(0)
298 | margin_min_curr.unsqueeze_(0)
299 | loss_min_curr.unsqueeze_(0)
300 | b_curr.unsqueeze_(0)
301 | be_curr.unsqueeze_(0)
302 | idx_to_fool.unsqueeze_(0)
303 |
304 | # build new candidate
305 | x_new = x_best_curr.clone()
306 | eps_it = max(int(self.p_selection(it) * eps), 1)
307 | ind_p = torch.randperm(eps)[:eps_it]
308 | ind_np = torch.randperm(n_pixels - eps)[:eps_it]
309 |
310 | for img in range(x_new.shape[0]):
311 | p_set = b_curr[img, ind_p]
312 | np_set = be_curr[img, ind_np]
313 | x_new[img, :, p_set // w, p_set % w] = x_curr[img, :, p_set // w, p_set % w].clone()
314 | if eps_it > 1:
315 | x_new[img, :, np_set // w, np_set % w] = self.random_choice([c, eps_it]).clamp(0., 1.)
316 | else:
317 | # if update is 1x1 make sure the sampled color is different from the current one
318 | old_clr = x_new[img, :, np_set // w, np_set % w].clone()
319 | assert old_clr.shape == (c, 1), print(old_clr)
320 | new_clr = old_clr.clone()
321 | while (new_clr == old_clr).all().item():
322 | new_clr = self.random_choice([c, 1]).clone().clamp(0., 1.)
323 | x_new[img, :, np_set // w, np_set % w] = new_clr.clone()
324 |
325 | # compute loss of the new candidates
326 | margin, loss = self.margin_and_loss(x_new, y_curr)
327 | n_queries[idx_to_fool] += 1
328 |
329 | # update best solution
330 | idx_improved = (loss < loss_min_curr).float()
331 | idx_to_update = (idx_improved > 0.).nonzero().squeeze()
332 | loss_min[idx_to_fool[idx_to_update]] = loss[idx_to_update]
333 |
334 | idx_miscl = (margin < -1e-6).float()
335 | idx_improved = torch.max(idx_improved, idx_miscl)
336 | nimpr = idx_improved.sum().item()
337 | if nimpr > 0.:
338 | idx_improved = (idx_improved.view(-1) > 0).nonzero().squeeze()
339 | margin_min[idx_to_fool[idx_improved]] = margin[idx_improved].clone()
340 | x_best[idx_to_fool[idx_improved]] = x_new[idx_improved].clone()
341 | t = b_curr[idx_improved].clone()
342 | te = be_curr[idx_improved].clone()
343 |
344 | if nimpr > 1:
345 | t[:, ind_p] = be_curr[idx_improved][:, ind_np] + 0
346 | te[:, ind_np] = b_curr[idx_improved][:, ind_p] + 0
347 | else:
348 | t[ind_p] = be_curr[idx_improved][ind_np] + 0
349 | te[ind_np] = b_curr[idx_improved][ind_p] + 0
350 |
351 | b_all[idx_to_fool[idx_improved]] = t.clone()
352 | be_all[idx_to_fool[idx_improved]] = te.clone()
353 |
354 | # log results current iteration
355 | ind_succ = (margin_min <= 0.).nonzero().squeeze()
356 | if self.verbose and ind_succ.numel() != 0:
357 | self.logger.log(' '.join(['{}'.format(it + 1),
358 | '- success rate={}/{} ({:.2%})'.format(
359 | ind_succ.numel(), n_ex_total,
360 | float(ind_succ.numel()) / n_ex_total),
361 | '- avg # queries={:.1f}'.format(
362 | n_queries[ind_succ].mean().item()),
363 | '- med # queries={:.1f}'.format(
364 | n_queries[ind_succ].median().item()),
365 | '- loss={:.3f}'.format(loss_min.mean()),
366 | '- max pert={:.0f}'.format(((x_new - x_curr).abs() > 0
367 | ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()),
368 | '- epsit={:.0f}'.format(eps_it),
369 | ]))
370 |
371 | if ind_succ.numel() == n_ex_total:
372 | break
373 |
374 | elif self.norm == 'patches':
375 | ''' assumes square images and patches '''
376 |
377 | s = int(math.ceil(self.eps ** .5))
378 | x_best = x.clone()
379 | x_new = x.clone()
380 | loc = torch.randint(h - s, size=[x.shape[0], 2])
381 | patches_coll = torch.zeros([x.shape[0], c, s, s]).to(self.device)
382 | assert abs(self.update_loc_period) > 1
383 | loc_t = abs(self.update_loc_period)
384 |
385 | # set when to start single channel updates
386 | it_start_cu = None
387 | for it in range(0, self.n_queries):
388 | s_it = int(max(self.p_selection(it) ** .5 * s, 1))
389 | if s_it == 1:
390 | break
391 | it_start_cu = it + (self.n_queries - it) // 2
392 | if self.verbose:
393 | self.logger.log('starting single channel updates at query {}'.format(
394 | it_start_cu))
395 |
396 | # initialize patches
397 | if self.verbose:
398 | self.logger.log('using {} initialization'.format(self.init_patches))
399 | for counter in range(x.shape[0]):
400 | patches_coll[counter] += self.get_init_patch(c, s).squeeze().clamp(0., 1.)
401 | x_new[counter, :, loc[counter, 0]:loc[counter, 0] + s,
402 | loc[counter, 1]:loc[counter, 1] + s] = patches_coll[counter].clone()
403 |
404 | margin_min, loss_min = self.margin_and_loss(x_new, y)
405 | n_queries = torch.ones(x.shape[0]).to(self.device)
406 |
407 | for it in range(1, self.n_queries):
408 | # check points still to fool
409 | idx_to_fool = (margin_min > -1e-6).nonzero().squeeze()
410 | x_curr = self.check_shape(x[idx_to_fool])
411 | patches_curr = self.check_shape(patches_coll[idx_to_fool])
412 | y_curr = y[idx_to_fool]
413 | margin_min_curr = margin_min[idx_to_fool]
414 | loss_min_curr = loss_min[idx_to_fool]
415 | loc_curr = loc[idx_to_fool]
416 | if len(y_curr.shape) == 0:
417 | y_curr.unsqueeze_(0)
418 | margin_min_curr.unsqueeze_(0)
419 | loss_min_curr.unsqueeze_(0)
420 |
421 | loc_curr.unsqueeze_(0)
422 | idx_to_fool.unsqueeze_(0)
423 |
424 | # sample update
425 | s_it = int(max(self.p_selection(it) ** .5 * s, 1))
426 | p_it = torch.randint(s - s_it + 1, size=[2])
427 | sh_it = int(max(self.sh_selection(it) * h, 0))
428 | patches_new = patches_curr.clone()
429 | x_new = x_curr.clone()
430 | loc_new = loc_curr.clone()
431 | update_loc = int((it % loc_t == 0) and (sh_it > 0))
432 | update_patch = 1. - update_loc
433 | if self.update_loc_period < 0 and sh_it > 0:
434 | update_loc = 1. - update_loc
435 | update_patch = 1. - update_patch
436 | for counter in range(x_curr.shape[0]):
437 | if update_patch == 1.:
438 | # update patch
439 | if it < it_start_cu:
440 | if s_it > 1:
441 | patches_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] += self.random_choice([c, 1, 1])
442 | else:
443 | # make sure to sample a different color
444 | old_clr = patches_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it].clone()
445 | new_clr = old_clr.clone()
446 | while (new_clr == old_clr).all().item():
447 | new_clr = self.random_choice([c, 1, 1]).clone().clamp(0., 1.)
448 | patches_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = new_clr.clone()
449 | else:
450 | assert s_it == 1
451 | assert it >= it_start_cu
452 | # single channel updates
453 | new_ch = self.random_int(low=0, high=3, shape=[1])
454 | patches_new[counter, new_ch, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = 1. - patches_new[
455 | counter, new_ch, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it]
456 |
457 | patches_new[counter].clamp_(0., 1.)
458 | if update_loc == 1:
459 | # update location
460 | loc_new[counter] += (torch.randint(low=-sh_it, high=sh_it + 1, size=[2]))
461 | loc_new[counter].clamp_(0, h - s)
462 |
463 | x_new[counter, :, loc_new[counter, 0]:loc_new[counter, 0] + s,
464 | loc_new[counter, 1]:loc_new[counter, 1] + s] = patches_new[counter].clone()
465 |
466 | # check loss of new candidate
467 | margin, loss = self.margin_and_loss(x_new, y_curr)
468 | n_queries[idx_to_fool]+= 1
469 |
470 | # update best solution
471 | idx_improved = (loss < loss_min_curr).float()
472 | idx_to_update = (idx_improved > 0.).nonzero().squeeze()
473 | loss_min[idx_to_fool[idx_to_update]] = loss[idx_to_update]
474 |
475 | idx_miscl = (margin < -1e-6).float()
476 | idx_improved = torch.max(idx_improved, idx_miscl)
477 | nimpr = idx_improved.sum().item()
478 | if nimpr > 0.:
479 | idx_improved = (idx_improved.view(-1) > 0).nonzero().squeeze()
480 | margin_min[idx_to_fool[idx_improved]] = margin[idx_improved].clone()
481 | patches_coll[idx_to_fool[idx_improved]] = patches_new[idx_improved].clone()
482 | loc[idx_to_fool[idx_improved]] = loc_new[idx_improved].clone()
483 |
484 | # log results current iteration
485 | ind_succ = (margin_min <= 0.).nonzero().squeeze()
486 | if self.verbose and ind_succ.numel() != 0:
487 | self.logger.log(' '.join(['{}'.format(it + 1),
488 | '- success rate={}/{} ({:.2%})'.format(
489 | ind_succ.numel(), n_ex_total,
490 | float(ind_succ.numel()) / n_ex_total),
491 | '- avg # queries={:.1f}'.format(
492 | n_queries[ind_succ].mean().item()),
493 | '- med # queries={:.1f}'.format(
494 | n_queries[ind_succ].median().item()),
495 | '- loss={:.3f}'.format(loss_min.mean()),
496 | '- max pert={:.0f}'.format(((x_new - x_curr).abs() > 0
497 | ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()),
498 | #'- sit={:.0f} - sh={:.0f}'.format(s_it, sh_it),
499 | '{}'.format(' - loc' if update_loc == 1. else ''),
500 | ]))
501 |
502 | if ind_succ.numel() == n_ex_total:
503 | break
504 |
505 | # apply patches
506 | for counter in range(x.shape[0]):
507 | x_best[counter, :, loc[counter, 0]:loc[counter, 0] + s,
508 | loc[counter, 1]:loc[counter, 1] + s] = patches_coll[counter].clone()
509 |
510 | elif self.norm == 'patches_universal':
511 | ''' assumes square images and patches '''
512 |
513 | s = int(math.ceil(self.eps ** .5))
514 | x_best = x.clone()
515 | self.n_imgs = x.shape[0]
516 | x_new = x.clone()
517 | loc = torch.randint(h - s + 1, size=[x.shape[0], 2])
518 |
519 | # set when to start single channel updates
520 | it_start_cu = None
521 | for it in range(0, self.n_queries):
522 | s_it = int(max(self.p_selection(it) ** .5 * s, 1))
523 | if s_it == 1:
524 | break
525 | it_start_cu = it + (self.n_queries - it) // 2
526 | if self.verbose:
527 | self.logger.log('starting single channel updates at query {}'.format(
528 | it_start_cu))
529 |
530 | # initialize patch
531 | if self.verbose:
532 | self.logger.log('using {} initialization'.format(self.init_patches))
533 | patch_univ = self.get_init_patch(c, s)
534 | it_init = 0
535 |
536 | loss_batch = float(1e10)
537 | n_succs = 0
538 | n_iter = self.n_queries
539 |
540 | # init update batch
541 | assert not self.data_loader is None
542 | assert not self.resample_loc is None
543 | assert self.targeted
544 | new_train_imgs = []
545 | n_newimgs = self.n_imgs + 0
546 | n_imgsneeded = math.ceil(self.n_queries / self.resample_loc) * n_newimgs
547 | tot_imgs = 0
548 | if self.verbose:
549 | self.logger.log('imgs updated={}, imgs needed={}'.format(
550 | n_newimgs, n_imgsneeded))
551 | while tot_imgs < min(100000, n_imgsneeded):
552 | x_toupdatetrain, _ = next(self.data_loader)
553 | new_train_imgs.append(x_toupdatetrain)
554 | tot_imgs += x_toupdatetrain.shape[0]
555 | newimgstoadd = torch.cat(new_train_imgs, axis=0)
556 | counter_resamplingimgs = 0
557 |
558 | for it in range(it_init, n_iter):
559 | # sample size and location of the update
560 | s_it = int(max(self.p_selection(it) ** .5 * s, 1))
561 | p_it = torch.randint(s - s_it + 1, size=[2])
562 |
563 | patch_new = patch_univ.clone()
564 |
565 | if s_it > 1:
566 | patch_new[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] += self.random_choice([c, 1, 1])
567 | else:
568 | old_clr = patch_new[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it].clone()
569 | new_clr = old_clr.clone()
570 | if it < it_start_cu:
571 | while (new_clr == old_clr).all().item():
572 | new_clr = self.random_choice(new_clr).clone().clamp(0., 1.)
573 | else:
574 | # single channel update
575 | new_ch = self.random_int(low=0, high=3, shape=[1])
576 | new_clr[new_ch] = 1. - new_clr[new_ch]
577 |
578 | patch_new[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = new_clr.clone()
579 |
580 | patch_new.clamp_(0., 1.)
581 |
582 | # compute loss for new candidate
583 | x_new = x.clone()
584 |
585 | for counter in range(x.shape[0]):
586 | loc_new = loc[counter]
587 | x_new[counter, :, loc_new[0]:loc_new[0] + s, loc_new[1]:loc_new[1] + s] = 0.
588 | x_new[counter, :, loc_new[0]:loc_new[0] + s, loc_new[1]:loc_new[1] + s] += patch_new[0]
589 |
590 | margin_run, loss_run = self.margin_and_loss(x_new, y)
591 | if self.loss == 'ce':
592 | loss_run += x_new.shape[0]
593 | loss_new = loss_run.sum()
594 | n_succs_new = (margin_run < -1e-6).sum().item()
595 |
596 | # accept candidate if loss improves
597 | if loss_new < loss_batch:
598 | is_accepted = True
599 | loss_batch = loss_new + 0.
600 | patch_univ = patch_new.clone()
601 | n_succs = n_succs_new + 0
602 | else:
603 | is_accepted = False
604 |
605 | # sample new locations and images
606 | if (it + 1) % self.resample_loc == 0:
607 | newimgstoadd_it = newimgstoadd[counter_resamplingimgs * n_newimgs:(
608 | counter_resamplingimgs + 1) * n_newimgs].clone().cuda()
609 | new_batch = [x[n_newimgs:].clone(), newimgstoadd_it.clone()]
610 | x = torch.cat(new_batch, dim=0)
611 | assert x.shape[0] == self.n_imgs
612 |
613 | loc = torch.randint(h - s + 1, size=[self.n_imgs, 2])
614 | assert loc.shape == (self.n_imgs, 2)
615 |
616 | loss_batch = loss_batch * 0. + 1e6
617 | counter_resamplingimgs += 1
618 |
619 | # logging current iteration
620 | if self.verbose:
621 | self.logger.log(' '.join(['{}'.format(it + 1),
622 | '- success rate={}/{} ({:.2%})'.format(
623 | n_succs, n_ex_total,
624 | float(n_succs) / n_ex_total),
625 | '- loss={:.3f}'.format(loss_batch),
626 | '- max pert={:.0f}'.format(((x_new - x).abs() > 0
627 | ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()),
628 | ]))
629 |
630 | # apply patches on the initial images
631 | for counter in range(x_best.shape[0]):
632 | loc_new = loc[counter]
633 | x_best[counter, :, loc_new[0]:loc_new[0] + s, loc_new[1]:loc_new[1] + s] = 0.
634 | x_best[counter, :, loc_new[0]:loc_new[0] + s, loc_new[1]:loc_new[1] + s] += patch_univ[0]
635 |
636 | elif self.norm == 'frames':
637 | # set width and indices of frames
638 | mask = torch.zeros(x.shape[-2:])
639 | s = self.eps + 0
640 | mask[:s] = 1.
641 | mask[-s:] = 1.
642 | mask[:, :s] = 1.
643 | mask[:, -s:] = 1.
644 | ind = (mask == 1.).nonzero().squeeze()
645 | eps = ind.shape[0]
646 | x_best = x.clone()
647 | x_new = x.clone()
648 | mask = mask.view(1, 1, h, w).to(self.device)
649 | mask_frame = torch.ones([1, c, h, w], device=x.device) * mask
650 | #
651 |
652 | # set when starting single channel updates
653 | it_start_cu = None
654 | for it in range(0, self.n_queries):
655 | s_it = int(max(self.p_selection(it), 1))
656 | if s_it == 1:
657 | break
658 | it_start_cu = it + (self.n_queries - it) // 2
659 | #it_start_cu = 10000
660 | if self.verbose:
661 | self.logger.log('starting single channel updates at query {}'.format(
662 | it_start_cu))
663 |
664 | # initialize frames
665 | x_best[:, :, ind[:, 0], ind[:, 1]] = self.random_choice(
666 | [x.shape[0], c, eps]).clamp(0., 1.)
667 |
668 | margin_min, loss_min = self.margin_and_loss(x_best, y)
669 | n_queries = torch.ones(x.shape[0]).to(self.device)
670 |
671 | for it in range(1, self.n_queries):
672 | # check points still to fool
673 | idx_to_fool = (margin_min > -1e-6).nonzero().squeeze()
674 | x_curr = self.check_shape(x[idx_to_fool])
675 | x_best_curr = self.check_shape(x_best[idx_to_fool])
676 | y_curr = y[idx_to_fool]
677 | margin_min_curr = margin_min[idx_to_fool]
678 | loss_min_curr = loss_min[idx_to_fool]
679 |
680 | if len(y_curr.shape) == 0:
681 | y_curr.unsqueeze_(0)
682 | margin_min_curr.unsqueeze_(0)
683 | loss_min_curr.unsqueeze_(0)
684 | idx_to_fool.unsqueeze_(0)
685 |
686 | # sample update
687 | s_it = max(int(self.p_selection(it)), 1)
688 | ind_it = torch.randperm(eps)[0]
689 |
690 | x_new = x_best_curr.clone()
691 | if s_it > 1:
692 | dir_h = self.random_choice([1]).long().cpu()
693 | dir_w = self.random_choice([1]).long().cpu()
694 | new_clr = self.random_choice([c, 1]).clamp(0., 1.)
695 |
696 | for counter in range(x_curr.shape[0]):
697 | if s_it > 1:
698 | for counter_h in range(s_it):
699 | for counter_w in range(s_it):
700 | x_new[counter, :, (ind[ind_it, 0] + dir_h * counter_h).clamp(0, h - 1),
701 | (ind[ind_it, 1] + dir_w * counter_w).clamp(0, w - 1)] = new_clr.clone()
702 | else:
703 | p_it = ind[ind_it].clone()
704 | old_clr = x_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it].clone()
705 | new_clr = old_clr.clone()
706 | if it < it_start_cu:
707 | while (new_clr == old_clr).all().item():
708 | new_clr = self.random_choice([c, 1, 1]).clone().clamp(0., 1.)
709 | else:
710 | # single channel update
711 | new_ch = self.random_int(low=0, high=3, shape=[1])
712 | new_clr[new_ch] = 1. - new_clr[new_ch]
713 | x_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = new_clr.clone()
714 |
715 | x_new.clamp_(0., 1.)
716 | x_new = (x_new - x_curr) * mask_frame + x_curr
717 |
718 | # check loss of new candidate
719 | margin, loss = self.margin_and_loss(x_new, y_curr)
720 | n_queries[idx_to_fool]+= 1
721 |
722 | # update best solution
723 | idx_improved = (loss < loss_min_curr).float()
724 | idx_to_update = (idx_improved > 0.).nonzero().squeeze()
725 | loss_min[idx_to_fool[idx_to_update]] = loss[idx_to_update]
726 |
727 | idx_miscl = (margin < -1e-6).float()
728 | idx_improved = torch.max(idx_improved, idx_miscl)
729 | nimpr = idx_improved.sum().item()
730 | if nimpr > 0.:
731 | idx_improved = (idx_improved.view(-1) > 0).nonzero().squeeze()
732 | margin_min[idx_to_fool[idx_improved]] = margin[idx_improved].clone()
733 | x_best[idx_to_fool[idx_improved]] = x_new[idx_improved].clone()
734 |
735 | # log results current iteration
736 | ind_succ = (margin_min <= 0.).nonzero().squeeze()
737 | if self.verbose and ind_succ.numel() != 0:
738 | self.logger.log(' '.join(['{}'.format(it + 1),
739 | '- success rate={}/{} ({:.2%})'.format(
740 | ind_succ.numel(), n_ex_total,
741 | float(ind_succ.numel()) / n_ex_total),
742 | '- avg # queries={:.1f}'.format(
743 | n_queries[ind_succ].mean().item()),
744 | '- med # queries={:.1f}'.format(
745 | n_queries[ind_succ].median().item()),
746 | '- loss={:.3f}'.format(loss_min.mean()),
747 | '- max pert={:.0f}'.format(((x_new - x_curr).abs() > 0
748 | ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()),
749 | #'- min pert={:.0f}'.format(((x_new - x_curr).abs() > 0
750 | #).max(1)[0].view(x_new.shape[0], -1).sum(-1).min()),
751 | #'- sit={:.0f} - indit={}'.format(s_it, ind_it.item()),
752 | ]))
753 |
754 | if ind_succ.numel() == n_ex_total:
755 | break
756 |
757 |
758 | elif self.norm == 'frames_universal':
759 | # set width and indices of frames
760 | mask = torch.zeros(x.shape[-2:])
761 | s = self.eps + 0
762 | mask[:s] = 1.
763 | mask[-s:] = 1.
764 | mask[:, :s] = 1.
765 | mask[:, -s:] = 1.
766 | ind = (mask == 1.).nonzero().squeeze()
767 | eps = ind.shape[0]
768 | x_best = x.clone()
769 | x_new = x.clone()
770 | mask = mask.view(1, 1, h, w).to(self.device)
771 | mask_frame = torch.ones([1, c, h, w], device=x.device) * mask
772 | frame_univ = self.random_choice([1, c, eps]).clamp(0., 1.)
773 |
774 | # set when to start single channel updates
775 | it_start_cu = None
776 | for it in range(0, self.n_queries):
777 | s_it = int(max(self.p_selection(it) * s, 1))
778 | if s_it == 1:
779 | break
780 | it_start_cu = it + (self.n_queries - it) // 2
781 | if self.verbose:
782 | self.logger.log('starting single channel updates at query {}'.format(
783 | it_start_cu))
784 |
785 | self.n_imgs = x.shape[0]
786 | loss_batch = float(1e10)
787 | n_queries = torch.ones_like(y).float()
788 |
789 | # init update batch
790 | assert not self.data_loader is None
791 | assert not self.resample_loc is None
792 | assert self.targeted
793 | new_train_imgs = []
794 | n_newimgs = self.n_imgs + 0
795 | n_imgsneeded = math.ceil(self.n_queries / self.resample_loc) * n_newimgs
796 | tot_imgs = 0
797 | if self.verbose:
798 | self.logger.log('imgs updated={}, imgs needed={}'.format(
799 | n_newimgs, n_imgsneeded))
800 | while tot_imgs < min(100000, n_imgsneeded):
801 | x_toupdatetrain, _ = next(self.data_loader)
802 | new_train_imgs.append(x_toupdatetrain)
803 | tot_imgs += x_toupdatetrain.shape[0]
804 | newimgstoadd = torch.cat(new_train_imgs, axis=0)
805 | counter_resamplingimgs = 0
806 |
807 | for it in range(self.n_queries):
808 | # sample update
809 | s_it = max(int(self.p_selection(it) * self.eps), 1)
810 | ind_it = torch.randperm(eps)[0]
811 |
812 | mask_frame[:, :, ind[:, 0], ind[:, 1]] = 0
813 | mask_frame[:, :, ind[:, 0], ind[:, 1]] += frame_univ
814 |
815 | if s_it > 1:
816 | dir_h = self.random_choice([1]).long().cpu()
817 | dir_w = self.random_choice([1]).long().cpu()
818 | new_clr = self.random_choice([c, 1]).clamp(0., 1.)
819 |
820 | for counter_h in range(s_it):
821 | for counter_w in range(s_it):
822 | mask_frame[0, :, (ind[ind_it, 0] + dir_h * counter_h).clamp(0, h - 1),
823 | (ind[ind_it, 1] + dir_w * counter_w).clamp(0, w - 1)] = new_clr.clone()
824 | else:
825 | p_it = ind[ind_it]
826 | old_clr = mask_frame[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it].clone()
827 | new_clr = old_clr.clone()
828 | if it < it_start_cu:
829 | while (new_clr == old_clr).all().item():
830 | new_clr = self.random_choice([c, 1, 1]).clone().clamp(0., 1.)
831 | else:
832 | # single channel update
833 | new_ch = self.random_int(low=0, high=3, shape=[1])
834 | new_clr[new_ch] = 1. - new_clr[new_ch]
835 | mask_frame[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = new_clr.clone()
836 |
837 | frame_new = mask_frame[:, :, ind[:, 0], ind[:, 1]].clone()
838 | frame_new.clamp_(0., 1.)
839 | if len(frame_new.shape) == 2:
840 | frame_new.unsqueeze_(0)
841 |
842 | x_new[:, :, ind[:, 0], ind[:, 1]] = 0.
843 | x_new[:, :, ind[:, 0], ind[:, 1]] += frame_new
844 |
845 | margin_run, loss_run = self.margin_and_loss(x_new, y)
846 | if self.loss == 'ce':
847 | loss_run += x_new.shape[0]
848 | loss_new = loss_run.sum()
849 | n_succs_new = (margin_run < -1e-6).sum().item()
850 |
851 | # accept candidate if loss improves
852 | if loss_new < loss_batch:
853 | #is_accepted = True
854 | loss_batch = loss_new + 0.
855 | frame_univ = frame_new.clone()
856 | n_succs = n_succs_new + 0
857 |
858 | # sample new images
859 | if (it + 1) % self.resample_loc == 0:
860 | newimgstoadd_it = newimgstoadd[counter_resamplingimgs * n_newimgs:(
861 | counter_resamplingimgs + 1) * n_newimgs].clone().cuda()
862 | new_batch = [x[n_newimgs:].clone(), newimgstoadd_it.clone()]
863 | x = torch.cat(new_batch, dim=0)
864 | assert x.shape[0] == self.n_imgs
865 |
866 | loss_batch = loss_batch * 0. + 1e6
867 | x_new = x.clone()
868 | counter_resamplingimgs += 1
869 |
870 | # loggin current iteration
871 | if self.verbose:
872 | self.logger.log(' '.join(['{}'.format(it + 1),
873 | '- success rate={}/{} ({:.2%})'.format(
874 | n_succs, n_ex_total,
875 | float(n_succs) / n_ex_total),
876 | '- loss={:.3f}'.format(loss_batch),
877 | '- max pert={:.0f}'.format(((x_new - x).abs() > 0
878 | ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()),
879 | ]))
880 |
881 | # apply frame on initial images
882 | x_best[:, :, ind[:, 0], ind[:, 1]] = 0.
883 | x_best[:, :, ind[:, 0], ind[:, 1]] += frame_univ
884 |
885 | return n_queries, x_best
886 |
887 | def perturb(self, x, y=None):
888 | """
889 | :param x: clean images
890 | :param y: untargeted attack -> clean labels,
891 | if None we use the predicted labels
892 | targeted attack -> target labels, if None random classes,
893 | different from the predicted ones, are sampled
894 | """
895 |
896 | self.init_hyperparam(x)
897 |
898 | adv = x.clone()
899 | qr = torch.zeros([x.shape[0]]).to(self.device)
900 | if y is None:
901 | if not self.targeted:
902 | with torch.no_grad():
903 | output = self.predict(x)
904 | y_pred = output.max(1)[1]
905 | y = y_pred.detach().clone().long().to(self.device)
906 | else:
907 | with torch.no_grad():
908 | output = self.predict(x)
909 | n_classes = output.shape[-1]
910 | y_pred = output.max(1)[1]
911 | y = self.random_target_classes(y_pred, n_classes)
912 | else:
913 | y = y.detach().clone().long().to(self.device)
914 |
915 | if not self.targeted:
916 | acc = self.predict(x).max(1)[1] == y
917 | else:
918 | acc = self.predict(x).max(1)[1] != y
919 |
920 | startt = time.time()
921 |
922 | torch.random.manual_seed(self.seed)
923 | torch.cuda.random.manual_seed(self.seed)
924 | np.random.seed(self.seed)
925 |
926 | for counter in range(self.n_restarts):
927 | ind_to_fool = acc.nonzero().squeeze()
928 | if len(ind_to_fool.shape) == 0:
929 | ind_to_fool = ind_to_fool.unsqueeze(0)
930 | if ind_to_fool.numel() != 0:
931 | x_to_fool = x[ind_to_fool].clone()
932 | y_to_fool = y[ind_to_fool].clone()
933 |
934 | qr_curr, adv_curr = self.attack_single_run(x_to_fool, y_to_fool)
935 |
936 | output_curr = self.predict(adv_curr)
937 | if not self.targeted:
938 | acc_curr = output_curr.max(1)[1] == y_to_fool
939 | else:
940 | acc_curr = output_curr.max(1)[1] != y_to_fool
941 | ind_curr = (acc_curr == 0).nonzero().squeeze()
942 |
943 | acc[ind_to_fool[ind_curr]] = 0
944 | adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone()
945 | qr[ind_to_fool[ind_curr]] = qr_curr[ind_curr].clone()
946 | if self.verbose:
947 | print('restart {} - robust accuracy: {:.2%}'.format(
948 | counter, acc.float().mean()),
949 | '- cum. time: {:.1f} s'.format(
950 | time.time() - startt))
951 |
952 | return qr, adv
953 |
954 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Logger():
6 | def __init__(self, log_path):
7 | self.log_path = log_path
8 |
9 | def log(self, str_to_log):
10 | print(str_to_log)
11 | if not self.log_path is None:
12 | with open(self.log_path, 'a') as f:
13 | f.write(str_to_log + '\n')
14 | f.flush()
15 |
16 |
17 | class SingleChannelModel():
18 | """ reshapes images to rgb before classification
19 | i.e. [N, 1, H, W x 3] -> [N, 3, H, W]
20 | """
21 | def __init__(self, model):
22 | if isinstance(model, nn.Module):
23 | assert not model.training
24 | self.model = model
25 |
26 | def __call__(self, x):
27 | return self.model(x.view(x.shape[0], 3, x.shape[2], x.shape[3] // 3))
28 |
29 |
--------------------------------------------------------------------------------
/vis_images.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import torch
4 | import os
5 | import argparse
6 |
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--path_data', type=str)
9 |
10 | args = parser.parse_args()
11 |
12 | if args.path_data is None:
13 | path_data = './results/ImageNet/sparse-rs_L0_pt_vgg_1_50_nqueries_1000_alphainit_0.30_loss_margin_eps_150_targeted_False_seed_0.pth'
14 | else:
15 | path_data = args.path_data
16 |
17 | data = torch.load(path_data)
18 | if 'qr' in list(data.keys()):
19 | imgs, qr = data['adv'].cpu(), data['qr'].cpu()
20 | else:
21 | imgs, qr = data['adv'].cpu(), torch.arange(1, data['adv'].shape[0])
22 |
23 | nqueries = 100000
24 | ind = ((qr > 0) * (qr < nqueries)).nonzero().squeeze()
25 | imgs_to_show = imgs[ind].permute(0, 2, 3, 1).cpu().numpy()
26 | if imgs_to_show.shape[-1] == 1:
27 | imgs_to_show = np.tile(imgs_to_show, (1, 1, 1, 3))
28 | qr_to_show = qr[ind]
29 |
30 | qr_to_show, ind = qr_to_show.sort(descending=True)
31 | imgs_to_show = imgs_to_show[ind]
32 |
33 | w = 10
34 | h = 10
35 | fig = plt.figure(figsize=(20, 12))
36 |
37 | columns = 10
38 | rows = 5
39 |
40 | # ax enables access to manipulate each of subplots
41 | ax = []
42 |
43 | init_pos = 0
44 | if 'patch' in list(data.keys()):
45 | ax.append( fig.add_subplot(rows, columns, 1) )
46 | ax[-1].set_title('patch')
47 | ax[-1].get_xaxis().set_ticks([])
48 | ax[-1].get_yaxis().set_ticks([])
49 | ax[-1].axis('off')
50 | print(data['patch'].shape)
51 | s = int(float(path_data.split('eps_')[1].split('_')[0]) ** .5)
52 | patch = data['patch'].squeeze().view(-1, s, s)
53 | plt.imshow(patch.permute(1, 2, 0).cpu().numpy(), interpolation='none')
54 | init_pos = 1
55 |
56 | for i in range(init_pos, columns*rows):
57 | if i < imgs_to_show.shape[0]:
58 | ax.append( fig.add_subplot(rows, columns, i+1) )
59 | ax[-1].set_title('qr = {:.0f}'.format(qr_to_show[i].item()))
60 | ax[-1].get_xaxis().set_ticks([])
61 | ax[-1].get_yaxis().set_ticks([])
62 | ax[-1].axis('off')
63 | plt.imshow(imgs_to_show[i], interpolation='none')
64 |
65 | if not os.path.exists('./results/plots/'):
66 | os.makedirs('./results/plots/')
67 |
68 | #plt.show()
69 | plt.savefig('./results/plots/pl_{}.pdf'.format(path_data.split('/')[-1][:-4]))
70 |
--------------------------------------------------------------------------------