├── .github └── workflows │ ├── ci.yml │ └── pypi.yml ├── .gitignore ├── LICENSE.txt ├── README.md ├── adv_train.py ├── evaluate_common_corruptions.py ├── evaluate_distances.py ├── evaluate_trained_model.py ├── generate_examples.py ├── getting_started.ipynb ├── mypy.ini ├── perceptual_advex ├── __init__.py ├── attacks.py ├── datasets.py ├── distances.py ├── evaluation.py ├── models.py ├── perceptual_attacks.py ├── py.typed ├── trades_wrn.py └── utilities.py ├── requirements.txt ├── setup.cfg └── setup.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - '**' 7 | 8 | jobs: 9 | test: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v2 14 | 15 | - name: Install Python 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: 3.7.6 19 | 20 | - name: Cache pip 21 | uses: actions/cache@v1 22 | with: 23 | path: ~/.cache/pip # This path is specific to Ubuntu 24 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} 25 | restore-keys: | 26 | ${{ runner.os }}-pip- 27 | 28 | - name: Install pip requirements 29 | timeout-minutes: 2 30 | run: | 31 | # Now, install dependencies. 32 | pip install --upgrade pip setuptools wheel 33 | pip install -r requirements.txt 34 | pip install mypy 35 | 36 | - name: Check types 37 | run: mypy perceptual_advex/*.py *.py 38 | if: ${{ always() }} 39 | -------------------------------------------------------------------------------- /.github/workflows/pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - v* 7 | 8 | jobs: 9 | publish: 10 | name: Publish to PyPI 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | 16 | - name: Install Python 17 | uses: actions/setup-python@v2 18 | with: 19 | python-version: 3.7.6 20 | 21 | - name: Cache pip 22 | uses: actions/cache@v1 23 | with: 24 | path: ~/.cache/pip # This path is specific to Ubuntu 25 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} 26 | restore-keys: | 27 | ${{ runner.os }}-pip- 28 | 29 | - name: Install pypa/build 30 | run: | 31 | python -m pip install build --user 32 | - name: Build a binary wheel and a source tarball 33 | run: | 34 | python -m build --sdist --wheel --outdir dist/ . 35 | - name: Publish distribution to PyPI 36 | if: startsWith(github.ref, 'refs/tags') 37 | uses: pypa/gh-action-pypi-publish@master 38 | with: 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.csv 2 | *.prof 3 | *.png 4 | *.txt 5 | *.pyc 6 | data 7 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | Copyright (c) 2018 YOUR NAME 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | The above copyright notice and this permission notice shall be included in all 10 | copies or substantial portions of the Software. 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 13 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 14 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 15 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 16 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 17 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Perceptual Aversarial Robustness 2 | This repository contains code and data for the ICLR 2021 paper ["Perceptual Adversarial Robustness: Defense Against Unseen Threat Models"](https://arxiv.org/abs/2006.12655). 3 | 4 | ## Installation 5 | 6 | The code can be downloaded as this GitHub repository, which includes the scripts for running all experiments in the paper. Alternatively, it can be installed as a pip package, which includes the models, attacks, and other utilities. 7 | 8 | ### As a repository 9 | 10 | 1. Install [Python 3](https://www.python.org/). 11 | 2. Clone the repository: 12 | 13 | git clone https://github.com/cassidylaidlaw/perceptual-advex.git 14 | cd perceptual-advex 15 | 16 | 2. Install pip requirements: 17 | 18 | pip install -r requirements.txt 19 | 20 | ### As a package 21 | 22 | 1. Install [Python 3](https://www.python.org/). 23 | 2. Install from PyPI: 24 | 25 | pip install perceptual-advex 26 | 27 | 3. (Optional) Install AutoAttack if you want to use it with the package: 28 | 29 | pip install git+git://github.com/fra31/auto-attack#egg=autoattack 30 | 31 | 4. Import the package as follows: 32 | 33 | from perceptual_advex.perceptual_attacks import FastLagrangePerceptualAttack 34 | 35 | See [getting_started.ipynb](getting_started.ipynb) or the Colab notebook below for examples of how to use the package. 36 | 37 | ## Data and Pretrained Models 38 | 39 | Download pretrained models from [here](https://perceptual-advex.s3.us-east-2.amazonaws.com/perceptual-advex-checkpoints.zip). 40 | 41 | Download perceptual study data from [here](https://perceptual-advex.s3.us-east-2.amazonaws.com/perceptual-advex-perceptual-study-data.zip). 42 | 43 | ## Usage 44 | 45 | This section explains how to get started with using the code and includes information about how to run all the experiments. 46 | 47 | ### Getting Started 48 | 49 | The [getting_started.ipynb](getting_started.ipynb) notebook shows how to load a pretrained model and construct perceptual adversarial examples for it. It is also available on Google Colab via the link below. 50 | 51 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cassidylaidlaw/perceptual-advex/blob/master/getting_started.ipynb) 52 | 53 | ### Perceptual Adversarial Training (PAT) 54 | 55 | The script `adv_train.py` can be used to perform Perceptual Adversarial Training (PAT) or to perform regular adversarial training. To train a ResNet-50 with self-bounded PAT on CIFAR-10: 56 | 57 | python adv_train.py --batch_size 50 --arch resnet50 --dataset cifar --attack "FastLagrangePerceptualAttack(model, bound=0.5, num_iterations=10)" --only_attack_correct 58 | 59 | This will create a directory `data/logs`, which will contain [TensorBoard](https://www.tensorflow.org/tensorboard) logs and checkpoints for each epoch. 60 | 61 | To train a ResNet-50 with self-bounded PAT on ImageNet-100: 62 | 63 | python adv_train.py --parallel 4 --batch_size 128 --dataset imagenet100 --dataset_path /path/to/ILSVRC2012 --arch resnet50 --attack "FastLagrangePerceptualAttack(model, bound=0.25, num_iterations=10)" --only_attack_correct 64 | 65 | This assumes 4 GPUs are available; you can change the number with the `--parallel` argument. To train a ResNet-50 with AlexNet-bounded PAT on ImageNet-100, run: 66 | 67 | python adv_train.py --parallel 4 --batch_size 128 --dataset imagenet100 --dataset_path /path/to/ILSVRC2012 --arch resnet50 --attack "FastLagrangePerceptualAttack(model, bound=0.25, num_iterations=10, lpips_model='alexnet')" --only_attack_correct 68 | 69 | ### Generating Perceptual Adversarial Attacks 70 | 71 | The script `generate_examples.py` will generate adversarially attacked images. For instance, to generate adversarial examples with the Perceptual Projected Gradient Descent (PPGD) and Lagrange Perceptual Attack (LPA) attacks on ImageNet, run: 72 | 73 | python generate_examples.py --dataset imagenet --arch resnet50 --checkpoint pretrained --batch_size 20 --shuffle --layout horizontal_alternate --dataset_path /path/to/ILSVRC2012 --output examples.png \ 74 | "PerceptualPGDAttack(model, bound=0.5, num_iterations=40, lpips_model='alexnet')" \ 75 | "LagrangePerceptualAttack(model, bound=0.5, num_iterations=40, lpips_model='alexnet')" 76 | 77 | This will create an image called `examples.png` with three columns. The first is the unmodified original images from the ImageNet test set. The second and third contain adversarial attacks and magnified difference from the originals for the PPGD and LPA attacks, respectively. 78 | 79 | #### Arguments 80 | 81 | - `--dataset` can be `cifar` for CIFAR-10, `imagenet100` for ImageNet-100, or `imagenet` for full ImageNet. 82 | - `--arch` can be `resnet50` (or `resnet34`, etc.) or `alexnet`. 83 | - `--checkpoint` can be `pretrained` to use the pretrained `torchvision` model. Otherwise, it should be a path to a pretrained model, such as those from the [robustness](https://github.com/MadryLab/robustness) library. 84 | - `--batch_size` indicates how many images to attack. 85 | - `--layout` controls the layout of the resulting image. It can be `vertical`, `vertical_alternate`, or `horizontal_alternate`. 86 | - `--output` specifies where the resulting image should be stored. 87 | - The remainder of the arguments specify attacks using Python expressions. See the `perceptual_advex.attacks` and `perceptual_advex.perceptual_attacks` modules for a full list of available attacks and arguments for those attacks. 88 | 89 | ### Evaluation 90 | 91 | The script `evaluate_trained_model.py` evaluates a model against a set of attacks. The arguments are similar to `generate_examples.py` (see above). For instance, to evaluate the torchvision pretrained ResNet-50 against PPGD and LPA, run: 92 | 93 | python evaluate_trained_model.py --dataset imagenet --arch resnet50 --checkpoint pretrained --batch_size 50 --dataset_path /path/to/ILSVRC2012 --output evaluation.csv \ 94 | "PerceptualPGDAttack(model, bound=0.5, num_iterations=40, lpips_model='alexnet')" \ 95 | "LagrangePerceptualAttack(model, bound=0.5, num_iterations=40, lpips_model='alexnet')" 96 | 97 | #### CIFAR-10 98 | 99 | The following command was used to evaluate CIFAR-10 classifiers for Tables 2, 6, 7, 8, and 9 in the paper: 100 | 101 | python evaluate_trained_model.py --dataset cifar --checkpoint /path/to/checkpoint.pt --arch resnet50 --batch_size 100 --output evaluation.csv \ 102 | "NoAttack()" \ 103 | "AutoLinfAttack(model, 'cifar', bound=8/255)" \ 104 | "AutoL2Attack(model, 'cifar', bound=1)" \ 105 | "StAdvAttack(model, num_iterations=100)" \ 106 | "ReColorAdvAttack(model, num_iterations=100)" \ 107 | "PerceptualPGDAttack(model, num_iterations=40, bound=0.5, lpips_model='alexnet_cifar', projection='newtons')" \ 108 | "LagrangePerceptualAttack(model, num_iterations=40, bound=0.5, lpips_model='alexnet_cifar', projection='newtons')" 109 | 110 | #### ImageNet-100 111 | 112 | The following command was used to evaluate ImageNet-100 classifiers for Table 3 in the paper, which shows the robustness of various models against several attacks at the medium perceptibility bound: 113 | 114 | python evaluate_trained_model.py --dataset imagenet100 --dataset_path /path/to/ILSVRC2012 --checkpoint /path/to/checkpoint.pt --arch resnet50 --batch_size 50 --output evaluation.csv \ 115 | "NoAttack()" \ 116 | "AutoLinfAttack(model, 'imagenet100', bound=4/255)" \ 117 | "AutoL2Attack(model, 'imagenet100', bound=1200/255)" \ 118 | "JPEGLinfAttack(model, 'imagenet100', bound=0.125, num_iterations=200)" \ 119 | "StAdvAttack(model, bound=0.05, num_iterations=200)" \ 120 | "ReColorAdvAttack(model, bound=0.06, num_iterations=200)" \ 121 | "PerceptualPGDAttack(model, bound=0.5, lpips_model='alexnet', num_iterations=40)" \ 122 | "LagrangePerceptualAttack(model, bound=0.5, lpips_model='alexnet', num_iterations=40)" 123 | 124 | ## Citation 125 | 126 | If you find this repository useful for your research, please cite our paper as follows: 127 | 128 | @inproceedings{laidlaw2021perceptual, 129 | title={Perceptual Adversarial Robustness: Defense Against Unseen Threat Models}, 130 | author={Laidlaw, Cassidy and Singla, Sahil and Feizi, Soheil}, 131 | booktitle={ICLR}, 132 | year={2021} 133 | } 134 | 135 | ## Contact 136 | 137 | For questions about the paper or code, please contact claidlaw@umd.edu. -------------------------------------------------------------------------------- /adv_train.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Optional, cast 2 | import torch 3 | import argparse 4 | import numpy as np 5 | import shutil 6 | import glob 7 | import time 8 | import random 9 | import os 10 | from torch import nn 11 | from tensorboardX import SummaryWriter 12 | 13 | from perceptual_advex import evaluation 14 | from perceptual_advex.utilities import add_dataset_model_arguments, \ 15 | get_dataset_model, calculate_accuracy 16 | from perceptual_advex.attacks import * 17 | from perceptual_advex.models import FeatureModel 18 | 19 | VAL_ITERS = 100 20 | 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser() 24 | 25 | add_dataset_model_arguments(parser) 26 | 27 | parser.add_argument('--num_epochs', type=int, required=False, 28 | help='number of epochs trained') 29 | parser.add_argument('--batch_size', type=int, default=100, 30 | help='number of examples/minibatch') 31 | parser.add_argument('--val_batches', type=int, default=10, 32 | help='number of batches to validate on') 33 | parser.add_argument('--log_dir', type=str, default='data/logs') 34 | parser.add_argument('--parallel', type=int, default=1, 35 | help='number of GPUs to train on') 36 | 37 | parser.add_argument('--lpips_model', type=str, required=False, 38 | help='model to use for LPIPS distance') 39 | parser.add_argument('--only_attack_correct', action='store_true', 40 | default=False, help='only attack examples that ' 41 | 'are classified correctly') 42 | parser.add_argument('--randomize_attack', action='store_true', 43 | default=False, 44 | help='randomly choose an attack at each step') 45 | parser.add_argument('--maximize_attack', action='store_true', 46 | default=False, 47 | help='choose the attack with maximum loss') 48 | 49 | parser.add_argument('--seed', type=int, default=0, help='RNG seed') 50 | parser.add_argument('--continue', default=False, action='store_true', 51 | help='continue previous training') 52 | parser.add_argument('--keep_every', type=int, default=1, 53 | help='only keep a checkpoint every X epochs') 54 | 55 | parser.add_argument('--optim', type=str, default='sgd') 56 | parser.add_argument('--lr', type=float, metavar='LR', required=False, 57 | help='learning rate') 58 | parser.add_argument('--lr_schedule', type=str, required=False, 59 | help='comma-separated list of epochs when learning ' 60 | 'rate should drop') 61 | parser.add_argument('--clip_grad', type=float, default=1.0, 62 | help='clip gradients to this value') 63 | 64 | parser.add_argument('--attack', type=str, action='append', 65 | help='attack(s) to harden against') 66 | 67 | args = parser.parse_args() 68 | 69 | if args.optim == 'adam': 70 | if args.lr is None: 71 | args.lr = 1e-3 72 | if args.lr_schedule is None: 73 | args.lr_schedule = '120' 74 | if args.num_epochs is None: 75 | args.num_epochs = 100 76 | elif args.optim == 'sgd': 77 | if args.dataset.startswith('cifar'): 78 | if args.lr is None: 79 | args.lr = 1e-1 80 | if args.lr_schedule is None: 81 | args.lr_schedule = '75,90,100' 82 | if args.num_epochs is None: 83 | args.num_epochs = 100 84 | elif ( 85 | args.dataset.startswith('imagenet') 86 | or args.dataset == 'bird_or_bicycle' 87 | ): 88 | if args.lr is None: 89 | args.lr = 1e-1 90 | if args.lr_schedule is None: 91 | args.lr_schedule = '30,60,80' 92 | if args.num_epochs is None: 93 | args.num_epochs = 90 94 | 95 | torch.manual_seed(args.seed) 96 | np.random.seed(args.seed) 97 | random.seed(args.seed) 98 | 99 | dataset, model = get_dataset_model(args) 100 | if isinstance(model, FeatureModel): 101 | model.allow_train() 102 | if torch.cuda.is_available(): 103 | model.cuda() 104 | 105 | if args.lpips_model is not None: 106 | _, lpips_model = get_dataset_model( 107 | args, checkpoint_fname=args.lpips_model) 108 | if torch.cuda.is_available(): 109 | lpips_model.cuda() 110 | 111 | train_loader, val_loader = dataset.make_loaders( 112 | workers=4, batch_size=args.batch_size) 113 | 114 | attacks = [eval(attack_str) for attack_str in args.attack] 115 | validation_attacks = [ 116 | NoAttack(), 117 | LinfAttack(model, dataset_name=args.dataset, 118 | num_iterations=VAL_ITERS), 119 | L2Attack(model, dataset_name=args.dataset, 120 | num_iterations=VAL_ITERS), 121 | JPEGLinfAttack(model, dataset_name=args.dataset, 122 | num_iterations=VAL_ITERS), 123 | FogAttack(model, dataset_name=args.dataset, 124 | num_iterations=VAL_ITERS), 125 | StAdvAttack(model, num_iterations=VAL_ITERS), 126 | ReColorAdvAttack(model, num_iterations=VAL_ITERS), 127 | LagrangePerceptualAttack(model, num_iterations=30), 128 | ] 129 | 130 | flags = [] 131 | if args.only_attack_correct: 132 | flags.append('only_attack_correct') 133 | if args.randomize_attack: 134 | flags.append('random') 135 | if args.maximize_attack: 136 | flags.append('maximum') 137 | if args.lpips_model: 138 | lpips_model_name, _ = os.path.splitext(os.path.basename( 139 | args.lpips_model)) 140 | flags.append(lpips_model_name) 141 | 142 | experiment_path_parts = [args.dataset, args.arch] 143 | if args.optim != 'sgd': 144 | experiment_path_parts.append(args.optim) 145 | attacks_part = '-'.join(args.attack + flags) 146 | if len(attacks_part) > 255: 147 | attacks_part = ( 148 | attacks_part 149 | .replace('model, ', '') 150 | .replace("'imagenet100', ", '') 151 | .replace("'cifar', ", '') 152 | .replace(", num_iterations=10", '') 153 | ) 154 | experiment_path_parts.append(attacks_part) 155 | experiment_path = os.path.join(*experiment_path_parts) 156 | 157 | iteration = 0 158 | log_dir = os.path.join(args.log_dir, experiment_path) 159 | if os.path.exists(log_dir): 160 | print(f'The log directory {log_dir} exists, delete? (y/N) ', end='') 161 | if not vars(args)['continue'] and input().strip() == 'y': 162 | shutil.rmtree(log_dir) 163 | # sleep necessary to prevent weird bug where directory isn't 164 | # actually deleted 165 | time.sleep(5) 166 | writer = SummaryWriter(log_dir) 167 | 168 | # optimizer 169 | optimizer: optim.Optimizer 170 | if args.optim == 'sgd': 171 | weight_decay = 1e-4 if ( 172 | args.dataset.startswith('imagenet') 173 | or args.dataset == 'bird_or_bicycle' 174 | ) else 2e-4 175 | optimizer = optim.SGD(model.parameters(), 176 | lr=args.lr, 177 | momentum=0.9, 178 | weight_decay=weight_decay) 179 | elif args.optim == 'adam': 180 | optimizer = optim.Adam(model.parameters()) 181 | else: 182 | raise ValueError(f'invalid optimizer {args.optim}') 183 | 184 | lr_drop_epochs = [int(epoch_str) for epoch_str in 185 | args.lr_schedule.split(',')] 186 | 187 | # check for checkpoints 188 | def get_checkpoint_fnames(): 189 | for checkpoint_fname in glob.glob(os.path.join(glob.escape(log_dir), 190 | '*.ckpt.pth')): 191 | epoch = int(os.path.basename(checkpoint_fname).split('.')[0]) 192 | if epoch < args.num_epochs: 193 | yield epoch, checkpoint_fname 194 | 195 | start_epoch = 0 196 | latest_checkpoint_epoch = -1 197 | latest_checkpoint_fname = None 198 | for epoch, checkpoint_fname in get_checkpoint_fnames(): 199 | if epoch > latest_checkpoint_epoch: 200 | latest_checkpoint_epoch = epoch 201 | latest_checkpoint_fname = checkpoint_fname 202 | if latest_checkpoint_fname is not None: 203 | print(f'Load checkpoint {latest_checkpoint_fname}? (Y/n) ', end='') 204 | if vars(args)['continue'] or input().strip() != 'n': 205 | state = torch.load(latest_checkpoint_fname) 206 | if 'iteration' in state: 207 | iteration = state['iteration'] 208 | if isinstance(model, FeatureModel): 209 | model.model.load_state_dict(state['model']) 210 | else: 211 | model.load_state_dict(state['model']) 212 | if 'optimizer' in state: 213 | optimizer.load_state_dict(state['optimizer']) 214 | start_epoch = latest_checkpoint_epoch + 1 215 | adaptive_eps = state.get('adaptive_eps', {}) 216 | 217 | # parallelize 218 | if torch.cuda.is_available(): 219 | device_ids = list(range(args.parallel)) 220 | model = nn.DataParallel(model, device_ids) 221 | attacks = [nn.DataParallel(attack, device_ids) for attack in attacks] 222 | validation_attacks = [nn.DataParallel(attack, device_ids) 223 | for attack in validation_attacks] 224 | 225 | # necessary to put training loop in a function because otherwise we get 226 | # huge memory leaks 227 | def run_iter( 228 | inputs: torch.Tensor, 229 | labels: torch.Tensor, 230 | iteration: int, 231 | train: bool = True, 232 | log_fn: Optional[Callable[[str, Any], Any]] = None, 233 | ): 234 | prefix = 'train' if train else 'val' 235 | if log_fn is None: 236 | log_fn = lambda tag, value: writer.add_scalar( 237 | f'{prefix}/{tag}', value, iteration) 238 | 239 | model.eval() # set model to eval to generate adversarial examples 240 | 241 | if torch.cuda.is_available(): 242 | inputs = inputs.cuda() 243 | labels = labels.cuda() 244 | 245 | if args.only_attack_correct: 246 | with torch.no_grad(): 247 | orig_logits = model(inputs) 248 | to_attack = orig_logits.argmax(1) == labels 249 | else: 250 | to_attack = torch.ones_like(labels).bool() 251 | 252 | if args.randomize_attack: 253 | step_attacks = [random.choice(attacks)] 254 | else: 255 | step_attacks = attacks 256 | 257 | adv_inputs_list: List[torch.Tensor] = [] 258 | for attack in step_attacks: 259 | attack_adv_inputs = inputs.clone() 260 | if to_attack.sum() > 0: 261 | attack_adv_inputs[to_attack] = attack(inputs[to_attack], 262 | labels[to_attack]) 263 | adv_inputs_list.append(attack_adv_inputs) 264 | adv_inputs: torch.Tensor = torch.cat(adv_inputs_list) 265 | 266 | all_labels = torch.cat([labels for attack in step_attacks]) 267 | 268 | # FORWARD PASS 269 | if train: 270 | optimizer.zero_grad() 271 | model.train() # now we set the model to train mode 272 | 273 | logits = model(adv_inputs) 274 | 275 | # CONSTRUCT LOSS 276 | loss = F.cross_entropy(logits, all_labels, reduction='none') 277 | if args.maximize_attack: 278 | loss, _ = loss.resize(len(step_attacks), inputs.size()[0]).max(0) 279 | loss = loss.mean() 280 | 281 | # LOGGING 282 | accuracy = calculate_accuracy(logits, all_labels) 283 | log_fn('loss', loss.item()) 284 | log_fn('accuracy', accuracy.item()) 285 | 286 | with torch.no_grad(): 287 | for attack_index, attack in enumerate(step_attacks): 288 | if isinstance(attack, nn.DataParallel): 289 | attack_name = attack.module.__class__.__name__ 290 | else: 291 | attack_name = attack.__class__.__name__ 292 | attack_logits = logits[ 293 | attack_index * inputs.size()[0]: 294 | (attack_index + 1) * inputs.size()[0] 295 | ] 296 | log_fn(f'loss/{attack_name}', 297 | F.cross_entropy(attack_logits, labels).item()) 298 | log_fn(f'accuracy/{attack_name}', 299 | calculate_accuracy(attack_logits, labels).item()) 300 | 301 | if train: 302 | print(f'ITER {iteration:06d}', 303 | f'accuracy: {accuracy.item() * 100:5.1f}%', 304 | f'loss: {loss.item():.2f}', 305 | sep='\t') 306 | 307 | # OPTIMIZATION 308 | if train: 309 | loss.backward() 310 | 311 | # clip gradients and optimize 312 | nn.utils.clip_grad_value_(model.parameters(), args.clip_grad) 313 | optimizer.step() 314 | 315 | for epoch in range(start_epoch, args.num_epochs): 316 | lr = args.lr 317 | for lr_drop_epoch in lr_drop_epochs: 318 | if epoch >= lr_drop_epoch: 319 | lr *= 0.1 320 | 321 | print(f'START EPOCH {epoch:04d} (lr={lr:.0e})') 322 | for batch_index, (inputs, labels) in enumerate(train_loader): 323 | # ramp-up learning rate for SGD 324 | if epoch < 5 and args.optim == 'sgd' and args.lr >= 0.1: 325 | lr = (iteration + 1) / (5 * len(train_loader)) * args.lr 326 | for param_group in optimizer.param_groups: 327 | param_group['lr'] = lr 328 | 329 | run_iter(inputs, labels, iteration) 330 | iteration += 1 331 | print(f'END EPOCH {epoch:04d}') 332 | 333 | if torch.cuda.is_available(): 334 | torch.cuda.empty_cache() 335 | 336 | # VALIDATION 337 | print('BEGIN VALIDATION') 338 | model.eval() 339 | 340 | evaluation.evaluate_against_attacks( 341 | model, validation_attacks, val_loader, parallel=args.parallel, 342 | writer=writer, iteration=iteration, num_batches=args.val_batches, 343 | ) 344 | 345 | checkpoint_fname = os.path.join(log_dir, f'{epoch:04d}.ckpt.pth') 346 | print(f'CHECKPOINT {checkpoint_fname}') 347 | checkpoint_model = model 348 | if isinstance(checkpoint_model, nn.DataParallel): 349 | checkpoint_model = checkpoint_model.module 350 | if isinstance(checkpoint_model, FeatureModel): 351 | checkpoint_model = checkpoint_model.model 352 | state = { 353 | 'model': checkpoint_model.state_dict(), 354 | 'optimizer': optimizer.state_dict(), 355 | 'iteration': iteration, 356 | 'arch': args.arch, 357 | } 358 | torch.save(state, checkpoint_fname) 359 | 360 | # delete extraneous checkpoints 361 | last_keep_checkpoint = (epoch // args.keep_every) * args.keep_every 362 | for epoch, checkpoint_fname in get_checkpoint_fnames(): 363 | if epoch < last_keep_checkpoint and epoch % args.keep_every != 0: 364 | print(f' remove {checkpoint_fname}') 365 | os.remove(checkpoint_fname) 366 | 367 | print('BEGIN EVALUATION') 368 | model.eval() 369 | 370 | evaluation.evaluate_against_attacks( 371 | model, validation_attacks, val_loader, parallel=args.parallel, 372 | ) 373 | print('END EVALUATION') 374 | -------------------------------------------------------------------------------- /evaluate_common_corruptions.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import csv 4 | import argparse 5 | import copy 6 | from typing import List 7 | 8 | from torch.hub import load_state_dict_from_url 9 | from torch import Tensor 10 | from torchvision.models import AlexNet 11 | from robustness.datasets import DATASETS 12 | 13 | from perceptual_advex.utilities import add_dataset_model_arguments, \ 14 | get_dataset_model 15 | from perceptual_advex.datasets import ImageNet100C 16 | 17 | 18 | if __name__ == '__main__': 19 | parser = argparse.ArgumentParser( 20 | description='Common corruptions evaluation') 21 | 22 | add_dataset_model_arguments(parser, include_checkpoint=True) 23 | parser.add_argument('--batch_size', type=int, default=100) 24 | parser.add_argument('--num_batches', type=int, required=False, 25 | help='number of batches (default entire dataset)') 26 | parser.add_argument('--output', type=str, 27 | help='output CSV') 28 | 29 | args = parser.parse_args() 30 | 31 | _, model = get_dataset_model(args) 32 | dataset_cls = DATASETS[args.dataset] 33 | 34 | alexnet_args = copy.deepcopy(args) 35 | alexnet_args.arch = 'alexnet' 36 | alexnet_args.checkpoint = None 37 | if args.dataset == 'cifar10c': 38 | alexnet_checkpoint_fname = 'data/checkpoints/alexnet_cifar.pt' 39 | elif args.dataset == 'imagenet100c': 40 | alexnet_checkpoint_fname = 'data/checkpoints/alexnet_imagenet100.pt' 41 | else: 42 | raise ValueError(f'Invalid dataset "{args.dataset}"') 43 | _, alexnet = get_dataset_model( 44 | alexnet_args, checkpoint_fname=alexnet_checkpoint_fname) 45 | 46 | model.eval() 47 | alexnet.eval() 48 | if torch.cuda.is_available(): 49 | model.cuda() 50 | alexnet.cuda() 51 | 52 | with open(args.output, 'w') as output_file: 53 | output_csv = csv.writer(output_file) 54 | output_csv.writerow([ 55 | 'corruption_type', 'severity', 'model_error', 'alexnet_error', 56 | ]) 57 | 58 | for corruption_type in [ 59 | 'gaussian_noise', 'shot_noise', 'impulse_noise', 60 | 'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur', 61 | 'snow', 'frost', 'fog', 'brightness', 62 | 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression', 63 | ]: 64 | model_errors: List[float] = [] 65 | alexnet_errors: List[float] = [] 66 | 67 | for severity in range(1, 6): 68 | print(f'CORRUPTION\t{corruption_type}\tseverity = {severity}') 69 | 70 | dataset = dataset_cls( 71 | args.dataset_path, corruption_type, severity) 72 | _, val_loader = dataset.make_loaders( 73 | 4, args.batch_size, only_val=True) 74 | 75 | batches_correct: List[Tensor] = [] 76 | alexnet_batches_correct: List[Tensor] = [] 77 | for batch_index, (inputs, labels) in enumerate(val_loader): 78 | if ( 79 | args.num_batches is not None and 80 | batch_index >= args.num_batches 81 | ): 82 | break 83 | 84 | if torch.cuda.is_available(): 85 | inputs = inputs.cuda() 86 | labels = labels.cuda() 87 | 88 | with torch.no_grad(): 89 | logits = model(inputs) 90 | batches_correct.append( 91 | (logits.argmax(1) == labels).detach()) 92 | 93 | alexnet_logits = alexnet(inputs) 94 | alexnet_batches_correct.append( 95 | (alexnet_logits.argmax(1) == labels).detach()) 96 | 97 | accuracy = torch.cat(batches_correct).float().mean().item() 98 | alexnet_accuracy = torch.cat( 99 | alexnet_batches_correct).float().mean().item() 100 | print('OVERALL\t', 101 | f'accuracy = {accuracy * 100:.1f}', 102 | f'AlexNet accuracy = {alexnet_accuracy * 100:.1f}', 103 | sep='\t') 104 | 105 | model_errors.append(1 - accuracy) 106 | alexnet_errors.append(1 - alexnet_accuracy) 107 | 108 | output_csv.writerow([ 109 | corruption_type, severity, 110 | 1 - accuracy, 1 - alexnet_accuracy, 111 | ]) 112 | 113 | ce = sum(model_errors) / sum(alexnet_errors) 114 | output_csv.writerow([corruption_type, 'ce', ce, 1]) 115 | -------------------------------------------------------------------------------- /evaluate_distances.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | 4 | import torch 5 | import csv 6 | from torch import nn 7 | from typing import List, Tuple 8 | from typing_extensions import Literal 9 | 10 | from perceptual_advex.utilities import add_dataset_model_arguments, \ 11 | get_dataset_model 12 | from perceptual_advex.distances import LPIPSDistance, LinfDistance, SSIM, \ 13 | L2Distance 14 | from perceptual_advex.models import FeatureModel 15 | from perceptual_advex.perceptual_attacks import get_lpips_model 16 | from perceptual_advex.perceptual_attacks import * 17 | from perceptual_advex.attacks import * 18 | 19 | 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser( 22 | description='Distance measure analysis') 23 | 24 | add_dataset_model_arguments(parser, include_checkpoint=True) 25 | parser.add_argument('--batch_size', type=int, default=50) 26 | parser.add_argument('--num_batches', type=int, required=False, 27 | help='number of batches (default entire dataset)') 28 | parser.add_argument('--per_example', action='store_true', default=False, 29 | help='output per-example accuracy') 30 | parser.add_argument('--output', type=str, help='output CSV') 31 | parser.add_argument('attacks', metavar='attack', type=str, nargs='+', 32 | help='attack names') 33 | 34 | args = parser.parse_args() 35 | 36 | dist_models: List[Tuple[str, nn.Module]] = [ 37 | ('l2', L2Distance()), 38 | ('linf', LinfDistance()), 39 | ('ssim', SSIM()), 40 | ] 41 | 42 | dataset, model = get_dataset_model(args) 43 | if not isinstance(model, FeatureModel): 44 | raise TypeError('model must be a FeatureModel') 45 | dist_models.append(('lpips_self', LPIPSDistance(model))) 46 | 47 | alexnet_model_name: Literal['alexnet_cifar', 'alexnet'] 48 | if args.dataset.startswith('cifar'): 49 | alexnet_model_name = 'alexnet_cifar' 50 | else: 51 | alexnet_model_name = 'alexnet' 52 | dist_models.append(( 53 | 'lpips_alexnet', 54 | LPIPSDistance(get_lpips_model(alexnet_model_name, model)), 55 | )) 56 | 57 | for _, dist_model in dist_models: 58 | dist_model.eval() 59 | if torch.cuda.is_available(): 60 | dist_model.cuda() 61 | 62 | _, val_loader = dataset.make_loaders(1, args.batch_size, only_val=True) 63 | 64 | model.eval() 65 | if torch.cuda.is_available(): 66 | model.cuda() 67 | 68 | attack_names: List[str] = args.attacks 69 | 70 | with open(args.output, 'w') as out_file: 71 | out_csv = csv.writer(out_file) 72 | out_csv.writerow([ 73 | attack_name for attack_name in attack_names 74 | for _ in dist_models 75 | ]) 76 | out_csv.writerow([ 77 | dist_model_name for _ in attack_names 78 | for dist_model_name, _ in dist_models 79 | ]) 80 | 81 | for batch_index, (inputs, labels) in enumerate(val_loader): 82 | if ( 83 | args.num_batches is not None and 84 | batch_index >= args.num_batches 85 | ): 86 | break 87 | 88 | print(f'BATCH\t{batch_index:05d}') 89 | 90 | if torch.cuda.is_available(): 91 | inputs = inputs.cuda() 92 | labels = labels.cuda() 93 | 94 | batch_distances = np.zeros(( 95 | inputs.shape[0], 96 | len(attack_names) * len(dist_models), 97 | )) 98 | 99 | for attack_index, attack_name in enumerate(attack_names): 100 | print(f'ATTACK {attack_name}') 101 | attack = eval(attack_name) 102 | 103 | adv_inputs = attack(inputs, labels) 104 | with torch.no_grad(): 105 | for dist_model_index, (_, dist_model) in \ 106 | enumerate(dist_models): 107 | batch_distances[ 108 | :, 109 | attack_index * len(dist_models) + dist_model_index 110 | ] = dist_model( 111 | inputs, 112 | adv_inputs, 113 | ).detach().cpu().numpy() 114 | 115 | for row in batch_distances: 116 | out_csv.writerow(row.tolist()) 117 | -------------------------------------------------------------------------------- /evaluate_trained_model.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Dict, List 3 | import torch 4 | import csv 5 | import argparse 6 | 7 | from perceptual_advex.utilities import add_dataset_model_arguments, \ 8 | get_dataset_model 9 | from perceptual_advex.attacks import * 10 | 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser( 14 | description='Adversarial training evaluation') 15 | 16 | add_dataset_model_arguments(parser, include_checkpoint=True) 17 | parser.add_argument('attacks', metavar='attack', type=str, nargs='+', 18 | help='attack names') 19 | parser.add_argument('--batch_size', type=int, default=100, 20 | help='number of examples/minibatch') 21 | parser.add_argument('--parallel', type=int, default=1, 22 | help='number of GPUs to train on') 23 | parser.add_argument('--num_batches', type=int, required=False, 24 | help='number of batches (default entire dataset)') 25 | parser.add_argument('--per_example', action='store_true', default=False, 26 | help='output per-example accuracy') 27 | parser.add_argument('--output', type=str, help='output CSV') 28 | 29 | args = parser.parse_args() 30 | 31 | dataset, model = get_dataset_model(args) 32 | _, val_loader = dataset.make_loaders(1, args.batch_size, only_val=True) 33 | 34 | model.eval() 35 | if torch.cuda.is_available(): 36 | model.cuda() 37 | 38 | attack_names: List[str] = args.attacks 39 | attacks = [eval(attack_name) for attack_name in attack_names] 40 | 41 | # Parallelize 42 | if torch.cuda.is_available(): 43 | device_ids = list(range(args.parallel)) 44 | model = nn.DataParallel(model, device_ids) 45 | attacks = [nn.DataParallel(attack, device_ids) for attack in attacks] 46 | 47 | batches_correct: Dict[str, List[torch.Tensor]] = \ 48 | {attack_name: [] for attack_name in attack_names} 49 | 50 | for batch_index, (inputs, labels) in enumerate(val_loader): 51 | print(f'BATCH {batch_index:05d}') 52 | 53 | if ( 54 | args.num_batches is not None and 55 | batch_index >= args.num_batches 56 | ): 57 | break 58 | 59 | if torch.cuda.is_available(): 60 | inputs = inputs.cuda() 61 | labels = labels.cuda() 62 | 63 | for attack_name, attack in zip(attack_names, attacks): 64 | adv_inputs = attack(inputs, labels) 65 | with torch.no_grad(): 66 | adv_logits = model(adv_inputs) 67 | batch_correct = (adv_logits.argmax(1) == labels).detach() 68 | 69 | batch_accuracy = batch_correct.float().mean().item() 70 | print(f'ATTACK {attack_name}', 71 | f'accuracy = {batch_accuracy * 100:.1f}', 72 | sep='\t') 73 | batches_correct[attack_name].append(batch_correct) 74 | 75 | print('OVERALL') 76 | accuracies = [] 77 | attacks_correct: Dict[str, torch.Tensor] = {} 78 | for attack_name in attack_names: 79 | attacks_correct[attack_name] = torch.cat(batches_correct[attack_name]) 80 | accuracy = attacks_correct[attack_name].float().mean().item() 81 | print(f'ATTACK {attack_name}', 82 | f'accuracy = {accuracy * 100:.1f}', 83 | sep='\t') 84 | accuracies.append(accuracy) 85 | 86 | with open(args.output, 'w') as out_file: 87 | out_csv = csv.writer(out_file) 88 | out_csv.writerow(attack_names) 89 | if args.per_example: 90 | for example_correct in zip(*[ 91 | attacks_correct[attack_name] for attack_name in attack_names 92 | ]): 93 | out_csv.writerow( 94 | [int(attack_correct.item()) for attack_correct 95 | in example_correct]) 96 | out_csv.writerow(accuracies) 97 | -------------------------------------------------------------------------------- /generate_examples.py: -------------------------------------------------------------------------------- 1 | """ 2 | Scripts that generates a number of adversarial examples for each of several 3 | attacks against a particular network. 4 | """ 5 | 6 | import torch 7 | import argparse 8 | import numpy as np 9 | import itertools 10 | from torchvision.utils import save_image 11 | 12 | from perceptual_advex.attacks import * 13 | from perceptual_advex.utilities import add_dataset_model_arguments, \ 14 | get_dataset_model 15 | 16 | 17 | def tile_images(images): 18 | """ 19 | Given a numpy array of shape r x c x C x W x H, where r and c are rows and 20 | columns in a grid of images, tiles the images into a numpy array 21 | C x (W * c) x (H * r). 22 | """ 23 | 24 | return np.concatenate(np.concatenate(images, axis=2), axis=2) 25 | 26 | 27 | if __name__ == '__main__': 28 | parser = argparse.ArgumentParser( 29 | description='Adversarial example generation') 30 | 31 | parser.add_argument('attacks', metavar='attack', type=str, nargs='+', 32 | help='attack names') 33 | 34 | add_dataset_model_arguments(parser, include_checkpoint=True) 35 | 36 | parser.add_argument('--batch_size', type=int, default=16, 37 | help='number of examples to generate ' 38 | 'adversarial examples for') 39 | parser.add_argument('--batch_index', type=int, default=0, 40 | help='batch index to generate adversarial examples ' 41 | 'for') 42 | parser.add_argument('--shuffle', default=False, action='store_true', 43 | help="Shuffle dataset before choosing a batch") 44 | parser.add_argument('--layout', type=str, default='vertical', 45 | help='lay out the same images on the same row ' 46 | '(horizontal) or column (vertical)') 47 | parser.add_argument('--only_successful', action='store_true', 48 | default=False, 49 | help='only show images where adversarial example ' 50 | 'was generated for all attacks') 51 | parser.add_argument('--output', type=str, 52 | help='output PNG file') 53 | parser.add_argument('--random_seed', type=int, default=None, 54 | help='seed for the Torch RNG') 55 | 56 | args = parser.parse_args() 57 | 58 | if args.random_seed is not None: 59 | torch.manual_seed(args.random_seed) 60 | 61 | dataset, model = get_dataset_model(args) 62 | _, val_loader = dataset.make_loaders(1, args.batch_size, only_val=True, 63 | shuffle_val=args.shuffle) 64 | model.eval() 65 | 66 | inputs, labels = next(itertools.islice( 67 | val_loader, args.batch_index, None)) 68 | if torch.cuda.is_available(): 69 | model.cuda() 70 | inputs = inputs.cuda() 71 | labels = labels.cuda() 72 | N, C, H, W = inputs.size() 73 | 74 | attacks = [None] + args.attacks 75 | out_advs = np.ones((len(attacks), N, C, H, W)) 76 | out_diffs = np.ones_like(out_advs) 77 | 78 | orig_labels = model(inputs).argmax(1) 79 | all_successful = np.ones(N, dtype=bool) 80 | all_labels = np.zeros((len(attacks), len(orig_labels)), dtype=int) 81 | all_labels[0] = orig_labels.cpu().detach().numpy() 82 | 83 | for attack_index, attack_name in enumerate(attacks): 84 | print(f'generating examples for {attack_name or "no"} attack') 85 | 86 | attack_params = None 87 | if attack_name is None: 88 | out_advs[attack_index] = inputs.cpu().numpy() 89 | out_diffs[attack_index] = 0 90 | else: 91 | attack = eval(attack_name) 92 | 93 | advs = attack(inputs, labels) 94 | adv_labels = model(advs).argmax(1) 95 | successful = (adv_labels != labels).cpu().detach().numpy() \ 96 | .astype(bool) 97 | 98 | print(f'accuracy = {np.mean(1 - successful) * 100:.1f}') 99 | diff = (advs - inputs).cpu().detach().numpy() 100 | advs = advs.cpu().detach().numpy() 101 | out_advs[attack_index, successful] = advs[successful] 102 | out_diffs[attack_index, successful] = diff[successful] 103 | 104 | all_labels[attack_index] = adv_labels.cpu().detach().numpy() 105 | 106 | all_successful[(adv_labels == orig_labels).cpu().detach().numpy() 107 | .astype(bool)] = False 108 | # mark examples that changed by less than 1/1000 as not successful 109 | all_successful[np.all(np.abs(diff) < 1e-3, 110 | axis=(1, 2, 3))] = False 111 | 112 | if args.only_successful: 113 | out_advs = out_advs[:, all_successful] 114 | out_diffs = out_diffs[:, all_successful] 115 | all_labels = all_labels[:, all_successful] 116 | 117 | for image_index in range(all_labels.shape[1]): 118 | print( 119 | f'image {image_index} labels:', 120 | ' '.join(map(str, all_labels[:, image_index])), 121 | ) 122 | 123 | out_diffs = np.clip(out_diffs * 3 + 0.5, 0, 1) 124 | 125 | combined_image: np.ndarray 126 | if args.layout == 'vertical': 127 | if len(attacks) == 2: 128 | combined_grid = np.concatenate([ 129 | out_advs, 130 | np.clip(out_diffs[1:2], 0, 1), 131 | ], axis=0) 132 | else: 133 | combined_grid = np.concatenate([ 134 | out_advs, 135 | np.ones((len(attacks), 1, C, H, W)), 136 | out_diffs, 137 | ], axis=1) 138 | combined_image = tile_images(combined_grid) 139 | elif args.layout == 'horizontal_alternate': 140 | rows = [] 141 | for i in range(out_advs.shape[1]): 142 | row = [] 143 | row.append(out_advs[0, i]) 144 | for adv, diff in zip(out_advs[1:, i], out_diffs[1:, i]): 145 | row.append(np.ones((C, H, W // 4))) 146 | row.append(adv) 147 | row.append(diff) 148 | rows.append(np.concatenate(row, axis=2)) 149 | combined_image = np.concatenate(rows, axis=1) 150 | elif args.layout == 'vertical_alternate': 151 | rows = [] 152 | for i in range(out_advs.shape[0]): 153 | row = [] 154 | for adv, diff in zip(out_advs[i], out_diffs[i]): 155 | row.append(np.ones((C, H, W // 4))) 156 | row.append(adv) 157 | row.append(diff) 158 | rows.append(np.concatenate(row[1:], axis=2)) 159 | combined_image = np.concatenate(rows, axis=1) 160 | else: 161 | raise ValueError(f'Unknown layout "{args.layout}"') 162 | save_image(torch.from_numpy(combined_image), args.output) 163 | -------------------------------------------------------------------------------- /getting_started.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Getting started with Perceptual Adversarial Robustness\n", 8 | "\n", 9 | "This notebook contains examples of how to load a pretrained model, measure LPIPS distance, and construct perceptual and non-perceptual attacks.\n", 10 | "\n", 11 | "If you are running this notebook in Google Colab, it is recommended to use a GPU. You can enable GPU acceleration by going to **Runtime** > **Change runtime type** and selecting **GPU** from the dropdown." 12 | ] 13 | }, 14 | { 15 | "source": [ 16 | "First, make sure you have installed the `perceptual_advex` package, either from GitHub or PyPI:" 17 | ], 18 | "cell_type": "markdown", 19 | "metadata": {} 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "try:\n", 28 | " import perceptual_advex\n", 29 | "except ImportError:\n", 30 | " !pip install perceptual-advex" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "## Loading a pretrained model\n", 38 | "First, let's load the CIFAR-10 dataset along with a pretrained model. The following code will download a model checkpoint and load it, but you can change the `checkpoint_name` parameter to load a different checkpoint. The checkpoint we're downloading here is trained against $L_2$ adversarial attacks with bound $\\epsilon = 1$." 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "import subprocess\n", 48 | "import os\n", 49 | "\n", 50 | "if not os.path.exists('data/checkpoints/cifar_pgd_l2_1.pt'):\n", 51 | " !mkdir -p data/checkpoints\n", 52 | " !curl -o data/checkpoints/cifar_pgd_l2_1.pt https://perceptual-advex.s3.us-east-2.amazonaws.com/cifar_pgd_l2_1_cpu.pt" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "from perceptual_advex.utilities import get_dataset_model\n", 62 | "\n", 63 | "dataset, model = get_dataset_model(\n", 64 | " dataset='cifar',\n", 65 | " arch='resnet50',\n", 66 | " checkpoint_fname='data/checkpoints/cifar_pgd_l2_1.pt',\n", 67 | ")" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "If you want to experiment with ImageNet-100 instead, just change the above to\n", 75 | "\n", 76 | " dataset, model = get_dataset_model(\n", 77 | " dataset='imagenet100',\n", 78 | " # Change this to where ImageNet is downloaded.\n", 79 | " dataset_path='/path/to/imagenet',\n", 80 | " arch='resnet50',\n", 81 | " # Change this to a pretrained checkpoint path.\n", 82 | " checkpoint_fname='/path/to/checkpoint',\n", 83 | " )" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "## Viewing images in the dataset\n", 91 | "\n", 92 | "Now that we have a dataset and model loaded, we can view some images in the dataset." 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "import torchvision\n", 102 | "import numpy as np\n", 103 | "import matplotlib.pyplot as plt\n", 104 | "\n", 105 | "# We'll use this helper function to show images in the Jupyter notebook.\n", 106 | "%matplotlib inline\n", 107 | "def show(img):\n", 108 | " if len(img.size()) == 4:\n", 109 | " img = torchvision.utils.make_grid(img, nrow=10, padding=0)\n", 110 | " npimg = img.detach().cpu().numpy()\n", 111 | " plt.figure(figsize=(18,16), dpi=80, facecolor='w', edgecolor='k')\n", 112 | " plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "import torch\n", 122 | "\n", 123 | "# Create a validation set loader.\n", 124 | "batch_size = 10\n", 125 | "_, val_loader = dataset.make_loaders(1, batch_size, only_val=True)\n", 126 | "\n", 127 | "# Get a batch from the validation set.\n", 128 | "inputs, labels = next(iter(val_loader))\n", 129 | "\n", 130 | "# If we have a GPU, let's convert everything to CUDA so it's quicker.\n", 131 | "if torch.cuda.is_available():\n", 132 | " inputs = inputs.cuda()\n", 133 | " labels = labels.cuda()\n", 134 | " model.cuda()\n", 135 | "\n", 136 | "# Show the batch!\n", 137 | "show(inputs)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "We can also test the accuracy of the model on this set of inputs by comparing the model output to the ground-truth labels." 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "pred_labels = model(inputs).argmax(1)\n", 154 | "print('Natural accuracy is', (labels == pred_labels).float().mean().item())" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": {}, 160 | "source": [ 161 | "If the natural accuracy is very low on this batch of images, you might want to load a new set by re-running the two cells above." 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "## Generating perceptual adversarial examples\n", 169 | "\n", 170 | "Next, let's generate some perceptual adversarial examples using Lagrange perceptual attack (LPA) with AlexNet bound $\\epsilon = 0.5$. Other perceptual attacks (PPGD and Fast-LPA) are also found in the `perceptual_advex.perceptual_attacks` module, and they mostly share the same options." 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "from perceptual_advex.perceptual_attacks import LagrangePerceptualAttack\n", 180 | "\n", 181 | "attack = LagrangePerceptualAttack(\n", 182 | " model,\n", 183 | " num_iterations=10,\n", 184 | " # The LPIPS distance bound on the adversarial examples.\n", 185 | " bound=0.5,\n", 186 | " # The model to use for calculate LPIPS; here we use AlexNet.\n", 187 | " # You can also use 'self' to perform a self-bounded attack.\n", 188 | " lpips_model='alexnet_cifar',\n", 189 | ")\n", 190 | "adv_inputs = attack(inputs, labels)\n", 191 | "\n", 192 | "# Show the adversarial examples.\n", 193 | "show(adv_inputs)\n", 194 | "\n", 195 | "# Show the magnified difference between the adversarial examples and unperturbed inputs.\n", 196 | "show((adv_inputs - inputs) * 5 + 0.5)" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "metadata": {}, 202 | "source": [ 203 | "Note that while the perturbations are sometimes large, the adversarial examples are still recognizable as the original image and do not appear too different perceptually.\n", 204 | "\n", 205 | "We can calculate the accuracy of the classifier on the adversarial examples:" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "adv_pred_labels = model(adv_inputs).argmax(1)\n", 215 | "print('Adversarial accuracy is', (labels == adv_pred_labels).float().mean().item())" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "Even though this network has been trained to be robust to $L_2$ perturbations, there are still imperceptible perturbations found using LPA that fool it almost every time!" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "metadata": {}, 228 | "source": [ 229 | "## Calculating LPIPS distance\n", 230 | "\n", 231 | "Next, let's calculate the LPIPS distance between the adversarial examples we generated and the original inputs:" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "from perceptual_advex.distances import LPIPSDistance\n", 241 | "from perceptual_advex.perceptual_attacks import get_lpips_model\n", 242 | "\n", 243 | "# LPIPS is based on the activations of a classifier, so we need to first\n", 244 | "# load the classifier we'll use.\n", 245 | "lpips_model = get_lpips_model('alexnet_cifar')\n", 246 | "if torch.cuda.is_available():\n", 247 | " lpips_model.cuda()\n", 248 | "\n", 249 | "# Now we can define a distance based on the model we loaded.\n", 250 | "# We could also do LPIPSDistance(model) for self-bounded LPIPS.\n", 251 | "lpips_distance = LPIPSDistance(lpips_model)\n", 252 | "\n", 253 | "# Finally, let's calculate the distance between the inputs and adversarial examples.\n", 254 | "print(lpips_distance(inputs, adv_inputs))" 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": {}, 260 | "source": [ 261 | "Note that all the distances are within the bound of 0.5! At this bound, the adversarial perturbations should all have a similar level of perceptibility to the human eye.\n", 262 | "\n", 263 | "Other distance measures between images are also defined in the `perceptual_advex.distances` package, including $L_\\infty$, $L_2$, and SSIM." 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "metadata": {}, 269 | "source": [ 270 | "## Generating non-perceptual adversarial examples\n", 271 | "\n", 272 | "The `perceptual_advex` package also includes code to perform attacks based on other, narrower threat models like $L_\\infty$ or $L_2$ distance and spatial transformations. The non-perceptual attacks are all in the `perceptual_advex.attacks` module. First, let's try an $L_2$ attack:" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "from perceptual_advex.attacks import L2Attack\n", 282 | "\n", 283 | "attack = L2Attack(\n", 284 | " model,\n", 285 | " 'cifar',\n", 286 | " # The bound is divided by 255, so this is equivalent to eps=1.\n", 287 | " bound=255,\n", 288 | ")\n", 289 | "l2_adv_inputs = attack(inputs, labels)\n", 290 | "\n", 291 | "show(l2_adv_inputs)\n", 292 | "show((l2_adv_inputs - inputs) * 5 + 0.5)\n", 293 | "\n", 294 | "l2_adv_pred_labels = model(l2_adv_inputs).argmax(1)\n", 295 | "print('L2 adversarial accuracy is', (labels == l2_adv_pred_labels).float().mean().item())" 296 | ] 297 | }, 298 | { 299 | "cell_type": "markdown", 300 | "metadata": {}, 301 | "source": [ 302 | "Here's an example of a spatial attack (StAdv):" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": null, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "from perceptual_advex.attacks import StAdvAttack\n", 312 | "\n", 313 | "attack = StAdvAttack(\n", 314 | " model,\n", 315 | " bound=0.02,\n", 316 | ")\n", 317 | "spatial_adv_inputs = attack(inputs, labels)\n", 318 | "\n", 319 | "show(spatial_adv_inputs)\n", 320 | "show((spatial_adv_inputs - inputs) * 5 + 0.5)\n", 321 | "\n", 322 | "spatial_adv_pred_labels = model(spatial_adv_inputs).argmax(1)\n", 323 | "print('Spatial adversarial accuracy is', (labels == spatial_adv_pred_labels).float().mean().item())" 324 | ] 325 | }, 326 | { 327 | "cell_type": "markdown", 328 | "metadata": {}, 329 | "source": [ 330 | "## Conclusion\n", 331 | "\n", 332 | "That's pretty much it for how to use the package! As a final note, here is an overview of what each module contains:\n", 333 | "\n", 334 | " * `perceptual_advex.attacks`: non-perceptual attacks (e.g. $L_2$, $L_\\infty$, spatial, recoloring, JPEG, etc.)\n", 335 | " * `perceptual_advex.datasets`: datasets (e.g. ImageNet-100, CIFAR-10, etc.)\n", 336 | " * `perceptual_advex.distances`: distance measures between images (e.g. LPIPS, SSIM, $L_2$)\n", 337 | " * `perceptual_advex.evaluation`: functions used for evaluating a trained model against attacks\n", 338 | " * `perceptual_advex.models`: classifier architectures (e.g. ResNet, AlexNet, etc.)\n", 339 | " * `perceptual_advex.perceptual_attacks`: perceptual attacks (e.g. LPA, PPGD, Fast-LPA)\n", 340 | " * `perceptual_advex.trades_wrn`: classifier architecture used by the TRADES defense (Zhang et al.)\n", 341 | " * `perceptual_advex.utilites`: various utilites, including `get_dataset_model` function to load a dataset and model" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [] 350 | } 351 | ], 352 | "metadata": { 353 | "kernelspec": { 354 | "display_name": "Python 3", 355 | "language": "python", 356 | "name": "python3" 357 | }, 358 | "language_info": { 359 | "codemirror_mode": { 360 | "name": "ipython", 361 | "version": 3 362 | }, 363 | "file_extension": ".py", 364 | "mimetype": "text/x-python", 365 | "name": "python", 366 | "nbconvert_exporter": "python", 367 | "pygments_lexer": "ipython3", 368 | "version": "3.7.9-final" 369 | } 370 | }, 371 | "nbformat": 4, 372 | "nbformat_minor": 2 373 | } -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.7 3 | warn_return_any = True 4 | warn_redundant_casts = True 5 | warn_unused_ignores = True 6 | warn_unused_configs = True 7 | check_untyped_defs = True 8 | 9 | [mypy-torchvision.*] 10 | ignore_missing_imports = True 11 | 12 | [mypy-tensorboardX.*] 13 | ignore_missing_imports = True 14 | 15 | [mypy-robustness.*] 16 | ignore_missing_imports = True 17 | 18 | [mypy-advex_uar.*] 19 | ignore_missing_imports = True 20 | 21 | [mypy-autoattack.*] 22 | ignore_missing_imports = True 23 | 24 | [mypy-bird_or_bicycle.*] 25 | ignore_missing_imports = True 26 | 27 | [mypy-unrestricted_advex.*] 28 | ignore_missing_imports = True 29 | 30 | [mypy-recoloradv.*] 31 | ignore_missing_imports = True 32 | 33 | [mypy-tqdm.*] 34 | ignore_missing_imports = True 35 | 36 | [mypy-PIL.*] 37 | ignore_missing_imports = True 38 | 39 | [mypy-statsmodels.*] 40 | ignore_missing_imports = True 41 | 42 | [mypy-scipy.*] 43 | ignore_missing_imports = True 44 | 45 | [mypy-sklearn.*] 46 | ignore_missing_imports = True 47 | 48 | [mypy-boto3] 49 | ignore_missing_imports = True 50 | -------------------------------------------------------------------------------- /perceptual_advex/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cassidylaidlaw/perceptual-advex/65c1bc3aabe1b9a475ee0edb7606aee44896b685/perceptual_advex/__init__.py -------------------------------------------------------------------------------- /perceptual_advex/attacks.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import functools 4 | from torch import nn 5 | from operator import mul 6 | from torch import optim 7 | from advex_uar.common.pyt_common import get_attack as get_uar_attack 8 | from advex_uar.attacks.attacks import InverseImagenetTransform 9 | 10 | from .perceptual_attacks import * 11 | from .utilities import LambdaLayer 12 | from . import utilities 13 | 14 | # mister_ed 15 | from recoloradv.mister_ed import loss_functions as lf 16 | from recoloradv.mister_ed import adversarial_training as advtrain 17 | from recoloradv.mister_ed import adversarial_perturbations as ap 18 | from recoloradv.mister_ed import adversarial_attacks as aa 19 | from recoloradv.mister_ed import spatial_transformers as st 20 | 21 | # ReColorAdv 22 | from recoloradv import perturbations as pt 23 | from recoloradv import color_transformers as ct 24 | from recoloradv import color_spaces as cs 25 | 26 | 27 | PGD_ITERS = 20 28 | DATASET_NUM_CLASSES = { 29 | 'cifar': 10, 30 | 'imagenet100': 100, 31 | 'imagenet': 1000, 32 | 'bird_or_bicycle': 2, 33 | } 34 | 35 | 36 | class NoAttack(nn.Module): 37 | """ 38 | Attack that does nothing. 39 | """ 40 | 41 | def __init__(self, model=None): 42 | super().__init__() 43 | self.model = model 44 | 45 | def forward(self, inputs, labels): 46 | return inputs 47 | 48 | 49 | class MisterEdAttack(nn.Module): 50 | """ 51 | Base class for attacks using the mister_ed library. 52 | """ 53 | 54 | def __init__(self, model, threat_model, randomize=False, 55 | perturbation_norm_loss=False, lr=0.001, random_targets=False, 56 | num_classes=None, **kwargs): 57 | super().__init__() 58 | 59 | self.model = model 60 | self.normalizer = nn.Identity() 61 | 62 | self.threat_model = threat_model 63 | self.randomize = randomize 64 | self.perturbation_norm_loss = perturbation_norm_loss 65 | self.attack_kwargs = kwargs 66 | self.lr = lr 67 | self.random_targets = random_targets 68 | self.num_classes = num_classes 69 | 70 | self.attack = None 71 | 72 | def _setup_attack(self): 73 | cw_loss = lf.CWLossF6(self.model, self.normalizer, kappa=float('inf')) 74 | if self.random_targets: 75 | cw_loss.forward = functools.partial(cw_loss.forward, targeted=True) 76 | perturbation_loss = lf.PerturbationNormLoss(lp=2) 77 | pert_factor = 0.0 78 | if self.perturbation_norm_loss is True: 79 | pert_factor = 0.05 80 | elif type(self.perturbation_norm_loss) is float: 81 | pert_factor = self.perturbation_norm_loss 82 | adv_loss = lf.RegularizedLoss({ 83 | 'cw': cw_loss, 84 | 'pert': perturbation_loss, 85 | }, { 86 | 'cw': 1.0, 87 | 'pert': pert_factor, 88 | }, negate=True) 89 | 90 | self.pgd_attack = aa.PGD(self.model, self.normalizer, 91 | self.threat_model(), adv_loss) 92 | 93 | attack_params = { 94 | 'optimizer': optim.Adam, 95 | 'optimizer_kwargs': {'lr': self.lr}, 96 | 'signed': False, 97 | 'verbose': False, 98 | 'num_iterations': 0 if self.randomize else PGD_ITERS, 99 | 'random_init': self.randomize, 100 | } 101 | attack_params.update(self.attack_kwargs) 102 | 103 | self.attack = advtrain.AdversarialAttackParameters( 104 | self.pgd_attack, 105 | 1.0, 106 | attack_specific_params={'attack_kwargs': attack_params}, 107 | ) 108 | self.attack.set_gpu(False) 109 | 110 | def forward(self, inputs, labels): 111 | if self.attack is None: 112 | self._setup_attack() 113 | assert self.attack is not None 114 | 115 | if self.random_targets: 116 | return utilities.run_attack_with_random_targets( 117 | lambda inputs, labels: self.attack.attack(inputs, labels)[0], 118 | self.model, 119 | inputs, 120 | labels, 121 | num_classes=self.num_classes, 122 | ) 123 | else: 124 | return self.attack.attack(inputs, labels)[0] 125 | 126 | 127 | class UARModel(nn.Module): 128 | def __init__(self, model): 129 | super().__init__() 130 | self.model = model 131 | 132 | def forward(self, x): 133 | inverse_transform = InverseImagenetTransform(x.size()[-1]) 134 | return self.model(inverse_transform(x) / 255) 135 | 136 | 137 | class UARAttack(nn.Module): 138 | """ 139 | One of the attacks from the paper "Testing Robustness Against Unforeseen 140 | Adversaries". 141 | """ 142 | 143 | def __init__(self, model, dataset_name, attack_name, bound, 144 | num_iterations=PGD_ITERS, step=None, random_targets=False, 145 | randomize=False): 146 | super().__init__() 147 | 148 | assert randomize is False 149 | 150 | if step is None: 151 | step = bound / (num_iterations ** 0.5) 152 | 153 | self.random_targets = random_targets 154 | self.num_classes = DATASET_NUM_CLASSES[dataset_name] 155 | if ( 156 | dataset_name.startswith('imagenet') 157 | or dataset_name == 'bird_or_bicycle' 158 | ): 159 | dataset_name = 'imagenet' 160 | elif dataset_name == 'cifar': 161 | dataset_name = 'cifar-10' 162 | 163 | self.model = model 164 | self.uar_model = UARModel(model) 165 | self.attack_name = attack_name 166 | self.bound = bound 167 | self.attack_fn = get_uar_attack(dataset_name, attack_name, eps=bound, 168 | n_iters=num_iterations, 169 | step_size=step, scale_each=1) 170 | self.attack = None 171 | 172 | def threat_model_contains(self, inputs, adv_inputs): 173 | """ 174 | Returns a boolean tensor which indicates if each of the given 175 | adversarial examples given is within this attack's threat model for 176 | the given natural input. 177 | """ 178 | 179 | if self.attack_name == 'pgd_linf': 180 | dist = (inputs - adv_inputs).reshape(inputs.size()[0], -1) \ 181 | .abs().max(1)[0] * 255 182 | elif self.attack_name == 'pgd_l2': 183 | dist = ( 184 | (inputs - adv_inputs).reshape(inputs.size()[0], -1) 185 | ** 2 186 | ).sum(1).sqrt() * 255 187 | elif self.attack_name == 'fw_l1': 188 | dist = ( 189 | (inputs - adv_inputs).reshape(inputs.size()[0], -1) 190 | .abs().sum(1) 191 | * 255 / functools.reduce(mul, inputs.size()[1:]) 192 | ) 193 | else: 194 | raise NotImplementedError() 195 | 196 | return dist <= self.bound 197 | 198 | def forward(self, inputs, labels): 199 | self.uar_model.training = self.model.training 200 | 201 | if self.attack is None: 202 | self.attack = self.attack_fn() 203 | self.attack.transform = LambdaLayer(lambda x: x / 255) 204 | self.attack.inverse_transform = LambdaLayer(lambda x: x * 255) 205 | 206 | if self.random_targets: 207 | attack = lambda inputs, targets: self.attack( 208 | self.uar_model, 209 | inputs, 210 | targets, 211 | avoid_target=False, 212 | scale_eps=False, 213 | ) 214 | adv_examples = utilities.run_attack_with_random_targets( 215 | attack, self.model, inputs, labels, self.num_classes, 216 | ) 217 | else: 218 | adv_examples = self.attack(self.uar_model, inputs, labels, 219 | scale_eps=False, avoid_target=True) 220 | 221 | # Some UAR attacks produce NaNs, so try to get rid of them here. 222 | perturbations = adv_examples - inputs 223 | perturbations[torch.isnan(perturbations)] = 0 224 | return (inputs + perturbations).detach() 225 | 226 | 227 | class LinfAttack(UARAttack): 228 | def __init__(self, model, dataset_name, bound=None, **kwargs): 229 | if bound is None: 230 | bound = { 231 | 'cifar': 8, 232 | 'imagenet100': 8, 233 | 'imagenet': 8, 234 | 'bird_or_bicycle': 16, 235 | }[dataset_name] 236 | 237 | super().__init__( 238 | model, 239 | dataset_name=dataset_name, 240 | attack_name='pgd_linf', 241 | bound=bound, 242 | **kwargs, 243 | ) 244 | 245 | 246 | class L2Attack(UARAttack): 247 | def __init__(self, model, dataset_name, bound=None, **kwargs): 248 | if bound is None: 249 | bound = { 250 | 'cifar': 255, 251 | 'imagenet100': 3 * 255, 252 | 'imagenet': 3 * 255, 253 | 'bird_or_bicycle': 10 * 255, 254 | }[dataset_name] 255 | 256 | super().__init__( 257 | model, 258 | dataset_name=dataset_name, 259 | attack_name='pgd_l2', 260 | bound=bound, 261 | **kwargs, 262 | ) 263 | 264 | 265 | class L1Attack(UARAttack): 266 | def __init__(self, model, dataset_name, bound=None, **kwargs): 267 | if bound is None: 268 | bound = { 269 | 'cifar': 0.5078125, 270 | 'imagenet100': 1.016422, 271 | 'imagenet': 1.016422, 272 | 'bird_or_bicycle': 1.016422, 273 | }[dataset_name] 274 | 275 | super().__init__( 276 | model, 277 | dataset_name=dataset_name, 278 | attack_name='fw_l1', 279 | bound=bound, 280 | **kwargs, 281 | ) 282 | 283 | 284 | class JPEGLinfAttack(UARAttack): 285 | def __init__(self, model, dataset_name, bound=None, **kwargs): 286 | if bound is None: 287 | bound = { 288 | 'cifar': 0.25, 289 | 'imagenet100': 0.5, 290 | 'imagenet': 0.5, 291 | 'bird_or_bicycle': 0.5, 292 | }[dataset_name] 293 | 294 | super().__init__( 295 | model, 296 | dataset_name=dataset_name, 297 | attack_name='jpeg_linf', 298 | bound=bound, 299 | **kwargs, 300 | ) 301 | 302 | 303 | class FogAttack(UARAttack): 304 | def __init__(self, model, dataset_name, bound=512, **kwargs): 305 | super().__init__( 306 | model, 307 | dataset_name=dataset_name, 308 | attack_name='fog', 309 | bound=bound, 310 | **kwargs, 311 | ) 312 | 313 | 314 | class StAdvAttack(MisterEdAttack): 315 | def __init__(self, model, bound=0.05, **kwargs): 316 | kwargs.setdefault('lr', 0.01) 317 | super().__init__( 318 | model, 319 | threat_model=lambda: ap.ThreatModel(ap.ParameterizedXformAdv, { 320 | 'lp_style': 'inf', 321 | 'lp_bound': bound, 322 | 'xform_class': st.FullSpatial, 323 | 'use_stadv': True, 324 | }), 325 | perturbation_norm_loss=0.0025 / bound, 326 | **kwargs, 327 | ) 328 | 329 | 330 | class ReColorAdvAttack(MisterEdAttack): 331 | def __init__(self, model, bound=0.06, **kwargs): 332 | super().__init__( 333 | model, 334 | threat_model=lambda: ap.ThreatModel(pt.ReColorAdv, { 335 | 'xform_class': ct.FullSpatial, 336 | 'cspace': cs.CIELUVColorSpace(), 337 | 'lp_style': 'inf', 338 | 'lp_bound': bound, 339 | 'xform_params': { 340 | 'resolution_x': 16, 341 | 'resolution_y': 32, 342 | 'resolution_z': 32, 343 | }, 344 | 'use_smooth_loss': True, 345 | }), 346 | perturbation_norm_loss=0.0036 / bound, 347 | **kwargs, 348 | ) 349 | 350 | 351 | class AutoAttack(nn.Module): 352 | def __init__(self, model, **kwargs): 353 | super().__init__() 354 | 355 | kwargs.setdefault('verbose', False) 356 | self.model = model 357 | self.kwargs = kwargs 358 | self.attack = None 359 | 360 | def forward(self, inputs, labels): 361 | # Necessary to initialize attack here because for parallelization 362 | # across multiple GPUs. 363 | if self.attack is None: 364 | try: 365 | import autoattack 366 | except ImportError: 367 | raise RuntimeError( 368 | 'Error: unable to import autoattack. Please install the ' 369 | 'package by running ' 370 | '"pip install git+git://github.com/fra31/auto-attack#egg=autoattack".' 371 | ) 372 | self.attack = autoattack.AutoAttack( 373 | self.model, device=inputs.device, **self.kwargs) 374 | 375 | return self.attack.run_standard_evaluation(inputs, labels) 376 | 377 | 378 | class AutoLinfAttack(AutoAttack): 379 | def __init__(self, model, dataset_name, bound=None, **kwargs): 380 | if bound is None: 381 | bound = { 382 | 'cifar': 8/255, 383 | 'imagenet100': 8/255, 384 | 'imagenet': 8/255, 385 | 'bird_or_bicycle': 16/255, 386 | }[dataset_name] 387 | 388 | super().__init__( 389 | model, 390 | norm='Linf', 391 | eps=bound, 392 | **kwargs, 393 | ) 394 | 395 | 396 | class AutoL2Attack(AutoAttack): 397 | def __init__(self, model, dataset_name, bound=None, **kwargs): 398 | if bound is None: 399 | bound = { 400 | 'cifar': 1, 401 | 'imagenet100': 3, 402 | 'imagenet': 3, 403 | 'bird_or_bicycle': 10, 404 | }[dataset_name] 405 | 406 | super().__init__( 407 | model, 408 | norm='L2', 409 | eps=bound, 410 | **kwargs, 411 | ) -------------------------------------------------------------------------------- /perceptual_advex/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tempfile 3 | import os 4 | import numpy as np 5 | 6 | from torchvision.datasets import CIFAR10 7 | 8 | from robustness.datasets import CIFAR, DATASETS, DataSet, CustomImageNet 9 | from robustness.data_augmentation import TRAIN_TRANSFORMS_IMAGENET, \ 10 | TEST_TRANSFORMS_IMAGENET 11 | from robustness import data_augmentation 12 | from torchvision.datasets.vision import VisionDataset 13 | 14 | 15 | class ImageNet100(CustomImageNet): 16 | def __init__(self, data_path, **kwargs): 17 | 18 | super().__init__( 19 | data_path=data_path, 20 | custom_grouping=[[label] for label in range(0, 1000, 10)], 21 | **kwargs, 22 | ) 23 | 24 | 25 | class ImageNet100A(CustomImageNet): 26 | def __init__(self, data_path, **kwargs): 27 | super().__init__( 28 | data_path=data_path, 29 | custom_grouping=[ 30 | [], 31 | [], 32 | [], 33 | [8], 34 | [], 35 | [13], 36 | [], 37 | [15], 38 | [], 39 | [20], 40 | [], 41 | [28], 42 | [], 43 | [32], 44 | [], 45 | [36], 46 | [], 47 | [], 48 | [], 49 | [], 50 | [], 51 | [], 52 | [], 53 | [], 54 | [], 55 | [], 56 | [], 57 | [], 58 | [], 59 | [], 60 | [], 61 | [53], 62 | [], 63 | [64], 64 | [], 65 | [], 66 | [], 67 | [], 68 | [], 69 | [], 70 | [75], 71 | [], 72 | [83], 73 | [86], 74 | [], 75 | [], 76 | [], 77 | [94], 78 | [], 79 | [], 80 | [], 81 | [], 82 | [], 83 | [104], 84 | [], 85 | [], 86 | [], 87 | [], 88 | [], 89 | [], 90 | [], 91 | [], 92 | [], 93 | [], 94 | [125], 95 | [], 96 | [], 97 | [], 98 | [], 99 | [], 100 | [], 101 | [], 102 | [], 103 | [], 104 | [], 105 | [], 106 | [], 107 | [], 108 | [150], 109 | [], 110 | [], 111 | [], 112 | [159], 113 | [], 114 | [], 115 | [167], 116 | [], 117 | [170], 118 | [172], 119 | [174], 120 | [176], 121 | [], 122 | [], 123 | [], 124 | [], 125 | [], 126 | [], 127 | [], 128 | [194], 129 | [], 130 | ], 131 | **kwargs, 132 | ) 133 | 134 | 135 | class ImageNet100C(CustomImageNet): 136 | """ 137 | ImageNet-C, but restricted to the ImageNet-100 classes. 138 | """ 139 | 140 | def __init__( 141 | self, 142 | data_path, 143 | corruption_type: str = 'gaussian_noise', 144 | severity: int = 1, 145 | **kwargs, 146 | ): 147 | # Need to create a temporary directory to act as the dataset because 148 | # the robustness library expects a particular directory structure. 149 | tmp_data_path = tempfile.mkdtemp() 150 | os.symlink(os.path.join(data_path, corruption_type, str(severity)), 151 | os.path.join(tmp_data_path, 'test')) 152 | 153 | super().__init__( 154 | data_path=tmp_data_path, 155 | custom_grouping=[[label] for label in range(0, 1000, 10)], 156 | **kwargs, 157 | ) 158 | 159 | 160 | class CIFAR10C(CIFAR): 161 | """ 162 | CIFAR-10-C from https://github.com/hendrycks/robustness. 163 | """ 164 | 165 | def __init__( 166 | self, 167 | data_path, 168 | corruption_type: str = 'gaussian_noise', 169 | severity: int = 1, 170 | **kwargs, 171 | ): 172 | class CustomCIFAR10(CIFAR10): 173 | def __init__(self, root, train=True, transform=None, 174 | target_transform=None, download=False): 175 | VisionDataset.__init__(self, root, transform=transform, 176 | target_transform=target_transform) 177 | 178 | if train: 179 | raise NotImplementedError( 180 | 'No train dataset for CIFAR-10-C') 181 | if download and not os.path.exists(root): 182 | raise NotImplementedError( 183 | 'Downloading CIFAR-10-C has not been implemented') 184 | 185 | all_data = np.load( 186 | os.path.join(root, f'{corruption_type}.npy')) 187 | all_labels = np.load(os.path.join(root, f'labels.npy')) 188 | 189 | severity_slice = slice( 190 | (severity - 1) * 10000, 191 | severity * 10000, 192 | ) 193 | 194 | self.data = all_data[severity_slice] 195 | self.targets = all_labels[severity_slice] 196 | 197 | DataSet.__init__( 198 | self, 199 | 'cifar10c', 200 | data_path, 201 | num_classes=10, 202 | mean=torch.tensor([0.4914, 0.4822, 0.4465]), 203 | std=torch.tensor([0.2023, 0.1994, 0.2010]), 204 | custom_class=CustomCIFAR10, 205 | label_mapping=None, 206 | transform_train=data_augmentation.TRAIN_TRANSFORMS_DEFAULT(32), 207 | transform_test=data_augmentation.TEST_TRANSFORMS_DEFAULT(32) 208 | ) 209 | 210 | 211 | class BirdOrBicycle(DataSet): 212 | """ 213 | Bird-or-bicycle dataset. 214 | https://github.com/google/unrestricted-adversarial-examples/tree/master/bird-or-bicycle 215 | """ 216 | 217 | def __init__(self, data_path=None, **kwargs): 218 | ds_name = 'bird_or_bicycle' 219 | import bird_or_bicycle 220 | 221 | # Need to create a temporary directory to act as the dataset because 222 | # the robustness library expects a particular directory structure. 223 | data_path = tempfile.mkdtemp() 224 | os.symlink(bird_or_bicycle.get_dataset('extras'), 225 | os.path.join(data_path, 'train')) 226 | os.symlink(bird_or_bicycle.get_dataset('test'), 227 | os.path.join(data_path, 'test')) 228 | 229 | ds_kwargs = { 230 | 'num_classes': 2, 231 | 'mean': torch.tensor([0.4717, 0.4499, 0.3837]), 232 | 'std': torch.tensor([0.2600, 0.2516, 0.2575]), 233 | 'custom_class': None, 234 | 'label_mapping': None, 235 | 'transform_train': TRAIN_TRANSFORMS_IMAGENET, 236 | 'transform_test': TEST_TRANSFORMS_IMAGENET, 237 | } 238 | super().__init__(ds_name, data_path, **ds_kwargs) 239 | 240 | 241 | DATASETS['imagenet100'] = ImageNet100 242 | DATASETS['imagenet100a'] = ImageNet100A 243 | DATASETS['imagenet100c'] = ImageNet100C 244 | DATASETS['cifar10c'] = CIFAR10C 245 | DATASETS['bird_or_bicycle'] = BirdOrBicycle 246 | -------------------------------------------------------------------------------- /perceptual_advex/distances.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union, cast 2 | import torch 3 | import torch.autograd 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import torchvision.models as torchvision_models 7 | from torch.autograd import Variable 8 | from math import exp 9 | from torch import nn 10 | from typing_extensions import Literal 11 | 12 | from .models import AlexNetFeatureModel, FeatureModel 13 | 14 | 15 | class LPIPSDistance(nn.Module): 16 | """ 17 | Calculates the square root of the Learned Perceptual Image Patch Similarity 18 | (LPIPS) between two images, using a given neural network. 19 | """ 20 | 21 | model: FeatureModel 22 | 23 | def __init__( 24 | self, 25 | model: Optional[Union[FeatureModel, nn.DataParallel]] = None, 26 | activation_distance: Literal['l2'] = 'l2', 27 | include_image_as_activation: bool = False, 28 | ): 29 | """ 30 | Constructs an LPIPS distance metric. The given network should return a 31 | tuple of (activations, logits). If a network is not specified, AlexNet 32 | will be used. activation_distance can be 'l2' or 'cw_ssim'. 33 | """ 34 | 35 | super().__init__() 36 | 37 | if model is None: 38 | alexnet_model = torchvision_models.alexnet(pretrained=True) 39 | self.model = AlexNetFeatureModel(alexnet_model) 40 | elif isinstance(model, nn.DataParallel): 41 | self.model = cast(FeatureModel, model.module) 42 | else: 43 | self.model = model 44 | 45 | self.activation_distance = activation_distance 46 | self.include_image_as_activation = include_image_as_activation 47 | 48 | self.eval() 49 | 50 | def features(self, image: torch.Tensor) -> Tuple[torch.Tensor, ...]: 51 | features = self.model.features(image) 52 | if self.include_image_as_activation: 53 | features = (image,) + features 54 | return features 55 | 56 | def forward(self, image1, image2): 57 | features1 = self.features(image1) 58 | features2 = self.features(image2) 59 | 60 | if self.activation_distance == 'l2': 61 | return ( 62 | normalize_flatten_features(features1) - 63 | normalize_flatten_features(features2) 64 | ).norm(dim=1) 65 | else: 66 | raise ValueError( 67 | f'Invalid activation_distance "{self.activation_distance}"') 68 | 69 | 70 | class LinearizedLPIPSDistance(LPIPSDistance): 71 | """ 72 | An approximation of the LPIPS distance using the Jacobian of the feature 73 | network, i.e. d(x1, x2) = || D phi(x1) (x2 - x1) ||_2. 74 | """ 75 | 76 | def __init__(self, *args, **kwargs): 77 | super().__init__(*args, **kwargs) 78 | if self.activation_distance != 'l2': 79 | raise ValueError( 80 | f'Invalid activation_distance "{self.activation_distance}"' 81 | ) 82 | 83 | def forward(self, image1, image2): 84 | # Use the double-autograd trick for forward derivatives from 85 | # https://j-towns.github.io/2017/06/12/A-new-trick.html 86 | # and https://github.com/pytorch/pytorch/issues/10223#issuecomment-547104071 87 | 88 | image1 = image1.detach().requires_grad_() 89 | diff = image2 - image1 90 | features1 = normalize_flatten_features(self.features(image1)) 91 | v = torch.ones_like(features1, requires_grad=True) 92 | vjp, = torch.autograd.grad( 93 | features1, 94 | image1, 95 | grad_outputs=v, 96 | create_graph=True, 97 | ) 98 | output, = torch.autograd.grad(vjp, v, grad_outputs=diff) 99 | return output.norm(dim=1) 100 | 101 | def normalize_flatten_features( 102 | features: Tuple[torch.Tensor, ...], 103 | eps=1e-10, 104 | ) -> torch.Tensor: 105 | """ 106 | Given a tuple of features (layer1, layer2, layer3, ...) from a network, 107 | flattens those features into a single vector per batch input. The 108 | features are also scaled such that the L2 distance between features 109 | for two different inputs is the LPIPS distance between those inputs. 110 | """ 111 | 112 | normalized_features: List[torch.Tensor] = [] 113 | for feature_layer in features: 114 | norm_factor = torch.sqrt( 115 | torch.sum(feature_layer ** 2, dim=1, keepdim=True)) + eps 116 | normalized_features.append( 117 | (feature_layer / (norm_factor * 118 | np.sqrt(feature_layer.size()[2] * 119 | feature_layer.size()[3]))) 120 | .view(feature_layer.size()[0], -1) 121 | ) 122 | return torch.cat(normalized_features, dim=1) 123 | 124 | 125 | def gaussian(window_size, sigma): 126 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 127 | return gauss/gauss.sum() 128 | 129 | def create_window(window_size, channel): 130 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 131 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 132 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 133 | return window 134 | 135 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 136 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 137 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 138 | 139 | mu1_sq = mu1.pow(2) 140 | mu2_sq = mu2.pow(2) 141 | mu1_mu2 = mu1*mu2 142 | 143 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 144 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 145 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 146 | 147 | C1 = 0.01**2 148 | C2 = 0.03**2 149 | 150 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 151 | 152 | if size_average: 153 | return ssim_map.mean() 154 | else: 155 | return ssim_map.mean(1).mean(1).mean(1) 156 | 157 | 158 | class SSIM(nn.Module): 159 | """ 160 | Copied from https://github.com/Po-Hsun-Su/pytorch-ssim 161 | """ 162 | 163 | def __init__(self, window_size=11, size_average=True, dissimilarity=False): 164 | super(SSIM, self).__init__() 165 | self.window_size = window_size 166 | self.size_average = size_average 167 | self.channel = 1 168 | self.window = create_window(window_size, self.channel) 169 | self.dissimilarity = dissimilarity 170 | 171 | def forward(self, imgs1, imgs2): 172 | (_, channel, _, _) = imgs1.size() 173 | 174 | if channel == self.channel and self.window.data.type() == imgs1.data.type(): 175 | window = self.window 176 | else: 177 | window = create_window(self.window_size, channel) 178 | 179 | if imgs1.is_cuda: 180 | window = window.cuda(imgs1.get_device()) 181 | window = window.type_as(imgs1) 182 | 183 | self.window = window 184 | self.channel = channel 185 | 186 | sim = torch.tensor([ 187 | _ssim(img1[None], img2[None], window, self.window_size, channel, self.size_average) 188 | for img1, img2 in zip(imgs1, imgs2) 189 | ]) 190 | return 1 - sim if self.dissimilarity else sim 191 | 192 | 193 | class L2Distance(nn.Module): 194 | def forward(self, img1, img2): 195 | return (img1 - img2).reshape(img1.shape[0], -1).norm(dim=1) 196 | 197 | 198 | class LinfDistance(nn.Module): 199 | def forward(self, img1, img2): 200 | return (img1 - img2).reshape(img1.shape[0], -1).abs().max(dim=1)[0] 201 | -------------------------------------------------------------------------------- /perceptual_advex/evaluation.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import random 4 | import copy 5 | from torch import nn 6 | 7 | from .distances import LPIPSDistance 8 | 9 | 10 | def evaluate_against_attacks(model, attacks, val_loader, parallel=1, 11 | writer=None, iteration=None, num_batches=None): 12 | """ 13 | Evaluates a model against the given attacks, printing the output and 14 | optionally writing it to a tensorboardX summary writer. 15 | """ 16 | 17 | model_lpips_model: nn.Module = LPIPSDistance(model) 18 | alexnet_lpips_model: nn.Module = LPIPSDistance() 19 | 20 | if torch.cuda.is_available(): 21 | model_lpips_model.cuda() 22 | alexnet_lpips_model.cuda() 23 | 24 | device_ids = list(range(parallel)) 25 | model_lpips_model = nn.DataParallel(model_lpips_model, device_ids) 26 | alexnet_lpips_model = nn.DataParallel(alexnet_lpips_model, device_ids) 27 | 28 | model_state_dict = copy.deepcopy(model.state_dict()) 29 | 30 | for attack in attacks: 31 | if isinstance(attack, nn.DataParallel): 32 | attack_name = attack.module.__class__.__name__ 33 | else: 34 | attack_name = attack.__class__.__name__ 35 | 36 | batches_correct = [] 37 | successful_attacks = [] 38 | successful_model_lpips = [] 39 | successful_alexnet_lpips = [] 40 | for batch_index, (inputs, labels) in enumerate(val_loader): 41 | if num_batches is not None and batch_index >= num_batches: 42 | break 43 | 44 | if torch.cuda.is_available(): 45 | inputs = inputs.cuda() 46 | labels = labels.cuda() 47 | 48 | adv_inputs = attack(inputs, labels) 49 | 50 | with torch.no_grad(): 51 | logits = model(inputs) 52 | adv_logits = model(adv_inputs) 53 | batches_correct.append((adv_logits.argmax(1) == labels).detach()) 54 | 55 | success = ( 56 | (logits.argmax(1) == labels) & # was classified correctly 57 | (adv_logits.argmax(1) != labels) # and now is not 58 | ) 59 | 60 | inputs_success = inputs[success] 61 | adv_inputs_success = adv_inputs[success] 62 | num_samples = min(len(inputs_success), 1) 63 | adv_indices = random.sample(range(len(inputs_success)), 64 | num_samples) 65 | for adv_index in adv_indices: 66 | successful_attacks.append(torch.cat([ 67 | inputs_success[adv_index], 68 | adv_inputs_success[adv_index], 69 | torch.clamp((adv_inputs_success[adv_index] - 70 | inputs_success[adv_index]) * 3 + 0.5, 71 | 0, 1), 72 | ], dim=1).detach()) 73 | 74 | if success.sum() > 0: 75 | successful_model_lpips.extend(model_lpips_model( 76 | inputs_success, 77 | adv_inputs_success, 78 | ).detach()) 79 | successful_alexnet_lpips.extend(alexnet_lpips_model( 80 | inputs_success, 81 | adv_inputs_success, 82 | ).detach()) 83 | print_cols = [f'ATTACK {attack_name}'] 84 | 85 | correct = torch.cat(batches_correct) 86 | accuracy = correct.float().mean() 87 | if writer is not None: 88 | writer.add_scalar(f'val/{attack_name}/accuracy', 89 | accuracy.item(), 90 | iteration) 91 | print_cols.append(f'accuracy: {accuracy.item() * 100:.1f}%') 92 | 93 | print(*print_cols, sep='\t') 94 | 95 | for lpips_name, successful_lpips in [ 96 | ('alexnet', successful_alexnet_lpips), 97 | ('model', successful_model_lpips), 98 | ]: 99 | if len(successful_lpips) > 0 and writer is not None: 100 | writer.add_histogram(f'val/{attack_name}/lpips/{lpips_name}', 101 | torch.stack(successful_lpips) 102 | .cpu().detach().numpy(), 103 | iteration) 104 | 105 | if len(successful_attacks) > 0 and writer is not None: 106 | writer.add_image(f'val/{attack_name}/images', 107 | torch.cat(successful_attacks, dim=2), 108 | iteration) 109 | 110 | new_model_state_dict = copy.deepcopy(model.state_dict()) 111 | for key in model_state_dict: 112 | old_tensor = model_state_dict[key] 113 | new_tensor = new_model_state_dict[key] 114 | max_diff = (old_tensor - new_tensor).abs().max().item() 115 | if max_diff > 1e-8: 116 | print(f'max difference for {key} = {max_diff}') 117 | -------------------------------------------------------------------------------- /perceptual_advex/models.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torchvision import transforms 6 | from torchvision import models as torchvision_models 7 | from robustness.cifar_models.resnet import ResNet 8 | 9 | from .trades_wrn import TradesWideResNet 10 | 11 | 12 | class FeatureModel(nn.Module): 13 | """ 14 | A classifier model which can produce layer features, output logits, or 15 | both. 16 | """ 17 | 18 | normalizer: nn.Module 19 | model: nn.Module 20 | 21 | def __init__(self): 22 | super().__init__() 23 | self._allow_training = False 24 | self.eval() 25 | 26 | def features(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]: 27 | """ 28 | Should return a tuple of features (layer1, layer2, ...). 29 | """ 30 | 31 | raise NotImplementedError() 32 | 33 | def classifier(self, last_layer: torch.Tensor) -> torch.Tensor: 34 | """ 35 | Given the final activation, returns the output logits. 36 | """ 37 | 38 | raise NotImplementedError() 39 | 40 | def forward(self, x: torch.Tensor) -> torch.Tensor: 41 | """ 42 | Returns logits for the given inputs. 43 | """ 44 | 45 | return self.classifier(self.features(x)[-1]) 46 | 47 | def features_logits( 48 | self, 49 | x: torch.Tensor, 50 | ) -> Tuple[Tuple[torch.Tensor, ...], torch.Tensor]: 51 | """ 52 | Returns a tuple (features, logits) for the given inputs. 53 | """ 54 | 55 | features = self.features(x) 56 | logits = self.classifier(features[-1]) 57 | return features, logits 58 | 59 | def allow_train(self): 60 | self._allow_training = True 61 | 62 | def train(self, mode=True): 63 | if mode is True and not self._allow_training: 64 | raise RuntimeError('should not be in train mode') 65 | super().train(mode) 66 | 67 | 68 | class ImageNetNormalizer(nn.Module): 69 | def __init__(self, mean=[0.485, 0.456, 0.406], 70 | std=[0.229, 0.224, 0.225]): 71 | super().__init__() 72 | self.mean = mean 73 | self.std = std 74 | 75 | def forward(self, x): 76 | mean = torch.tensor(self.mean, device=x.device) 77 | std = torch.tensor(self.std, device=x.device) 78 | 79 | return ( 80 | (x - mean[None, :, None, None]) / 81 | std[None, :, None, None] 82 | ) 83 | 84 | 85 | class AlexNetFeatureModel(FeatureModel): 86 | model: torchvision_models.AlexNet 87 | 88 | def __init__(self, alexnet_model: torchvision_models.AlexNet): 89 | super().__init__() 90 | self.normalizer = ImageNetNormalizer() 91 | self.model = alexnet_model.eval() 92 | 93 | assert len(self.model.features) == 13 94 | self.layer1 = nn.Sequential(self.model.features[:2]) 95 | self.layer2 = nn.Sequential(self.model.features[2:5]) 96 | self.layer3 = nn.Sequential(self.model.features[5:8]) 97 | self.layer4 = nn.Sequential(self.model.features[8:10]) 98 | self.layer5 = nn.Sequential(self.model.features[10:12]) 99 | self.layer6 = self.model.features[12] 100 | 101 | def features(self, x): 102 | x = self.normalizer(x) 103 | 104 | x_layer1 = self.layer1(x) 105 | x_layer2 = self.layer2(x_layer1) 106 | x_layer3 = self.layer3(x_layer2) 107 | x_layer4 = self.layer4(x_layer3) 108 | x_layer5 = self.layer5(x_layer4) 109 | 110 | return (x_layer1, x_layer2, x_layer3, x_layer4, x_layer5) 111 | 112 | def classifier(self, last_layer): 113 | x = self.layer6(last_layer) 114 | if isinstance(self.model, CifarAlexNet): 115 | x = x.view(x.size(0), 256 * 2 * 2) 116 | else: 117 | x = self.model.avgpool(x) 118 | x = torch.flatten(x, 1) 119 | x = self.model.classifier(x) 120 | return x 121 | 122 | 123 | class VGG16FeatureModel(FeatureModel): 124 | model: torchvision_models.VGG 125 | 126 | def __init__(self, vgg_model: torchvision_models.VGG): 127 | super().__init__() 128 | 129 | self.normalizer = ImageNetNormalizer() 130 | self.model = vgg_model.eval() 131 | 132 | self.layer1 = nn.Sequential(self.model.features[:4]) 133 | self.layer2 = nn.Sequential(self.model.features[4:9]) 134 | self.layer3 = nn.Sequential(self.model.features[9:16]) 135 | self.layer4 = nn.Sequential(self.model.features[16:23]) 136 | self.layer5 = nn.Sequential(self.model.features[23:30]) 137 | 138 | def features(self, x): 139 | x = self.normalizer(x) 140 | 141 | x_layer1 = self.layer1(x) 142 | x_layer2 = self.layer2(x_layer1) 143 | x_layer3 = self.layer3(x_layer2) 144 | x_layer4 = self.layer4(x_layer3) 145 | x_layer5 = self.layer5(x_layer4) 146 | 147 | return (x_layer1, x_layer2, x_layer3, x_layer4, x_layer5) 148 | 149 | 150 | class CifarResNetFeatureModel(FeatureModel): 151 | model: ResNet 152 | 153 | def __init__(self, attacker_model): 154 | super().__init__() 155 | self.normalizer = attacker_model.normalizer 156 | self.model = attacker_model.model 157 | 158 | def features(self, x): 159 | x = self.normalizer(x) 160 | 161 | x = F.relu(self.model.bn1(self.model.conv1(x))) 162 | 163 | x = self.model.layer1(x) 164 | x_layer1 = x 165 | x = self.model.layer2(x) 166 | x_layer2 = x 167 | x = self.model.layer3(x) 168 | x_layer3 = x 169 | x = self.model.layer4(x, fake_relu=False) 170 | x_layer4 = x 171 | 172 | return (x_layer1, x_layer2, x_layer3, x_layer4) 173 | 174 | def classifier(self, last_layer): 175 | x = F.avg_pool2d(last_layer, 4) 176 | x = x.view(x.size(0), -1) 177 | x = self.model.linear(x) 178 | return x 179 | 180 | 181 | class ImageNetResNetFeatureModel(FeatureModel): 182 | model: torchvision_models.ResNet 183 | 184 | def __init__(self, attacker_model: torchvision_models.ResNet): 185 | super().__init__() 186 | self.normalizer = attacker_model.normalizer 187 | self.model = attacker_model.model 188 | 189 | def features(self, x): 190 | x = self.normalizer(x) 191 | 192 | x = self.model.conv1(x) 193 | x = self.model.bn1(x) 194 | x = self.model.relu(x) 195 | x = self.model.maxpool(x) 196 | 197 | x = self.model.layer1(x) 198 | x_layer1 = x 199 | x = self.model.layer2(x) 200 | x_layer2 = x 201 | x = self.model.layer3(x) 202 | x_layer3 = x 203 | x = self.model.layer4(x) 204 | x_layer4 = x 205 | 206 | return (x_layer1, x_layer2, x_layer3, x_layer4) 207 | 208 | def classifier(self, last_layer): 209 | x = self.model.avgpool(last_layer) 210 | x = x.view(x.size(0), -1) 211 | x = self.model.fc(x) 212 | return x 213 | 214 | 215 | class CifarAlexNet(nn.Module): 216 | def __init__(self, num_classes=10): 217 | super().__init__() 218 | self.features = nn.Sequential( 219 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), 220 | nn.ReLU(inplace=True), 221 | nn.MaxPool2d(kernel_size=2), 222 | nn.Conv2d(64, 192, kernel_size=3, padding=1), 223 | nn.ReLU(inplace=True), 224 | nn.MaxPool2d(kernel_size=2), 225 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 226 | nn.ReLU(inplace=True), 227 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 228 | nn.ReLU(inplace=True), 229 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 230 | nn.ReLU(inplace=True), 231 | nn.MaxPool2d(kernel_size=2), 232 | ) 233 | self.classifier = nn.Sequential( 234 | nn.Dropout(), 235 | nn.Linear(256 * 2 * 2, 4096), 236 | nn.ReLU(inplace=True), 237 | nn.Dropout(), 238 | nn.Linear(4096, 4096), 239 | nn.ReLU(inplace=True), 240 | nn.Linear(4096, num_classes), 241 | ) 242 | 243 | def forward(self, x): 244 | x = self.features(x) 245 | x = x.view(x.size(0), 256 * 2 * 2) 246 | x = self.classifier(x) 247 | return x 248 | -------------------------------------------------------------------------------- /perceptual_advex/perceptual_attacks.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | import torch 3 | import torchvision.models as torchvision_models 4 | from torch.hub import load_state_dict_from_url 5 | import math 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from typing_extensions import Literal 9 | 10 | from .distances import normalize_flatten_features, LPIPSDistance 11 | from .utilities import MarginLoss 12 | from .models import AlexNetFeatureModel, CifarAlexNet, FeatureModel 13 | from . import utilities 14 | 15 | 16 | _cached_alexnet: Optional[AlexNetFeatureModel] = None 17 | _cached_alexnet_cifar: Optional[AlexNetFeatureModel] = None 18 | 19 | 20 | def get_lpips_model( 21 | lpips_model_spec: Union[ 22 | Literal['self', 'alexnet', 'alexnet_cifar'], 23 | FeatureModel, 24 | ], 25 | model: Optional[FeatureModel] = None, 26 | ) -> FeatureModel: 27 | global _cached_alexnet, _cached_alexnet_cifar 28 | 29 | lpips_model: FeatureModel 30 | 31 | if lpips_model_spec == 'self': 32 | if model is None: 33 | raise ValueError( 34 | 'Specified "self" for LPIPS model but no model passed' 35 | ) 36 | return model 37 | elif lpips_model_spec == 'alexnet': 38 | if _cached_alexnet is None: 39 | alexnet_model = torchvision_models.alexnet(pretrained=True) 40 | _cached_alexnet = AlexNetFeatureModel(alexnet_model) 41 | lpips_model = _cached_alexnet 42 | if torch.cuda.is_available(): 43 | lpips_model.cuda() 44 | elif lpips_model_spec == 'alexnet_cifar': 45 | if _cached_alexnet_cifar is None: 46 | alexnet_model = CifarAlexNet() 47 | _cached_alexnet_cifar = AlexNetFeatureModel(alexnet_model) 48 | lpips_model = _cached_alexnet_cifar 49 | if torch.cuda.is_available(): 50 | lpips_model.cuda() 51 | try: 52 | state = torch.load('data/checkpoints/alexnet_cifar.pt') 53 | except FileNotFoundError: 54 | state = load_state_dict_from_url( 55 | 'https://perceptual-advex.s3.us-east-2.amazonaws.com/' 56 | 'alexnet_cifar.pt', 57 | progress=True, 58 | ) 59 | lpips_model.load_state_dict(state['model']) 60 | elif isinstance(lpips_model_spec, str): 61 | raise ValueError(f'Invalid LPIPS model "{lpips_model_spec}"') 62 | else: 63 | lpips_model = lpips_model_spec 64 | 65 | lpips_model.eval() 66 | return lpips_model 67 | 68 | 69 | class FastLagrangePerceptualAttack(nn.Module): 70 | def __init__(self, model, bound=0.5, step=None, num_iterations=20, 71 | lam=10, h=1e-1, lpips_model='self', decay_step_size=True, 72 | increase_lambda=True, projection='none', kappa=math.inf, 73 | include_image_as_activation=False, randomize=False): 74 | """ 75 | Perceptual attack using a Lagrangian relaxation of the 76 | LPIPS-constrainted optimization problem. 77 | 78 | bound is the (soft) bound on the LPIPS distance. 79 | step is the LPIPS step size. 80 | num_iterations is the number of steps to take. 81 | lam is the lambda value multiplied by the regularization term. 82 | h is the step size to use for finite-difference calculation. 83 | lpips_model is the model to use to calculate LPIPS or 'self' or 84 | 'alexnet' 85 | """ 86 | 87 | super().__init__() 88 | 89 | assert randomize is False 90 | 91 | self.model = model 92 | self.bound = bound 93 | if step is None: 94 | self.step = self.bound 95 | else: 96 | self.step = step 97 | self.num_iterations = num_iterations 98 | self.lam = lam 99 | self.h = h 100 | self.decay_step_size = decay_step_size 101 | self.increase_lambda = increase_lambda 102 | 103 | self.lpips_model = get_lpips_model(lpips_model, model) 104 | self.lpips_distance = LPIPSDistance( 105 | self.lpips_model, 106 | include_image_as_activation=include_image_as_activation, 107 | ) 108 | self.projection = PROJECTIONS[projection](self.bound, self.lpips_model) 109 | self.loss = MarginLoss(kappa=kappa) 110 | 111 | def _get_features(self, inputs: torch.Tensor) -> torch.Tensor: 112 | return normalize_flatten_features(self.lpips_model.features(inputs)) 113 | 114 | def _get_features_logits( 115 | self, inputs: torch.Tensor 116 | ) -> Tuple[torch.Tensor, torch.Tensor]: 117 | features, logits = self.lpips_model.features_logits(inputs) 118 | return normalize_flatten_features(features), logits 119 | 120 | def forward(self, inputs, labels): 121 | perturbations = torch.zeros_like(inputs) 122 | perturbations.normal_(0, 0.01) 123 | 124 | perturbations.requires_grad = True 125 | 126 | step_size = self.step 127 | lam = self.lam 128 | 129 | input_features = self._get_features(inputs).detach() 130 | 131 | for attack_iter in range(self.num_iterations): 132 | # Decay step size, but increase lambda over time. 133 | if self.decay_step_size: 134 | step_size = \ 135 | self.step * 0.1 ** (attack_iter / self.num_iterations) 136 | if self.increase_lambda: 137 | lam = \ 138 | self.lam * 0.1 ** (1 - attack_iter / self.num_iterations) 139 | 140 | if perturbations.grad is not None: 141 | perturbations.grad.data.zero_() 142 | 143 | adv_inputs = inputs + perturbations 144 | 145 | if self.model == self.lpips_model: 146 | adv_features, adv_logits = \ 147 | self._get_features_logits(adv_inputs) 148 | else: 149 | adv_features = self._get_features(adv_inputs) 150 | adv_logits = self.model(adv_inputs) 151 | 152 | adv_loss = self.loss(adv_logits, labels) 153 | 154 | lpips_dists = (adv_features - input_features).norm(dim=1) 155 | 156 | loss = -adv_loss + lam * F.relu(lpips_dists - self.bound) 157 | loss.sum().backward() 158 | 159 | grad = perturbations.grad.data 160 | grad_normed = grad / \ 161 | (grad.reshape(grad.size()[0], -1).norm(dim=1) 162 | [:, None, None, None] + 1e-8) 163 | 164 | dist_grads = ( 165 | adv_features - self._get_features( 166 | inputs + perturbations - grad_normed * self.h) 167 | ).norm(dim=1) / 0.1 168 | 169 | perturbation_updates = -grad_normed * ( 170 | step_size / (dist_grads + 1e-4) 171 | )[:, None, None, None] 172 | 173 | perturbations.data = ( 174 | (inputs + perturbations + perturbation_updates).clamp(0, 1) - 175 | inputs 176 | ).detach() 177 | 178 | adv_inputs = (inputs + perturbations).detach() 179 | return self.projection(inputs, adv_inputs, input_features).detach() 180 | 181 | 182 | class NoProjection(nn.Module): 183 | def __init__(self, bound, lpips_model): 184 | super().__init__() 185 | 186 | def forward(self, inputs, adv_inputs, input_features=None): 187 | return adv_inputs 188 | 189 | 190 | class BisectionPerceptualProjection(nn.Module): 191 | def __init__(self, bound, lpips_model, num_steps=10): 192 | super().__init__() 193 | 194 | self.bound = bound 195 | self.lpips_model = lpips_model 196 | self.num_steps = num_steps 197 | 198 | def forward(self, inputs, adv_inputs, input_features=None): 199 | batch_size = inputs.shape[0] 200 | if input_features is None: 201 | input_features = normalize_flatten_features( 202 | self.lpips_model.features(inputs)) 203 | 204 | lam_min = torch.zeros(batch_size, device=inputs.device) 205 | lam_max = torch.ones(batch_size, device=inputs.device) 206 | lam = 0.5 * torch.ones(batch_size, device=inputs.device) 207 | 208 | for _ in range(self.num_steps): 209 | projected_adv_inputs = ( 210 | inputs * (1 - lam[:, None, None, None]) + 211 | adv_inputs * lam[:, None, None, None] 212 | ) 213 | adv_features = self.lpips_model.features(projected_adv_inputs) 214 | adv_features = normalize_flatten_features(adv_features).detach() 215 | diff_features = adv_features - input_features 216 | norm_diff_features = torch.norm(diff_features, dim=1) 217 | 218 | lam_max[norm_diff_features > self.bound] = \ 219 | lam[norm_diff_features > self.bound] 220 | lam_min[norm_diff_features <= self.bound] = \ 221 | lam[norm_diff_features <= self.bound] 222 | lam = 0.5*(lam_min + lam_max) 223 | return projected_adv_inputs.detach() 224 | 225 | 226 | class NewtonsPerceptualProjection(nn.Module): 227 | def __init__(self, bound, lpips_model, projection_overshoot=1e-1, 228 | max_iterations=10): 229 | super().__init__() 230 | 231 | self.bound = bound 232 | self.lpips_model = lpips_model 233 | self.projection_overshoot = projection_overshoot 234 | self.max_iterations = max_iterations 235 | self.bisection_projection = BisectionPerceptualProjection( 236 | bound, lpips_model) 237 | 238 | def forward(self, inputs, adv_inputs, input_features=None): 239 | original_adv_inputs = adv_inputs 240 | if input_features is None: 241 | input_features = normalize_flatten_features( 242 | self.lpips_model.features(inputs)) 243 | 244 | needs_projection = torch.ones_like(adv_inputs[:, 0, 0, 0]) \ 245 | .bool() 246 | 247 | needs_projection.requires_grad = False 248 | iteration = 0 249 | while needs_projection.sum() > 0 and iteration < self.max_iterations: 250 | adv_inputs.requires_grad = True 251 | adv_features = normalize_flatten_features( 252 | self.lpips_model.features(adv_inputs[needs_projection])) 253 | adv_lpips = (input_features[needs_projection] - 254 | adv_features).norm(dim=1) 255 | adv_lpips.sum().backward() 256 | 257 | projection_step_size = (adv_lpips - self.bound) \ 258 | .clamp(min=0) 259 | projection_step_size[projection_step_size > 0] += \ 260 | self.projection_overshoot 261 | 262 | grad_norm = adv_inputs.grad.data[needs_projection] \ 263 | .view(needs_projection.sum(), -1).norm(dim=1) 264 | inverse_grad = adv_inputs.grad.data[needs_projection] / \ 265 | grad_norm[:, None, None, None] ** 2 266 | 267 | adv_inputs.data[needs_projection] = ( 268 | adv_inputs.data[needs_projection] - 269 | projection_step_size[:, None, None, None] * 270 | (1 + self.projection_overshoot) * 271 | inverse_grad 272 | ).clamp(0, 1).detach() 273 | 274 | needs_projection[needs_projection.clone()] = \ 275 | projection_step_size > 0 276 | iteration += 1 277 | 278 | if needs_projection.sum() > 0: 279 | # If we still haven't projected all inputs after max_iterations, 280 | # just use the bisection method. 281 | adv_inputs = self.bisection_projection( 282 | inputs, original_adv_inputs, input_features) 283 | 284 | return adv_inputs.detach() 285 | 286 | 287 | PROJECTIONS = { 288 | 'none': NoProjection, 289 | 'linesearch': BisectionPerceptualProjection, 290 | 'bisection': BisectionPerceptualProjection, 291 | 'gradient': NewtonsPerceptualProjection, 292 | 'newtons': NewtonsPerceptualProjection, 293 | } 294 | 295 | 296 | class FirstOrderStepPerceptualAttack(nn.Module): 297 | def __init__(self, model, bound=0.5, num_iterations=5, 298 | h=1e-3, kappa=1, lpips_model='self', 299 | targeted=False, randomize=False, 300 | include_image_as_activation=False): 301 | """ 302 | Perceptual attack using conjugate gradient to solve the constrained 303 | optimization problem. 304 | 305 | bound is the (approximate) bound on the LPIPS distance. 306 | num_iterations is the number of CG iterations to take. 307 | h is the step size to use for finite-difference calculation. 308 | """ 309 | 310 | super().__init__() 311 | 312 | assert randomize is False 313 | 314 | self.model = model 315 | self.bound = bound 316 | self.num_iterations = num_iterations 317 | self.h = h 318 | 319 | self.lpips_model = get_lpips_model(lpips_model, model) 320 | self.lpips_distance = LPIPSDistance( 321 | self.lpips_model, 322 | include_image_as_activation=include_image_as_activation, 323 | ) 324 | self.loss = MarginLoss(kappa=kappa, targeted=targeted) 325 | 326 | def _multiply_matrix(self, v): 327 | """ 328 | If (D phi) is the Jacobian of the features function for the model 329 | at inputs, then approximately calculates 330 | (D phi)T (D phi) v 331 | """ 332 | 333 | self.inputs.grad.data.zero_() 334 | 335 | with torch.no_grad(): 336 | v_features = self.lpips_model.features(self.inputs.detach() + 337 | self.h * v) 338 | D_phi_v = ( 339 | normalize_flatten_features(v_features) - 340 | self.input_features 341 | ) / self.h 342 | 343 | torch.sum(self.input_features * D_phi_v).backward(retain_graph=True) 344 | 345 | return self.inputs.grad.data.clone() 346 | 347 | def forward(self, inputs, labels): 348 | self.inputs = inputs 349 | 350 | inputs.requires_grad = True 351 | if self.model == self.lpips_model: 352 | input_features, orig_logits = self.model.features_logits(inputs) 353 | else: 354 | input_features = self.lpips_model.features(inputs) 355 | orig_logits = self.model(inputs) 356 | self.input_features = normalize_flatten_features(input_features) 357 | 358 | loss = self.loss(orig_logits, labels) 359 | loss.sum().backward(retain_graph=True) 360 | 361 | inputs_grad = inputs.grad.data.clone() 362 | if inputs_grad.abs().max() < 1e-4: 363 | return inputs 364 | 365 | # Variable names are from 366 | # https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_resulting_algorithm 367 | x = torch.zeros_like(inputs) 368 | r = inputs_grad - self._multiply_matrix(x) 369 | p = r 370 | 371 | for cg_iter in range(self.num_iterations): 372 | r_last = r 373 | p_last = p 374 | x_last = x 375 | del r, p, x 376 | 377 | r_T_r = (r_last ** 2).sum(dim=[1, 2, 3]) 378 | if r_T_r.max() < 1e-1 and cg_iter > 0: 379 | # If the residual is small enough, just stop the algorithm. 380 | x = x_last 381 | break 382 | 383 | A_p_last = self._multiply_matrix(p_last) 384 | 385 | # print('|r|^2 =', ' '.join(f'{z:.2f}' for z in r_T_r)) 386 | alpha = ( 387 | r_T_r / 388 | (p_last * A_p_last).sum(dim=[1, 2, 3]) 389 | )[:, None, None, None] 390 | x = x_last + alpha * p_last 391 | 392 | # These calculations aren't necessary on the last iteration. 393 | if cg_iter < self.num_iterations - 1: 394 | r = r_last - alpha * A_p_last 395 | 396 | beta = ( 397 | (r ** 2).sum(dim=[1, 2, 3]) / 398 | r_T_r 399 | )[:, None, None, None] 400 | p = r + beta * p_last 401 | 402 | x_features = self.lpips_model.features(self.inputs.detach() + 403 | self.h * x) 404 | D_phi_x = ( 405 | normalize_flatten_features(x_features) - 406 | self.input_features 407 | ) / self.h 408 | 409 | lam = (self.bound / D_phi_x.norm(dim=1))[:, None, None, None] 410 | 411 | inputs_grad_norm = inputs_grad.reshape( 412 | inputs_grad.size()[0], -1).norm(dim=1) 413 | # If the grad is basically 0, don't perturb that input. It's likely 414 | # already misclassified, and trying to perturb it further leads to 415 | # numerical instability. 416 | lam[inputs_grad_norm < 1e-4] = 0 417 | x[inputs_grad_norm < 1e-4] = 0 418 | 419 | # print('LPIPS', self.lpips_distance( 420 | # inputs, 421 | # inputs + lam * x, 422 | # )) 423 | 424 | return (inputs + lam * x).clamp(0, 1).detach() 425 | 426 | 427 | class PerceptualPGDAttack(nn.Module): 428 | def __init__(self, model, bound=0.5, step=None, num_iterations=5, 429 | cg_iterations=5, h=1e-3, lpips_model='self', 430 | decay_step_size=False, kappa=1, 431 | projection='newtons', randomize=False, 432 | random_targets=False, num_classes=None, 433 | include_image_as_activation=False): 434 | """ 435 | Iterated version of the conjugate gradient attack. 436 | 437 | step_size is the step size in LPIPS distance. 438 | num_iterations is the number of steps to take. 439 | cg_iterations is the conjugate gradient iterations per step. 440 | h is the step size to use for finite-difference calculation. 441 | project is whether or not to project the perturbation into the LPIPS 442 | ball after each step. 443 | """ 444 | 445 | super().__init__() 446 | 447 | assert randomize is False 448 | 449 | self.model = model 450 | self.bound = bound 451 | self.num_iterations = num_iterations 452 | self.decay_step_size = decay_step_size 453 | self.step = step 454 | self.random_targets = random_targets 455 | self.num_classes = num_classes 456 | 457 | if self.step is None: 458 | if self.decay_step_size: 459 | self.step = self.bound 460 | else: 461 | self.step = 2 * self.bound / self.num_iterations 462 | 463 | self.lpips_model = get_lpips_model(lpips_model, model) 464 | self.first_order_step = FirstOrderStepPerceptualAttack( 465 | model, bound=self.step, num_iterations=cg_iterations, h=h, 466 | kappa=kappa, lpips_model=self.lpips_model, 467 | include_image_as_activation=include_image_as_activation, 468 | targeted=self.random_targets) 469 | self.projection = PROJECTIONS[projection](self.bound, self.lpips_model) 470 | 471 | def _attack(self, inputs, labels): 472 | with torch.no_grad(): 473 | input_features = normalize_flatten_features( 474 | self.lpips_model.features(inputs)) 475 | 476 | start_perturbations = torch.zeros_like(inputs) 477 | start_perturbations.normal_(0, 0.01) 478 | adv_inputs = inputs + start_perturbations 479 | for attack_iter in range(self.num_iterations): 480 | if self.decay_step_size: 481 | step_size = self.step * \ 482 | 0.1 ** (attack_iter / self.num_iterations) 483 | self.first_order_step.bound = step_size 484 | adv_inputs = self.first_order_step(adv_inputs, labels) 485 | adv_inputs = self.projection(inputs, adv_inputs, input_features) 486 | 487 | # print('LPIPS', self.first_order_step.lpips_distance( 488 | # inputs, 489 | # adv_inputs, 490 | # )) 491 | 492 | return adv_inputs 493 | 494 | def forward(self, inputs, labels): 495 | if self.random_targets: 496 | return utilities.run_attack_with_random_targets( 497 | self._attack, 498 | self.model, 499 | inputs, 500 | labels, 501 | self.num_classes, 502 | ) 503 | else: 504 | return self._attack(inputs, labels) 505 | 506 | class LagrangePerceptualAttack(nn.Module): 507 | def __init__(self, model, bound=0.5, step=None, num_iterations=20, 508 | binary_steps=5, h=0.1, kappa=1, lpips_model='self', 509 | projection='newtons', decay_step_size=True, 510 | num_classes=None, 511 | include_image_as_activation=False, 512 | randomize=False, random_targets=False): 513 | """ 514 | Perceptual attack using a Lagrangian relaxation of the 515 | LPIPS-constrainted optimization problem. 516 | bound is the (soft) bound on the LPIPS distance. 517 | step is the LPIPS step size. 518 | num_iterations is the number of steps to take. 519 | lam is the lambda value multiplied by the regularization term. 520 | h is the step size to use for finite-difference calculation. 521 | lpips_model is the model to use to calculate LPIPS or 'self' or 522 | 'alexnet' 523 | """ 524 | 525 | super().__init__() 526 | 527 | assert randomize is False 528 | 529 | self.model = model 530 | self.bound = bound 531 | self.decay_step_size = decay_step_size 532 | self.num_iterations = num_iterations 533 | if step is None: 534 | if self.decay_step_size: 535 | self.step = self.bound 536 | else: 537 | self.step = self.bound * 2 / self.num_iterations 538 | else: 539 | self.step = step 540 | self.binary_steps = binary_steps 541 | self.h = h 542 | self.random_targets = random_targets 543 | self.num_classes = num_classes 544 | 545 | self.lpips_model = get_lpips_model(lpips_model, model) 546 | self.lpips_distance = LPIPSDistance( 547 | self.lpips_model, 548 | include_image_as_activation=include_image_as_activation, 549 | ) 550 | self.loss = MarginLoss(kappa=kappa, targeted=self.random_targets) 551 | self.projection = PROJECTIONS[projection](self.bound, self.lpips_model) 552 | 553 | def threat_model_contains(self, inputs, adv_inputs): 554 | """ 555 | Returns a boolean tensor which indicates if each of the given 556 | adversarial examples given is within this attack's threat model for 557 | the given natural input. 558 | """ 559 | 560 | return self.lpips_distance(inputs, adv_inputs) <= self.bound 561 | 562 | def _attack(self, inputs, labels): 563 | perturbations = torch.zeros_like(inputs) 564 | perturbations.normal_(0, 0.01) 565 | perturbations.requires_grad = True 566 | 567 | batch_size = inputs.shape[0] 568 | step_size = self.step 569 | 570 | lam = 0.01 * torch.ones(batch_size, device=inputs.device) 571 | 572 | input_features = normalize_flatten_features( 573 | self.lpips_model.features(inputs)).detach() 574 | 575 | live = torch.ones(batch_size, device=inputs.device, dtype=torch.bool) 576 | 577 | for binary_iter in range(self.binary_steps): 578 | for attack_iter in range(self.num_iterations): 579 | if self.decay_step_size: 580 | step_size = self.step * \ 581 | (0.1 ** (attack_iter / self.num_iterations)) 582 | else: 583 | step_size = self.step 584 | 585 | if perturbations.grad is not None: 586 | perturbations.grad.data.zero_() 587 | 588 | adv_inputs = (inputs + perturbations)[live] 589 | 590 | if self.model == self.lpips_model: 591 | adv_features, adv_logits = \ 592 | self.model.features_logits(adv_inputs) 593 | else: 594 | adv_features = self.lpips_model.features(adv_inputs) 595 | adv_logits = self.model(adv_inputs) 596 | 597 | adv_labels = adv_logits.argmax(1) 598 | adv_loss = self.loss(adv_logits, labels[live]) 599 | adv_features = normalize_flatten_features(adv_features) 600 | lpips_dists = (adv_features - input_features[live]).norm(dim=1) 601 | all_lpips_dists = torch.zeros(batch_size, device=inputs.device) 602 | all_lpips_dists[live] = lpips_dists 603 | 604 | loss = -adv_loss + lam[live] * F.relu(lpips_dists - self.bound) 605 | loss.sum().backward() 606 | 607 | grad = perturbations.grad.data[live] 608 | grad_normed = grad / \ 609 | (grad.reshape(grad.size()[0], -1).norm(dim=1) 610 | [:, None, None, None] + 1e-8) 611 | 612 | dist_grads = ( 613 | adv_features - 614 | normalize_flatten_features(self.lpips_model.features( 615 | adv_inputs - grad_normed * self.h)) 616 | ).norm(dim=1) / self.h 617 | 618 | updates = -grad_normed * ( 619 | step_size / (dist_grads + 1e-8) 620 | )[:, None, None, None] 621 | 622 | perturbations.data[live] = ( 623 | (inputs[live] + perturbations[live] + 624 | updates).clamp(0, 1) - 625 | inputs[live] 626 | ).detach() 627 | 628 | if self.random_targets: 629 | live[live.clone()] = (adv_labels != labels[live]) | (lpips_dists > self.bound) 630 | else: 631 | live[live.clone()] = (adv_labels == labels[live]) | (lpips_dists > self.bound) 632 | if live.sum() == 0: 633 | break 634 | 635 | lam[all_lpips_dists >= self.bound] *= 10 636 | if live.sum() == 0: 637 | break 638 | 639 | adv_inputs = (inputs + perturbations).detach() 640 | adv_inputs = self.projection(inputs, adv_inputs, input_features) 641 | return adv_inputs 642 | 643 | def forward(self, inputs, labels): 644 | if self.random_targets: 645 | return utilities.run_attack_with_random_targets( 646 | self._attack, 647 | self.model, 648 | inputs, 649 | labels, 650 | self.num_classes, 651 | ) 652 | else: 653 | return self._attack(inputs, labels) 654 | -------------------------------------------------------------------------------- /perceptual_advex/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cassidylaidlaw/perceptual-advex/65c1bc3aabe1b9a475ee0edb7606aee44896b685/perceptual_advex/py.typed -------------------------------------------------------------------------------- /perceptual_advex/trades_wrn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | 23 | def forward(self, x): 24 | if not self.equalInOut: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 29 | if self.droprate > 0: 30 | out = F.dropout(out, p=self.droprate, training=self.training) 31 | out = self.conv2(out) 32 | return torch.add( 33 | x if self.equalInOut or self.convShortcut is None 34 | else self.convShortcut(x), 35 | out, 36 | ) 37 | 38 | 39 | class NetworkBlock(nn.Module): 40 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 41 | super(NetworkBlock, self).__init__() 42 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 43 | 44 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 45 | layers = [] 46 | for i in range(int(nb_layers)): 47 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 48 | return nn.Sequential(*layers) 49 | 50 | def forward(self, x): 51 | return self.layer(x) 52 | 53 | 54 | class TradesWideResNet(nn.Module): 55 | def __init__(self, depth=34, num_classes=10, widen_factor=10, dropRate=0.0): 56 | super().__init__() 57 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 58 | assert ((depth - 4) % 6 == 0) 59 | n = (depth - 4) / 6 60 | block = BasicBlock 61 | # 1st conv before any network block 62 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 63 | padding=1, bias=False) 64 | # 1st block 65 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 66 | # 1st sub-block 67 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 68 | # 2nd block 69 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 70 | # 3rd block 71 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 72 | # global average pooling and classifier 73 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.fc = nn.Linear(nChannels[3], num_classes) 76 | self.nChannels = nChannels[3] 77 | 78 | for m in self.modules(): 79 | if isinstance(m, nn.Conv2d): 80 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 81 | m.weight.data.normal_(0, math.sqrt(2. / n)) 82 | elif isinstance(m, nn.BatchNorm2d): 83 | m.weight.data.fill_(1) 84 | m.bias.data.zero_() 85 | elif isinstance(m, nn.Linear): 86 | m.bias.data.zero_() 87 | 88 | def forward(self, x): 89 | out = self.conv1(x) 90 | out = self.block1(out) 91 | out = self.block2(out) 92 | out = self.block3(out) 93 | out = self.relu(self.bn1(out)) 94 | out = F.avg_pool2d(out, 8) 95 | out = out.view(-1, self.nChannels) 96 | return self.fc(out) 97 | -------------------------------------------------------------------------------- /perceptual_advex/utilities.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import torch 3 | import os 4 | import copy 5 | import torchvision.models as torchvision_models 6 | from torch import nn 7 | from robustness.datasets import DATASETS, DataSet 8 | from robustness.model_utils import make_and_restore_model 9 | from robustness.attacker import AttackerModel 10 | from advex_uar.common import pyt_common as uar_common 11 | 12 | from .models import CifarResNetFeatureModel, ImageNetResNetFeatureModel, \ 13 | AlexNetFeatureModel, CifarAlexNet, VGG16FeatureModel, TradesWideResNet 14 | from . import datasets 15 | 16 | 17 | class MarginLoss(nn.Module): 18 | """ 19 | Calculates the margin loss max(kappa, (max z_k (x) k != y) - z_y(x)), 20 | also known as the f6 loss used by the Carlini & Wagner attack. 21 | """ 22 | 23 | def __init__(self, kappa=float('inf'), targeted=False): 24 | super().__init__() 25 | self.kappa = kappa 26 | self.targeted = targeted 27 | 28 | def forward(self, logits, labels): 29 | correct_logits = torch.gather(logits, 1, labels.view(-1, 1)) 30 | 31 | max_2_logits, argmax_2_logits = torch.topk(logits, 2, dim=1) 32 | top_max, second_max = max_2_logits.chunk(2, dim=1) 33 | top_argmax, _ = argmax_2_logits.chunk(2, dim=1) 34 | labels_eq_max = top_argmax.squeeze().eq(labels).float().view(-1, 1) 35 | labels_ne_max = top_argmax.squeeze().ne(labels).float().view(-1, 1) 36 | max_incorrect_logits = labels_eq_max * second_max + labels_ne_max * top_max 37 | 38 | if self.targeted: 39 | return (correct_logits - max_incorrect_logits) \ 40 | .clamp(max=self.kappa).squeeze() 41 | else: 42 | return (max_incorrect_logits - correct_logits) \ 43 | .clamp(max=self.kappa).squeeze() 44 | 45 | 46 | def add_dataset_model_arguments(parser, include_checkpoint=False): 47 | """ 48 | Adds the argparse arguments to the given parser necessary for calling the 49 | get_dataset_model command. 50 | """ 51 | 52 | if include_checkpoint: 53 | parser.add_argument('--checkpoint', type=str, help='checkpoint path') 54 | 55 | parser.add_argument('--arch', type=str, default='resnet50', 56 | help='model architecture') 57 | parser.add_argument('--dataset', type=str, default='cifar', 58 | help='dataset name') 59 | parser.add_argument('--dataset_path', type=str, default='~/datasets', 60 | help='path to datasets directory') 61 | 62 | 63 | def get_dataset_model( 64 | args=None, 65 | dataset_path: Optional[str] = None, 66 | arch: Optional[str] = None, 67 | checkpoint_fname: Optional[str] = None, 68 | **kwargs, 69 | ) -> Tuple[DataSet, nn.Module]: 70 | """ 71 | Given an argparse namespace with certain parameters, or those parameters 72 | as keyword arguments, returns a tuple (dataset, model) with a robustness 73 | dataset and a FeatureModel. 74 | """ 75 | 76 | if dataset_path is None: 77 | if args is None: 78 | dataset_path = '~/datasets' 79 | else: 80 | dataset_path = args.dataset_path 81 | dataset_path = os.path.expandvars(dataset_path) 82 | 83 | dataset_name = kwargs.get('dataset') or args.dataset 84 | dataset = DATASETS[dataset_name](dataset_path) 85 | 86 | checkpoint_is_feature_model = False 87 | 88 | if checkpoint_fname is None: 89 | checkpoint_fname = getattr(args, 'checkpoint', None) 90 | if arch is None: 91 | arch = args.arch 92 | 93 | if arch.startswith('rob-') or ( 94 | dataset_name.startswith('cifar') and 95 | 'resnet' in arch 96 | ): 97 | if arch.startswith('rob-'): 98 | arch = arch[4:] 99 | if checkpoint_fname == 'pretrained': 100 | pytorch_pretrained = True 101 | checkpoint_fname = None 102 | else: 103 | pytorch_pretrained = False 104 | try: 105 | model, _ = make_and_restore_model( 106 | arch=arch, 107 | dataset=dataset, 108 | resume_path=checkpoint_fname, 109 | pytorch_pretrained=pytorch_pretrained, 110 | parallel=False, 111 | ) 112 | except RuntimeError as error: 113 | if 'state_dict' in str(error): 114 | model, _ = make_and_restore_model( 115 | arch=arch, 116 | dataset=dataset, 117 | parallel=False, 118 | ) 119 | try: 120 | state = torch.load(checkpoint_fname) 121 | model.model.load_state_dict(state['model']) 122 | except RuntimeError as error: 123 | if 'state_dict' in str(error): 124 | checkpoint_is_feature_model = True 125 | else: 126 | raise error 127 | else: 128 | raise error # type: ignore 129 | elif arch == 'trades-wrn': 130 | model = TradesWideResNet() 131 | if checkpoint_fname is not None: 132 | state = torch.load(checkpoint_fname) 133 | model.load_state_dict(state) 134 | elif hasattr(torchvision_models, arch): 135 | if ( 136 | arch == 'alexnet' and 137 | dataset_name.startswith('cifar') and 138 | checkpoint_fname != 'pretrained' 139 | ): 140 | model = CifarAlexNet(num_classes=dataset.num_classes) 141 | else: 142 | if checkpoint_fname == 'pretrained': 143 | model = getattr(torchvision_models, arch)(pretrained=True) 144 | else: 145 | model = getattr(torchvision_models, arch)( 146 | num_classes=dataset.num_classes) 147 | 148 | if checkpoint_fname is not None and checkpoint_fname != 'pretrained': 149 | try: 150 | state = torch.load(checkpoint_fname) 151 | model.load_state_dict(state['model']) 152 | except RuntimeError as error: 153 | if 'state_dict' in str(error): 154 | checkpoint_is_feature_model = True 155 | else: 156 | raise error 157 | else: 158 | raise RuntimeError(f'Unsupported architecture {arch}.') 159 | 160 | if 'alexnet' in arch: 161 | model = AlexNetFeatureModel(model) 162 | elif 'vgg16' in arch: 163 | model = VGG16FeatureModel(model) 164 | elif 'resnet' in arch: 165 | if not isinstance(model, AttackerModel): 166 | model = AttackerModel(model, dataset) 167 | if dataset_name.startswith('cifar'): 168 | model = CifarResNetFeatureModel(model) 169 | elif ( 170 | dataset_name.startswith('imagenet') 171 | or dataset_name == 'bird_or_bicycle' 172 | ): 173 | model = ImageNetResNetFeatureModel(model) 174 | else: 175 | raise RuntimeError('Unsupported dataset.') 176 | elif arch == 'trades-wrn': 177 | pass # We can't use this as a FeatureModel yet. 178 | else: 179 | raise RuntimeError(f'Unsupported architecture {arch}.') 180 | 181 | if checkpoint_is_feature_model: 182 | model.load_state_dict(state['model']) 183 | 184 | return dataset, model 185 | 186 | 187 | def calculate_accuracy(logits, labels): 188 | correct = logits.argmax(1) == labels 189 | return correct.float().mean() 190 | 191 | 192 | class LambdaLayer(nn.Module): 193 | def __init__(self, lambd): 194 | super().__init__() 195 | self.lambd = lambd 196 | 197 | def forward(self, x): 198 | return self.lambd(x) 199 | 200 | 201 | def run_attack_with_random_targets(attack, model, inputs, labels, num_classes): 202 | """ 203 | Runs an attack with targets randomly selected from all classes besides the 204 | correct one. The attack should be a function from (inputs, labels) to 205 | adversarial examples. 206 | """ 207 | 208 | rand_targets = torch.randint( 209 | 0, num_classes - 1, labels.size(), 210 | dtype=labels.dtype, device=labels.device, 211 | ) 212 | targets = torch.remainder(labels + rand_targets + 1, num_classes) 213 | 214 | adv_inputs = attack(inputs, targets) 215 | adv_labels = model(adv_inputs).argmax(1) 216 | unsuccessful = adv_labels != targets 217 | adv_inputs[unsuccessful] = inputs[unsuccessful] 218 | 219 | return adv_inputs 220 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.4.0 2 | git+https://github.com/MadryLab/robustness#egg=robustness 3 | numpy>=1.18.2 4 | torchvision>=0.5.0 5 | tensorboardX>=2.0 6 | advex-uar>=0.0.5.dev0 7 | git+https://github.com/fra31/auto-attack#egg=autoattack 8 | recoloradv==0.0.1 9 | git+https://github.com/numpy/numpy-stubs#egg=numpy-stubs 10 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | setup( 3 | name='perceptual-advex', 4 | packages=[ 5 | 'perceptual_advex', 6 | ], 7 | package_data={'perceptual_advex': ['py.typed']}, 8 | version='0.2.6', 9 | license='MIT', 10 | description='Code for the ICLR 2021 paper "Perceptual Adversarial Robustness: Defense Against Unseen Threat Models"', 11 | author='Cassidy Laidlaw', 12 | author_email='claidlaw@umd.edu', 13 | url='https://github.com/cassidylaidlaw/perceptual-advex', 14 | download_url='https://github.com/cassidylaidlaw/perceptual-advex/archive/TODO.tar.gz', 15 | keywords=['adversarial examples', 'machine learning'], 16 | install_requires=[ 17 | 'torch>=1.4.0', 18 | 'robustness>=1.1.post2', 19 | 'numpy>=1.18.2', 20 | 'torchvision>=0.5.0', 21 | 'PyWavelets>=1.0.0', 22 | 'advex-uar>=0.0.5.dev0', 23 | 'statsmodels==0.11.1', 24 | 'recoloradv==0.0.1', 25 | ], 26 | classifiers=[ 27 | 'Development Status :: 3 - Alpha', 28 | 'Intended Audience :: Science/Research', 29 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 30 | 'Topic :: Scientific/Engineering :: Image Recognition', 31 | 'License :: OSI Approved :: MIT License', 32 | 'Programming Language :: Python :: 3', 33 | 'Programming Language :: Python :: 3.6', 34 | 'Programming Language :: Python :: 3.7', 35 | 'Programming Language :: Python :: 3.8', 36 | 'Programming Language :: Python :: 3.9', 37 | ], 38 | ) 39 | --------------------------------------------------------------------------------