├── .gitignore ├── LICENSE ├── README.md ├── config.json ├── eval.py ├── fetch_model.py ├── model.py ├── model_robustml.py ├── pgd_attack.py ├── run_attack.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # data files 2 | MNIST_DATA 3 | 4 | # attack files 5 | *.npy 6 | 7 | # model files 8 | models 9 | 10 | # compiled python files 11 | *.pyc 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Aleksander Madry, Aleksandar Makelov, Ludwig Schmidt, Dimitris Tsipras, and Adrian Vladu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MNIST Adversarial Examples Challenge 2 | 3 | Recently, there has been much progress on adversarial *attacks* against neural networks, such as the [cleverhans](https://github.com/tensorflow/cleverhans) library and the code by [Carlini and Wagner](https://github.com/carlini/nn_robust_attacks). 4 | We now complement these advances by proposing an *attack challenge* for the 5 | [MNIST](http://yann.lecun.com/exdb/mnist/) dataset (we recently released [a 6 | CIFAR10 variant of this 7 | challenge](https://github.com/MadryLab/cifar10_challenge)). 8 | We have trained a robust network, and the objective is to find a set of adversarial examples on which this network achieves only a low accuracy. 9 | To train an adversarially-robust network, we followed the approach from our recent paper: 10 | 11 | **Towards Deep Learning Models Resistant to Adversarial Attacks**
12 | *Aleksander Madry, Aleksandar Makelov, Ludwig Schmidt, Dimitris Tsipras, Adrian Vladu*
13 | https://arxiv.org/abs/1706.06083. 14 | 15 | As part of the challenge, we release both the training code and the network architecture, but keep the network weights secret. 16 | We invite any researcher to submit attacks against our model (see the detailed instructions below). 17 | We will maintain a leaderboard of the best attacks for the next two months and then publish our secret network weights. 18 | 19 | The goal of our challenge is to clarify the state-of-the-art for adversarial robustness on MNIST. Moreover, we hope that future work on defense mechanisms will adopt a similar challenge format in order to improve reproducibility and empirical comparisons. 20 | 21 | **Update 2022-05-03:** We will no longer be accepting submissions to this challenge. 22 | 23 | **Update 2017-09-14:** Due to recently increased interest in our challenge, we are extending its duration until October 15th. 24 | 25 | **Update 2017-10-19:** We released our secret model, you can download it by 26 | running `python fetch_model.py secret`. As of Oct 15 we are no longer 27 | accepting black-box challenge submissions. We will soon set up a leaderboard to keep track 28 | of white-box attacks. Many thanks to everyone who participated! 29 | 30 | **Update 2017-11-06:** We have set up a leaderboard for white-box attacks on the (now released) secret model. The submission format is the same as before. We plan to continue evaluating submissions and maintaining the leaderboard for the foreseeable future. 31 | 32 | ## Black-Box Leaderboard (Original Challenge) 33 | 34 | | Attack | Submitted by | Accuracy | Submission Date | 35 | | -------------------------------------- | ------------- | -------- | ---- | 36 | | AdvGAN from ["Generating Adversarial Examples
with Adversarial Networks"](https://arxiv.org/abs/1801.02610) | AdvGAN | **92.76%** | Sep 25, 2017 | 37 | | PGD against three independently and
adversarially trained copies of the network | [Florian Tramèr](http://floriantramer.com/) | 93.54% | Jul 5, 2017 | 38 | | FGSM on the [CW](https://github.com/carlini/nn_robust_attacks) loss for model B from
["Ensemble Adversarial Training [...]"](https://arxiv.org/abs/1705.07204) | [Florian Tramèr](http://floriantramer.com/) | 94.36% | Jun 29, 2017 | 39 | | FGSM on the [CW](https://github.com/carlini/nn_robust_attacks) loss for the
naturally trained public network | (initial entry) | 96.08% | Jun 28, 2017 | 40 | | PGD on the cross-entropy loss for the
naturally trained public network | (initial entry) | 96.81% | Jun 28, 2017 | 41 | | Attack using Gaussian Filter for selected pixels
on the adversarially trained public network | Anonymous | 97.33% | Aug 27, 2017 | 42 | | FGSM on the cross-entropy loss for the
adversarially trained public network | (initial entry) | 97.66% | Jun 28, 2017 | 43 | | PGD on the cross-entropy loss for the
adversarially trained public network | (initial entry) | 97.79% | Jun 28, 2017 | 44 | 45 | ## White-Box Leaderboard 46 | 47 | | Attack | Submitted by | Accuracy | Submission Date | 48 | | -------------------------------------- | ------------- | -------- | ---- | 49 | | Guided Local Attack | Siyuan Yi | **88.00%** | Aug 30, 2021 | 50 | | [PCROS Attack](https://github.com/wan-lab/PCROS) | Chen Wan | 88.04% | Oct 28, 2020 | 51 | | Adaptive [Distributionally Adversarial Attack](https://github.com/tianzheng4/Distributionally-Adversarial-Attack) | Tianhang Zheng | 88.06% | Feb 29, 2019 | 52 | | [PGD attack with Output Diversified Initialization](https://arxiv.org/abs/2003.06878) | Yusuke Tashiro | 88.13% | Feb 15, 2020 | 53 | | [Square Attack](https://github.com/max-andr/square-attack) | Francesco Croce | 88.25% | Jan 14, 2020 | 54 | | First-Order Adversary with Quantized Gradients | Zhuanghua Liu | 88.32% | Oct 16, 2019 | 55 | | [MultiTargeted](https://arxiv.org/abs/1910.09338) | Sven Gowal | 88.36% | Aug 28, 2019 | 56 | | [Interval Attacks](https://github.com/tcwangshiqi-columbia/Interval-Attack) | [Shiqi Wang](https://www.cs.columbia.edu/~tcwangshiqi/) | 88.42% | Feb 28, 2019 | 57 | | [Distributionally Adversarial Attack](https://github.com/tianzheng4/Distributionally-Adversarial-Attack)
merging multiple hyperparameters | Tianhang Zheng | 88.56% | Jan 13, 2019 | 58 | | [Interval Attacks](https://github.com/tcwangshiqi-columbia/Interval-Attack) | [Shiqi Wang](https://www.cs.columbia.edu/~tcwangshiqi/) | 88.59% | Jan 6, 2019 | 59 | | [Distributionally Adversarial Attack](https://github.com/tianzheng4/Distributionally-Adversarial-Attack) | Tianhang Zheng | 88.79% | Aug 13, 2018 | 60 | | First-order attack on logit difference
for optimally chosen target label | Samarth Gupta | 88.85% | May 23, 2018 | 61 | | 100-step PGD on the cross-entropy loss
with 50 random restarts | (initial entry) | 89.62% | Nov 6, 2017 | 62 | | 100-step PGD on the [CW](https://github.com/carlini/nn_robust_attacks) loss
with 50 random restarts | (initial entry) | 89.71% | Nov 6, 2017 | 63 | | 100-step PGD on the cross-entropy loss | (initial entry) | 92.52% | Nov 6, 2017 | 64 | | 100-step PGD on the [CW](https://github.com/carlini/nn_robust_attacks) loss | (initial entry) | 93.04% | Nov 6, 2017 | 65 | | FGSM on the cross-entropy loss | (initial entry) | 96.36% | Nov 6, 2017 | 66 | | FGSM on the [CW](https://github.com/carlini/nn_robust_attacks) loss | (initial entry) | 96.40% | Nov 6, 2017 | 67 | 68 | ## Format and Rules 69 | 70 | The objective of the challenge is to find black-box (transfer) attacks that are effective against our MNIST model. 71 | Attacks are allowed to perturb each pixel of the input image by at most `epsilon=0.3`. 72 | To ensure that the attacks are indeed black-box, we release our training code and model architecture, but keep the actual network weights secret. 73 | 74 | We invite any interested researchers to submit attacks against our model. 75 | The most successful attacks will be listed in the leaderboard above. 76 | As a reference point, we have seeded the leaderboard with the results of some standard attacks. 77 | 78 | ### The MNIST Model 79 | 80 | We used the code published in this repository to produce an adversarially robust model for MNIST classification. The model is a convolutional neural network consisting of two convolutional layers (each followed by max-pooling) and a fully connected layer. This architecture is derived from the [MNIST tensorflow tutorial](https://www.tensorflow.org/get_started/mnist/pros). 81 | The network was trained against an iterative adversary that is allowed to perturb each pixel by at most `epsilon=0.3`. 82 | 83 | The random seed used for training and the trained network weights will be kept secret. 84 | 85 | The `sha256()` digest of our model file is: 86 | ``` 87 | 14eea09c72092db5c2eb5e34cd105974f42569281d2f34826316e356d057f96d 88 | ``` 89 | We will release the corresponding model file on October 15th 2017, which is roughly two months after the start of this competition. 90 | 91 | ### The Attack Model 92 | 93 | We are interested in adversarial inputs that are derived from the MNIST test set. 94 | Each pixel can be perturbed by at most `epsilon=0.3` from its initial value. 95 | All pixels can be perturbed independently, so this is an l_infinity attack. 96 | 97 | ### Submitting an Attack 98 | 99 | Each attack should consist of a perturbed version of the MNIST test set. 100 | Each perturbed image in this test set should follow the above attack model. 101 | 102 | The adversarial test set should be formated as a numpy array with one row per example and each row containing a flattened 103 | array of 28x28 pixels. 104 | Hence the overall dimensions are 10,000 rows and 784 columns. 105 | Each pixel must be in the [0,1] range. 106 | See the script `pgd_attack.py` for an attack that generates an adversarial test set in this format. 107 | 108 | In order to submit your attack, save the matrix containing your adversarial examples with `numpy.save` and email the resulting file to mnist.challenge@gmail.com. 109 | We will then run the `run_attack.py` script on your file to verify that the attack is valid and to evaluate the accuracy of our secret model on your examples. 110 | After that, we will reply with the predictions of our model on each of your examples and the overall accuracy of our model on your evaluation set. 111 | 112 | If the attack is valid and outperforms all current attacks in the leaderboard, it will appear at the top of the leaderboard. 113 | Novel types of attacks might be included in the leaderboard even if they do not perform best. 114 | 115 | We strongly encourage you to disclose your attack method. 116 | We would be happy to add a link to your code in our leaderboard. 117 | 118 | ## Overview of the Code 119 | The code consists of six Python scripts and the file `config.json` that contains various parameter settings. 120 | 121 | ### Running the code 122 | - `python train.py`: trains the network, storing checkpoints along 123 | the way. 124 | - `python eval.py`: an infinite evaluation loop, processing each new 125 | checkpoint as it is created while logging summaries. It is intended 126 | to be run in parallel with the `train.py` script. 127 | - `python pgd_attack.py`: applies the attack to the MNIST eval set and 128 | stores the resulting adversarial eval set in a `.npy` file. This file is 129 | in a valid attack format for our challenge. 130 | - `python run_attack.py`: evaluates the model on the examples in 131 | the `.npy` file specified in config, while ensuring that the adversarial examples 132 | are indeed a valid attack. The script also saves the network predictions in `pred.npy`. 133 | - `python fetch_model.py name`: downloads the pre-trained model with the 134 | specified name (at the moment `adv_trained` or `natural`), prints the sha256 135 | hash, and places it in the models directory. 136 | 137 | ### Parameters in `config.json` 138 | 139 | Model configuration: 140 | - `model_dir`: contains the path to the directory of the currently 141 | trained/evaluated model. 142 | 143 | Training configuration: 144 | - `random_seed`: the seed for the RNG used to initialize the network 145 | weights. 146 | - `max_num_training_steps`: the number of training steps. 147 | - `num_output_steps`: the number of training steps between printing 148 | progress in standard output. 149 | - `num_summary_steps`: the number of training steps between storing 150 | tensorboard summaries. 151 | - `num_checkpoint_steps`: the number of training steps between storing 152 | model checkpoints. 153 | - `training_batch_size`: the size of the training batch. 154 | 155 | Evaluation configuration: 156 | - `num_eval_examples`: the number of MNIST examples to evaluate the 157 | model on. 158 | - `eval_batch_size`: the size of the evaluation batches. 159 | - `eval_on_cpu`: forces the `eval.py` script to run on the CPU so it does not compete with `train.py` for GPU resources. 160 | 161 | Adversarial examples configuration: 162 | - `epsilon`: the maximum allowed perturbation per pixel. 163 | - `k`: the number of PGD iterations used by the adversary. 164 | - `a`: the size of the PGD adversary steps. 165 | - `random_start`: specifies whether the adversary will start iterating 166 | from the natural example or a random perturbation of it. 167 | - `loss_func`: the loss function used to run pgd on. `xent` corresponds to the 168 | standard cross-entropy loss, `cw` corresponds to the loss function 169 | of [Carlini and Wagner](https://arxiv.org/abs/1608.04644). 170 | - `store_adv_path`: the file in which adversarial examples are stored. 171 | Relevant for the `pgd_attack.py` and `run_attack.py` scripts. 172 | 173 | ## Example usage 174 | After cloning the repository you can either train a new network or evaluate/attack one of our pre-trained networks. 175 | #### Training a new network 176 | * Start training by running: 177 | ``` 178 | python train.py 179 | ``` 180 | * (Optional) Evaluation summaries can be logged by simultaneously 181 | running: 182 | ``` 183 | python eval.py 184 | ``` 185 | #### Download a pre-trained network 186 | * For an adversarially trained network, run 187 | ``` 188 | python fetch_model.py adv_trained 189 | ``` 190 | and use the `config.json` file to set `"model_dir": "models/adv_trained"`. 191 | * For a naturally trained network, run 192 | ``` 193 | python fetch_model.py natural 194 | ``` 195 | and use the `config.json` file to set `"model_dir": "models/natural"`. 196 | #### Test the network 197 | * Create an attack file by running 198 | ``` 199 | python pgd_attack.py 200 | ``` 201 | * Evaluate the network with 202 | ``` 203 | python run_attack.py 204 | ``` 205 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_comment": "===== MODEL CONFIGURATION =====", 3 | "model_dir": "models/a_very_robust_model", 4 | 5 | "_comment": "===== TRAINING CONFIGURATION =====", 6 | "random_seed": 4557077, 7 | "max_num_training_steps": 100000, 8 | "num_output_steps": 100, 9 | "num_summary_steps": 100, 10 | "num_checkpoint_steps": 300, 11 | "training_batch_size": 50, 12 | 13 | "_comment": "===== EVAL CONFIGURATION =====", 14 | "num_eval_examples": 10000, 15 | "eval_batch_size": 200, 16 | "eval_on_cpu": true, 17 | 18 | "_comment": "=====ADVERSARIAL EXAMPLES CONFIGURATION=====", 19 | "epsilon": 0.3, 20 | "k": 40, 21 | "a": 0.01, 22 | "random_start": true, 23 | "loss_func": "xent", 24 | "store_adv_path": "attack.npy" 25 | } 26 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Infinite evaluation loop going through the checkpoints in the model directory 3 | as they appear and evaluating them. Accuracy and average loss are printed and 4 | added as tensorboard summaries. 5 | """ 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | from datetime import datetime 11 | import json 12 | import math 13 | import os 14 | import sys 15 | import time 16 | 17 | import tensorflow as tf 18 | from tensorflow.examples.tutorials.mnist import input_data 19 | 20 | from model import Model 21 | from pgd_attack import LinfPGDAttack 22 | 23 | # Global constants 24 | with open('config.json') as config_file: 25 | config = json.load(config_file) 26 | num_eval_examples = config['num_eval_examples'] 27 | eval_batch_size = config['eval_batch_size'] 28 | eval_on_cpu = config['eval_on_cpu'] 29 | 30 | model_dir = config['model_dir'] 31 | 32 | # Set upd the data, hyperparameters, and the model 33 | mnist = input_data.read_data_sets('MNIST_data', one_hot=False) 34 | 35 | if eval_on_cpu: 36 | with tf.device("/cpu:0"): 37 | model = Model() 38 | attack = LinfPGDAttack(model, 39 | config['epsilon'], 40 | config['k'], 41 | config['a'], 42 | config['random_start'], 43 | config['loss_func']) 44 | else: 45 | model = Model() 46 | attack = LinfPGDAttack(model, 47 | config['epsilon'], 48 | config['k'], 49 | config['a'], 50 | config['random_start'], 51 | config['loss_func']) 52 | 53 | global_step = tf.contrib.framework.get_or_create_global_step() 54 | 55 | # Setting up the Tensorboard and checkpoint outputs 56 | if not os.path.exists(model_dir): 57 | os.makedirs(model_dir) 58 | eval_dir = os.path.join(model_dir, 'eval') 59 | if not os.path.exists(eval_dir): 60 | os.makedirs(eval_dir) 61 | 62 | last_checkpoint_filename = '' 63 | already_seen_state = False 64 | 65 | saver = tf.train.Saver() 66 | summary_writer = tf.summary.FileWriter(eval_dir) 67 | 68 | # A function for evaluating a single checkpoint 69 | def evaluate_checkpoint(filename): 70 | with tf.Session() as sess: 71 | # Restore the checkpoint 72 | saver.restore(sess, filename) 73 | 74 | # Iterate over the samples batch-by-batch 75 | num_batches = int(math.ceil(num_eval_examples / eval_batch_size)) 76 | total_xent_nat = 0. 77 | total_xent_adv = 0. 78 | total_corr_nat = 0 79 | total_corr_adv = 0 80 | 81 | for ibatch in range(num_batches): 82 | bstart = ibatch * eval_batch_size 83 | bend = min(bstart + eval_batch_size, num_eval_examples) 84 | 85 | x_batch = mnist.test.images[bstart:bend, :] 86 | y_batch = mnist.test.labels[bstart:bend] 87 | 88 | dict_nat = {model.x_input: x_batch, 89 | model.y_input: y_batch} 90 | 91 | x_batch_adv = attack.perturb(x_batch, y_batch, sess) 92 | 93 | dict_adv = {model.x_input: x_batch_adv, 94 | model.y_input: y_batch} 95 | 96 | cur_corr_nat, cur_xent_nat = sess.run( 97 | [model.num_correct,model.xent], 98 | feed_dict = dict_nat) 99 | cur_corr_adv, cur_xent_adv = sess.run( 100 | [model.num_correct,model.xent], 101 | feed_dict = dict_adv) 102 | 103 | total_xent_nat += cur_xent_nat 104 | total_xent_adv += cur_xent_adv 105 | total_corr_nat += cur_corr_nat 106 | total_corr_adv += cur_corr_adv 107 | 108 | avg_xent_nat = total_xent_nat / num_eval_examples 109 | avg_xent_adv = total_xent_adv / num_eval_examples 110 | acc_nat = total_corr_nat / num_eval_examples 111 | acc_adv = total_corr_adv / num_eval_examples 112 | 113 | summary = tf.Summary(value=[ 114 | tf.Summary.Value(tag='xent adv eval', simple_value= avg_xent_adv), 115 | tf.Summary.Value(tag='xent adv', simple_value= avg_xent_adv), 116 | tf.Summary.Value(tag='xent nat', simple_value= avg_xent_nat), 117 | tf.Summary.Value(tag='accuracy adv eval', simple_value= acc_adv), 118 | tf.Summary.Value(tag='accuracy adv', simple_value= acc_adv), 119 | tf.Summary.Value(tag='accuracy nat', simple_value= acc_nat)]) 120 | summary_writer.add_summary(summary, global_step.eval(sess)) 121 | 122 | print('natural: {:.2f}%'.format(100 * acc_nat)) 123 | print('adversarial: {:.2f}%'.format(100 * acc_adv)) 124 | print('avg nat loss: {:.4f}'.format(avg_xent_nat)) 125 | print('avg adv loss: {:.4f}'.format(avg_xent_adv)) 126 | 127 | # Infinite eval loop 128 | while True: 129 | cur_checkpoint = tf.train.latest_checkpoint(model_dir) 130 | 131 | # Case 1: No checkpoint yet 132 | if cur_checkpoint is None: 133 | if not already_seen_state: 134 | print('No checkpoint yet, waiting ...', end='') 135 | already_seen_state = True 136 | else: 137 | print('.', end='') 138 | sys.stdout.flush() 139 | time.sleep(10) 140 | # Case 2: Previously unseen checkpoint 141 | elif cur_checkpoint != last_checkpoint_filename: 142 | print('\nCheckpoint {}, evaluating ... ({})'.format(cur_checkpoint, 143 | datetime.now())) 144 | sys.stdout.flush() 145 | last_checkpoint_filename = cur_checkpoint 146 | already_seen_state = False 147 | evaluate_checkpoint(cur_checkpoint) 148 | # Case 3: Previously evaluated checkpoint 149 | else: 150 | if not already_seen_state: 151 | print('Waiting for the next checkpoint ... ({}) '.format( 152 | datetime.now()), 153 | end='') 154 | already_seen_state = True 155 | else: 156 | print('.', end='') 157 | sys.stdout.flush() 158 | time.sleep(10) 159 | -------------------------------------------------------------------------------- /fetch_model.py: -------------------------------------------------------------------------------- 1 | """Downloads a model, computes its SHA256 hash and unzips it 2 | at the proper location.""" 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import sys 8 | import zipfile 9 | import hashlib 10 | 11 | if len(sys.argv) != 2 or sys.argv[1] not in ['natural', 12 | 'adv_trained', 13 | 'secret']: 14 | print('Usage: python fetch_model.py [natural, adv_trained, secret]') 15 | sys.exit(1) 16 | 17 | if sys.argv[1] == 'natural': 18 | url = 'https://github.com/MadryLab/mnist_challenge_models/raw/master/natural.zip' 19 | elif sys.argv[1] == 'secret': 20 | url = 'https://github.com/MadryLab/mnist_challenge_models/raw/master/secret.zip' 21 | else: # fetch adv_trained model 22 | url = 'https://github.com/MadryLab/mnist_challenge_models/raw/master/adv_trained.zip' 23 | 24 | fname = url.split('/')[-1] # get the name of the file 25 | 26 | # model download 27 | print('Downloading models') 28 | if sys.version_info >= (3,): 29 | import urllib.request 30 | urllib.request.urlretrieve(url, fname) 31 | else: 32 | import urllib 33 | urllib.urlretrieve(url, fname) 34 | 35 | # computing model hash 36 | sha256 = hashlib.sha256() 37 | with open(fname, 'rb') as f: 38 | data = f.read() 39 | sha256.update(data) 40 | print('SHA256 hash: {}'.format(sha256.hexdigest())) 41 | 42 | # extracting model 43 | print('Extracting model') 44 | with zipfile.ZipFile(fname, 'r') as model_zip: 45 | model_zip.extractall() 46 | print('Extracted model in {}'.format(model_zip.namelist()[0])) 47 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | The model is adapted from the tensorflow tutorial: 3 | https://www.tensorflow.org/get_started/mnist/pros 4 | """ 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import tensorflow as tf 10 | 11 | class Model(object): 12 | def __init__(self): 13 | self.x_input = tf.placeholder(tf.float32, shape = [None, 784]) 14 | self.y_input = tf.placeholder(tf.int64, shape = [None]) 15 | 16 | self.x_image = tf.reshape(self.x_input, [-1, 28, 28, 1]) 17 | 18 | # first convolutional layer 19 | W_conv1 = self._weight_variable([5,5,1,32]) 20 | b_conv1 = self._bias_variable([32]) 21 | 22 | h_conv1 = tf.nn.relu(self._conv2d(self.x_image, W_conv1) + b_conv1) 23 | h_pool1 = self._max_pool_2x2(h_conv1) 24 | 25 | # second convolutional layer 26 | W_conv2 = self._weight_variable([5,5,32,64]) 27 | b_conv2 = self._bias_variable([64]) 28 | 29 | h_conv2 = tf.nn.relu(self._conv2d(h_pool1, W_conv2) + b_conv2) 30 | h_pool2 = self._max_pool_2x2(h_conv2) 31 | 32 | # first fully connected layer 33 | W_fc1 = self._weight_variable([7 * 7 * 64, 1024]) 34 | b_fc1 = self._bias_variable([1024]) 35 | 36 | h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) 37 | h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 38 | 39 | # output layer 40 | W_fc2 = self._weight_variable([1024,10]) 41 | b_fc2 = self._bias_variable([10]) 42 | 43 | self.pre_softmax = tf.matmul(h_fc1, W_fc2) + b_fc2 44 | 45 | y_xent = tf.nn.sparse_softmax_cross_entropy_with_logits( 46 | labels=self.y_input, logits=self.pre_softmax) 47 | 48 | self.xent = tf.reduce_sum(y_xent) 49 | 50 | self.y_pred = tf.argmax(self.pre_softmax, 1) 51 | 52 | correct_prediction = tf.equal(self.y_pred, self.y_input) 53 | 54 | self.num_correct = tf.reduce_sum(tf.cast(correct_prediction, tf.int64)) 55 | self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 56 | 57 | @staticmethod 58 | def _weight_variable(shape): 59 | initial = tf.truncated_normal(shape, stddev=0.1) 60 | return tf.Variable(initial) 61 | 62 | @staticmethod 63 | def _bias_variable(shape): 64 | initial = tf.constant(0.1, shape = shape) 65 | return tf.Variable(initial) 66 | 67 | @staticmethod 68 | def _conv2d(x, W): 69 | return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME') 70 | 71 | @staticmethod 72 | def _max_pool_2x2( x): 73 | return tf.nn.max_pool(x, 74 | ksize = [1,2,2,1], 75 | strides=[1,2,2,1], 76 | padding='SAME') 77 | -------------------------------------------------------------------------------- /model_robustml.py: -------------------------------------------------------------------------------- 1 | import robustml 2 | import tensorflow as tf 3 | 4 | import model 5 | 6 | class Model(robustml.model.Model): 7 | def __init__(self, sess): 8 | self._model = model.Model() 9 | 10 | saver = tf.train.Saver() 11 | checkpoint = tf.train.latest_checkpoint('models/secret') 12 | saver.restore(sess, checkpoint) 13 | 14 | self._sess = sess 15 | self._input = self._model.x_input 16 | self._logits = self._model.pre_softmax 17 | self._predictions = self._model.y_pred 18 | self._dataset = robustml.dataset.MNIST() 19 | self._threat_model = robustml.threat_model.Linf(epsilon=0.3) 20 | 21 | @property 22 | def dataset(self): 23 | return self._dataset 24 | 25 | @property 26 | def threat_model(self): 27 | return self._threat_model 28 | 29 | def classify(self, x): 30 | return self._sess.run(self._predictions, 31 | {self._input: x})[0] 32 | 33 | # expose attack interface 34 | 35 | @property 36 | def input(self): 37 | return self._input 38 | 39 | @property 40 | def logits(self): 41 | return self._logits 42 | 43 | @property 44 | def predictions(self): 45 | return self._predictions 46 | -------------------------------------------------------------------------------- /pgd_attack.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of attack methods. Running this file as a program will 3 | apply the attack to the model specified by the config file and store 4 | the examples in an .npy file. 5 | """ 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import tensorflow as tf 11 | import numpy as np 12 | 13 | 14 | class LinfPGDAttack: 15 | def __init__(self, model, epsilon, k, a, random_start, loss_func): 16 | """Attack parameter initialization. The attack performs k steps of 17 | size a, while always staying within epsilon from the initial 18 | point.""" 19 | self.model = model 20 | self.epsilon = epsilon 21 | self.k = k 22 | self.a = a 23 | self.rand = random_start 24 | 25 | if loss_func == 'xent': 26 | loss = model.xent 27 | elif loss_func == 'cw': 28 | label_mask = tf.one_hot(model.y_input, 29 | 10, 30 | on_value=1.0, 31 | off_value=0.0, 32 | dtype=tf.float32) 33 | correct_logit = tf.reduce_sum(label_mask * model.pre_softmax, axis=1) 34 | wrong_logit = tf.reduce_max((1-label_mask) * model.pre_softmax 35 | - 1e4*label_mask, axis=1) 36 | loss = -tf.nn.relu(correct_logit - wrong_logit + 50) 37 | else: 38 | print('Unknown loss function. Defaulting to cross-entropy') 39 | loss = model.xent 40 | 41 | self.grad = tf.gradients(loss, model.x_input)[0] 42 | 43 | def perturb(self, x_nat, y, sess): 44 | """Given a set of examples (x_nat, y), returns a set of adversarial 45 | examples within epsilon of x_nat in l_infinity norm.""" 46 | if self.rand: 47 | x = x_nat + np.random.uniform(-self.epsilon, self.epsilon, x_nat.shape) 48 | x = np.clip(x, 0, 1) # ensure valid pixel range 49 | else: 50 | x = np.copy(x_nat) 51 | 52 | for i in range(self.k): 53 | grad = sess.run(self.grad, feed_dict={self.model.x_input: x, 54 | self.model.y_input: y}) 55 | 56 | x += self.a * np.sign(grad) 57 | 58 | x = np.clip(x, x_nat - self.epsilon, x_nat + self.epsilon) 59 | x = np.clip(x, 0, 1) # ensure valid pixel range 60 | 61 | return x 62 | 63 | 64 | if __name__ == '__main__': 65 | import json 66 | import sys 67 | import math 68 | 69 | from tensorflow.examples.tutorials.mnist import input_data 70 | 71 | from model import Model 72 | 73 | with open('config.json') as config_file: 74 | config = json.load(config_file) 75 | 76 | model_file = tf.train.latest_checkpoint(config['model_dir']) 77 | if model_file is None: 78 | print('No model found') 79 | sys.exit() 80 | 81 | model = Model() 82 | attack = LinfPGDAttack(model, 83 | config['epsilon'], 84 | config['k'], 85 | config['a'], 86 | config['random_start'], 87 | config['loss_func']) 88 | saver = tf.train.Saver() 89 | 90 | mnist = input_data.read_data_sets('MNIST_data', one_hot=False) 91 | 92 | with tf.Session() as sess: 93 | # Restore the checkpoint 94 | saver.restore(sess, model_file) 95 | 96 | # Iterate over the samples batch-by-batch 97 | num_eval_examples = config['num_eval_examples'] 98 | eval_batch_size = config['eval_batch_size'] 99 | num_batches = int(math.ceil(num_eval_examples / eval_batch_size)) 100 | 101 | x_adv = [] # adv accumulator 102 | 103 | print('Iterating over {} batches'.format(num_batches)) 104 | 105 | for ibatch in range(num_batches): 106 | bstart = ibatch * eval_batch_size 107 | bend = min(bstart + eval_batch_size, num_eval_examples) 108 | print('batch size: {}'.format(bend - bstart)) 109 | 110 | x_batch = mnist.test.images[bstart:bend, :] 111 | y_batch = mnist.test.labels[bstart:bend] 112 | 113 | x_batch_adv = attack.perturb(x_batch, y_batch, sess) 114 | 115 | x_adv.append(x_batch_adv) 116 | 117 | print('Storing examples') 118 | path = config['store_adv_path'] 119 | x_adv = np.concatenate(x_adv, axis=0) 120 | np.save(path, x_adv) 121 | print('Examples stored in {}'.format(path)) 122 | -------------------------------------------------------------------------------- /run_attack.py: -------------------------------------------------------------------------------- 1 | """Evaluates a model against examples from a .npy file as specified 2 | in config.json""" 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from datetime import datetime 8 | import json 9 | import math 10 | import os 11 | import sys 12 | import time 13 | 14 | import tensorflow as tf 15 | from tensorflow.examples.tutorials.mnist import input_data 16 | 17 | import numpy as np 18 | 19 | from model import Model 20 | 21 | def run_attack(checkpoint, x_adv, epsilon): 22 | mnist = input_data.read_data_sets('MNIST_data', one_hot=False) 23 | 24 | model = Model() 25 | 26 | saver = tf.train.Saver() 27 | 28 | num_eval_examples = 10000 29 | eval_batch_size = 64 30 | 31 | num_batches = int(math.ceil(num_eval_examples / eval_batch_size)) 32 | total_corr = 0 33 | 34 | x_nat = mnist.test.images 35 | l_inf = np.amax(np.abs(x_nat - x_adv)) 36 | 37 | if l_inf > epsilon + 0.0001: 38 | print('maximum perturbation found: {}'.format(l_inf)) 39 | print('maximum perturbation allowed: {}'.format(epsilon)) 40 | return 41 | 42 | y_pred = [] # label accumulator 43 | 44 | with tf.Session() as sess: 45 | # Restore the checkpoint 46 | saver.restore(sess, checkpoint) 47 | 48 | # Iterate over the samples batch-by-batch 49 | for ibatch in range(num_batches): 50 | bstart = ibatch * eval_batch_size 51 | bend = min(bstart + eval_batch_size, num_eval_examples) 52 | 53 | x_batch = x_adv[bstart:bend, :] 54 | y_batch = mnist.test.labels[bstart:bend] 55 | 56 | dict_adv = {model.x_input: x_batch, 57 | model.y_input: y_batch} 58 | cur_corr, y_pred_batch = sess.run([model.num_correct, model.y_pred], 59 | feed_dict=dict_adv) 60 | 61 | total_corr += cur_corr 62 | y_pred.append(y_pred_batch) 63 | 64 | accuracy = total_corr / num_eval_examples 65 | 66 | print('Accuracy: {:.2f}%'.format(100.0 * accuracy)) 67 | y_pred = np.concatenate(y_pred, axis=0) 68 | np.save('pred.npy', y_pred) 69 | print('Output saved at pred.npy') 70 | 71 | if __name__ == '__main__': 72 | import json 73 | 74 | with open('config.json') as config_file: 75 | config = json.load(config_file) 76 | 77 | model_dir = config['model_dir'] 78 | 79 | checkpoint = tf.train.latest_checkpoint(model_dir) 80 | x_adv = np.load(config['store_adv_path']) 81 | 82 | if checkpoint is None: 83 | print('No checkpoint found') 84 | elif x_adv.shape != (10000, 784): 85 | print('Invalid shape: expected (10000,784), found {}'.format(x_adv.shape)) 86 | elif np.amax(x_adv) > 1.0001 or \ 87 | np.amin(x_adv) < -0.0001 or \ 88 | np.isnan(np.amax(x_adv)): 89 | print('Invalid pixel range. Expected [0, 1], found [{}, {}]'.format( 90 | np.amin(x_adv), 91 | np.amax(x_adv))) 92 | else: 93 | run_attack(checkpoint, x_adv, config['epsilon']) 94 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Trains a model, saving checkpoints and tensorboard summaries along 2 | the way.""" 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from datetime import datetime 8 | import json 9 | import os 10 | import shutil 11 | from timeit import default_timer as timer 12 | 13 | import tensorflow as tf 14 | import numpy as np 15 | from tensorflow.examples.tutorials.mnist import input_data 16 | 17 | from model import Model 18 | from pgd_attack import LinfPGDAttack 19 | 20 | with open('config.json') as config_file: 21 | config = json.load(config_file) 22 | 23 | # Setting up training parameters 24 | tf.set_random_seed(config['random_seed']) 25 | 26 | max_num_training_steps = config['max_num_training_steps'] 27 | num_output_steps = config['num_output_steps'] 28 | num_summary_steps = config['num_summary_steps'] 29 | num_checkpoint_steps = config['num_checkpoint_steps'] 30 | 31 | batch_size = config['training_batch_size'] 32 | 33 | # Setting up the data and the model 34 | mnist = input_data.read_data_sets('MNIST_data', one_hot=False) 35 | global_step = tf.contrib.framework.get_or_create_global_step() 36 | model = Model() 37 | 38 | # Setting up the optimizer 39 | train_step = tf.train.AdamOptimizer(1e-4).minimize(model.xent, 40 | global_step=global_step) 41 | 42 | # Set up adversary 43 | attack = LinfPGDAttack(model, 44 | config['epsilon'], 45 | config['k'], 46 | config['a'], 47 | config['random_start'], 48 | config['loss_func']) 49 | 50 | # Setting up the Tensorboard and checkpoint outputs 51 | model_dir = config['model_dir'] 52 | if not os.path.exists(model_dir): 53 | os.makedirs(model_dir) 54 | 55 | # We add accuracy and xent twice so we can easily make three types of 56 | # comparisons in Tensorboard: 57 | # - train vs eval (for a single run) 58 | # - train of different runs 59 | # - eval of different runs 60 | 61 | saver = tf.train.Saver(max_to_keep=3) 62 | tf.summary.scalar('accuracy adv train', model.accuracy) 63 | tf.summary.scalar('accuracy adv', model.accuracy) 64 | tf.summary.scalar('xent adv train', model.xent / batch_size) 65 | tf.summary.scalar('xent adv', model.xent / batch_size) 66 | tf.summary.image('images adv train', model.x_image) 67 | merged_summaries = tf.summary.merge_all() 68 | 69 | shutil.copy('config.json', model_dir) 70 | 71 | with tf.Session() as sess: 72 | # Initialize the summary writer, global variables, and our time counter. 73 | summary_writer = tf.summary.FileWriter(model_dir, sess.graph) 74 | sess.run(tf.global_variables_initializer()) 75 | training_time = 0.0 76 | 77 | # Main training loop 78 | for ii in range(max_num_training_steps): 79 | x_batch, y_batch = mnist.train.next_batch(batch_size) 80 | 81 | # Compute Adversarial Perturbations 82 | start = timer() 83 | x_batch_adv = attack.perturb(x_batch, y_batch, sess) 84 | end = timer() 85 | training_time += end - start 86 | 87 | nat_dict = {model.x_input: x_batch, 88 | model.y_input: y_batch} 89 | 90 | adv_dict = {model.x_input: x_batch_adv, 91 | model.y_input: y_batch} 92 | 93 | # Output to stdout 94 | if ii % num_output_steps == 0: 95 | nat_acc = sess.run(model.accuracy, feed_dict=nat_dict) 96 | adv_acc = sess.run(model.accuracy, feed_dict=adv_dict) 97 | print('Step {}: ({})'.format(ii, datetime.now())) 98 | print(' training nat accuracy {:.4}%'.format(nat_acc * 100)) 99 | print(' training adv accuracy {:.4}%'.format(adv_acc * 100)) 100 | if ii != 0: 101 | print(' {} examples per second'.format( 102 | num_output_steps * batch_size / training_time)) 103 | training_time = 0.0 104 | # Tensorboard summaries 105 | if ii % num_summary_steps == 0: 106 | summary = sess.run(merged_summaries, feed_dict=adv_dict) 107 | summary_writer.add_summary(summary, global_step.eval(sess)) 108 | 109 | # Write a checkpoint 110 | if ii % num_checkpoint_steps == 0: 111 | saver.save(sess, 112 | os.path.join(model_dir, 'checkpoint'), 113 | global_step=global_step) 114 | 115 | # Actual training step 116 | start = timer() 117 | sess.run(train_step, feed_dict=adv_dict) 118 | end = timer() 119 | training_time += end - start 120 | --------------------------------------------------------------------------------