├── attacks ├── __init__.py ├── __pycache__ │ ├── BPDA.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ └── l2_attack.cpython-36.pyc ├── BPDA.py └── l2_attack.py ├── examples ├── bpda.jpg ├── clean.jpg └── l2_attack.jpg ├── __pycache__ └── util.cpython-36.pyc ├── README.md ├── util.py ├── train.py └── test_BPDA.py /attacks/__init__.py: -------------------------------------------------------------------------------- 1 | from attacks.BPDA import BPDAattack 2 | from attacks.l2_attack import CarliniL2 -------------------------------------------------------------------------------- /examples/bpda.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Annonymous-repos/attacks-in-pytorch/HEAD/examples/bpda.jpg -------------------------------------------------------------------------------- /examples/clean.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Annonymous-repos/attacks-in-pytorch/HEAD/examples/clean.jpg -------------------------------------------------------------------------------- /examples/l2_attack.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Annonymous-repos/attacks-in-pytorch/HEAD/examples/l2_attack.jpg -------------------------------------------------------------------------------- /__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Annonymous-repos/attacks-in-pytorch/HEAD/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /attacks/__pycache__/BPDA.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Annonymous-repos/attacks-in-pytorch/HEAD/attacks/__pycache__/BPDA.cpython-36.pyc -------------------------------------------------------------------------------- /attacks/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Annonymous-repos/attacks-in-pytorch/HEAD/attacks/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /attacks/__pycache__/l2_attack.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Annonymous-repos/attacks-in-pytorch/HEAD/attacks/__pycache__/l2_attack.cpython-36.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # attacks-in-pytorch 2 | Reproduces BPDA attack in pytorch 3 | 4 | 5 | Pytorch implementation of two attack methods in paper [Obfuscated Gradients Give a False Sense of Security:Circumventing Defenses to Adversarial Examples](https://arxiv.org/abs/1802.00420) 6 | 7 | ## Environment 8 | 9 | - python=3.6.8 10 | - pytorch=1.1.0 11 | - numpy=1.13.3 12 | - [advertorch](https://github.com/BorealisAI/advertorch) 13 | 14 | ## Acknowledgement 15 | This repository utilizes the source codes of "Obfuscated Gradients Give a False Sense of Security: Circumventing Defenses to Adversarial Examples" 16 | 17 | - [obfuscated-gradients](https://github.com/anishathalye/obfuscated-gradients) 18 | 19 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.utils import save_image 4 | import math 5 | 6 | 7 | def accu(output, target): 8 | with torch.no_grad(): 9 | pred = torch.argmax(output, dim=1) 10 | target = target 11 | assert pred.shape[0] == len(target) 12 | correct = 0 13 | correct += torch.sum(pred == target).item() 14 | return correct / len(target) 15 | 16 | 17 | def ensure_dir(path): 18 | if not os.path.exists(path): 19 | os.makedirs(path) 20 | 21 | 22 | def _save_image(img_dir, image, name): 23 | file_name = img_dir + '/' + name + '.jpg' 24 | nrow = int(math.sqrt(image.shape[0])) 25 | save_image(image.cpu(), file_name, nrow=nrow, padding=2, normalize=True) 26 | return 27 | -------------------------------------------------------------------------------- /attacks/BPDA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class BPDAattack(object): 7 | def __init__(self, model=None, defense=None, device=None, epsilon=None, learning_rate=0.5, 8 | max_iterations=100, clip_min=0, clip_max=1): 9 | self.model = model 10 | self.epsilon = epsilon 11 | self.loss_fn = nn.CrossEntropyLoss(reduction='sum') 12 | self.defense = defense 13 | self.clip_min = clip_min 14 | self.clip_max = clip_max 15 | 16 | self.LEARNING_RATE = learning_rate 17 | self.MAX_ITERATIONS = max_iterations 18 | self.device = device 19 | 20 | def generate(self, x, y): 21 | """ 22 | Given examples (X_nat, y), returns their adversarial 23 | counterparts with an attack length of epsilon. 24 | 25 | """ 26 | 27 | adv = x.detach().clone() 28 | 29 | lower = np.clip(x.detach().cpu().numpy() - self.epsilon, self.clip_min, self.clip_max) 30 | upper = np.clip(x.detach().cpu().numpy() + self.epsilon, self.clip_min, self.clip_max) 31 | 32 | for i in range(self.MAX_ITERATIONS): 33 | adv_purified = self.defense(adv) 34 | adv_purified.requires_grad_() 35 | adv_purified.retain_grad() 36 | 37 | scores = self.model(adv_purified) 38 | loss = self.loss_fn(scores, y) 39 | loss.backward() 40 | 41 | grad_sign = adv_purified.grad.data.sign() 42 | 43 | # early stop, only for batch_size = 1 44 | # p = torch.argmax(F.softmax(scores), 1) 45 | # if y != p: 46 | # break 47 | 48 | adv += self.LEARNING_RATE * grad_sign 49 | 50 | adv_img = np.clip(adv.detach().cpu().numpy(), lower, upper) 51 | adv = torch.Tensor(adv_img).to(self.device) 52 | return adv 53 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torchvision import transforms 4 | import torchvision.datasets as datasets 5 | import torch.nn as nn 6 | 7 | MNIST_PATH = 'data/' 8 | 9 | 10 | def get_mnist_train_loader(batch_size, shuffle=True): 11 | return torch.utils.data.DataLoader( 12 | datasets.MNIST(MNIST_PATH, train=True, download=True, 13 | transform=transforms.ToTensor()), 14 | batch_size=batch_size, shuffle=shuffle) 15 | 16 | 17 | class LeNet(nn.Module): 18 | def __init__(self): 19 | super(LeNet, self).__init__() 20 | self.Block = nn.Sequential( 21 | nn.Conv2d(1, 20, 5), 22 | nn.MaxPool2d(kernel_size=2, stride=2), 23 | nn.Conv2d(20, 50, 5), 24 | nn.MaxPool2d(kernel_size=2), 25 | ) 26 | self.linear1 = nn.Sequential( 27 | nn.Linear(50 * 16, 500), 28 | nn.ReLU()) 29 | self.linear2 = nn.Linear(500, 10) 30 | 31 | def forward(self, x): 32 | x = self.Block(x) 33 | x = x.reshape(x.shape[0], -1) 34 | x = self.linear1(x) 35 | logits = self.linear2(x) 36 | return logits 37 | 38 | 39 | def main(args): 40 | device = torch.device('cuda:' + args.device) 41 | data_loader = get_mnist_train_loader(100) 42 | 43 | loss_fn = nn.CrossEntropyLoss() 44 | 45 | model = LeNet() 46 | model.to(device) 47 | 48 | optimiser = torch.optim.Adam(model.parameters(), lr=0.001) 49 | 50 | for epoch in range(10): 51 | for batch_idx, (gt_image, label) in enumerate(data_loader): 52 | gt_image, label = gt_image.to(device), label.to(device) 53 | optimiser.zero_grad() 54 | logits = model(gt_image) 55 | loss = loss_fn(logits, label) 56 | loss.backward() 57 | optimiser.step() 58 | torch.save(model.state_dict(), 'LeNet.pth') 59 | 60 | 61 | if __name__ == '__main__': 62 | parser = argparse.ArgumentParser(description='PyTorch Template') 63 | parser.add_argument('-d', '--device', default='0', type=str, 64 | help='the device used for computing (gpu is required)') 65 | 66 | args = parser.parse_args() 67 | 68 | main(args) 69 | -------------------------------------------------------------------------------- /test_BPDA.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from tqdm import tqdm 4 | import torch.nn as nn 5 | from attacks import BPDAattack 6 | from util import _save_image, ensure_dir, accu 7 | from advertorch.defenses import MedianSmoothing2D, BitSqueezing, JPEGFilter 8 | from torchvision import transforms 9 | import torchvision.datasets as datasets 10 | 11 | MNIST_PATH = 'data/' 12 | 13 | 14 | def get_mnist_test_loader(batch_size, shuffle=False): 15 | return torch.utils.data.DataLoader( 16 | datasets.MNIST(MNIST_PATH, train=False, download=True, 17 | transform=transforms.ToTensor()), 18 | batch_size=batch_size, shuffle=shuffle) 19 | 20 | 21 | class LeNet(nn.Module): 22 | def __init__(self): 23 | super(LeNet, self).__init__() 24 | self.Block = nn.Sequential( 25 | nn.Conv2d(1, 20, 5), 26 | nn.MaxPool2d(kernel_size=2, stride=2), 27 | nn.Conv2d(20, 50, 5), 28 | nn.MaxPool2d(kernel_size=2), 29 | ) 30 | self.linear1 = nn.Sequential( 31 | nn.Linear(50 * 16, 500), 32 | nn.ReLU()) 33 | self.linear2 = nn.Linear(500, 10) 34 | 35 | def forward(self, x): 36 | x = self.Block(x) 37 | x = x.reshape(x.shape[0], -1) 38 | x = self.linear1(x) 39 | logits = self.linear2(x) 40 | return logits 41 | 42 | 43 | class whitebox(object): 44 | 45 | def __init__(self, args): 46 | self.args = args 47 | 48 | # setup data_loader instances 49 | self.data_loader = get_mnist_test_loader(100, shuffle=False) 50 | 51 | # setup device 52 | self.device = torch.device('cuda:' + args.device) 53 | 54 | # build defense architecture 55 | bits_squeezing = BitSqueezing(bit_depth=5) 56 | median_filter = MedianSmoothing2D(kernel_size=3) 57 | jpeg_filter = JPEGFilter(10) 58 | 59 | self.defense = nn.Sequential( 60 | jpeg_filter, 61 | bits_squeezing, 62 | median_filter, 63 | ) 64 | # build classifier architecture 65 | self.oracle = LeNet() 66 | self.oracle = self.oracle.to(self.device) 67 | self.oracle.load_state_dict(torch.load(args.resume_oracle)) 68 | self.oracle.eval() 69 | 70 | self.adversary = BPDAattack(self.oracle, self.defense, self.device, 71 | epsilon=0.3, 72 | learning_rate=0.5, 73 | max_iterations=100) 74 | 75 | def eval_(self): 76 | """ 77 | :return: 78 | """ 79 | total_metrics = 0 80 | defense_metrics = 0 81 | for batch_idx, (gt_image, label) in enumerate(tqdm(self.data_loader)): 82 | gt_image, label = gt_image.to(self.device), label.to(self.device) 83 | 84 | adv = self.adversary.generate(gt_image, label) 85 | adv = adv.detach() 86 | 87 | logits = self.oracle(adv) 88 | total_metrics += accu(logits, label) 89 | 90 | reformed = self.defense(adv) 91 | logits = self.oracle(reformed) 92 | defense_metrics += accu(logits, label) 93 | 94 | total_metrics /= len(self.data_loader) 95 | defense_metrics /= len(self.data_loader) 96 | 97 | return total_metrics, defense_metrics, adv, gt_image, reformed 98 | 99 | 100 | def main(args): 101 | ensure_dir(args.results_dir) 102 | wbx = whitebox(args) 103 | 104 | total_accu, defense_aacu, adv, gt, reform = wbx.eval_() 105 | _save_image(args.results_dir, adv, 'adv') 106 | _save_image(args.results_dir, gt, 'gt') 107 | _save_image(args.results_dir, reform, 'reform') 108 | print(total_accu) 109 | print(defense_aacu) 110 | 111 | 112 | if __name__ == '__main__': 113 | parser = argparse.ArgumentParser(description='PyTorch Template') 114 | 115 | parser.add_argument('--resume_oracle', default='LeNet.pth', type=str, 116 | help='path to latest checkpoint of oracle (default: None)') 117 | 118 | parser.add_argument('--results_dir', default='results', type=str, 119 | help='output dictionary') 120 | 121 | parser.add_argument('--device', default='0', type=str, 122 | help='the device used for computing (gpu is required)') 123 | 124 | args = parser.parse_args() 125 | 126 | main(args) 127 | -------------------------------------------------------------------------------- /attacks/l2_attack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | class CarliniL2: 8 | def __init__(self, model, gan, device, confidence=0, targeted=False, learning_rate=1e-1, 9 | binary_search_steps=5, max_iterations=10000, abort_early=False, initial_const=1, 10 | clip_min=0, clip_max=1): 11 | self.TARGETED = targeted 12 | self.LEARNING_RATE = learning_rate 13 | self.MAX_ITERATIONS = max_iterations 14 | self.BINARY_SEARCH_STEPS = binary_search_steps 15 | self.ABORT_EARLY = abort_early 16 | self.CONFIDENCE = confidence 17 | self.initial_const = initial_const 18 | self.clip_min = clip_min 19 | self.clip_max = clip_max 20 | self.model = model 21 | self.device = device 22 | self.gan = gan 23 | self.learning_rate = learning_rate 24 | self.repeat = binary_search_steps >= 10 25 | 26 | def get_or_guess_labels(self, x, y=None): 27 | """ 28 | Get the label to use in generating an adversarial example for x. 29 | The kwargs are fed directly from the kwargs of the attack. 30 | If 'y' is in kwargs, use that as the label. 31 | Otherwise, use the model's prediction as the label. 32 | """ 33 | if y is not None: 34 | labels = y 35 | else: 36 | preds = F.softmax(self.model(x)) 37 | preds_max = torch.max(preds, 1, keepdim=True)[0] 38 | original_predictions = (preds == preds_max) 39 | labels = original_predictions 40 | del preds 41 | return labels.float() 42 | 43 | def atanh(self, x): 44 | return 0.5 * torch.log((1 + x) / (1 - x)) 45 | 46 | def to_one_hot(self, x): 47 | one_hot = torch.FloatTensor(x.shape[0], 10).to(x.get_device()) 48 | one_hot.zero_() 49 | x = x.unsqueeze(1) 50 | one_hot = one_hot.scatter_(1, x, 1) 51 | return one_hot 52 | 53 | def generate(self, imgs, y, start): 54 | 55 | batch_size = imgs.shape[0] 56 | labs = self.get_or_guess_labels(imgs, y) 57 | 58 | def compare(x, y): 59 | if self.TARGETED is None: return True 60 | 61 | if sum(x.shape) != 0: 62 | x = x.clone() 63 | if self.TARGETED: 64 | x[y] -= self.CONFIDENCE 65 | else: 66 | x[y] += self.CONFIDENCE 67 | x = torch.argmax(x) 68 | if self.TARGETED: 69 | return x == y 70 | else: 71 | return x != y 72 | 73 | # set the lower and upper bounds accordingly 74 | lower_bound = torch.zeros(batch_size).to(self.device) 75 | CONST = torch.ones(batch_size).to(self.device) * self.initial_const 76 | upper_bound = (torch.ones(batch_size) * 1e10).to(self.device) 77 | 78 | # the best l2, score, and image attack 79 | o_bestl2 = [1e10] * batch_size 80 | o_bestscore = [-1] * batch_size 81 | o_bestattack = self.gan(start) 82 | 83 | # check if the input label is one-hot, if not, then change it into one-hot vector 84 | if len(labs.shape) == 1: 85 | tlabs = self.to_one_hot(labs.long()) 86 | else: 87 | tlabs = labs 88 | 89 | for outer_step in range(self.BINARY_SEARCH_STEPS): 90 | # completely reset adam's internal state. 91 | modifier = nn.Parameter(start) 92 | optimizer = torch.optim.Adam([modifier, ], lr=self.learning_rate) 93 | 94 | bestl2 = [1e10] * batch_size 95 | bestscore = -1 * torch.ones(batch_size, dtype=torch.float32).to(self.device) 96 | 97 | # The last iteration (if we run many steps) repeat the search once. 98 | if self.repeat and outer_step == self.BINARY_SEARCH_STEPS - 1: 99 | CONST = upper_bound 100 | prev = 1e6 101 | 102 | for i in range(self.MAX_ITERATIONS): 103 | optimizer.zero_grad() 104 | nimgs = self.gan(modifier.to(self.device)) 105 | 106 | # distance to the input data 107 | l2dist = torch.sum(torch.sum(torch.sum((nimgs - imgs) ** 2, 1), 1), 1) 108 | loss2 = torch.sum(l2dist) 109 | 110 | # prediction BEFORE-SOFTMAX of the model 111 | scores = self.model(nimgs) 112 | 113 | # compute the probability of the label class versus the maximum other 114 | other = torch.max(((1 - tlabs) * scores - tlabs * 10000), 1)[0] 115 | real = torch.sum(tlabs * scores, 1) 116 | 117 | if self.TARGETED: 118 | # if targeted, optimize for making the other class most likely 119 | loss1 = torch.max(torch.zeros_like(other), other - real + self.CONFIDENCE) 120 | else: 121 | # if untargeted, optimize for making this class least likely. 122 | loss1 = torch.max(torch.zeros_like(other), real - other + self.CONFIDENCE) 123 | 124 | # sum up the losses 125 | loss1 = torch.sum(CONST * loss1) 126 | loss = loss1 + loss2 127 | 128 | # update the modifier 129 | loss.backward() 130 | optimizer.step() 131 | 132 | # check if we should abort search if we're getting nowhere. 133 | if self.ABORT_EARLY and i % ((self.MAX_ITERATIONS // 10) or 1) == 0: 134 | if loss > prev * .9999: 135 | # print('Stop early') 136 | break 137 | prev = loss 138 | 139 | # adjust the best result found so far 140 | for e, (l2, sc, ii) in enumerate(zip(l2dist, scores, nimgs)): 141 | lab = torch.argmax(tlabs[e]) 142 | 143 | if l2 < bestl2[e] and compare(sc, lab): 144 | bestl2[e] = l2 145 | bestscore[e] = torch.argmax(sc) 146 | 147 | if l2 < o_bestl2[e] and compare(sc, lab): 148 | o_bestl2[e] = l2 149 | o_bestscore[e] = torch.argmax(sc) 150 | o_bestattack[e] = ii 151 | 152 | # adjust the constant as needed 153 | for e in range(batch_size): 154 | if compare(bestscore[e], torch.argmax(tlabs[e]).float()) and \ 155 | bestscore[e] != -1: 156 | # success, divide CONST by two 157 | upper_bound[e] = min(upper_bound[e], CONST[e]) 158 | if upper_bound[e] < 1e9: 159 | CONST[e] = (lower_bound[e] + upper_bound[e]) / 2 160 | else: 161 | # failure, either multiply by 10 if no solution found yet 162 | # or do binary search with the known upper bound 163 | lower_bound[e] = max(lower_bound[e], CONST[e]) 164 | if upper_bound[e] < 1e9: 165 | CONST[e] = (lower_bound[e] + upper_bound[e]) / 2 166 | else: 167 | CONST[e] *= 10 168 | 169 | # return the best solution found 170 | o_bestl2 = np.array(o_bestl2) 171 | return o_bestattack 172 | --------------------------------------------------------------------------------