├── .gitignore
├── LICENSE
├── README.md
├── config.json
├── eval.py
├── fetch_model.py
├── model.py
├── model_robustml.py
├── pgd_attack.py
├── run_attack.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # data files
2 | MNIST_DATA
3 |
4 | # attack files
5 | *.npy
6 |
7 | # model files
8 | models
9 |
10 | # compiled python files
11 | *.pyc
12 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Aleksander Madry, Aleksandar Makelov, Ludwig Schmidt, Dimitris Tsipras, and Adrian Vladu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MNIST Adversarial Examples Challenge
2 |
3 | Recently, there has been much progress on adversarial *attacks* against neural networks, such as the [cleverhans](https://github.com/tensorflow/cleverhans) library and the code by [Carlini and Wagner](https://github.com/carlini/nn_robust_attacks).
4 | We now complement these advances by proposing an *attack challenge* for the
5 | [MNIST](http://yann.lecun.com/exdb/mnist/) dataset (we recently released [a
6 | CIFAR10 variant of this
7 | challenge](https://github.com/MadryLab/cifar10_challenge)).
8 | We have trained a robust network, and the objective is to find a set of adversarial examples on which this network achieves only a low accuracy.
9 | To train an adversarially-robust network, we followed the approach from our recent paper:
10 |
11 | **Towards Deep Learning Models Resistant to Adversarial Attacks**
12 | *Aleksander Madry, Aleksandar Makelov, Ludwig Schmidt, Dimitris Tsipras, Adrian Vladu*
13 | https://arxiv.org/abs/1706.06083.
14 |
15 | As part of the challenge, we release both the training code and the network architecture, but keep the network weights secret.
16 | We invite any researcher to submit attacks against our model (see the detailed instructions below).
17 | We will maintain a leaderboard of the best attacks for the next two months and then publish our secret network weights.
18 |
19 | The goal of our challenge is to clarify the state-of-the-art for adversarial robustness on MNIST. Moreover, we hope that future work on defense mechanisms will adopt a similar challenge format in order to improve reproducibility and empirical comparisons.
20 |
21 | **Update 2022-05-03:** We will no longer be accepting submissions to this challenge.
22 |
23 | **Update 2017-09-14:** Due to recently increased interest in our challenge, we are extending its duration until October 15th.
24 |
25 | **Update 2017-10-19:** We released our secret model, you can download it by
26 | running `python fetch_model.py secret`. As of Oct 15 we are no longer
27 | accepting black-box challenge submissions. We will soon set up a leaderboard to keep track
28 | of white-box attacks. Many thanks to everyone who participated!
29 |
30 | **Update 2017-11-06:** We have set up a leaderboard for white-box attacks on the (now released) secret model. The submission format is the same as before. We plan to continue evaluating submissions and maintaining the leaderboard for the foreseeable future.
31 |
32 | ## Black-Box Leaderboard (Original Challenge)
33 |
34 | | Attack | Submitted by | Accuracy | Submission Date |
35 | | -------------------------------------- | ------------- | -------- | ---- |
36 | | AdvGAN from ["Generating Adversarial Examples
with Adversarial Networks"](https://arxiv.org/abs/1801.02610) | AdvGAN | **92.76%** | Sep 25, 2017 |
37 | | PGD against three independently and
adversarially trained copies of the network | [Florian Tramèr](http://floriantramer.com/) | 93.54% | Jul 5, 2017 |
38 | | FGSM on the [CW](https://github.com/carlini/nn_robust_attacks) loss for model B from
["Ensemble Adversarial Training [...]"](https://arxiv.org/abs/1705.07204) | [Florian Tramèr](http://floriantramer.com/) | 94.36% | Jun 29, 2017 |
39 | | FGSM on the [CW](https://github.com/carlini/nn_robust_attacks) loss for the
naturally trained public network | (initial entry) | 96.08% | Jun 28, 2017 |
40 | | PGD on the cross-entropy loss for the
naturally trained public network | (initial entry) | 96.81% | Jun 28, 2017 |
41 | | Attack using Gaussian Filter for selected pixels
on the adversarially trained public network | Anonymous | 97.33% | Aug 27, 2017 |
42 | | FGSM on the cross-entropy loss for the
adversarially trained public network | (initial entry) | 97.66% | Jun 28, 2017 |
43 | | PGD on the cross-entropy loss for the
adversarially trained public network | (initial entry) | 97.79% | Jun 28, 2017 |
44 |
45 | ## White-Box Leaderboard
46 |
47 | | Attack | Submitted by | Accuracy | Submission Date |
48 | | -------------------------------------- | ------------- | -------- | ---- |
49 | | Guided Local Attack | Siyuan Yi | **88.00%** | Aug 30, 2021 |
50 | | [PCROS Attack](https://github.com/wan-lab/PCROS) | Chen Wan | 88.04% | Oct 28, 2020 |
51 | | Adaptive [Distributionally Adversarial Attack](https://github.com/tianzheng4/Distributionally-Adversarial-Attack) | Tianhang Zheng | 88.06% | Feb 29, 2019 |
52 | | [PGD attack with Output Diversified Initialization](https://arxiv.org/abs/2003.06878) | Yusuke Tashiro | 88.13% | Feb 15, 2020 |
53 | | [Square Attack](https://github.com/max-andr/square-attack) | Francesco Croce | 88.25% | Jan 14, 2020 |
54 | | First-Order Adversary with Quantized Gradients | Zhuanghua Liu | 88.32% | Oct 16, 2019 |
55 | | [MultiTargeted](https://arxiv.org/abs/1910.09338) | Sven Gowal | 88.36% | Aug 28, 2019 |
56 | | [Interval Attacks](https://github.com/tcwangshiqi-columbia/Interval-Attack) | [Shiqi Wang](https://www.cs.columbia.edu/~tcwangshiqi/) | 88.42% | Feb 28, 2019 |
57 | | [Distributionally Adversarial Attack](https://github.com/tianzheng4/Distributionally-Adversarial-Attack)
merging multiple hyperparameters | Tianhang Zheng | 88.56% | Jan 13, 2019 |
58 | | [Interval Attacks](https://github.com/tcwangshiqi-columbia/Interval-Attack) | [Shiqi Wang](https://www.cs.columbia.edu/~tcwangshiqi/) | 88.59% | Jan 6, 2019 |
59 | | [Distributionally Adversarial Attack](https://github.com/tianzheng4/Distributionally-Adversarial-Attack) | Tianhang Zheng | 88.79% | Aug 13, 2018 |
60 | | First-order attack on logit difference
for optimally chosen target label | Samarth Gupta | 88.85% | May 23, 2018 |
61 | | 100-step PGD on the cross-entropy loss
with 50 random restarts | (initial entry) | 89.62% | Nov 6, 2017 |
62 | | 100-step PGD on the [CW](https://github.com/carlini/nn_robust_attacks) loss
with 50 random restarts | (initial entry) | 89.71% | Nov 6, 2017 |
63 | | 100-step PGD on the cross-entropy loss | (initial entry) | 92.52% | Nov 6, 2017 |
64 | | 100-step PGD on the [CW](https://github.com/carlini/nn_robust_attacks) loss | (initial entry) | 93.04% | Nov 6, 2017 |
65 | | FGSM on the cross-entropy loss | (initial entry) | 96.36% | Nov 6, 2017 |
66 | | FGSM on the [CW](https://github.com/carlini/nn_robust_attacks) loss | (initial entry) | 96.40% | Nov 6, 2017 |
67 |
68 | ## Format and Rules
69 |
70 | The objective of the challenge is to find black-box (transfer) attacks that are effective against our MNIST model.
71 | Attacks are allowed to perturb each pixel of the input image by at most `epsilon=0.3`.
72 | To ensure that the attacks are indeed black-box, we release our training code and model architecture, but keep the actual network weights secret.
73 |
74 | We invite any interested researchers to submit attacks against our model.
75 | The most successful attacks will be listed in the leaderboard above.
76 | As a reference point, we have seeded the leaderboard with the results of some standard attacks.
77 |
78 | ### The MNIST Model
79 |
80 | We used the code published in this repository to produce an adversarially robust model for MNIST classification. The model is a convolutional neural network consisting of two convolutional layers (each followed by max-pooling) and a fully connected layer. This architecture is derived from the [MNIST tensorflow tutorial](https://www.tensorflow.org/get_started/mnist/pros).
81 | The network was trained against an iterative adversary that is allowed to perturb each pixel by at most `epsilon=0.3`.
82 |
83 | The random seed used for training and the trained network weights will be kept secret.
84 |
85 | The `sha256()` digest of our model file is:
86 | ```
87 | 14eea09c72092db5c2eb5e34cd105974f42569281d2f34826316e356d057f96d
88 | ```
89 | We will release the corresponding model file on October 15th 2017, which is roughly two months after the start of this competition.
90 |
91 | ### The Attack Model
92 |
93 | We are interested in adversarial inputs that are derived from the MNIST test set.
94 | Each pixel can be perturbed by at most `epsilon=0.3` from its initial value.
95 | All pixels can be perturbed independently, so this is an l_infinity attack.
96 |
97 | ### Submitting an Attack
98 |
99 | Each attack should consist of a perturbed version of the MNIST test set.
100 | Each perturbed image in this test set should follow the above attack model.
101 |
102 | The adversarial test set should be formated as a numpy array with one row per example and each row containing a flattened
103 | array of 28x28 pixels.
104 | Hence the overall dimensions are 10,000 rows and 784 columns.
105 | Each pixel must be in the [0,1] range.
106 | See the script `pgd_attack.py` for an attack that generates an adversarial test set in this format.
107 |
108 | In order to submit your attack, save the matrix containing your adversarial examples with `numpy.save` and email the resulting file to mnist.challenge@gmail.com.
109 | We will then run the `run_attack.py` script on your file to verify that the attack is valid and to evaluate the accuracy of our secret model on your examples.
110 | After that, we will reply with the predictions of our model on each of your examples and the overall accuracy of our model on your evaluation set.
111 |
112 | If the attack is valid and outperforms all current attacks in the leaderboard, it will appear at the top of the leaderboard.
113 | Novel types of attacks might be included in the leaderboard even if they do not perform best.
114 |
115 | We strongly encourage you to disclose your attack method.
116 | We would be happy to add a link to your code in our leaderboard.
117 |
118 | ## Overview of the Code
119 | The code consists of six Python scripts and the file `config.json` that contains various parameter settings.
120 |
121 | ### Running the code
122 | - `python train.py`: trains the network, storing checkpoints along
123 | the way.
124 | - `python eval.py`: an infinite evaluation loop, processing each new
125 | checkpoint as it is created while logging summaries. It is intended
126 | to be run in parallel with the `train.py` script.
127 | - `python pgd_attack.py`: applies the attack to the MNIST eval set and
128 | stores the resulting adversarial eval set in a `.npy` file. This file is
129 | in a valid attack format for our challenge.
130 | - `python run_attack.py`: evaluates the model on the examples in
131 | the `.npy` file specified in config, while ensuring that the adversarial examples
132 | are indeed a valid attack. The script also saves the network predictions in `pred.npy`.
133 | - `python fetch_model.py name`: downloads the pre-trained model with the
134 | specified name (at the moment `adv_trained` or `natural`), prints the sha256
135 | hash, and places it in the models directory.
136 |
137 | ### Parameters in `config.json`
138 |
139 | Model configuration:
140 | - `model_dir`: contains the path to the directory of the currently
141 | trained/evaluated model.
142 |
143 | Training configuration:
144 | - `random_seed`: the seed for the RNG used to initialize the network
145 | weights.
146 | - `max_num_training_steps`: the number of training steps.
147 | - `num_output_steps`: the number of training steps between printing
148 | progress in standard output.
149 | - `num_summary_steps`: the number of training steps between storing
150 | tensorboard summaries.
151 | - `num_checkpoint_steps`: the number of training steps between storing
152 | model checkpoints.
153 | - `training_batch_size`: the size of the training batch.
154 |
155 | Evaluation configuration:
156 | - `num_eval_examples`: the number of MNIST examples to evaluate the
157 | model on.
158 | - `eval_batch_size`: the size of the evaluation batches.
159 | - `eval_on_cpu`: forces the `eval.py` script to run on the CPU so it does not compete with `train.py` for GPU resources.
160 |
161 | Adversarial examples configuration:
162 | - `epsilon`: the maximum allowed perturbation per pixel.
163 | - `k`: the number of PGD iterations used by the adversary.
164 | - `a`: the size of the PGD adversary steps.
165 | - `random_start`: specifies whether the adversary will start iterating
166 | from the natural example or a random perturbation of it.
167 | - `loss_func`: the loss function used to run pgd on. `xent` corresponds to the
168 | standard cross-entropy loss, `cw` corresponds to the loss function
169 | of [Carlini and Wagner](https://arxiv.org/abs/1608.04644).
170 | - `store_adv_path`: the file in which adversarial examples are stored.
171 | Relevant for the `pgd_attack.py` and `run_attack.py` scripts.
172 |
173 | ## Example usage
174 | After cloning the repository you can either train a new network or evaluate/attack one of our pre-trained networks.
175 | #### Training a new network
176 | * Start training by running:
177 | ```
178 | python train.py
179 | ```
180 | * (Optional) Evaluation summaries can be logged by simultaneously
181 | running:
182 | ```
183 | python eval.py
184 | ```
185 | #### Download a pre-trained network
186 | * For an adversarially trained network, run
187 | ```
188 | python fetch_model.py adv_trained
189 | ```
190 | and use the `config.json` file to set `"model_dir": "models/adv_trained"`.
191 | * For a naturally trained network, run
192 | ```
193 | python fetch_model.py natural
194 | ```
195 | and use the `config.json` file to set `"model_dir": "models/natural"`.
196 | #### Test the network
197 | * Create an attack file by running
198 | ```
199 | python pgd_attack.py
200 | ```
201 | * Evaluate the network with
202 | ```
203 | python run_attack.py
204 | ```
205 |
--------------------------------------------------------------------------------
/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_comment": "===== MODEL CONFIGURATION =====",
3 | "model_dir": "models/a_very_robust_model",
4 |
5 | "_comment": "===== TRAINING CONFIGURATION =====",
6 | "random_seed": 4557077,
7 | "max_num_training_steps": 100000,
8 | "num_output_steps": 100,
9 | "num_summary_steps": 100,
10 | "num_checkpoint_steps": 300,
11 | "training_batch_size": 50,
12 |
13 | "_comment": "===== EVAL CONFIGURATION =====",
14 | "num_eval_examples": 10000,
15 | "eval_batch_size": 200,
16 | "eval_on_cpu": true,
17 |
18 | "_comment": "=====ADVERSARIAL EXAMPLES CONFIGURATION=====",
19 | "epsilon": 0.3,
20 | "k": 40,
21 | "a": 0.01,
22 | "random_start": true,
23 | "loss_func": "xent",
24 | "store_adv_path": "attack.npy"
25 | }
26 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | """
2 | Infinite evaluation loop going through the checkpoints in the model directory
3 | as they appear and evaluating them. Accuracy and average loss are printed and
4 | added as tensorboard summaries.
5 | """
6 | from __future__ import absolute_import
7 | from __future__ import division
8 | from __future__ import print_function
9 |
10 | from datetime import datetime
11 | import json
12 | import math
13 | import os
14 | import sys
15 | import time
16 |
17 | import tensorflow as tf
18 | from tensorflow.examples.tutorials.mnist import input_data
19 |
20 | from model import Model
21 | from pgd_attack import LinfPGDAttack
22 |
23 | # Global constants
24 | with open('config.json') as config_file:
25 | config = json.load(config_file)
26 | num_eval_examples = config['num_eval_examples']
27 | eval_batch_size = config['eval_batch_size']
28 | eval_on_cpu = config['eval_on_cpu']
29 |
30 | model_dir = config['model_dir']
31 |
32 | # Set upd the data, hyperparameters, and the model
33 | mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
34 |
35 | if eval_on_cpu:
36 | with tf.device("/cpu:0"):
37 | model = Model()
38 | attack = LinfPGDAttack(model,
39 | config['epsilon'],
40 | config['k'],
41 | config['a'],
42 | config['random_start'],
43 | config['loss_func'])
44 | else:
45 | model = Model()
46 | attack = LinfPGDAttack(model,
47 | config['epsilon'],
48 | config['k'],
49 | config['a'],
50 | config['random_start'],
51 | config['loss_func'])
52 |
53 | global_step = tf.contrib.framework.get_or_create_global_step()
54 |
55 | # Setting up the Tensorboard and checkpoint outputs
56 | if not os.path.exists(model_dir):
57 | os.makedirs(model_dir)
58 | eval_dir = os.path.join(model_dir, 'eval')
59 | if not os.path.exists(eval_dir):
60 | os.makedirs(eval_dir)
61 |
62 | last_checkpoint_filename = ''
63 | already_seen_state = False
64 |
65 | saver = tf.train.Saver()
66 | summary_writer = tf.summary.FileWriter(eval_dir)
67 |
68 | # A function for evaluating a single checkpoint
69 | def evaluate_checkpoint(filename):
70 | with tf.Session() as sess:
71 | # Restore the checkpoint
72 | saver.restore(sess, filename)
73 |
74 | # Iterate over the samples batch-by-batch
75 | num_batches = int(math.ceil(num_eval_examples / eval_batch_size))
76 | total_xent_nat = 0.
77 | total_xent_adv = 0.
78 | total_corr_nat = 0
79 | total_corr_adv = 0
80 |
81 | for ibatch in range(num_batches):
82 | bstart = ibatch * eval_batch_size
83 | bend = min(bstart + eval_batch_size, num_eval_examples)
84 |
85 | x_batch = mnist.test.images[bstart:bend, :]
86 | y_batch = mnist.test.labels[bstart:bend]
87 |
88 | dict_nat = {model.x_input: x_batch,
89 | model.y_input: y_batch}
90 |
91 | x_batch_adv = attack.perturb(x_batch, y_batch, sess)
92 |
93 | dict_adv = {model.x_input: x_batch_adv,
94 | model.y_input: y_batch}
95 |
96 | cur_corr_nat, cur_xent_nat = sess.run(
97 | [model.num_correct,model.xent],
98 | feed_dict = dict_nat)
99 | cur_corr_adv, cur_xent_adv = sess.run(
100 | [model.num_correct,model.xent],
101 | feed_dict = dict_adv)
102 |
103 | total_xent_nat += cur_xent_nat
104 | total_xent_adv += cur_xent_adv
105 | total_corr_nat += cur_corr_nat
106 | total_corr_adv += cur_corr_adv
107 |
108 | avg_xent_nat = total_xent_nat / num_eval_examples
109 | avg_xent_adv = total_xent_adv / num_eval_examples
110 | acc_nat = total_corr_nat / num_eval_examples
111 | acc_adv = total_corr_adv / num_eval_examples
112 |
113 | summary = tf.Summary(value=[
114 | tf.Summary.Value(tag='xent adv eval', simple_value= avg_xent_adv),
115 | tf.Summary.Value(tag='xent adv', simple_value= avg_xent_adv),
116 | tf.Summary.Value(tag='xent nat', simple_value= avg_xent_nat),
117 | tf.Summary.Value(tag='accuracy adv eval', simple_value= acc_adv),
118 | tf.Summary.Value(tag='accuracy adv', simple_value= acc_adv),
119 | tf.Summary.Value(tag='accuracy nat', simple_value= acc_nat)])
120 | summary_writer.add_summary(summary, global_step.eval(sess))
121 |
122 | print('natural: {:.2f}%'.format(100 * acc_nat))
123 | print('adversarial: {:.2f}%'.format(100 * acc_adv))
124 | print('avg nat loss: {:.4f}'.format(avg_xent_nat))
125 | print('avg adv loss: {:.4f}'.format(avg_xent_adv))
126 |
127 | # Infinite eval loop
128 | while True:
129 | cur_checkpoint = tf.train.latest_checkpoint(model_dir)
130 |
131 | # Case 1: No checkpoint yet
132 | if cur_checkpoint is None:
133 | if not already_seen_state:
134 | print('No checkpoint yet, waiting ...', end='')
135 | already_seen_state = True
136 | else:
137 | print('.', end='')
138 | sys.stdout.flush()
139 | time.sleep(10)
140 | # Case 2: Previously unseen checkpoint
141 | elif cur_checkpoint != last_checkpoint_filename:
142 | print('\nCheckpoint {}, evaluating ... ({})'.format(cur_checkpoint,
143 | datetime.now()))
144 | sys.stdout.flush()
145 | last_checkpoint_filename = cur_checkpoint
146 | already_seen_state = False
147 | evaluate_checkpoint(cur_checkpoint)
148 | # Case 3: Previously evaluated checkpoint
149 | else:
150 | if not already_seen_state:
151 | print('Waiting for the next checkpoint ... ({}) '.format(
152 | datetime.now()),
153 | end='')
154 | already_seen_state = True
155 | else:
156 | print('.', end='')
157 | sys.stdout.flush()
158 | time.sleep(10)
159 |
--------------------------------------------------------------------------------
/fetch_model.py:
--------------------------------------------------------------------------------
1 | """Downloads a model, computes its SHA256 hash and unzips it
2 | at the proper location."""
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import sys
8 | import zipfile
9 | import hashlib
10 |
11 | if len(sys.argv) != 2 or sys.argv[1] not in ['natural',
12 | 'adv_trained',
13 | 'secret']:
14 | print('Usage: python fetch_model.py [natural, adv_trained, secret]')
15 | sys.exit(1)
16 |
17 | if sys.argv[1] == 'natural':
18 | url = 'https://github.com/MadryLab/mnist_challenge_models/raw/master/natural.zip'
19 | elif sys.argv[1] == 'secret':
20 | url = 'https://github.com/MadryLab/mnist_challenge_models/raw/master/secret.zip'
21 | else: # fetch adv_trained model
22 | url = 'https://github.com/MadryLab/mnist_challenge_models/raw/master/adv_trained.zip'
23 |
24 | fname = url.split('/')[-1] # get the name of the file
25 |
26 | # model download
27 | print('Downloading models')
28 | if sys.version_info >= (3,):
29 | import urllib.request
30 | urllib.request.urlretrieve(url, fname)
31 | else:
32 | import urllib
33 | urllib.urlretrieve(url, fname)
34 |
35 | # computing model hash
36 | sha256 = hashlib.sha256()
37 | with open(fname, 'rb') as f:
38 | data = f.read()
39 | sha256.update(data)
40 | print('SHA256 hash: {}'.format(sha256.hexdigest()))
41 |
42 | # extracting model
43 | print('Extracting model')
44 | with zipfile.ZipFile(fname, 'r') as model_zip:
45 | model_zip.extractall()
46 | print('Extracted model in {}'.format(model_zip.namelist()[0]))
47 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | """
2 | The model is adapted from the tensorflow tutorial:
3 | https://www.tensorflow.org/get_started/mnist/pros
4 | """
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import tensorflow as tf
10 |
11 | class Model(object):
12 | def __init__(self):
13 | self.x_input = tf.placeholder(tf.float32, shape = [None, 784])
14 | self.y_input = tf.placeholder(tf.int64, shape = [None])
15 |
16 | self.x_image = tf.reshape(self.x_input, [-1, 28, 28, 1])
17 |
18 | # first convolutional layer
19 | W_conv1 = self._weight_variable([5,5,1,32])
20 | b_conv1 = self._bias_variable([32])
21 |
22 | h_conv1 = tf.nn.relu(self._conv2d(self.x_image, W_conv1) + b_conv1)
23 | h_pool1 = self._max_pool_2x2(h_conv1)
24 |
25 | # second convolutional layer
26 | W_conv2 = self._weight_variable([5,5,32,64])
27 | b_conv2 = self._bias_variable([64])
28 |
29 | h_conv2 = tf.nn.relu(self._conv2d(h_pool1, W_conv2) + b_conv2)
30 | h_pool2 = self._max_pool_2x2(h_conv2)
31 |
32 | # first fully connected layer
33 | W_fc1 = self._weight_variable([7 * 7 * 64, 1024])
34 | b_fc1 = self._bias_variable([1024])
35 |
36 | h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
37 | h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
38 |
39 | # output layer
40 | W_fc2 = self._weight_variable([1024,10])
41 | b_fc2 = self._bias_variable([10])
42 |
43 | self.pre_softmax = tf.matmul(h_fc1, W_fc2) + b_fc2
44 |
45 | y_xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
46 | labels=self.y_input, logits=self.pre_softmax)
47 |
48 | self.xent = tf.reduce_sum(y_xent)
49 |
50 | self.y_pred = tf.argmax(self.pre_softmax, 1)
51 |
52 | correct_prediction = tf.equal(self.y_pred, self.y_input)
53 |
54 | self.num_correct = tf.reduce_sum(tf.cast(correct_prediction, tf.int64))
55 | self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
56 |
57 | @staticmethod
58 | def _weight_variable(shape):
59 | initial = tf.truncated_normal(shape, stddev=0.1)
60 | return tf.Variable(initial)
61 |
62 | @staticmethod
63 | def _bias_variable(shape):
64 | initial = tf.constant(0.1, shape = shape)
65 | return tf.Variable(initial)
66 |
67 | @staticmethod
68 | def _conv2d(x, W):
69 | return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')
70 |
71 | @staticmethod
72 | def _max_pool_2x2( x):
73 | return tf.nn.max_pool(x,
74 | ksize = [1,2,2,1],
75 | strides=[1,2,2,1],
76 | padding='SAME')
77 |
--------------------------------------------------------------------------------
/model_robustml.py:
--------------------------------------------------------------------------------
1 | import robustml
2 | import tensorflow as tf
3 |
4 | import model
5 |
6 | class Model(robustml.model.Model):
7 | def __init__(self, sess):
8 | self._model = model.Model()
9 |
10 | saver = tf.train.Saver()
11 | checkpoint = tf.train.latest_checkpoint('models/secret')
12 | saver.restore(sess, checkpoint)
13 |
14 | self._sess = sess
15 | self._input = self._model.x_input
16 | self._logits = self._model.pre_softmax
17 | self._predictions = self._model.y_pred
18 | self._dataset = robustml.dataset.MNIST()
19 | self._threat_model = robustml.threat_model.Linf(epsilon=0.3)
20 |
21 | @property
22 | def dataset(self):
23 | return self._dataset
24 |
25 | @property
26 | def threat_model(self):
27 | return self._threat_model
28 |
29 | def classify(self, x):
30 | return self._sess.run(self._predictions,
31 | {self._input: x})[0]
32 |
33 | # expose attack interface
34 |
35 | @property
36 | def input(self):
37 | return self._input
38 |
39 | @property
40 | def logits(self):
41 | return self._logits
42 |
43 | @property
44 | def predictions(self):
45 | return self._predictions
46 |
--------------------------------------------------------------------------------
/pgd_attack.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of attack methods. Running this file as a program will
3 | apply the attack to the model specified by the config file and store
4 | the examples in an .npy file.
5 | """
6 | from __future__ import absolute_import
7 | from __future__ import division
8 | from __future__ import print_function
9 |
10 | import tensorflow as tf
11 | import numpy as np
12 |
13 |
14 | class LinfPGDAttack:
15 | def __init__(self, model, epsilon, k, a, random_start, loss_func):
16 | """Attack parameter initialization. The attack performs k steps of
17 | size a, while always staying within epsilon from the initial
18 | point."""
19 | self.model = model
20 | self.epsilon = epsilon
21 | self.k = k
22 | self.a = a
23 | self.rand = random_start
24 |
25 | if loss_func == 'xent':
26 | loss = model.xent
27 | elif loss_func == 'cw':
28 | label_mask = tf.one_hot(model.y_input,
29 | 10,
30 | on_value=1.0,
31 | off_value=0.0,
32 | dtype=tf.float32)
33 | correct_logit = tf.reduce_sum(label_mask * model.pre_softmax, axis=1)
34 | wrong_logit = tf.reduce_max((1-label_mask) * model.pre_softmax
35 | - 1e4*label_mask, axis=1)
36 | loss = -tf.nn.relu(correct_logit - wrong_logit + 50)
37 | else:
38 | print('Unknown loss function. Defaulting to cross-entropy')
39 | loss = model.xent
40 |
41 | self.grad = tf.gradients(loss, model.x_input)[0]
42 |
43 | def perturb(self, x_nat, y, sess):
44 | """Given a set of examples (x_nat, y), returns a set of adversarial
45 | examples within epsilon of x_nat in l_infinity norm."""
46 | if self.rand:
47 | x = x_nat + np.random.uniform(-self.epsilon, self.epsilon, x_nat.shape)
48 | x = np.clip(x, 0, 1) # ensure valid pixel range
49 | else:
50 | x = np.copy(x_nat)
51 |
52 | for i in range(self.k):
53 | grad = sess.run(self.grad, feed_dict={self.model.x_input: x,
54 | self.model.y_input: y})
55 |
56 | x += self.a * np.sign(grad)
57 |
58 | x = np.clip(x, x_nat - self.epsilon, x_nat + self.epsilon)
59 | x = np.clip(x, 0, 1) # ensure valid pixel range
60 |
61 | return x
62 |
63 |
64 | if __name__ == '__main__':
65 | import json
66 | import sys
67 | import math
68 |
69 | from tensorflow.examples.tutorials.mnist import input_data
70 |
71 | from model import Model
72 |
73 | with open('config.json') as config_file:
74 | config = json.load(config_file)
75 |
76 | model_file = tf.train.latest_checkpoint(config['model_dir'])
77 | if model_file is None:
78 | print('No model found')
79 | sys.exit()
80 |
81 | model = Model()
82 | attack = LinfPGDAttack(model,
83 | config['epsilon'],
84 | config['k'],
85 | config['a'],
86 | config['random_start'],
87 | config['loss_func'])
88 | saver = tf.train.Saver()
89 |
90 | mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
91 |
92 | with tf.Session() as sess:
93 | # Restore the checkpoint
94 | saver.restore(sess, model_file)
95 |
96 | # Iterate over the samples batch-by-batch
97 | num_eval_examples = config['num_eval_examples']
98 | eval_batch_size = config['eval_batch_size']
99 | num_batches = int(math.ceil(num_eval_examples / eval_batch_size))
100 |
101 | x_adv = [] # adv accumulator
102 |
103 | print('Iterating over {} batches'.format(num_batches))
104 |
105 | for ibatch in range(num_batches):
106 | bstart = ibatch * eval_batch_size
107 | bend = min(bstart + eval_batch_size, num_eval_examples)
108 | print('batch size: {}'.format(bend - bstart))
109 |
110 | x_batch = mnist.test.images[bstart:bend, :]
111 | y_batch = mnist.test.labels[bstart:bend]
112 |
113 | x_batch_adv = attack.perturb(x_batch, y_batch, sess)
114 |
115 | x_adv.append(x_batch_adv)
116 |
117 | print('Storing examples')
118 | path = config['store_adv_path']
119 | x_adv = np.concatenate(x_adv, axis=0)
120 | np.save(path, x_adv)
121 | print('Examples stored in {}'.format(path))
122 |
--------------------------------------------------------------------------------
/run_attack.py:
--------------------------------------------------------------------------------
1 | """Evaluates a model against examples from a .npy file as specified
2 | in config.json"""
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | from datetime import datetime
8 | import json
9 | import math
10 | import os
11 | import sys
12 | import time
13 |
14 | import tensorflow as tf
15 | from tensorflow.examples.tutorials.mnist import input_data
16 |
17 | import numpy as np
18 |
19 | from model import Model
20 |
21 | def run_attack(checkpoint, x_adv, epsilon):
22 | mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
23 |
24 | model = Model()
25 |
26 | saver = tf.train.Saver()
27 |
28 | num_eval_examples = 10000
29 | eval_batch_size = 64
30 |
31 | num_batches = int(math.ceil(num_eval_examples / eval_batch_size))
32 | total_corr = 0
33 |
34 | x_nat = mnist.test.images
35 | l_inf = np.amax(np.abs(x_nat - x_adv))
36 |
37 | if l_inf > epsilon + 0.0001:
38 | print('maximum perturbation found: {}'.format(l_inf))
39 | print('maximum perturbation allowed: {}'.format(epsilon))
40 | return
41 |
42 | y_pred = [] # label accumulator
43 |
44 | with tf.Session() as sess:
45 | # Restore the checkpoint
46 | saver.restore(sess, checkpoint)
47 |
48 | # Iterate over the samples batch-by-batch
49 | for ibatch in range(num_batches):
50 | bstart = ibatch * eval_batch_size
51 | bend = min(bstart + eval_batch_size, num_eval_examples)
52 |
53 | x_batch = x_adv[bstart:bend, :]
54 | y_batch = mnist.test.labels[bstart:bend]
55 |
56 | dict_adv = {model.x_input: x_batch,
57 | model.y_input: y_batch}
58 | cur_corr, y_pred_batch = sess.run([model.num_correct, model.y_pred],
59 | feed_dict=dict_adv)
60 |
61 | total_corr += cur_corr
62 | y_pred.append(y_pred_batch)
63 |
64 | accuracy = total_corr / num_eval_examples
65 |
66 | print('Accuracy: {:.2f}%'.format(100.0 * accuracy))
67 | y_pred = np.concatenate(y_pred, axis=0)
68 | np.save('pred.npy', y_pred)
69 | print('Output saved at pred.npy')
70 |
71 | if __name__ == '__main__':
72 | import json
73 |
74 | with open('config.json') as config_file:
75 | config = json.load(config_file)
76 |
77 | model_dir = config['model_dir']
78 |
79 | checkpoint = tf.train.latest_checkpoint(model_dir)
80 | x_adv = np.load(config['store_adv_path'])
81 |
82 | if checkpoint is None:
83 | print('No checkpoint found')
84 | elif x_adv.shape != (10000, 784):
85 | print('Invalid shape: expected (10000,784), found {}'.format(x_adv.shape))
86 | elif np.amax(x_adv) > 1.0001 or \
87 | np.amin(x_adv) < -0.0001 or \
88 | np.isnan(np.amax(x_adv)):
89 | print('Invalid pixel range. Expected [0, 1], found [{}, {}]'.format(
90 | np.amin(x_adv),
91 | np.amax(x_adv)))
92 | else:
93 | run_attack(checkpoint, x_adv, config['epsilon'])
94 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """Trains a model, saving checkpoints and tensorboard summaries along
2 | the way."""
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | from datetime import datetime
8 | import json
9 | import os
10 | import shutil
11 | from timeit import default_timer as timer
12 |
13 | import tensorflow as tf
14 | import numpy as np
15 | from tensorflow.examples.tutorials.mnist import input_data
16 |
17 | from model import Model
18 | from pgd_attack import LinfPGDAttack
19 |
20 | with open('config.json') as config_file:
21 | config = json.load(config_file)
22 |
23 | # Setting up training parameters
24 | tf.set_random_seed(config['random_seed'])
25 |
26 | max_num_training_steps = config['max_num_training_steps']
27 | num_output_steps = config['num_output_steps']
28 | num_summary_steps = config['num_summary_steps']
29 | num_checkpoint_steps = config['num_checkpoint_steps']
30 |
31 | batch_size = config['training_batch_size']
32 |
33 | # Setting up the data and the model
34 | mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
35 | global_step = tf.contrib.framework.get_or_create_global_step()
36 | model = Model()
37 |
38 | # Setting up the optimizer
39 | train_step = tf.train.AdamOptimizer(1e-4).minimize(model.xent,
40 | global_step=global_step)
41 |
42 | # Set up adversary
43 | attack = LinfPGDAttack(model,
44 | config['epsilon'],
45 | config['k'],
46 | config['a'],
47 | config['random_start'],
48 | config['loss_func'])
49 |
50 | # Setting up the Tensorboard and checkpoint outputs
51 | model_dir = config['model_dir']
52 | if not os.path.exists(model_dir):
53 | os.makedirs(model_dir)
54 |
55 | # We add accuracy and xent twice so we can easily make three types of
56 | # comparisons in Tensorboard:
57 | # - train vs eval (for a single run)
58 | # - train of different runs
59 | # - eval of different runs
60 |
61 | saver = tf.train.Saver(max_to_keep=3)
62 | tf.summary.scalar('accuracy adv train', model.accuracy)
63 | tf.summary.scalar('accuracy adv', model.accuracy)
64 | tf.summary.scalar('xent adv train', model.xent / batch_size)
65 | tf.summary.scalar('xent adv', model.xent / batch_size)
66 | tf.summary.image('images adv train', model.x_image)
67 | merged_summaries = tf.summary.merge_all()
68 |
69 | shutil.copy('config.json', model_dir)
70 |
71 | with tf.Session() as sess:
72 | # Initialize the summary writer, global variables, and our time counter.
73 | summary_writer = tf.summary.FileWriter(model_dir, sess.graph)
74 | sess.run(tf.global_variables_initializer())
75 | training_time = 0.0
76 |
77 | # Main training loop
78 | for ii in range(max_num_training_steps):
79 | x_batch, y_batch = mnist.train.next_batch(batch_size)
80 |
81 | # Compute Adversarial Perturbations
82 | start = timer()
83 | x_batch_adv = attack.perturb(x_batch, y_batch, sess)
84 | end = timer()
85 | training_time += end - start
86 |
87 | nat_dict = {model.x_input: x_batch,
88 | model.y_input: y_batch}
89 |
90 | adv_dict = {model.x_input: x_batch_adv,
91 | model.y_input: y_batch}
92 |
93 | # Output to stdout
94 | if ii % num_output_steps == 0:
95 | nat_acc = sess.run(model.accuracy, feed_dict=nat_dict)
96 | adv_acc = sess.run(model.accuracy, feed_dict=adv_dict)
97 | print('Step {}: ({})'.format(ii, datetime.now()))
98 | print(' training nat accuracy {:.4}%'.format(nat_acc * 100))
99 | print(' training adv accuracy {:.4}%'.format(adv_acc * 100))
100 | if ii != 0:
101 | print(' {} examples per second'.format(
102 | num_output_steps * batch_size / training_time))
103 | training_time = 0.0
104 | # Tensorboard summaries
105 | if ii % num_summary_steps == 0:
106 | summary = sess.run(merged_summaries, feed_dict=adv_dict)
107 | summary_writer.add_summary(summary, global_step.eval(sess))
108 |
109 | # Write a checkpoint
110 | if ii % num_checkpoint_steps == 0:
111 | saver.save(sess,
112 | os.path.join(model_dir, 'checkpoint'),
113 | global_step=global_step)
114 |
115 | # Actual training step
116 | start = timer()
117 | sess.run(train_step, feed_dict=adv_dict)
118 | end = timer()
119 | training_time += end - start
120 |
--------------------------------------------------------------------------------