├── tests ├── __init__.py ├── test_datasets.py ├── test_utils.py └── test_functional.py ├── adversarial ├── __init__.py ├── models.py ├── utils.py ├── attacks.py ├── datasets.py └── functional.py ├── models ├── mnist_natural.pt ├── mnist_pgd_k=40_step=0.1_eps=3.0.pt └── mnist_iterated_fgsm_k=40_step=0.01_eps=0.3.pt ├── assets └── pgd_attack_imagenet_example.png ├── config.py ├── scripts ├── experiments.txt ├── train_natural.py └── train_adversarial.py ├── requirements.txt ├── README.md └── .gitignore /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /adversarial/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/mnist_natural.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oscarknagg/adversarial/HEAD/models/mnist_natural.pt -------------------------------------------------------------------------------- /assets/pgd_attack_imagenet_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oscarknagg/adversarial/HEAD/assets/pgd_attack_imagenet_example.png -------------------------------------------------------------------------------- /models/mnist_pgd_k=40_step=0.1_eps=3.0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oscarknagg/adversarial/HEAD/models/mnist_pgd_k=40_step=0.1_eps=3.0.pt -------------------------------------------------------------------------------- /models/mnist_iterated_fgsm_k=40_step=0.01_eps=0.3.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oscarknagg/adversarial/HEAD/models/mnist_iterated_fgsm_k=40_step=0.01_eps=0.3.pt -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | PATH = os.path.dirname(os.path.realpath(__file__)) 4 | 5 | DATA_PATH = None 6 | 7 | if DATA_PATH is None: 8 | raise Exception('Configure your data folder location in config.py before continuing!') 9 | -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from torch.utils.data import DataLoader 3 | from multiprocessing import cpu_count 4 | 5 | from adversarial.datasets import RestrictedImageNet 6 | 7 | 8 | class TestRestrictedImageNet(unittest.TestCase): 9 | def test_dataset(self): 10 | data = RestrictedImageNet() 11 | 12 | dataloader = DataLoader(data, batch_size=128, num_workers=cpu_count()) 13 | 14 | for batch in dataloader: 15 | print(batch[0].shape) 16 | break 17 | -------------------------------------------------------------------------------- /adversarial/models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import functional as F 3 | 4 | 5 | class MNISTClassifier(nn.Module): 6 | def __init__(self): 7 | super(MNISTClassifier, self).__init__() 8 | self.conv1 = nn.Conv2d(1, 32, kernel_size=5) 9 | self.conv2 = nn.Conv2d(32, 64, kernel_size=5) 10 | self.fc1 = nn.Linear(1024, 10) 11 | 12 | def forward(self, x): 13 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 14 | x = F.relu(F.max_pool2d(self.conv2(x), 2)) 15 | x = x.view(-1, 1024) 16 | x = self.fc1(x) 17 | return F.log_softmax(x, dim=1) 18 | -------------------------------------------------------------------------------- /scripts/experiments.txt: -------------------------------------------------------------------------------- 1 | python -m scripts.train_adversarial --attack FGSM --eps 0.05 2 | python -m scripts.train_adversarial --attack FGSM --eps 0.1 3 | python -m scripts.train_adversarial --attack FGSM --eps 0.15 4 | python -m scripts.train_adversarial --attack FGSM --eps 0.2 5 | python -m scripts.train_adversarial --attack FGSM --eps 0.25 6 | 7 | python -m scripts.eval_adversarial --attack FGSM --model mnist_attack=FGSM_eps=0.05 8 | python -m scripts.eval_adversarial --attack FGSM --model mnist_attack=FGSM_eps=0.1 9 | python -m scripts.eval_adversarial --attack FGSM --model mnist_attack=FGSM_eps=0.15 10 | python -m scripts.eval_adversarial --attack FGSM --model mnist_attack=FGSM_eps=0.2 11 | python -m scripts.eval_adversarial --attack FGSM --model mnist_attack=FGSM_eps=0.25 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | atomicwrites==1.2.1 2 | attrs==18.2.0 3 | backcall==0.1.0 4 | bleach==3.0.2 5 | cycler==0.10.0 6 | decorator==4.3.0 7 | defusedxml==0.5.0 8 | entrypoints==0.2.3 9 | ipykernel==5.1.0 10 | ipython==7.2.0 11 | ipython-genutils==0.2.0 12 | ipywidgets==7.4.2 13 | jedi==0.13.2 14 | Jinja2==2.10 15 | jsonschema==2.6.0 16 | jupyter==1.0.0 17 | jupyter-client==5.2.4 18 | jupyter-console==6.0.0 19 | jupyter-core==4.4.0 20 | kiwisolver==1.0.1 21 | MarkupSafe==1.1.0 22 | matplotlib==3.0.2 23 | mistune==0.8.4 24 | more-itertools==4.3.0 25 | nbconvert==5.4.0 26 | nbformat==4.4.0 27 | notebook==5.7.4 28 | numpy==1.15.4 29 | olympic==0.1.3 30 | pandas==0.23.4 31 | pandocfilters==1.4.2 32 | parso==0.3.1 33 | pexpect==4.6.0 34 | pickleshare==0.7.5 35 | Pillow==5.3.0 36 | pkg-resources==0.0.0 37 | pluggy==0.8.0 38 | prometheus-client==0.5.0 39 | prompt-toolkit==2.0.7 40 | ptyprocess==0.6.0 41 | py==1.7.0 42 | Pygments==2.3.1 43 | pyparsing==2.3.0 44 | pytest==4.0.1 45 | python-dateutil==2.7.5 46 | pytz==2018.7 47 | pyzmq==17.1.2 48 | qtconsole==4.4.3 49 | scikit-learn==0.20.2 50 | scipy==1.2.0 51 | Send2Trash==1.5.0 52 | six==1.12.0 53 | terminado==0.8.1 54 | testpath==0.4.2 55 | torch==1.0.0 56 | torchvision==0.2.1 57 | tornado==5.1.1 58 | tqdm==4.28.1 59 | traitlets==4.3.2 60 | wcwidth==0.1.7 61 | webencodings==0.5.1 62 | widgetsnbextension==3.4.2 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # adversarial 2 | 3 | This repository contains PyTorch code to create and defend against 4 | adversarial attacks. 5 | 6 | See [this Medium article](https://towardsdatascience.com/know-your-enemy-7f7c5038bdf3) 7 | for a discussion on how to use and defend against 8 | the projected gradient attack. 9 | 10 | Example adversarial attack created using this repo. 11 | 12 | ![PGD Attack](https://github.com/oscarknagg/adversarial/blob/master/assets/pgd_attack_imagenet_example.png) 13 | 14 | 15 | Cool fact - adversarially trained discriminative (_not generative!_) 16 | models can be used to interpolate between classes by creating 17 | large-epsilon adversarial examples against them. 18 | 19 | ![MNIST Class Interpolation](https://media.giphy.com/media/NlGeQeG4jUViIcZRAD/giphy.gif) 20 | 21 | # Contents 22 | 23 | - A Jupyter notebook demonstrating how to use and defend against 24 | the projected gradient attack (see `notebooks/`) 25 | 26 | - `adversarial.functional` contains functional style implementations of 27 | a view different types of adversarial attacks 28 | - Fast Gradient Sign Method - white box - batch implementation 29 | - Projected Gradient Descent - white box - batch implementation 30 | - Local-search attack - black box, score-based - single image 31 | - Boundary attack - black box, decision-based - single imagae 32 | 33 | 34 | # Setup 35 | ## Requirements 36 | 37 | Listed in `requirements.txt`. Install with 38 | `pip install -r requirements.txt` preferably in a virtualenv. 39 | 40 | ## Tests (optional) 41 | 42 | Run `pytest` in the root directory to run all tests. -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /scripts/train_natural.py: -------------------------------------------------------------------------------- 1 | from torch import nn, optim 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms, datasets 4 | from multiprocessing import cpu_count 5 | from olympic.callbacks import * 6 | import argparse 7 | import olympic 8 | 9 | from adversarial.models import MNISTClassifier 10 | from config import PATH 11 | 12 | ############## 13 | # Parameters # 14 | ############## 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--dataset', default='mnist') 17 | parser.add_argument('--device', default='cuda') 18 | args = parser.parse_args() 19 | 20 | if args.dataset != 'mnist': 21 | raise NotImplementedError 22 | 23 | 24 | transform = transforms.Compose([ 25 | transforms.ToTensor(), 26 | ]) 27 | 28 | train = datasets.MNIST(f'{PATH}/data/', train=True, transform=transform, download=True) 29 | val = datasets.MNIST(f'{PATH}/data/', train=False, transform=transform, download=True) 30 | 31 | train_loader = DataLoader(train, batch_size=128, num_workers=cpu_count()) 32 | val_loader = DataLoader(val, batch_size=128, num_workers=cpu_count()) 33 | 34 | model = MNISTClassifier().to(args.device) 35 | optimiser = optim.SGD(model.parameters(), lr=0.1) 36 | loss_fn = nn.CrossEntropyLoss() 37 | 38 | callbacks = [ 39 | Evaluate(val_loader), 40 | ReduceLROnPlateau(monitor='val_accuracy', patience=5), 41 | ModelCheckpoint(f'{PATH}/models/{args.dataset}_natural.pt', save_best_only=True, monitor='val_loss', verbose=True), 42 | CSVLogger(f'{PATH}/logs/{args.dataset}_natural.csv') 43 | ] 44 | 45 | olympic.fit( 46 | model, 47 | optimiser, 48 | loss_fn, 49 | dataloader=train_loader, 50 | epochs=10, 51 | metrics=['accuracy'], 52 | callbacks=callbacks, 53 | prepare_batch=lambda batch: (batch[0].to(args.device), batch[1].to(args.device)) 54 | ) 55 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import torch 4 | 5 | from adversarial.utils import * 6 | 7 | 8 | shape = (1, 28, 28) 9 | 10 | 11 | class TestProject(unittest.TestCase): 12 | def test_2_norm(self): 13 | norm = 2 14 | eps = 0.5 15 | 16 | # Test batchwise projection 17 | x = torch.zeros((2, ) + shape) 18 | y = torch.cat([ 19 | torch.ones((1, ) + shape), 20 | torch.zeros((1,) + shape) 21 | ]) 22 | 23 | y_proj = project(x, y, norm, eps) 24 | 25 | within_norm_balls = y_proj.view(y_proj.shape[0], -1).norm(norm, dim=-1) <= eps 26 | self.assertTrue(torch.all(within_norm_balls)) 27 | 28 | def test_inf_norm(self): 29 | norm = 'inf' 30 | eps = 0.5 31 | 32 | # Test batchwise projection 33 | x = torch.zeros((2,) + shape) 34 | y = torch.cat([ 35 | torch.ones((1,) + shape), 36 | torch.zeros((1,) + shape) 37 | ]) 38 | 39 | y_proj = project(x, y, norm, eps) 40 | 41 | within_norm_balls = y_proj.view(y_proj.shape[0], -1).norm(float(norm), dim=-1) <= eps 42 | self.assertTrue(torch.all(within_norm_balls)) 43 | 44 | 45 | class TestRandomPerturbation(unittest.TestCase): 46 | def test_2_norm(self): 47 | norm = 2 48 | eps = 0.5 49 | 50 | x = torch.zeros((2,) + shape) 51 | x_ = random_perturbation(x, norm, eps) 52 | within_norm_balls = x_.view(x_.shape[0], -1).norm(norm, dim=-1) <= eps 53 | self.assertTrue(torch.all(within_norm_balls)) 54 | 55 | def test_inf_norm(self): 56 | norm = 'inf' 57 | eps = 0.5 58 | 59 | x = torch.zeros((2,) + shape) 60 | x_ = random_perturbation(x, norm, eps) 61 | within_norm_balls = x_.view(x_.shape[0], -1).norm(float(norm), dim=-1) <= eps 62 | self.assertTrue(torch.all(within_norm_balls)) 63 | -------------------------------------------------------------------------------- /adversarial/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple 2 | from torch.nn import Module 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def project(x: torch.Tensor, x_adv: torch.Tensor, norm: Union[str, int], eps: float) -> torch.Tensor: 8 | """Projects x_adv into the l_norm ball around x 9 | 10 | Assumes x and x_adv are 4D Tensors representing batches of images 11 | 12 | Args: 13 | x: Batch of natural images 14 | x_adv: Batch of adversarial images 15 | norm: Norm of ball around x 16 | eps: Radius of ball 17 | 18 | Returns: 19 | x_adv: Adversarial examples projected to be at most eps 20 | distance from x under a certain norm 21 | """ 22 | if x.shape != x_adv.shape: 23 | raise ValueError('Input Tensors must have the same shape') 24 | 25 | if norm == 'inf': 26 | # Workaround as PyTorch doesn't have elementwise clip 27 | x_adv = torch.max(torch.min(x_adv, x + eps), x - eps) 28 | else: 29 | delta = x_adv - x 30 | 31 | # Assume x and x_adv are batched tensors where the first dimension is 32 | # a batch dimension 33 | mask = delta.view(delta.shape[0], -1).norm(norm, dim=1) <= eps 34 | 35 | scaling_factor = delta.view(delta.shape[0], -1).norm(norm, dim=1) 36 | scaling_factor[mask] = eps 37 | 38 | # .view() assumes batched images as a 4D Tensor 39 | delta *= eps / scaling_factor.view(-1, 1, 1, 1) 40 | 41 | x_adv = x + delta 42 | 43 | return x_adv 44 | 45 | 46 | def random_perturbation(x: torch.Tensor, norm: Union[str, int], eps: float) -> torch.Tensor: 47 | """Applies a random l_norm bounded perturbation to x 48 | 49 | Assumes x is a 4D Tensor representing a batch of images 50 | 51 | Args: 52 | x: Batch of images 53 | norm: Norm to measure size of perturbation 54 | eps: Size of perturbation 55 | 56 | Returns: 57 | x_perturbed: Randomly perturbed version of x 58 | """ 59 | perturbation = torch.normal(torch.zeros_like(x), torch.ones_like(x)) 60 | if norm == 'inf': 61 | perturbation = torch.sign(perturbation) * eps 62 | else: 63 | perturbation = project(torch.zeros_like(x), perturbation, norm, eps) 64 | 65 | return x + perturbation 66 | 67 | 68 | def generate_misclassified_sample(model: Module, 69 | x: torch.Tensor, 70 | y: torch.Tensor, 71 | clamp: Tuple[float, float] = (0, 1)) -> torch.Tensor: 72 | """Generates an arbitrary misclassified sample 73 | 74 | Args: 75 | model: Model that must misclassify 76 | x: Batch of image data 77 | y: Corresponding labels 78 | clamp: Max and minimum values of elements in the samples i.e. (0, 1) for MNIST 79 | 80 | Returns: 81 | x_misclassified: A sample for the model that is not classified correctly 82 | """ 83 | while True: 84 | x_misclassified = torch.empty_like(x).uniform_(*clamp) 85 | 86 | if model(x_misclassified).argmax(dim=1) != y: 87 | return x_misclassified 88 | -------------------------------------------------------------------------------- /scripts/train_adversarial.py: -------------------------------------------------------------------------------- 1 | from torch import nn, optim 2 | from torch.utils.data import DataLoader, Subset 3 | from torchvision import transforms, datasets, models 4 | from multiprocessing import cpu_count 5 | from olympic.callbacks import * 6 | import argparse 7 | import olympic 8 | 9 | from adversarial.models import MNISTClassifier 10 | from adversarial.attacks import * 11 | from adversarial.functional import * 12 | from adversarial.datasets import RestrictedImageNet 13 | from config import PATH 14 | 15 | 16 | torch.backends.cudnn.benchmark = True 17 | 18 | 19 | ############## 20 | # Parameters # 21 | ############## 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--dataset', default='mnist') 24 | parser.add_argument('--attack') 25 | parser.add_argument('--eps', type=float) 26 | parser.add_argument('--step', type=float) 27 | parser.add_argument('--k', type=int) 28 | parser.add_argument('--norm', default='inf') 29 | parser.add_argument('--device', default='cuda') 30 | parser.add_argument('--epochs', type=int) 31 | parser.add_argument('--batch-size', type=int) 32 | parser.add_argument('--random-start', type=lambda x: x.lower()[0] == 't', default=True) # Quick hack to extract boolean 33 | args = parser.parse_args() 34 | 35 | 36 | if args.norm != 'inf': 37 | norm = int(args.norm) 38 | else: 39 | norm = args.norm 40 | 41 | 42 | ######## 43 | # Data # 44 | ######## 45 | if args.dataset == 'mnist': 46 | transform = transforms.Compose([ 47 | transforms.ToTensor(), 48 | ]) 49 | 50 | train = datasets.MNIST(f'{PATH}/data/', train=True, transform=transform, download=True) 51 | val = datasets.MNIST(f'{PATH}/data/', train=False, transform=transform, download=True) 52 | 53 | train_loader = DataLoader(train, batch_size=args.batch_size, num_workers=cpu_count()) 54 | val_loader = DataLoader(val, batch_size=args.batch_size, num_workers=cpu_count()) 55 | elif args.dataset == 'restricted_imagenet': 56 | rng = np.random.RandomState(0) 57 | data = RestrictedImageNet() 58 | 59 | indices = np.array(range(len(data))) 60 | rng.shuffle(indices) 61 | train_indices = indices[:int(len(indices)*0.9)] 62 | val_indices = indices[int(len(indices)*0.9):] 63 | 64 | train = Subset(data, train_indices) 65 | val = Subset(data, val_indices) 66 | 67 | train_loader = DataLoader(train, batch_size=args.batch_size, num_workers=cpu_count()) 68 | val_loader = DataLoader(val, batch_size=args.batch_size, num_workers=cpu_count()) 69 | else: 70 | raise ValueError('Unsupported dataset') 71 | 72 | 73 | ######### 74 | # Model # 75 | ######### 76 | if args.dataset == 'mnist': 77 | model = MNISTClassifier().to(args.device) 78 | elif args.dataset == 'restricted_imagenet': 79 | model = models.resnet50(num_classes=RestrictedImageNet().num_classes()).to(args.device) 80 | else: 81 | raise ValueError('Unsupported norm') 82 | 83 | optimiser = optim.SGD(model.parameters(), lr=0.1) 84 | loss_fn = nn.CrossEntropyLoss() 85 | 86 | 87 | ################# 88 | # Training loop # 89 | ################# 90 | callbacks = [ 91 | Evaluate(val_loader), 92 | ReduceLROnPlateau(monitor='val_accuracy', patience=5), 93 | ModelCheckpoint( 94 | f'{PATH}/models/{args.dataset}_attack={args.attack}_eps={args.eps}.pt', 95 | save_best_only=True, 96 | monitor='val_accuracy', 97 | verbose=True 98 | ), 99 | CSVLogger(f'{PATH}/logs/{args.dataset}_attack={args.attack}_eps={args.eps}.csv') 100 | ] 101 | 102 | 103 | def adversarial_update(model, optimiser, loss_fn, x, y, epoch, eps, step, k, norm, **kwargs): 104 | """Performs a single update against an adversary""" 105 | model.train() 106 | 107 | # Adversial perturbation 108 | if norm == 'inf': 109 | x_adv = iterated_fgsm(model, x, y, loss_fn, k=k, step=step, eps=eps, norm='inf', random=args.random_start) 110 | elif norm == 2: 111 | x_adv = pgd(model, x, y, loss_fn, k=k, step=step, eps=eps, norm=2, random=args.random_start) 112 | else: 113 | raise ValueError('Unsupported norm') 114 | 115 | optimiser.zero_grad() 116 | y_pred = model(x_adv) 117 | loss = loss_fn(y_pred, y) 118 | loss.backward() 119 | optimiser.step() 120 | 121 | return loss, y_pred 122 | 123 | 124 | olympic.fit( 125 | model, 126 | optimiser, 127 | loss_fn, 128 | dataloader=train_loader, 129 | epochs=args.epochs, 130 | metrics=['accuracy'], 131 | callbacks=callbacks, 132 | update_fn=adversarial_update, 133 | update_fn_kwargs={'eps': args.eps, 'step': args.step, 'norm': norm, 'k': args.k}, 134 | prepare_batch=lambda batch: (batch[0].to(args.device), batch[1].to(args.device)) 135 | ) 136 | -------------------------------------------------------------------------------- /adversarial/attacks.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Union, Callable, Tuple 3 | from torch.nn import Module 4 | from collections import deque 5 | import torch 6 | 7 | from adversarial.utils import project 8 | from adversarial.functional import * 9 | 10 | 11 | class Attack(ABC): 12 | """Base class for adversarial attack methods""" 13 | @abstractmethod 14 | def create_adversarial_sample(self, *args, **kwargs) -> torch.Tensor: 15 | raise NotImplementedError 16 | 17 | 18 | class FGSM(Attack): 19 | """Implements the Fast Gradient-Sign Method (FGSM). 20 | 21 | FGSM is a white box attack. 22 | """ 23 | def __init__(self, 24 | eps: float, 25 | model: Module, 26 | loss_fn: Callable): 27 | super(FGSM, self).__init__() 28 | self.model = model 29 | self.loss_fn = loss_fn 30 | self.eps = eps 31 | 32 | def create_adversarial_sample(self, 33 | x: torch.Tensor, 34 | y: torch.Tensor, 35 | clamp: Tuple[float, float] = (0, 1)): 36 | """Creates an adversarial sample 37 | 38 | Args: 39 | x: Batch of samples 40 | y: Corresponding labels 41 | clamp: Max and minimum values of elements in the samples i.e. (0, 1) for MNIST 42 | 43 | Returns: 44 | x_adv: Adversarially perturbed version of x 45 | """ 46 | return fgsm(self.model, x, y, self.loss_fn, self.eps, clamp) 47 | 48 | 49 | class IteratedFGSM(Attack): 50 | """Implements the iterated Fast Gradient-Sign Method""" 51 | def __init__(self, 52 | model: Module, 53 | loss_fn: Callable, 54 | eps: float, 55 | k: int, 56 | step: float, 57 | norm: Union[str, int] = 'inf'): 58 | super(IteratedFGSM, self).__init__() 59 | self.model = model 60 | self.loss_fn = loss_fn 61 | self.eps = eps 62 | self.step = step 63 | self.k = k 64 | self.norm = norm 65 | 66 | def create_adversarial_sample(self, 67 | x: torch.Tensor, 68 | y: torch.Tensor, 69 | clamp: Tuple[float, float] = (0, 1)) -> torch.Tensor: 70 | return iterated_fgsm(self.model, x, y, self.loss_fn, self.k, self.step, self.eps, self.norm) 71 | 72 | 73 | class PGD(Attack): 74 | """Implements the Projected Gradient Descent attack""" 75 | def __init__(self, 76 | model: Module, 77 | loss_fn: Callable, 78 | eps: float, 79 | k: int, 80 | step: float, 81 | norm: Union[str, int] = 'inf'): 82 | super(PGD, self).__init__() 83 | self.model = model 84 | self.loss_fn = loss_fn 85 | self.eps = eps 86 | self.step = step 87 | self.k = k 88 | self.norm = norm 89 | 90 | def create_adversarial_sample(self, 91 | x: torch.Tensor, 92 | y: torch.Tensor, 93 | clamp: Tuple[float, float] = (0, 1)) -> torch.Tensor: 94 | return pgd(self.model, x, y, self.loss_fn, self.k, self.step, self.eps, self.norm, clamp) 95 | 96 | 97 | class Boundary(Attack): 98 | """Implements the boundary attack 99 | 100 | This is a black box attack that doesn't require knowledge of the model 101 | structure. It only requires knowledge of 102 | 103 | https://arxiv.org/pdf/1712.04248.pdf 104 | 105 | Args: 106 | model: 107 | k: 108 | orthogonal_step: orthogonal step size (delta in paper) 109 | perpendicular_step: perpendicular step size (epsilon in paper) 110 | """ 111 | def __init__(self, model: Module, k: int, orthogonal_step: float = 0.1, perpendicular_step: float = 0.1): 112 | super(Boundary, self).__init__() 113 | self.model = model 114 | self.k = k 115 | self.orthogonal_step = orthogonal_step 116 | self.perpendicular_step = perpendicular_step 117 | 118 | def create_adversarial_sample(self, 119 | model: Module, 120 | x: torch.Tensor, 121 | y: torch.Tensor, 122 | initial: torch.Tensor = None, 123 | clamp: Tuple[float, float] = (0, 1)) -> torch.Tensor: 124 | return boundary(model, x, y, self.orthogonal_step, self.perpendicular_step, self.k, initial, clamp) 125 | 126 | -------------------------------------------------------------------------------- /adversarial/datasets.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import pandas as pd 3 | import torch 4 | from PIL import Image 5 | from torchvision import transforms 6 | from tqdm import tqdm 7 | import os 8 | 9 | from config import DATA_PATH 10 | 11 | 12 | class RestrictedImageNet(Dataset): 13 | classes = { 14 | 'n01532829': 'BIRD', 15 | 'n01558993': 'BIRD', 16 | 'n01843383': 'BIRD', 17 | 'n01855672': 'BIRD', 18 | 19 | 'n02089867': 'DOG', 20 | 'n02091244': 'DOG', 21 | 'n02099601': 'DOG', 22 | 'n02101006': 'DOG', 23 | 'n02105505': 'DOG', 24 | 'n02108551': 'DOG', 25 | 'n02108915': 'DOG', 26 | 'n02110063': 'DOG', 27 | 'n02111277': 'DOG', 28 | 'n02114548': 'DOG', 29 | 'n02091831': 'DOG', 30 | 'n02108089': 'DOG', 31 | 'n02110341': 'DOG', 32 | 'n02113712': 'DOG', 33 | 34 | 'n02165456': 'INSECT', 35 | 'n02174001': 'INSECT', 36 | 'n02219486': 'INSECT', 37 | 'n01770081': 'INSECT', 38 | 39 | 'n02795169': 'CONTAINER', 40 | 'n03127925': 'CONTAINER', 41 | 'n03337140': 'CONTAINER', 42 | 'n02747177': 'CONTAINER', 43 | 44 | 'n03272010': 'INSTRUMENT', 45 | 'n03838899': 'INSTRUMENT', 46 | 'n03854065': 'INSTRUMENT', 47 | 'n04515003': 'INSTRUMENT', 48 | 49 | 'n03417042': 'VEHICLE', 50 | 'n04146614': 'VEHICLE', 51 | 'n04389033': 'VEHICLE', 52 | 53 | 'n04596742': 'FOOD', 54 | 'n07747607': 'FOOD', 55 | 'n03400231': 'FOOD', 56 | 'n07584110': 'FOOD', 57 | 'n07613480': 'FOOD', 58 | 59 | 'n01910747': 'SEA_CREATURE', 60 | 'n01981276': 'SEA_CREATURE', 61 | 'n02074367': 'SEA_CREATURE', 62 | 'n02606052': 'SEA_CREATURE', 63 | 64 | } 65 | 66 | def __init__(self, transform_list=None): 67 | """Dataset class representing representing a restricted ImageNet 68 | 69 | 70 | """ 71 | self.df = pd.DataFrame(self.index_data()) 72 | 73 | # Index of dataframe has direct correspondence to item in dataset 74 | self.df = self.df.assign(id=self.df.index.values) 75 | 76 | # Convert arbitrary class names of dataset to ordered 0-(num_speakers - 1) integers 77 | self.unique_characters = sorted(self.df['class_name'].unique()) 78 | self.class_name_to_id = {self.unique_characters[i]: i for i in range(self.num_classes())} 79 | self.df = self.df.assign(class_id=self.df['class_name'].apply(lambda c: self.class_name_to_id[c])) 80 | 81 | # Create dicts 82 | self.datasetid_to_filepath = self.df.to_dict()['filepath'] 83 | self.datasetid_to_class_id = self.df.to_dict()['class_id'] 84 | 85 | # Setup transforms 86 | if transform_list is None: 87 | self.transform = transforms.Compose([ 88 | transforms.RandomResizedCrop(224, scale=(0.75, 1)), 89 | transforms.RandomHorizontalFlip(), 90 | transforms.ToTensor(), 91 | 92 | ]) 93 | else: 94 | self.transform = transform_list 95 | 96 | def __getitem__(self, item): 97 | instance = Image.open(self.datasetid_to_filepath[item]) 98 | instance = self.transform(instance) 99 | label = self.datasetid_to_class_id[item] 100 | return instance, label 101 | 102 | def __len__(self): 103 | return len(self.df) 104 | 105 | def num_classes(self): 106 | return len(self.df['class_name'].unique()) 107 | 108 | @staticmethod 109 | def index_data(): 110 | """Index a subset by looping through all of its files and recording relevant information. 111 | 112 | # Arguments 113 | subset: Name of the subset 114 | 115 | # Returns 116 | A list of dicts containing information about all the image files in a particular subset of the 117 | miniImageNet dataset 118 | """ 119 | images = [] 120 | # Quick first pass to find total for tqdm bar 121 | 122 | for root, folders, files in os.walk(DATA_PATH + '/miniImageNet/images_background'): 123 | if len(files) == 0: 124 | continue 125 | 126 | class_name = root.split('/')[-1] 127 | 128 | if not class_name in RestrictedImageNet.classes.keys(): 129 | continue 130 | 131 | for f in files: 132 | images.append({ 133 | 'class_name': RestrictedImageNet.classes[class_name], 134 | 'filepath': os.path.join(root, f) 135 | }) 136 | 137 | for root, folders, files in os.walk(DATA_PATH + '/miniImageNet/images_evaluation/'): 138 | if len(files) == 0: 139 | continue 140 | 141 | class_name = root.split('/')[-1] 142 | 143 | if not class_name in RestrictedImageNet.classes.keys(): 144 | continue 145 | 146 | for f in files: 147 | images.append({ 148 | 'class_name': RestrictedImageNet.classes[class_name], 149 | 'filepath': os.path.join(root, f) 150 | }) 151 | 152 | return images 153 | -------------------------------------------------------------------------------- /tests/test_functional.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from torchvision import transforms, datasets 4 | from torch.utils.data import DataLoader 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | from adversarial.functional import * 9 | from adversarial.models import MNISTClassifier 10 | from config import PATH 11 | 12 | 13 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 14 | 15 | 16 | class TestAttacks(unittest.TestCase): 17 | 18 | @classmethod 19 | def setUpClass(cls): 20 | cls.model = MNISTClassifier() 21 | cls.model.load_state_dict(torch.load(f'{PATH}/models/mnist_natural.pt')) 22 | cls.model.to(DEVICE) 23 | 24 | transform = transforms.Compose([ 25 | transforms.ToTensor(), 26 | ]) 27 | cls.val = datasets.MNIST(f'{PATH}/data/', train=False, transform=transform, download=True) 28 | cls.val_loader = DataLoader(cls.val, batch_size=1, num_workers=0) 29 | 30 | x, y = cls.val_loader.__iter__().__next__() 31 | cls.x = x.to(DEVICE) 32 | cls.y = y.to(DEVICE) 33 | 34 | def test_fgsm(self): 35 | eps = 0.25 36 | 37 | # x, y = self.val_loader.__iter__().__next__() 38 | # x = x.to(device) 39 | # y = y.to(device) 40 | 41 | x_adv = fgsm(self.model, self.x, self.y, torch.nn.CrossEntropyLoss(), eps) 42 | 43 | # Check that adversarial example is misclassified 44 | self.assertNotEqual(self.model(x_adv).argmax(dim=1).item(), self.y.item()) 45 | 46 | # Check that l_inf distance is as expected 47 | self.assertTrue( 48 | np.isclose(torch.norm(self.x - x_adv, float("inf")).item(), eps) 49 | ) 50 | 51 | def test_iterated_fgsm_untargeted(self): 52 | k = 100 53 | step = 1 54 | eps = 0.3 55 | norm = 'inf' 56 | 57 | x_adv = iterated_fgsm(self.model, self.x, self.y, torch.nn.CrossEntropyLoss(), k=k, step=step, eps=eps, norm=norm) 58 | 59 | # Check that adversarial example is misclassified 60 | self.assertNotEqual(self.model(x_adv).argmax(dim=1).item(), self.y.item()) 61 | 62 | # Assert that distance between adversarial and natural sample is 63 | # less than specified amount 64 | adversarial_distance = (self.x - x_adv).norm(float(norm)).item() 65 | self.assertTrue( 66 | np.isclose(adversarial_distance, eps) or adversarial_distance < eps 67 | ) 68 | 69 | def test_iterated_fgsm_targeted(self): 70 | k = 100 71 | step = 0.1 72 | eps = 0.3 73 | norm = 'inf' 74 | target = torch.Tensor([0]).long().to(DEVICE) 75 | 76 | x_adv = iterated_fgsm(self.model, self.x, self.y, torch.nn.CrossEntropyLoss(), y_target=target, k=k, step=step, 77 | eps=eps, norm=norm) 78 | 79 | # Check that adversarial example is classified as the target class 80 | self.assertEqual(self.model(x_adv).argmax(dim=1).item(), target.item()) 81 | 82 | # Assert that distance between adversarial and natural sample is 83 | # less than specified amount 84 | adversarial_distance = (self.x - x_adv).norm(float(norm)).item() 85 | self.assertTrue( 86 | np.isclose(adversarial_distance, eps) or adversarial_distance < eps 87 | ) 88 | 89 | def test_pgd_untargeted(self): 90 | k = 40 91 | eps = 3 92 | step = 0.1 93 | norm = 2 94 | 95 | # Perform an untarget PGD attack on a batch of MNIST images 96 | x = torch.cat([self.x, self.x]) 97 | y = torch.cat([self.y, self.y]) 98 | x_adv = pgd(self.model, x, y, torch.nn.CrossEntropyLoss(), k, step, eps=eps, norm=norm) 99 | 100 | # Check that adversarial examples are misclassified 101 | self.assertTrue(torch.all(self.model(x_adv).argmax(dim=1) != self.y)) 102 | 103 | # Assert that distance between adversarial and natural samples are less or equal to specified eps 104 | delta = x - x_adv 105 | self.assertTrue(torch.all(delta.view(delta.shape[0], -1).norm(norm, dim=-1) < eps)) 106 | 107 | def test_pgd_targeted(self): 108 | k = 10 109 | eps = 3 110 | step = 1 111 | norm = 2 112 | target = torch.Tensor([0, 0]).long().to(DEVICE) 113 | 114 | x = torch.cat([self.x, self.x]) 115 | y = torch.cat([self.y, self.y]) 116 | x_adv = pgd(self.model, x, y, torch.nn.CrossEntropyLoss(), k, step, 117 | y_target=target, eps=eps, norm=norm) 118 | 119 | # Check that adversarial example is misclassified 120 | self.assertTrue(torch.all(self.model(x_adv).argmax(dim=1) == target)) 121 | 122 | # Assert that distance between adversarial and natural samples are less or equal to specified eps 123 | delta = x - x_adv 124 | self.assertTrue(torch.all(delta.view(delta.shape[0], -1).norm(norm, dim=-1) < eps)) 125 | 126 | def test_boundary_untargeted(self): 127 | x_adv = boundary(self.model, self.x, self.y, 500) 128 | 129 | # Check that adversarial example is misclassified 130 | self.assertNotEqual(self.model(x_adv).argmax(dim=1).item(), self.y.item()) 131 | 132 | # Assert that distance between adversarial and natural sample is small 133 | self.assertLess( 134 | (torch.norm(self.x - x_adv, 2).pow(2) / self.x.numel()).item(), 135 | 0.1 136 | ) 137 | 138 | def test_local_search_untarget(self): 139 | x_adv = local_search(self.model, self.x, self.y, 100) 140 | 141 | # Check that adversarial example is misclassified 142 | self.assertNotEqual(self.model(x_adv).argmax(dim=1).item(), self.y.item()) 143 | -------------------------------------------------------------------------------- /adversarial/functional.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Callable, Tuple 2 | from functools import reduce 3 | from collections import deque 4 | from torch.nn import Module 5 | import torch 6 | 7 | from adversarial.utils import project, generate_misclassified_sample, random_perturbation 8 | 9 | 10 | def fgsm(model: Module, 11 | x: torch.Tensor, 12 | y: torch.Tensor, 13 | loss_fn: Callable, 14 | eps: float, 15 | clamp: Tuple[float, float] = (0, 1)) -> torch.Tensor: 16 | """Creates an adversarial sample using the Fast Gradient-Sign Method (FGSM) 17 | 18 | This is a white-box attack. 19 | 20 | Args: 21 | model: Model 22 | x: Batch of samples 23 | y: Corresponding labels 24 | loss_fn: Loss function to maximise 25 | eps: Size of adversarial perturbation 26 | clamp: Max and minimum values of elements in the samples i.e. (0, 1) for MNIST 27 | 28 | Returns: 29 | x_adv: Adversarially perturbed version of x 30 | """ 31 | x.requires_grad = True 32 | model.train() 33 | prediction = model(x) 34 | loss = -loss_fn(prediction, y) 35 | loss.backward() 36 | 37 | x_adv = (x - torch.sign(x.grad) * eps).clamp(*clamp).detach() 38 | 39 | return x_adv 40 | 41 | 42 | def _iterative_gradient(model: Module, 43 | x: torch.Tensor, 44 | y: torch.Tensor, 45 | loss_fn: Callable, 46 | k: int, 47 | step: float, 48 | eps: float, 49 | norm: Union[str, float], 50 | step_norm: Union[str, float], 51 | y_target: torch.Tensor = None, 52 | random: bool = False, 53 | clamp: Tuple[float, float] = (0, 1)) -> torch.Tensor: 54 | """Base function for PGD and iterated FGSM 55 | 56 | Args: 57 | model: Model 58 | x: Batch of samples 59 | y: Corresponding labels 60 | loss_fn: Loss function to maximise 61 | k: Number of iterations to make 62 | step: Size of step to make at each iteration 63 | eps: Maximum size of adversarial perturbation, larger perturbations will be projected back into the 64 | L_norm ball 65 | norm: Type of norm 66 | step_norm: 2 for PGD, 'inf' for iterated FGSM 67 | y_target: 68 | random: Whether to start Iterated FGSM within a random point in the l_norm ball 69 | clamp: Max and minimum values of elements in the samples i.e. (0, 1) for MNIST 70 | 71 | Returns: 72 | x_adv: Adversarially perturbed version of x 73 | """ 74 | x_adv = x.clone().detach().requires_grad_(True).to(x.device) 75 | targeted = y_target is not None 76 | 77 | if random: 78 | x_adv = random_perturbation(x_adv, norm, eps) 79 | 80 | for i in range(k): 81 | _x_adv = x_adv.clone().detach().requires_grad_(True) 82 | 83 | prediction = model(_x_adv) 84 | loss = loss_fn(prediction, y_target if targeted else y) 85 | loss.backward() 86 | 87 | with torch.no_grad(): 88 | if step_norm == 'inf': 89 | gradients = _x_adv.grad.sign() * step 90 | else: 91 | # .view() assumes batched image data as 4D tensor 92 | gradients = _x_adv.grad * step / _x_adv.grad.view(_x_adv.shape[0], -1).norm(step_norm, dim=-1)\ 93 | .view(-1, 1, 1, 1) 94 | 95 | if targeted: 96 | # Targeted: Gradient descent with on the loss of the (incorrect) target label 97 | # w.r.t. the model parameters 98 | x_adv -= gradients 99 | else: 100 | # Untargeted: Gradient ascent on the loss of the correct label w.r.t. 101 | # the model parameters 102 | x_adv += gradients 103 | 104 | 105 | # Project back into l_norm ball and correct range 106 | x_adv = project(x, x_adv, norm, eps).clamp(*clamp) 107 | 108 | return x_adv.detach() 109 | 110 | 111 | def iterated_fgsm(model: Module, 112 | x: torch.Tensor, 113 | y: torch.Tensor, 114 | loss_fn: Callable, 115 | k: int, 116 | step: float, 117 | eps: float, 118 | norm: Union[str, float], 119 | y_target: torch.Tensor = None, 120 | random: bool = False, 121 | clamp: Tuple[float, float] = (0, 1)) -> torch.Tensor: 122 | """Creates an adversarial sample using the iterated Fast Gradient-Sign Method 123 | 124 | This is a white-box attack. 125 | 126 | Args: 127 | model: Model 128 | x: Batch of samples 129 | y: Corresponding labels 130 | loss_fn: Loss function to maximise 131 | k: Number of iterations to make 132 | step: Size of step to make at each iteration 133 | eps: Maximum size of adversarial perturbation, larger perturbations will be projected back into the 134 | L_norm ball 135 | norm: Type of norm 136 | y_target: 137 | random: Whether to start Iterated FGSM within a random point in the l_norm ball 138 | clamp: Max and minimum values of elements in the samples i.e. (0, 1) for MNIST 139 | 140 | Returns: 141 | x_adv: Adversarially perturbed version of x 142 | """ 143 | return _iterative_gradient(model=model, x=x, y=y, loss_fn=loss_fn, k=k, eps=eps, norm=norm, step=step, 144 | step_norm='inf', y_target=y_target, random=random, clamp=clamp) 145 | 146 | 147 | def pgd(model: Module, 148 | x: torch.Tensor, 149 | y: torch.Tensor, 150 | loss_fn: Callable, 151 | k: int, 152 | step: float, 153 | eps: float, 154 | norm: Union[str, float], 155 | y_target: torch.Tensor = None, 156 | random: bool = False, 157 | clamp: Tuple[float, float] = (0, 1)) -> torch.Tensor: 158 | """Creates an adversarial sample using the Projected Gradient Descent Method 159 | 160 | This is a white-box attack. 161 | 162 | Args: 163 | model: Model 164 | x: Batch of samples 165 | y: Corresponding labels 166 | loss_fn: Loss function to maximise 167 | k: Number of iterations to make 168 | step: Size of step (i.e. L2 norm) to make at each iteration 169 | eps: Maximum size of adversarial perturbation, larger perturbations will be projected back into the 170 | L_norm ball 171 | norm: Type of norm 172 | random: Whether to start PGD within a random point in the l_norm ball 173 | clamp: Max and minimum values of elements in the samples i.e. (0, 1) for MNIST 174 | 175 | Returns: 176 | x_adv: Adversarially perturbed version of x 177 | """ 178 | return _iterative_gradient(model=model, x=x, y=y, loss_fn=loss_fn, k=k, eps=eps, norm=norm, step=step, step_norm=2, 179 | y_target=y_target, random=random, clamp=clamp) 180 | 181 | 182 | def boundary(model: Module, 183 | x: torch.Tensor, 184 | y: torch.Tensor, 185 | k: int, 186 | orthogonal_step: float = 1e-2, 187 | perpendicular_step: float = 1e-2, 188 | initial: torch.Tensor = None, 189 | clamp: Tuple[float, float] = (0, 1)) -> torch.Tensor: 190 | """Implements the boundary attack 191 | 192 | This is a black box attack that doesn't require knowledge of the model 193 | structure. It only requires knowledge of 194 | 195 | https://arxiv.org/pdf/1712.04248.pdf 196 | 197 | Args: 198 | model: Model to be attacked 199 | x: Batched image data 200 | y: Corresponding labels 201 | k: Number of steps to take 202 | orthogonal_step: orthogonal step size (delta in paper) 203 | perpendicular_step: perpendicular step size (epsilon in paper) 204 | initial: Initial attack image to start with. If this is None then use random noise 205 | clamp: Max and minimum values of elements in the samples i.e. (0, 1) for MNIST 206 | 207 | Returns: 208 | x_adv: Best i.e. closest adversarial example for x 209 | """ 210 | orth_step_stats = deque(maxlen=30) 211 | perp_step_stats = deque(maxlen=30) 212 | # Factors to adjust step sizes by 213 | orth_step_factor = 0.97 214 | perp_step_factor = 0.97 215 | 216 | def _propose(x: torch.Tensor, 217 | x_adv: torch.Tensor, 218 | y: torch.Tensor, 219 | model: Module, 220 | clamp: Tuple[float, float] = (0, 1)) -> torch.Tensor: 221 | """Generate proposal perturbed sample 222 | 223 | Args: 224 | x: Original sample 225 | x_adv: Adversarial sample 226 | y: Label of original sample 227 | clamp: Domain (i.e. max/min) of samples 228 | """ 229 | # Sample from unit Normal distribution with same shape as input 230 | perturbation = torch.normal(torch.zeros_like(x_adv), torch.ones_like(x_adv)) 231 | 232 | # Rescale perturbation so l2 norm is delta 233 | perturbation = project(torch.zeros_like(perturbation), perturbation, norm=2, eps=orthogonal_step) 234 | 235 | # Apply perturbation and project onto sphere around original sample such that the distance 236 | # between the perturbed adversarial sample and the original sample is the same as the 237 | # distance between the unperturbed adversarial sample and the original sample 238 | # i.e. d(x_adv, x) = d(x_adv + perturbation, x) 239 | perturbed = x_adv + perturbation 240 | perturbed = project(x, perturbed, 2, torch.norm(x_adv - x, 2)).clamp(*clamp) 241 | 242 | # Record success/failure of orthogonal step 243 | orth_step_stats.append(model(perturbed).argmax(dim=1) != y) 244 | 245 | # Make step towards original sample 246 | step_towards_original = project(torch.zeros_like(perturbation), x - perturbed, norm=2, eps=perpendicular_step) 247 | perturbed = (perturbed + step_towards_original).clamp(*clamp) 248 | 249 | # Record success/failure of perpendicular step 250 | perp_step_stats.append(model(perturbed).argmax(dim=1) != y) 251 | 252 | # Clamp to domain of sample 253 | perturbed = perturbed.clamp(*clamp) 254 | 255 | return perturbed 256 | 257 | if x.size(0) != 1: 258 | # TODO: Attack a whole batch in parallel 259 | raise NotImplementedError 260 | 261 | if initial is not None: 262 | x_adv = initial 263 | else: 264 | # Generate initial adversarial sample from uniform distribution 265 | x_adv = generate_misclassified_sample(model, x, y) 266 | 267 | total_stats = torch.zeros(k) 268 | 269 | for i in range(k): 270 | # Propose perturbation 271 | perturbed = _propose(x, x_adv, y, model, clamp) 272 | 273 | # Check if perturbed input is adversarial i.e. gives the wrong prediction 274 | perturbed_prediction = model(perturbed).argmax(dim=1) 275 | total_stats[i] = perturbed_prediction != y 276 | if perturbed_prediction != y: 277 | x_adv = perturbed 278 | 279 | # Check statistics and adjust step sizes 280 | if len(perp_step_stats) == perp_step_stats.maxlen: 281 | if torch.Tensor(perp_step_stats).mean() > 0.5: 282 | perpendicular_step /= perp_step_factor 283 | orthogonal_step /= orth_step_factor 284 | elif torch.Tensor(perp_step_stats).mean() < 0.2: 285 | perpendicular_step *= perp_step_factor 286 | orthogonal_step *= orth_step_factor 287 | 288 | if len(orth_step_stats) == orth_step_stats.maxlen: 289 | if torch.Tensor(orth_step_stats).mean() > 0.5: 290 | orthogonal_step /= orth_step_factor 291 | elif torch.Tensor(orth_step_stats).mean() < 0.2: 292 | orthogonal_step *= orth_step_factor 293 | 294 | return x_adv 295 | 296 | 297 | def _perturb(x: torch.Tensor, 298 | p: float, 299 | i: int, 300 | j: int, 301 | clamp: Tuple[float, float] = (0, 1)) -> torch.Tensor: 302 | """Perturbs a pixel in an image 303 | 304 | Args: 305 | x: image 306 | p: perturbation parameters 307 | i: row 308 | j: column 309 | """ 310 | if x.size(0) != 1: 311 | raise NotImplementedError('Only implemented for single image') 312 | 313 | x[0, :, i, j] = p * torch.sign(x[0, :, i, j]) 314 | 315 | return x.clamp(*clamp) 316 | 317 | 318 | def local_search(model: Module, 319 | x: torch.Tensor, 320 | y: torch.Tensor, 321 | k: int, 322 | branching: Union[int, float] = 0.1, 323 | p: float = 1., 324 | d: int = None, 325 | clamp: Tuple[float, float] = (0, 1)) -> torch.Tensor: 326 | """Performs the local search attack 327 | 328 | This is a black-box (score based) attack first described in 329 | https://arxiv.org/pdf/1612.06299.pdf 330 | 331 | Args: 332 | model: Model to attack 333 | x: Batched image data 334 | y: Corresponding labels 335 | k: Number of rounds of local search to perform 336 | branching: Either fraction of image pixels to search at each round or 337 | number of image pixels to search at each round 338 | p: Size of perturbation 339 | d: Neighbourhood square half side length 340 | 341 | Returns: 342 | x_adv: Adversarial version of x 343 | """ 344 | if x.size(0) != 1: 345 | # TODO: Attack a whole batch at a time 346 | raise NotImplementedError('Only implemented for single image') 347 | 348 | x_adv = x.clone().detach().requires_grad_(False).to(x.device) 349 | model.eval() 350 | 351 | data_shape = x_adv.shape[2:] 352 | if isinstance(branching, float): 353 | branching = int(reduce(lambda x, y: x*y, data_shape) * branching) 354 | 355 | for _ in range(k): 356 | # Select pixel locations at random 357 | perturb_pixels = torch.randperm(reduce(lambda x, y: x*y, data_shape))[:branching] 358 | 359 | perturb_pixels = torch.stack([perturb_pixels // data_shape[0], perturb_pixels % data_shape[1]]).transpose(1, 0) 360 | 361 | # Kinda hacky but works for MNIST (i.e. 1 channel images) 362 | # TODO: multi channel images 363 | neighbourhood = x_adv.repeat((branching, 1, 1, 1)) 364 | perturb_pixels = torch.cat([torch.arange(branching).unsqueeze(-1), perturb_pixels], dim=1) 365 | neighbourhood[perturb_pixels[:, 0], 0, perturb_pixels[:, 1], perturb_pixels[:, 2]] = 1 366 | 367 | predictions = model(neighbourhood).softmax(dim=1) 368 | scores = predictions[:, y] 369 | 370 | # Select best perturbation and continue 371 | i_best, j_best = perturb_pixels[scores.argmin(dim=0).item(), 1:] 372 | x_adv[0, :, i_best, j_best] = 1. 373 | x_adv.clamp_(*clamp) 374 | 375 | # Early exit if adversarial is found 376 | worst_prediction = predictions.argmax(dim=1)[scores.argmin(dim=0).item()] 377 | if worst_prediction.item() != y.item(): 378 | return x_adv 379 | 380 | # Attack failed, return sample with lowest score of correct class 381 | return x_adv 382 | --------------------------------------------------------------------------------