├── .gitmodules ├── LICENSE ├── README.md ├── run_simba.py ├── run_simba_cifar.py ├── simba.py └── utils.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pytorch-cifar"] 2 | path = pytorch-cifar 3 | url = https://github.com/kuangliu/pytorch-cifar.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 cg563 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository contains code for the ICML 2019 paper: 2 | 3 | Chuan Guo, Jacob R. Gardner, Yurong You, Andrew G. Wilson, Kilian Q. Weinberger. Simple Black-box Adversarial Attacks. 4 | https://arxiv.org/abs/1905.07121 5 | 6 | Our code uses PyTorch (pytorch >= 0.4.1, torchvision >= 0.2.1) with CUDA 9.0 and Python 3.5. The script run_simba.py contains code to run SimBA and SimBA-DCT with various options. 7 | 8 | To run SimBA (pixel attack): 9 | ``` 10 | python run_simba.py --data_root --num_iters 10000 --pixel_attack --freq_dims 224 11 | ``` 12 | To run SimBA-DCT (low frequency attack): 13 | ``` 14 | python run_simba.py --data_root --num_iters 10000 --freq_dims 28 --order strided --stride 7 15 | ``` 16 | For targeted attack, add flag ```--targeted``` and change ```--num_iters``` to 30000. 17 | 18 | For the Inception-v3 model, we used ```--freq_dims 38``` and ```--stride 9``` due to the larger input size. 19 | 20 | **Update 2020/01/09**: Due to changes in the underlying Google Cloud Vision models, our attack no longer works against them. 21 | 22 | **Update 2020/06/22**: Added L_inf bounded SimBA-DCT attack. To run with L_inf bound 0.05: 23 | ``` 24 | python run_simba.py --data_root --num_iters 10000 --freq_dims 224 --order rand --linf_bound 0.05 25 | ``` 26 | 27 | **Update 2021/03/10**: Added code for running SimBA on CIFAR-10 using models from [pytorch-cifar](https://github.com/kuangliu/pytorch-cifar). To run targeted attack against a trained ResNet-18 model: 28 | ``` 29 | python run_simba_cifar.py --data_root --model_ckpt --model ResNet18 --targeted 30 | ``` -------------------------------------------------------------------------------- /run_simba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.datasets as dset 3 | import torchvision.transforms as transforms 4 | import torchvision.models as models 5 | import utils 6 | import math 7 | import random 8 | import argparse 9 | import os 10 | from simba import SimBA 11 | 12 | parser = argparse.ArgumentParser(description='Runs SimBA on a set of images') 13 | parser.add_argument('--data_root', type=str, required=True, help='root directory of imagenet data') 14 | parser.add_argument('--result_dir', type=str, default='save', help='directory for saving results') 15 | parser.add_argument('--sampled_image_dir', type=str, default='save', help='directory to cache sampled images') 16 | parser.add_argument('--model', type=str, default='resnet50', help='type of base model to use') 17 | parser.add_argument('--num_runs', type=int, default=1000, help='number of image samples') 18 | parser.add_argument('--batch_size', type=int, default=50, help='batch size for parallel runs') 19 | parser.add_argument('--num_iters', type=int, default=0, help='maximum number of iterations, 0 for unlimited') 20 | parser.add_argument('--log_every', type=int, default=10, help='log every n iterations') 21 | parser.add_argument('--epsilon', type=float, default=0.2, help='step size per iteration') 22 | parser.add_argument('--linf_bound', type=float, default=0.0, help='L_inf bound for frequency space attack') 23 | parser.add_argument('--freq_dims', type=int, default=14, help='dimensionality of 2D frequency space') 24 | parser.add_argument('--order', type=str, default='rand', help='(random) order of coordinate selection') 25 | parser.add_argument('--stride', type=int, default=7, help='stride for block order') 26 | parser.add_argument('--targeted', action='store_true', help='perform targeted attack') 27 | parser.add_argument('--pixel_attack', action='store_true', help='attack in pixel space') 28 | parser.add_argument('--save_suffix', type=str, default='', help='suffix appended to save file') 29 | args = parser.parse_args() 30 | 31 | if not os.path.exists(args.result_dir): 32 | os.mkdir(args.result_dir) 33 | if not os.path.exists(args.sampled_image_dir): 34 | os.mkdir(args.sampled_image_dir) 35 | 36 | # load model and dataset 37 | model = getattr(models, args.model)(pretrained=True).cuda() 38 | model.eval() 39 | if args.model.startswith('inception'): 40 | image_size = 299 41 | testset = dset.ImageFolder(args.data_root + '/val', utils.INCEPTION_TRANSFORM) 42 | else: 43 | image_size = 224 44 | testset = dset.ImageFolder(args.data_root + '/val', utils.IMAGENET_TRANSFORM) 45 | attacker = SimBA(model, 'imagenet', image_size) 46 | 47 | # load sampled images or sample new ones 48 | # this is to ensure all attacks are run on the same set of correctly classified images 49 | batchfile = '%s/images_%s_%d.pth' % (args.sampled_image_dir, args.model, args.num_runs) 50 | if os.path.isfile(batchfile): 51 | checkpoint = torch.load(batchfile) 52 | images = checkpoint['images'] 53 | labels = checkpoint['labels'] 54 | else: 55 | images = torch.zeros(args.num_runs, 3, image_size, image_size) 56 | labels = torch.zeros(args.num_runs).long() 57 | preds = labels + 1 58 | while preds.ne(labels).sum() > 0: 59 | idx = torch.arange(0, images.size(0)).long()[preds.ne(labels)] 60 | for i in list(idx): 61 | images[i], labels[i] = testset[random.randint(0, len(testset) - 1)] 62 | preds[idx], _ = utils.get_preds(model, images[idx], 'imagenet', batch_size=args.batch_size) 63 | torch.save({'images': images, 'labels': labels}, batchfile) 64 | 65 | if args.order == 'rand': 66 | n_dims = 3 * args.freq_dims * args.freq_dims 67 | else: 68 | n_dims = 3 * image_size * image_size 69 | if args.num_iters > 0: 70 | max_iters = int(min(n_dims, args.num_iters)) 71 | else: 72 | max_iters = int(n_dims) 73 | N = int(math.floor(float(args.num_runs) / float(args.batch_size))) 74 | for i in range(N): 75 | upper = min((i + 1) * args.batch_size, args.num_runs) 76 | images_batch = images[(i * args.batch_size):upper] 77 | labels_batch = labels[(i * args.batch_size):upper] 78 | # replace true label with random target labels in case of targeted attack 79 | if args.targeted: 80 | labels_targeted = labels_batch.clone() 81 | while labels_targeted.eq(labels_batch).sum() > 0: 82 | labels_targeted = torch.floor(1000 * torch.rand(labels_batch.size())).long() 83 | labels_batch = labels_targeted 84 | adv, probs, succs, queries, l2_norms, linf_norms = attacker.simba_batch( 85 | images_batch, labels_batch, max_iters, args.freq_dims, args.stride, args.epsilon, linf_bound=args.linf_bound, 86 | order=args.order, targeted=args.targeted, pixel_attack=args.pixel_attack, log_every=args.log_every) 87 | if i == 0: 88 | all_adv = adv 89 | all_probs = probs 90 | all_succs = succs 91 | all_queries = queries 92 | all_l2_norms = l2_norms 93 | all_linf_norms = linf_norms 94 | else: 95 | all_adv = torch.cat([all_adv, adv], dim=0) 96 | all_probs = torch.cat([all_probs, probs], dim=0) 97 | all_succs = torch.cat([all_succs, succs], dim=0) 98 | all_queries = torch.cat([all_queries, queries], dim=0) 99 | all_l2_norms = torch.cat([all_l2_norms, l2_norms], dim=0) 100 | all_linf_norms = torch.cat([all_linf_norms, linf_norms], dim=0) 101 | if args.pixel_attack: 102 | prefix = 'pixel' 103 | else: 104 | prefix = 'dct' 105 | if args.targeted: 106 | prefix += '_targeted' 107 | savefile = '%s/%s_%s_%d_%d_%d_%.4f_%s%s.pth' % ( 108 | args.result_dir, prefix, args.model, args.num_runs, args.num_iters, args.freq_dims, args.epsilon, args.order, args.save_suffix) 109 | torch.save({'adv': all_adv, 'probs': all_probs, 'succs': all_succs, 'queries': all_queries, 110 | 'l2_norms': all_l2_norms, 'linf_norms': all_linf_norms}, savefile) 111 | -------------------------------------------------------------------------------- /run_simba_cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.datasets as dset 3 | import torchvision.transforms as transforms 4 | import utils 5 | import math 6 | import random 7 | import argparse 8 | import os 9 | import sys 10 | sys.path.append('pytorch-cifar') 11 | import models 12 | from simba import SimBA 13 | 14 | parser = argparse.ArgumentParser(description='Runs SimBA on a set of images') 15 | parser.add_argument('--data_root', type=str, required=True, help='root directory of imagenet data') 16 | parser.add_argument('--result_dir', type=str, default='save_cifar', help='directory for saving results') 17 | parser.add_argument('--sampled_image_dir', type=str, default='save_cifar', help='directory to cache sampled images') 18 | parser.add_argument('--model', type=str, default='resnet18', help='type of base model to use') 19 | parser.add_argument('--model_ckpt', type=str, required=True, help='model checkpoint location') 20 | parser.add_argument('--num_runs', type=int, default=1000, help='number of image samples') 21 | parser.add_argument('--batch_size', type=int, default=50, help='batch size for parallel runs') 22 | parser.add_argument('--num_iters', type=int, default=0, help='maximum number of iterations, 0 for unlimited') 23 | parser.add_argument('--log_every', type=int, default=10, help='log every n iterations') 24 | parser.add_argument('--epsilon', type=float, default=0.2, help='step size per iteration') 25 | parser.add_argument('--linf_bound', type=float, default=0.0, help='L_inf bound for frequency space attack') 26 | parser.add_argument('--freq_dims', type=int, default=32, help='dimensionality of 2D frequency space') 27 | parser.add_argument('--order', type=str, default='rand', help='(random) order of coordinate selection') 28 | parser.add_argument('--stride', type=int, default=7, help='stride for block order') 29 | parser.add_argument('--targeted', action='store_true', help='perform targeted attack') 30 | parser.add_argument('--pixel_attack', action='store_true', help='attack in pixel space') 31 | parser.add_argument('--save_suffix', type=str, default='', help='suffix appended to save file') 32 | args = parser.parse_args() 33 | 34 | if not os.path.exists(args.result_dir): 35 | os.mkdir(args.result_dir) 36 | if not os.path.exists(args.sampled_image_dir): 37 | os.mkdir(args.sampled_image_dir) 38 | 39 | # load model and dataset 40 | model = getattr(models, args.model)().cuda() 41 | model = torch.nn.DataParallel(model) 42 | checkpoint = torch.load(args.model_ckpt) 43 | model.load_state_dict(checkpoint['net']) 44 | model.eval() 45 | image_size = 32 46 | testset = dset.CIFAR10(root=args.data_root, train=False, download=True, transform=utils.CIFAR_TRANSFORM) 47 | attacker = SimBA(model, 'cifar', image_size) 48 | 49 | # load sampled images or sample new ones 50 | # this is to ensure all attacks are run on the same set of correctly classified images 51 | batchfile = '%s/images_%s_%d.pth' % (args.sampled_image_dir, args.model, args.num_runs) 52 | if os.path.isfile(batchfile): 53 | checkpoint = torch.load(batchfile) 54 | images = checkpoint['images'] 55 | labels = checkpoint['labels'] 56 | else: 57 | images = torch.zeros(args.num_runs, 3, image_size, image_size) 58 | labels = torch.zeros(args.num_runs).long() 59 | preds = labels + 1 60 | while preds.ne(labels).sum() > 0: 61 | idx = torch.arange(0, images.size(0)).long()[preds.ne(labels)] 62 | for i in list(idx): 63 | images[i], labels[i] = testset[random.randint(0, len(testset) - 1)] 64 | preds[idx], _ = utils.get_preds(model, images[idx], 'cifar', batch_size=args.batch_size) 65 | torch.save({'images': images, 'labels': labels}, batchfile) 66 | 67 | if args.order == 'rand': 68 | n_dims = 3 * args.freq_dims * args.freq_dims 69 | else: 70 | n_dims = 3 * image_size * image_size 71 | if args.num_iters > 0: 72 | max_iters = int(min(n_dims, args.num_iters)) 73 | else: 74 | max_iters = int(n_dims) 75 | N = int(math.floor(float(args.num_runs) / float(args.batch_size))) 76 | for i in range(N): 77 | upper = min((i + 1) * args.batch_size, args.num_runs) 78 | images_batch = images[(i * args.batch_size):upper] 79 | labels_batch = labels[(i * args.batch_size):upper] 80 | # replace true label with random target labels in case of targeted attack 81 | if args.targeted: 82 | labels_targeted = labels_batch.clone() 83 | while labels_targeted.eq(labels_batch).sum() > 0: 84 | labels_targeted = torch.floor(10 * torch.rand(labels_batch.size())).long() 85 | labels_batch = labels_targeted 86 | adv, probs, succs, queries, l2_norms, linf_norms = attacker.simba_batch( 87 | images_batch, labels_batch, max_iters, args.freq_dims, args.stride, args.epsilon, linf_bound=args.linf_bound, 88 | order=args.order, targeted=args.targeted, pixel_attack=args.pixel_attack, log_every=args.log_every) 89 | if i == 0: 90 | all_adv = adv 91 | all_probs = probs 92 | all_succs = succs 93 | all_queries = queries 94 | all_l2_norms = l2_norms 95 | all_linf_norms = linf_norms 96 | else: 97 | all_adv = torch.cat([all_adv, adv], dim=0) 98 | all_probs = torch.cat([all_probs, probs], dim=0) 99 | all_succs = torch.cat([all_succs, succs], dim=0) 100 | all_queries = torch.cat([all_queries, queries], dim=0) 101 | all_l2_norms = torch.cat([all_l2_norms, l2_norms], dim=0) 102 | all_linf_norms = torch.cat([all_linf_norms, linf_norms], dim=0) 103 | if args.pixel_attack: 104 | prefix = 'pixel' 105 | else: 106 | prefix = 'dct' 107 | if args.targeted: 108 | prefix += '_targeted' 109 | savefile = '%s/%s_%s_%d_%d_%d_%.4f_%s%s.pth' % ( 110 | args.result_dir, prefix, args.model, args.num_runs, args.num_iters, args.freq_dims, args.epsilon, args.order, args.save_suffix) 111 | torch.save({'adv': all_adv, 'probs': all_probs, 'succs': all_succs, 'queries': all_queries, 112 | 'l2_norms': all_l2_norms, 'linf_norms': all_linf_norms}, savefile) 113 | -------------------------------------------------------------------------------- /simba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import utils 4 | 5 | 6 | class SimBA: 7 | 8 | def __init__(self, model, dataset, image_size): 9 | self.model = model 10 | self.dataset = dataset 11 | self.image_size = image_size 12 | self.model.eval() 13 | 14 | def expand_vector(self, x, size): 15 | batch_size = x.size(0) 16 | x = x.view(-1, 3, size, size) 17 | z = torch.zeros(batch_size, 3, self.image_size, self.image_size) 18 | z[:, :, :size, :size] = x 19 | return z 20 | 21 | def normalize(self, x): 22 | return utils.apply_normalization(x, self.dataset) 23 | 24 | def get_probs(self, x, y): 25 | output = self.model(self.normalize(x.cuda())).cpu() 26 | probs = torch.index_select(F.softmax(output, dim=-1).data, 1, y) 27 | return torch.diag(probs) 28 | 29 | def get_preds(self, x): 30 | output = self.model(self.normalize(x.cuda())).cpu() 31 | _, preds = output.data.max(1) 32 | return preds 33 | 34 | # 20-line implementation of SimBA for single image input 35 | def simba_single(self, x, y, num_iters=10000, epsilon=0.2, targeted=False): 36 | n_dims = x.view(1, -1).size(1) 37 | perm = torch.randperm(n_dims) 38 | x = x.unsqueeze(0) 39 | last_prob = self.get_probs(x, y) 40 | for i in range(num_iters): 41 | diff = torch.zeros(n_dims) 42 | diff[perm[i]] = epsilon 43 | left_prob = self.get_probs((x - diff.view(x.size())).clamp(0, 1), y) 44 | if targeted != (left_prob < last_prob): 45 | x = (x - diff.view(x.size())).clamp(0, 1) 46 | last_prob = left_prob 47 | else: 48 | right_prob = self.get_probs((x + diff.view(x.size())).clamp(0, 1), y) 49 | if targeted != (right_prob < last_prob): 50 | x = (x + diff.view(x.size())).clamp(0, 1) 51 | last_prob = right_prob 52 | if i % 10 == 0: 53 | print(last_prob) 54 | return x.squeeze() 55 | 56 | # runs simba on a batch of images with true labels (for untargeted attack) or target labels 57 | # (for targeted attack) 58 | def simba_batch(self, images_batch, labels_batch, max_iters, freq_dims, stride, epsilon, linf_bound=0.0, 59 | order='rand', targeted=False, pixel_attack=False, log_every=1): 60 | batch_size = images_batch.size(0) 61 | image_size = images_batch.size(2) 62 | assert self.image_size == image_size 63 | # sample a random ordering for coordinates independently per batch element 64 | if order == 'rand': 65 | indices = torch.randperm(3 * freq_dims * freq_dims)[:max_iters] 66 | elif order == 'diag': 67 | indices = utils.diagonal_order(image_size, 3)[:max_iters] 68 | elif order == 'strided': 69 | indices = utils.block_order(image_size, 3, initial_size=freq_dims, stride=stride)[:max_iters] 70 | else: 71 | indices = utils.block_order(image_size, 3)[:max_iters] 72 | if order == 'rand': 73 | expand_dims = freq_dims 74 | else: 75 | expand_dims = image_size 76 | n_dims = 3 * expand_dims * expand_dims 77 | x = torch.zeros(batch_size, n_dims) 78 | # logging tensors 79 | probs = torch.zeros(batch_size, max_iters) 80 | succs = torch.zeros(batch_size, max_iters) 81 | queries = torch.zeros(batch_size, max_iters) 82 | l2_norms = torch.zeros(batch_size, max_iters) 83 | linf_norms = torch.zeros(batch_size, max_iters) 84 | prev_probs = self.get_probs(images_batch, labels_batch) 85 | preds = self.get_preds(images_batch) 86 | if pixel_attack: 87 | trans = lambda z: z 88 | else: 89 | trans = lambda z: utils.block_idct(z, block_size=image_size, linf_bound=linf_bound) 90 | remaining_indices = torch.arange(0, batch_size).long() 91 | for k in range(max_iters): 92 | dim = indices[k] 93 | expanded = (images_batch[remaining_indices] + trans(self.expand_vector(x[remaining_indices], expand_dims))).clamp(0, 1) 94 | perturbation = trans(self.expand_vector(x, expand_dims)) 95 | l2_norms[:, k] = perturbation.view(batch_size, -1).norm(2, 1) 96 | linf_norms[:, k] = perturbation.view(batch_size, -1).abs().max(1)[0] 97 | preds_next = self.get_preds(expanded) 98 | preds[remaining_indices] = preds_next 99 | if targeted: 100 | remaining = preds.ne(labels_batch) 101 | else: 102 | remaining = preds.eq(labels_batch) 103 | # check if all images are misclassified and stop early 104 | if remaining.sum() == 0: 105 | adv = (images_batch + trans(self.expand_vector(x, expand_dims))).clamp(0, 1) 106 | probs_k = self.get_probs(adv, labels_batch) 107 | probs[:, k:] = probs_k.unsqueeze(1).repeat(1, max_iters - k) 108 | succs[:, k:] = torch.ones(batch_size, max_iters - k) 109 | queries[:, k:] = torch.zeros(batch_size, max_iters - k) 110 | break 111 | remaining_indices = torch.arange(0, batch_size)[remaining].long() 112 | if k > 0: 113 | succs[:, k] = ~remaining 114 | diff = torch.zeros(remaining.sum(), n_dims) 115 | diff[:, dim] = epsilon 116 | left_vec = x[remaining_indices] - diff 117 | right_vec = x[remaining_indices] + diff 118 | # trying negative direction 119 | adv = (images_batch[remaining_indices] + trans(self.expand_vector(left_vec, expand_dims))).clamp(0, 1) 120 | left_probs = self.get_probs(adv, labels_batch[remaining_indices]) 121 | queries_k = torch.zeros(batch_size) 122 | # increase query count for all images 123 | queries_k[remaining_indices] += 1 124 | if targeted: 125 | improved = left_probs.gt(prev_probs[remaining_indices]) 126 | else: 127 | improved = left_probs.lt(prev_probs[remaining_indices]) 128 | # only increase query count further by 1 for images that did not improve in adversarial loss 129 | if improved.sum() < remaining_indices.size(0): 130 | queries_k[remaining_indices[~improved]] += 1 131 | # try positive directions 132 | adv = (images_batch[remaining_indices] + trans(self.expand_vector(right_vec, expand_dims))).clamp(0, 1) 133 | right_probs = self.get_probs(adv, labels_batch[remaining_indices]) 134 | if targeted: 135 | right_improved = right_probs.gt(torch.max(prev_probs[remaining_indices], left_probs)) 136 | else: 137 | right_improved = right_probs.lt(torch.min(prev_probs[remaining_indices], left_probs)) 138 | probs_k = prev_probs.clone() 139 | # update x depending on which direction improved 140 | if improved.sum() > 0: 141 | left_indices = remaining_indices[improved] 142 | left_mask_remaining = improved.unsqueeze(1).repeat(1, n_dims) 143 | x[left_indices] = left_vec[left_mask_remaining].view(-1, n_dims) 144 | probs_k[left_indices] = left_probs[improved] 145 | if right_improved.sum() > 0: 146 | right_indices = remaining_indices[right_improved] 147 | right_mask_remaining = right_improved.unsqueeze(1).repeat(1, n_dims) 148 | x[right_indices] = right_vec[right_mask_remaining].view(-1, n_dims) 149 | probs_k[right_indices] = right_probs[right_improved] 150 | probs[:, k] = probs_k 151 | queries[:, k] = queries_k 152 | prev_probs = probs[:, k] 153 | if (k + 1) % log_every == 0 or k == max_iters - 1: 154 | print('Iteration %d: queries = %.4f, prob = %.4f, remaining = %.4f' % ( 155 | k + 1, queries.sum(1).mean(), probs[:, k].mean(), remaining.float().mean())) 156 | expanded = (images_batch + trans(self.expand_vector(x, expand_dims))).clamp(0, 1) 157 | preds = self.get_preds(expanded) 158 | if targeted: 159 | remaining = preds.ne(labels_batch) 160 | else: 161 | remaining = preds.eq(labels_batch) 162 | succs[:, max_iters-1] = ~remaining 163 | return expanded, probs, succs, queries, l2_norms, linf_norms 164 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision.transforms as trans 4 | import math 5 | from scipy.fftpack import dct, idct 6 | 7 | # mean and std for different datasets 8 | IMAGENET_SIZE = 224 9 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 10 | IMAGENET_STD = [0.229, 0.224, 0.225] 11 | IMAGENET_TRANSFORM = trans.Compose([ 12 | trans.Resize(256), 13 | trans.CenterCrop(224), 14 | trans.ToTensor()]) 15 | 16 | INCEPTION_SIZE = 299 17 | INCEPTION_TRANSFORM = trans.Compose([ 18 | trans.Resize(342), 19 | trans.CenterCrop(299), 20 | trans.ToTensor()]) 21 | 22 | CIFAR_SIZE = 32 23 | CIFAR_MEAN = [0.4914, 0.4822, 0.4465] 24 | CIFAR_STD = [0.2023, 0.1994, 0.2010] 25 | CIFAR_TRANSFORM = trans.Compose([ 26 | trans.ToTensor()]) 27 | 28 | MNIST_SIZE = 28 29 | MNIST_MEAN = [0.5] 30 | MNIST_STD = [1.0] 31 | MNIST_TRANSFORM = trans.Compose([ 32 | trans.ToTensor()]) 33 | 34 | 35 | # reverses the normalization transformation 36 | def invert_normalization(imgs, dataset): 37 | if dataset == 'imagenet': 38 | mean = IMAGENET_MEAN 39 | std = IMAGENET_STD 40 | elif dataset == 'cifar': 41 | mean = CIFAR_MEAN 42 | std = CIFAR_STD 43 | elif dataset == 'mnist': 44 | mean = MNIST_MEAN 45 | std = MNIST_STD 46 | imgs_trans = imgs.clone() 47 | if len(imgs.size()) == 3: 48 | for i in range(imgs.size(0)): 49 | imgs_trans[i, :, :] = imgs_trans[i, :, :] * std[i] + mean[i] 50 | else: 51 | for i in range(imgs.size(1)): 52 | imgs_trans[:, i, :, :] = imgs_trans[:, i, :, :] * std[i] + mean[i] 53 | return imgs_trans 54 | 55 | 56 | # applies the normalization transformations 57 | def apply_normalization(imgs, dataset): 58 | if dataset == 'imagenet': 59 | mean = IMAGENET_MEAN 60 | std = IMAGENET_STD 61 | elif dataset == 'cifar': 62 | mean = CIFAR_MEAN 63 | std = CIFAR_STD 64 | elif dataset == 'mnist': 65 | mean = MNIST_MEAN 66 | std = MNIST_STD 67 | else: 68 | mean = [0, 0, 0] 69 | std = [1, 1, 1] 70 | imgs_tensor = imgs.clone() 71 | if dataset == 'mnist': 72 | imgs_tensor = (imgs_tensor - mean[0]) / std[0] 73 | else: 74 | if imgs.dim() == 3: 75 | for i in range(imgs_tensor.size(0)): 76 | imgs_tensor[i, :, :] = (imgs_tensor[i, :, :] - mean[i]) / std[i] 77 | else: 78 | for i in range(imgs_tensor.size(1)): 79 | imgs_tensor[:, i, :, :] = (imgs_tensor[:, i, :, :] - mean[i]) / std[i] 80 | return imgs_tensor 81 | 82 | 83 | # get most likely predictions and probabilities for a set of inputs 84 | def get_preds(model, inputs, dataset_name, correct_class=None, batch_size=25, return_cpu=True): 85 | num_batches = int(math.ceil(inputs.size(0) / float(batch_size))) 86 | softmax = torch.nn.Softmax() 87 | all_preds, all_probs = None, None 88 | transform = trans.Normalize(IMAGENET_MEAN, IMAGENET_STD) 89 | for i in range(num_batches): 90 | upper = min((i + 1) * batch_size, inputs.size(0)) 91 | input = apply_normalization(inputs[(i * batch_size):upper], dataset_name) 92 | input_var = torch.autograd.Variable(input.cuda(), volatile=True) 93 | output = softmax.forward(model.forward(input_var)) 94 | if correct_class is None: 95 | prob, pred = output.max(1) 96 | else: 97 | prob, pred = output[:, correct_class], torch.autograd.Variable(torch.ones(output.size()) * correct_class) 98 | if return_cpu: 99 | prob = prob.data.cpu() 100 | pred = pred.data.cpu() 101 | else: 102 | prob = prob.data 103 | pred = pred.data 104 | if i == 0: 105 | all_probs = prob 106 | all_preds = pred 107 | else: 108 | all_probs = torch.cat((all_probs, prob), 0) 109 | all_preds = torch.cat((all_preds, pred), 0) 110 | return all_preds, all_probs 111 | 112 | 113 | # get least likely predictions and probabilities for a set of inputs 114 | def get_least_likely(model, inputs, dataset_name, batch_size=25, return_cpu=True): 115 | num_batches = int(math.ceil(inputs.size(0) / float(batch_size))) 116 | softmax = torch.nn.Softmax() 117 | all_preds, all_probs = None, None 118 | transform = trans.Normalize(IMAGENET_MEAN, IMAGENET_STD) 119 | for i in range(num_batches): 120 | upper = min((i + 1) * batch_size, inputs.size(0)) 121 | input = apply_normalization(inputs[(i * batch_size):upper], dataset_name) 122 | input_var = torch.autograd.Variable(input.cuda(), volatile=True) 123 | output = softmax.forward(model.forward(input_var)) 124 | prob, pred = output.min(1) 125 | if return_cpu: 126 | prob = prob.data.cpu() 127 | pred = pred.data.cpu() 128 | else: 129 | prob = prob.data 130 | pred = pred.data 131 | if i == 0: 132 | all_probs = prob 133 | all_preds = pred 134 | else: 135 | all_probs = torch.cat((all_probs, prob), 0) 136 | all_preds = torch.cat((all_preds, pred), 0) 137 | return all_preds, all_probs 138 | 139 | 140 | # defines a diagonal order 141 | # order is fixed across diagonals but are randomized across channels and within the diagonal 142 | # e.g. 143 | # [1, 2, 5] 144 | # [3, 4, 8] 145 | # [6, 7, 9] 146 | def diagonal_order(image_size, channels): 147 | x = torch.arange(0, image_size).cumsum(0) 148 | order = torch.zeros(image_size, image_size) 149 | for i in range(image_size): 150 | order[i, :(image_size - i)] = i + x[i:] 151 | for i in range(1, image_size): 152 | reverse = order[image_size - i - 1].index_select(0, torch.LongTensor([i for i in range(i-1, -1, -1)])) 153 | order[i, (image_size - i):] = image_size * image_size - 1 - reverse 154 | if channels > 1: 155 | order_2d = order 156 | order = torch.zeros(channels, image_size, image_size) 157 | for i in range(channels): 158 | order[i, :, :] = 3 * order_2d + i 159 | return order.view(1, -1).squeeze().long().sort()[1] 160 | 161 | 162 | # defines a block order, starting with top-left (initial_size x initial_size) submatrix 163 | # expanding by stride rows and columns whenever exhausted 164 | # randomized within the block and across channels 165 | # e.g. (initial_size=2, stride=1) 166 | # [1, 3, 6] 167 | # [2, 4, 9] 168 | # [5, 7, 8] 169 | def block_order(image_size, channels, initial_size=1, stride=1): 170 | order = torch.zeros(channels, image_size, image_size) 171 | total_elems = channels * initial_size * initial_size 172 | perm = torch.randperm(total_elems) 173 | order[:, :initial_size, :initial_size] = perm.view(channels, initial_size, initial_size) 174 | for i in range(initial_size, image_size, stride): 175 | num_elems = channels * (2 * stride * i + stride * stride) 176 | perm = torch.randperm(num_elems) + total_elems 177 | num_first = channels * stride * (stride + i) 178 | order[:, :(i+stride), i:(i+stride)] = perm[:num_first].view(channels, -1, stride) 179 | order[:, i:(i+stride), :i] = perm[num_first:].view(channels, stride, -1) 180 | total_elems += num_elems 181 | return order.view(1, -1).squeeze().long().sort()[1] 182 | 183 | 184 | # zeros all elements outside of the top-left (block_size * ratio) submatrix for every block 185 | def block_zero(x, block_size=8, ratio=0.5): 186 | z = torch.zeros(x.size()) 187 | num_blocks = int(x.size(2) / block_size) 188 | mask = torch.zeros(x.size(0), x.size(1), block_size, block_size) 189 | mask[:, :, :int(block_size * ratio), :int(block_size * ratio)] = 1 190 | for i in range(num_blocks): 191 | for j in range(num_blocks): 192 | z[:, :, (i * block_size):((i + 1) * block_size), (j * block_size):((j + 1) * block_size)] = x[:, :, (i * block_size):((i + 1) * block_size), (j * block_size):((j + 1) * block_size)] * mask 193 | return z 194 | 195 | 196 | # applies DCT to each block of size block_size 197 | def block_dct(x, block_size=8, masked=False, ratio=0.5): 198 | z = torch.zeros(x.size()) 199 | num_blocks = int(x.size(2) / block_size) 200 | mask = np.zeros((x.size(0), x.size(1), block_size, block_size)) 201 | mask[:, :, :int(block_size * ratio), :int(block_size * ratio)] = 1 202 | for i in range(num_blocks): 203 | for j in range(num_blocks): 204 | submat = x[:, :, (i * block_size):((i + 1) * block_size), (j * block_size):((j + 1) * block_size)].numpy() 205 | submat_dct = dct(dct(submat, axis=2, norm='ortho'), axis=3, norm='ortho') 206 | if masked: 207 | submat_dct = submat_dct * mask 208 | submat_dct = torch.from_numpy(submat_dct) 209 | z[:, :, (i * block_size):((i + 1) * block_size), (j * block_size):((j + 1) * block_size)] = submat_dct 210 | return z 211 | 212 | 213 | # applies IDCT to each block of size block_size 214 | def block_idct(x, block_size=8, masked=False, ratio=0.5, linf_bound=0.0): 215 | z = torch.zeros(x.size()) 216 | num_blocks = int(x.size(2) / block_size) 217 | mask = np.zeros((x.size(0), x.size(1), block_size, block_size)) 218 | if type(ratio) != float: 219 | for i in range(x.size(0)): 220 | mask[i, :, :int(block_size * ratio[i]), :int(block_size * ratio[i])] = 1 221 | else: 222 | mask[:, :, :int(block_size * ratio), :int(block_size * ratio)] = 1 223 | for i in range(num_blocks): 224 | for j in range(num_blocks): 225 | submat = x[:, :, (i * block_size):((i + 1) * block_size), (j * block_size):((j + 1) * block_size)].numpy() 226 | if masked: 227 | submat = submat * mask 228 | z[:, :, (i * block_size):((i + 1) * block_size), (j * block_size):((j + 1) * block_size)] = torch.from_numpy(idct(idct(submat, axis=3, norm='ortho'), axis=2, norm='ortho')) 229 | if linf_bound > 0: 230 | return z.clamp(-linf_bound, linf_bound) 231 | else: 232 | return z --------------------------------------------------------------------------------