├── LICENSE ├── MANIFEST.in ├── README.md ├── adversarial ├── __init__.py ├── classify_images.py ├── examples │ └── demo.py ├── gen_adversarial_images.py ├── gen_tar_index.py ├── gen_transformed_images.py ├── index_patches.py ├── lib │ ├── __init__.py │ ├── adversary.py │ ├── constants.py │ ├── convnet.py │ ├── dataset.py │ ├── datasets │ │ ├── __init__.py │ │ ├── dataset_classes_folder.py │ │ ├── sub_dataset_folder.py │ │ ├── sub_dataset_tarfolder.py │ │ ├── tar_metadata.py │ │ ├── tarfolder.py │ │ └── transform_dataset.py │ ├── defenses.py │ ├── model.py │ ├── models │ │ └── __init__.py │ ├── opts.py │ ├── path_config.json │ ├── paths.py │ ├── transformations │ │ ├── __init__.py │ │ ├── _tv_bregman.patch │ │ ├── findseam.cpp │ │ ├── findseam.h │ │ ├── quilting.cpp │ │ ├── quilting.h │ │ ├── quilting.py │ │ ├── quilting_fast.py │ │ ├── transformation_helper.py │ │ ├── transforms.py │ │ └── tvm.py │ └── util.py ├── test │ └── images │ │ ├── sample │ │ ├── lena_quilting.png │ │ └── lena_tvm.png │ │ ├── train │ │ └── class_name │ │ │ ├── lena-copy.png │ │ │ └── lena.png │ │ └── val │ │ └── class_name │ │ ├── lena-copy.png │ │ └── lena.png └── train_model.py └── setup.py /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include setup.py 3 | include MANIFEST.in 4 | 5 | include adversarial/*.py 6 | include adversarial/lib/*.py 7 | include adversarial/lib/datasets/*.py 8 | include adversarial/lib/models/*.py 9 | include adversarial/lib/path_config.json 10 | exclude adversarial/test/*.py 11 | 12 | recursive-include adversarial/lib/transformations *.py *.h *.cpp *.patch 13 | recursive-include adversarial/examples * 14 | recursive-include adversarial/test *.png 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Countering Adversarial Images Using Input Transformations 2 | 3 | # Overview 4 | This package implements the experiments described in the paper [Countering Adversarial Images Using Input Transformations](https://arxiv.org/pdf/1711.00117.pdf). 5 | It contains implementations for [adversarial attacks](#adversarial_attack), [defenses based image transformations](#image_transformation), [training](#training), and [testing](#classify) convolutional networks under adversarial attacks using our defenses. We also provide [pre-trained models](#pretrained). 6 | 7 | If you use this code, please cite our paper: 8 | 9 | - Chuan Guo, Mayank Rana, Moustapha Cisse, and Laurens van der Maaten. **Countering Adversarial Images using Input Transformations**. arXiv 1711.00117, 2017. [[PDF](https://arxiv.org/pdf/1711.00117.pdf)] 10 | 11 | ## Adversarial Defenses 12 | The code implements the following four defenses against adversarial images, all of which are based on image transformations: 13 | - Image quilting 14 | - Total variation minimization 15 | - JPEG compression 16 | - Pixel quantization 17 | 18 | Please refer to the paper for details on these defenses. A detailed description of the original image quilting algorithm can be found [here](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/papers/efros-siggraph01.pdf); a detailed description of our solver for total variation minimization can be found [here](ftp://ftp.math.ucla.edu/pub/camreport/cam08-29.pdf). 19 | 20 | ## Adversarial Attacks 21 | 22 | The code implements the following four approaches to generating adversarial images: 23 | - [Fast gradient sign method (FGSM)](https://arxiv.org/abs/1412.6572) 24 | - [Iterative FGSM](https://arxiv.org/abs/1611.01236) 25 | - [DeepFool](https://arxiv.org/abs/1511.04599) 26 | - [Carlini-Wagner attack](https://arxiv.org/abs/1608.04644) 27 | 28 | 29 | # Installation 30 | To use this code, first install Python, [PyTorch](www.pytorch.org), and [Faiss](https://github.com/facebookresearch/faiss) (to perform image quilting). We tested the code using Python 2.7, PyTorch v0.2.0, and scikit-image 0.11; your mileage may vary when using other versions. 31 | 32 | Pytorch can be installed using the instructions [here](http://pytorch.org/). Faiss is required to run the image quilting algorithm; it is not automatically included because faiss does not have a pip support and because it requires configuring BLAS and LAPACK flags, as described [here](https://github.com/facebookresearch/faiss/blob/master/INSTALL.md). Please install faiss using the instructions given [here](https://github.com/facebookresearch/faiss). 33 | 34 | The code uses several other external dependencies (for training Inception models, performing Bregman iteration, etc.). These dependencies are automatically downloaded and installed when you install this package via `pip`: 35 | ```bash 36 | # Install from source 37 | cd adversarial_image_defenses 38 | pip install . 39 | 40 | ``` 41 | 42 | # Usage 43 | 44 | To import the package in Python: 45 | ```python 46 | import adversarial 47 | ``` 48 | 49 | The functionality implemented in this package is demonstrated in [this example](https://github.com/facebookresearch/adversarial_image_defenses/blob/master/adversarial/examples/demo.py). Run the example via: 50 | ```bash 51 | python adversarial/examples/demo.py 52 | ``` 53 | 54 | 55 | ## API 56 | The full functionality of the package is exposed via several runnable Python scripts. All these scripts require the user to specify the path to the Imagenet dataset, the path to pre-trained models, and the path to quilted images (once they are computed) in `lib/path_config.json`. Alternatively, the paths can be passed as input arguments into the scripts. 57 | 58 | 59 | ### Generate quilting patches 60 | [`index_patches.py`](adversarial/index_patches.py) creates a faiss index of images patches. This index can be used to perform quilting of images. 61 | 62 | Code example: 63 | ```python 64 | import adversarial 65 | from index_patches import create_faiss_patches, parse_args 66 | 67 | args = parse_args() 68 | # Update args if needed 69 | args.patch_size = 5 70 | create_faiss_patches(args) 71 | ``` 72 | 73 | Alternatively, run `python index_patches.py`. The following arguments are supported: 74 | - `--patch_size` Patch size (square) that will be used in quilting (default: 5). 75 | - `--num_patches` Number of patches to generate (default: 1000000). 76 | - `--pca_dims` PCA dimension for faiss (default: 64). 77 | - `--patches_file` File in which patches are saved. 78 | - `--index_file` File in which faiss index of patches is saved. 79 | 80 | 81 | 82 | ### Image transformations 83 | [`gen_transformed_images.py`](adversarial/gen_transformed_images.py) has applies an image transformation on (adversarial or non-adversarial) ImageNet images, and saves them to disk. Image transformations such as image quilting are too computationally intensive to be performed on-the-fly during network training, which is why we precompute the transformed images. 84 | 85 | Code example: 86 | ```python 87 | import adversarial 88 | from gen_transformed_images import generate_transformed_images 89 | from lib import opts 90 | # load default args for transformation functions 91 | args = opts.parse_args(opts.OptType.TRANSFORMATION) 92 | args.operation = "transformation_on_raw" 93 | args.defenses = ["tvm"] 94 | args.partition_size = 1 # Number of samples to generate 95 | 96 | generate_transformed_images(args) 97 | ``` 98 | 99 | Alternatively, run `python gen_transformed_images.py`. In addition to the [common arguments](#common_args) and [adversarial arguments](adversarial_args), the following arguments are supported: 100 | - `--operation` Operation to run. Supported operations are: 101 | `transformation_on_raw`: Apply transformations on raw images. 102 | `transformation_on_adv`: Apply transformations on adversarial images. 103 | `cat_data`: Concatenate output from distributed `transformation_on_adv`. 104 | - `--data_type` Data type (`train` or `raw`) for `transformation_on_raw` (default: `train`). 105 | - `--out_dir` Directory path for output of `cat_data`. 106 | - `--partition_dir` Directory path to output transformed data. 107 | - `--data_batches` Number of data batches to generate. Used for random crops for ensembling. 108 | - `--partition` Distributed data partition (default: 0). 109 | - `--partition_size` The size of each data partition. 110 | For `transformation_on_raw`, partition_size represents number of classes for each process. 111 | For `transformation_on_adv`, partition_size represents number of images for each process. 112 | - `--n_threads` Number of threads for `transformation_on_raw`. 113 | 114 | 115 | ### Generate TAR data index 116 | Many file systems perform poorly when dealing with millions of small files (such as images). Therefore, we generally TAR our image datasets (obtained by running `generate_transformed_images`). Next, we use 117 | [`gen_tar_index.py`](adversarial/gen_tar_index.py) to generate a file index for the TAR file. The file index facilitates fast, random-access reading of the TAR file; it is much faster and requires less memory than untarring the data or using `tarfile` package. 118 | 119 | Code example: 120 | ```python 121 | import adversarial 122 | from gen_tar_index import generate_tar_index, parse_args 123 | 124 | args = parse_args() 125 | generate_tar_index(args) 126 | ``` 127 | 128 | Alternatively, run `python gen_tar_index.py`. The following arguments are supported: 129 | - `--tar_path` Path for TAR file or directory. 130 | - `--index_root` Directory in which to store TAR index file. 131 | - `--path_prefix` Prefix to identify TAR member names to be indexed. 132 | 133 | 134 | 135 | ### Adversarial Attacks 136 | [`gen_adversarial_images.py`](adversarial/gen_adversarial_images.py) implements the generation of adversarial images for the ImageNet dataset. 137 | 138 | Code example: 139 | ```python 140 | import adversarial 141 | from gen_adversarial_images import generate_adversarial_images 142 | from lib import opts 143 | # load default args for adversary functions 144 | args = opts.parse_args(opts.OptType.ADVERSARIAL) 145 | args.model = "resnet50" 146 | args.adversary_to_generate = "fgs" 147 | args.partition_size = 1 # Number of samples to generate 148 | args.data_type = "val" # input dataset type 149 | args.normalize = True # apply normalization on input data 150 | args.attack_type = "blackbox" # For attack, use transformed models 151 | args.pretrained = True # Use pretrained model from model-zoo 152 | 153 | generate_adversarial_images(args) 154 | ``` 155 | 156 | Alternatively, run `python gen_adversarial_images.py`. For a list of the supported arguments, see [common arguments](#common_args) and [adversarial arguments](adversarial_args). 157 | 158 | 159 | 160 | ### Training 161 | [`train_model.py`](adversarial/train_model.py) implements the training of convolutional networks on (transformed or non-transformed) ImageNet images. 162 | 163 | Code example: 164 | ```python 165 | import adversarial 166 | from train_model import train_model 167 | from lib import opts 168 | # load default args 169 | args = opts.parse_args(opts.OptType.TRAIN) 170 | args.defenses = None # defense=<(raw, tvm, quilting, jpeg, quantization)> 171 | args.model = "resnet50" 172 | args.normalize = True # apply normalization on input data 173 | 174 | train_model(args) 175 | ``` 176 | 177 | Alternatively, run `python train_model.py`. In addition to the [common arguments](#common_args), the following arguments are supported: 178 | - `--resume` Resume training from checkpoint (if available). 179 | - `--lr` Initial learning rate defined in [constants.py] (lr=0.045 for Inception-v4, 0.1 for other models). 180 | - `--lr_decay` Exponential learning rate decay defined in [constants.py] (0.94 for inception_v4, 0.1 for other models). 181 | - `--lr_decay_stepsize` Decay learning rate after every stepsize epochs defined in [constants.py] (0.94 for inception_v4, 0.1 for other models). 182 | - `--momentum` Momentum (default: 0.9). 183 | - `--weight_decay` Amount of weight decay (default: 1e-4). 184 | - `--start_epoch` Index of first epoch (default: 0). 185 | - `--end_epoch` Index of last epoch (default: 90). 186 | - `--preprocessed_epoch_data` Augmented and transformed data for each epoch is pre-generated (default: `False`). 187 | 188 | 189 | ### Testing 190 | [`classify_images.py`](adversarial/classify_images.py) implements the testing of a training convolutional network on an dataset of (adversarial or non-adversarial / transformed or non-transformed) ImageNet images. 191 | 192 | Code exammple: 193 | ```python 194 | import adversarial 195 | from classify_images import classify_images 196 | from lib import opts 197 | # load default args 198 | args = opts.parse_args(opts.OptType.CLASSIFY) 199 | 200 | classify_images(args) 201 | ``` 202 | 203 | Alternatively, run `python classify_images.py`. In addition to the [common arguments](#common_args), the following arguments are supported: 204 | - `--ensemble` Ensembling type, `None`, `avg`, `max` (default: `None`). 205 | - `--ncrops` List of number of crops for each defense to use for ensembling (default: `None`). 206 | - `--crop_frac` List of crop fraction for each defense to use for ensembling (default: `None`). 207 | - `--crop_type` List of crop type(`center`, `random`, `sliding`(hardset for 9 crops)) for each defense to use for ensembling (default: `None`). 208 | 209 | 210 | ### Pre-trained models 211 | We provide pre-trained models that were trained on ImageNet images that were processed using total variation minimization (TVM) or image quilting can be downloaded from the following links (set the `models_root` argument to the path that contains these model model files): 212 | 213 | - [ResNet-50_model trained on quilted images] 214 | - [ResNet-50_model trained on TVM images] 215 | - [ResNet-101_model trained on quilted images] 216 | - [ResNet-101_model trained on TVM images] 217 | - [DenseNet-169_model trained on quilted images]) 218 | - [DenseNet-169_model trained on TVM images] 219 | - [Inception-v4_model trained on quilted images] 220 | - [Inception-v4_model trained on TVM images] 221 | 222 | 223 | 224 | ### Common arguments 225 | 226 | The following arguments are used by multiple scripts, including 227 | `generate_transformed_images`, `train_model`, and `classify_images`: 228 | 229 | #### Paths 230 | - `--data_root` Main data directory to save and read data. 231 | - `--models_root` Directory path to store/load models. 232 | - `--tar_dir` Directory path for transformed images(train/val) stored in TAR files. 233 | - `--tar_index_dir` Directory path for index files for transformed images in TAR files. 234 | - `--quilting_index_root` Directory path for quilting index files. 235 | - `--quilting_patch_root` Directory path for quilting patch files. 236 | 237 | #### Train/Classifier params 238 | - `--model` Model to use (default: `resnet50`). 239 | - `--device` Device to use: cpu or gpu (default: `gpu`). 240 | - `--normalize` Normalize image data. 241 | - `--batchsize` Batch size for training and testing (default: 256). 242 | - `--preprocessed_data` Transformations/Defenses are already applied on saved images (default: `False`). 243 | - `--defenses` List of defenses to apply: `raw` (no defense), `tvm`, `quilting`, `jpeg`, `quantization` (default: `None`). 244 | - `--pretrained` Use pretrained model from PyTorch model zoo (default: `False`). 245 | 246 | #### Tranformation params 247 | - `--tvm_weight` Regularization weight for total variation minimization (TVM). 248 | - `--pixel_drop_rate` Pixel drop rate to use in TVM. 249 | - `--tvm_method` Reconstruction method to use in TVM (default: `bregman`). 250 | - `--quilting_patch_size` Patch size to use in image quilting. 251 | - `--quilting_neighbors` Number of nearest patches to sample from in image quilting (default: 1). 252 | - `--quantize_depth` Bit depth for quantization defense (default: 8). 253 | 254 | 255 | 256 | #### Adversarial arguments 257 | The following arguments are used whem generating adversarial images with `gen_transformed_images.py`: 258 | 259 | - `--n_samples` Maximum number of samples to test on. 260 | - `--attack_type` Attack type: `None` (no attack), `blackbox`, `whitebox` (default: `None`). 261 | - `--adversary` Adversary to use: `fgs`, `ifgs`, `cwl2`, `deepfool` (default: `None`). 262 | - `--adversary_model` Model to use for generating adversarial images (default: `resnet50`). 263 | - `--learning_rate` Learning rate for iterative adversarial attacks (default: read from constants). 264 | - `--adv_strength` Adversarial strength for non-iterative adversarial attacks (default: read from constants). 265 | - `--adversarial_root` Path containing adversarial images. 266 | -------------------------------------------------------------------------------- /adversarial/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import os 14 | import sys 15 | 16 | # add adversarial dir to python path to import sub packages 17 | sys.path.append(os.path.dirname(__file__)) 18 | -------------------------------------------------------------------------------- /adversarial/classify_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import torch 14 | import torch.nn.parallel 15 | import torch.optim 16 | import torch.utils.data 17 | 18 | from lib.convnet import get_prob 19 | from lib.dataset import load_dataset, get_data_loader 20 | from lib.defenses import get_defense 21 | import lib.opts as opts 22 | from lib.model import get_model 23 | import lib.transformations.transforms as transforms 24 | from lib.transformations.transformation_helper import update_dataset_transformation 25 | import lib.constants as constants 26 | from lib.constants import DefenseType 27 | from lib.util import accuracy 28 | 29 | ENSEMBLE_TYPE = ['max', 'avg'] 30 | 31 | 32 | # Test and ensemble image crops 33 | def _eval_crops(args, dataset, model, defense, crop, ncrops, crop_type): 34 | 35 | # assertions 36 | assert dataset is not None, "dataset expected" 37 | assert model is not None, "model expected" 38 | assert crop_type is None or isinstance(crop_type, str) 39 | if crop is not None: 40 | assert callable(crop) 41 | assert type(ncrops) == int 42 | 43 | probs = None 44 | 45 | for crop_num in range(ncrops): 46 | 47 | # For sliding crop update crop function in dataset 48 | if crop_type == 'sliding': 49 | crop.update_sliding_position(crop_num) 50 | dataset = update_dataset_transformation( 51 | dataset, args, 'valid', defense, crop) 52 | 53 | # set up dataloader: 54 | print('| set up data loader...') 55 | data_loader = get_data_loader( 56 | dataset, 57 | batchsize=args.batchsize, 58 | device=args.device, 59 | shuffle=False, 60 | ) 61 | 62 | # test 63 | prob, targets = get_prob(model, data_loader) 64 | # collect prob for each run 65 | if probs is None: 66 | probs = torch.zeros(ncrops, len(dataset), prob.size(1)) 67 | probs[crop_num, :, :] = prob 68 | 69 | # measure and print accuracy 70 | _, _prob = prob.topk(5, dim=1) 71 | _correct = _prob.eq(targets.view(-1, 1).expand_as(_prob)) 72 | _top1 = _correct.select(1, 0).float().mean() * 100 73 | defense_name = "no defense" if defense is None else defense.get_name() 74 | print('| crop[%d]: top1 acc for %s = %f' % (crop_num, defense_name, _top1)) 75 | 76 | data_loader = None 77 | 78 | return probs, targets 79 | 80 | 81 | def classify_images(args): 82 | 83 | # assertions 84 | assert args.ensemble is None or args.ensemble in ENSEMBLE_TYPE, \ 85 | "{} not a supported type. Only supported ensembling are {}".format( 86 | args.ensemble, ENSEMBLE_TYPE) 87 | if not args.ensemble: 88 | assert args.ncrops is None or ( 89 | len(args.ncrops) == 1 and args.ncrops[0] == 1) 90 | if args.defenses is not None: 91 | for d in args.defenses: 92 | assert DefenseType.has_value(d), \ 93 | "\"{}\" defense not defined".format(d) 94 | # crops expected for each defense 95 | assert (args.ncrops is None or 96 | len(args.ncrops) == len(args.defenses)), ( 97 | "Number of crops for each defense is expected") 98 | assert (args.crop_type is None or 99 | len(args.crop_type) == len(args.defenses)), ( 100 | "crop_type for each defense is expected") 101 | # assert (len(args.crop_frac) == len(args.defenses)), ( 102 | # "crop_frac for each defense is expected") 103 | elif args.ncrops is not None: 104 | # no crop ensembling when defense is None 105 | assert len(args.ncrops) == 1 106 | assert args.crop_frac is not None and len(args.crop_frac) == 1, \ 107 | "Only one crop_frac is expected as there is no defense" 108 | assert args.crop_type is not None and len(args.crop_type) == 1, \ 109 | "Only one crop_type is expected as there is no defense" 110 | 111 | if args.defenses is None or len(args.defenses) == 0: 112 | defenses = [None] 113 | else: 114 | defenses = args.defenses 115 | 116 | all_defense_probs = None 117 | for idx, defense_name in enumerate(defenses): 118 | # initialize dataset 119 | defense = get_defense(defense_name, args) 120 | # Read preset params for adversary based on args 121 | adv_params = constants.get_adv_params(args, idx) 122 | print("| adv_params: ", adv_params) 123 | # setup crop 124 | ncrops = 1 125 | crop_type = None 126 | crop_frac = 1.0 127 | if args.ncrops: 128 | crop_type = args.crop_type[idx] 129 | crop_frac = args.crop_frac[idx] 130 | if crop_type == 'sliding': 131 | ncrops = 9 132 | else: 133 | ncrops = args.ncrops[idx] 134 | # Init custom crop function 135 | crop = transforms.Crop(crop_type, crop_frac) 136 | # initialize dataset 137 | dataset = load_dataset(args, 'valid', defense, adv_params, crop) 138 | # load model 139 | model, _, _ = get_model(args, load_checkpoint=True, defense_name=defense_name) 140 | 141 | # get crop probabilities for crops for current defense 142 | probs, targets = _eval_crops(args, dataset, model, defense, 143 | crop, ncrops, crop_type) 144 | 145 | if all_defense_probs is None: 146 | all_defense_probs = torch.zeros(len(defenses), 147 | len(dataset), 148 | probs.size(2)) 149 | # Ensemble crop probabilities 150 | if args.ensemble == 'max': 151 | probs = torch.max(probs, dim=0)[0] 152 | elif args.ensemble == 'avg': # for average ensembling 153 | probs = torch.mean(probs, dim=0) 154 | else: # for no ensembling 155 | assert all_defense_probs.size(0) == 1 156 | probs = probs[0] 157 | all_defense_probs[idx, :, :] = probs 158 | 159 | # free memory 160 | dataset = None 161 | model = None 162 | 163 | # Ensemble defense probabilities 164 | if args.ensemble == 'max': 165 | all_defense_probs = torch.max(all_defense_probs, dim=0)[0] 166 | elif args.ensemble == 'avg': # for average ensembling 167 | all_defense_probs = torch.mean(all_defense_probs, dim=0) 168 | else: # for no ensembling 169 | assert all_defense_probs.size(0) == 1 170 | all_defense_probs = all_defense_probs[0] 171 | # Calculate top1 and top5 accuracy 172 | prec1, prec5 = accuracy(all_defense_probs, targets, topk=(1, 5)) 173 | print('=' * 60) 174 | print('| Results for model={}, attack={}, ensemble_type={} '.format( 175 | args.model, args.adversary, args.ensemble)) 176 | prec1 = prec1[0] 177 | prec5 = prec5[0] 178 | print('| classification accuracy @1: %2.5f' % (prec1)) 179 | print('| classification accuracy @5: %2.5f' % (prec5)) 180 | print('| classification error @1: %2.5f' % (100. - prec1)) 181 | print('| classification error @5: %2.5f' % (100. - prec5)) 182 | print('| done.') 183 | 184 | 185 | # run: 186 | if __name__ == '__main__': 187 | # parse input arguments 188 | args = opts.parse_args(opts.OptType.CLASSIFY) 189 | classify_images(args) 190 | -------------------------------------------------------------------------------- /adversarial/examples/demo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import pkgutil 14 | if pkgutil.find_loader("adversarial") is not None: 15 | # If module is installed using "pip install ." 16 | from adversarial.gen_adversarial_images import generate_adversarial_images 17 | from adversarial.gen_transformed_images import generate_transformed_images 18 | from adversarial.classify_images import classify_images 19 | from adversarial.train_model import train_model 20 | from adversarial.index_patches import create_faiss_patches 21 | else: 22 | from gen_adversarial_images import generate_adversarial_images 23 | from gen_transformed_images import generate_transformed_images 24 | from classify_images import classify_images 25 | from train_model import train_model 26 | from index_patches import create_faiss_patches 27 | 28 | from lib import opts 29 | import os 30 | 31 | 32 | # Generate adversarial images 33 | def _generate_adversarial_images(): 34 | print("=" * 30 + "GENERATING ADVERSARIAL IMAGES" + "=" * 30) 35 | # load default args for adversary functions 36 | args = opts.parse_args(opts.OptType.ADVERSARIAL) 37 | # edit default args 38 | args.operation = "generate_adversarial" 39 | args.model = "resnet50" 40 | args.adversary_to_generate = "fgs" 41 | args.defenses = None 42 | args.partition_size = 1 # Number of samples to generate 43 | args.n_samples = 10000 # Total samples in input data 44 | args.data_type = "val" # input dataset type 45 | args.normalize = True # apply normalization on input data 46 | args.attack_type = "blackbox" # For attack, use transformed models 47 | args.pretrained = True # Use pretrained model from model-zoo 48 | 49 | generate_adversarial_images(args) 50 | 51 | 52 | # Apply transformations 53 | def _generate_transformed_images(): 54 | print("=" * 30 + "GENERATING IMAGE TRANSFORMATIONS" + "=" * 30) 55 | # load default args for transformation functions 56 | args = opts.parse_args(opts.OptType.TRANSFORMATION) 57 | # edit default args 58 | # Apply transformations on raw images, 59 | # for adversarial images use "transformation_on_adv" 60 | args.operation = "transformation_on_raw" 61 | args.adversary = None # update to adversary for operation "transformation_on_adv" 62 | # For quilting expects patches data at QUILTING_ROOT (defined in 63 | # path_config.json or passed in args) 64 | args.defenses = ["quilting"] # <"tvm", "quilting", "jpeg", quantize> 65 | args.partition_size = 1 # Number of samples to generate 66 | args.data_type = "val" # input dataset type 67 | 68 | # args.n_samples = 50000 # Total samples in input data when reading from .pth files 69 | # args.attack_type = "blackbox" # Used for file paths for "transformation_on_adv" 70 | 71 | generate_transformed_images(args) 72 | print('Transformed images saved at {}'.format( 73 | os.path.join(args.partition_dir, args.defenses[0]))) 74 | 75 | 76 | def _classify_images(): 77 | print("=" * 30 + "CLASSIFYING" + "=" * 30) 78 | # classify images without any attack or defense 79 | args = opts.parse_args(opts.OptType.CLASSIFY) 80 | 81 | # edit default args 82 | args.n_samples = 1 # Total samples in input data 83 | args.normalize = True # apply normalization on input data 84 | args.pretrained = True # Use pretrained model from model-zoo 85 | # To classify transformed images using transformed model update defenses to 86 | # 87 | args.defenses = None 88 | # To classify attack images update attack_type to 89 | args.attack_type = None 90 | # To classify attack images update adversary to 91 | args.adversary = None 92 | 93 | classify_images(args) 94 | 95 | 96 | def _train_model(): 97 | print("=" * 30 + "TRAINING" + "=" * 30) 98 | args = opts.parse_args(opts.OptType.TRAIN) 99 | 100 | # edit default args 101 | # To classify transformed images using transformed model update defenses to 102 | # 103 | args.defenses = None # defense=<(raw, tvm, quilting, jpeg, quantization)> 104 | args.model = "resnet50" 105 | args.normalize = True # apply normalization on input data 106 | args.resume = True # Resume training from checkpoint if available 107 | args.end_epoch = 10 108 | 109 | train_model(args) 110 | 111 | 112 | # Generate patches and index for quilting 113 | def _index_patches(): 114 | print("=" * 30 + "GENERATING QUILTING PATCHES AND INDICES" + "=" * 30) 115 | args = opts.parse_args(opts.OptType.QUILTING_PATCHES) 116 | args.num_patches = 10000 117 | args.quilting_patch_size = 5 118 | args.index_file = str(os.path.join(args.quilting_index_root, 119 | "index_{}.faiss".format(args.quilting_patch_size))) 120 | args.patches_file = str(os.path.join(args.quilting_patch_root, 121 | "patches_{}.pickle".format(args.quilting_patch_size))) 122 | create_faiss_patches(args) 123 | 124 | 125 | if __name__ == '__main__': 126 | _generate_adversarial_images() 127 | _index_patches() 128 | _generate_transformed_images() 129 | _train_model() 130 | _classify_images() 131 | -------------------------------------------------------------------------------- /adversarial/gen_adversarial_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import progressbar 14 | import torch 15 | import lib.opts as opts 16 | from lib.dataset import load_dataset, get_data_loader 17 | import lib.adversary as adversary 18 | from lib.model import get_model 19 | import lib.constants as constants 20 | from lib.constants import AdversaryType 21 | from lib.paths import get_adversarial_file_path 22 | from lib.transformations.transforms import Unnormalize, Normalize 23 | import os 24 | from enum import Enum 25 | 26 | 27 | class OperationType(Enum): 28 | GENERATE_ADVERSARIAL = 'generate_adversarial' 29 | CONCAT_ADVERSARIAL = 'concat_adversarial' 30 | COMPUTE_STATS = 'compute_adversarial_stats' 31 | 32 | @classmethod 33 | def has_value(cls, value): 34 | return (any(value == item.value for item in cls)) 35 | 36 | def __str__(self): 37 | return str(self.value) 38 | 39 | 40 | def _get_data_indices(args): 41 | assert 'partition' in args, \ 42 | 'partition argumenet is expected but not present in args' 43 | assert 'partition_size' in args, \ 44 | 'partition_size argumenet is expected but not present in args' 45 | 46 | data_indices = {} 47 | data_indices['start_idx'] = args.partition * args.partition_size 48 | data_indices['end_idx'] = (args.partition + 1) * args.partition_size 49 | return data_indices 50 | 51 | 52 | # Concat adversarial data generated from batches 53 | def concat_adversarial(args): 54 | assert not args.partition_size == 0, \ 55 | "partition_size can't be zero" 56 | assert 'learning_rate' in args, \ 57 | "adv_params are not provided" 58 | assert len(args.learning_rate) == 1, \ 59 | "adv_params are not provided" 60 | 61 | defense_name = None if not args.defenses else args.defenses[0] 62 | adv_params = { 63 | 'learning_rate': args.learning_rate[0], 64 | 'adv_strength': None 65 | } 66 | 67 | end_idx = args.n_samples 68 | nfiles = end_idx // args.partition_size 69 | for i in range(nfiles): 70 | start_idx = (i * args.partition_size) + 1 71 | partition_end = (i + 1) * args.partition_size 72 | partition_file = get_adversarial_file_path( 73 | args, args.adversarial_root, defense_name, adv_params, partition_end, 74 | start_idx, with_defense=False) 75 | 76 | assert os.path.isfile(partition_file), \ 77 | "No file found at " + partition_file 78 | print('| Reading file ' + partition_file) 79 | result = torch.load(partition_file) 80 | inputs = result['all_inputs'] 81 | outputs = result['all_outputs'] 82 | targets = result['all_targets'] 83 | status = result['status'] 84 | targets = torch.LongTensor(targets) 85 | if i == 0: 86 | all_inputs = inputs 87 | all_outputs = outputs 88 | all_targets = targets 89 | all_status = status 90 | else: 91 | all_inputs = torch.cat((all_inputs, inputs), 0) 92 | all_outputs = torch.cat((all_outputs, outputs), 0) 93 | all_status = torch.cat((all_status, status), 0) 94 | all_targets = torch.cat((all_targets, targets), 0) 95 | # print(all_inputs.size()) 96 | 97 | out_file = get_adversarial_file_path(args, args.adversarial_root, 98 | defense_name, adv_params, 99 | nfiles * args.partition_size, 100 | args.partition + 1, 101 | with_defense=False) 102 | 103 | if not os.path.isdir(args.adversarial_root): 104 | os.mkdir(args.adversarial_root) 105 | print('| Writing concatenated adversarial data to ' + out_file) 106 | torch.save({'status': all_status, 'all_inputs': all_inputs, 107 | 'all_outputs': all_outputs, 'all_targets': all_targets}, 108 | out_file) 109 | 110 | 111 | def compute_stats(args): 112 | assert not args.partition_size == 0, \ 113 | "partition_size can't be zero" 114 | assert 'learning_rate' in args and 'adv_strength' in args, \ 115 | "adv_params are not provided" 116 | 117 | defense_name = None if not args.defenses else args.defenses[0] 118 | adv_params = constants.get_adv_params(args) 119 | print('| adv_params:', adv_params) 120 | start_idx = 0 121 | end_idx = args.n_samples 122 | in_file = get_adversarial_file_path( 123 | args, args.adversarial_root, defense_name, adv_params, end_idx, 124 | start_idx, with_defense=True) 125 | assert os.path.isfile(in_file), \ 126 | "No file found at " + in_file 127 | print('| Reading file ' + in_file) 128 | result = torch.load(in_file) 129 | all_inputs = result['all_inputs'] 130 | all_outputs = result['all_outputs'] 131 | 132 | normalize = Normalize(args.data_params['MEAN_STD']['MEAN'], 133 | args.data_params['MEAN_STD']['STD']) 134 | all_inputs = normalize(all_inputs) 135 | all_outputs = normalize(all_outputs) 136 | rb, _ssim, sc = adversary.compute_stats( 137 | all_inputs, all_outputs, result['status']) 138 | print('average robustness = ' + str(rb)) 139 | print('success rate = ' + str(sc)) 140 | 141 | 142 | def generate_adversarial_images(args): 143 | # assertions 144 | assert args.adversary_to_generate is not None, \ 145 | "adversary_to_generate can't be None" 146 | assert AdversaryType.has_value(args.adversary_to_generate), \ 147 | "\"{}\" adversary_to_generate not defined".format(args.adversary_to_generate) 148 | 149 | defense_name = None if not args.defenses else args.defenses[0] 150 | data_indices = _get_data_indices(args) 151 | data_type = args.data_type if args.data_type == "train" else "valid" 152 | dataset = load_dataset(args, data_type, None, data_indices=data_indices) 153 | data_loader = get_data_loader( 154 | dataset, 155 | batchsize=args.batchsize, 156 | device=args.device, 157 | shuffle=False) 158 | 159 | model, _, _ = get_model(args, load_checkpoint=True, defense_name=defense_name) 160 | 161 | adv_params = constants.get_adv_params(args) 162 | print('| adv_params:', adv_params) 163 | status = None 164 | all_inputs = None 165 | all_outputs = None 166 | all_targets = None 167 | bar = progressbar.ProgressBar(len(data_loader)) 168 | bar.start() 169 | for batch_num, (imgs, targets) in enumerate(data_loader): 170 | if args.adversary_to_generate == str(AdversaryType.DEEPFOOL): 171 | assert adv_params['learning_rate'] is not None 172 | s, r = adversary.deepfool( 173 | model, imgs, targets, args.data_params['NUM_CLASSES'], 174 | train_mode=(args.data_type == 'train'), max_iter=args.max_adv_iter, 175 | step_size=adv_params['learning_rate'], batch_size=args.batchsize, 176 | labels=dataset.get_classes()) 177 | elif args.adversary_to_generate == str(AdversaryType.FGS): 178 | s, r = adversary.fgs( 179 | model, imgs, targets, train_mode=(args.data_type == 'train'), 180 | mode=args.fgs_mode) 181 | elif args.adversary_to_generate == str(AdversaryType.IFGS): 182 | assert adv_params['learning_rate'] is not None 183 | s, r = adversary.ifgs( 184 | model, imgs, targets, 185 | train_mode=(args.data_type == 'train'), max_iter=args.max_adv_iter, 186 | step_size=adv_params['learning_rate'], mode=args.fgs_mode) 187 | elif args.adversary_to_generate == str(AdversaryType.CWL2): 188 | assert args.adv_strength is not None and len(args.adv_strength) == 1 189 | if len(args.crop_frac) == 1: 190 | crop_frac = args.crop_frac[0] 191 | else: 192 | crop_frac = 1.0 193 | s, r = adversary.cw( 194 | model, imgs, targets, args.adv_strength[0], 'l2', 195 | tv_weight=args.tvm_weight, 196 | train_mode=(args.data_type == 'train'), max_iter=args.max_adv_iter, 197 | drop_rate=args.pixel_drop_rate, crop_frac=crop_frac, 198 | kappa=args.margin) 199 | elif args.adversary_to_generate == str(AdversaryType.CWLINF): 200 | assert args.adv_strength is not None and len(args.adv_strength) == 1 201 | s, r = adversary.cw( 202 | model, imgs, targets, args.adv_strength[0], 'linf', 203 | bound=args.adv_bound, 204 | tv_weight=args.tvm_weight, 205 | train_mode=(args.data_type == 'train'), max_iter=args.max_adv_iter, 206 | drop_rate=args.pixel_drop_rate, crop_frac=args.crop_frac, 207 | kappa=args.margin) 208 | 209 | if status is None: 210 | status = s.clone() 211 | all_inputs = imgs.clone() 212 | all_outputs = imgs + r 213 | all_targets = targets.clone() 214 | else: 215 | status = torch.cat((status, s), 0) 216 | all_inputs = torch.cat((all_inputs, imgs), 0) 217 | all_outputs = torch.cat((all_outputs, imgs + r), 0) 218 | all_targets = torch.cat((all_targets, targets), 0) 219 | bar.update(batch_num) 220 | 221 | print("| computing adversarial stats...") 222 | if args.compute_stats: 223 | rb, ssim, sc = adversary.compute_stats(all_inputs, all_outputs, status) 224 | print('| average robustness = ' + str(rb)) 225 | print('| average SSIM = ' + str(ssim)) 226 | print('| success rate = ' + str(sc)) 227 | 228 | # Unnormalize before saving 229 | unnormalize = Unnormalize(args.data_params['MEAN_STD']['MEAN'], 230 | args.data_params['MEAN_STD']['STD']) 231 | all_inputs = unnormalize(all_inputs) 232 | all_outputs = unnormalize(all_outputs) 233 | # save output 234 | output_file = get_adversarial_file_path( 235 | args, args.adversarial_root, defense_name, adv_params, 236 | data_indices['end_idx'], start_idx=data_indices['start_idx'], 237 | with_defense=False) 238 | print("| Saving adversarial data at " + output_file) 239 | if not os.path.isdir(args.adversarial_root): 240 | os.makedirs(args.adversarial_root) 241 | torch.save({'status': status, 'all_inputs': all_inputs, 242 | 'all_outputs': all_outputs, 'all_targets': all_targets}, 243 | output_file) 244 | 245 | 246 | def main(): 247 | # parse input arguments: 248 | args = opts.parse_args(opts.OptType.ADVERSARIAL) 249 | 250 | # Only runs one method at a time 251 | assert args.operation is not None, \ 252 | "operation to run can't be None" 253 | assert OperationType.has_value(args.operation), \ 254 | "\"{}\" operation not defined".format(args.operation) 255 | if args.attack_type == str(constants.AttackType.WHITEBOX): 256 | assert args.defenses is not None, \ 257 | "For whitebox attacks, atleast one defense is required" 258 | elif args.defenses is not None: 259 | print("Warning: Defenses will be unused for non whitebox attacks") 260 | 261 | if args.operation == str(OperationType.GENERATE_ADVERSARIAL): 262 | generate_adversarial_images(args) 263 | elif args.operation == str(OperationType.CONCAT_ADVERSARIAL): 264 | concat_adversarial(args) 265 | elif args.operation == str(OperationType.COMPUTE_STATS): 266 | compute_stats(args) 267 | 268 | 269 | # run: 270 | if __name__ == '__main__': 271 | main() 272 | -------------------------------------------------------------------------------- /adversarial/gen_tar_index.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | import argparse 13 | from lib.datasets.tarfolder import gen_tar_index 14 | import sys 15 | import os 16 | 17 | 18 | def parse_args(args): 19 | parser = argparse.ArgumentParser(description='Generate tar indices') 20 | parser.add_argument('--tar_path', 21 | default=None, 22 | type=str, metavar='N', 23 | help='Path for tar file or directory') 24 | parser.add_argument('--index_root', 25 | default=None, 26 | type=str, metavar='N', 27 | help='Directory path to store tar index object') 28 | parser.add_argument('--path_prefix', 29 | default='', 30 | type=str, metavar='N', 31 | help='prefix in member name') 32 | 33 | args = parser.parse_args(args) 34 | return args 35 | 36 | 37 | def generate_tar_index(args): 38 | assert args.tar_path is not None 39 | assert args.index_root is not None 40 | if not os.path.isdir(args.index_root): 41 | os.mkdir(args.index_root) 42 | 43 | gen_tar_index(args.tar_path, args.index_root, args.path_prefix) 44 | 45 | 46 | if __name__ == '__main__': 47 | args = parse_args(sys.argv[1:]) 48 | generate_tar_index(args) 49 | -------------------------------------------------------------------------------- /adversarial/gen_transformed_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | from lib.dataset import load_dataset 14 | from lib.defenses import get_defense 15 | import torch 16 | import torchvision.transforms as trans 17 | import os 18 | import lib.paths as paths 19 | import lib.opts as opts 20 | from enum import Enum 21 | import lib.constants as constants 22 | from multiprocessing.pool import ThreadPool 23 | from torchvision.transforms import ToPILImage 24 | 25 | 26 | class OperationType(Enum): 27 | TRANSFORM_ADVERSARIAL = 'transformation_on_adv' 28 | TRANSFORM_RAW = 'transformation_on_raw' 29 | CAT_DATA = 'concatenate_data' 30 | SAVE_SAMPLES = 'save_samples' 31 | 32 | @classmethod 33 | def has_value(cls, value): 34 | return (any(value == item.value for item in cls)) 35 | 36 | def __str__(self): 37 | return str(self.value) 38 | 39 | 40 | def _get_start_end_index(args): 41 | assert 'partition' in args, \ 42 | 'partition argumenet is expected but not present in args' 43 | assert 'partition_size' in args, \ 44 | 'partition_size argumenet is expected but not present in args' 45 | 46 | start_idx = args.partition * args.partition_size 47 | end_idx = (args.partition + 1) * args.partition_size 48 | return start_idx, end_idx 49 | 50 | 51 | def _get_out_file(output_root, defense_name, target, file_name): 52 | output_dir = '{root}/{transfomation}/{target}'.format( 53 | root=output_root, transfomation=defense_name, target=target) 54 | if not os.path.isdir(output_dir): 55 | try: 56 | os.makedirs(output_dir) 57 | except OSError as exception: 58 | import errno 59 | if exception.errno != errno.EEXIST: 60 | raise 61 | output_filepath = '{path}/{fname}'.format(path=output_dir, fname=file_name) 62 | return output_filepath 63 | 64 | 65 | # setup partial dataset 66 | def _load_partial_dataset(args, data_type, defense, adv_params): 67 | start_idx, end_idx = _get_start_end_index(args) 68 | data_indices = {'start_idx': start_idx, 'end_idx': end_idx} 69 | dataset = load_dataset(args, data_type, defense, adv_params, 70 | data_indices=data_indices) 71 | return dataset 72 | 73 | 74 | # Concat data generated from batches 75 | def concatenate_data(args, defense_name, adv_params, data_batch_idx=None): 76 | assert not args.partition_size == 0, \ 77 | "partition_size can't be zero" 78 | 79 | end_idx = args.n_samples 80 | nfiles = end_idx // args.partition_size 81 | for i in range(nfiles): 82 | start_idx = (i * args.partition_size) + 1 83 | partition_end = (i + 1) * args.partition_size 84 | partition_file = paths.get_adversarial_file_path( 85 | args, args.partition_dir, defense_name, adv_params, partition_end, 86 | start_idx, data_batch_idx) 87 | 88 | assert os.path.isfile(partition_file), \ 89 | "No file found at " + partition_file 90 | print('| Reading file ' + partition_file) 91 | result = torch.load(partition_file) 92 | inputs = result['all_outputs'] 93 | targets = result['all_targets'] 94 | targets = torch.LongTensor(targets) 95 | if i == 0: 96 | all_imgs = inputs 97 | all_targets = targets 98 | else: 99 | all_imgs = torch.cat((all_imgs, inputs), 0) 100 | all_targets = torch.cat((all_targets, targets), 0) 101 | 102 | out_file = paths.get_adversarial_file_path(args, args.out_dir, 103 | defense_name, adv_params, 104 | nfiles * args.partition_size, 105 | args.partition + 1) 106 | 107 | if not os.path.isdir(args.out_dir): 108 | os.mkdir(args.out_dir) 109 | print('| Writing concatenated data to ' + out_file) 110 | torch.save({'all_outputs': all_imgs, 'all_targets': all_targets}, 111 | out_file) 112 | 113 | 114 | # Apply transformations on adversarial images 115 | def transformation_on_adv(args, dataset, defense_name, adv_params, 116 | data_batch_idx=None): 117 | pool = ThreadPool(args.n_threads) 118 | 119 | def generate(idx): 120 | return dataset[idx] 121 | 122 | dataset = pool.map(generate, range(len(dataset))) 123 | pool.close() 124 | pool.join() 125 | 126 | # save all data in a file 127 | all_adv = [] 128 | all_targets = [] 129 | for item in dataset: 130 | all_adv.append(item[0]) 131 | all_targets.append(item[1]) 132 | 133 | all_adv = torch.stack(all_adv, 0) 134 | all_targets = torch.LongTensor(all_targets) 135 | if not os.path.isdir(args.partition_dir): 136 | os.makedir(args.partition_dir) 137 | start_idx, end_idx = _get_start_end_index(args) 138 | out_file = paths.get_adversarial_file_path( 139 | args, args.partition_dir, defense_name, adv_params, 140 | end_idx, start_idx, data_batch_idx) 141 | torch.save({'all_outputs': all_adv, 142 | 'all_targets': all_targets}, 143 | out_file) 144 | 145 | print('Saved Transformed tensor at ' + out_file) 146 | dataset = None 147 | all_adv = None 148 | 149 | 150 | def transformation_on_raw(args, dataset, defense_name): 151 | pool = ThreadPool(args.n_threads) 152 | if not os.path.isdir(args.partition_dir): 153 | os.makedirs(args.partition_dir) 154 | 155 | def generate(idx): 156 | img, target_index, file_name = dataset[idx] 157 | target = dataset.get_idx_to_class(target_index) 158 | out_file = _get_out_file(args.partition_dir, defense_name, target, file_name) 159 | ToPILImage()(img).save(out_file) 160 | 161 | pool.map(generate, range(len(dataset))) 162 | 163 | 164 | def save_samples(args): 165 | assert args.data_file is not None and os.path.isfile(args.data_file), \ 166 | "Data file path required" 167 | 168 | basename = os.path.basename(args.data_file) 169 | # Validate if generated data is good 170 | result = torch.load(args.data_file) 171 | outputs = result['all_outputs'] 172 | for i in range(10): 173 | img = trans.ToPILImage()(outputs[i]) 174 | img_path = str("/tmp/test_img_" + basename + "_" + str(i) + ".JPEG") 175 | print("saving image: " + img_path) 176 | img.save(img_path) 177 | 178 | 179 | def generate_transformed_images(args): 180 | 181 | # Only runs one method at a time 182 | assert args.operation is not None, \ 183 | "operation to run can't be None" 184 | assert OperationType.has_value(args.operation), \ 185 | "\"{}\" operation not defined".format(args.operation) 186 | 187 | assert args.defenses is not None, "Defenses can't be None" 188 | assert not args.preprocessed_data, \ 189 | "Trying to apply transformations on already transformed images" 190 | 191 | if args.operation == str(OperationType.TRANSFORM_ADVERSARIAL): 192 | for idx, defense_name in enumerate(args.defenses): 193 | defense = get_defense(defense_name, args) 194 | adv_params = constants.get_adv_params(args, idx) 195 | print("| adv_params: ", adv_params) 196 | dataset = _load_partial_dataset(args, 'valid', defense, adv_params) 197 | 198 | if args.data_batches is None: 199 | transformation_on_adv(args, dataset, defense_name, adv_params) 200 | else: 201 | for i in range(args.data_batches): 202 | transformation_on_adv(args, dataset, defense_name, adv_params, 203 | data_batch_idx=i) 204 | 205 | elif args.operation == str(OperationType.CAT_DATA): 206 | for idx, defense_name in enumerate(args.defenses): 207 | adv_params = constants.get_adv_params(args, idx) 208 | print("| adv_params: ", adv_params) 209 | if args.data_batches is None: 210 | concatenate_data(args, defense_name, adv_params) 211 | else: 212 | for i in range(args.data_batches): 213 | concatenate_data(args, defense_name, adv_params, data_batch_idx=i) 214 | 215 | elif args.operation == str(OperationType.TRANSFORM_RAW): 216 | start_class_idx = args.partition * args.partition_size 217 | end_class_idx = (args.partition + 1) * args.partition_size 218 | class_indices = range(start_class_idx, end_class_idx) 219 | for defense_name in args.defenses: 220 | defense = get_defense(defense_name, args) 221 | data_type = args.data_type if args.data_type == "train" else "valid" 222 | dataset = load_dataset(args, data_type, defense, 223 | class_indices=class_indices) 224 | transformation_on_raw(args, dataset, defense_name) 225 | 226 | 227 | # run: 228 | if __name__ == '__main__': 229 | # parse input arguments: 230 | args = opts.parse_args(opts.OptType.TRANSFORMATION) 231 | generate_transformed_images(args) 232 | -------------------------------------------------------------------------------- /adversarial/index_patches.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | try: 14 | import cPickle as pickle 15 | except ImportError: 16 | import pickle 17 | import progressbar 18 | import random 19 | 20 | import torch 21 | import faiss 22 | 23 | from lib.dataset import load_dataset, get_data_loader 24 | import lib.opts as opts 25 | 26 | 27 | # function that indexes a large number of patches: 28 | def gather_patches(image_dataset, num_patches, patch_size, patch_transform=None): 29 | 30 | # assertions: 31 | assert isinstance(image_dataset, torch.utils.data.dataset.Dataset) 32 | assert type(num_patches) == int and num_patches > 0 33 | assert type(patch_size) == int and patch_size > 0 34 | if patch_transform is not None: 35 | assert callable(patch_transform) 36 | 37 | # gather patches (TODO: speed this up): 38 | patches, n = [], 0 39 | num_images = len(image_dataset) 40 | bar = progressbar.ProgressBar(num_patches) 41 | bar.start() 42 | data_loader = get_data_loader(image_dataset, batchsize=1, workers=1) 43 | for (img, _) in data_loader: 44 | img = img.squeeze() 45 | for _ in range(0, max(1, int(num_patches / num_images))): 46 | n += 1 47 | y = random.randint(0, img.size(1) - patch_size) 48 | x = random.randint(0, img.size(2) - patch_size) 49 | patch = img[:, y:y + patch_size, x:x + patch_size] 50 | if patch_transform is not None: 51 | patch = patch_transform(patch) 52 | patches.append(patch) 53 | if n % 100 == 0: 54 | bar.update(n) 55 | if n >= num_patches: 56 | break 57 | if n >= num_patches: 58 | break 59 | 60 | # copy all patches into single tensor: 61 | patches = torch.stack(patches, dim=0) 62 | patches = patches.view(patches.size(0), int(patches.nelement() / patches.size(0))) 63 | return patches 64 | 65 | 66 | # function that trains faiss index on patches and saves them: 67 | def index_patches(patches, index_file, pca_dims=64): 68 | 69 | # settings for faiss: 70 | num_lists, M, num_bits = 200, 16, 8 71 | 72 | # assertions: 73 | assert torch.is_tensor(patches) and patches.dim() == 2 74 | assert type(pca_dims) == int and pca_dims > 0 75 | if pca_dims > patches.size(1): 76 | print('WARNING: Input dimension < %d. Using fewer PCA dimensions.' % pca_dims) 77 | pca_dims = patches.size(1) - (patches.size(1) % M) 78 | 79 | # construct faiss index: 80 | quantizer = faiss.IndexFlatL2(pca_dims) 81 | assert pca_dims % M == 0 82 | sub_index = faiss.IndexIVFPQ(quantizer, pca_dims, num_lists, M, num_bits) 83 | pca_matrix = faiss.PCAMatrix(patches.size(1), pca_dims, 0, True) 84 | faiss_index = faiss.IndexPreTransform(pca_matrix, sub_index) 85 | 86 | # train faiss index: 87 | patches = patches.numpy() 88 | faiss_index.train(patches) 89 | faiss_index.add(patches) 90 | 91 | # save faiss index: 92 | print('| writing faiss index to %s' % index_file) 93 | faiss.write_index(faiss_index, index_file) 94 | 95 | 96 | # run all the things: 97 | def create_faiss_patches(args): 98 | 99 | # load image dataset: 100 | print('| set up image loader...') 101 | image_dataset = load_dataset(args, 'train', None, with_transformation=True) 102 | image_dataset.imgs = image_dataset.imgs[:20000] # we don't need all images 103 | 104 | # gather image patches: 105 | print('| gather image patches...') 106 | patches = gather_patches( 107 | image_dataset, args.num_patches, args.quilting_patch_size, 108 | patch_transform=None, 109 | ) 110 | 111 | # build faiss index: 112 | print('| training faiss index...') 113 | index_patches(patches, args.index_file, pca_dims=args.pca_dims) 114 | 115 | # save patches: 116 | with open(args.patches_file, 'wb') as fwrite: 117 | print('| writing patches to %s' % args.patches_file) 118 | pickle.dump(patches, fwrite, pickle.HIGHEST_PROTOCOL) 119 | 120 | 121 | # run: 122 | if __name__ == '__main__': 123 | # parse input arguments: 124 | args = opts.parse_args(opts.OptType.QUILTING_PATCHES) 125 | create_faiss_patches(args) 126 | -------------------------------------------------------------------------------- /adversarial/lib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | from __future__ import unicode_literals 11 | -------------------------------------------------------------------------------- /adversarial/lib/adversary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import torch 14 | import lib.transformations.tvm as minimize_tv 15 | import torchvision.transforms as trans 16 | from enum import Enum 17 | import lib.util as util 18 | from lib.transformations.transforms import Crop 19 | 20 | 21 | class FGSMode(Enum): 22 | CARLINI = 'carlini' 23 | LOGIT = 'logit' 24 | 25 | @classmethod 26 | def has_value(cls, value): 27 | return (any(value == item.value for item in cls)) 28 | 29 | def __str__(self): 30 | return str(self.value) 31 | 32 | 33 | # computes robustness, MSSIM, and success rate 34 | # commented out MSSIM since it is quite slow... 35 | def compute_stats(all_inputs, all_outputs, status): 36 | # computing ssim takes too long... 37 | # ssim = MSSIM(all_inputs, all_outputs) 38 | ssim = None 39 | all_inputs = all_inputs.view(all_inputs.size(0), -1) 40 | all_outputs = all_outputs.view(all_outputs.size(0), -1) 41 | diff = (all_inputs - all_outputs).norm(2, 1).squeeze() 42 | diff = diff.div(all_inputs.norm(2, 1).squeeze()) 43 | n_succ = status.eq(1).sum() 44 | n_fail = status.eq(-1).sum() 45 | return (diff.mean(), ssim, float(n_succ) / float(n_succ + n_fail)) 46 | 47 | 48 | # Implementing fast gradient sign 49 | # Goodfellow et al. - Explaining and harnessing adversarial examples 50 | # Use logit mode if computing loss w.r.t. output scores rather than softmax 51 | def fgs(model, input, target, step_size=0.1, train_mode=False, mode=None, verbose=True): 52 | is_gpu = next(model.parameters()).is_cuda 53 | if mode: 54 | assert FGSMode.has_value(mode) 55 | if train_mode: 56 | model.train() 57 | else: 58 | model.eval() 59 | model.zero_grad() 60 | input_var = torch.autograd.Variable(input, requires_grad=True) 61 | output = model(input_var) 62 | if is_gpu: 63 | cpu_targets = target.clone() 64 | target = target.cuda(async=True) 65 | else: 66 | cpu_targets = target 67 | target_var = torch.autograd.Variable(target) 68 | _, pred = output.data.cpu().max(1) 69 | pred = pred.squeeze() 70 | corr = pred.eq(cpu_targets) 71 | if mode == str(FGSMode.CARLINI): 72 | output = output.mul(-1).add(1).log() 73 | criterion = torch.nn.NLLLoss() 74 | elif mode == str(FGSMode.LOGIT): 75 | criterion = torch.nn.NLLLoss() 76 | else: 77 | criterion = torch.nn.CrossEntropyLoss() 78 | if is_gpu: 79 | criterion = criterion.cuda() 80 | loss = criterion(output, target_var) 81 | loss.backward() 82 | grad_sign = input_var.grad.sign() 83 | input_var2 = input_var + step_size * grad_sign 84 | output2 = model(input_var2) 85 | _, pred2 = output2.data.cpu().max(1) 86 | pred2 = pred2.squeeze() 87 | status = torch.zeros(input_var.size(0)).long() 88 | status[corr] = 2 * pred[corr].ne(pred2[corr]).long() - 1 89 | return (status, step_size * grad_sign.data.cpu()) 90 | 91 | 92 | # uses line search to find the first step size that reaches desired robustness 93 | # model should have perfect accuracy on input 94 | # assumes epsilon that achieves desired robustness is in range 2^[a,b] 95 | def fgs_search(model, input, target, r, rb=0.9, precision=2, 96 | a=-10.0, b=0.0, batch_size=25, verbose=True): 97 | opt_exp = b 98 | for i in range(precision): 99 | # search through predefined range on first iteration 100 | if i == 0: 101 | lower = a 102 | else: 103 | lower = opt_exp - pow(10, 1 - i) 104 | exponents = torch.arange(lower, opt_exp, pow(10, -i)) 105 | for exponent in exponents: 106 | step_size = pow(2, exponent) 107 | succ = torch.zeros(input.size(0)).byte() 108 | dataset = torch.utils.data.TensorDataset(input + step_size * r, target) 109 | dataloader = torch.utils.data.DataLoader( 110 | dataset, batch_size=batch_size, shuffle=False, num_workers=4) 111 | count = 0 112 | for x, y in dataloader: 113 | x_batch = torch.autograd.Variable(x).cuda() 114 | output = model.forward(x_batch) 115 | _, pred = output.data.max(1) 116 | pred = pred.squeeze().cpu() 117 | succ[count:(count + x.size(0))] = pred.ne(y) 118 | count = count + x.size(0) 119 | success_rate = succ.float().mean() 120 | if verbose: 121 | print('step size = %1.4f, success rate = %1.4f' 122 | % (step_size, success_rate)) 123 | if success_rate >= rb: 124 | opt_exp = exponent 125 | break 126 | return (succ, pow(2, opt_exp)) 127 | 128 | 129 | def fgs_compute_status(model, inputs, outputs, targets, status, 130 | batch_size=25, threshold=0.9, verbose=True): 131 | all_idx = torch.arange(0, status.size(0)).long() 132 | corr = all_idx[status.ne(0)] 133 | r = outputs - inputs 134 | succ, eps = fgs_search( 135 | model, inputs[corr], targets[corr], r[corr], rb=threshold, 136 | batch_size=batch_size, verbose=verbose) 137 | succ = succ.long() 138 | status[corr] = 2 * succ - 1 139 | return (status, eps) 140 | 141 | 142 | # Implements iterative fast gradient sign 143 | # Kurakin et al. - Adversarial examples in the physical world 144 | def ifgs(model, input, target, max_iter=10, step_size=0.01, train_mode=False, 145 | mode=None, verbose=True): 146 | if train_mode: 147 | model.train() 148 | else: 149 | model.eval() 150 | pred = util.get_labels(model, input) 151 | corr = pred.eq(target) 152 | r = torch.zeros(input.size()) 153 | for _ in range(max_iter): 154 | _, ri = fgs( 155 | model, input, target, step_size, train_mode, mode, verbose=verbose) 156 | r = r + ri 157 | input = input + ri 158 | pred_xp = util.get_labels(model, input + r) 159 | status = torch.zeros(input.size(0)).long() 160 | status[corr] = 2 * pred[corr].ne(pred_xp[corr]).long() - 1 161 | return (status, r) 162 | 163 | 164 | # computes DeepFool for a single input image 165 | def deepfool_single(model, imgs, target, n_classes, train_mode, max_iter=10, 166 | step_size=0.1, batch_size=25, labels=None, verbose=True): 167 | is_gpu = next(model.parameters()).is_cuda 168 | if train_mode: 169 | model.train() 170 | else: 171 | model.eval() 172 | cpu_targets = target 173 | imgs_var = torch.autograd.Variable(imgs) 174 | imgs_var2 = imgs_var.clone() 175 | r = torch.zeros(imgs_var.size()) 176 | criterion = torch.nn.NLLLoss() 177 | if is_gpu: 178 | criterion = criterion.cuda() 179 | for m in range(max_iter): 180 | imgs_var_in = imgs_var2.expand(1, imgs_var2.size(0), imgs_var2.size(1), 181 | imgs_var2.size(2)) 182 | grad_input = imgs_var_in.repeat(n_classes, 1, 1, 1) 183 | output = model(imgs_var_in).clone() 184 | for j in range(int(n_classes / batch_size)): 185 | model.zero_grad() 186 | idx = torch.arange(j * batch_size, 187 | (j + 1) * batch_size).long() 188 | imgs_var_batch = torch.autograd.Variable( 189 | imgs_var_in.data.repeat(batch_size, 1, 1, 1), requires_grad=True) 190 | output_batch = model(imgs_var_batch) 191 | if is_gpu: 192 | _idx = idx.clone().cuda() 193 | else: 194 | _idx = idx.clone() 195 | loss_batch = criterion(output_batch, torch.autograd.Variable(_idx)) 196 | loss_batch.backward() 197 | grad_input.index_copy_(0, torch.autograd.Variable(idx), 198 | -imgs_var_batch.grad) 199 | f = (output - output[0][target].expand_as(output)).cpu() 200 | w = grad_input - grad_input[target].expand_as(grad_input) 201 | w_norm = w.view(n_classes, -1).norm(2, 1) 202 | ratio = torch.abs(f).div(w_norm).data 203 | ratio[0][target] = float('inf') 204 | min_ratio, min_idx = ratio.min(1) 205 | min_w = w[min_idx[0]] 206 | min_norm = w_norm[min_idx[0]].data 207 | min_ratio = min_ratio[0] 208 | min_norm = min_norm[0] 209 | ri = min_ratio / min_norm * step_size * min_w 210 | imgs_var2 = imgs_var2.add(ri) 211 | r = r.add(ri.data) 212 | imgs_var_in = imgs_var2.clone().expand(1, imgs_var2.size(0), 213 | imgs_var2.size(1), imgs_var2.size(2)) 214 | output2 = model.forward(imgs_var_in).clone() 215 | _, pred2 = output2.data.cpu().max(1) 216 | pred2 = pred2.squeeze()[0] 217 | diff = torch.norm(imgs_var - imgs_var2) / torch.norm(imgs_var) 218 | diff = diff.data[0] 219 | if verbose: 220 | print('iteration ' + str(m + 1) + 221 | ': perturbation norm ratio = ' + str(diff)) 222 | if pred2 != cpu_targets: 223 | if verbose: 224 | if labels: 225 | print('old label = %s, new label = %s' % (labels[cpu_targets], 226 | labels[pred2])) 227 | else: 228 | print('old label = %d, new label = %d' % (cpu_targets, pred2)) 229 | break 230 | return (pred2 != target, r) 231 | 232 | 233 | # Implements DeepFool for a batch of examples 234 | def deepfool(model, input, target, n_classes, train_mode=False, max_iter=5, 235 | step_size=0.1, batch_size=25, labels=None): 236 | pred = util.get_labels(model, input, batch_size) 237 | status = torch.zeros(input.size(0)).long() 238 | r = torch.zeros(input.size()) 239 | for i in range(input.size(0)): 240 | status[i], r[i] = deepfool_single( 241 | model, input[i], target[i], n_classes, train_mode, 242 | max_iter, step_size, batch_size, labels) 243 | status = 2 * status - 1 244 | status[pred.ne(target)] = 0 245 | return (status, r) 246 | 247 | 248 | # Implements universal adversarial perturbations 249 | # does not really work... 250 | def universal(model, input, target, n_classes, max_val=0.1, train_mode=False, 251 | max_iter=10, step_size=0.1, batch_size=25, data_dir=None, r=None, 252 | verbose=True): 253 | pred = util.get_labels(model, input, batch_size) 254 | if r is None: 255 | r = torch.zeros(input[0].size()) 256 | perm = torch.randperm(input.size(0)) 257 | for i in range(input.size(0)): 258 | idx = perm[i] 259 | if verbose: 260 | print('sample %d: index %d' % (i + 1, idx)) 261 | x_adv = torch.autograd.Variable((input[idx] + r)) 262 | x_adv = x_adv.expand(1, input.size(1), input.size(2), input.size(3)) 263 | output = model.forward(x_adv) 264 | _, pred_adv = output.max(1) 265 | pred_adv = pred_adv.data.cpu()[0][0] 266 | if pred[idx] == pred_adv: 267 | succ, ri = deepfool_single( 268 | model, input[idx] + r, pred[idx], n_classes, train_mode, max_iter, 269 | step_size, batch_size, data_dir) 270 | if succ: 271 | r = (r + ri).clamp(-max_val, max_val) 272 | x = input + r.expand_as(input) 273 | pred_xp = util.get_labels(model, x) 274 | status = 2 * pred_xp.ne(target).long() - 1 275 | status[pred.ne(target)] = 0 276 | return (status, r) 277 | 278 | 279 | # Implements Carlini-Wagner's L2 and Linf attacks 280 | # Carlini and Wagner - Towards evaluating the robustness of neural networks 281 | # Modified with TV minimization, random cropping, and random pixel dropping 282 | def cw(model, input, target, weight, loss_str, bound=0, tv_weight=0, 283 | max_iter=100, step_size=0.01, kappa=0, p=2, crop_frac=1.0, drop_rate=0.0, 284 | train_mode=False, verbose=True): 285 | is_gpu = next(model.parameters()).is_cuda 286 | if train_mode: 287 | model.train() 288 | else: 289 | model.eval() 290 | pred = util.get_labels(model, input) 291 | corr = pred.eq(target) 292 | w = torch.autograd.Variable(input, requires_grad=True) 293 | best_w = input.clone() 294 | best_loss = float('inf') 295 | optimizer = torch.optim.Adam([w], lr=step_size) 296 | input_var = torch.autograd.Variable(input) 297 | input_vec = input.view(input.size(0), -1) 298 | to_pil = trans.ToPILImage() 299 | scale_up = trans.Resize((w.size(2), w.size(3))) 300 | scale_down = trans.Resize((int(crop_frac * w.size(2)), int(crop_frac * w.size(3)))) 301 | to_tensor = trans.ToTensor() 302 | probs = util.get_probs(model, input) 303 | _, top2 = probs.topk(2, 1) 304 | argmax = top2[:, 0] 305 | for j in range(top2.size(0)): 306 | if argmax[j] == target[j]: 307 | argmax[j] = top2[j, 1] 308 | for i in range(max_iter): 309 | if i > 0: 310 | w.grad.data.fill_(0) 311 | model.zero_grad() 312 | if loss_str == 'l2': 313 | loss = torch.pow(w - input_var, 2).sum() 314 | elif loss_str == 'linf': 315 | loss = torch.clamp((w - input_var).abs() - bound, min=0).sum() 316 | else: 317 | raise ValueError('Unsupported loss: %s' % loss_str) 318 | recons_loss = loss.data[0] 319 | w_data = w.data 320 | if crop_frac < 1 and i % 3 == 1: 321 | w_cropped = torch.zeros( 322 | w.size(0), w.size(1), int(crop_frac * w.size(2)), 323 | int(crop_frac * w.size(3))) 324 | locs = torch.zeros(w.size(0), 4).long() 325 | w_in = torch.zeros(w.size()) 326 | for m in range(w.size(0)): 327 | locs[m] = torch.LongTensor(Crop('random', crop_frac)(w_data[m])) 328 | w_cropped = w_data[m, :, locs[m][0]:(locs[m][0] + locs[m][2]), 329 | locs[m][1]:(locs[m][1] + locs[m][3])] 330 | minimum = w_cropped.min() 331 | maximum = w_cropped.max() - minimum 332 | w_in[m] = to_tensor(scale_up(to_pil((w_cropped - minimum) / maximum))) 333 | w_in[m] = w_in[m] * maximum + minimum 334 | w_in = torch.autograd.Variable(w_in, requires_grad=True) 335 | else: 336 | w_in = torch.autograd.Variable(w_data, requires_grad=True) 337 | if drop_rate == 0 and i % 3 == 2: 338 | output = model.forward(w_in) 339 | else: 340 | output = model.forward(torch.nn.Dropout(p=drop_rate).forward(w_in)) 341 | for j in range(output.size(0)): 342 | loss += weight * torch.clamp( 343 | output[j][target[j]] - output[j][argmax[j]] + kappa, min=0).cpu() 344 | adv_loss = loss.data[0] - recons_loss 345 | if is_gpu: 346 | loss = loss.cuda() 347 | loss.backward() 348 | if crop_frac < 1 and i % 3 == 1: 349 | grad_full = torch.zeros(w.size()) 350 | grad_cpu = w_in.grad.data 351 | for m in range(w.size(0)): 352 | minimum = grad_cpu[m].min() 353 | maximum = grad_cpu[m].max() - minimum 354 | grad_m = to_tensor(scale_down( 355 | to_pil((grad_cpu[m] - minimum) / maximum))) 356 | grad_m = grad_m * maximum + minimum 357 | grad_full[m, :, locs[m][0]:(locs[m][0] + locs[m][2]), 358 | locs[m][1]:(locs[m][1] + locs[m][3])] = grad_m 359 | w.grad.data.add_(grad_full) 360 | else: 361 | w.grad.data.add_(w_in.grad.data) 362 | w_cpu = w.data.numpy() 363 | input_np = input.numpy() 364 | tv_loss = 0 365 | if tv_weight > 0: 366 | for j in range(output.size(0)): 367 | for k in range(3): 368 | tv_loss += tv_weight * minimize_tv.tv( 369 | w_cpu[j, k] - input_np[j, k], p) 370 | grad = tv_weight * torch.from_numpy( 371 | minimize_tv.tv_dx(w_cpu[j, k] - input_np[j, k], p)) 372 | w.grad.data[j, k].add_(grad.float()) 373 | optimizer.step() 374 | total_loss = loss.data.cpu()[0] + tv_loss 375 | # w.data = utils.img_to_tensor(utils.transform_img(w.data), scale=False) 376 | output_vec = w.data 377 | preds = util.get_labels(model, output_vec) 378 | output_vec = output_vec.view(output_vec.size(0), -1) 379 | diff = (input_vec - output_vec).norm(2, 1).squeeze() 380 | diff = diff.div(input_vec.norm(2, 1).squeeze()) 381 | rb = diff.mean() 382 | sr = float(preds.ne(target).sum()) / target.size(0) 383 | if verbose: 384 | print('iteration %d: loss = %f, %s_loss = %f, ' 385 | 'adv_loss = %f, tv_loss = %f' % ( 386 | i + 1, total_loss, loss_str, recons_loss, adv_loss, tv_loss)) 387 | print('robustness = %f, success rate = %f' % (rb, sr)) 388 | if total_loss < best_loss: 389 | best_loss = total_loss 390 | best_w = w.data.clone() 391 | pred_xp = util.get_labels(model, best_w) 392 | status = torch.zeros(input.size(0)).long() 393 | status[corr] = 2 * pred[corr].ne(pred_xp[corr]).long() - 1 394 | return (status, best_w - input) 395 | 396 | 397 | # random signs 398 | def rand_sign(model, input, target, step_size, num_bins=100): 399 | x = torch.autograd.Variable(input, requires_grad=True) 400 | output = model.forward(x) 401 | _, pred = output.data.max(1) 402 | pred = pred.squeeze().cpu() 403 | corr = pred.eq(target) 404 | target = torch.autograd.Variable(target) 405 | P = torch.ones(input.size(0), num_bins) 406 | sign = 2 * torch.bernoulli(P) - 1 407 | H = torch.rand(input.size()) 408 | H = (H * num_bins).floor().int() 409 | r = torch.zeros(input.size()) 410 | for i in range(input.size(0)): 411 | for j in range(num_bins): 412 | r[i][H[i].eq(j)] = sign[i, j] 413 | xp = x + step_size * torch.autograd.Variable(r.cuda()) 414 | output_xp = model.forward(xp).clone() 415 | _, pred_xp = output_xp.data.max(1) 416 | pred_xp = pred_xp.squeeze().cpu() 417 | status = torch.zeros(x.size(0)).long() 418 | status[corr] = 2 * pred[corr].ne(pred_xp[corr]).long() - 1 419 | return (status, step_size * r) 420 | -------------------------------------------------------------------------------- /adversarial/lib/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | from enum import Enum 14 | 15 | 16 | class DefenseType(Enum): 17 | RAW = "raw" 18 | TVM = 'tvm' 19 | QUILTING = 'quilting' 20 | ENSEMBLE_TRAINING = 'ensemble_training' 21 | JPEG = 'jpeg' 22 | QUANTIZATION = 'quantize' 23 | 24 | @classmethod 25 | def has_value(cls, value): 26 | return (any(value == item.value for item in cls)) 27 | 28 | def __str__(self): 29 | return str(self.value) 30 | 31 | 32 | class AdversaryType(Enum): 33 | FGS = "fgs" 34 | IFGS = 'ifgs' 35 | CWL2 = 'cwl2' 36 | CWLINF = 'cwlinf' 37 | DEEPFOOL = 'deepfool' 38 | 39 | @classmethod 40 | def has_value(cls, value): 41 | return (any(value == item.value for item in cls)) 42 | 43 | def __str__(self): 44 | return str(self.value) 45 | 46 | 47 | class AttackType(Enum): 48 | WHITEBOX = "whitebox" 49 | BLACKBOX = 'blackbox' 50 | 51 | @classmethod 52 | def has_value(cls, value): 53 | return (any(value == item.value for item in cls)) 54 | 55 | def __str__(self): 56 | return str(self.value) 57 | 58 | 59 | # Constants 60 | # Transformations params 61 | QUILTING_PATCH_SIZE = 5 62 | TVM_WEIGHT = 0.03 63 | PIXEL_DROP_RATE = 0.5 64 | TVM_METHOD = 'bregman' 65 | 66 | # Data params 67 | INCEPTION_V4_DATA_PARAMS = { 68 | 'MEAN_STD': { 69 | 'MEAN': [0.5, 0.5, 0.5], 70 | 'STD': [0.5, 0.5, 0.5], 71 | }, 72 | 'IMAGE_SIZE': 299, 73 | 'IMAGE_SCALE_SIZE': 342, 74 | 'NUM_CLASSES': 1000 75 | } 76 | 77 | RESNET_DENSENET_DATA_PARAMS = { 78 | 'MEAN_STD': { 79 | 'MEAN': [0.485, 0.456, 0.406], 80 | 'STD': [0.229, 0.224, 0.225], 81 | }, 82 | 'IMAGE_SIZE': 224, 83 | 'IMAGE_SCALE_SIZE': 256, 84 | 'NUM_CLASSES': 1000 85 | } 86 | 87 | # Same as paper:https://arxiv.org/pdf/1602.07261.pdf 88 | INCEPTION_V4_TRAINING_PARAMS = { 89 | 'LR': 0.045, 90 | 'LR_DECAY': 0.94, 91 | 'LR_DECAY_STEPSIZE': 2, 92 | 'EPOCHS': 160, 93 | 'RMS_EPS': 1.0, 94 | 'RMS_ALPHA': 0.9, 95 | } 96 | 97 | TRAINING_PARAMS = { 98 | 'LR': 0.1, 99 | 'LR_DECAY': 0.1, 100 | 'LR_DECAY_STEPSIZE': 30, 101 | 'EPOCHS': 90, 102 | } 103 | 104 | # List of supported models 105 | MODELS = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 106 | 'DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161', 107 | 'Inception3', 'inception_v3', 'inception_v4'] 108 | 109 | 110 | # Read pre-calculated params for adversary for different settings 111 | def get_adv_params(args, defense_idx=0): 112 | 113 | if args.attack_type is None: 114 | assert args.adv_strength is None and args.learning_rate is None, \ 115 | ("Either adversarial strength or learning rate needs to be provided" 116 | " when attack_type is not defined") 117 | 118 | if args.adversary_to_generate: 119 | adversary = args.adversary_to_generate 120 | else: 121 | adversary = args.adversary 122 | learning_rate = None 123 | adv_strength = None 124 | 125 | # get adv_strength from input arguments 126 | if args.adv_strength is not None: 127 | # No defense 128 | if args.defenses is None: 129 | assert (len(args.adv_strength) == 1 and 130 | defense_idx == 0) 131 | else: 132 | # adv_strength is provided for each defense 133 | assert (len(args.defenses) == len(args.adv_strength) and 134 | defense_idx < len(args.adv_strength)) 135 | adv_strength = args.adv_strength[defense_idx] 136 | 137 | # get learning_rate from input arguments 138 | if args.learning_rate is not None: 139 | # No defense 140 | if args.defenses is None: 141 | assert (len(args.learning_rate) == 1 and 142 | defense_idx == 0) 143 | else: 144 | # learning_rate is provided for each defense 145 | assert (len(args.defenses) == len(args.learning_rate) and 146 | defense_idx < len(args.learning_rate)) 147 | learning_rate = args.learning_rate[defense_idx] 148 | 149 | # if adversary params are not provided in input arguments, 150 | # then use below precomputed params on resnet50 151 | # parameters maintain L2 dissimilarity of ~0.06 152 | if adv_strength is None and learning_rate is None: 153 | assert (args.attack_type is None or 154 | AttackType.has_value(args.attack_type)) 155 | consts = {adversary: (None, None)} 156 | # params for blackbox attack 157 | if args.attack_type == str(AttackType.BLACKBOX): 158 | consts = { 159 | str(AdversaryType.IFGS): (0.021, None), 160 | str(AdversaryType.DEEPFOOL): (0.96, None), 161 | str(AdversaryType.CWL2): (None, 31.5), 162 | str(AdversaryType.FGS): (None, 0.07), 163 | } 164 | elif args.attack_type == str(AttackType.WHITEBOX): 165 | if args.defenses[defense_idx] == str(DefenseType.TVM): 166 | consts = { 167 | str(AdversaryType.IFGS): (0.018, None), 168 | str(AdversaryType.DEEPFOOL): (3.36, None), 169 | str(AdversaryType.CWL2): (None, 126), 170 | str(AdversaryType.FGS): (None, 0.07), 171 | } 172 | 173 | elif args.defenses[defense_idx] == str(DefenseType.QUILTING): 174 | consts = { 175 | str(AdversaryType.IFGS): (0.015, None), 176 | str(AdversaryType.DEEPFOOL): (0.42, None), 177 | str(AdversaryType.CWL2): (None, 17.4), 178 | str(AdversaryType.FGS): (None, 0.07), 179 | } 180 | 181 | # If model used is InceptionResnetV2 using ensemble training 182 | elif args.defenses[defense_idx] == DefenseType.ENSEMBLE_TRAINING: 183 | consts = { 184 | str(AdversaryType.IFGS): (0.01, None), 185 | str(AdversaryType.DEEPFOOL): (1.1, None), 186 | str(AdversaryType.CWL2): (None, 16.5), 187 | str(AdversaryType.FGS): (None, 0.07), 188 | } 189 | return dict(zip(('learning_rate', 'adv_strength'), consts[adversary])) 190 | else: 191 | return {'learning_rate': learning_rate, 'adv_strength': adv_strength} 192 | -------------------------------------------------------------------------------- /adversarial/lib/convnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import progressbar 14 | 15 | import torch 16 | import torch.nn as nn 17 | from lib.util import accuracy 18 | 19 | 20 | # function that trains a model: 21 | def train(model, criterion, optimizer, data_loader_hook=None, 22 | start_epoch_hook=None, end_epoch_hook=None, 23 | start_epoch=0, end_epoch=90, learning_rate=0.1): 24 | 25 | # assertions: 26 | assert isinstance(model, nn.Module) 27 | assert isinstance(criterion, nn.modules.loss._Loss) 28 | assert isinstance(optimizer, torch.optim.Optimizer) 29 | assert type(start_epoch) == int and start_epoch >= 0 30 | assert type(end_epoch) == int and end_epoch >= start_epoch 31 | assert type(learning_rate) == float and learning_rate > .0 32 | if start_epoch_hook is not None: 33 | assert callable(start_epoch_hook) 34 | if end_epoch_hook is not None: 35 | assert callable(end_epoch_hook) 36 | assert data_loader_hook is not None 37 | assert callable(data_loader_hook) 38 | 39 | # are we on CPU or GPU? 40 | is_gpu = not isinstance(model, torch.nn.backends.thnn.THNNFunctionBackend) 41 | 42 | # train the model: 43 | model.train() 44 | for epoch in range(start_epoch, end_epoch): 45 | 46 | data_loader = data_loader_hook(epoch) 47 | assert isinstance(data_loader, torch.utils.data.dataloader.DataLoader) 48 | 49 | # start-of-epoch hook: 50 | if start_epoch_hook is not None: 51 | start_epoch_hook(epoch, model, optimizer) 52 | 53 | # loop over training data: 54 | model.train() 55 | precs1, precs5, num_batches, num_total = [], [], 0, 0 56 | bar = progressbar.ProgressBar(len(data_loader)) 57 | bar.start() 58 | for num_batches, (imgs, targets) in enumerate(data_loader): 59 | 60 | # copy data to GPU: 61 | if is_gpu: 62 | cpu_targets = targets.clone() 63 | targets = targets.cuda(async=True) 64 | # Make sure the imgs are converted to cuda tensor too 65 | imgs = imgs.cuda(async=True) 66 | 67 | imgsvar = torch.autograd.Variable(imgs) 68 | tgtsvar = torch.autograd.Variable(targets) 69 | 70 | # perform forward pass: 71 | out = model(imgsvar) 72 | loss = criterion(out, tgtsvar) 73 | 74 | # measure accuracy: 75 | prec1, prec5 = accuracy(out.data.cpu(), cpu_targets, topk=(1, 5)) 76 | precs1.append(prec1[0] * targets.size(0)) 77 | precs5.append(prec5[0] * targets.size(0)) 78 | num_total += imgs.size(0) 79 | 80 | # compute gradient and do SGD step: 81 | optimizer.zero_grad() 82 | loss.backward() 83 | optimizer.step() 84 | bar.update(num_batches) 85 | 86 | # end-of-epoch hook: 87 | if end_epoch_hook is not None: 88 | prec1 = sum(precs1) / num_total 89 | prec5 = sum(precs5) / num_total 90 | end_epoch_hook(epoch, model, optimizer, prec1=prec1, prec5=prec5) 91 | 92 | # return trained model: 93 | return model 94 | 95 | 96 | # helper function that test a model: 97 | def _test(model, data_loader, return_probability=False): 98 | 99 | # assertions 100 | assert isinstance(model, torch.nn.Module) 101 | assert isinstance(data_loader, torch.utils.data.dataloader.DataLoader) 102 | 103 | # are we on CPU or GPU? 104 | is_gpu = not isinstance(model, torch.nn.backends.thnn.THNNFunctionBackend) 105 | 106 | # loop over data: 107 | model.eval() 108 | precs1, precs5, num_batches, num_total = [], [], 0, 0 109 | probs, all_targets = None, None 110 | bar = progressbar.ProgressBar(len(data_loader)) 111 | bar.start() 112 | for num_batches, (imgs, targets) in enumerate(data_loader): 113 | 114 | # copy data to GPU: 115 | if is_gpu: 116 | cpu_targets = targets.clone() 117 | targets = targets.cuda(async=True) 118 | # Make sure the imgs are converted to cuda tensor too 119 | imgs = imgs.cuda(async=True) 120 | 121 | # perform prediction: 122 | imgsvar = torch.autograd.Variable(imgs.squeeze(), volatile=True) 123 | output = model(imgsvar) 124 | pred = output.data.cpu() 125 | 126 | if return_probability: 127 | probs = pred if probs is None else torch.cat((probs, pred), dim=0) 128 | all_targets = targets if all_targets is None else ( 129 | torch.cat((all_targets, targets), dim=0)) 130 | 131 | # measure accuracy: 132 | prec1, prec5 = accuracy(pred, cpu_targets, topk=(1, 5)) 133 | precs1.append(prec1[0] * targets.size(0)) 134 | precs5.append(prec5[0] * targets.size(0)) 135 | num_total += imgs.size(0) 136 | bar.update(num_batches) 137 | 138 | if return_probability: 139 | return probs, all_targets 140 | else: 141 | # return average accuracy (@ 1 and 5): 142 | return sum(precs1) / num_total, sum(precs5) / num_total 143 | 144 | 145 | def test(model, data_loader): 146 | return _test(model, data_loader, return_probability=False) 147 | 148 | 149 | def get_prob(model, data_loader): 150 | return _test(model, data_loader, return_probability=True) 151 | -------------------------------------------------------------------------------- /adversarial/lib/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import torch 14 | from lib.datasets.sub_dataset_folder import PartialImageFolder 15 | from lib.datasets.sub_dataset_tarfolder import PartialTarFolder 16 | from lib.datasets.transform_dataset import TransformDataset 17 | from lib.transformations.transformation_helper import setup_transformations 18 | import lib.paths as paths 19 | import os 20 | from lib.datasets.dataset_classes_folder import ImageClassFolder 21 | from six import string_types 22 | 23 | 24 | # helper function for loading adversarial data: 25 | def _load_adversarial_helper(path, eps=0, 26 | normalize=True, mask=None, preprocessed=False): 27 | 28 | # assertions: 29 | assert isinstance(path, string_types) 30 | if mask is not None: 31 | assert torch.is_tensor(mask) 32 | 33 | # load file with images and perturbations: 34 | result = torch.load(path) 35 | all_outputs = result['all_outputs'] 36 | all_targets = result['all_targets'] 37 | 38 | if not preprocessed: 39 | # construct adversarial examples: 40 | if eps > 0: 41 | all_inputs = result['all_inputs'] 42 | r = all_outputs - all_inputs 43 | if normalize: 44 | r = r.sign() 45 | all_outputs = all_inputs + eps * r 46 | if mask is not None: 47 | all_idx = torch.arange(0, all_outputs.size(0)).long() 48 | mask_idx = all_idx[mask] 49 | all_outputs = all_outputs[mask_idx] 50 | all_targets = all_targets[mask_idx] 51 | 52 | return all_outputs, all_targets 53 | 54 | 55 | # function that loads data for a given adversary 56 | def load_adversarial(args, img_dir, defense_name, adv_params, 57 | data_batch_idx=None): 58 | # assertions: 59 | assert isinstance(img_dir, string_types) 60 | 61 | data_file = paths.get_adversarial_file_path( 62 | args, img_dir, defense_name, adv_params, args.n_samples, 63 | data_batch_idx=data_batch_idx, 64 | with_defense=args.preprocessed_data) 65 | 66 | assert os.path.isfile(data_file), \ 67 | "No file found at " + data_file 68 | # load the adversarial examples" 69 | print('| loading adversarial examples from %s...' % data_file) 70 | if not (args.adversary == 'cwl2' or args.adversary == 'fgs'): 71 | adv_strength = 0.0 72 | else: 73 | adv_strength = adv_params['adv_strength'] 74 | 75 | normalize = True 76 | if args.adversary == 'cwl2' and adv_strength > 0: 77 | normalize = False 78 | return _load_adversarial_helper(data_file, adv_strength, 79 | normalize=normalize, 80 | preprocessed=args.preprocessed_data) 81 | 82 | 83 | # init dataset 84 | def get_dataset(args, img_dir, defense_name, adv_params, transform, 85 | data_batch_idx=None, class_indices=None, data_indices=None): 86 | 87 | # assertions 88 | if 'preprocessed_data' in args and args.preprocessed_data: 89 | assert defense_name is not None, ( 90 | "If data is already pre processed for defenses then " 91 | "defenses can't be None") 92 | 93 | # for data without adversary 94 | if 'adversary' not in args or args.adversary is None: 95 | # For pre-applied defense, read data from tar files 96 | if 'preprocessed_data' in args and args.preprocessed_data: 97 | # get prefix in tar member names 98 | if defense_name: 99 | if args.tar_prefix: 100 | tar_prefix = str(args.tar_prefix + '/' + defense_name) 101 | else: 102 | tar_prefix = defense_name 103 | else: 104 | tar_prefix = args.tar_prefix 105 | dataset = PartialTarFolder(img_dir, path_prefix=tar_prefix, 106 | transform=transform) 107 | else: 108 | if class_indices: 109 | # Load data for only target classes(helpful in parallel processing) 110 | dataset = ImageClassFolder(img_dir, class_indices, transform) 111 | else: 112 | # dataset = ImageFolder( 113 | # img_dir, transform=transform) 114 | dataset = PartialImageFolder( 115 | img_dir, data_indices=data_indices, transform=transform) 116 | 117 | else: # adversary 118 | # Load adversarial dataset 119 | adv_data, targets = load_adversarial(args, img_dir, 120 | defense_name, adv_params, 121 | data_batch_idx=data_batch_idx) 122 | dataset = TransformDataset( 123 | torch.utils.data.TensorDataset(adv_data, targets), transform, data_indices) 124 | 125 | return dataset 126 | 127 | 128 | def load_dataset(args, data_type, defense, adv_params=None, crop=None, 129 | epoch=-1, data_batch_idx=None, 130 | class_indices=None, data_indices=None, 131 | with_transformation=True): 132 | 133 | # assertions: 134 | assert (data_type == 'train' or data_type == 'valid'), ( 135 | "{} data type not defined. Defined types are \"train\" " 136 | "and \"valid\" ".format(data_type)) 137 | if defense is not None: 138 | assert callable(defense), ( 139 | "defense should be a callable method") 140 | 141 | # get data directory 142 | img_dir = paths.get_img_dir(args, data_type, epoch=epoch) 143 | 144 | # setup transformations to apply on loaded data 145 | transform = None 146 | if with_transformation: 147 | transform = setup_transformations(args, data_type, defense, 148 | crop=crop) 149 | defense_name = None if defense is None else defense.get_name() 150 | # initialize dataset 151 | print('| Loading data from ' + img_dir) 152 | dataset = get_dataset(args, img_dir, defense_name, adv_params, transform, 153 | data_batch_idx=data_batch_idx, 154 | class_indices=class_indices, 155 | data_indices=data_indices) 156 | return dataset 157 | 158 | 159 | # function that constructs a data loader for a dataset: 160 | def get_data_loader(dataset, batchsize=32, workers=10, device='cpu', shuffle=True): 161 | 162 | # assertions: 163 | assert isinstance(dataset, torch.utils.data.dataset.Dataset) 164 | assert type(batchsize) == int and batchsize > 0 165 | assert type(workers) == int and workers >= 0 166 | assert device == 'cpu' or device == 'gpu' 167 | 168 | # construct data loader: 169 | return torch.utils.data.DataLoader( 170 | dataset, 171 | batch_size=batchsize, 172 | shuffle=shuffle, 173 | num_workers=workers, 174 | pin_memory=(device == 'gpu'), 175 | ) 176 | 177 | 178 | # helper function to make image look pretty: 179 | def visualize_image(args, img, normalize=True): 180 | assert torch.is_tensor(img) and img.dim() == 3 181 | new_img = img.clone() 182 | if normalize: 183 | for c in range(new_img.size(0)): 184 | new_img[c].mul_( 185 | args.data_params['MEAN_STD']['STD'][c]).add_( 186 | args.data_params['MEAN_STD']['MEAN'][c]) 187 | return torch.mul(new_img, 255.).byte() 188 | -------------------------------------------------------------------------------- /adversarial/lib/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookarchive/adversarial_image_defenses/55bf56ddb017535fb6630a746c6f946202336052/adversarial/lib/datasets/__init__.py -------------------------------------------------------------------------------- /adversarial/lib/datasets/dataset_classes_folder.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import torch.utils.data as data 7 | import os 8 | import os.path 9 | from torchvision.datasets.folder import ( 10 | is_image_file, find_classes, default_loader, IMG_EXTENSIONS) 11 | 12 | 13 | def make_dataset(dir, class_to_idx, target_indices): 14 | images = [] 15 | dir = os.path.expanduser(dir) 16 | for target in sorted(os.listdir(dir)): 17 | if not os.path.isdir(os.path.join(dir, target)): 18 | continue 19 | # Only get data for target classes 20 | if not class_to_idx[target] in target_indices: 21 | continue 22 | d = os.path.join(dir, target) 23 | if not os.path.isdir(d): 24 | continue 25 | 26 | for root, _, fnames in sorted(os.walk(d)): 27 | for fname in sorted(fnames): 28 | if is_image_file(fname): 29 | path = os.path.join(root, fname) 30 | item = (path, class_to_idx[target], fname) 31 | images.append(item) 32 | 33 | return images 34 | 35 | 36 | # Extend data.Dataset similar to ImageFolder and allow reading only pre-defined classes 37 | class ImageClassFolder(data.Dataset): 38 | """A generic data loader where the images are arranged in this way: :: 39 | 40 | root/dog/xxx.png 41 | root/dog/xxy.png 42 | root/dog/xxz.png 43 | 44 | root/cat/123.png 45 | root/cat/nsdf3.png 46 | root/cat/asd932_.png 47 | 48 | Args: 49 | root (string): Root directory path. 50 | transform (callable, optional): A function/transform that takes in an PIL image 51 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 52 | target_transform (callable, optional): A function/transform that takes in the 53 | target and transforms it. 54 | loader (callable, optional): A function to load an image given its path. 55 | 56 | Attributes: 57 | classes (list): List of the class names. 58 | class_to_idx (dict): Dict with items (class_name, class_index). 59 | imgs (list): List of (image path, class_index) tuples 60 | """ 61 | 62 | def __init__(self, root, target_indices, 63 | transform=None, target_transform=None, loader=default_loader): 64 | classes, class_to_idx = find_classes(root) 65 | imgs = make_dataset(root, class_to_idx, target_indices) 66 | if len(imgs) == 0: 67 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 68 | "Supported image extensions are: " + 69 | ",".join(IMG_EXTENSIONS))) 70 | 71 | self.root = root 72 | self.imgs = imgs 73 | self.classes = classes 74 | self.class_to_idx = class_to_idx 75 | self.transform = transform 76 | self.target_transform = target_transform 77 | self.loader = loader 78 | 79 | def __getitem__(self, index): 80 | """ 81 | Args: 82 | index (int): Index 83 | 84 | Returns: 85 | tuple: (image, target) where target is class_index of the target class. 86 | """ 87 | path, target, file_name = self.imgs[index] 88 | img = self.loader(path) 89 | if self.transform is not None: 90 | img = self.transform(img) 91 | if self.target_transform is not None: 92 | target = self.target_transform(target) 93 | 94 | return img, target, file_name 95 | 96 | def __len__(self): 97 | return len(self.imgs) 98 | 99 | def get_idx_to_class(self, index): 100 | return self.classes[index] 101 | -------------------------------------------------------------------------------- /adversarial/lib/datasets/sub_dataset_folder.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | from torchvision.datasets.folder import ImageFolder, default_loader 7 | 8 | 9 | class PartialImageFolder(ImageFolder): 10 | def __init__(self, root, data_indices=None, 11 | transform=None, target_transform=None, loader=default_loader): 12 | super(PartialImageFolder, self).__init__( 13 | root, transform=transform, target_transform=target_transform, loader=loader) 14 | 15 | if data_indices: 16 | assert ('start_idx' in data_indices and 17 | isinstance(data_indices['start_idx'], int)), \ 18 | "data_indices expects argument start_idx of int type" 19 | assert ('end_idx' in data_indices and 20 | isinstance(data_indices['end_idx'], int)), \ 21 | "data_indices expects argument end_idx of int type" 22 | assert data_indices['start_idx'] < len(self.imgs) 23 | 24 | end_idx = min(data_indices['end_idx'], len(self.imgs)) 25 | self.imgs = self.imgs[data_indices['start_idx']:end_idx] 26 | 27 | def get_classes(self): 28 | return self.classes 29 | -------------------------------------------------------------------------------- /adversarial/lib/datasets/sub_dataset_tarfolder.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | from lib.datasets.tarfolder import ImageTarFile, default_loader 7 | 8 | 9 | class PartialTarFolder(ImageTarFile): 10 | def __init__(self, tar_index_file, data_indices=None, 11 | path_prefix='', transform=None, 12 | target_transform=None, loader=default_loader): 13 | super(PartialTarFolder, self).__init__( 14 | tar_index_file, path_prefix=path_prefix, transform=transform, 15 | target_transform=target_transform, loader=loader) 16 | 17 | if data_indices: 18 | assert ('start_idx' in data_indices and 19 | isinstance(data_indices['start_idx'], int)), \ 20 | "data_indices expects argument start_idx of int type" 21 | assert ('end_idx' in data_indices and 22 | isinstance(data_indices['end_idx'], int)), \ 23 | "data_indices expects argument end_idx of int type" 24 | assert data_indices['start_idx'] < len(self.imgs) 25 | 26 | end_idx = min(data_indices['end_idx'], len(self.imgs)) 27 | self.imgs = self.imgs[data_indices['start_idx']:end_idx] 28 | self.img2tarfile = self.img2tarfile[data_indices['start_idx']:end_idx] 29 | -------------------------------------------------------------------------------- /adversarial/lib/datasets/tar_metadata.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | 7 | def extract_classname_from_member(member_name): 8 | # expects filename format as "root/class_name/img.ext" 9 | class_name = str(member_name.split('/')[-2]) 10 | return class_name 11 | 12 | 13 | # Handles a directory of tar files 14 | class TarDirMetaData(object): 15 | 16 | def __init__(self, tar_dir): 17 | assert isinstance(tar_dir, str) 18 | self.tar_dir = tar_dir 19 | self.tarfiles = [] 20 | self.classes = set() 21 | 22 | def __len__(self): 23 | return len(self.tarfiles) 24 | 25 | def add_file(self, tarfile): 26 | assert isinstance(tarfile, TarFileMetadata) 27 | self.tarfiles.append(tarfile) 28 | for class_name in tarfile.classes: 29 | if class_name not in self.classes: 30 | self.classes.add(class_name) 31 | 32 | def __getitem__(self, index): 33 | assert index < len(self.tarfiles) 34 | return self.tarfiles[index] 35 | 36 | 37 | # Handles a tarfile with data files 38 | class TarFileMetadata(object): 39 | 40 | def __init__(self, tarfile): 41 | assert isinstance(tarfile, str) 42 | self.tarfile = tarfile 43 | self.files = [] 44 | self.classes = set() 45 | 46 | def __len__(self): 47 | return len(self.files) 48 | 49 | def add_file(self, file_metadata): 50 | assert isinstance(file_metadata, DataFileMetadata) 51 | self.files.append(file_metadata) 52 | filename = file_metadata.filename 53 | class_name = extract_classname_from_member(filename) 54 | if class_name not in self.classes: 55 | self.classes.add(class_name) 56 | 57 | def __getitem__(self, index): 58 | assert index < len(self.files) 59 | return self.files[index] 60 | 61 | 62 | # Handles a data file inside a tarfile 63 | class DataFileMetadata(object): 64 | def __init__(self, filename, offset, size): 65 | self.filename = filename 66 | self.offset = offset 67 | self.size = size 68 | 69 | def get_metadata(self): 70 | return self.filename, self.offset, self.size 71 | -------------------------------------------------------------------------------- /adversarial/lib/datasets/tarfolder.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import unicode_literals 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import torch.utils.data as data 7 | 8 | import sys 9 | from os import listdir 10 | from os.path import isdir, isfile, join, basename 11 | import tarfile 12 | from lib.datasets.tar_metadata import (TarDirMetaData, TarFileMetadata, 13 | DataFileMetadata, extract_classname_from_member) 14 | from PIL import Image 15 | try: 16 | import cPickle as pickle 17 | except ImportError: 18 | import pickle 19 | 20 | if sys.version_info[0] == 3: # for python3 21 | from io import StringIO 22 | py3 = True 23 | else: 24 | from cStringIO import StringIO 25 | py3 = False 26 | 27 | 28 | IMG_EXTENSIONS = [ 29 | '.jpg', '.JPG', '.jpeg', '.JPEG', 30 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 31 | ] 32 | TAR_EXTENSIONS = ['.tar', '.TAR', '.tar.gz', '.TAR.GZ'] 33 | 34 | 35 | def _is_tar_file(filename): 36 | return any(filename.endswith(extension) for extension in TAR_EXTENSIONS) 37 | 38 | 39 | def _is_image_file(filename): 40 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 41 | 42 | 43 | # creates an object of tarinfo from tarfile 44 | def indextar(dbtarfile, path_prefix=''): 45 | tar_file_obj = TarFileMetadata(dbtarfile) 46 | with tarfile.open(dbtarfile, 'r|') as db: 47 | counter = 0 48 | for tarinfo in db: 49 | if _is_image_file(tarinfo.name): 50 | # Only consider filenames with path_prefix 51 | if not path_prefix or path_prefix in tarinfo.name: 52 | data_file_obj = DataFileMetadata(tarinfo.name, 53 | tarinfo.offset_data, 54 | tarinfo.size) 55 | tar_file_obj.add_file(data_file_obj) 56 | counter += 1 57 | if counter % 1000 == 0: 58 | # tarfile object maintains a list of members of the archive, 59 | # and keeps this updated whenever you read or write members 60 | # free ram by clearing the member list... 61 | db.members = [] 62 | 63 | if len(tar_file_obj) == 0: 64 | print('No file with {} prefix found in {}'.format(path_prefix, dbtarfile)) 65 | return None 66 | 67 | return tar_file_obj 68 | 69 | 70 | def gen_tar_index(tar_path, index_outdir, path_prefix='', verbose=False): 71 | # assertions 72 | assert isinstance(tar_path, str), \ 73 | "tar_path \'{}\'should be of str type".format(tar_path) 74 | assert (isdir(tar_path) or isfile(tar_path)), \ 75 | "No tar file(s) exist at path {}".format(tar_path) 76 | assert isinstance(index_outdir, str), \ 77 | "Expected index_outdir to be of str type" 78 | 79 | # Collect name of all tar files at given path 80 | tarfiles = [] 81 | if (isfile(tar_path) and 82 | _is_tar_file(tar_path)): 83 | tarfiles.append(tar_path) 84 | outfile = str(basename(tar_path).split('.')[0] + '.index') 85 | elif isdir(tar_path): 86 | for f in listdir(tar_path): 87 | if _is_tar_file(f) and isfile(join(tar_path, f)): 88 | tarfiles.append(join(tar_path, f)) 89 | outfile = str(basename(tar_path) + '.index') 90 | 91 | if len(tarfiles) == 0: 92 | raise(RuntimeError("No tarfile found at the given path")) 93 | 94 | # Read all tar file indices in a single object 95 | tar_dir_obj = TarDirMetaData(tar_path) 96 | count = {} 97 | for idx, tar in enumerate(tarfiles): 98 | tar_file_obj = indextar(tar, path_prefix) 99 | count[tar] = len(tar_file_obj) if tar_file_obj else 0 100 | if tar_file_obj: 101 | tar_dir_obj.add_file(tar_file_obj) 102 | sys.stdout.write("\r%0.2f%%" % ((float(idx) * 100) / len(tarfiles))) 103 | sys.stdout.flush() 104 | sys.stdout.write("\n") 105 | if verbose: 106 | print("Number of files for each tarfile:") 107 | print(count) 108 | 109 | outfile = join(index_outdir, outfile) 110 | 111 | # Save tar index object 112 | f = open(outfile, 'wb') 113 | pickle.dump(tar_dir_obj, f) 114 | f.close() 115 | 116 | print('Saved tar index object in file {}'.format(outfile)) 117 | return outfile 118 | 119 | 120 | def get_tar_index_files(tar_dir, index_dir, tar_ext='.tar.gz'): 121 | assert tar_dir is not None 122 | assert index_dir is not None 123 | tarfiles = [] 124 | indexfiles = [] 125 | for f in listdir(tar_dir): 126 | if (_is_tar_file(f) and isfile(join(tar_dir, f)) and 127 | isfile(join(index_dir, f.replace(tar_ext, '.index')))): 128 | tarfiles.append(join(tar_dir, f)) 129 | indexfiles.append(join(index_dir, f.replace('.tar.gz', '.index'))) 130 | return tarfiles, indexfiles 131 | 132 | 133 | def tar_lookup(tarfile, datafile_metadata): 134 | assert _is_tar_file(tarfile) 135 | assert isinstance(datafile_metadata, DataFileMetadata) 136 | with open(tarfile, 'r') as tar: 137 | tar.seek(int(datafile_metadata.offset)) 138 | buffer = tar.read(int(datafile_metadata.size)) 139 | return buffer 140 | 141 | 142 | def make_dataset(tar_file_obj, class_to_idx, path_prefix=''): 143 | assert isinstance(tar_file_obj, TarFileMetadata) 144 | assert isinstance(path_prefix, str) 145 | images = [] 146 | members_name = [] 147 | members = {} 148 | for idx in range(len(tar_file_obj)): 149 | data_file = tar_file_obj[idx] 150 | filename, offset, size = data_file.get_metadata() 151 | if _is_image_file(filename): 152 | if not path_prefix or (path_prefix in filename and path_prefix != filename): 153 | members_name.append(filename) 154 | members[filename] = data_file 155 | members_name = sorted(members_name) 156 | for member_name in members_name: 157 | # item = (DataFileMetadata, class index) 158 | item = (members[member_name], 159 | class_to_idx[extract_classname_from_member(member_name)]) 160 | images.append(item) 161 | 162 | return images 163 | 164 | 165 | def default_loader(tarfile_path, datafile_metadata): 166 | assert isinstance(tarfile_path, str) 167 | assert isinstance(datafile_metadata, DataFileMetadata) 168 | buffer = tar_lookup(tarfile_path, datafile_metadata) 169 | if py3: 170 | img = Image.open(StringIO.read(StringIO(buffer))) 171 | else: 172 | img = Image.open(StringIO.StringIO(buffer)) 173 | return img.convert('RGB') 174 | 175 | 176 | def get_classes_from_tar_dir(tar_dir_obj): 177 | # convert set to list 178 | classes = list(tar_dir_obj.classes) 179 | classes.sort() 180 | class_to_idx = {classes[i]: i for i in range(len(classes))} 181 | 182 | return classes, class_to_idx 183 | 184 | 185 | # reads TarDirMetaData or TarFileMetadata object from a serialized file 186 | # which stores metadata for tar files 187 | def read_tar_index(tar_index_file): 188 | assert tar_index_file is not None and isfile(tar_index_file) 189 | 190 | # Load tar indices 191 | f = open(tar_index_file, 'rb') 192 | tar_dir_obj = pickle.load(f) 193 | f.close() 194 | assert (isinstance(tar_dir_obj, TarDirMetaData) or 195 | isinstance(tar_dir_obj, TarFileMetadata)) 196 | # For only one tar file, wrap it into dir class to maintain consistency 197 | if isinstance(tar_dir_obj, TarFileMetadata): 198 | tar_dir_obj = TarDirMetaData(tar_dir_obj) 199 | 200 | return tar_dir_obj 201 | 202 | 203 | class ImageTarFile(data.Dataset): 204 | """A data loader where the images are tarred and arranged in this way: 205 | 206 | root/prefix_for_image_type/class_name/xxx.png 207 | 208 | root/tvm/dog/xxx.png 209 | root/tvm/dog/xxy.png 210 | root/tvm/dog/xxz.png 211 | 212 | root/quilting/cat/123.png 213 | root/quilting/cat/nsdf3.png 214 | root/quilting/cat/asd932_.png 215 | 216 | Args: 217 | tar_index_file (string): Path for TarDirMetadata/TarFileMetadata object 218 | path_prefix (string): path prefix in all tar files (default="") 219 | transform (callable, optional): A function/transform that takes in an PIL image 220 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 221 | target_transform (callable, optional): A function/transform that takes in the 222 | target and transforms it. 223 | loader (callable, optional): A function to load an image given its path. 224 | Attributes: 225 | classes (list): List of the class names. 226 | class_to_idx (dict): Dict with items (class_name, class_index). 227 | imgs (list): List of (tar member_name, taget class index) tuples 228 | """ 229 | 230 | def __init__(self, tar_index_file, path_prefix='', transform=None, 231 | target_transform=None, loader=default_loader): 232 | 233 | assert isinstance(tar_index_file, str), \ 234 | "Expect tar_index_file to be of str type" 235 | assert isinstance(path_prefix, str), \ 236 | "Expect path_prefix to be of str type" 237 | 238 | tar_dir_obj = read_tar_index(tar_index_file) 239 | classes, class_to_idx = get_classes_from_tar_dir(tar_dir_obj) 240 | # get image indexes: 241 | imgs, tar_file_list, img2tarfile = [], [], [] 242 | for idx in range(len(tar_dir_obj)): 243 | tar_file = tar_dir_obj[idx] 244 | _imgs = make_dataset(tar_file, class_to_idx, path_prefix) 245 | imgs += _imgs # NOTE: Does this need to be sorted again by target? 246 | img2tarfile += [idx] * len(_imgs) 247 | tar_file_list.append(tar_file.tarfile) 248 | if len(imgs) == 0: 249 | raise(RuntimeError("Found 0 images in " + str(len(tar_dir_obj.tarfies)) + 250 | " TAR files.\n" + 251 | "Supported image extensions are: " + 252 | ",".join(IMG_EXTENSIONS))) 253 | 254 | # store some fields: 255 | self.tar_dir_obj = tar_dir_obj 256 | self.img2tarfile = img2tarfile 257 | self.tar_file_list = tar_file_list 258 | self.imgs = imgs 259 | self.classes = classes 260 | self.class_to_idx = class_to_idx 261 | self.transform = transform 262 | self.target_transform = target_transform 263 | self.loader = loader 264 | 265 | def __getitem__(self, index): 266 | """ 267 | Args: 268 | index (int): Index 269 | 270 | Returns: 271 | tuple: (image, target) where target is class_index of the target class. 272 | """ 273 | data_file_obj, target = self.imgs[index] 274 | tar_file = self.tar_file_list[self.img2tarfile[index]] 275 | img = self.loader(tar_file, data_file_obj) 276 | 277 | if self.transform is not None: 278 | img = self.transform(img) 279 | if self.target_transform is not None: 280 | target = self.target_transform(target) 281 | 282 | return img, target 283 | 284 | def __len__(self): 285 | return len(self.imgs) 286 | 287 | def get_class_from_index(self, index): 288 | assert index < len(self.classes), "index can't be greater than numer of classes" 289 | return self.classes[index] 290 | 291 | def get_classes(self): 292 | return self.classes 293 | -------------------------------------------------------------------------------- /adversarial/lib/datasets/transform_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import torch.utils.data 9 | 10 | 11 | class TransformDataset(torch.utils.data.Dataset): 12 | 13 | def __init__(self, dataset, transform=None, data_indices=None): 14 | super(TransformDataset, self).__init__() 15 | # assertions 16 | assert isinstance(dataset, torch.utils.data.Dataset) 17 | if transform is not None: 18 | assert callable(transform) 19 | if data_indices: 20 | assert ('start_idx' in data_indices and 21 | isinstance(data_indices['start_idx'], int)), \ 22 | "data_indices expects argument start_idx of int type" 23 | assert ('end_idx' in data_indices and 24 | isinstance(data_indices['end_idx'], int)), \ 25 | "data_indices expects argument end_idx of int type" 26 | assert data_indices['start_idx'] < len(dataset) 27 | 28 | self.dataset = dataset 29 | self.transform = transform 30 | if data_indices: 31 | end_idx = min(data_indices['end_idx'], len(dataset)) 32 | self.dataset.data_tensor = self.dataset.data_tensor[data_indices['start_idx']:end_idx] 33 | self.dataset.target_tensor = self.dataset.target_tensor[data_indices['start_idx']:end_idx] 34 | 35 | # Apply each transform on an image and concatenate output into multichannel 36 | def __getitem__(self, idx): 37 | item = self.dataset[idx] # sample is a tuple of form (img, target) 38 | if self.transform is not None: 39 | item = (self.transform(item[0]), item[1]) 40 | return item 41 | 42 | def __len__(self): 43 | return len(self.dataset) 44 | -------------------------------------------------------------------------------- /adversarial/lib/defenses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | try: 14 | import cPickle as pickle 15 | except ImportError: 16 | import pickle 17 | import math 18 | 19 | import torch 20 | from lib.transformations.quilting_fast import quilting 21 | from lib.transformations.tvm import reconstruct as tvm 22 | from PIL import Image 23 | from lib.paths import get_quilting_filepaths 24 | from torchvision.transforms import ToPILImage, ToTensor 25 | try: 26 | from cStringIO import StringIO as BytesIO 27 | except ImportError: 28 | from io import BytesIO 29 | from lib.constants import DefenseType 30 | import os 31 | 32 | 33 | def _quantize_img(im, depth=8): 34 | assert torch.is_tensor(im) 35 | N = int(math.pow(2, depth)) 36 | im = (im * N).round() 37 | im = im / N 38 | return im 39 | 40 | 41 | def _jpeg_compression(im): 42 | assert torch.is_tensor(im) 43 | im = ToPILImage()(im) 44 | savepath = BytesIO() 45 | im.save(savepath, 'JPEG', quality=75) 46 | im = Image.open(savepath) 47 | im = ToTensor()(im) 48 | return im 49 | 50 | 51 | # class describing a defense transformation: 52 | class Defense(object): 53 | 54 | def __init__(self, defense, defense_name): 55 | assert callable(defense) 56 | self.defense = defense 57 | self.defense_name = defense_name 58 | 59 | def __call__(self, im): 60 | return self.defense(im) 61 | 62 | def get_name(self): 63 | return self.defense_name 64 | 65 | 66 | # function that returns defense: 67 | def get_defense(defense_name, args): 68 | 69 | print('| Defense: {}'.format(defense_name)) 70 | assert (defense_name is None or DefenseType.has_value(defense_name)), ( 71 | "{} defense type not defined".format(defense_name)) 72 | 73 | defense = None 74 | # set up quilting defense: 75 | if defense_name == str(DefenseType.RAW): 76 | # return image as it is 77 | def defense(im): 78 | return im 79 | defense = Defense(defense, defense_name) 80 | 81 | elif defense_name == str(DefenseType.QUILTING): 82 | # load faiss index: 83 | import faiss 84 | patches_filename, index_filename = get_quilting_filepaths(args) 85 | # If quilting patches doesn't exist, then create them 86 | assert (os.path.isfile(patches_filename) and 87 | os.path.isfile(index_filename)), ( 88 | "ERROR: No quilting patch data found at {}. \n" 89 | "No patch index data found at {}.\n" 90 | "Generate quilting patches using index_patches.py." 91 | "See demo.py for example" 92 | .format(patches_filename, index_filename)) 93 | 94 | print('| Loading quilting patch data from {} ...'.format(patches_filename)) 95 | with open(patches_filename, 'rb') as fread: 96 | patches = pickle.load(fread) 97 | patch_size = int(math.sqrt(patches.size(1) / 3)) 98 | faiss_index = faiss.read_index(index_filename) 99 | 100 | # the actual quilting defense: 101 | def defense(im): 102 | im = quilting( 103 | im, faiss_index, patches, 104 | patch_size=patch_size, 105 | overlap=(patch_size // 2), 106 | graphcut=True, 107 | k=args.quilting_neighbors, 108 | random_stitch=args.quilting_random_stitch 109 | ) 110 | # Clamping because some values are overflowing in quilting 111 | im = torch.clamp(im, min=0.0, max=1.0) 112 | return im 113 | defense = Defense(defense, defense_name) 114 | 115 | # set up tvm defense: 116 | elif defense_name == str(DefenseType.TVM): 117 | # the actual tvm defense: 118 | def defense(im): 119 | im = tvm( 120 | im, 121 | args.pixel_drop_rate, 122 | args.tvm_method, 123 | args.tvm_weight 124 | ) 125 | return im 126 | defense = Defense(defense, defense_name) 127 | 128 | elif defense_name == str(DefenseType.QUANTIZATION): 129 | def defense(im): 130 | im = _quantize_img(im, depth=args.quantize_depth) 131 | return im 132 | defense = Defense(defense, defense_name) 133 | 134 | elif defense_name == str(DefenseType.JPEG): 135 | def defense(im): 136 | im = _jpeg_compression(im) 137 | return im 138 | defense = Defense(defense, defense_name) 139 | 140 | else: 141 | print('| No defense for \"%s\" is available' % (defense_name)) 142 | 143 | return defense 144 | -------------------------------------------------------------------------------- /adversarial/lib/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import os 14 | import pkgutil 15 | import torch.nn as nn 16 | import torch.nn.parallel 17 | import torchvision.models as models 18 | from lib.util import load_checkpoint 19 | import lib.constants as constants 20 | 21 | 22 | def _load_torchvision_model(model, pretrained=True): 23 | assert hasattr(models, model), ( 24 | "Model {} is not available in torchvision.models." 25 | "Supported models are: {}".format(model, constants.MODELS)) 26 | model = getattr(models, model)(pretrained=pretrained) 27 | return model 28 | 29 | 30 | def _init_data_parallel(model, device): 31 | if device == 'gpu': 32 | import torch.backends.cudnn as cudnn 33 | cudnn.benchmark = True 34 | model = torch.nn.DataParallel(model).cuda() 35 | return model 36 | 37 | 38 | def _init_inceptionresnetv2(model, device): 39 | # inceptionresnetv2 has default 1001 classes, 40 | # get rid of first background class 41 | new_classif = nn.Linear(1536, 1000) 42 | new_classif.weight.data = model.classif.weight.data[1:] 43 | new_classif.bias.data = model.classif.bias.data[1:] 44 | model.classif = new_classif 45 | model = _init_data_parallel(model, device) 46 | return model 47 | 48 | 49 | def _load_model_from_checkpoint(model, model_path, model_name, 50 | defense_name=None, training=False): 51 | 52 | assert model is not None, "Model should not be None" 53 | assert model_path and model_name, \ 54 | "Model path is not provided" 55 | model_path = os.path.join(model_path, model_name) 56 | if defense_name is not None: 57 | model_path = str(model_path + '_' + defense_name) 58 | 59 | if not training: 60 | assert os.path.isdir(model_path), \ 61 | "Model directory doesn't exist at: {}".format(model_path) 62 | 63 | print('| loading model from checkpoint %s' % model_path) 64 | checkpoint = load_checkpoint(model_path) 65 | start_epoch, optimizer = None, None 66 | if checkpoint is not None: 67 | model.load_state_dict(checkpoint['model_state_dict']) 68 | if training: 69 | start_epoch = checkpoint['epoch'] 70 | optimizer = checkpoint['optimizer'] 71 | else: 72 | print('| no model available at %s...' % model_path) 73 | 74 | return model, start_epoch, optimizer 75 | 76 | 77 | def get_model(args, load_checkpoint=False, defense_name=None, training=False): 78 | 79 | assert (args.model in constants.MODELS), ("%s not a supported model" % args.model) 80 | model, start_epoch, optimizer = None, None, None 81 | # load model: 82 | print('| loading model {} ...'.format(args.model)) 83 | if args.model == 'inception_v4': 84 | assert "NUM_CLASSES" in args.data_params, \ 85 | "Inception parameters should have number of classes defined" 86 | if args.pretrained: 87 | pretrained = 'imagenet' 88 | else: 89 | pretrained = None 90 | assert pkgutil.find_loader("lib.models.inceptionv4") is not None, \ 91 | ("Module lib.models.inceptionv4 can't be found. " 92 | "Check the setup script and rebuild again to download") 93 | from lib.models.inceptionv4 import inceptionv4 94 | model = inceptionv4(pretrained=pretrained) 95 | elif args.model == 'inceptionresnetv2': 96 | assert not args.pretrained, \ 97 | "For inceptionresnetv2 pretrained not available" 98 | assert pkgutil.find_loader("lib.models.inceptionresnetv2") is not None, \ 99 | ("Module lib.models.inceptionresnetv2 can't be found. " 100 | "Check the setup script and rebuild again to download") 101 | from lib.models.inceptionresnetv2 import InceptionResnetV2 102 | model = InceptionResnetV2() 103 | else: 104 | model = _load_torchvision_model(args.model, 105 | pretrained=args.pretrained) 106 | 107 | # inceptionresnetv2 from adversarial ensemble training at checkpoint 108 | # is not saved with DataParallel 109 | # TODO: Save it with DataParallel to cleanup code below 110 | if not args.model == 'inceptionresnetv2': 111 | model = _init_data_parallel(model, args.device) 112 | 113 | if load_checkpoint and not args.pretrained: 114 | model, start_epoch, optimizer = _load_model_from_checkpoint( 115 | model, args.models_root, args.model, defense_name, 116 | training=training) 117 | 118 | if args.model == 'inceptionresnetv2': 119 | model = _init_inceptionresnetv2(model, args.device) 120 | 121 | return model, start_epoch, optimizer 122 | -------------------------------------------------------------------------------- /adversarial/lib/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookarchive/adversarial_image_defenses/55bf56ddb017535fb6630a746c6f946202336052/adversarial/lib/models/__init__.py -------------------------------------------------------------------------------- /adversarial/lib/opts.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import argparse 14 | import os 15 | from enum import Enum 16 | import lib.constants as constants 17 | import json 18 | 19 | 20 | # Init paths from config 21 | path_config_file = str("path_config.json") 22 | if not os.path.isfile(path_config_file): 23 | path_config_file = os.path.join(os.path.dirname(__file__), str("path_config.json")) 24 | assert os.path.isfile(path_config_file), \ 25 | "path_config.json file not found at {}".format(path_config_file) 26 | path_config = json.load(open(path_config_file, "r")) 27 | 28 | DATA_ROOT = path_config["DATA_ROOT"] 29 | QUILTING_ROOT = path_config["QUILTING_ROOT"] 30 | MODELS_ROOT = path_config["MODELS_ROOT"] 31 | if not path_config["IMAGENET_DIR1"]: 32 | IMAGENET_DIR1 = os.path.join(os.path.dirname(__file__), "../test/images") 33 | print("\nWARNING: IMAGENET_DIR1 is not defined in path_config.json, " 34 | "so loading test images from {}".format(IMAGENET_DIR1)) 35 | print("To load imagenet data update path IMAGENET_DIR1 in path_config.json\n") 36 | else: 37 | IMAGENET_DIR1 = path_config["IMAGENET_DIR1"] 38 | if not path_config["IMAGENET_DIR2"]: 39 | IMAGENET_DIR2 = None 40 | else: 41 | IMAGENET_DIR2 = path_config["IMAGENET_DIR2"] 42 | 43 | # path to save/load tarred transformed data 44 | TAR_DIR = DATA_ROOT + '/imagenet_transformed_tarred' 45 | # path to save/load index objects for tarred data 46 | # used to directly reads tar data without untarring fully(much faster) 47 | TAR_INDEX_DIR = DATA_ROOT + '/imagenet_transformed_tarred_index' 48 | 49 | 50 | class OptType(Enum): 51 | QUILTING_PATCHES = 'QUILTING_PATCHES' 52 | TRAIN = 'TRAIN' 53 | CLASSIFY = 'CLASSIFY' 54 | TRANSFORMATION = 'TRANSFORMATION' 55 | ADVERSARIAL = 'ADVERSARIAL' 56 | 57 | 58 | def _setup_common_args(parser): 59 | # paths 60 | parser.add_argument('--data_root', default=DATA_ROOT, type=str, metavar='N', 61 | help='Main data directory to save and read data') 62 | parser.add_argument('--models_root', default=MODELS_ROOT, type=str, metavar='N', 63 | help='Directory to store/load models') 64 | parser.add_argument('--tar_dir', default=TAR_DIR, type=str, metavar='N', 65 | help='Path for directory with processed(transformed)' 66 | 'tar train/val files') 67 | parser.add_argument('--tar_index_dir', default=TAR_INDEX_DIR, type=str, metavar='N', 68 | help='Path for directory with processed tar index files') 69 | parser.add_argument('--tar_prefix', default='tmp/imagenet_transformed', 70 | type=str, metavar='N', 71 | help='Path prefix in all tar files') 72 | parser.add_argument('--quilting_index_root', default=QUILTING_ROOT, 73 | type=str, metavar='N', 74 | help='Path for quilting index files') 75 | parser.add_argument('--quilting_patch_root', default=QUILTING_ROOT, 76 | type=str, metavar='N', 77 | help='the path for quilting patches') 78 | 79 | parser.add_argument('--model', default='resnet50', type=str, metavar='N', 80 | help='model to use (default: resnet50)') 81 | parser.add_argument('--device', default='gpu', type=str, metavar='N', 82 | help='device to use: cpu or gpu (default = gpu)') 83 | # Set normalize to True for training, testing and generating adversarial images. 84 | # For generating transformations, let it be False. 85 | parser.add_argument('--normalize', default=False, action='store_true', 86 | help='Normalize image data.') 87 | parser.add_argument('--batchsize', default=256, type=int, metavar='N', 88 | help='batch size (default = 256)') 89 | parser.add_argument('--preprocessed_data', default=False, action='store_true', 90 | help='Defenses are already applied on saved images') 91 | parser.add_argument('--defenses', default=None, nargs='*', type=str, metavar='N', 92 | help='List of defense to apply like tvm, quilting') 93 | parser.add_argument('--pretrained', default=False, action='store_true', 94 | help='use pretrained model from model-zoo') 95 | 96 | # Defense params 97 | # TVM 98 | parser.add_argument('--tvm_weight', default=constants.TVM_WEIGHT, 99 | type=float, metavar='N', 100 | help='weight for TVM') 101 | parser.add_argument('--pixel_drop_rate', default=constants.PIXEL_DROP_RATE, 102 | type=float, metavar='N', 103 | help='Pixel drop rate to use in TVM') 104 | parser.add_argument('--tvm_method', default=constants.TVM_METHOD, 105 | type=str, metavar='N', 106 | help='Reconstruction method to use in TVM') 107 | # Quilting 108 | parser.add_argument('--quilting_patch_size', 109 | default=constants.QUILTING_PATCH_SIZE, 110 | type=int, metavar='N', help='Patch size to use in quilting') 111 | parser.add_argument('--quilting_neighbors', default=1, type=int, metavar='N', 112 | help='Number of nearest neighbors to use for quilting patches') 113 | parser.add_argument('--quilting_random_stitch', default=False, action='store_true', 114 | help='Randomly use quilting patches') 115 | # Quantization 116 | parser.add_argument('--quantize_depth', default=8, type=int, metavar='N', 117 | help='Bit depth for quantization defense') 118 | 119 | return parser 120 | 121 | 122 | def _setup_adversary_args(parser): 123 | # commaon params for generating or reading adversary images 124 | parser.add_argument('--n_samples', default=50000, type=int, metavar='N', 125 | help='Max number of samples to test on') 126 | parser.add_argument('--attack_type', default=None, type=str, metavar='N', 127 | help='Attack type (None(No attack) | blackbox | whitebox)') 128 | # parser.add_argument('--renormalize', default=False, action='store_true', 129 | # help='Renormalize for inception data params') 130 | parser.add_argument('--adversary', default=None, type=str, metavar='N', 131 | help='Adversary to use for pre-generated attack images' 132 | '(default = None)') 133 | parser.add_argument('--adversary_model', default='resnet50', 134 | type=str, metavar='N', 135 | help='Adversarial model to use (default resnet50)') 136 | parser.add_argument('--learning_rate', default=None, nargs='*', 137 | type=float, metavar='N', 138 | help='List of adversarial learning rate for each defense') 139 | parser.add_argument('--adv_strength', default=None, nargs='*', 140 | type=float, metavar='N', 141 | help='List of adversarial strength for each defense') 142 | parser.add_argument('--adversarial_root', default=DATA_ROOT + '/adversarial', 143 | type=str, metavar='N', 144 | help='Directory path adversary data') 145 | 146 | # params for generating adversary images 147 | parser.add_argument('--operation', default='transformation_on_adv', 148 | type=str, metavar='N', 149 | help='Operation to run (generate_adversarial, ' 150 | 'concat_adversarial, compute_adversarial_stats)') 151 | parser.add_argument('--adversary_to_generate', default=None, type=str, metavar='N', 152 | help='Adversary to generate (default = None)') 153 | parser.add_argument('--partition', default=0, type=int, metavar='N', 154 | help='the data partition to work on (indexing from 0)') 155 | parser.add_argument('--partition_size', default=50000, type=int, metavar='N', 156 | help='the size of each data partition') 157 | parser.add_argument('--data_type', default='train', 158 | type=str, metavar='N', 159 | help='data_type (train|raw) for transformation_on_raw') 160 | parser.add_argument('--max_adv_iter', default=10, type=int, metavar='N', 161 | help='max iterations for iteratibe attacks') 162 | parser.add_argument('--fgs_mode', default=None, 163 | type=str, metavar='N', 164 | help='fgs_mode (logit | carlini) for loss computation in FGS') 165 | parser.add_argument('--margin', default=0, type=float, metavar='N', 166 | help='margin parameter for cwl2') 167 | parser.add_argument('--compute_stats', default=False, action='store_true', 168 | help='Compute adversarial stats(robustness, SSIM, ' 169 | 'success rate)') 170 | parser.add_argument('--crop_frac', default=[1.0], nargs='*', 171 | type=float, metavar='N', 172 | help='crop fraction for ensembling or Carlini-Wagner') 173 | 174 | return parser 175 | 176 | 177 | def _parse_train_opts(): 178 | parser = argparse.ArgumentParser(description='Train convolutional network') 179 | parser = _setup_common_args(parser) 180 | 181 | parser.add_argument('--resume', default=False, action='store_true', 182 | help='Resume training from checkpoint (if available)') 183 | parser.add_argument('--lr', default=None, type=float, metavar='N', 184 | help='Initial learning rate for training, \ 185 | for inception_v4 use lr=0.045, 0.1 for others)') 186 | parser.add_argument('--lr_decay', default=None, type=float, metavar='N', 187 | help='exponential learning rate decay(0.94 for \ 188 | inception_v4, 0.1 for others)') 189 | parser.add_argument('--lr_decay_stepsize', default=None, type=float, metavar='N', 190 | help='decay learning rate after every stepsize \ 191 | epochs(2 for inception_v4, 30 for others)') 192 | parser.add_argument('--momentum', default=0.9, type=float, metavar='N', 193 | help='amount of momentum (default = 0.9)') 194 | parser.add_argument('--weight_decay', default=1e-4, type=float, metavar='N', 195 | help='amount of weight decay (default = 1e-4)') 196 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 197 | help='index of first epoch (default = 0)') 198 | parser.add_argument('--end_epoch', default=None, type=int, metavar='N', 199 | help='index of last epoch (default = 90), \ 200 | for inception_v4 use end_epoch=160)') 201 | parser.add_argument('--preprocessed_epoch_data', default=False, action='store_true', 202 | help='Randomly cropped data for each epoch is pre-generated') 203 | 204 | args = parser.parse_args() 205 | 206 | if args.model.startswith('inception'): 207 | hyperparams = constants.INCEPTION_V4_TRAINING_PARAMS 208 | # inception specific hyper_params 209 | args.rms_eps = constants.INCEPTION_V4_TRAINING_PARAMS['RMS_EPS'] 210 | args.rms_alpha = constants.INCEPTION_V4_TRAINING_PARAMS['RMS_ALPHA'] 211 | else: 212 | hyperparams = constants.TRAINING_PARAMS 213 | 214 | # init common training hyperparams 215 | if not args.lr: 216 | args.lr = hyperparams['LR'] 217 | if not args.lr_decay: 218 | args.lr_decay = hyperparams['LR_DECAY'] 219 | if not args.lr_decay_stepsize: 220 | args.lr_decay_stepsize = hyperparams['LR_DECAY_STEPSIZE'] 221 | if not args.end_epoch: 222 | args.end_epoch = hyperparams['EPOCHS'] 223 | 224 | return args 225 | 226 | 227 | def _parse_classify_opts(): 228 | # set input arguments: 229 | parser = argparse.ArgumentParser(description='Classify adversarial images') 230 | parser = _setup_common_args(parser) 231 | parser = _setup_adversary_args(parser) 232 | 233 | parser.add_argument('--ensemble', default=None, type=str, metavar='N', 234 | help='ensembling type (None | avg | max)') 235 | parser.add_argument('--ncrops', default=None, nargs='*', type=int, metavar='N', 236 | help='list of number of crops for each defense' 237 | ' to use for ensembling') 238 | parser.add_argument('--crop_type', default=None, nargs='*', 239 | type=str, metavar='N', 240 | help='Crop during(center=CenterCrop, ' 241 | 'random=RandomRisedCrop, sliding=Sliding Window Crops)') 242 | 243 | args = parser.parse_args() 244 | 245 | return args 246 | 247 | 248 | # args for generating transformation data 249 | def _parse_generate_opts(): 250 | parser = argparse.ArgumentParser(description='Generate and save' + 251 | 'image transformations') 252 | parser = _setup_common_args(parser) 253 | parser = _setup_adversary_args(parser) 254 | 255 | # paths 256 | parser.add_argument('--out_dir', default=DATA_ROOT, type=str, metavar='N', 257 | help='Directory path to output concatenated ' 258 | 'transformed data') 259 | parser.add_argument('--partition_dir', default=DATA_ROOT + '/partitioned', 260 | type=str, metavar='N', 261 | help='Directory path to output transformed data') 262 | 263 | parser.add_argument('--data_batches', default=None, type=int, metavar='N', 264 | help='Number of data batches to generate') 265 | 266 | parser.add_argument('--n_threads', default=20, type=int, metavar='N', 267 | help='Number of threads for raw image transformation') 268 | 269 | parser.add_argument('--data_file', default=None, type=str, metavar='N', 270 | help='Data file path to read images to visualize') 271 | 272 | args = parser.parse_args() 273 | 274 | return args 275 | 276 | 277 | def _setup_model_based_data_params(args): 278 | if args.model.startswith('inception'): 279 | args.data_params = constants.INCEPTION_V4_DATA_PARAMS 280 | else: 281 | args.data_params = constants.RESNET_DENSENET_DATA_PARAMS 282 | 283 | return args 284 | 285 | 286 | def _parse_adversarial_opts(): 287 | parser = argparse.ArgumentParser(description='Generate adversarial images') 288 | parser = _setup_common_args(parser) 289 | parser = _setup_adversary_args(parser) 290 | args = parser.parse_args() 291 | return args 292 | 293 | 294 | def _parse_quilting_patch_opts(): 295 | # set input arguments: 296 | parser = argparse.ArgumentParser(description='Build FAISS index of patches') 297 | parser = _setup_common_args(parser) 298 | parser.add_argument('--num_patches', default=1000000, type=int, metavar='N', 299 | help='number of patches in index (default: 1M)') 300 | parser.add_argument('--pca_dims', default=64, type=int, metavar='N', 301 | help='number of pca dimensions to use (default: 64)') 302 | parser.add_argument('--patches_file', default='/tmp/tmp.pickle', type=str, 303 | metavar='N', help='filename in which to save patches') 304 | parser.add_argument('--index_file', default='/tmp/tmp.faiss', type=str, metavar='N', 305 | help='filename in which to save faiss index') 306 | args = parser.parse_args() 307 | args.data_params = { 308 | 'IMAGE_SIZE': 224, 309 | 'IMAGE_SCALE_SIZE': 256, 310 | } 311 | return args 312 | 313 | 314 | def parse_args(opt_type): 315 | assert isinstance(opt_type, OptType), \ 316 | '{} not an instance of OptType Enum'.format(opt_type) 317 | assert DATA_ROOT, \ 318 | "{} DATA_ROOT can't be empty. Update in path_config.json with correct value" 319 | 320 | if opt_type == OptType.QUILTING_PATCHES: 321 | args = _parse_quilting_patch_opts() 322 | else: 323 | if opt_type == OptType.TRAIN: 324 | args = _parse_train_opts() 325 | elif opt_type == OptType.CLASSIFY: 326 | args = _parse_classify_opts() 327 | elif opt_type == OptType.TRANSFORMATION: 328 | args = _parse_generate_opts() 329 | elif opt_type == OptType.ADVERSARIAL: 330 | args = _parse_adversarial_opts() 331 | 332 | # model 333 | assert args.model in constants.MODELS, \ 334 | "model \"{}\" is not defined".format(args.model) 335 | 336 | args = _setup_model_based_data_params(args) 337 | 338 | # imagenet dir 339 | imagenet_dirs = [ 340 | IMAGENET_DIR1, 341 | IMAGENET_DIR2, 342 | ] 343 | for path in imagenet_dirs: 344 | if os.path.isdir(path): 345 | args.imagenet_dir = str(path) 346 | break 347 | assert hasattr(args, "imagenet_dir"), ( 348 | "ERROR: Can't find imagenet data at paths: {}. \n " 349 | "Update the IMAGENET_DIR1 in path_config.json with the correct path" 350 | .format(imagenet_dirs)) 351 | 352 | print("| Input args are:") 353 | print(args) 354 | 355 | return args 356 | -------------------------------------------------------------------------------- /adversarial/lib/path_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "DATA_ROOT": "/tmp/adverarial_dataset", 3 | "QUILTING_ROOT": "/tmp/adverarial_dataset", 4 | "MODELS_ROOT": "/tmp/adverarial_dataset", 5 | "IMAGENET_DIR1": "", 6 | "IMAGENET_DIR2": "" 7 | } 8 | -------------------------------------------------------------------------------- /adversarial/lib/paths.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import os 14 | from lib.constants import AdversaryType, DefenseType, AttackType 15 | 16 | 17 | def _get_adv_path_part(args, defense_name, adv_params, with_defense=True): 18 | 19 | if args.adversary_to_generate: 20 | adversary = args.adversary_to_generate 21 | elif args.adversary: 22 | adversary = args.adversary 23 | else: 24 | return "" 25 | 26 | advtext = "adversary-" 27 | 28 | if args.attack_type == str(AttackType.WHITEBOX) and defense_name: 29 | advtext = advtext + defense_name + "_" 30 | 31 | if (adversary == str(AdversaryType.IFGS) or 32 | adversary == str(AdversaryType.DEEPFOOL)): 33 | assert adv_params is not None and adv_params['learning_rate'] is not None,\ 34 | "learning rate can't be None for iterative attacks" 35 | advtext = '%s%s_%1.4f' % (advtext, adversary, 36 | adv_params['learning_rate']) 37 | else: 38 | advtext = advtext + adversary 39 | # non iterative attacks are generated at run time so for no defense 40 | if ((adversary == str(AdversaryType.CWL2) or 41 | adversary == str(AdversaryType.FGS)) and 42 | (adv_params is not None and adv_params['adv_strength'] is not None) and 43 | defense_name and with_defense): 44 | advtext = '%s_%1.4f' % (advtext, adv_params['adv_strength']) 45 | 46 | return advtext 47 | 48 | 49 | def _get_path_part_from_defenses(args, defense_name): 50 | assert 'preprocessed_data' in args, \ 51 | 'preprocessed_data argument is expected but not present in args' 52 | assert 'defenses' in args, \ 53 | 'defenses argument is expected but not present in args' 54 | assert 'quilting_random_stitch' in args, \ 55 | 'quilting_random_stitch argument is expected but not present in args' 56 | assert 'quilting_neighbors' in args, \ 57 | 'quilting_neighbors argument is expected but not present in args' 58 | assert 'tvm_weight' in args, \ 59 | 'tvm_weight argument is expected but not present in args' 60 | assert 'pixel_drop_rate' in args, \ 61 | 'pixel_drop_rate argument is expected but not present in args' 62 | 63 | if not defense_name: 64 | return "" 65 | 66 | d_str = "defense-" + defense_name 67 | if defense_name == str(DefenseType.TVM): 68 | d_str = str(d_str + '_drop_' + str(args.pixel_drop_rate) + 69 | '_weight_' + str(args.tvm_weight)) 70 | elif defense_name == str(DefenseType.QUILTING): 71 | if args.quilting_random_stitch: 72 | d_str = d_str + 'random-stitch' 73 | elif args.quilting_neighbors > 1: 74 | d_str = d_str + 'random-patch_' + str(args.quilting_neighbors) 75 | 76 | return d_str 77 | 78 | 79 | def get_adversarial_file_path(args, root_dir, defense_name, adv_params, end_idx, 80 | start_idx=0, data_batch_idx=None, with_defense=True): 81 | assert 'adversary' in args, \ 82 | 'partition argumenet is expected but not present in args' 83 | assert 'preprocessed_data' in args, \ 84 | 'preprocessed_data argumenet is expected but not present in args' 85 | assert 'adversary_model' in args, \ 86 | 'adversary_model argumenet is expected but not present in args' 87 | 88 | d_str = None 89 | if with_defense: 90 | d_str = _get_path_part_from_defenses(args, defense_name) 91 | adv_str = _get_adv_path_part(args, defense_name, adv_params, 92 | with_defense=with_defense) 93 | 94 | file_path = root_dir + '/' 95 | if d_str: 96 | file_path = file_path + d_str + '_' 97 | if adv_str: 98 | file_path = file_path + adv_str + '_' 99 | if args.adversary_model: 100 | file_path = file_path + args.adversary_model + '_' 101 | 102 | file_path = '%s%s_%d-%d' % (file_path, 'val', start_idx + 1, end_idx) 103 | 104 | if data_batch_idx is not None: 105 | file_path = str('%s_%d' % (file_path, data_batch_idx)) 106 | file_path = file_path + '.pth' 107 | 108 | return file_path 109 | 110 | 111 | def _get_preprocessed_tar_index_dir(args, data_type, epoch): 112 | assert (data_type == 'train' or data_type == 'valid'), ( 113 | "{} data type not defined. Defined types are \"train\" " 114 | "and \"valid\" ".format(data_type)) 115 | assert os.path.isdir(args.tar_index_dir), \ 116 | "{} doesn't exist".format(args.tar_index_dir) 117 | # preprocessed train dataset for all epochs 118 | if data_type == 'train': 119 | assert epoch >= 0 120 | index_file = "{}/epoch_{}.index".format(args.tar_index_dir, epoch) 121 | # For validation, same dataset for all epochs 122 | else: 123 | index_file = "{}/val.index".format(args.tar_index_dir) 124 | assert os.path.isfile(index_file), \ 125 | "{} doesn't exist".format(index_file) 126 | 127 | return index_file 128 | 129 | 130 | # get location of the images 131 | def get_img_dir(args, 132 | data_type, 133 | epoch=-1): 134 | 135 | img_dir = None 136 | # Images are not adversarial 137 | if 'adversary' not in args or not args.adversary: 138 | if 'preprocessed_data' not in args or not args.preprocessed_data: 139 | dir_name = 'train' if data_type == 'train' else 'val' 140 | img_dir = str(os.path.join(args.imagenet_dir, dir_name)) 141 | # Data is preprocessed for defenses like tvm, quilting 142 | else: 143 | # this is index_file stored as img_dir 144 | img_dir = str(_get_preprocessed_tar_index_dir(args, data_type, epoch)) 145 | 146 | # If needs to work on pre-generated adversarial images 147 | else: 148 | img_dir = str(args.adversarial_root) 149 | 150 | assert os.path.isdir(img_dir) or os.path.isfile(img_dir), \ 151 | "Data directory {} doesn't exist. Update the IMAGENET_DIR1 in the \ 152 | config file with correct path".format(img_dir) 153 | 154 | return img_dir 155 | 156 | 157 | def get_quilting_filepaths(args): 158 | root = args.quilting_patch_root 159 | size = args.quilting_patch_size 160 | patches_filename = str("{root}/patches_{size}.pickle".format( 161 | root=root, size=size)) 162 | index_filename = str("{root}/index_{size}.faiss".format( 163 | root=root, size=size)) 164 | 165 | return patches_filename, index_filename 166 | -------------------------------------------------------------------------------- /adversarial/lib/transformations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookarchive/adversarial_image_defenses/55bf56ddb017535fb6630a746c6f946202336052/adversarial/lib/transformations/__init__.py -------------------------------------------------------------------------------- /adversarial/lib/transformations/_tv_bregman.patch: -------------------------------------------------------------------------------- 1 | --- _denoise_cy.pyx 2017-12-11 10:48:58.425296545 -0800 2 | +++ _tv_bregman.pyx 2017-12-11 10:50:10.234351675 -0800 3 | @@ -7,229 +7,82 @@ 4 | import numpy as np 5 | from libc.math cimport exp, fabs, sqrt 6 | from libc.float cimport DBL_MAX 7 | -from .._shared.interpolation cimport get_pixel3d 8 | -from ..util import img_as_float 9 | - 10 | - 11 | -cdef inline double _gaussian_weight(double sigma, double value): 12 | - return exp(-0.5 * (value / sigma)**2) 13 | - 14 | - 15 | -cdef double[:] _compute_color_lut(Py_ssize_t bins, double sigma, double max_value): 16 | - 17 | - cdef: 18 | - double[:] color_lut = np.empty(bins, dtype=np.double) 19 | - Py_ssize_t b 20 | - 21 | - for b in range(bins): 22 | - color_lut[b] = _gaussian_weight(sigma, b * max_value / bins) 23 | - 24 | - return color_lut 25 | - 26 | - 27 | -cdef double[:] _compute_range_lut(Py_ssize_t win_size, double sigma): 28 | - 29 | - cdef: 30 | - double[:] range_lut = np.empty(win_size**2, dtype=np.double) 31 | - Py_ssize_t kr, kc 32 | - Py_ssize_t window_ext = (win_size - 1) / 2 33 | - double dist 34 | - 35 | - for kr in range(win_size): 36 | - for kc in range(win_size): 37 | - dist = sqrt((kr - window_ext)**2 + (kc - window_ext)**2) 38 | - range_lut[kr * win_size + kc] = _gaussian_weight(sigma, dist) 39 | - 40 | - return range_lut 41 | - 42 | - 43 | -cdef inline Py_ssize_t Py_ssize_t_min(Py_ssize_t value1, Py_ssize_t value2): 44 | - if value1 < value2: 45 | - return value1 46 | - else: 47 | - return value2 48 | - 49 | - 50 | -def _denoise_bilateral(image, Py_ssize_t win_size, sigma_color, 51 | - double sigma_spatial, Py_ssize_t bins, 52 | - mode, double cval): 53 | - cdef: 54 | - double min_value, max_value 55 | - 56 | - min_value = image.min() 57 | - max_value = image.max() 58 | - 59 | - if min_value == max_value: 60 | - return image 61 | - 62 | - # if image.max() is 0, then dist_scale can have an unverified value 63 | - # and color_lut[(dist * dist_scale)] may cause a segmentation fault 64 | - # so we verify we have a positive image and that the max is not 0.0. 65 | - if min_value < 0.0: 66 | - raise ValueError("Image must contain only positive values") 67 | - 68 | - if max_value == 0.0: 69 | - raise ValueError("The maximum value found in the image was 0.") 70 | - 71 | - image = np.atleast_3d(img_as_float(image)) 72 | - 73 | - cdef: 74 | - Py_ssize_t rows = image.shape[0] 75 | - Py_ssize_t cols = image.shape[1] 76 | - Py_ssize_t dims = image.shape[2] 77 | - Py_ssize_t window_ext = (win_size - 1) / 2 78 | - Py_ssize_t max_color_lut_bin = bins - 1 79 | - 80 | - double[:, :, ::1] cimage 81 | - double[:, :, ::1] out 82 | - 83 | - double[:] color_lut 84 | - double[:] range_lut 85 | - 86 | - Py_ssize_t r, c, d, wr, wc, kr, kc, rr, cc, pixel_addr, color_lut_bin 87 | - double value, weight, dist, total_weight, csigma_color, color_weight, \ 88 | - range_weight 89 | - double dist_scale 90 | - double[:] values 91 | - double[:] centres 92 | - double[:] total_values 93 | - 94 | - if sigma_color is None: 95 | - csigma_color = image.std() 96 | - else: 97 | - csigma_color = sigma_color 98 | - 99 | - if mode not in ('constant', 'wrap', 'symmetric', 'reflect', 'edge'): 100 | - raise ValueError("Invalid mode specified. Please use `constant`, " 101 | - "`edge`, `wrap`, `symmetric` or `reflect`.") 102 | - cdef char cmode = ord(mode[0].upper()) 103 | - 104 | - cimage = np.ascontiguousarray(image) 105 | - 106 | - out = np.zeros((rows, cols, dims), dtype=np.double) 107 | - color_lut = _compute_color_lut(bins, csigma_color, max_value) 108 | - range_lut = _compute_range_lut(win_size, sigma_spatial) 109 | - dist_scale = bins / dims / max_value 110 | - values = np.empty(dims, dtype=np.double) 111 | - centres = np.empty(dims, dtype=np.double) 112 | - total_values = np.empty(dims, dtype=np.double) 113 | - 114 | - for r in range(rows): 115 | - for c in range(cols): 116 | - total_weight = 0 117 | - for d in range(dims): 118 | - total_values[d] = 0 119 | - centres[d] = cimage[r, c, d] 120 | - for wr in range(-window_ext, window_ext + 1): 121 | - rr = wr + r 122 | - kr = wr + window_ext 123 | - for wc in range(-window_ext, window_ext + 1): 124 | - cc = wc + c 125 | - kc = wc + window_ext 126 | - 127 | - # save pixel values for all dims and compute euclidian 128 | - # distance between centre stack and current position 129 | - dist = 0 130 | - for d in range(dims): 131 | - value = get_pixel3d(&cimage[0, 0, 0], rows, cols, dims, 132 | - rr, cc, d, cmode, cval) 133 | - values[d] = value 134 | - dist += (centres[d] - value)**2 135 | - dist = sqrt(dist) 136 | 137 | - range_weight = range_lut[kr * win_size + kc] 138 | - 139 | - color_lut_bin = Py_ssize_t_min( 140 | - (dist * dist_scale), max_color_lut_bin) 141 | - color_weight = color_lut[color_lut_bin] 142 | 143 | - weight = range_weight * color_weight 144 | - for d in range(dims): 145 | - total_values[d] += values[d] * weight 146 | - total_weight += weight 147 | - for d in range(dims): 148 | - out[r, c, d] = total_values[d] / total_weight 149 | - 150 | - return np.squeeze(np.asarray(out)) 151 | - 152 | - 153 | -def _denoise_tv_bregman(image, double weight, int max_iter, double eps, 154 | - char isotropic): 155 | - image = np.atleast_3d(img_as_float(image)) 156 | +def _denoise_tv_bregman(image, mask, double weight, int max_iter, int gs_iter, 157 | + double eps, char isotropic): 158 | + image = np.atleast_3d(image) 159 | 160 | cdef: 161 | Py_ssize_t rows = image.shape[0] 162 | Py_ssize_t cols = image.shape[1] 163 | Py_ssize_t dims = image.shape[2] 164 | - Py_ssize_t rows2 = rows + 2 165 | - Py_ssize_t cols2 = cols + 2 166 | Py_ssize_t r, c, k 167 | 168 | Py_ssize_t total = rows * cols * dims 169 | 170 | - shape_ext = (rows2, cols2, dims) 171 | - u = np.zeros(shape_ext, dtype=np.double) 172 | + u = np.zeros(image.shape, dtype=np.double) 173 | + u[:, :, :] = image 174 | 175 | cdef: 176 | double[:, :, ::1] cimage = np.ascontiguousarray(image) 177 | + char[:, :, ::1] cmask = mask 178 | double[:, :, ::1] cu = u 179 | 180 | - double[:, :, ::1] dx = np.zeros(shape_ext, dtype=np.double) 181 | - double[:, :, ::1] dy = np.zeros(shape_ext, dtype=np.double) 182 | - double[:, :, ::1] bx = np.zeros(shape_ext, dtype=np.double) 183 | - double[:, :, ::1] by = np.zeros(shape_ext, dtype=np.double) 184 | + double[:, :, ::1] dx = np.zeros(image.shape, dtype=np.double) 185 | + double[:, :, ::1] dy = np.zeros(image.shape, dtype=np.double) 186 | + double[:, :, ::1] bx = np.zeros(image.shape, dtype=np.double) 187 | + double[:, :, ::1] by = np.zeros(image.shape, dtype=np.double) 188 | + double[:, :, ::1] z = np.zeros(image.shape, dtype=np.double) 189 | + double[:, :, ::1] uprev = np.ascontiguousarray(image) 190 | 191 | - double ux, uy, uprev, unew, bxx, byy, dxx, dyy, s 192 | + double ux, uy, unew, bxx, byy, dxx, dyy, s 193 | int i = 0 194 | double lam = 2 * weight 195 | double rmse = DBL_MAX 196 | - double norm = (weight + 4 * lam) 197 | - 198 | - u[1:-1, 1:-1] = image 199 | - 200 | - # reflect image 201 | - u[0, 1:-1] = image[1, :] 202 | - u[1:-1, 0] = image[:, 1] 203 | - u[-1, 1:-1] = image[-2, :] 204 | - u[1:-1, -1] = image[:, -2] 205 | + double neighbors = 0 206 | + double inner = 0 207 | 208 | while i < max_iter and rmse > eps: 209 | 210 | - rmse = 0 211 | - 212 | + for _ in range(gs_iter): 213 | - for k in range(dims): 214 | + for k in range(dims): 215 | - for r in range(1, rows + 1): 216 | - for c in range(1, cols + 1): 217 | - 218 | - uprev = cu[r, c, k] 219 | - 220 | - # forward derivatives 221 | - ux = cu[r, c + 1, k] - uprev 222 | - uy = cu[r + 1, c, k] - uprev 223 | - 224 | + for r in range(rows): 225 | + for c in range(cols): 226 | # Gauss-Seidel method 227 | - unew = ( 228 | - lam * ( 229 | - + cu[r + 1, c, k] 230 | - + cu[r - 1, c, k] 231 | - + cu[r, c + 1, k] 232 | - + cu[r, c - 1, k] 233 | - 234 | - + dx[r, c - 1, k] 235 | - - dx[r, c, k] 236 | - + dy[r - 1, c, k] 237 | - - dy[r, c, k] 238 | - 239 | - - bx[r, c - 1, k] 240 | - + bx[r, c, k] 241 | - - by[r - 1, c, k] 242 | - + by[r, c, k] 243 | - ) + weight * cimage[r - 1, c - 1, k] 244 | - ) / norm 245 | + inner = z[r, c, k] 246 | + neighbors = 0 247 | + if r > 0: 248 | + inner += cu[r - 1, c, k] 249 | + neighbors += 1 250 | + if r < rows - 1: 251 | + inner += cu[r + 1, c, k] 252 | + neighbors += 1 253 | + if c > 0: 254 | + inner += cu[r, c - 1, k] 255 | + neighbors += 1 256 | + if c < cols - 1: 257 | + inner += cu[r, c + 1, k] 258 | + neighbors += 1 259 | + if cmask[r, c, k] == 1: 260 | + unew = (lam * inner + weight * cimage[r, c, k]) / (weight + neighbors * lam) 261 | + else: 262 | + unew = inner / 4 263 | - cu[r, c, k] = unew 264 | + cu[r, c, k] = unew 265 | 266 | - # update root mean square error 267 | - rmse += (unew - uprev)**2 268 | + rmse = 0 269 | + for k in range(dims): 270 | + for r in range(rows): 271 | + for c in range(cols): 272 | + # forward derivatives 273 | + if c == cols - 1: 274 | + ux = 0 275 | + else: 276 | + ux = cu[r, c + 1, k] - cu[r, c, k] 277 | + if r == rows - 1: 278 | + uy = 0 279 | + else: 280 | + uy = cu[r + 1, c, k] - cu[r, c, k] 281 | 282 | bxx = bx[r, c, k] 283 | byy = by[r, c, k] 284 | @@ -262,7 +114,17 @@ 285 | bx[r, c, k] += ux - dxx 286 | by[r, c, k] += uy - dyy 287 | 288 | + z[r, c, k] = -dx[r, c, k] - dy[r, c, k] + bx[r, c, k] + by[r, c, k] 289 | + if r > 0: 290 | + z[r, c, k] += dy[r - 1, c, k] - by[r - 1, c, k] 291 | + if c > 0: 292 | + z[r, c, k] += dx[r, c - 1, k] - bx[r, c - 1, k] 293 | + 294 | + # update rmse 295 | + rmse += (cu[r, c, k] - uprev[r, c, k])**2 296 | + 297 | rmse = sqrt(rmse / total) 298 | + uprev = np.copy(cu) 299 | i += 1 300 | 301 | - return np.squeeze(np.asarray(u[1:-1, 1:-1])) 302 | + return np.squeeze(np.asarray(u)) 303 | -------------------------------------------------------------------------------- /adversarial/lib/transformations/findseam.cpp: -------------------------------------------------------------------------------- 1 | #include "findseam.h" 2 | #include "graph.h" 3 | 4 | double findseam( 5 | int numnodes, // number of nodes 6 | int numedges, // number of edges 7 | int* from, // from indices 8 | int* to, // to indices 9 | float* values, // values on edges 10 | float* tvalues, // values for terminal edges 11 | int* labels // memory in which to write the labels 12 | ) { 13 | // initialize graph: 14 | Graph* g = 15 | new Graph(numnodes, numedges); 16 | g->add_node(numnodes); 17 | 18 | // add edges: 19 | for (unsigned int i = 0; i < numedges; i++) { 20 | g->add_edge(from[i], to[i], values[i], 0.0f); 21 | } 22 | 23 | // add terminal nodes: 24 | for (unsigned int i = 0; i < numnodes; i++) { 25 | g->add_tweights(i, tvalues[i * 2], tvalues[i * 2 + 1]); 26 | } 27 | 28 | // run maxflow algorithm: 29 | double flow = g->maxflow(); 30 | for (unsigned int i = 0; i < numnodes; i++) { 31 | labels[i] = g->what_segment(i); 32 | } 33 | 34 | // return results: 35 | delete g; 36 | return flow; 37 | } 38 | -------------------------------------------------------------------------------- /adversarial/lib/transformations/findseam.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __cplusplus 4 | extern "C" { 5 | #endif 6 | 7 | double findseam( 8 | int numnodes, // number of nodes 9 | int numedges, // number of edges 10 | int* from, // from indices 11 | int* to, // to indices 12 | float* values, // values on edges 13 | float* tvalues, // values for terminal edges 14 | int* labels // memory in which to write the labels 15 | ); 16 | 17 | #ifdef __cplusplus 18 | } 19 | #endif 20 | -------------------------------------------------------------------------------- /adversarial/lib/transformations/quilting.cpp: -------------------------------------------------------------------------------- 1 | #include "quilting.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "findseam.h" 7 | 8 | void generatePatches( 9 | float* result, 10 | float* img, 11 | unsigned int imgH, 12 | unsigned int imgW, 13 | unsigned int patchSize, 14 | unsigned int overlap) { 15 | int n = 0; 16 | 17 | for (int y = 0; y < imgH - patchSize; y += patchSize - overlap) { 18 | for (int x = 0; x < imgW - patchSize; x += patchSize - overlap) { 19 | for (int c = 0; c < 3; c++) { 20 | for (int j = 0; j < patchSize; j++) { 21 | for (int i = 0; i < patchSize; i++) { 22 | result 23 | [n * 3 * patchSize * patchSize + c * patchSize * patchSize + 24 | j * patchSize + i] = 25 | img[c * imgH * imgW + (y + j) * imgW + (x + i)]; 26 | } 27 | } 28 | } 29 | n++; 30 | } 31 | } 32 | } 33 | 34 | // "from" nodes and "to" nodes 35 | using graphLattice = std::pair, std::vector>; 36 | 37 | std::map, graphLattice> cache; 38 | 39 | graphLattice _getFourLattice(unsigned int h, unsigned int w, bool useCache) { 40 | if (useCache) { 41 | auto iter = cache.find(std::make_pair(h, w)); 42 | if (iter != cache.end()) { 43 | return iter->second; 44 | } 45 | } 46 | 47 | std::vector from, to; 48 | 49 | // right 50 | for (int j = 0; j < h; j++) { 51 | for (int i = 0; i < w - 1; i++) { 52 | from.push_back(j * w + i); 53 | to.push_back(j * w + (i + 1)); 54 | } 55 | } 56 | 57 | // left 58 | for (int j = 0; j < h; j++) { 59 | for (int i = 1; i < w; i++) { 60 | from.push_back(j * w + i); 61 | to.push_back(j * w + (i - 1)); 62 | } 63 | } 64 | 65 | // down 66 | for (int j = 0; j < h - 1; j++) { 67 | for (int i = 0; i < w; i++) { 68 | from.push_back(j * w + i); 69 | to.push_back((j + 1) * w + i); 70 | } 71 | } 72 | 73 | // up 74 | for (int j = 1; j < h; j++) { 75 | for (int i = 0; i < w; i++) { 76 | from.push_back(j * w + i); 77 | to.push_back((j - 1) * w + i); 78 | } 79 | } 80 | 81 | graphLattice result = std::make_pair(from, to); 82 | 83 | if (useCache) { 84 | cache[std::make_pair(h, w)] = result; 85 | } 86 | 87 | return result; 88 | } 89 | 90 | void _findSeam( 91 | int* result, 92 | float* im1, 93 | float* im2, 94 | unsigned int patchSize, 95 | unsigned int* mask) { 96 | graphLattice graph = _getFourLattice(patchSize, patchSize, true); 97 | std::vector from = graph.first; 98 | std::vector to = graph.second; 99 | int edgeNum = 4 * patchSize * patchSize - 2 * (patchSize + patchSize); 100 | 101 | float* values = new float[edgeNum]; 102 | for (int i = 0; i < edgeNum; i++) { 103 | values[i] = 0; 104 | } 105 | 106 | for (int c = 0; c < 3; c++) { 107 | for (int i = 0; i < edgeNum; i++) { 108 | values[i] += fabs( 109 | im2[c * patchSize * patchSize + to[i]] - 110 | im1[c * patchSize * patchSize + from[i]]); 111 | } 112 | } 113 | 114 | int nodeNum = patchSize * patchSize; 115 | float* tvalues = new float[nodeNum * 2]; 116 | for (int i = 0; i < nodeNum * 2; i++) { 117 | tvalues[i] = 0; 118 | } 119 | 120 | for (int j = 0; j < patchSize; j++) { 121 | for (int i = 0; i < patchSize; i++) { 122 | for (int c = 0; c < 2; c++) { 123 | if (mask[j * patchSize + i] == c + 1) { 124 | tvalues[(j * patchSize + i) * 2 + c] = 125 | std::numeric_limits::infinity(); 126 | } 127 | } 128 | } 129 | } 130 | 131 | findseam(nodeNum, edgeNum, from.data(), to.data(), values, tvalues, result); 132 | delete[] values; 133 | delete[] tvalues; 134 | } 135 | 136 | void stitch( 137 | float* result, 138 | float* im1, 139 | float* im2, 140 | unsigned int patchSize, 141 | unsigned int overlap, 142 | unsigned int y, 143 | unsigned int x) { 144 | unsigned int* mask = new unsigned int[patchSize * patchSize]; 145 | 146 | for (int j = 0; j < patchSize; j++) { 147 | for (int i = 0; i < patchSize; i++) { 148 | mask[j * patchSize + i] = 2; 149 | } 150 | } 151 | 152 | if (y > 0) { 153 | for (int j = 0; j < overlap; j++) { 154 | for (int i = 0; i < patchSize; i++) { 155 | mask[j * patchSize + i] = 0; 156 | } 157 | } 158 | } 159 | 160 | if (x > 0) { 161 | for (int j = 0; j < patchSize; j++) { 162 | for (int i = 0; i < overlap; i++) { 163 | mask[j * patchSize + i] = 0; 164 | } 165 | } 166 | } 167 | 168 | int* seamMask = new int[patchSize * patchSize]; 169 | _findSeam(seamMask, im1, im2, patchSize, mask); 170 | 171 | int offset; 172 | for (int c = 0; c < 3; c++) { 173 | for (int j = 0; j < patchSize; j++) { 174 | for (int i = 0; i < patchSize; i++) { 175 | offset = c * patchSize * patchSize + j * patchSize + i; 176 | result[offset] = 177 | (seamMask[j * patchSize + i] == 1) ? im2[offset] : im1[offset]; 178 | } 179 | } 180 | } 181 | delete [] mask; 182 | delete [] seamMask; 183 | } 184 | 185 | void generateQuiltedImages( 186 | float* result, 187 | long* neighbors, 188 | float* patchDict, 189 | unsigned int imgH, 190 | unsigned int imgW, 191 | unsigned int patchSize, 192 | unsigned int overlap, 193 | bool graphcut) { 194 | int n = 0; 195 | for (int y = 0; y < imgH - patchSize; y += patchSize - overlap) { 196 | for (int x = 0; x < imgW - patchSize; x += patchSize - overlap) { 197 | if (neighbors[n] != -1) { 198 | if (graphcut) { 199 | float* patch = new float[3 * patchSize * patchSize]; 200 | for (int c = 0; c < 3; c++) { 201 | for (int j = 0; j < patchSize; j++) { 202 | for (int i = 0; i < patchSize; i++) { 203 | patch[c * patchSize * patchSize + j * patchSize + i] = 204 | result[c * imgH * imgW + (y + j) * imgW + (x + i)]; 205 | } 206 | } 207 | } 208 | 209 | float* stitched = new float[3 * patchSize * patchSize]; 210 | float* matched = 211 | patchDict + (neighbors[n] * 3 * patchSize * patchSize); 212 | stitch(stitched, patch, matched, patchSize, overlap, y, x); 213 | for (int c = 0; c < 3; c++) { 214 | for (int j = 0; j < patchSize; j++) { 215 | for (int i = 0; i < patchSize; i++) { 216 | result[c * imgH * imgW + (y + j) * imgW + (x + i)] = 217 | stitched[c * patchSize * patchSize + j * patchSize + i]; 218 | } 219 | } 220 | } 221 | delete[] patch; 222 | delete[] stitched; 223 | } else { 224 | for (int c = 0; c < 3; c++) { 225 | for (int j = 0; j < patchSize; j++) { 226 | for (int i = 0; i < patchSize; i++) { 227 | result[c * imgH * imgW + (y + j) * imgW + (x + i)] = patchDict 228 | [neighbors[n] * 3 * patchSize * patchSize + 229 | c * patchSize * patchSize + j * patchSize + i]; 230 | } 231 | } 232 | } 233 | } 234 | } 235 | 236 | n++; 237 | } 238 | } 239 | } 240 | -------------------------------------------------------------------------------- /adversarial/lib/transformations/quilting.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __cplusplus 4 | extern "C" { 5 | #endif 6 | 7 | void generatePatches( 8 | float* result, // N x (C x P x P) 9 | float* img, // C x H x W 10 | unsigned int imgH, 11 | unsigned int imgW, 12 | unsigned int patchSize, 13 | unsigned int overlap); 14 | 15 | void generateQuiltedImages( 16 | float* result, // C x H x W 17 | long* neighbors, // M 18 | float* patchDict, // N x (C x P x P) 19 | unsigned int imgH, 20 | unsigned int imgW, 21 | unsigned int patchSize, 22 | unsigned int overlap, 23 | bool graphcut); 24 | 25 | #ifdef __cplusplus 26 | } 27 | #endif 28 | -------------------------------------------------------------------------------- /adversarial/lib/transformations/quilting.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import ctypes 9 | import torch 10 | 11 | # load seam-finding library: 12 | FINDSEAM_LIB = ctypes.cdll.LoadLibrary( 13 | 'libexperimental_deeplearning_lvdmaaten_adversarial_findseam.so') 14 | 15 | # other globals: 16 | LATTICE_CACHE = {} # cache lattices here 17 | 18 | 19 | # function that constructs a four-connected lattice: 20 | def __four_lattice__(height, width, use_cache=True): 21 | 22 | # try the cache first: 23 | if use_cache and (height, width) in LATTICE_CACHE: 24 | return LATTICE_CACHE[(height, width)] 25 | 26 | # assertions and initialization: 27 | assert type(width) == int and type(height) == int and \ 28 | width > 0 and height > 0, 'height and width should be positive integers' 29 | N = height * width 30 | height, width = width, height # tensors are in row-major format 31 | graph = { 32 | 'from': torch.LongTensor(4 * N - (height + width) * 2), 33 | 'to': torch.LongTensor(4 * N - (height + width) * 2), 34 | } 35 | 36 | # closure that copies stuff in: 37 | def add_edges(i, j, offset): 38 | graph['from'].narrow(0, offset, i.nelement()).copy_(i) 39 | graph['from'].narrow(0, offset + i.nelement(), j.nelement()).copy_(j) 40 | graph['to'].narrow(0, offset, j.nelement()).copy_(j) 41 | graph['to'].narrow(0, offset + j.nelement(), i.nelement()).copy_(i) 42 | 43 | # add vertical connections: 44 | i = torch.arange(0, N).squeeze().long() 45 | mask = torch.ByteTensor(N).fill_(1) 46 | mask.index_fill_(0, torch.arange(height - 1, N, height).squeeze().long(), 0) 47 | i = i[mask] 48 | add_edges(i, torch.add(i, 1), 0) 49 | 50 | # add horizontal connections: 51 | offset = 2 * i.nelement() 52 | i = torch.arange(0, N - height).squeeze().long() 53 | add_edges(i, torch.add(i, height), offset) 54 | 55 | # cache and return graph: 56 | if use_cache: 57 | LATTICE_CACHE[(height, width)] = graph 58 | return graph 59 | 60 | 61 | # utility function for checking inputs: 62 | def __assert_inputs__(im1, im2, mask=None): 63 | assert type(im1) == torch.ByteTensor or type(im1) == torch.FloatTensor, \ 64 | 'im1 should be a ByteTensor or FloatTensor' 65 | assert type(im2) == torch.ByteTensor or type(im2) == torch.FloatTensor, \ 66 | 'im2 should be a ByteTensor or FloatTensor' 67 | assert im1.dim() == 3, 'im1 should be three-dimensional' 68 | assert im2.dim() == 3, 'im2 should be three-dimensional' 69 | assert im1.size() == im2.size(), 'im1 and im2 should have same size' 70 | if mask is not None: 71 | assert mask.dim() == 2, 'mask should be two-dimensional' 72 | assert type(mask) == torch.ByteTensor, 'mask should be torch.ByteTensor' 73 | assert mask.size(0) == im1.size(1) and mask.size(1) == im1.size(2), \ 74 | 'mask should have same height and width as images' 75 | 76 | 77 | # function that finds seam between two images: 78 | def find_seam(im1, im2, mask): 79 | 80 | # assertions: 81 | __assert_inputs__(im1, im2, mask) 82 | im1 = im1.float() 83 | im2 = im2.float() 84 | 85 | # construct edge weights: 86 | graph = __four_lattice__(im1.size(1), im1.size(2)) 87 | values = torch.FloatTensor(graph['from'].size(0)).fill_(0.) 88 | for c in range(im1.size(0)): 89 | im1c = im1[c].contiguous().view(im1.size(1) * im1.size(2)) 90 | im2c = im2[c].contiguous().view(im2.size(1) * im2.size(2)) 91 | values.add_(torch.abs( 92 | im2c.index_select(0, graph['to']) - 93 | im1c.index_select(0, graph['from']) 94 | )) 95 | 96 | # construct terminal weights: 97 | idxim = torch.arange(0, mask.nelement()).long().view(mask.size()) 98 | tvalues = torch.FloatTensor(mask.nelement(), 2).fill_(0) 99 | for c in range(2): 100 | select_c = (mask == (c + 1)) 101 | if select_c.any(): 102 | tvalues.select(1, c).index_fill_(0, idxim[select_c], float('inf')) 103 | 104 | # convert graph to IntTensor (make sure this is not GC'ed): 105 | graph_from = graph['from'].int() 106 | graph_to = graph['to'].int() 107 | 108 | # run the Boykov algorithm to obtain stitching mask: 109 | labels = torch.IntTensor(mask.nelement()) 110 | FINDSEAM_LIB.findseam( 111 | ctypes.c_int(mask.nelement()), 112 | ctypes.c_int(values.nelement()), 113 | ctypes.c_void_p(graph_from.data_ptr()), 114 | ctypes.c_void_p(graph_to.data_ptr()), 115 | ctypes.c_void_p(values.data_ptr()), 116 | ctypes.c_void_p(tvalues.data_ptr()), 117 | ctypes.c_void_p(labels.data_ptr()), 118 | ) 119 | mask = labels.resize_(mask.size()).byte() 120 | return mask 121 | 122 | 123 | # function that performs the stitch: 124 | def __stitch__(im1, im2, overlap, y, x): 125 | 126 | # assertions: 127 | __assert_inputs__(im1, im2) 128 | 129 | # construct mask: 130 | patch_size = im1.size(1) 131 | mask = torch.ByteTensor(patch_size, patch_size).fill_(2) 132 | if y > 0: # there is not overlap at the border 133 | mask.narrow(0, 0, overlap).fill_(0) 134 | if x > 0: # there is not overlap at the border 135 | mask.narrow(1, 0, overlap).fill_(0) 136 | 137 | # seam the two patches: 138 | seam_mask = find_seam(im1, im2, mask) 139 | stitched_im = im1.clone() 140 | for c in range(stitched_im.size(0)): 141 | stitched_im[c][seam_mask == 1] = im2[c][seam_mask] 142 | return stitched_im 143 | 144 | 145 | # main quilting function: 146 | def quilting(img, faiss_index, patch_dict, patch_size=5, overlap=2, 147 | graphcut=False, patch_transform=None): 148 | 149 | # assertions: 150 | assert torch.is_tensor(img) 151 | assert torch.is_tensor(patch_dict) and patch_dict.dim() == 2 152 | assert type(patch_size) == int and patch_size > 0 153 | assert type(overlap) == int and overlap > 0 154 | assert patch_size > overlap 155 | if patch_transform is not None: 156 | assert callable(patch_transform) 157 | 158 | # gather all image patches: 159 | patches = [] 160 | y_range = range(0, img.size(1) - patch_size, patch_size - overlap) 161 | x_range = range(0, img.size(2) - patch_size, patch_size - overlap) 162 | for y in y_range: 163 | for x in range(0, img.size(2) - patch_size, patch_size - overlap): 164 | patch = img[:, y:y + patch_size, x:x + patch_size] 165 | if patch_transform is not None: 166 | patch = patch_transform(patch) 167 | patches.append(patch) 168 | 169 | # find nearest patches in faiss index: 170 | patches = torch.stack(patches, dim=0) 171 | patches = patches.view(patches.size(0), int(patches.nelement() / patches.size(0))) 172 | faiss_index.nprobe = 5 173 | _, neighbors = faiss_index.search(patches.numpy(), 1) 174 | neighbors = torch.LongTensor(neighbors).squeeze() 175 | if (neighbors == -1).any(): 176 | print('WARNING: %d out of %d neighbor searches failed.' % 177 | ((neighbors == -1).sum(), neighbors.nelement())) 178 | 179 | # piece the image back together: 180 | n = 0 181 | quilt_img = img.clone().fill_(0) 182 | for y in y_range: 183 | for x in x_range: 184 | if neighbors[n] != -1: 185 | 186 | # get current image and new patch: 187 | patch = patch_dict[neighbors[n]].view( 188 | img.size(0), patch_size, patch_size 189 | ) 190 | cur_img = quilt_img[:, y:y + patch_size, x:x + patch_size] 191 | 192 | # compute graph cut if requested: 193 | if graphcut: 194 | patch = __stitch__(cur_img, patch, overlap, y, x) 195 | 196 | # copy the patch into the image: 197 | cur_img.copy_(patch) 198 | n += 1 199 | 200 | # return the quilted image: 201 | return quilt_img 202 | -------------------------------------------------------------------------------- /adversarial/lib/transformations/quilting_fast.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import ctypes 9 | import torch 10 | import random 11 | import numpy 12 | import os 13 | 14 | import pkgutil 15 | if pkgutil.find_loader("adversarial") is not None: 16 | # If adversarial module is created by pip install 17 | QUILTING_LIB = ctypes.cdll.LoadLibrary(os.path.join(os.path.dirname(__file__), "libquilting.so")) 18 | else: 19 | try: 20 | QUILTING_LIB = ctypes.cdll.LoadLibrary('libquilting.so') 21 | except ImportError: 22 | raise ImportError("libquilting.so not found. Check build script") 23 | 24 | 25 | def generate_patches(img, patch_size, overlap): 26 | assert torch.is_tensor(img) and img.dim() == 3 27 | assert type(patch_size) == int and patch_size > 0 28 | assert type(overlap) == int and overlap > 0 29 | assert patch_size > overlap 30 | 31 | y_range = range(0, img.size(1) - patch_size, patch_size - overlap) 32 | x_range = range(0, img.size(2) - patch_size, patch_size - overlap) 33 | num_patches = len(y_range) * len(x_range) 34 | patches = torch.FloatTensor(num_patches, 3 * patch_size * patch_size).zero_() 35 | 36 | QUILTING_LIB.generatePatches( 37 | ctypes.c_void_p(patches.data_ptr()), 38 | ctypes.c_void_p(img.data_ptr()), 39 | ctypes.c_uint(img.size(1)), 40 | ctypes.c_uint(img.size(2)), 41 | ctypes.c_uint(patch_size), 42 | ctypes.c_uint(overlap) 43 | ) 44 | 45 | return patches 46 | 47 | 48 | def generate_quilted_images(neighbors, patch_dict, img_h, img_w, patch_size, 49 | overlap, graphcut=False, random_stitch=False): 50 | assert torch.is_tensor(neighbors) and neighbors.dim() == 1 51 | assert torch.is_tensor(patch_dict) and patch_dict.dim() == 2 52 | assert type(img_h) == int and img_h > 0 53 | assert type(img_w) == int and img_w > 0 54 | assert type(patch_size) == int and patch_size > 0 55 | assert type(overlap) == int and overlap > 0 56 | assert patch_size > overlap 57 | 58 | result = torch.FloatTensor(3, img_h, img_w).zero_() 59 | 60 | QUILTING_LIB.generateQuiltedImages( 61 | ctypes.c_void_p(result.data_ptr()), 62 | ctypes.c_void_p(neighbors.data_ptr()), 63 | ctypes.c_void_p(patch_dict.data_ptr()), 64 | ctypes.c_uint(img_h), 65 | ctypes.c_uint(img_w), 66 | ctypes.c_uint(patch_size), 67 | ctypes.c_uint(overlap), 68 | ctypes.c_bool(graphcut) 69 | ) 70 | 71 | return result 72 | 73 | 74 | def select_random_neighbor(neighbors): 75 | if len(neighbors.shape) == 1: 76 | # If only 1 neighbor per path is available then return 77 | return neighbors 78 | else: 79 | # Pick a neighbor randomly from top k neighbors for all queries 80 | nrows = neighbors.shape[0] 81 | ncols = neighbors.shape[1] 82 | random_patched_neighbors = numpy.zeros(nrows).astype('int') 83 | for i in range(0, nrows): 84 | col = random.randint(0, ncols - 1) 85 | random_patched_neighbors[i] = neighbors[i, col] 86 | return random_patched_neighbors 87 | 88 | 89 | # main quilting function: 90 | def quilting(img, faiss_index, patch_dict, patch_size=9, overlap=2, 91 | graphcut=False, k=1, random_stitch=False): 92 | 93 | # assertions: 94 | assert torch.is_tensor(img) 95 | assert torch.is_tensor(patch_dict) and patch_dict.dim() == 2 96 | assert type(patch_size) == int and patch_size > 0 97 | assert type(overlap) == int and overlap > 0 98 | assert patch_size > overlap 99 | 100 | # generate image patches 101 | patches = generate_patches(img, patch_size, overlap) 102 | 103 | # find nearest patches in faiss index: 104 | faiss_index.nprobe = 5 105 | # get top k neighbors of all queries 106 | _, neighbors = faiss_index.search(patches.numpy(), k) 107 | neighbors = select_random_neighbor(neighbors) 108 | neighbors = torch.LongTensor(neighbors).squeeze() 109 | if (neighbors == -1).any(): 110 | print('WARNING: %d out of %d neighbor searches failed.' % 111 | ((neighbors == -1).sum(), neighbors.nelement())) 112 | 113 | # stitch nn patches in the dict 114 | quilted_img = generate_quilted_images(neighbors, patch_dict, img.size(1), 115 | img.size(2), patch_size, overlap, 116 | graphcut) 117 | 118 | return quilted_img 119 | -------------------------------------------------------------------------------- /adversarial/lib/transformations/transformation_helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import torchvision.transforms as torch_trans 9 | import lib.transformations.transforms as transforms 10 | from lib.datasets.transform_dataset import TransformDataset 11 | 12 | 13 | # Initialize transformations to be applied to dataset 14 | def setup_transformations(args, data_type, defense, crop=None): 15 | if 'preprocessed_data' in args and args.preprocessed_data: 16 | assert defense is not None, ( 17 | "If data is already pre processed for defenses then " 18 | "defenses can't be None") 19 | if crop: 20 | assert callable(crop), "crop should be a callable method" 21 | 22 | transform = [] 23 | # setup transformation without adversary 24 | if 'adversary' not in args or args.adversary is None: 25 | if (data_type == 'train'): 26 | if 'preprocessed_data' in args and args.preprocessed_data: 27 | # Defenses are already applied on randomly cropped images 28 | transform.append(torch_trans.Scale(args.data_params['IMAGE_SIZE'])) 29 | else: 30 | transform.append( 31 | torch_trans.RandomSizedCrop(args.data_params['IMAGE_SIZE'])) 32 | 33 | transform.append(torch_trans.RandomHorizontalFlip()) 34 | transform.append(torch_trans.ToTensor()) 35 | 36 | else: # validation 37 | # No augmentation for validation 38 | if 'preprocessed_data' not in args or not args.preprocessed_data: 39 | transform.append(torch_trans.Scale(args.data_params['IMAGE_SCALE_SIZE'])) 40 | transform.append(torch_trans.CenterCrop( 41 | args.data_params['IMAGE_SIZE'])) 42 | transform.append(torch_trans.ToTensor()) 43 | if crop: 44 | transform.append(crop) 45 | 46 | transform.append(transforms.Scale(args.data_params['IMAGE_SIZE'])) 47 | 48 | # Apply defenses at runtime (VERY SLOW) 49 | # Prefer pre-processing and saving data, and then using it 50 | if ('preprocessed_data' in args and not args.preprocessed_data and 51 | defense is not None): 52 | transform = transform + [defense] 53 | 54 | else: # Adversarial images 55 | if crop is not None: 56 | transform.append(crop) 57 | 58 | transform.append(transforms.Scale(args.data_params['IMAGE_SIZE'], 59 | args.data_params['MEAN_STD'])) 60 | 61 | # Apply defenses at runtime (VERY SLOW) 62 | # Prefer pre-processing and saving data, and then using it 63 | if not args.preprocessed_data and defense is not None: 64 | transform.append(defense) 65 | 66 | if 'normalize' in args and args.normalize: 67 | transform.append( 68 | torch_trans.Normalize(mean=args.data_params['MEAN_STD']['MEAN'], 69 | std=args.data_params['MEAN_STD']['STD'])) 70 | 71 | if len(transform) == 0: 72 | transform = None 73 | else: 74 | transform = torch_trans.Compose(transform) 75 | 76 | return transform 77 | 78 | 79 | # Update dataset 80 | def update_dataset_transformation(dataset, args, data_type, 81 | defense, crop): 82 | 83 | # only supported for TransformDataset at the moment 84 | assert isinstance(dataset, TransformDataset), ( 85 | "updating datase transformation is only supported for TransformDataset" 86 | "for adversaries") 87 | 88 | assert data_type is not 'train', \ 89 | "updating datase transformation is not supported in training" 90 | 91 | transform = setup_transformations(args, data_type, defense, crop) 92 | dataset.update_transformation(transform=transform) 93 | -------------------------------------------------------------------------------- /adversarial/lib/transformations/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | # from PIL import Image 7 | import torchvision.transforms as torch_trans 8 | import random 9 | from torch import is_tensor 10 | 11 | CROP_TYPE = ['center', 'random', 'sliding'] 12 | 13 | 14 | class Crop(object): 15 | """Crops the given img tensor. 16 | Args: 17 | size (sequence or int): Desired output size of the crop. If size is an 18 | int instead of sequence like (h, w), a square crop (size, size) is 19 | made. 20 | crop_frac: crop fraction to crop from the image 21 | """ 22 | 23 | def __init__(self, crop_type=None, crop_frac=1.0, 24 | sliding_crop_position=None): 25 | assert crop_frac <= 1.0, \ 26 | "crop_frac can't be greater than 1.0" 27 | if sliding_crop_position is not None: 28 | # max positions are fixed to 9 29 | assert sliding_crop_position < 9 30 | 31 | assert (crop_type is None or crop_type in CROP_TYPE), ( 32 | "{} is not a valid crop_type".format(crop_type)) 33 | 34 | self.crop_type = crop_type 35 | self.crop_frac = crop_frac 36 | self.sliding_crop_position = sliding_crop_position 37 | 38 | def __call__(self, img): 39 | """ 40 | Args: 41 | img (Tensor): Image to be cropped. 42 | Returns: 43 | """ 44 | assert img is not None, "img should not be None" 45 | assert is_tensor(img), "Tensor expected" 46 | h = img.size(1) 47 | w = img.size(2) 48 | h2 = int(h * self.crop_frac) 49 | w2 = int(w * self.crop_frac) 50 | h_range = h - h2 51 | w_range = w - w2 52 | 53 | if self.crop_type == 'sliding': 54 | assert self.sliding_crop_position is not None 55 | row = int(self.sliding_crop_position / 3) 56 | col = self.sliding_crop_position % 3 57 | x = col * int(w_range / 2) 58 | y = row * int(h_range / 2) 59 | 60 | elif self.crop_type == 'random': 61 | x, y = random.randint(0, w_range), random.randint(0, h_range) 62 | 63 | elif self.crop_type == 'center': 64 | y = int(h_range / 2) 65 | x = int(w_range / 2) 66 | 67 | if self.crop_type is not None: 68 | img = img.narrow(1, y, h2).narrow(2, x, w2).clone() 69 | 70 | return img 71 | 72 | def update_sliding_position(self, sliding_crop_position): 73 | assert sliding_crop_position >= 0 and sliding_crop_position < 9, \ 74 | "Only 9 sliding positions supported" 75 | self.sliding_crop_position = sliding_crop_position 76 | 77 | 78 | class Scale(object): 79 | """Scale the given img tensor. 80 | Args: 81 | size (sequence or int): Desired output size of the crop. If size is an 82 | int instead of sequence like (h, w), a square crop (size, size) is 83 | made. 84 | """ 85 | def __init__(self, size, mean_std=None): 86 | 87 | if mean_std is not None: 88 | assert 'MEAN' in mean_std 89 | assert 'STD' in mean_std 90 | self.size = size 91 | self.mean_std = mean_std 92 | 93 | def __call__(self, img): 94 | """ 95 | Args: 96 | img (Tensor): Image to be cropped. 97 | Returns: 98 | 99 | """ 100 | assert img is not None, "img should not be None" 101 | assert is_tensor(img), "Tensor expected" 102 | 103 | if not img.size(1) == self.size: 104 | # TODO: We should not need to Unnormalize for scaling(validate if its true) 105 | if self.mean_std: 106 | img = Unnormalize(mean=self.mean_std['MEAN'], 107 | std=self.mean_std['STD'])(img) 108 | img = torch_trans.ToPILImage()(img) 109 | img = torch_trans.Scale(self.size)(img) 110 | img = torch_trans.ToTensor()(img) 111 | if self.mean_std: 112 | img = torch_trans.Normalize(mean=self.mean_std['MEAN'], 113 | std=self.mean_std['STD'])(img) 114 | 115 | return img 116 | 117 | 118 | class Unnormalize(object): 119 | def __init__(self, mean, std): 120 | self.mean = mean 121 | self.std = std 122 | 123 | def __call__(self, imgs): 124 | assert imgs is not None, "img should not be None" 125 | assert is_tensor(imgs), "Tensor expected" 126 | imgs_trans = imgs.clone() 127 | if len(imgs.size()) == 3: 128 | for i in range(imgs.size(0)): 129 | imgs_trans[i, :, :] = imgs_trans[i, :, :] * self.std[i] + self.mean[i] 130 | else: 131 | for i in range(imgs.size(1)): 132 | imgs_trans[:, i, :, :] = ((imgs_trans[:, i, :, :] * self.std[i]) + 133 | self.mean[i]) 134 | return imgs_trans 135 | 136 | 137 | class Normalize(object): 138 | def __init__(self, mean, std): 139 | self.mean = mean 140 | self.std = std 141 | 142 | def __call__(self, imgs): 143 | assert imgs is not None, "img should not be None" 144 | assert is_tensor(imgs), "Tensor expected" 145 | imgs_trans = imgs.clone() 146 | if len(imgs.size()) == 3: 147 | for i in range(imgs.size(0)): 148 | imgs_trans[i, :, :] = (imgs_trans[i, :, :] - self.mean[i]) / self.std[i] 149 | else: 150 | for i in range(imgs.size(1)): 151 | imgs_trans[:, i, :, :] = ((imgs_trans[:, i, :, :] - self.mean[i]) / 152 | self.std[i]) 153 | return imgs_trans 154 | -------------------------------------------------------------------------------- /adversarial/lib/transformations/tvm.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | import torch 6 | import numpy as np 7 | from skimage.restoration import denoise_tv_chambolle, denoise_tv_bregman 8 | from skimage.util import random_noise 9 | from scipy.optimize import minimize 10 | from skimage.util import img_as_float 11 | from skimage import color 12 | 13 | try: 14 | from lib.transformations.tv_bregman import _denoise_tv_bregman 15 | except ImportError: 16 | raise ImportError("tv_bregman not found. Check build script") 17 | 18 | 19 | def tv(x, p): 20 | f = np.linalg.norm(x[1:, :] - x[:-1, :], p, axis=1).sum() 21 | f += np.linalg.norm(x[:, 1:] - x[:, :-1], p, axis=0).sum() 22 | return f 23 | 24 | 25 | def tv_dx(x, p): 26 | if p == 1: 27 | x_diff0 = np.sign(x[1:, :] - x[:-1, :]) 28 | x_diff1 = np.sign(x[:, 1:] - x[:, :-1]) 29 | elif p > 1: 30 | x_diff0_norm = np.power(np.linalg.norm(x[1:, :] - x[:-1, :], p, axis=1), p - 1) 31 | x_diff1_norm = np.power(np.linalg.norm(x[:, 1:] - x[:, :-1], p, axis=0), p - 1) 32 | x_diff0_norm[x_diff0_norm < 1e-3] = 1e-3 33 | x_diff1_norm[x_diff1_norm < 1e-3] = 1e-3 34 | x_diff0_norm = np.repeat(x_diff0_norm[:, np.newaxis], x.shape[1], axis=1) 35 | x_diff1_norm = np.repeat(x_diff1_norm[np.newaxis, :], x.shape[0], axis=0) 36 | x_diff0 = p * np.power(x[1:, :] - x[:-1, :], p - 1) / x_diff0_norm 37 | x_diff1 = p * np.power(x[:, 1:] - x[:, :-1], p - 1) / x_diff1_norm 38 | df = np.zeros(x.shape) 39 | df[:-1, :] = -x_diff0 40 | df[1:, :] += x_diff0 41 | df[:, :-1] -= x_diff1 42 | df[:, 1:] += x_diff1 43 | return df 44 | 45 | 46 | def tv_l2(x, y, w, lam, p): 47 | f = 0.5 * np.power(x - y.flatten(), 2).dot(w.flatten()) 48 | x = np.reshape(x, y.shape) 49 | return f + lam * tv(x, p) 50 | 51 | 52 | def tv_l2_dx(x, y, w, lam, p): 53 | x = np.reshape(x, y.shape) 54 | df = (x - y) * w 55 | return df.flatten() + lam * tv_dx(x, p).flatten() 56 | 57 | 58 | def tv_inf(x, y, lam, p, tau): 59 | x = np.reshape(x, y.shape) 60 | return tau + lam * tv(x, p) 61 | 62 | 63 | def tv_inf_dx(x, y, lam, p, tau): 64 | x = np.reshape(x, y.shape) 65 | return lam * tv_dx(x, p).flatten() 66 | 67 | 68 | def minimize_tv(img, w, lam=0.01, p=2, solver='L-BFGS-B', maxiter=100, verbose=False): 69 | x_opt = np.copy(img) 70 | if solver == 'L-BFGS-B' or solver == 'CG' or solver == 'Newton-CG': 71 | for i in range(img.shape[2]): 72 | options = {'disp': verbose, 'maxiter': maxiter} 73 | res = minimize( 74 | tv_l2, x_opt[:, :, i], (img[:, :, i], w[:, :, i], lam, p), 75 | method=solver, jac=tv_l2_dx, options=options).x 76 | x_opt[:, :, i] = np.reshape(res, x_opt[:, :, i].shape) 77 | else: 78 | print('unsupported solver ' + solver) 79 | exit() 80 | return x_opt 81 | 82 | 83 | def minimize_tv_inf(img, w, tau=0.1, lam=0.01, p=2, solver='L-BFGS-B', maxiter=100, 84 | verbose=False): 85 | x_opt = np.copy(img) 86 | if solver == 'L-BFGS-B' or solver == 'CG' or solver == 'Newton-CG': 87 | for i in range(img.shape[2]): 88 | options = {'disp': verbose, 'maxiter': maxiter} 89 | lower = img[:, :, i] - tau 90 | upper = img[:, :, i] + tau 91 | lower[w[:, :, i] < 1e-6] = 0 92 | upper[w[:, :, i] < 1e-6] = 1 93 | bounds = np.array([lower.flatten(), upper.flatten()]).transpose() 94 | res = minimize( 95 | tv_inf, x_opt[:, :, i], (img[:, :, i], lam, p, tau), 96 | method=solver, bounds=bounds, jac=tv_inf_dx, options=options).x 97 | x_opt[:, :, i] = np.reshape(res, x_opt[:, :, i].shape) 98 | else: 99 | print('unsupported solver ' + solver) 100 | exit() 101 | return x_opt 102 | 103 | 104 | def minimize_tv_bregman(img, mask, weight, maxiter=100, gsiter=10, eps=0.001, 105 | isotropic=True): 106 | img = img_as_float(img) 107 | mask = mask.astype('uint8', order='C') 108 | return _denoise_tv_bregman(img, mask, weight, maxiter, gsiter, eps, isotropic) 109 | 110 | 111 | # applies TV reconstruction 112 | def reconstruct(img, drop_rate, recons, weight, drop_rate_post=0, lab=False, 113 | verbose=False, input_filepath=''): 114 | assert torch.is_tensor(img) 115 | temp = np.rollaxis(img.numpy(), 0, 3) 116 | w = np.ones_like(temp) 117 | if drop_rate > 0: 118 | # independent channel/pixel salt and pepper 119 | temp2 = random_noise(temp, 's&p', amount=drop_rate, salt_vs_pepper=0) 120 | # per-pixel all channel salt and pepper 121 | r = temp2 - temp 122 | w = (np.absolute(r) < 1e-6).astype('float') 123 | temp = temp + r 124 | if lab: 125 | temp = color.rgb2lab(temp) 126 | if recons == 'none': 127 | temp = temp 128 | elif recons == 'chambolle': 129 | temp = denoise_tv_chambolle(temp, weight=weight, multichannel=True) 130 | elif recons == 'bregman': 131 | if drop_rate == 0: 132 | temp = denoise_tv_bregman(temp, weight=1 / weight, isotropic=True) 133 | else: 134 | temp = minimize_tv_bregman( 135 | temp, w, weight=1 / weight, gsiter=10, eps=0.01, isotropic=True) 136 | elif recons == 'tvl2': 137 | temp = minimize_tv(temp, w, lam=weight, p=2, solver='L-BFGS-B', verbose=verbose) 138 | elif recons == 'tvinf': 139 | temp = minimize_tv_inf( 140 | temp, w, tau=weight, p=2, solver='L-BFGS-B', verbose=verbose) 141 | else: 142 | print('unsupported reconstruction method ' + recons) 143 | exit() 144 | if lab: 145 | temp = color.lab2rgb(temp) 146 | # temp = random_noise(temp, 's&p', amount=drop_rate_post, salt_vs_pepper=0) 147 | temp = torch.from_numpy(np.rollaxis(temp, 2, 0)).float() 148 | return temp 149 | -------------------------------------------------------------------------------- /adversarial/lib/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import os 14 | import tempfile 15 | 16 | import torch 17 | 18 | # constants: 19 | CHECKPOINT_FILE = 'checkpoint.torch' 20 | 21 | 22 | # function that measures top-k accuracy: 23 | def accuracy(output, target, topk=(1,)): 24 | maxk = max(topk) 25 | batch_size = target.size(0) 26 | _, pred = output.topk(maxk, 1, True, True) 27 | pred = pred.t() 28 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 29 | res = [] 30 | for k in topk: 31 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 32 | res.append(correct_k.mul_(100. / batch_size)) 33 | return res 34 | 35 | 36 | # function that tries to load a checkpoint: 37 | def load_checkpoint(checkpoint_folder): 38 | 39 | # read what the latest model file is: 40 | filename = os.path.join(checkpoint_folder, CHECKPOINT_FILE) 41 | if not os.path.exists(filename): 42 | return None 43 | 44 | # load and return the checkpoint: 45 | return torch.load(filename) 46 | 47 | 48 | # function that saves checkpoint: 49 | def save_checkpoint(checkpoint_folder, state): 50 | 51 | # make sure that we have a checkpoint folder: 52 | if not os.path.isdir(checkpoint_folder): 53 | try: 54 | os.makedirs(checkpoint_folder) 55 | except BaseException: 56 | print('| WARNING: could not create directory %s' % checkpoint_folder) 57 | if not os.path.isdir(checkpoint_folder): 58 | return False 59 | 60 | # write checkpoint atomically: 61 | try: 62 | with tempfile.NamedTemporaryFile( 63 | 'w', dir=checkpoint_folder, delete=False) as fwrite: 64 | tmp_filename = fwrite.name 65 | torch.save(state, fwrite.name) 66 | os.rename(tmp_filename, os.path.join(checkpoint_folder, CHECKPOINT_FILE)) 67 | return True 68 | except BaseException: 69 | print('| WARNING: could not write checkpoint to %s.' % checkpoint_folder) 70 | return False 71 | 72 | 73 | # function that adjusts the learning rate: 74 | def adjust_learning_rate(base_lr, epoch, optimizer, lr_decay, lr_decay_stepsize): 75 | lr = base_lr * (lr_decay ** (epoch // lr_decay_stepsize)) 76 | for param_group in optimizer.param_groups: 77 | param_group['lr'] = lr 78 | 79 | 80 | # adversary functions 81 | # computes SSIM for a single block 82 | def SSIM(x, y): 83 | x = x.resize_(x.size(0), x.size(1) * x.size(2) * x.size(3)) 84 | y = y.resize_(y.size(0), y.size(1) * y.size(2) * y.size(3)) 85 | N = x.size(1) 86 | mu_x = x.mean(1) 87 | mu_y = y.mean(1) 88 | sigma_x = x.std(1) 89 | sigma_y = y.std(1) 90 | sigma_xy = ((x - mu_x.expand_as(x)) * (y - mu_y.expand_as(y))).sum(1) / (N - 1) 91 | ssim = (2 * mu_x * mu_y) * (2 * sigma_xy) 92 | ssim = ssim / (mu_x.pow(2) + mu_y.pow(2)) 93 | ssim = ssim / (sigma_x.pow(2) + sigma_y.pow(2)) 94 | return ssim 95 | 96 | 97 | # mean SSIM using local block averaging 98 | def MSSIM(x, y, window_size=16, stride=4): 99 | ssim = torch.zeros(x.size(0)) 100 | L = x.size(2) 101 | W = x.size(3) 102 | x_inds = torch.arange(0, L - window_size + 1, stride).long() 103 | y_inds = torch.arange(0, W - window_size + 1, stride).long() 104 | for i in x_inds: 105 | for j in y_inds: 106 | x_sub = x[:, :, i:(i + window_size), j:(j + window_size)] 107 | y_sub = y[:, :, i:(i + window_size), j:(j + window_size)] 108 | ssim = ssim + SSIM(x_sub, y_sub) 109 | return ssim / x_inds.size(0) / y_inds.size(0) 110 | 111 | 112 | # forwards input through model to get probabilities 113 | def get_probs(model, imgs, output_prob=False): 114 | softmax = torch.nn.Softmax(1) 115 | # probs = torch.zeros(imgs.size(0), n_classes) 116 | imgsvar = torch.autograd.Variable(imgs.squeeze(), volatile=True) 117 | output = model(imgsvar) 118 | if output_prob: 119 | probs = output.data.cpu() 120 | else: 121 | probs = softmax.forward(output).data.cpu() 122 | 123 | return probs 124 | 125 | 126 | # calls get_probs to get predictions 127 | def get_labels(model, input, output_prob=False): 128 | probs = get_probs(model, input, output_prob) 129 | _, label = probs.max(1) 130 | return label.squeeze() 131 | -------------------------------------------------------------------------------- /adversarial/test/images/sample/lena_quilting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookarchive/adversarial_image_defenses/55bf56ddb017535fb6630a746c6f946202336052/adversarial/test/images/sample/lena_quilting.png -------------------------------------------------------------------------------- /adversarial/test/images/sample/lena_tvm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookarchive/adversarial_image_defenses/55bf56ddb017535fb6630a746c6f946202336052/adversarial/test/images/sample/lena_tvm.png -------------------------------------------------------------------------------- /adversarial/test/images/train/class_name/lena-copy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookarchive/adversarial_image_defenses/55bf56ddb017535fb6630a746c6f946202336052/adversarial/test/images/train/class_name/lena-copy.png -------------------------------------------------------------------------------- /adversarial/test/images/train/class_name/lena.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookarchive/adversarial_image_defenses/55bf56ddb017535fb6630a746c6f946202336052/adversarial/test/images/train/class_name/lena.png -------------------------------------------------------------------------------- /adversarial/test/images/val/class_name/lena-copy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookarchive/adversarial_image_defenses/55bf56ddb017535fb6630a746c6f946202336052/adversarial/test/images/val/class_name/lena-copy.png -------------------------------------------------------------------------------- /adversarial/test/images/val/class_name/lena.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookarchive/adversarial_image_defenses/55bf56ddb017535fb6630a746c6f946202336052/adversarial/test/images/val/class_name/lena.png -------------------------------------------------------------------------------- /adversarial/train_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.optim 16 | from lib.convnet import train, test 17 | from lib.dataset import load_dataset, get_data_loader 18 | from lib.defenses import get_defense 19 | from lib.util import adjust_learning_rate, save_checkpoint 20 | import lib.opts as opts 21 | from lib.model import get_model 22 | 23 | 24 | def _get_optimizer(model, args): 25 | if args.model.startswith('inceptionv4'): 26 | optimizer = torch.optim.RMSprop( 27 | model.parameters, lr=args.lr, 28 | alpha=args.rms_alpha, eps=args.rms_eps) 29 | else: 30 | optimizer = torch.optim.SGD( 31 | model.parameters(), args.lr, 32 | momentum=args.momentum, weight_decay=args.weight_decay, 33 | ) 34 | 35 | return optimizer 36 | 37 | 38 | # run all the things: 39 | def train_model(args): 40 | 41 | # At max 1 defense as no ensembling in training 42 | assert args.defenses is None or len(args.defenses) == 1 43 | defense_name = None if not args.defenses else args.defenses[0] 44 | defense = get_defense(defense_name, args) 45 | 46 | # Load model 47 | model, start_epoch, optimizer_ = get_model( 48 | args, load_checkpoint=args.resume, defense_name=defense_name, training=True) 49 | 50 | # set up optimizer: 51 | optimizer = _get_optimizer(model, args) 52 | 53 | # get from checkpoint if available 54 | if start_epoch and optimizer: 55 | args.start_epoch = start_epoch 56 | optimizer.load_state_dict(optimizer_) 57 | 58 | # set up criterion: 59 | criterion = nn.CrossEntropyLoss() 60 | 61 | if args.device == 'gpu': 62 | # Call .cuda() method on model 63 | criterion = criterion.cuda() 64 | model = model.cuda() 65 | 66 | loaders = {} 67 | 68 | # set up start-of-epoch hook: 69 | def start_epoch_hook(epoch, model, optimizer): 70 | print('| epoch %d, training:' % epoch) 71 | adjust_learning_rate( 72 | args.lr, epoch, optimizer, 73 | args.lr_decay, args.lr_decay_stepsize 74 | ) 75 | 76 | # set up the end-of-epoch hook: 77 | def end_epoch_hook(epoch, model, optimizer, prec1=None, prec5=None): 78 | 79 | # print training error: 80 | if prec1 is not None: 81 | print('| training error @1 (epoch %d): %2.5f' % (epoch, 100. - prec1)) 82 | if prec5 is not None: 83 | print('| training error @5 (epoch %d): %2.5f' % (epoch, 100. - prec5)) 84 | 85 | # save checkpoint: 86 | print('| epoch %d, testing:' % epoch) 87 | save_checkpoint(args.models_root, { 88 | 'epoch': epoch + 1, 89 | 'model_name': args.model, 90 | 'model_state_dict': model.state_dict(), 91 | 'optimizer': optimizer.state_dict(), 92 | }) 93 | 94 | # measure validation error: 95 | prec1, prec5 = test(model, loaders['valid']) 96 | print('| validation error @1 (epoch %d: %2.5f' % (epoch, 100. - prec1)) 97 | print('| validation error @5 (epoch %d: %2.5f' % (epoch, 100. - prec5)) 98 | 99 | def data_loader_hook(epoch): 100 | # Reload data loader for epoch 101 | if args.preprocessed_epoch_data: 102 | print('| epoch %d, Loading data:' % epoch) 103 | for key in {'train', 'valid'}: 104 | # Load validation data only once 105 | if key == 'valid' and 'valid' in loaders: 106 | break 107 | loaders[key] = get_data_loader( 108 | load_dataset(args, key, defense, epoch=epoch), 109 | batchsize=args.batchsize, 110 | device=args.device, 111 | shuffle=True, 112 | ) 113 | # if data needs to be loaded only once and is not yet loaded 114 | elif len(loaders) == 0: 115 | print('| epoch %d, Loading data:' % epoch) 116 | for key in {'train', 'valid'}: 117 | loaders[key] = get_data_loader( 118 | load_dataset(args, key, defense), 119 | batchsize=args.batchsize, 120 | device=args.device, 121 | shuffle=True, 122 | ) 123 | 124 | return loaders['train'] 125 | 126 | # train the model: 127 | print('| training model...') 128 | train(model, criterion, optimizer, 129 | start_epoch_hook=start_epoch_hook, 130 | end_epoch_hook=end_epoch_hook, 131 | data_loader_hook=data_loader_hook, 132 | start_epoch=args.start_epoch, 133 | end_epoch=args.end_epoch, 134 | learning_rate=args.lr) 135 | print('| done.') 136 | 137 | 138 | # run all the things: 139 | if __name__ == '__main__': 140 | # parse input arguments: 141 | args = opts.parse_args(opts.OptType.TRAIN) 142 | train_model(args) 143 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | from setuptools import setup, Extension 14 | from setuptools.command.build_ext import build_ext 15 | import os 16 | import subprocess 17 | import zipfile 18 | import sys 19 | if sys.version_info[0] == 3: # for python3 20 | import urllib.request as urllib 21 | py3 = True 22 | else: # for python2 23 | import urllib as urllib 24 | py3 = False 25 | 26 | ext_code_download = True 27 | 28 | readme = open('README.md').read() 29 | 30 | TRANSFORMATION_DIR = str("adversarial/lib/transformations") 31 | cwd = os.path.dirname(os.path.abspath(__file__)) 32 | 33 | # Download 3rd party code (Respective licenses are applicable) 34 | if ext_code_download: 35 | # Download inception_v4 and inceptionresnetv2 36 | MODEL_DIR = str("adversarial/lib/models") 37 | INCEPTION_V4_URL = ("https://raw.githubusercontent.com/Cadene/" 38 | "tensorflow-model-zoo.torch/" 39 | "f43005c4b4cdd745e9788b22e182c91453c54daf/inceptionv4" 40 | "/pytorch_load.py") 41 | INCEPTION_RESNET_V2_URL = ("https://raw.githubusercontent.com/Cadene/" 42 | "tensorflow-model-zoo.torch/" 43 | "f43005c4b4cdd745e9788b22e182c91453c54daf/" 44 | "inceptionresnetv2/pytorch_load.py") 45 | urlopener = urllib.URLopener() 46 | urlopener.retrieve(INCEPTION_V4_URL, os.path.join(MODEL_DIR, "inceptionv4.py")) 47 | urlopener.retrieve(INCEPTION_RESNET_V2_URL, 48 | os.path.join(MODEL_DIR, "inceptionresnetv2.py")) 49 | 50 | # Download denoising code for tv_bregman from scikit-image 51 | DENOISE_URL = ("https://raw.githubusercontent.com/scikit-image/scikit-image/" 52 | "902a9a68add274c4125a358b29e3263b9d94f686/skimage/" 53 | "restoration/_denoise_cy.pyx") 54 | urlopener = urllib.URLopener() 55 | urlopener.retrieve(DENOISE_URL, os.path.join(TRANSFORMATION_DIR, "_denoise_cy.pyx")) 56 | # Apply patch to support TVM compressed sensing 57 | # _tv_bregman.patch was created from commit 902a9a6 58 | # would need to be updated if the source gets updated 59 | cmd = ("(cd adversarial/lib/transformations && patch -p0 -o tv_bregman.pyx) " 60 | "< adversarial/lib/transformations/_tv_bregman.patch") 61 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, shell=True) 62 | process.communicate() 63 | 64 | # Download and unzip the code for maxflow 65 | MAXFLOW_URL = "http://mouse.cs.uwaterloo.ca/code/maxflow-v3.01.zip" 66 | urlopener = urllib.URLopener() 67 | maxflow_file = os.path.join(TRANSFORMATION_DIR, "maxflow.zip") 68 | urlopener.retrieve(MAXFLOW_URL, maxflow_file) 69 | zip_ref = zipfile.ZipFile(maxflow_file, 'r') 70 | zip_ref.extractall(TRANSFORMATION_DIR) 71 | zip_ref.close() 72 | 73 | # Create Extension to build quilting code 74 | quilting_c_src = [ 75 | str(os.path.join(TRANSFORMATION_DIR, 'graph.cpp')), 76 | str(os.path.join(TRANSFORMATION_DIR, 'maxflow.cpp')), 77 | str(os.path.join(TRANSFORMATION_DIR, 'quilting.cpp')), 78 | str(os.path.join(TRANSFORMATION_DIR, 'findseam.cpp')) 79 | ] 80 | 81 | include_dirs = [ 82 | cwd, 83 | os.path.join(cwd, TRANSFORMATION_DIR), 84 | ] 85 | library_dirs = [os.path.join(cwd, 'adversarial', 'lib')] 86 | c_compile_args = [str('-std=c++11')] 87 | 88 | extensions = [] 89 | quilting_ext = Extension(str("libquilting"), 90 | sources=quilting_c_src, 91 | language='c++', 92 | include_dirs=include_dirs, 93 | library_dirs=library_dirs, 94 | extra_compile_args=c_compile_args) 95 | extensions = [quilting_ext] 96 | 97 | # Create Extension to build cython code for TVM 98 | cython_ext = Extension(str('tv_bregman'), 99 | sources=[str("adversarial/lib/transformations/tv_bregman.pyx")], 100 | ) 101 | extensions.append(cython_ext) 102 | 103 | requirements = ['pillow', 'torchvision', 'scipy', 'scikit-image', 104 | 'Cython', 'enum34'] 105 | if py3: 106 | requirements += ['progressbar33'] 107 | else: 108 | requirements += ['progressbar'] 109 | 110 | 111 | # cython Extension needs numpy include dir path 112 | # this will be called only after installing numpy from setup_requires 113 | class CustomBuildExt(build_ext): 114 | def finalize_options(self): 115 | build_ext.finalize_options(self) 116 | # Prevent numpy from thinking it is still in its setup process: 117 | __builtins__.__NUMPY_SETUP__ = False 118 | import numpy 119 | self.include_dirs.append(numpy.get_include()) 120 | 121 | 122 | setup( 123 | # Metadata 124 | name="adversarial", 125 | version="0.1.0", 126 | author="Mayank Rana", 127 | author_email="mayankrana@fb.com", 128 | url="https://github.com/facebookresearch/adversarial_image_defenses", 129 | description="Code for Countering Adversarial Images using Input Transformations", 130 | long_description=readme, 131 | license='CC-BY-4.0', 132 | 133 | # Package Info 134 | packages=['adversarial'], 135 | package_dir={'adversarial': 'adversarial'}, 136 | include_package_data=True, 137 | zip_safe=False, 138 | install_requires=requirements, 139 | setup_requires=['setuptools>=18.0', 'cython', 'numpy'], 140 | cmdclass={'build_ext': CustomBuildExt}, 141 | ext_package='adversarial.lib.transformations', 142 | ext_modules=extensions, 143 | ) 144 | --------------------------------------------------------------------------------