├── .gitignore
├── LICENSE
├── README.md
├── attack.py
├── data.py
├── images
├── ablation_study.png
├── adv_examples.png
├── algorithm_rs.png
├── ezgif.com-gif-maker-50-conf-small.gif
├── ezgif.com-gif-maker-img-53-l2-2.gif
├── main_results_imagenet.png
├── main_results_imagenet_l2_commonly_successful.png
├── repository_picture.png
├── sensitivity_wrt_p.png
├── success_rate_curves_full.png
├── table_clp_lsq.png
├── table_madry_mnist_l2.png
├── table_madry_trades_mnist_linf.png
└── table_post_avg.png
├── logit_pairing
└── models.py
├── madry_cifar10
├── LICENSE
├── README.md
├── cifar10_input.py
├── config.json
├── eval.py
├── fetch_model.py
├── model.py
├── model_robustml.py
├── pgd_attack.py
├── run_attack.py
└── train.py
├── madry_mnist
├── LICENSE
├── config.json
├── eval.py
├── fetch_model.py
├── model.py
├── run_attack.py
└── train.py
├── metrics
├── 2019-11-10 15:57:14 model=pt_inception dataset=imagenet n_ex=1000 eps=12.75 p=0.05 n_iter=10000.metrics.npy
├── 2019-11-10 15:57:14 model=pt_resnet dataset=imagenet n_ex=1000 eps=12.75 p=0.05 n_iter=10000.metrics.npy
├── 2019-11-10 15:57:14 model=pt_vgg dataset=imagenet n_ex=1000 eps=12.75 p=0.05 n_iter=10000.metrics.npy
├── square_l2_inceptionv3_queries.npy
├── square_l2_resnet50_queries.npy
└── square_l2_vgg16_queries.npy
├── models.py
├── post_avg
├── LICENSE.txt
├── PADefense.py
├── README.md
├── attacks.py
├── postAveragedModels.py
├── resnetSmall.py
├── robustml_test_cifar10.py
├── robustml_test_imagenet.py
└── visualHelper.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 |
3 | # data files
4 | MNIST_DATA
5 |
6 | fast_mnist/
7 |
8 | # model files
9 | models
10 | secret.zip
11 |
12 | # compiled python files
13 | *.pyc
14 |
15 | .idea/
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2019, Maksym Andriushchenko, Francesco Croce, Nicolas Flammarion, Matthias Hein
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 | * Redistributions of source code must retain the above copyright
7 | notice, this list of conditions and the following disclaimer.
8 | * Redistributions in binary form must reproduce the above copyright
9 | notice, this list of conditions and the following disclaimer in the
10 | documentation and/or other materials provided with the distribution.
11 | * Neither the name of the copyright holder nor the
12 | names of its contributors may be used to endorse or promote products
13 | derived from this software without specific prior written permission.
14 |
15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER BE LIABLE FOR ANY
19 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Square Attack: a query-efficient black-box adversarial attack via random search
2 | **ECCV 2020**
3 |
4 | **Maksym Andriushchenko\*, Francesco Croce\*, Nicolas Flammarion, Matthias Hein**
5 |
6 | **EPFL, University of Tübingen**
7 |
8 | **Paper:** [https://arxiv.org/abs/1912.00049](https://arxiv.org/abs/1912.00049)
9 |
10 | \* denotes equal contribution
11 |
12 |
13 | ## News
14 | + [Jul 2020] The paper is accepted at **ECCV 2020**! Please stop by our virtual poster for the latest insights in black-box adversarial attacks (also check out our recent preprint [Sparse-RS paper](https://arxiv.org/abs/2006.12834) where we use random search for sparse attacks).
15 | + [Mar 2020] Our attack is now part of [AutoAttack](https://github.com/fra31/auto-attack), an ensemble of attacks used
16 | for automatic (i.e., no hyperparameter tuning needed) robustness evaluation. Table 2 in the [AutoAttack paper](https://arxiv.org/abs/2003.01690)
17 | shows that at least on 6 models our **black-box** attack outperforms gradient-based methods. Always useful to have a black-box attack to prevent inaccurate robustness claims!
18 | + [Mar 2020] We also achieve the best results on [TRADES MNIST benchmark](https://github.com/yaodongyu/TRADES)!
19 | + [Jan 2020] The Square Attack achieves the best results on [MadryLab's MNIST challenge](https://github.com/MadryLab/mnist_challenge),
20 | outperforming all white-box attacks! In this case we used 50 random restarts of our attack, each with a query limit of 20000.
21 | + [Nov 2019] The Square Attack breaks the recently proposed defense from "Bandlimiting Neural Networks Against Adversarial Attacks"
22 | ([https://github.com/robust-ml/robust-ml.github.io/issues/15](https://github.com/robust-ml/robust-ml.github.io/issues/15)).
23 |
24 |
25 | ## Abstract
26 | We propose the *Square Attack*, a score-based black-box L2- and Linf-adversarial attack that does not
27 | rely on local gradient information and thus is not affected by gradient masking. Square Attack is based on a randomized
28 | search scheme which selects localized square-shaped updates at random positions so that at each iteration the perturbation
29 | is situated approximately at the boundary of the feasible set. Our method is significantly more query efficient and achieves a higher success rate compared to the state-of-the-art
30 | methods, especially in the untargeted setting. In particular, on ImageNet we improve the average query efficiency in the untargeted setting for various deep networks
31 | by a factor of at least 1.8 and up to 3 compared to the recent state-of-the-art Linf-attack of Al-Dujaili & O’Reilly.
32 | Moreover, although our attack is *black-box*, it can also outperform gradient-based *white-box* attacks
33 | on the standard benchmarks achieving a new state-of-the-art in terms of the success rate.
34 |
35 | -----
36 |
37 | The code of the Square Attack can be found in `square_attack_linf(...)` and `square_attack_l2(...)` in `attack.py`.\
38 | Below we show adversarial examples generated by our method for Linf and L2 perturbations:
39 |

40 |
41 |
45 |
46 |
47 | ## About the paper
48 | The general algorithm of the attack is extremely simple and relies on the random search algorithm: we try some update and
49 | accept it only if it helps to improve the loss:
50 | 
51 |
52 | The only thing we customize is the sampling distribution P (see the paper for details). The main idea behind the choice
53 | of the sampling distributions is that:
54 | - We start at the boundary of the feasible set with a good initialization that helps to improve the query efficiency (particularly for the Linf-attack).
55 | - Every iteration we stay at the boundary of the feasible set by changing squared-shaped regions of the image.
56 |
57 | In the paper we also provide convergence analysis of a variant of our attack in the non-convex setting, and justify
58 | the main algorithmic choices such as modifying squares and using the same sign of the update.
59 |
60 | This simple algorithm is sufficient to significantly outperform much more complex approaches in terms of the success rate
61 | and query efficiency:
62 | 
63 | 
64 |
65 | Here are the complete success rate curves with respect to different number of queries. We note that the Square Attack
66 | also outperforms the competing approaches in the low-query regime.
67 | 
68 |
69 | The Square Attack also performs very well on adversarially trained models on MNIST achieving results competitive or
70 | better than *white-box* attacks despite the fact our attack is *black-box*:
71 | 
72 |
73 | Interestingly, the L2 perturbations for the Linf adversarially trained model are challenging for many attacks, including
74 | white-box PGD, and also other black-box attacks. However, the Square Attack is able to much more accurately assess the
75 | robustness in this setting:
76 | 
77 |
78 |
93 |
94 |
95 |
96 | ## Running the code
97 | `attack.py` is the main module that implements the Square Attack, see the command line arguments there.
98 | The main functions which implement the attack are `square_attack_linf()` and `square_attack_l2()`.
99 |
100 | In order to run the untargeted Linf Square Attack on ImageNet models from the PyTorch repository you need to specify a correct path
101 | to the validation set (see `IMAGENET_PATH` in `data.py`) and then run:
102 | - ``` python attack.py --attack=square_linf --model=pt_vgg --n_ex=1000 --eps=12.75 --p=0.05 --n_iter=10000 ```
103 | - ``` python attack.py --attack=square_linf --model=pt_resnet --n_ex=1000 --eps=12.75 --p=0.05 --n_iter=10000 ```
104 | - ``` python attack.py --attack=square_linf --model=pt_inception --n_ex=1000 --eps=12.75 --p=0.05 --n_iter=10000 ```
105 |
106 | Note that eps=12.75 is then divided by 255, so in the end it is equal to 0.05.
107 |
108 | For performing targeted attacks, one should use additionally the flag `--targeted`, use a lower `p`, and specify more
109 | iterations `--n_iter=100000` since it usually takes more iteration to achieve a misclassification to some particular,
110 | randomly chosen class.
111 |
112 | The rest of the models have to downloaded first (see the instructions below), and then can be evaluated in the following way:
113 |
114 | Post-averaging models:
115 | - ``` python attack.py --attack=square_linf --model=pt_post_avg_cifar10 --n_ex=1000 --eps=8.0 --p=0.3 --n_iter=20000 ```
116 | - ``` python attack.py --attack=square_linf --model=pt_post_avg_imagenet --n_ex=1000 --eps=8.0 --p=0.3 --n_iter=20000 ```
117 |
118 | Clean logit pairing and logit squeezing models:
119 | - ``` python attack.py --attack=square_linf --model=clp_mnist --n_ex=1000 --eps=0.3 --p=0.3 --n_iter=20000 ```
120 | - ``` python attack.py --attack=square_linf --model=lsq_mnist --n_ex=1000 --eps=0.3 --p=0.3 --n_iter=20000 ```
121 | - ``` python attack.py --attack=square_linf --model=clp_cifar10 --n_ex=1000 --eps=16.0 --p=0.3 --n_iter=20000 ```
122 | - ``` python attack.py --attack=square_linf --model=lsq_cifar10 --n_ex=1000 --eps=16.0 --p=0.3 --n_iter=20000 ```
123 |
124 | Adversarially trained model (with only 1 restart; note that the results in the paper are based on 50 restarts):
125 | - ``` python attack.py --attack=square_linf --model=madry_mnist_robust --n_ex=10000 --eps=0.3 --p=0.8 --n_iter=20000 ```
126 |
127 | The L2 Square Attack can be run similarly, but please check the recommended hyperparameters in the paper (Section B of the supplement)
128 | and make sure that you specify the right value `eps` taking into account whether the pixels are in [0, 1] or in [0, 255]
129 | for a particular dataset dataset and model.
130 | For example, for the standard ImageNet models, the correct L2 eps to specify is 1275 since after division by 255 it will become 5.0.
131 |
132 |
133 |
134 | ## Saved statistics
135 | In the folder `metrics`, we provide saved statistics of the attack on 4 models: Inception-v3, ResNet-50, VGG-16-BN.\
136 | Here are simple examples how to load the metrics file.
137 |
138 | ### Linf attack
139 | To print the statistics from the last iteration:
140 | ```
141 | metrics = np.load('metrics/2019-11-10 15:57:14 model=pt_resnet dataset=imagenet n_ex=1000 eps=12.75 p=0.05 n_iter=10000.metrics.npy')
142 | iteration = np.argmax(metrics[:, -1]) # max time is the last available iteration
143 | acc, acc_corr, mean_nq, mean_nq_ae, median_nq, avg_loss, time_total = metrics[iteration]
144 | print('[iter {}] acc={:.2%} acc_corr={:.2%} avg#q={:.2f} avg#q_ae={:.2f} med#q_ae={:.2f} (p={}, n_ex={}, eps={}, {:.2f}min)'.
145 | format(n_iters+1, acc, acc_corr, mean_nq, mean_nq_ae, median_nq_ae, p, n_ex, eps, time_total/60))
146 | ```
147 |
148 | Then one can also create different plots based on the data contained in `metrics`. For example, one can use `1 - acc_corr`
149 | to plot the success rate of the Square Attack at different number of queries.
150 |
151 | ### L2 attack
152 | In this case we provide the number of queries necessary to achieve misclassification (`n_queries[i] = 0` means that the image `i` was initially misclassified, `n_queries[i] = 10001` indicates that the attack could not find an adversarial example for the image `i`).
153 | To load the metrics and compute the success rate of the Square Attack after `k` queries, you can run:
154 | ```
155 | n_queries = np.load('metrics/square_l2_resnet50_queries.npy')['n_queries']
156 | success_rate = float(((n_queries > 0) * (n_queries <= k)).sum()) / (n_queries > 0).sum()
157 | ```
158 |
159 |
160 | ## Models
161 | Note that in order to evaluate other models, one has to first download them and move them to the folders specified in
162 | `model_path_dict` from `models.py`:
163 | - [Clean Logit Pairing on MNIST](https://oc.cs.uni-saarland.de/owncloud/index.php/s/w2yegcfx8mc8kNa)
164 | - [Logit Squeezing on MNIST](https://oc.cs.uni-saarland.de/owncloud/index.php/s/a5ZY72BDCPEtb2S)
165 | - [Clean Logit Pairing on CIFAR-10](https://oc.cs.uni-saarland.de/owncloud/index.php/s/odcd7FgFdbqq6zL)
166 | - [Logit Squeezing on CIFAR-10](https://oc.cs.uni-saarland.de/owncloud/index.php/s/EYnbHDeMbe4mq5M)
167 | - MNIST, Madry adversarial training: run `python madry_mnist/fetch_model.py secret`
168 | - MNIST, TRADES: download the [models](https://drive.google.com/file/d/1scTd9-YO3-5Ul3q5SJuRrTNX__LYLD_M) and see their [repository](https://github.com/yaodongyu/TRADES)
169 | - [Post-averaging defense](https://github.com/YupingLin171/PostAvgDefense/blob/master/trainedModel/resnet110.th): the model can be downloaded directly from the repository
170 |
171 | For the first 4 models, one has to additionally update the paths in the `checkpoint` file in the following way:
172 | ```
173 | model_checkpoint_path: "model.ckpt"
174 | all_model_checkpoint_paths: "model.ckpt"
175 | ```
176 |
177 |
178 |
179 | ## Requirements
180 | - PyTorch 1.0.0
181 | - Tensorflow 1.12.0
182 |
183 |
184 |
185 | ## Contact
186 | Do you have a problem or question regarding the code?
187 | Please don't hesitate to open an issue or contact [Maksym Andriushchenko](https://github.com/max-andr) or
188 | [Francesco Croce](https://github.com/fra31) directly.
189 |
190 |
191 | ## Citation
192 | ```
193 | @article{ACFH2020square,
194 | title={Square Attack: a query-efficient black-box adversarial attack via random search},
195 | author={Andriushchenko, Maksym and Croce, Francesco and Flammarion, Nicolas and Hein, Matthias},
196 | conference={ECCV},
197 | year={2020}
198 | }
199 | ```
200 |
--------------------------------------------------------------------------------
/attack.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import numpy as np
4 | import data
5 | import models
6 | import os
7 | import utils
8 | from datetime import datetime
9 | np.set_printoptions(precision=5, suppress=True)
10 |
11 |
12 | def p_selection(p_init, it, n_iters):
13 | """ Piece-wise constant schedule for p (the fraction of pixels changed on every iteration). """
14 | it = int(it / n_iters * 10000)
15 |
16 | if 10 < it <= 50:
17 | p = p_init / 2
18 | elif 50 < it <= 200:
19 | p = p_init / 4
20 | elif 200 < it <= 500:
21 | p = p_init / 8
22 | elif 500 < it <= 1000:
23 | p = p_init / 16
24 | elif 1000 < it <= 2000:
25 | p = p_init / 32
26 | elif 2000 < it <= 4000:
27 | p = p_init / 64
28 | elif 4000 < it <= 6000:
29 | p = p_init / 128
30 | elif 6000 < it <= 8000:
31 | p = p_init / 256
32 | elif 8000 < it <= 10000:
33 | p = p_init / 512
34 | else:
35 | p = p_init
36 |
37 | return p
38 |
39 |
40 | def pseudo_gaussian_pert_rectangles(x, y):
41 | delta = np.zeros([x, y])
42 | x_c, y_c = x // 2 + 1, y // 2 + 1
43 |
44 | counter2 = [x_c - 1, y_c - 1]
45 | for counter in range(0, max(x_c, y_c)):
46 | delta[max(counter2[0], 0):min(counter2[0] + (2 * counter + 1), x),
47 | max(0, counter2[1]):min(counter2[1] + (2 * counter + 1), y)] += 1.0 / (counter + 1) ** 2
48 |
49 | counter2[0] -= 1
50 | counter2[1] -= 1
51 |
52 | delta /= np.sqrt(np.sum(delta ** 2, keepdims=True))
53 |
54 | return delta
55 |
56 |
57 | def meta_pseudo_gaussian_pert(s):
58 | delta = np.zeros([s, s])
59 | n_subsquares = 2
60 | if n_subsquares == 2:
61 | delta[:s // 2] = pseudo_gaussian_pert_rectangles(s // 2, s)
62 | delta[s // 2:] = pseudo_gaussian_pert_rectangles(s - s // 2, s) * (-1)
63 | delta /= np.sqrt(np.sum(delta ** 2, keepdims=True))
64 | if np.random.rand(1) > 0.5: delta = np.transpose(delta)
65 |
66 | elif n_subsquares == 4:
67 | delta[:s // 2, :s // 2] = pseudo_gaussian_pert_rectangles(s // 2, s // 2) * np.random.choice([-1, 1])
68 | delta[s // 2:, :s // 2] = pseudo_gaussian_pert_rectangles(s - s // 2, s // 2) * np.random.choice([-1, 1])
69 | delta[:s // 2, s // 2:] = pseudo_gaussian_pert_rectangles(s // 2, s - s // 2) * np.random.choice([-1, 1])
70 | delta[s // 2:, s // 2:] = pseudo_gaussian_pert_rectangles(s - s // 2, s - s // 2) * np.random.choice([-1, 1])
71 | delta /= np.sqrt(np.sum(delta ** 2, keepdims=True))
72 |
73 | return delta
74 |
75 |
76 | def square_attack_l2(model, x, y, corr_classified, eps, n_iters, p_init, metrics_path, targeted, loss_type):
77 | """ The L2 square attack """
78 | np.random.seed(0)
79 |
80 | min_val, max_val = 0, 1
81 | c, h, w = x.shape[1:]
82 | n_features = c * h * w
83 | n_ex_total = x.shape[0]
84 | x, y = x[corr_classified], y[corr_classified]
85 |
86 | ### initialization
87 | delta_init = np.zeros(x.shape)
88 | s = h // 5
89 | log.print('Initial square side={} for bumps'.format(s))
90 | sp_init = (h - s * 5) // 2
91 | center_h = sp_init + 0
92 | for counter in range(h // s):
93 | center_w = sp_init + 0
94 | for counter2 in range(w // s):
95 | delta_init[:, :, center_h:center_h + s, center_w:center_w + s] += meta_pseudo_gaussian_pert(s).reshape(
96 | [1, 1, s, s]) * np.random.choice([-1, 1], size=[x.shape[0], c, 1, 1])
97 | center_w += s
98 | center_h += s
99 |
100 | x_best = np.clip(x + delta_init / np.sqrt(np.sum(delta_init ** 2, axis=(1, 2, 3), keepdims=True)) * eps, 0, 1)
101 |
102 | logits = model.predict(x_best)
103 | loss_min = model.loss(y, logits, targeted, loss_type=loss_type)
104 | margin_min = model.loss(y, logits, targeted, loss_type='margin_loss')
105 | n_queries = np.ones(x.shape[0]) # ones because we have already used 1 query
106 |
107 | time_start = time.time()
108 | s_init = int(np.sqrt(p_init * n_features / c))
109 | metrics = np.zeros([n_iters, 7])
110 | for i_iter in range(n_iters):
111 | idx_to_fool = (margin_min > 0.0)
112 |
113 | x_curr, x_best_curr = x[idx_to_fool], x_best[idx_to_fool]
114 | y_curr, margin_min_curr = y[idx_to_fool], margin_min[idx_to_fool]
115 | loss_min_curr = loss_min[idx_to_fool]
116 | delta_curr = x_best_curr - x_curr
117 |
118 | p = p_selection(p_init, i_iter, n_iters)
119 | s = max(int(round(np.sqrt(p * n_features / c))), 3)
120 |
121 | if s % 2 == 0:
122 | s += 1
123 |
124 | s2 = s + 0
125 | ### window_1
126 | center_h = np.random.randint(0, h - s)
127 | center_w = np.random.randint(0, w - s)
128 | new_deltas_mask = np.zeros(x_curr.shape)
129 | new_deltas_mask[:, :, center_h:center_h + s, center_w:center_w + s] = 1.0
130 |
131 | ### window_2
132 | center_h_2 = np.random.randint(0, h - s2)
133 | center_w_2 = np.random.randint(0, w - s2)
134 | new_deltas_mask_2 = np.zeros(x_curr.shape)
135 | new_deltas_mask_2[:, :, center_h_2:center_h_2 + s2, center_w_2:center_w_2 + s2] = 1.0
136 | norms_window_2 = np.sqrt(
137 | np.sum(delta_curr[:, :, center_h_2:center_h_2 + s2, center_w_2:center_w_2 + s2] ** 2, axis=(-2, -1),
138 | keepdims=True))
139 |
140 | ### compute total norm available
141 | curr_norms_window = np.sqrt(
142 | np.sum(((x_best_curr - x_curr) * new_deltas_mask) ** 2, axis=(2, 3), keepdims=True))
143 | curr_norms_image = np.sqrt(np.sum((x_best_curr - x_curr) ** 2, axis=(1, 2, 3), keepdims=True))
144 | mask_2 = np.maximum(new_deltas_mask, new_deltas_mask_2)
145 | norms_windows = np.sqrt(np.sum((delta_curr * mask_2) ** 2, axis=(2, 3), keepdims=True))
146 |
147 | ### create the updates
148 | new_deltas = np.ones([x_curr.shape[0], c, s, s])
149 | new_deltas = new_deltas * meta_pseudo_gaussian_pert(s).reshape([1, 1, s, s])
150 | new_deltas *= np.random.choice([-1, 1], size=[x_curr.shape[0], c, 1, 1])
151 | old_deltas = delta_curr[:, :, center_h:center_h + s, center_w:center_w + s] / (1e-10 + curr_norms_window)
152 | new_deltas += old_deltas
153 | new_deltas = new_deltas / np.sqrt(np.sum(new_deltas ** 2, axis=(2, 3), keepdims=True)) * (
154 | np.maximum(eps ** 2 - curr_norms_image ** 2, 0) / c + norms_windows ** 2) ** 0.5
155 | delta_curr[:, :, center_h_2:center_h_2 + s2, center_w_2:center_w_2 + s2] = 0.0 # set window_2 to 0
156 | delta_curr[:, :, center_h:center_h + s, center_w:center_w + s] = new_deltas + 0 # update window_1
157 |
158 | hps_str = 's={}->{}'.format(s_init, s)
159 | x_new = x_curr + delta_curr / np.sqrt(np.sum(delta_curr ** 2, axis=(1, 2, 3), keepdims=True)) * eps
160 | x_new = np.clip(x_new, min_val, max_val)
161 | curr_norms_image = np.sqrt(np.sum((x_new - x_curr) ** 2, axis=(1, 2, 3), keepdims=True))
162 |
163 | logits = model.predict(x_new)
164 | loss = model.loss(y_curr, logits, targeted, loss_type=loss_type)
165 | margin = model.loss(y_curr, logits, targeted, loss_type='margin_loss')
166 |
167 | idx_improved = loss < loss_min_curr
168 | loss_min[idx_to_fool] = idx_improved * loss + ~idx_improved * loss_min_curr
169 | margin_min[idx_to_fool] = idx_improved * margin + ~idx_improved * margin_min_curr
170 |
171 | idx_improved = np.reshape(idx_improved, [-1, *[1] * len(x.shape[:-1])])
172 | x_best[idx_to_fool] = idx_improved * x_new + ~idx_improved * x_best_curr
173 | n_queries[idx_to_fool] += 1
174 |
175 | acc = (margin_min > 0.0).sum() / n_ex_total
176 | acc_corr = (margin_min > 0.0).mean()
177 | mean_nq, mean_nq_ae, median_nq, median_nq_ae = np.mean(n_queries), np.mean(
178 | n_queries[margin_min <= 0]), np.median(n_queries), np.median(n_queries[margin_min <= 0])
179 |
180 | time_total = time.time() - time_start
181 | log.print(
182 | '{}: acc={:.2%} acc_corr={:.2%} avg#q_ae={:.1f} med#q_ae={:.1f} {}, n_ex={}, {:.0f}s, loss={:.3f}, max_pert={:.1f}, impr={:.0f}'.
183 | format(i_iter + 1, acc, acc_corr, mean_nq_ae, median_nq_ae, hps_str, x.shape[0], time_total,
184 | np.mean(margin_min), np.amax(curr_norms_image), np.sum(idx_improved)))
185 | metrics[i_iter] = [acc, acc_corr, mean_nq, mean_nq_ae, median_nq, margin_min.mean(), time_total]
186 | if (i_iter <= 500 and i_iter % 500) or (i_iter > 100 and i_iter % 500) or i_iter + 1 == n_iters or acc == 0:
187 | np.save(metrics_path, metrics)
188 | if acc == 0:
189 | curr_norms_image = np.sqrt(np.sum((x_best - x) ** 2, axis=(1, 2, 3), keepdims=True))
190 | print('Maximal norm of the perturbations: {:.5f}'.format(np.amax(curr_norms_image)))
191 | break
192 |
193 | curr_norms_image = np.sqrt(np.sum((x_best - x) ** 2, axis=(1, 2, 3), keepdims=True))
194 | print('Maximal norm of the perturbations: {:.5f}'.format(np.amax(curr_norms_image)))
195 |
196 | return n_queries, x_best
197 |
198 |
199 | def square_attack_linf(model, x, y, corr_classified, eps, n_iters, p_init, metrics_path, targeted, loss_type):
200 | """ The Linf square attack """
201 | np.random.seed(0) # important to leave it here as well
202 | min_val, max_val = 0, 1 if x.max() <= 1 else 255
203 | c, h, w = x.shape[1:]
204 | n_features = c*h*w
205 | n_ex_total = x.shape[0]
206 | x, y = x[corr_classified], y[corr_classified]
207 |
208 | # [c, 1, w], i.e. vertical stripes work best for untargeted attacks
209 | init_delta = np.random.choice([-eps, eps], size=[x.shape[0], c, 1, w])
210 | x_best = np.clip(x + init_delta, min_val, max_val)
211 |
212 | logits = model.predict(x_best)
213 | loss_min = model.loss(y, logits, targeted, loss_type=loss_type)
214 | margin_min = model.loss(y, logits, targeted, loss_type='margin_loss')
215 | n_queries = np.ones(x.shape[0]) # ones because we have already used 1 query
216 |
217 | time_start = time.time()
218 | metrics = np.zeros([n_iters, 7])
219 | for i_iter in range(n_iters - 1):
220 | idx_to_fool = margin_min > 0
221 | x_curr, x_best_curr, y_curr = x[idx_to_fool], x_best[idx_to_fool], y[idx_to_fool]
222 | loss_min_curr, margin_min_curr = loss_min[idx_to_fool], margin_min[idx_to_fool]
223 | deltas = x_best_curr - x_curr
224 |
225 | p = p_selection(p_init, i_iter, n_iters)
226 | for i_img in range(x_best_curr.shape[0]):
227 | s = int(round(np.sqrt(p * n_features / c)))
228 | s = min(max(s, 1), h-1) # at least c x 1 x 1 window is taken and at most c x h-1 x h-1
229 | center_h = np.random.randint(0, h - s)
230 | center_w = np.random.randint(0, w - s)
231 |
232 | x_curr_window = x_curr[i_img, :, center_h:center_h+s, center_w:center_w+s]
233 | x_best_curr_window = x_best_curr[i_img, :, center_h:center_h+s, center_w:center_w+s]
234 | # prevent trying out a delta if it doesn't change x_curr (e.g. an overlapping patch)
235 | while np.sum(np.abs(np.clip(x_curr_window + deltas[i_img, :, center_h:center_h+s, center_w:center_w+s], min_val, max_val) - x_best_curr_window) < 10**-7) == c*s*s:
236 | deltas[i_img, :, center_h:center_h+s, center_w:center_w+s] = np.random.choice([-eps, eps], size=[c, 1, 1])
237 |
238 | x_new = np.clip(x_curr + deltas, min_val, max_val)
239 |
240 | logits = model.predict(x_new)
241 | loss = model.loss(y_curr, logits, targeted, loss_type=loss_type)
242 | margin = model.loss(y_curr, logits, targeted, loss_type='margin_loss')
243 |
244 | idx_improved = loss < loss_min_curr
245 | loss_min[idx_to_fool] = idx_improved * loss + ~idx_improved * loss_min_curr
246 | margin_min[idx_to_fool] = idx_improved * margin + ~idx_improved * margin_min_curr
247 | idx_improved = np.reshape(idx_improved, [-1, *[1]*len(x.shape[:-1])])
248 | x_best[idx_to_fool] = idx_improved * x_new + ~idx_improved * x_best_curr
249 | n_queries[idx_to_fool] += 1
250 |
251 | acc = (margin_min > 0.0).sum() / n_ex_total
252 | acc_corr = (margin_min > 0.0).mean()
253 | mean_nq, mean_nq_ae, median_nq_ae = np.mean(n_queries), np.mean(n_queries[margin_min <= 0]), np.median(n_queries[margin_min <= 0])
254 | avg_margin_min = np.mean(margin_min)
255 | time_total = time.time() - time_start
256 | log.print('{}: acc={:.2%} acc_corr={:.2%} avg#q_ae={:.2f} med#q={:.1f}, avg_margin={:.2f} (n_ex={}, eps={:.3f}, {:.2f}s)'.
257 | format(i_iter+1, acc, acc_corr, mean_nq_ae, median_nq_ae, avg_margin_min, x.shape[0], eps, time_total))
258 |
259 | metrics[i_iter] = [acc, acc_corr, mean_nq, mean_nq_ae, median_nq_ae, margin_min.mean(), time_total]
260 | if (i_iter <= 500 and i_iter % 20 == 0) or (i_iter > 100 and i_iter % 50 == 0) or i_iter + 1 == n_iters or acc == 0:
261 | np.save(metrics_path, metrics)
262 | if acc == 0:
263 | break
264 |
265 | return n_queries, x_best
266 |
267 |
268 | if __name__ == '__main__':
269 | parser = argparse.ArgumentParser(description='Define hyperparameters.')
270 | parser.add_argument('--model', type=str, default='pt_resnet', choices=models.all_model_names, help='Model name.')
271 | parser.add_argument('--attack', type=str, default='square_linf', choices=['square_linf', 'square_l2'], help='Attack.')
272 | parser.add_argument('--exp_folder', type=str, default='exps', help='Experiment folder to store all output.')
273 | parser.add_argument('--gpu', type=str, default='7', help='GPU number. Multiple GPUs are possible for PT models.')
274 | parser.add_argument('--n_ex', type=int, default=10000, help='Number of test ex to test on.')
275 | parser.add_argument('--p', type=float, default=0.05,
276 | help='Probability of changing a coordinate. Note: check the paper for the best values. '
277 | 'Linf standard: 0.05, L2 standard: 0.1. But robust models require higher p.')
278 | parser.add_argument('--eps', type=float, default=0.05, help='Radius of the Lp ball.')
279 | parser.add_argument('--n_iter', type=int, default=10000)
280 | parser.add_argument('--targeted', action='store_true', help='Targeted or untargeted attack.')
281 | args = parser.parse_args()
282 | args.loss = 'margin_loss' if not args.targeted else 'cross_entropy'
283 |
284 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
285 | dataset = 'mnist' if 'mnist' in args.model else 'cifar10' if 'cifar10' in args.model else 'imagenet'
286 | timestamp = str(datetime.now())[:-7]
287 | hps_str = '{} model={} dataset={} attack={} n_ex={} eps={} p={} n_iter={}'.format(
288 | timestamp, args.model, dataset, args.attack, args.n_ex, args.eps, args.p, args.n_iter)
289 | args.eps = args.eps / 255.0 if dataset == 'imagenet' else args.eps # for mnist and cifar10 we leave as it is
290 | batch_size = data.bs_dict[dataset]
291 | model_type = 'pt' if 'pt_' in args.model else 'tf'
292 | n_cls = 1000 if dataset == 'imagenet' else 10
293 | gpu_memory = 0.5 if dataset == 'mnist' and args.n_ex > 1000 else 0.15 if dataset == 'mnist' else 0.99
294 |
295 | log_path = '{}/{}.log'.format(args.exp_folder, hps_str)
296 | metrics_path = '{}/{}.metrics'.format(args.exp_folder, hps_str)
297 |
298 | log = utils.Logger(log_path)
299 | log.print('All hps: {}'.format(hps_str))
300 |
301 | if args.model != 'pt_inception':
302 | x_test, y_test = data.datasets_dict[dataset](args.n_ex)
303 | else: # exception for inception net on imagenet -- 299x299 images instead of 224x224
304 | x_test, y_test = data.datasets_dict[dataset](args.n_ex, size=299)
305 | x_test, y_test = x_test[:args.n_ex], y_test[:args.n_ex]
306 |
307 | if args.model == 'pt_post_avg_cifar10':
308 | x_test /= 255.0
309 | args.eps = args.eps / 255.0
310 |
311 | models_class_dict = {'tf': models.ModelTF, 'pt': models.ModelPT}
312 | model = models_class_dict[model_type](args.model, batch_size, gpu_memory)
313 |
314 | logits_clean = model.predict(x_test)
315 | corr_classified = logits_clean.argmax(1) == y_test
316 | # important to check that the model was restored correctly and the clean accuracy is high
317 | log.print('Clean accuracy: {:.2%}'.format(np.mean(corr_classified)))
318 |
319 | square_attack = square_attack_linf if args.attack == 'square_linf' else square_attack_l2
320 | y_target = utils.random_classes_except_current(y_test, n_cls) if args.targeted else y_test
321 | y_target_onehot = utils.dense_to_onehot(y_target, n_cls=n_cls)
322 | # Note: we count the queries only across correctly classified images
323 | n_queries, x_adv = square_attack(model, x_test, y_target_onehot, corr_classified, args.eps, args.n_iter,
324 | args.p, metrics_path, args.targeted, args.loss)
325 |
326 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torchvision import transforms
4 | from torchvision.datasets import ImageFolder
5 | from torch.utils.data import DataLoader
6 |
7 |
8 | def load_mnist(n_ex):
9 | from tensorflow.keras.datasets import mnist as mnist_keras
10 |
11 | x_test, y_test = mnist_keras.load_data()[1]
12 | x_test = x_test.astype(np.float64) / 255.0
13 | x_test = x_test[:, None, :, :]
14 |
15 | return x_test[:n_ex], y_test[:n_ex]
16 |
17 |
18 | def load_cifar10(n_ex):
19 | from madry_cifar10.cifar10_input import CIFAR10Data
20 | cifar = CIFAR10Data('madry_cifar10/cifar10_data')
21 | x_test, y_test = cifar.eval_data.xs.astype(np.float32), cifar.eval_data.ys
22 | x_test = np.transpose(x_test, axes=[0, 3, 1, 2])
23 | return x_test[:n_ex], y_test[:n_ex]
24 |
25 |
26 | def load_imagenet(n_ex, size=224):
27 | IMAGENET_SL = size
28 | IMAGENET_PATH = "/scratch/maksym/imagenet/val_orig"
29 | imagenet = ImageFolder(IMAGENET_PATH,
30 | transforms.Compose([
31 | transforms.Resize(IMAGENET_SL),
32 | transforms.CenterCrop(IMAGENET_SL),
33 | transforms.ToTensor()
34 | ]))
35 | torch.manual_seed(0)
36 |
37 | imagenet_loader = DataLoader(imagenet, batch_size=n_ex, shuffle=True, num_workers=1)
38 | x_test, y_test = next(iter(imagenet_loader))
39 |
40 | return np.array(x_test, dtype=np.float32), np.array(y_test)
41 |
42 |
43 | datasets_dict = {'mnist': load_mnist,
44 | 'cifar10': load_cifar10,
45 | 'imagenet': load_imagenet,
46 | }
47 | bs_dict = {'mnist': 10000,
48 | 'cifar10': 4096, # 4096 is the maximum that fits
49 | 'imagenet': 100,
50 | }
51 |
--------------------------------------------------------------------------------
/images/ablation_study.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/ablation_study.png
--------------------------------------------------------------------------------
/images/adv_examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/adv_examples.png
--------------------------------------------------------------------------------
/images/algorithm_rs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/algorithm_rs.png
--------------------------------------------------------------------------------
/images/ezgif.com-gif-maker-50-conf-small.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/ezgif.com-gif-maker-50-conf-small.gif
--------------------------------------------------------------------------------
/images/ezgif.com-gif-maker-img-53-l2-2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/ezgif.com-gif-maker-img-53-l2-2.gif
--------------------------------------------------------------------------------
/images/main_results_imagenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/main_results_imagenet.png
--------------------------------------------------------------------------------
/images/main_results_imagenet_l2_commonly_successful.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/main_results_imagenet_l2_commonly_successful.png
--------------------------------------------------------------------------------
/images/repository_picture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/repository_picture.png
--------------------------------------------------------------------------------
/images/sensitivity_wrt_p.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/sensitivity_wrt_p.png
--------------------------------------------------------------------------------
/images/success_rate_curves_full.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/success_rate_curves_full.png
--------------------------------------------------------------------------------
/images/table_clp_lsq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/table_clp_lsq.png
--------------------------------------------------------------------------------
/images/table_madry_mnist_l2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/table_madry_mnist_l2.png
--------------------------------------------------------------------------------
/images/table_madry_trades_mnist_linf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/table_madry_trades_mnist_linf.png
--------------------------------------------------------------------------------
/images/table_post_avg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/table_post_avg.png
--------------------------------------------------------------------------------
/logit_pairing/models.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from collections import OrderedDict
3 |
4 |
5 | # -------------------------------------------------------------
6 | # Models
7 | # -------------------------------------------------------------
8 |
9 | class LeNet:
10 | def __init__(self):
11 | super().__init__()
12 | self.nb_classes = 10
13 | self.input_shape = [28, 28, 3]
14 | self.weights_init = 'He'
15 | self.filters = 32 # 32 is the default here for all our pre-trained models
16 | self.is_training = False
17 | self.bn = False
18 | self.bn_scale = False
19 | self.bn_bias = False
20 | self.parameters = 0
21 |
22 | # Create variables
23 | with tf.variable_scope('conv1_vars'):
24 | self.W_conv1 = create_conv2d_weights(kernel_size=3, filter_in=1, filter_out=self.filters,
25 | init=self.weights_init)
26 | self.parameters += 3 * 3 * self.input_shape[-1] * self.filters
27 |
28 | self.b_conv1 = create_biases(size=self.filters)
29 | self.parameters += self.filters
30 |
31 | with tf.variable_scope('conv2_vars'):
32 | self.W_conv2 = create_conv2d_weights(kernel_size=3, filter_in=self.filters, filter_out=self.filters * 2,
33 | init=self.weights_init)
34 | self.parameters += 3 * 3 * self.filters * (self.filters * 2)
35 |
36 | self.b_conv2 = create_biases(size=self.filters * 2)
37 | self.parameters += self.filters * 2
38 |
39 | with tf.variable_scope('fc1_vars'):
40 | self.W_fc1 = create_weights(units_in=7 * 7 * self.filters * 2, units_out=1024, init=self.weights_init)
41 | self.parameters += (7 * 7 * self.filters * 2) * 1024
42 |
43 | self.b_fc1 = create_biases(size=1024)
44 | self.parameters += 1024
45 |
46 | with tf.variable_scope('fc2_vars'):
47 | self.W_fc2 = create_weights(units_in=1024, units_out=self.nb_classes, init=self.weights_init)
48 | self.parameters += 1024 * self.nb_classes
49 |
50 | self.b_fc2 = create_biases(size=self.nb_classes)
51 | self.parameters += self.nb_classes
52 |
53 | self.x_input = tf.placeholder(tf.float32, shape=[None, 784])
54 | self.y_input = tf.placeholder(tf.int64, shape=[None])
55 |
56 | x = tf.reshape(self.x_input, [-1, 28, 28, 1])
57 |
58 | with tf.name_scope('conv-block-1'):
59 | conv1 = conv_layer(x, self.is_training, self.W_conv1, stride=1, padding='SAME', bn=self.bn,
60 | bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv1', bias=self.b_conv1)
61 |
62 | with tf.name_scope('max-pool-1'):
63 | conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
64 |
65 | with tf.name_scope('conv-block-2'):
66 | conv2 = conv_layer(conv1, self.is_training, self.W_conv2, stride=1, padding='SAME', bn=self.bn,
67 | bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv2', bias=self.b_conv2)
68 |
69 | with tf.name_scope('max-pool-2'):
70 | conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
71 |
72 | with tf.name_scope('fc-block'):
73 | conv2 = tf.layers.flatten(conv2)
74 | fc1 = fc_layer(conv2, self.is_training, self.W_fc1, bn=self.bn, bn_scale=self.bn_scale,
75 | bn_bias=self.bn_bias, name='fc1', non_linearity='relu', bias=self.b_fc1)
76 |
77 | logits = fc_layer(fc1, self.is_training, self.W_fc2, bn=self.bn, bn_scale=self.bn_scale,
78 | bn_bias=self.bn_bias, name='fc2', non_linearity='linear', bias=self.b_fc2)
79 |
80 | self.summaries = False
81 | self.logits = logits
82 |
83 |
84 | class ResNet20_v2:
85 | def __init__(self):
86 | super().__init__()
87 | self.nb_classes = 10
88 | self.input_shape = [32, 32, 3]
89 | self.weights_init = 'He'
90 | self.filters = 64 # 64 is the default here for all our pre-trained models
91 | self.is_training = False
92 | self.bn = True
93 | self.bn_scale = True
94 | self.bn_bias = True
95 | self.parameters = 0
96 |
97 | # Create variables
98 | with tf.variable_scope('conv1_vars'):
99 | self.W_conv1 = create_conv2d_weights(kernel_size=3, filter_in=self.input_shape[-1], filter_out=self.filters,
100 | init=self.weights_init)
101 | self.parameters += 3 * 3 * self.input_shape[-1] * self.filters
102 |
103 | with tf.variable_scope('conv2_vars'):
104 | self.W_conv2 = create_conv2d_weights(kernel_size=3, filter_in=self.filters, filter_out=self.filters,
105 | init=self.weights_init)
106 | self.parameters += 3 * 3 * self.filters * self.filters
107 |
108 | with tf.variable_scope('conv3_vars'):
109 | self.W_conv3 = create_conv2d_weights(kernel_size=3, filter_in=self.filters, filter_out=self.filters,
110 | init=self.weights_init)
111 | self.parameters += 3 * 3 * self.filters * self.filters
112 |
113 | with tf.variable_scope('conv4_vars'):
114 | self.W_conv4 = create_conv2d_weights(kernel_size=3, filter_in=self.filters, filter_out=self.filters,
115 | init=self.weights_init)
116 | self.parameters += 3 * 3 * self.filters * self.filters
117 |
118 | with tf.variable_scope('conv5_vars'):
119 | self.W_conv5 = create_conv2d_weights(kernel_size=3, filter_in=self.filters, filter_out=self.filters,
120 | init=self.weights_init)
121 | self.parameters += 3 * 3 * self.filters * self.filters
122 |
123 | with tf.variable_scope('conv6_vars'):
124 | self.W_conv6 = create_conv2d_weights(kernel_size=3, filter_in=self.filters, filter_out=self.filters,
125 | init=self.weights_init)
126 | self.parameters += 3 * 3 * self.filters * self.filters
127 |
128 | with tf.variable_scope('conv7_vars'):
129 | self.W_conv7 = create_conv2d_weights(kernel_size=3, filter_in=self.filters, filter_out=self.filters,
130 | init=self.weights_init)
131 | self.parameters += 3 * 3 * self.filters * self.filters
132 |
133 | with tf.variable_scope('conv8_vars'):
134 | self.W_conv8 = create_conv2d_weights(kernel_size=3, filter_in=self.filters, filter_out=self.filters * 2,
135 | init=self.weights_init)
136 | self.parameters += 3 * 3 * self.filters * (self.filters * 2)
137 |
138 | with tf.variable_scope('conv9_vars'):
139 | self.W_conv9 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 2, filter_out=self.filters * 2,
140 | init=self.weights_init)
141 | self.parameters += 3 * 3 * (self.filters * 2) * (self.filters * 2)
142 |
143 | with tf.variable_scope('conv10_vars'):
144 | self.W_conv10 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 2,
145 | filter_out=self.filters * 2, init=self.weights_init)
146 | self.parameters += 3 * 3 * (self.filters * 2) * (self.filters * 2)
147 |
148 | with tf.variable_scope('conv11_vars'):
149 | self.W_conv11 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 2,
150 | filter_out=self.filters * 2, init=self.weights_init)
151 | self.parameters += 3 * 3 * (self.filters * 2) * (self.filters * 2)
152 |
153 | with tf.variable_scope('conv12_vars'):
154 | self.W_conv12 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 2,
155 | filter_out=self.filters * 2, init=self.weights_init)
156 | self.parameters += 3 * 3 * (self.filters * 2) * (self.filters * 2)
157 |
158 | with tf.variable_scope('conv13_vars'):
159 | self.W_conv13 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 2,
160 | filter_out=self.filters * 2, init=self.weights_init)
161 | self.parameters += 3 * 3 * (self.filters * 2) * (self.filters * 2)
162 |
163 | with tf.variable_scope('conv14_vars'):
164 | self.W_conv14 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 2,
165 | filter_out=self.filters * 4, init=self.weights_init)
166 | self.parameters += 3 * 3 * (self.filters * 2) * (self.filters * 4)
167 |
168 | with tf.variable_scope('conv15_vars'):
169 | self.W_conv15 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 4,
170 | filter_out=self.filters * 4, init=self.weights_init)
171 | self.parameters += 3 * 3 * (self.filters * 4) * (self.filters * 4)
172 |
173 | with tf.variable_scope('conv16_vars'):
174 | self.W_conv16 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 4,
175 | filter_out=self.filters * 4, init=self.weights_init)
176 | self.parameters += 3 * 3 * (self.filters * 4) * (self.filters * 4)
177 |
178 | with tf.variable_scope('conv17_vars'):
179 | self.W_conv17 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 4,
180 | filter_out=self.filters * 4, init=self.weights_init)
181 | self.parameters += 3 * 3 * (self.filters * 4) * (self.filters * 4)
182 |
183 | with tf.variable_scope('conv18_vars'):
184 | self.W_conv18 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 4,
185 | filter_out=self.filters * 4, init=self.weights_init)
186 | self.parameters += 3 * 3 * (self.filters * 4) * (self.filters * 4)
187 |
188 | with tf.variable_scope('conv19_vars'):
189 | self.W_conv19 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 4,
190 | filter_out=self.filters * 4, init=self.weights_init)
191 | self.parameters += 3 * 3 * (self.filters * 4) * (self.filters * 4)
192 |
193 | with tf.variable_scope('fc1_vars'):
194 | self.W_fc1 = create_weights(units_in=self.filters * 4, units_out=self.nb_classes, init=self.weights_init)
195 | self.parameters += (self.filters * 4) * self.nb_classes
196 |
197 | self.b_fc1 = create_biases(size=self.nb_classes)
198 | self.parameters += self.nb_classes
199 |
200 | with tf.variable_scope('scip1_vars'):
201 | self.W_scip1 = create_conv2d_weights(kernel_size=1, filter_in=self.filters, filter_out=self.filters,
202 | init=self.weights_init)
203 | self.parameters += 1 * 1 * self.filters * self.filters
204 |
205 | with tf.variable_scope('scip2_vars'):
206 | self.W_scip2 = create_conv2d_weights(kernel_size=1, filter_in=self.filters, filter_out=self.filters * 2,
207 | init=self.weights_init)
208 | self.parameters += 1 * 1 * self.filters * (self.filters * 2)
209 |
210 | with tf.variable_scope('scip3_vars'):
211 | self.W_scip3 = create_conv2d_weights(kernel_size=1, filter_in=self.filters * 2, filter_out=self.filters * 4,
212 | init=self.weights_init)
213 | self.parameters += 1 * 1 * (self.filters * 2) * (self.filters * 4)
214 |
215 | self.x_input = tf.placeholder(tf.float32, shape=[None, 32, 32, 3])
216 | self.y_input = tf.placeholder(tf.int64, shape=None)
217 | x = self.x_input / 255.0
218 |
219 | # Specify forward pass
220 | with tf.name_scope('input-block'):
221 | conv1 = conv_layer(x, self.is_training, self.W_conv1, stride=1, padding='SAME',
222 | bn=False, bn_scale=self.bn_scale, bn_bias=self.bn_bias,
223 | name='conv1',
224 | non_linearity='linear')
225 |
226 | with tf.name_scope('conv-block-1'):
227 | conv2 = pre_act_conv_layer(conv1, self.is_training, self.W_conv2, stride=1, padding='SAME',
228 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv2')
229 |
230 | conv3 = pre_act_conv_layer(conv2, self.is_training, self.W_conv3, stride=1, padding='SAME',
231 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv3')
232 |
233 | # skip connection
234 | conv3 += tf.nn.conv2d(conv1, self.W_scip1, strides=[1, 1, 1, 1], padding='SAME', name='conv-skip1')
235 |
236 | conv4 = pre_act_conv_layer(conv3, self.is_training, self.W_conv4, stride=1, padding='SAME',
237 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv4')
238 |
239 | conv5 = pre_act_conv_layer(conv4, self.is_training, self.W_conv5, stride=1, padding='SAME',
240 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv5')
241 |
242 | # skip connection
243 | conv5 += conv3
244 |
245 | conv6 = pre_act_conv_layer(conv5, self.is_training, self.W_conv6, stride=1, padding='SAME',
246 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv6')
247 |
248 | conv7 = pre_act_conv_layer(conv6, self.is_training, self.W_conv7, stride=1, padding='SAME',
249 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv7')
250 |
251 | # skip connection
252 | conv7 += conv5
253 |
254 | with tf.name_scope('conv-block-2'):
255 | conv8 = pre_act_conv_layer(conv7, self.is_training, self.W_conv8, stride=2, padding='SAME',
256 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv8')
257 |
258 | conv9 = pre_act_conv_layer(conv8, self.is_training, self.W_conv9, stride=1, padding='SAME',
259 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv9')
260 |
261 | # skip connection
262 | conv9 += tf.nn.conv2d(conv7, self.W_scip2, strides=[1, 2, 2, 1], padding='SAME', name='conv-skip2')
263 |
264 | conv10 = pre_act_conv_layer(conv9, self.is_training, self.W_conv10, stride=1, padding='SAME',
265 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv10')
266 |
267 | conv11 = pre_act_conv_layer(conv10, self.is_training, self.W_conv11, stride=1, padding='SAME',
268 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv11')
269 |
270 | # skip connection
271 | conv11 += conv9
272 |
273 | conv12 = pre_act_conv_layer(conv11, self.is_training, self.W_conv12, stride=1, padding='SAME',
274 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv12')
275 |
276 | conv13 = pre_act_conv_layer(conv12, self.is_training, self.W_conv13, stride=1, padding='SAME',
277 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv13')
278 |
279 | # skip connection
280 | conv13 += conv11
281 |
282 | with tf.name_scope('conv-block-3'):
283 | conv14 = pre_act_conv_layer(conv13, self.is_training, self.W_conv14, stride=2, padding='SAME',
284 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv14')
285 |
286 | conv15 = pre_act_conv_layer(conv14, self.is_training, self.W_conv15, stride=1, padding='SAME',
287 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv15')
288 |
289 | # skip connection
290 | conv15 += tf.nn.conv2d(conv13, self.W_scip3, strides=[1, 2, 2, 1], padding='SAME', name='conv-skip3')
291 |
292 | conv16 = pre_act_conv_layer(conv15, self.is_training, self.W_conv16, stride=1, padding='SAME',
293 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv16')
294 |
295 | conv17 = pre_act_conv_layer(conv16, self.is_training, self.W_conv17, stride=1, padding='SAME',
296 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv17')
297 |
298 | # skip connection
299 | conv17 += conv15
300 |
301 | conv18 = pre_act_conv_layer(conv17, self.is_training, self.W_conv18, stride=1, padding='SAME',
302 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv18')
303 |
304 | conv19 = pre_act_conv_layer(conv18, self.is_training, self.W_conv19, stride=1, padding='SAME',
305 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv19')
306 |
307 | # skip connection
308 | conv19 += conv17
309 | conv19 = nonlinearity(conv19)
310 |
311 | with tf.name_scope('output-block'):
312 | with tf.name_scope('global-average-pooling'):
313 | fc1 = tf.reduce_mean(conv19, axis=[1, 2])
314 |
315 | logits = fc_layer(fc1, self.is_training, self.W_fc1, bn=False, bn_scale=self.bn_scale, bn_bias=self.bn_bias,
316 | name='fc1',
317 | non_linearity='linear', bias=self.b_fc1)
318 |
319 | self.summaries = False
320 | self.logits = logits
321 |
322 |
323 | # -------------------------------------------------------------
324 | # Helpers
325 | # -------------------------------------------------------------
326 |
327 | def create_weights(units_in, units_out, init='Xavier', seed=None):
328 | if init == 'Xavier':
329 | initializer = tf.variance_scaling_initializer(scale=1.0,
330 | mode='fan_in',
331 | distribution='normal',
332 | seed=None,
333 | dtype=tf.float32)
334 | elif init == 'He':
335 | initializer = tf.variance_scaling_initializer(scale=2.0,
336 | mode='fan_in',
337 | distribution='normal',
338 | seed=None,
339 | dtype=tf.float32)
340 | else:
341 | initializer = tf.truncated_normal_initializer(mean=0.0, stddev=0.01, seed=seed, dtype=tf.float32)
342 |
343 | weights = tf.get_variable(name='weights',
344 | shape=[units_in, units_out],
345 | dtype=tf.float32,
346 | initializer=initializer)
347 | return weights
348 |
349 |
350 | def create_conv2d_weights(kernel_size, filter_in, filter_out, init='Xavier', seed=None):
351 | if init == 'Xavier':
352 | initializer = tf.variance_scaling_initializer(scale=1.0,
353 | mode='fan_in',
354 | distribution='normal',
355 | seed=None,
356 | dtype=tf.float32)
357 | elif init == 'He':
358 | initializer = tf.variance_scaling_initializer(scale=2.0,
359 | mode='fan_in',
360 | distribution='normal',
361 | seed=None,
362 | dtype=tf.float32)
363 | else:
364 | initializer = tf.truncated_normal_initializer(mean=0.0, stddev=0.01, seed=seed, dtype=tf.float32)
365 |
366 | weights = tf.get_variable(name='weights',
367 | shape=[kernel_size, kernel_size, filter_in, filter_out],
368 | dtype=tf.float32,
369 | initializer=initializer)
370 | return weights
371 |
372 |
373 | def create_biases(size):
374 | return tf.get_variable(name='biases', shape=[size], dtype=tf.float32, initializer=tf.zeros_initializer())
375 |
376 |
377 | def batch_norm(x, is_training, scale, bias, name, reuse):
378 | return tf.contrib.layers.batch_norm(
379 | x,
380 | decay=0.999,
381 | center=bias,
382 | scale=scale,
383 | epsilon=0.001,
384 | param_initializers=None,
385 | updates_collections=tf.GraphKeys.UPDATE_OPS,
386 | is_training=is_training,
387 | reuse=reuse,
388 | variables_collections=['batch-norm'],
389 | outputs_collections=None,
390 | trainable=True,
391 | batch_weights=None,
392 | fused=False,
393 | zero_debias_moving_mean=False,
394 | scope=name,
395 | renorm=False,
396 | renorm_clipping=None,
397 | renorm_decay=0.99
398 | )
399 |
400 |
401 | def nonlinearity(x, non_linearity='relu'):
402 | if non_linearity == 'linear':
403 | return tf.identity(x)
404 | if non_linearity == 'sigmoid':
405 | return tf.nn.sigmoid(x)
406 | if non_linearity == 'tanh':
407 | return tf.nn.tanh(x)
408 | if non_linearity == 'relu':
409 | return tf.nn.relu(x)
410 | if non_linearity == 'elu':
411 | return tf.nn.elu(x)
412 | if non_linearity == 'selu':
413 | return tf.nn.selu(x)
414 |
415 |
416 | def conv_layer(inputs, is_training, weights, stride, padding, bn, bn_scale, bn_bias, name,
417 | non_linearity='relu', bias=None):
418 | if bias is not None:
419 | inputs = tf.nn.conv2d(inputs, weights, strides=[1, stride, stride, 1], padding=padding) + bias
420 | else:
421 | inputs = tf.nn.conv2d(inputs, weights, strides=[1, stride, stride, 1], padding=padding)
422 |
423 | if bn:
424 | inputs = batch_norm(inputs, is_training=is_training, scale=bn_scale, bias=bn_bias,
425 | name='batch-norm-{:s}'.format(name),
426 | reuse=tf.AUTO_REUSE)
427 |
428 | activations = nonlinearity(inputs, non_linearity=non_linearity)
429 |
430 | return activations
431 |
432 |
433 | def pre_act_conv_layer(inputs, is_training, weights, stride, padding, bn, bn_scale, bn_bias, name,
434 | non_linearity='relu'):
435 | if bn:
436 | inputs = batch_norm(inputs, is_training=is_training, scale=bn_scale, bias=bn_bias,
437 | name='batch-norm-{:s}'.format(name),
438 | reuse=tf.AUTO_REUSE)
439 |
440 | activations = nonlinearity(inputs, non_linearity=non_linearity)
441 |
442 | outputs = tf.nn.conv2d(activations, weights, strides=[1, stride, stride, 1], padding=padding)
443 |
444 | return outputs
445 |
446 |
447 | def fc_layer(inputs, is_training, weights, bn, bn_scale, bn_bias, name, non_linearity='relu', bias=None):
448 | if bias is not None:
449 | inputs = tf.matmul(inputs, weights) + bias
450 | else:
451 | inputs = tf.matmul(inputs, weights)
452 |
453 | if bn:
454 | inputs = batch_norm(inputs, is_training=is_training, scale=bn_scale, bias=bn_bias,
455 | name='batch-norm-{:s}'.format(name),
456 | reuse=tf.AUTO_REUSE)
457 |
458 | activations = nonlinearity(inputs, non_linearity)
459 |
460 | return activations
461 |
--------------------------------------------------------------------------------
/madry_cifar10/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Aleksander Madry, Aleksandar Makelov, Ludwig Schmidt, Dimitris Tsipras, and Adrian Vladu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/madry_cifar10/README.md:
--------------------------------------------------------------------------------
1 | # CIFAR10 Adversarial Examples Challenge
2 |
3 | Recently, there has been much progress on adversarial *attacks* against neural networks, such as the [cleverhans](https://github.com/tensorflow/cleverhans) library and the code by [Carlini and Wagner](https://github.com/carlini/nn_robust_attacks).
4 | We now complement these advances by proposing an *attack challenge* for the
5 | [CIFAR10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html) which follows the
6 | format of [our earlier MNIST challenge](https://github.com/MadryLab/mnist_challenge).
7 | We have trained a robust network, and the objective is to find a set of adversarial examples on which this network achieves only a low accuracy.
8 | To train an adversarially-robust network, we followed the approach from our recent paper:
9 |
10 | **Towards Deep Learning Models Resistant to Adversarial Attacks**
11 | *Aleksander Madry, Aleksandar Makelov, Ludwig Schmidt, Dimitris Tsipras, Adrian Vladu*
12 | https://arxiv.org/abs/1706.06083.
13 |
14 | As part of the challenge, we release both the training code and the network architecture, but keep the network weights secret.
15 | We invite any researcher to submit attacks against our model (see the detailed instructions below).
16 | We will maintain a leaderboard of the best attacks for the next two months and then publish our secret network weights.
17 |
18 | Analogously to our MNIST challenge, the goal of this challenge is to clarify the state-of-the-art for adversarial robustness on CIFAR10. Moreover, we hope that future work on defense mechanisms will adopt a similar challenge format in order to improve reproducibility and empirical comparisons.
19 |
20 | **Update 2017-12-10**: We released our secret model. You can download it by running `python fetch_model.py secret`. As of Dec 10 we are no longer accepting black-box challenge submissions. We have set up a leaderboard for white-box attacks on the (now released) secret model. The submission format is the same as before. We plan to continue evaluating submissions and maintaining the leaderboard for the foreseeable future.
21 |
22 | ## Black-Box Leaderboard (Original Challenge)
23 |
24 | | Attack | Submitted by | Accuracy | Submission Date |
25 | | -------------------------------------- | ------------- | -------- | ---- |
26 | | PGD on the cross-entropy loss for the
adversarially trained public network | (initial entry) | **63.39%** | Jul 12, 2017 |
27 | | PGD on the [CW](https://github.com/carlini/nn_robust_attacks) loss for the
adversarially trained public network | (initial entry) | 64.38% | Jul 12, 2017 |
28 | | FGSM on the [CW](https://github.com/carlini/nn_robust_attacks) loss for the
adversarially trained public network | (initial entry) | 67.25% | Jul 12, 2017 |
29 | | FGSM on the [CW](https://github.com/carlini/nn_robust_attacks) loss for the
naturally trained public network | (initial entry) | 85.23% | Jul 12, 2017 |
30 |
31 | ## White-Box Leaderboard
32 |
33 | | Attack | Submitted by | Accuracy | Submission Date |
34 | | -------------------------------------- | ------------- | -------- | ---- |
35 | | [FAB: Fast Adaptive Boundary Attack](https://github.com/fra31/fab-attack) | Francesco Croce | **44.51%** | Jun 7, 2019 |
36 | | [Distributionally Adversarial Attack](https://github.com/tianzheng4/Distributionally-Adversarial-Attack) | Tianhang Zheng | 44.71% | Aug 21, 2018 |
37 | | 20-step PGD on the cross-entropy loss
with 10 random restarts | Tianhang Zheng | 45.21% | Aug 24, 2018 |
38 | | 20-step PGD on the cross-entropy loss | (initial entry) | 47.04% | Dec 10, 2017 |
39 | | 20-step PGD on the [CW](https://github.com/carlini/nn_robust_attacks) loss | (initial entry) | 47.76% | Dec 10, 2017 |
40 | | FGSM on the [CW](https://github.com/carlini/nn_robust_attacks) loss | (initial entry) | 54.92% | Dec 10, 2017 |
41 | | FGSM on the cross-entropy loss | (initial entry) | 55.55% | Dec 10, 2017 |
42 |
43 |
44 |
45 |
46 |
47 | ## Format and Rules
48 |
49 | The objective of the challenge is to find black-box (transfer) attacks that are effective against our CIFAR10 model.
50 | Attacks are allowed to perturb each pixel of the input image by at most `epsilon=8.0` on a `0-255` pixel scale.
51 | To ensure that the attacks are indeed black-box, we release our training code and model architecture, but keep the actual network weights secret.
52 |
53 | We invite any interested researchers to submit attacks against our model.
54 | The most successful attacks will be listed in the leaderboard above.
55 | As a reference point, we have seeded the leaderboard with the results of some standard attacks.
56 |
57 | ### The CIFAR10 Model
58 |
59 | We used the code published in this repository to produce an adversarially robust model for CIFAR10 classification. The model is a residual convolutional neural network consisting of five residual units and a fully connected layer. This architecture is derived from the "w32-10 wide" variant of the [Tensorflow model repository](https://github.com/tensorflow/models/blob/master/resnet/resnet_model.py).
60 | The network was trained against an iterative adversary that is allowed to perturb each pixel by at most `epsilon=8.0`.
61 |
62 | The random seed used for training and the trained network weights will be kept secret.
63 |
64 | The `sha256()` digest of our model file is:
65 | ```
66 | 555be6e892372599380c9da5d5f9802f9cbd098be8a47d24d96937a002305fd4
67 | ```
68 | We will release the corresponding model file on September 15 2017, which is roughly two months after the start of this competition. **Edit: We are extending the deadline for submitting attacks to October 15th due to requests.**
69 |
70 | ### The Attack Model
71 |
72 | We are interested in adversarial inputs that are derived from the CIFAR10 test set.
73 | Each pixel can be perturbed by at most `epsilon=8.0` from its initial value on the `0-255` pixel scale.
74 | All pixels can be perturbed independently, so this is an l_infinity attack.
75 |
76 | ### Submitting an Attack
77 |
78 | Each attack should consist of a perturbed version of the CIFAR10 test set.
79 | Each perturbed image in this test set should follow the above attack model.
80 |
81 | The adversarial test set should be formated as a numpy array with one row per example and each row containing a 32x32x3
82 | array of pixels.
83 | Hence the overall dimensions are 10,000x32x32x3.
84 | Each pixel must be in the [0, 255] range.
85 | See the script `pgd_attack.py` for an attack that generates an adversarial test set in this format.
86 |
87 | In order to submit your attack, save the matrix containing your adversarial examples with `numpy.save` and email the resulting file to cifar10.challenge@gmail.com.
88 | We will then run the `run_attack.py` script on your file to verify that the attack is valid and to evaluate the accuracy of our secret model on your examples.
89 | After that, we will reply with the predictions of our model on each of your examples and the overall accuracy of our model on your evaluation set.
90 |
91 | If the attack is valid and outperforms all current attacks in the leaderboard, it will appear at the top of the leaderboard.
92 | Novel types of attacks might be included in the leaderboard even if they do not perform best.
93 |
94 | We strongly encourage you to disclose your attack method.
95 | We would be happy to add a link to your code in our leaderboard.
96 |
97 | ## Overview of the Code
98 | The code consists of seven Python scripts and the file `config.json` that contains various parameter settings.
99 |
100 | ### Running the code
101 | - `python train.py`: trains the network, storing checkpoints along
102 | the way.
103 | - `python eval.py`: an infinite evaluation loop, processing each new
104 | checkpoint as it is created while logging summaries. It is intended
105 | to be run in parallel with the `train.py` script.
106 | - `python pgd_attack.py`: applies the attack to the CIFAR10 eval set and
107 | stores the resulting adversarial eval set in a `.npy` file. This file is
108 | in a valid attack format for our challenge.
109 | - `python run_attack.py`: evaluates the model on the examples in
110 | the `.npy` file specified in config, while ensuring that the adversarial examples
111 | are indeed a valid attack. The script also saves the network predictions in `pred.npy`.
112 | - `python fetch_model.py name`: downloads the pre-trained model with the
113 | specified name (at the moment `adv_trained` or `natural`), prints the sha256
114 | hash, and places it in the models directory.
115 | - `cifar10_input.py` provides utility functions and classes for loading the CIFAR10 dataset.
116 |
117 | ### Parameters in `config.json`
118 |
119 | Model configuration:
120 | - `model_dir`: contains the path to the directory of the currently
121 | trained/evaluated model.
122 |
123 | Training configuration:
124 | - `tf_random_seed`: the seed for the RNG used to initialize the network
125 | weights.
126 | - `numpy_random_seed`: the seed for the RNG used to pass over the dataset in random order
127 | - `max_num_training_steps`: the number of training steps.
128 | - `num_output_steps`: the number of training steps between printing
129 | progress in standard output.
130 | - `num_summary_steps`: the number of training steps between storing
131 | tensorboard summaries.
132 | - `num_checkpoint_steps`: the number of training steps between storing
133 | model checkpoints.
134 | - `training_batch_size`: the size of the training batch.
135 |
136 | Evaluation configuration:
137 | - `num_eval_examples`: the number of CIFAR10 examples to evaluate the
138 | model on.
139 | - `eval_batch_size`: the size of the evaluation batches.
140 | - `eval_on_cpu`: forces the `eval.py` script to run on the CPU so it does not compete with `train.py` for GPU resources.
141 |
142 | Adversarial examples configuration:
143 | - `epsilon`: the maximum allowed perturbation per pixel.
144 | - `k`: the number of PGD iterations used by the adversary.
145 | - `a`: the size of the PGD adversary steps.
146 | - `random_start`: specifies whether the adversary will start iterating
147 | from the natural example or a random perturbation of it.
148 | - `loss_func`: the loss function used to run pgd on. `xent` corresponds to the
149 | standard cross-entropy loss, `cw` corresponds to the loss function
150 | of [Carlini and Wagner](https://arxiv.org/abs/1608.04644).
151 | - `store_adv_path`: the file in which adversarial examples are stored.
152 | Relevant for the `pgd_attack.py` and `run_attack.py` scripts.
153 |
154 | ## Example usage
155 | After cloning the repository you can either train a new network or evaluate/attack one of our pre-trained networks.
156 | #### Training a new network
157 | * Start training by running:
158 | ```
159 | python train.py
160 | ```
161 | * (Optional) Evaluation summaries can be logged by simultaneously
162 | running:
163 | ```
164 | python eval.py
165 | ```
166 | #### Download a pre-trained network
167 | * For an adversarially trained network, run
168 | ```
169 | python fetch_model.py adv_trained
170 | ```
171 | and use the `config.json` file to set `"model_dir": "models/adv_trained"`.
172 | * For a naturally trained network, run
173 | ```
174 | python fetch_model.py natural
175 | ```
176 | and use the `config.json` file to set `"model_dir": "models/naturally_trained"`.
177 | #### Test the network
178 | * Create an attack file by running
179 | ```
180 | python pgd_attack.py
181 | ```
182 | * Evaluate the network with
183 | ```
184 | python run_attack.py
185 | ```
186 |
--------------------------------------------------------------------------------
/madry_cifar10/cifar10_input.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for importing the CIFAR10 dataset.
3 |
4 | Each image in the dataset is a numpy array of shape (32, 32, 3), with the values
5 | being unsigned integers (i.e., in the range 0,1,...,255).
6 | """
7 |
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 |
12 | import os
13 | import pickle
14 | import sys
15 | import tensorflow as tf
16 | from tensorflow.examples.tutorials.mnist import input_data
17 | version = sys.version_info
18 |
19 | import numpy as np
20 |
21 | class CIFAR10Data(object):
22 | """
23 | Unpickles the CIFAR10 dataset from a specified folder containing a pickled
24 | version following the format of Krizhevsky which can be found
25 | [here](https://www.cs.toronto.edu/~kriz/cifar.html).
26 |
27 | Inputs to constructor
28 | =====================
29 |
30 | - path: path to the pickled dataset. The training data must be pickled
31 | into five files named data_batch_i for i = 1, ..., 5, containing 10,000
32 | examples each, the test data
33 | must be pickled into a single file called test_batch containing 10,000
34 | examples, and the 10 class names must be
35 | pickled into a file called batches.meta. The pickled examples should
36 | be stored as a tuple of two objects: an array of 10,000 32x32x3-shaped
37 | arrays, and an array of their 10,000 true labels.
38 |
39 | """
40 | def __init__(self, path):
41 | train_filenames = ['data_batch_{}'.format(ii + 1) for ii in range(5)]
42 | eval_filename = 'test_batch'
43 | metadata_filename = 'batches.meta'
44 |
45 | train_images = np.zeros((50000, 32, 32, 3), dtype='uint8')
46 | train_labels = np.zeros(50000, dtype='int32')
47 | for ii, fname in enumerate(train_filenames):
48 | cur_images, cur_labels = self._load_datafile(os.path.join(path, fname))
49 | train_images[ii * 10000 : (ii+1) * 10000, ...] = cur_images
50 | train_labels[ii * 10000 : (ii+1) * 10000, ...] = cur_labels
51 | eval_images, eval_labels = self._load_datafile(
52 | os.path.join(path, eval_filename))
53 |
54 | with open(os.path.join(path, metadata_filename), 'rb') as fo:
55 | if version.major == 3:
56 | data_dict = pickle.load(fo, encoding='bytes')
57 | else:
58 | data_dict = pickle.load(fo)
59 |
60 | self.label_names = data_dict[b'label_names']
61 | for ii in range(len(self.label_names)):
62 | self.label_names[ii] = self.label_names[ii].decode('utf-8')
63 |
64 | self.train_data = DataSubset(train_images, train_labels)
65 | self.eval_data = DataSubset(eval_images, eval_labels)
66 |
67 | @staticmethod
68 | def _load_datafile(filename):
69 | with open(filename, 'rb') as fo:
70 | if version.major == 3:
71 | data_dict = pickle.load(fo, encoding='bytes')
72 | else:
73 | data_dict = pickle.load(fo)
74 |
75 | assert data_dict[b'data'].dtype == np.uint8
76 | image_data = data_dict[b'data']
77 | image_data = image_data.reshape((10000, 3, 32, 32)).transpose(0, 2, 3, 1)
78 | return image_data, np.array(data_dict[b'labels'])
79 |
80 | class AugmentedCIFAR10Data(object):
81 | """
82 | Data augmentation wrapper over a loaded dataset.
83 |
84 | Inputs to constructor
85 | =====================
86 | - raw_cifar10data: the loaded CIFAR10 dataset, via the CIFAR10Data class
87 | - sess: current tensorflow session
88 | - model: current model (needed for input tensor)
89 | """
90 | def __init__(self, raw_cifar10data, sess, model):
91 | assert isinstance(raw_cifar10data, CIFAR10Data)
92 | self.image_size = 32
93 |
94 | # create augmentation computational graph
95 | self.x_input_placeholder = tf.placeholder(tf.float32, shape=[None, 32, 32, 3])
96 | padded = tf.map_fn(lambda img: tf.image.resize_image_with_crop_or_pad(
97 | img, self.image_size + 4, self.image_size + 4),
98 | self.x_input_placeholder)
99 | cropped = tf.map_fn(lambda img: tf.random_crop(img, [self.image_size,
100 | self.image_size,
101 | 3]), padded)
102 | flipped = tf.map_fn(lambda img: tf.image.random_flip_left_right(img), cropped)
103 | self.augmented = flipped
104 |
105 | self.train_data = AugmentedDataSubset(raw_cifar10data.train_data, sess,
106 | self.x_input_placeholder,
107 | self.augmented)
108 | self.eval_data = AugmentedDataSubset(raw_cifar10data.eval_data, sess,
109 | self.x_input_placeholder,
110 | self.augmented)
111 | self.label_names = raw_cifar10data.label_names
112 |
113 |
114 | class DataSubset(object):
115 | def __init__(self, xs, ys):
116 | self.xs = xs
117 | self.n = xs.shape[0]
118 | self.ys = ys
119 | self.batch_start = 0
120 | self.cur_order = np.random.permutation(self.n)
121 |
122 | def get_next_batch(self, batch_size, multiple_passes=False, reshuffle_after_pass=True):
123 | if self.n < batch_size:
124 | raise ValueError('Batch size can be at most the dataset size')
125 | if not multiple_passes:
126 | actual_batch_size = min(batch_size, self.n - self.batch_start)
127 | if actual_batch_size <= 0:
128 | raise ValueError('Pass through the dataset is complete.')
129 | batch_end = self.batch_start + actual_batch_size
130 | batch_xs = self.xs[self.cur_order[self.batch_start : batch_end], ...]
131 | batch_ys = self.ys[self.cur_order[self.batch_start : batch_end], ...]
132 | self.batch_start += actual_batch_size
133 | return batch_xs, batch_ys
134 | actual_batch_size = min(batch_size, self.n - self.batch_start)
135 | if actual_batch_size < batch_size:
136 | if reshuffle_after_pass:
137 | self.cur_order = np.random.permutation(self.n)
138 | self.batch_start = 0
139 | batch_end = self.batch_start + batch_size
140 | batch_xs = self.xs[self.cur_order[self.batch_start : batch_end], ...]
141 | batch_ys = self.ys[self.cur_order[self.batch_start : batch_end], ...]
142 | self.batch_start += batch_size
143 | return batch_xs, batch_ys
144 |
145 |
146 | class AugmentedDataSubset(object):
147 | def __init__(self, raw_datasubset, sess, x_input_placeholder,
148 | augmented):
149 | self.sess = sess
150 | self.raw_datasubset = raw_datasubset
151 | self.x_input_placeholder = x_input_placeholder
152 | self.augmented = augmented
153 |
154 | def get_next_batch(self, batch_size, multiple_passes=False, reshuffle_after_pass=True):
155 | raw_batch = self.raw_datasubset.get_next_batch(batch_size, multiple_passes,
156 | reshuffle_after_pass)
157 | images = raw_batch[0].astype(np.float32)
158 | return self.sess.run(self.augmented, feed_dict={self.x_input_placeholder:
159 | raw_batch[0]}), raw_batch[1]
160 |
161 |
--------------------------------------------------------------------------------
/madry_cifar10/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_comment": "===== MODEL CONFIGURATION =====",
3 | "model_dir": "models/secret",
4 |
5 | "_comment": "===== DATASET CONFIGURATION =====",
6 | "data_path": "cifar10_data",
7 |
8 | "_comment": "===== TRAINING CONFIGURATION =====",
9 | "tf_random_seed": 451760341,
10 | "np_random_seed": 216105420,
11 | "max_num_training_steps": 80000,
12 | "num_output_steps": 100,
13 | "num_summary_steps": 100,
14 | "num_checkpoint_steps": 1000,
15 | "training_batch_size": 128,
16 | "step_size_schedule": [[0, 0.1], [40000, 0.01], [60000, 0.001]],
17 | "weight_decay": 0.0002,
18 | "momentum": 0.9,
19 |
20 | "_comment": "===== EVAL CONFIGURATION =====",
21 | "num_eval_examples": 100,
22 | "eval_batch_size": 100,
23 | "eval_on_cpu": false,
24 |
25 | "_comment": "=====ADVERSARIAL EXAMPLES CONFIGURATION=====",
26 | "epsilon": 8.0,
27 | "num_steps": 10,
28 | "step_size": 2.0,
29 | "random_start": true,
30 | "loss_func": "xent",
31 | "store_adv_path": "attack.npy"
32 | }
33 |
--------------------------------------------------------------------------------
/madry_cifar10/eval.py:
--------------------------------------------------------------------------------
1 | """
2 | Infinite evaluation loop going through the checkpoints in the model directory
3 | as they appear and evaluating them. Accuracy and average loss are printed and
4 | added as tensorboard summaries.
5 | """
6 | from __future__ import absolute_import
7 | from __future__ import division
8 | from __future__ import print_function
9 |
10 | from datetime import datetime
11 | import json
12 | import math
13 | import os
14 | import sys
15 | import time
16 |
17 | import tensorflow as tf
18 |
19 | import cifar10_input
20 | from model import Model
21 | from pgd_attack import LinfPGDAttack
22 |
23 | # Global constants
24 | with open('config.json') as config_file:
25 | config = json.load(config_file)
26 | num_eval_examples = config['num_eval_examples']
27 | eval_batch_size = config['eval_batch_size']
28 | eval_on_cpu = config['eval_on_cpu']
29 | data_path = config['data_path']
30 |
31 | model_dir = config['model_dir']
32 |
33 | # Set upd the data, hyperparameters, and the model
34 | cifar = cifar10_input.CIFAR10Data(data_path)
35 |
36 | if eval_on_cpu:
37 | with tf.device("/cpu:0"):
38 | model = Model(mode='eval')
39 | attack = LinfPGDAttack(model,
40 | config['epsilon'],
41 | config['num_steps'],
42 | config['step_size'],
43 | config['random_start'],
44 | config['loss_func'])
45 | else:
46 | model = Model(mode='eval')
47 | attack = LinfPGDAttack(model,
48 | config['epsilon'],
49 | config['num_steps'],
50 | config['step_size'],
51 | config['random_start'],
52 | config['loss_func'])
53 |
54 | global_step = tf.contrib.framework.get_or_create_global_step()
55 |
56 | # Setting up the Tensorboard and checkpoint outputs
57 | if not os.path.exists(model_dir):
58 | os.makedirs(model_dir)
59 | eval_dir = os.path.join(model_dir, 'eval')
60 | if not os.path.exists(eval_dir):
61 | os.makedirs(eval_dir)
62 |
63 | last_checkpoint_filename = ''
64 | already_seen_state = False
65 |
66 | saver = tf.train.Saver()
67 | summary_writer = tf.summary.FileWriter(eval_dir)
68 |
69 | # A function for evaluating a single checkpoint
70 | def evaluate_checkpoint(filename):
71 | with tf.Session() as sess:
72 | # Restore the checkpoint
73 | saver.restore(sess, filename)
74 |
75 | # Iterate over the samples batch-by-batch
76 | num_batches = int(math.ceil(num_eval_examples / eval_batch_size))
77 | total_xent_nat = 0.
78 | total_xent_adv = 0.
79 | total_corr_nat = 0
80 | total_corr_adv = 0
81 |
82 | for ibatch in range(num_batches):
83 | bstart = ibatch * eval_batch_size
84 | bend = min(bstart + eval_batch_size, num_eval_examples)
85 |
86 | x_batch = cifar.eval_data.xs[bstart:bend, :]
87 | y_batch = cifar.eval_data.ys[bstart:bend]
88 |
89 | dict_nat = {model.x_input: x_batch,
90 | model.y_input: y_batch}
91 |
92 | x_batch_adv = attack.perturb(x_batch, y_batch, sess)
93 |
94 | dict_adv = {model.x_input: x_batch_adv,
95 | model.y_input: y_batch}
96 |
97 | cur_corr_nat, cur_xent_nat = sess.run(
98 | [model.num_correct,model.xent],
99 | feed_dict = dict_nat)
100 | cur_corr_adv, cur_xent_adv = sess.run(
101 | [model.num_correct,model.xent],
102 | feed_dict = dict_adv)
103 |
104 | print(eval_batch_size)
105 | print("Correctly classified natural examples: {}".format(cur_corr_nat))
106 | print("Correctly classified adversarial examples: {}".format(cur_corr_adv))
107 | total_xent_nat += cur_xent_nat
108 | total_xent_adv += cur_xent_adv
109 | total_corr_nat += cur_corr_nat
110 | total_corr_adv += cur_corr_adv
111 |
112 | avg_xent_nat = total_xent_nat / num_eval_examples
113 | avg_xent_adv = total_xent_adv / num_eval_examples
114 | acc_nat = total_corr_nat / num_eval_examples
115 | acc_adv = total_corr_adv / num_eval_examples
116 |
117 | summary = tf.Summary(value=[
118 | tf.Summary.Value(tag='xent adv eval', simple_value= avg_xent_adv),
119 | tf.Summary.Value(tag='xent adv', simple_value= avg_xent_adv),
120 | tf.Summary.Value(tag='xent nat', simple_value= avg_xent_nat),
121 | tf.Summary.Value(tag='accuracy adv eval', simple_value= acc_adv),
122 | tf.Summary.Value(tag='accuracy adv', simple_value= acc_adv),
123 | tf.Summary.Value(tag='accuracy nat', simple_value= acc_nat)])
124 | summary_writer.add_summary(summary, global_step.eval(sess))
125 |
126 | print('natural: {:.2f}%'.format(100 * acc_nat))
127 | print('adversarial: {:.2f}%'.format(100 * acc_adv))
128 | print('avg nat loss: {:.4f}'.format(avg_xent_nat))
129 | print('avg adv loss: {:.4f}'.format(avg_xent_adv))
130 |
131 | # Infinite eval loop
132 | while True:
133 | cur_checkpoint = tf.train.latest_checkpoint(model_dir)
134 |
135 | # Case 1: No checkpoint yet
136 | if cur_checkpoint is None:
137 | if not already_seen_state:
138 | print('No checkpoint yet, waiting ...', end='')
139 | already_seen_state = True
140 | else:
141 | print('.', end='')
142 | sys.stdout.flush()
143 | time.sleep(10)
144 | # Case 2: Previously unseen checkpoint
145 | elif cur_checkpoint != last_checkpoint_filename:
146 | print('\nCheckpoint {}, evaluating ... ({})'.format(cur_checkpoint,
147 | datetime.now()))
148 | sys.stdout.flush()
149 | last_checkpoint_filename = cur_checkpoint
150 | already_seen_state = False
151 | evaluate_checkpoint(cur_checkpoint)
152 | # Case 3: Previously evaluated checkpoint
153 | else:
154 | if not already_seen_state:
155 | print('Waiting for the next checkpoint ... ({}) '.format(
156 | datetime.now()),
157 | end='')
158 | already_seen_state = True
159 | else:
160 | print('.', end='')
161 | sys.stdout.flush()
162 | time.sleep(10)
163 |
--------------------------------------------------------------------------------
/madry_cifar10/fetch_model.py:
--------------------------------------------------------------------------------
1 | """Downloads a model, computes its SHA256 hash and unzips it
2 | at the proper location."""
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import sys
8 | import zipfile
9 | import hashlib
10 |
11 | if len(sys.argv) == 1 or sys.argv[1] not in ['natural',
12 | 'adv_trained',
13 | 'secret']:
14 | print('Usage: python fetch_model.py [natural, adv_trained]')
15 | sys.exit(1)
16 |
17 | if sys.argv[1] == 'natural':
18 | url = 'https://www.dropbox.com/s/cgzd5odqoojvxzk/natural.zip?dl=1'
19 | elif sys.argv[1] == 'adv_trained':
20 | url = 'https://www.dropbox.com/s/g4b6ntrp8zrudbz/adv_trained.zip?dl=1'
21 | else: # fetch secret model
22 | url = 'https://www.dropbox.com/s/ywc0hg8lr5ba8zd/secret.zip?dl=1'
23 |
24 | fname = url.split('/')[-1].split('?')[0] # get the name of the file
25 |
26 | # model download
27 | print('Downloading models')
28 | if sys.version_info >= (3,):
29 | import urllib.request
30 | urllib.request.urlretrieve(url, fname)
31 | else:
32 | import urllib
33 | urllib.urlretrieve(url, fname)
34 |
35 | # computing model hash
36 | sha256 = hashlib.sha256()
37 | with open(fname, 'rb') as f:
38 | data = f.read()
39 | sha256.update(data)
40 | print('SHA256 hash: {}'.format(sha256.hexdigest()))
41 |
42 | # extracting model
43 | print('Extracting model')
44 | with zipfile.ZipFile(fname, 'r') as model_zip:
45 | model_zip.extractall()
46 | print('Extracted model in {}'.format(model_zip.namelist()[0]))
47 |
--------------------------------------------------------------------------------
/madry_cifar10/model.py:
--------------------------------------------------------------------------------
1 | # based on https://github.com/tensorflow/models/tree/master/resnet
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import numpy as np
7 | import tensorflow as tf
8 |
9 | class Model(object):
10 | """ResNet model."""
11 |
12 | def __init__(self, mode='eval'):
13 | """ResNet constructor.
14 |
15 | Args:
16 | mode: One of 'train' and 'eval'.
17 | """
18 | self.mode = mode
19 | self._build_model()
20 |
21 | def add_internal_summaries(self):
22 | pass
23 |
24 | def _stride_arr(self, stride):
25 | """Map a stride scalar to the stride array for tf.nn.conv2d."""
26 | return [1, stride, stride, 1]
27 |
28 | def _build_model(self):
29 | assert self.mode == 'train' or self.mode == 'eval'
30 | """Build the core model within the graph."""
31 | with tf.variable_scope('input'):
32 |
33 | self.x_input = tf.placeholder(
34 | tf.float32,
35 | shape=[None, 32, 32, 3])
36 |
37 | self.y_input = tf.placeholder(tf.int64, shape=None)
38 |
39 |
40 | input_standardized = tf.map_fn(lambda img: tf.image.per_image_standardization(img),
41 | self.x_input)
42 | x = self._conv('init_conv', input_standardized, 3, 3, 16, self._stride_arr(1))
43 |
44 |
45 |
46 | strides = [1, 2, 2]
47 | activate_before_residual = [True, False, False]
48 | res_func = self._residual
49 |
50 | # Uncomment the following codes to use w28-10 wide residual network.
51 | # It is more memory efficient than very deep residual network and has
52 | # comparably good performance.
53 | # https://arxiv.org/pdf/1605.07146v1.pdf
54 | filters = [16, 160, 320, 640]
55 |
56 |
57 | # Update hps.num_residual_units to 9
58 |
59 | with tf.variable_scope('unit_1_0'):
60 | x = res_func(x, filters[0], filters[1], self._stride_arr(strides[0]),
61 | activate_before_residual[0])
62 | for i in range(1, 5):
63 | with tf.variable_scope('unit_1_%d' % i):
64 | x = res_func(x, filters[1], filters[1], self._stride_arr(1), False)
65 |
66 | with tf.variable_scope('unit_2_0'):
67 | x = res_func(x, filters[1], filters[2], self._stride_arr(strides[1]),
68 | activate_before_residual[1])
69 | for i in range(1, 5):
70 | with tf.variable_scope('unit_2_%d' % i):
71 | x = res_func(x, filters[2], filters[2], self._stride_arr(1), False)
72 |
73 | with tf.variable_scope('unit_3_0'):
74 | x = res_func(x, filters[2], filters[3], self._stride_arr(strides[2]),
75 | activate_before_residual[2])
76 | for i in range(1, 5):
77 | with tf.variable_scope('unit_3_%d' % i):
78 | x = res_func(x, filters[3], filters[3], self._stride_arr(1), False)
79 |
80 | with tf.variable_scope('unit_last'):
81 | x = self._batch_norm('final_bn', x)
82 | x = self._relu(x, 0.1)
83 | x = self._global_avg_pool(x)
84 |
85 | with tf.variable_scope('logit'):
86 | self.pre_softmax = self._fully_connected(x, 10)
87 |
88 | self.predictions = tf.argmax(self.pre_softmax, 1)
89 | self.correct_prediction = tf.equal(self.predictions, self.y_input)
90 | self.num_correct = tf.reduce_sum(
91 | tf.cast(self.correct_prediction, tf.int64))
92 | self.accuracy = tf.reduce_mean(
93 | tf.cast(self.correct_prediction, tf.float32))
94 |
95 | with tf.variable_scope('costs'):
96 | self.y_xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
97 | logits=self.pre_softmax, labels=self.y_input)
98 | self.xent_per_point = self.y_xent
99 | self.xent = tf.reduce_sum(self.y_xent, name='y_xent')
100 | self.mean_xent = tf.reduce_mean(self.y_xent)
101 | self.weight_decay_loss = self._decay()
102 |
103 | def _batch_norm(self, name, x):
104 | """Batch normalization."""
105 | with tf.name_scope(name):
106 | return tf.contrib.layers.batch_norm(
107 | inputs=x,
108 | decay=.9,
109 | center=True,
110 | scale=True,
111 | activation_fn=None,
112 | updates_collections=None,
113 | is_training=(self.mode == 'train'))
114 |
115 | def _residual(self, x, in_filter, out_filter, stride,
116 | activate_before_residual=False):
117 | """Residual unit with 2 sub layers."""
118 | if activate_before_residual:
119 | with tf.variable_scope('shared_activation'):
120 | x = self._batch_norm('init_bn', x)
121 | x = self._relu(x, 0.1)
122 | orig_x = x
123 | else:
124 | with tf.variable_scope('residual_only_activation'):
125 | orig_x = x
126 | x = self._batch_norm('init_bn', x)
127 | x = self._relu(x, 0.1)
128 |
129 | with tf.variable_scope('sub1'):
130 | x = self._conv('conv1', x, 3, in_filter, out_filter, stride)
131 |
132 | with tf.variable_scope('sub2'):
133 | x = self._batch_norm('bn2', x)
134 | x = self._relu(x, 0.1)
135 | x = self._conv('conv2', x, 3, out_filter, out_filter, [1, 1, 1, 1])
136 |
137 | with tf.variable_scope('sub_add'):
138 | if in_filter != out_filter:
139 | orig_x = tf.nn.avg_pool(orig_x, stride, stride, 'VALID')
140 | orig_x = tf.pad(
141 | orig_x, [[0, 0], [0, 0], [0, 0],
142 | [(out_filter-in_filter)//2, (out_filter-in_filter)//2]])
143 | x += orig_x
144 |
145 | tf.logging.debug('image after unit %s', x.get_shape())
146 | return x
147 |
148 | def _decay(self):
149 | """L2 weight decay loss."""
150 | costs = []
151 | for var in tf.trainable_variables():
152 | if var.op.name.find('DW') > 0:
153 | costs.append(tf.nn.l2_loss(var))
154 | return tf.add_n(costs)
155 |
156 | def _conv(self, name, x, filter_size, in_filters, out_filters, strides):
157 | """Convolution."""
158 | with tf.variable_scope(name):
159 | n = filter_size * filter_size * out_filters
160 | kernel = tf.get_variable(
161 | 'DW', [filter_size, filter_size, in_filters, out_filters],
162 | tf.float32, initializer=tf.random_normal_initializer(
163 | stddev=np.sqrt(2.0/n)))
164 | return tf.nn.conv2d(x, kernel, strides, padding='SAME')
165 |
166 | def _relu(self, x, leakiness=0.0):
167 | """Relu, with optional leaky support."""
168 | return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu')
169 |
170 | def _fully_connected(self, x, out_dim):
171 | """FullyConnected layer for final output."""
172 | num_non_batch_dimensions = len(x.shape)
173 | prod_non_batch_dimensions = 1
174 | for ii in range(num_non_batch_dimensions - 1):
175 | prod_non_batch_dimensions *= int(x.shape[ii + 1])
176 | x = tf.reshape(x, [tf.shape(x)[0], -1])
177 | w = tf.get_variable(
178 | 'DW', [prod_non_batch_dimensions, out_dim],
179 | initializer=tf.uniform_unit_scaling_initializer(factor=1.0))
180 | b = tf.get_variable('biases', [out_dim],
181 | initializer=tf.constant_initializer())
182 | return tf.nn.xw_plus_b(x, w, b)
183 |
184 | def _global_avg_pool(self, x):
185 | assert x.get_shape().ndims == 4
186 | return tf.reduce_mean(x, [1, 2])
187 |
188 |
189 |
190 |
--------------------------------------------------------------------------------
/madry_cifar10/model_robustml.py:
--------------------------------------------------------------------------------
1 | import robustml
2 | import tensorflow as tf
3 |
4 | import model
5 |
6 | class Model(robustml.model.Model):
7 | def __init__(self, sess):
8 | self._model = model.Model('eval')
9 |
10 | saver = tf.train.Saver()
11 | checkpoint = tf.train.latest_checkpoint('models/secret')
12 | saver.restore(sess, checkpoint)
13 |
14 | self._sess = sess
15 | self._input = self._model.x_input
16 | self._logits = self._model.pre_softmax
17 | self._predictions = self._model.predictions
18 | self._dataset = robustml.dataset.CIFAR10()
19 | self._threat_model = robustml.threat_model.Linf(epsilon=0.03)
20 |
21 | @property
22 | def dataset(self):
23 | return self._dataset
24 |
25 | @property
26 | def threat_model(self):
27 | return self._threat_model
28 |
29 | def classify(self, x):
30 | return self._sess.run(self._predictions,
31 | {self._input: x})[0]
32 |
33 | # expose attack interface
34 |
35 | @property
36 | def input(self):
37 | return self._input
38 |
39 | @property
40 | def logits(self):
41 | return self._logits
42 |
43 | @property
44 | def predictions(self):
45 | return self._predictions
46 |
--------------------------------------------------------------------------------
/madry_cifar10/pgd_attack.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of attack methods. Running this file as a program will
3 | apply the attack to the model specified by the config file and store
4 | the examples in an .npy file.
5 | """
6 | from __future__ import absolute_import
7 | from __future__ import division
8 | from __future__ import print_function
9 |
10 | import tensorflow as tf
11 | import numpy as np
12 |
13 | import cifar10_input
14 |
15 |
16 | class LinfPGDAttack:
17 | def __init__(self, model, epsilon, num_steps, step_size, random_start, loss_func):
18 | """Attack parameter initialization. The attack performs k steps of
19 | size a, while always staying within epsilon from the initial
20 | point."""
21 | self.model = model
22 | self.epsilon = epsilon
23 | self.num_steps = num_steps
24 | self.step_size = step_size
25 | self.rand = random_start
26 |
27 | if loss_func == 'xent':
28 | loss = model.xent
29 | elif loss_func == 'cw':
30 | label_mask = tf.one_hot(model.y_input,
31 | 10,
32 | on_value=1.0,
33 | off_value=0.0,
34 | dtype=tf.float32)
35 | correct_logit = tf.reduce_sum(label_mask * model.pre_softmax, axis=1)
36 | wrong_logit = tf.reduce_max((1 - label_mask) * model.pre_softmax - 1e4 * label_mask, axis=1)
37 | loss = -tf.nn.relu(correct_logit - wrong_logit + 50)
38 | else:
39 | print('Unknown loss function. Defaulting to cross-entropy')
40 | loss = model.xent
41 |
42 | self.grad = tf.gradients(loss, model.x_input)[0]
43 |
44 | def perturb(self, x_nat, y, sess):
45 | """Given a set of examples (x_nat, y), returns a set of adversarial
46 | examples within epsilon of x_nat in l_infinity norm."""
47 | if self.rand:
48 | x = x_nat + np.random.uniform(-self.epsilon, self.epsilon, x_nat.shape)
49 | x = np.clip(x, 0, 255) # ensure valid pixel range
50 | else:
51 | x = np.copy(x_nat)
52 |
53 | for i in range(self.num_steps):
54 | grad = sess.run(self.grad, feed_dict={self.model.x_input: x,
55 | self.model.y_input: y})
56 |
57 | x = np.add(x, self.step_size * np.sign(grad), out=x, casting='unsafe')
58 |
59 | x = np.clip(x, x_nat - self.epsilon, x_nat + self.epsilon)
60 | x = np.clip(x, 0, 255) # ensure valid pixel range
61 |
62 | return x
63 |
64 |
65 | if __name__ == '__main__':
66 | import json
67 | import sys
68 | import math
69 |
70 | from model import Model
71 |
72 | with open('config.json') as config_file:
73 | config = json.load(config_file)
74 |
75 | model_file = tf.train.latest_checkpoint(config['model_dir'])
76 | if model_file is None:
77 | print('No model found')
78 | sys.exit()
79 |
80 | model = Model(mode='eval')
81 | attack = LinfPGDAttack(model,
82 | config['epsilon'],
83 | config['num_steps'],
84 | config['step_size'],
85 | config['random_start'],
86 | config['loss_func'])
87 | saver = tf.train.Saver()
88 |
89 | data_path = config['data_path']
90 | cifar = cifar10_input.CIFAR10Data(data_path)
91 |
92 | gpu_options = tf.GPUOptions(visible_device_list='7', per_process_gpu_memory_fraction=0.5)
93 | tf_config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)
94 | with tf.Session(config=tf_config) as sess:
95 | # Restore the checkpoint
96 | saver.restore(sess, model_file)
97 |
98 | # Iterate over the samples batch-by-batch
99 | num_eval_examples = config['num_eval_examples']
100 | eval_batch_size = config['eval_batch_size']
101 | num_batches = int(math.ceil(num_eval_examples / eval_batch_size))
102 |
103 | x_adv = [] # adv accumulator
104 |
105 | print('Iterating over {} batches'.format(num_batches))
106 |
107 | for ibatch in range(num_batches):
108 | bstart = ibatch * eval_batch_size
109 | bend = min(bstart + eval_batch_size, num_eval_examples)
110 | print('batch size: {}'.format(bend - bstart))
111 |
112 | x_batch = cifar.eval_data.xs[bstart:bend, :]
113 | y_batch = cifar.eval_data.ys[bstart:bend]
114 |
115 | x_batch_adv = attack.perturb(x_batch, y_batch, sess)
116 |
117 | x_adv.append(x_batch_adv)
118 |
119 | print('Storing examples')
120 | path = config['store_adv_path']
121 | x_adv = np.concatenate(x_adv, axis=0)
122 | np.save(path, x_adv)
123 | print('Examples stored in {}'.format(path))
124 |
--------------------------------------------------------------------------------
/madry_cifar10/run_attack.py:
--------------------------------------------------------------------------------
1 | """Evaluates a model against examples from a .npy file as specified
2 | in config.json"""
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | from datetime import datetime
8 | import json
9 | import math
10 | import os
11 | import sys
12 | import time
13 |
14 | import tensorflow as tf
15 | import numpy as np
16 |
17 | from model import Model
18 | import cifar10_input
19 |
20 | with open('config.json') as config_file:
21 | config = json.load(config_file)
22 |
23 | data_path = config['data_path']
24 |
25 | def run_attack(checkpoint, x_adv, epsilon):
26 | cifar = cifar10_input.CIFAR10Data(data_path)
27 |
28 | model = Model(mode='eval')
29 |
30 | saver = tf.train.Saver()
31 |
32 | num_eval_examples = 10000
33 | eval_batch_size = 100
34 |
35 | num_batches = int(math.ceil(num_eval_examples / eval_batch_size))
36 | total_corr = 0
37 |
38 | x_nat = cifar.eval_data.xs
39 | l_inf = np.amax(np.abs(x_nat - x_adv))
40 |
41 | if l_inf > epsilon + 0.0001:
42 | print('maximum perturbation found: {}'.format(l_inf))
43 | print('maximum perturbation allowed: {}'.format(epsilon))
44 | return
45 |
46 | y_pred = [] # label accumulator
47 |
48 | with tf.Session() as sess:
49 | # Restore the checkpoint
50 | saver.restore(sess, checkpoint)
51 |
52 | # Iterate over the samples batch-by-batch
53 | for ibatch in range(num_batches):
54 | bstart = ibatch * eval_batch_size
55 | bend = min(bstart + eval_batch_size, num_eval_examples)
56 |
57 | x_batch = x_adv[bstart:bend, :]
58 | y_batch = cifar.eval_data.ys[bstart:bend]
59 |
60 | dict_adv = {model.x_input: x_batch,
61 | model.y_input: y_batch}
62 | cur_corr, y_pred_batch = sess.run([model.num_correct, model.predictions],
63 | feed_dict=dict_adv)
64 |
65 | total_corr += cur_corr
66 | y_pred.append(y_pred_batch)
67 |
68 | accuracy = total_corr / num_eval_examples
69 |
70 | print('Accuracy: {:.2f}%'.format(100.0 * accuracy))
71 | y_pred = np.concatenate(y_pred, axis=0)
72 | np.save('pred.npy', y_pred)
73 | print('Output saved at pred.npy')
74 |
75 | if __name__ == '__main__':
76 | import json
77 |
78 | with open('config.json') as config_file:
79 | config = json.load(config_file)
80 |
81 | model_dir = config['model_dir']
82 |
83 | checkpoint = tf.train.latest_checkpoint(model_dir)
84 | x_adv = np.load(config['store_adv_path'])
85 |
86 | if checkpoint is None:
87 | print('No checkpoint found')
88 | elif x_adv.shape != (10000, 32, 32, 3):
89 | print('Invalid shape: expected (10000, 32, 32, 3), found {}'.format(x_adv.shape))
90 | elif np.amax(x_adv) > 255.0001 or np.amin(x_adv) < -0.0001:
91 | print('Invalid pixel range. Expected [0, 255], found [{}, {}]'.format(
92 | np.amin(x_adv),
93 | np.amax(x_adv)))
94 | else:
95 | run_attack(checkpoint, x_adv, config['epsilon'])
96 |
--------------------------------------------------------------------------------
/madry_cifar10/train.py:
--------------------------------------------------------------------------------
1 | """Trains a model, saving checkpoints and tensorboard summaries along
2 | the way."""
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | from datetime import datetime
8 | import json
9 | import os
10 | import shutil
11 | from timeit import default_timer as timer
12 |
13 | import tensorflow as tf
14 | import numpy as np
15 |
16 | from model import Model
17 | import cifar10_input
18 | from pgd_attack import LinfPGDAttack
19 |
20 | with open('config.json') as config_file:
21 | config = json.load(config_file)
22 |
23 | # seeding randomness
24 | tf.set_random_seed(config['tf_random_seed'])
25 | np.random.seed(config['np_random_seed'])
26 |
27 | # Setting up training parameters
28 | max_num_training_steps = config['max_num_training_steps']
29 | num_output_steps = config['num_output_steps']
30 | num_summary_steps = config['num_summary_steps']
31 | num_checkpoint_steps = config['num_checkpoint_steps']
32 | step_size_schedule = config['step_size_schedule']
33 | weight_decay = config['weight_decay']
34 | data_path = config['data_path']
35 | momentum = config['momentum']
36 | batch_size = config['training_batch_size']
37 |
38 | # Setting up the data and the model
39 | raw_cifar = cifar10_input.CIFAR10Data(data_path)
40 | global_step = tf.contrib.framework.get_or_create_global_step()
41 | model = Model(mode='train')
42 |
43 | # Setting up the optimizer
44 | boundaries = [int(sss[0]) for sss in step_size_schedule]
45 | boundaries = boundaries[1:]
46 | values = [sss[1] for sss in step_size_schedule]
47 | learning_rate = tf.train.piecewise_constant(
48 | tf.cast(global_step, tf.int32),
49 | boundaries,
50 | values)
51 | total_loss = model.mean_xent + weight_decay * model.weight_decay_loss
52 | train_step = tf.train.MomentumOptimizer(learning_rate, momentum).minimize(
53 | total_loss,
54 | global_step=global_step)
55 |
56 | # Set up adversary
57 | attack = LinfPGDAttack(model,
58 | config['epsilon'],
59 | config['num_steps'],
60 | config['step_size'],
61 | config['random_start'],
62 | config['loss_func'])
63 |
64 | # Setting up the Tensorboard and checkpoint outputs
65 | model_dir = config['model_dir']
66 | if not os.path.exists(model_dir):
67 | os.makedirs(model_dir)
68 |
69 | # We add accuracy and xent twice so we can easily make three types of
70 | # comparisons in Tensorboard:
71 | # - train vs eval (for a single run)
72 | # - train of different runs
73 | # - eval of different runs
74 |
75 | saver = tf.train.Saver(max_to_keep=3)
76 | tf.summary.scalar('accuracy adv train', model.accuracy)
77 | tf.summary.scalar('accuracy adv', model.accuracy)
78 | tf.summary.scalar('xent adv train', model.xent / batch_size)
79 | tf.summary.scalar('xent adv', model.xent / batch_size)
80 | tf.summary.image('images adv train', model.x_input)
81 | merged_summaries = tf.summary.merge_all()
82 |
83 | # keep the configuration file with the model for reproducibility
84 | shutil.copy('config.json', model_dir)
85 |
86 | with tf.Session() as sess:
87 |
88 | # initialize data augmentation
89 | cifar = cifar10_input.AugmentedCIFAR10Data(raw_cifar, sess, model)
90 |
91 | # Initialize the summary writer, global variables, and our time counter.
92 | summary_writer = tf.summary.FileWriter(model_dir, sess.graph)
93 | sess.run(tf.global_variables_initializer())
94 | training_time = 0.0
95 |
96 | # Main training loop
97 | for ii in range(max_num_training_steps):
98 | x_batch, y_batch = cifar.train_data.get_next_batch(batch_size,
99 | multiple_passes=True)
100 |
101 | # Compute Adversarial Perturbations
102 | start = timer()
103 | x_batch_adv = attack.perturb(x_batch, y_batch, sess)
104 | end = timer()
105 | training_time += end - start
106 |
107 | nat_dict = {model.x_input: x_batch,
108 | model.y_input: y_batch}
109 |
110 | adv_dict = {model.x_input: x_batch_adv,
111 | model.y_input: y_batch}
112 |
113 | # Output to stdout
114 | if ii % num_output_steps == 0:
115 | nat_acc = sess.run(model.accuracy, feed_dict=nat_dict)
116 | adv_acc = sess.run(model.accuracy, feed_dict=adv_dict)
117 | print('Step {}: ({})'.format(ii, datetime.now()))
118 | print(' training nat accuracy {:.4}%'.format(nat_acc * 100))
119 | print(' training adv accuracy {:.4}%'.format(adv_acc * 100))
120 | if ii != 0:
121 | print(' {} examples per second'.format(
122 | num_output_steps * batch_size / training_time))
123 | training_time = 0.0
124 | # Tensorboard summaries
125 | if ii % num_summary_steps == 0:
126 | summary = sess.run(merged_summaries, feed_dict=adv_dict)
127 | summary_writer.add_summary(summary, global_step.eval(sess))
128 |
129 | # Write a checkpoint
130 | if ii % num_checkpoint_steps == 0:
131 | saver.save(sess,
132 | os.path.join(model_dir, 'checkpoint'),
133 | global_step=global_step)
134 |
135 | # Actual training step
136 | start = timer()
137 | sess.run(train_step, feed_dict=adv_dict)
138 | end = timer()
139 | training_time += end - start
140 |
--------------------------------------------------------------------------------
/madry_mnist/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Aleksander Madry, Aleksandar Makelov, Ludwig Schmidt, Dimitris Tsipras, and Adrian Vladu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/madry_mnist/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_comment": "===== MODEL CONFIGURATION =====",
3 | "model_dir": "models/secret",
4 |
5 | "_comment": "===== TRAINING CONFIGURATION =====",
6 | "random_seed": 4557077,
7 | "max_num_training_steps": 100000,
8 | "num_output_steps": 100,
9 | "num_summary_steps": 100,
10 | "num_checkpoint_steps": 300,
11 | "training_batch_size": 50,
12 |
13 | "_comment": "===== EVAL CONFIGURATION =====",
14 | "num_eval_examples": 10000,
15 | "eval_on_cpu": false,
16 |
17 | "_comment": "=====ADVERSARIAL EXAMPLES CONFIGURATION=====",
18 | "epsilon": 0.3,
19 | "k": 100,
20 | "a": 0.01,
21 | "random_start": true,
22 | "loss_func": "xent",
23 | "store_adv_path": "attack.npy"
24 | }
25 |
--------------------------------------------------------------------------------
/madry_mnist/eval.py:
--------------------------------------------------------------------------------
1 | """
2 | Infinite evaluation loop going through the checkpoints in the model directory
3 | as they appear and evaluating them. Accuracy and average loss are printed and
4 | added as tensorboard summaries.
5 | """
6 | from __future__ import absolute_import
7 | from __future__ import division
8 | from __future__ import print_function
9 |
10 | from datetime import datetime
11 | import json
12 | import math
13 | import os
14 | import sys
15 | import time
16 |
17 | import tensorflow as tf
18 | from tensorflow.examples.tutorials.mnist import input_data
19 |
20 | from model import Model
21 | from attack import LinfPGDAttack
22 |
23 | # Global constants
24 | with open('config.json') as config_file:
25 | config = json.load(config_file)
26 | num_eval_examples = config['num_eval_examples']
27 | eval_batch_size = config['eval_batch_size']
28 | eval_on_cpu = config['eval_on_cpu']
29 |
30 | model_dir = config['model_dir']
31 |
32 | # Set upd the data, hyperparameters, and the model
33 | mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
34 |
35 | if eval_on_cpu:
36 | with tf.device("/cpu:0"):
37 | model = Model()
38 | attack = LinfPGDAttack(model,
39 | config['epsilon'],
40 | config['k'],
41 | config['a'],
42 | config['random_start'],
43 | config['loss_func'])
44 | else:
45 | model = Model()
46 | attack = LinfPGDAttack(model,
47 | config['epsilon'],
48 | config['k'],
49 | config['a'],
50 | config['random_start'],
51 | config['loss_func'])
52 |
53 | global_step = tf.contrib.framework.get_or_create_global_step()
54 |
55 | # Setting up the Tensorboard and checkpoint outputs
56 | if not os.path.exists(model_dir):
57 | os.makedirs(model_dir)
58 | eval_dir = os.path.join(model_dir, 'eval')
59 | if not os.path.exists(eval_dir):
60 | os.makedirs(eval_dir)
61 |
62 | last_checkpoint_filename = ''
63 | already_seen_state = False
64 |
65 | saver = tf.train.Saver()
66 | summary_writer = tf.summary.FileWriter(eval_dir)
67 |
68 | # A function for evaluating a single checkpoint
69 | def evaluate_checkpoint(filename):
70 | with tf.Session() as sess:
71 | # Restore the checkpoint
72 | saver.restore(sess, filename)
73 |
74 | # Iterate over the samples batch-by-batch
75 | num_batches = int(math.ceil(num_eval_examples / eval_batch_size))
76 | total_xent_nat = 0.
77 | total_xent_adv = 0.
78 | total_corr_nat = 0
79 | total_corr_adv = 0
80 |
81 | for ibatch in range(num_batches):
82 | bstart = ibatch * eval_batch_size
83 | bend = min(bstart + eval_batch_size, num_eval_examples)
84 |
85 | x_batch = mnist.test.images[bstart:bend, :]
86 | y_batch = mnist.test.labels[bstart:bend]
87 |
88 | dict_nat = {model.x_input: x_batch,
89 | model.y_input: y_batch}
90 |
91 | x_batch_adv = attack.perturb(x_batch, y_batch, sess)
92 |
93 | dict_adv = {model.x_input: x_batch_adv,
94 | model.y_input: y_batch}
95 |
96 | cur_corr_nat, cur_xent_nat = sess.run(
97 | [model.num_correct,model.xent],
98 | feed_dict = dict_nat)
99 | cur_corr_adv, cur_xent_adv = sess.run(
100 | [model.num_correct,model.xent],
101 | feed_dict = dict_adv)
102 |
103 | total_xent_nat += cur_xent_nat
104 | total_xent_adv += cur_xent_adv
105 | total_corr_nat += cur_corr_nat
106 | total_corr_adv += cur_corr_adv
107 |
108 | avg_xent_nat = total_xent_nat / num_eval_examples
109 | avg_xent_adv = total_xent_adv / num_eval_examples
110 | acc_nat = total_corr_nat / num_eval_examples
111 | acc_adv = total_corr_adv / num_eval_examples
112 |
113 | summary = tf.Summary(value=[
114 | tf.Summary.Value(tag='xent adv eval', simple_value= avg_xent_adv),
115 | tf.Summary.Value(tag='xent adv', simple_value= avg_xent_adv),
116 | tf.Summary.Value(tag='xent nat', simple_value= avg_xent_nat),
117 | tf.Summary.Value(tag='accuracy adv eval', simple_value= acc_adv),
118 | tf.Summary.Value(tag='accuracy adv', simple_value= acc_adv),
119 | tf.Summary.Value(tag='accuracy nat', simple_value= acc_nat)])
120 | summary_writer.add_summary(summary, global_step.eval(sess))
121 |
122 | print('natural: {:.2f}%'.format(100 * acc_nat))
123 | print('adversarial: {:.2f}%'.format(100 * acc_adv))
124 | print('avg nat loss: {:.4f}'.format(avg_xent_nat))
125 | print('avg adv loss: {:.4f}'.format(avg_xent_adv))
126 |
127 | # Infinite eval loop
128 | while True:
129 | cur_checkpoint = tf.train.latest_checkpoint(model_dir)
130 |
131 | # Case 1: No checkpoint yet
132 | if cur_checkpoint is None:
133 | if not already_seen_state:
134 | print('No checkpoint yet, waiting ...', end='')
135 | already_seen_state = True
136 | else:
137 | print('.', end='')
138 | sys.stdout.flush()
139 | time.sleep(10)
140 | # Case 2: Previously unseen checkpoint
141 | elif cur_checkpoint != last_checkpoint_filename:
142 | print('\nCheckpoint {}, evaluating ... ({})'.format(cur_checkpoint,
143 | datetime.now()))
144 | sys.stdout.flush()
145 | last_checkpoint_filename = cur_checkpoint
146 | already_seen_state = False
147 | evaluate_checkpoint(cur_checkpoint)
148 | # Case 3: Previously evaluated checkpoint
149 | else:
150 | if not already_seen_state:
151 | print('Waiting for the next checkpoint ... ({}) '.format(
152 | datetime.now()),
153 | end='')
154 | already_seen_state = True
155 | else:
156 | print('.', end='')
157 | sys.stdout.flush()
158 | time.sleep(10)
159 |
--------------------------------------------------------------------------------
/madry_mnist/fetch_model.py:
--------------------------------------------------------------------------------
1 | """Downloads a model, computes its SHA256 hash and unzips it
2 | at the proper location."""
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import sys
8 | import zipfile
9 | import hashlib
10 |
11 | if len(sys.argv) != 2 or sys.argv[1] not in ['natural',
12 | 'adv_trained',
13 | 'secret']:
14 | print('Usage: python fetch_model.py [natural, adv_trained, secret]')
15 | sys.exit(1)
16 |
17 | if sys.argv[1] == 'natural':
18 | url = 'https://github.com/MadryLab/mnist_challenge_models/raw/master/natural.zip'
19 | elif sys.argv[1] == 'secret':
20 | url = 'https://github.com/MadryLab/mnist_challenge_models/raw/master/secret.zip'
21 | else: # fetch adv_trained model
22 | url = 'https://github.com/MadryLab/mnist_challenge_models/raw/master/adv_trained.zip'
23 |
24 | fname = url.split('/')[-1] # get the name of the file
25 |
26 | # model download
27 | print('Downloading models')
28 | if sys.version_info >= (3,):
29 | import urllib.request
30 | urllib.request.urlretrieve(url, fname)
31 | else:
32 | import urllib
33 | urllib.urlretrieve(url, fname)
34 |
35 | # computing model hash
36 | sha256 = hashlib.sha256()
37 | with open(fname, 'rb') as f:
38 | data = f.read()
39 | sha256.update(data)
40 | print('SHA256 hash: {}'.format(sha256.hexdigest()))
41 |
42 | # extracting model
43 | print('Extracting model')
44 | with zipfile.ZipFile(fname, 'r') as model_zip:
45 | model_zip.extractall()
46 | print('Extracted model in {}'.format(model_zip.namelist()[0]))
47 |
--------------------------------------------------------------------------------
/madry_mnist/model.py:
--------------------------------------------------------------------------------
1 | """
2 | The model is adapted from the tensorflow tutorial:
3 | https://www.tensorflow.org/get_started/mnist/pros
4 | """
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import tensorflow as tf
10 |
11 | class Model(object):
12 | def __init__(self):
13 | self.x_input = tf.placeholder(tf.float32, shape = [None, 784])
14 | self.y_input = tf.placeholder(tf.int64, shape = [None])
15 |
16 | self.x_image = tf.reshape(self.x_input, [-1, 28, 28, 1])
17 |
18 | # first convolutional layer
19 | W_conv1 = self._weight_variable([5,5,1,32])
20 | b_conv1 = self._bias_variable([32])
21 |
22 | h_conv1 = tf.nn.relu(self._conv2d(self.x_image, W_conv1) + b_conv1)
23 | h_pool1 = self._max_pool_2x2(h_conv1)
24 |
25 | # second convolutional layer
26 | W_conv2 = self._weight_variable([5,5,32,64])
27 | b_conv2 = self._bias_variable([64])
28 |
29 | h_conv2 = tf.nn.relu(self._conv2d(h_pool1, W_conv2) + b_conv2)
30 | h_pool2 = self._max_pool_2x2(h_conv2)
31 |
32 | # first fully connected layer
33 | W_fc1 = self._weight_variable([7 * 7 * 64, 1024])
34 | b_fc1 = self._bias_variable([1024])
35 |
36 | h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
37 | h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
38 |
39 | # output layer
40 | W_fc2 = self._weight_variable([1024,10])
41 | b_fc2 = self._bias_variable([10])
42 |
43 | self.pre_softmax = tf.matmul(h_fc1, W_fc2) + b_fc2
44 |
45 | y_xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
46 | labels=self.y_input, logits=self.pre_softmax)
47 |
48 | self.xent_per_point = y_xent
49 | self.xent = tf.reduce_sum(y_xent)
50 |
51 | self.y_pred = tf.argmax(self.pre_softmax, 1)
52 |
53 | correct_prediction = tf.equal(self.y_pred, self.y_input)
54 |
55 | self.num_correct = tf.reduce_sum(tf.cast(correct_prediction, tf.int64))
56 | self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
57 |
58 | @staticmethod
59 | def _weight_variable(shape):
60 | initial = tf.truncated_normal(shape, stddev=0.1)
61 | return tf.Variable(initial)
62 |
63 | @staticmethod
64 | def _bias_variable(shape):
65 | initial = tf.constant(0.1, shape = shape)
66 | return tf.Variable(initial)
67 |
68 | @staticmethod
69 | def _conv2d(x, W):
70 | return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')
71 |
72 | @staticmethod
73 | def _max_pool_2x2( x):
74 | return tf.nn.max_pool(x,
75 | ksize = [1,2,2,1],
76 | strides=[1,2,2,1],
77 | padding='SAME')
78 |
--------------------------------------------------------------------------------
/madry_mnist/run_attack.py:
--------------------------------------------------------------------------------
1 | """Evaluates a model against examples from a .npy file as specified
2 | in config.json"""
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | from datetime import datetime
8 | import json
9 | import math
10 | import os
11 | import sys
12 | import time
13 |
14 | import tensorflow as tf
15 | from tensorflow.examples.tutorials.mnist import input_data
16 |
17 | import numpy as np
18 |
19 | from model import Model
20 |
21 |
22 | def run_attack(checkpoint, x_adv, epsilon):
23 | mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
24 |
25 | model = Model()
26 |
27 | saver = tf.train.Saver()
28 |
29 | num_eval_examples = 10000
30 | eval_batch_size = 64
31 |
32 | num_batches = int(math.ceil(num_eval_examples / eval_batch_size))
33 | total_corr = 0
34 |
35 | x_nat = mnist.test.images
36 | l_inf = np.amax(np.abs(x_nat - x_adv))
37 |
38 | if l_inf > epsilon + 0.0001:
39 | print('maximum perturbation found: {}'.format(l_inf))
40 | print('maximum perturbation allowed: {}'.format(epsilon))
41 | return
42 |
43 | y_pred = [] # label accumulator
44 |
45 | with tf.Session() as sess:
46 | # Restore the checkpoint
47 | saver.restore(sess, checkpoint)
48 |
49 | # Iterate over the samples batch-by-batch
50 | for ibatch in range(num_batches):
51 | bstart = ibatch * eval_batch_size
52 | bend = min(bstart + eval_batch_size, num_eval_examples)
53 |
54 | x_batch = x_adv[bstart:bend, :]
55 | y_batch = mnist.test.labels[bstart:bend]
56 |
57 | dict_adv = {model.x_input: x_batch,
58 | model.y_input: y_batch}
59 | cur_corr, y_pred_batch = sess.run([model.num_correct, model.y_pred],
60 | feed_dict=dict_adv)
61 |
62 | total_corr += cur_corr
63 | y_pred.append(y_pred_batch)
64 |
65 | accuracy = total_corr / num_eval_examples
66 |
67 | print('Accuracy: {:.2f}%'.format(100.0 * accuracy))
68 | y_pred = np.concatenate(y_pred, axis=0)
69 | np.save('pred.npy', y_pred)
70 | print('Output saved at pred.npy')
71 |
72 |
73 | if __name__ == '__main__':
74 | import json
75 |
76 | with open('config.json') as config_file:
77 | config = json.load(config_file)
78 |
79 | model_dir = config['model_dir']
80 |
81 | checkpoint = tf.train.latest_checkpoint(model_dir)
82 | x_adv = np.load(config['store_adv_path'])
83 |
84 | if checkpoint is None:
85 | print('No checkpoint found')
86 | elif x_adv.shape != (10000, 784):
87 | print('Invalid shape: expected (10000,784), found {}'.format(x_adv.shape))
88 | elif np.amax(x_adv) > 1.0001 or \
89 | np.amin(x_adv) < -0.0001 or \
90 | np.isnan(np.amax(x_adv)):
91 | print('Invalid pixel range. Expected [0, 1], found [{}, {}]'.format(
92 | np.amin(x_adv),
93 | np.amax(x_adv)))
94 | else:
95 | run_attack(checkpoint, x_adv, config['epsilon'])
96 |
--------------------------------------------------------------------------------
/madry_mnist/train.py:
--------------------------------------------------------------------------------
1 | """Trains a model, saving checkpoints and tensorboard summaries along
2 | the way."""
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | from datetime import datetime
8 | import json
9 | import os
10 | import shutil
11 | from timeit import default_timer as timer
12 |
13 | import tensorflow as tf
14 | import numpy as np
15 | from tensorflow.examples.tutorials.mnist import input_data
16 |
17 | from model import Model
18 | from attack import LinfPGDAttack
19 |
20 | with open('config.json') as config_file:
21 | config = json.load(config_file)
22 |
23 | # Setting up training parameters
24 | tf.set_random_seed(config['random_seed'])
25 |
26 | max_num_training_steps = config['max_num_training_steps']
27 | num_output_steps = config['num_output_steps']
28 | num_summary_steps = config['num_summary_steps']
29 | num_checkpoint_steps = config['num_checkpoint_steps']
30 |
31 | batch_size = config['training_batch_size']
32 |
33 | # Setting up the data and the model
34 | mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
35 | global_step = tf.contrib.framework.get_or_create_global_step()
36 | model = Model()
37 |
38 | # Setting up the optimizer
39 | train_step = tf.train.AdamOptimizer(1e-4).minimize(model.xent,
40 | global_step=global_step)
41 |
42 | # Set up adversary
43 | attack = LinfPGDAttack(model,
44 | config['epsilon'],
45 | config['k'],
46 | config['a'],
47 | config['random_start'],
48 | config['loss_func'])
49 |
50 | # Setting up the Tensorboard and checkpoint outputs
51 | model_dir = config['model_dir']
52 | if not os.path.exists(model_dir):
53 | os.makedirs(model_dir)
54 |
55 | # We add accuracy and xent twice so we can easily make three types of
56 | # comparisons in Tensorboard:
57 | # - train vs eval (for a single run)
58 | # - train of different runs
59 | # - eval of different runs
60 |
61 | saver = tf.train.Saver(max_to_keep=3)
62 | tf.summary.scalar('accuracy adv train', model.accuracy)
63 | tf.summary.scalar('accuracy adv', model.accuracy)
64 | tf.summary.scalar('xent adv train', model.xent / batch_size)
65 | tf.summary.scalar('xent adv', model.xent / batch_size)
66 | tf.summary.image('images adv train', model.x_image)
67 | merged_summaries = tf.summary.merge_all()
68 |
69 | shutil.copy('config.json', model_dir)
70 |
71 | with tf.Session() as sess:
72 | # Initialize the summary writer, global variables, and our time counter.
73 | summary_writer = tf.summary.FileWriter(model_dir, sess.graph)
74 | sess.run(tf.global_variables_initializer())
75 | training_time = 0.0
76 |
77 | # Main training loop
78 | for ii in range(max_num_training_steps):
79 | x_batch, y_batch = mnist.train.next_batch(batch_size)
80 |
81 | # Compute Adversarial Perturbations
82 | start = timer()
83 | x_batch_adv = attack.perturb(x_batch, y_batch, sess)
84 | end = timer()
85 | training_time += end - start
86 |
87 | nat_dict = {model.x_input: x_batch,
88 | model.y_input: y_batch}
89 |
90 | adv_dict = {model.x_input: x_batch_adv,
91 | model.y_input: y_batch}
92 |
93 | # Output to stdout
94 | if ii % num_output_steps == 0:
95 | nat_acc = sess.run(model.accuracy, feed_dict=nat_dict)
96 | adv_acc = sess.run(model.accuracy, feed_dict=adv_dict)
97 | print('Step {}: ({})'.format(ii, datetime.now()))
98 | print(' training nat accuracy {:.4}%'.format(nat_acc * 100))
99 | print(' training adv accuracy {:.4}%'.format(adv_acc * 100))
100 | if ii != 0:
101 | print(' {} examples per second'.format(
102 | num_output_steps * batch_size / training_time))
103 | training_time = 0.0
104 | # Tensorboard summaries
105 | if ii % num_summary_steps == 0:
106 | summary = sess.run(merged_summaries, feed_dict=adv_dict)
107 | summary_writer.add_summary(summary, global_step.eval(sess))
108 |
109 | # Write a checkpoint
110 | if ii % num_checkpoint_steps == 0:
111 | saver.save(sess,
112 | os.path.join(model_dir, 'checkpoint'),
113 | global_step=global_step)
114 |
115 | # Actual training step
116 | start = timer()
117 | sess.run(train_step, feed_dict=adv_dict)
118 | end = timer()
119 | training_time += end - start
120 |
--------------------------------------------------------------------------------
/metrics/2019-11-10 15:57:14 model=pt_inception dataset=imagenet n_ex=1000 eps=12.75 p=0.05 n_iter=10000.metrics.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/metrics/2019-11-10 15:57:14 model=pt_inception dataset=imagenet n_ex=1000 eps=12.75 p=0.05 n_iter=10000.metrics.npy
--------------------------------------------------------------------------------
/metrics/2019-11-10 15:57:14 model=pt_resnet dataset=imagenet n_ex=1000 eps=12.75 p=0.05 n_iter=10000.metrics.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/metrics/2019-11-10 15:57:14 model=pt_resnet dataset=imagenet n_ex=1000 eps=12.75 p=0.05 n_iter=10000.metrics.npy
--------------------------------------------------------------------------------
/metrics/2019-11-10 15:57:14 model=pt_vgg dataset=imagenet n_ex=1000 eps=12.75 p=0.05 n_iter=10000.metrics.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/metrics/2019-11-10 15:57:14 model=pt_vgg dataset=imagenet n_ex=1000 eps=12.75 p=0.05 n_iter=10000.metrics.npy
--------------------------------------------------------------------------------
/metrics/square_l2_inceptionv3_queries.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/metrics/square_l2_inceptionv3_queries.npy
--------------------------------------------------------------------------------
/metrics/square_l2_resnet50_queries.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/metrics/square_l2_resnet50_queries.npy
--------------------------------------------------------------------------------
/metrics/square_l2_vgg16_queries.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/metrics/square_l2_vgg16_queries.npy
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tensorflow as tf
3 | import numpy as np
4 | import math
5 | import utils
6 | from torchvision import models as torch_models
7 | from torch.nn import DataParallel
8 | from madry_mnist.model import Model as madry_model_mnist
9 | from madry_cifar10.model import Model as madry_model_cifar10
10 | from logit_pairing.models import LeNet as lp_model_mnist, ResNet20_v2 as lp_model_cifar10
11 | from post_avg.postAveragedModels import pa_resnet110_config1 as post_avg_cifar10_resnet
12 | from post_avg.postAveragedModels import pa_resnet152_config1 as post_avg_imagenet_resnet
13 |
14 |
15 | class Model:
16 | def __init__(self, batch_size, gpu_memory):
17 | self.batch_size = batch_size
18 | self.gpu_memory = gpu_memory
19 |
20 | def predict(self, x):
21 | raise NotImplementedError('use ModelTF or ModelPT')
22 |
23 | def loss(self, y, logits, targeted=False, loss_type='margin_loss'):
24 | """ Implements the margin loss (difference between the correct and 2nd best class). """
25 | if loss_type == 'margin_loss':
26 | preds_correct_class = (logits * y).sum(1, keepdims=True)
27 | diff = preds_correct_class - logits # difference between the correct class and all other classes
28 | diff[y] = np.inf # to exclude zeros coming from f_correct - f_correct
29 | margin = diff.min(1, keepdims=True)
30 | loss = margin * -1 if targeted else margin
31 | elif loss_type == 'cross_entropy':
32 | probs = utils.softmax(logits)
33 | loss = -np.log(probs[y])
34 | loss = loss * -1 if not targeted else loss
35 | else:
36 | raise ValueError('Wrong loss.')
37 | return loss.flatten()
38 |
39 |
40 | class ModelTF(Model):
41 | """
42 | Wrapper class around TensorFlow models.
43 |
44 | In order to incorporate a new model, one has to ensure that self.model has a TF variable `logits`,
45 | and that the preprocessing of the inputs is done correctly (e.g. subtracting the mean and dividing over the
46 | standard deviation).
47 | """
48 | def __init__(self, model_name, batch_size, gpu_memory):
49 | super().__init__(batch_size, gpu_memory)
50 | model_folder = model_path_dict[model_name]
51 | model_file = tf.train.latest_checkpoint(model_folder)
52 | self.model = model_class_dict[model_name]()
53 | self.batch_size = batch_size
54 | self.model_name = model_name
55 | self.model_file = model_file
56 | if 'logits' not in self.model.__dict__:
57 | self.model.logits = self.model.pre_softmax
58 |
59 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_memory)
60 | config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)
61 | self.sess = tf.Session(config=config)
62 | tf.train.Saver().restore(self.sess, model_file)
63 |
64 | def predict(self, x):
65 | if 'mnist' in self.model_name:
66 | shape = self.model.x_input.shape[1:].as_list()
67 | x = np.reshape(x, [-1, *shape])
68 | elif 'cifar10' in self.model_name:
69 | x = np.transpose(x, axes=[0, 2, 3, 1])
70 |
71 | n_batches = math.ceil(x.shape[0] / self.batch_size)
72 | logits_list = []
73 | for i in range(n_batches):
74 | x_batch = x[i*self.batch_size:(i+1)*self.batch_size]
75 | logits = self.sess.run(self.model.logits, feed_dict={self.model.x_input: x_batch})
76 | logits_list.append(logits)
77 | logits = np.vstack(logits_list)
78 | return logits
79 |
80 |
81 | class ModelPT(Model):
82 | """
83 | Wrapper class around PyTorch models.
84 |
85 | In order to incorporate a new model, one has to ensure that self.model is a callable object that returns logits,
86 | and that the preprocessing of the inputs is done correctly (e.g. subtracting the mean and dividing over the
87 | standard deviation).
88 | """
89 | def __init__(self, model_name, batch_size, gpu_memory):
90 | super().__init__(batch_size, gpu_memory)
91 | if model_name in ['pt_vgg', 'pt_resnet', 'pt_inception', 'pt_densenet']:
92 | model = model_class_dict[model_name](pretrained=True)
93 | self.mean = np.reshape([0.485, 0.456, 0.406], [1, 3, 1, 1])
94 | self.std = np.reshape([0.229, 0.224, 0.225], [1, 3, 1, 1])
95 | model = DataParallel(model.cuda())
96 | else:
97 | model = model_class_dict[model_name]()
98 | if model_name in ['pt_post_avg_cifar10', 'pt_post_avg_imagenet']:
99 | # checkpoint = torch.load(model_path_dict[model_name])
100 | self.mean = np.reshape([0.485, 0.456, 0.406], [1, 3, 1, 1])
101 | self.std = np.reshape([0.229, 0.224, 0.225], [1, 3, 1, 1])
102 | else:
103 | model = DataParallel(model).cuda()
104 | checkpoint = torch.load(model_path_dict[model_name] + '.pth')
105 | self.mean = np.reshape([0.485, 0.456, 0.406], [1, 3, 1, 1])
106 | self.std = np.reshape([0.225, 0.225, 0.225], [1, 3, 1, 1])
107 | model.load_state_dict(checkpoint)
108 | model.float()
109 | self.mean, self.std = self.mean.astype(np.float32), self.std.astype(np.float32)
110 |
111 | model.eval()
112 | self.model = model
113 |
114 | def predict(self, x):
115 | x = (x - self.mean) / self.std
116 | x = x.astype(np.float32)
117 |
118 | n_batches = math.ceil(x.shape[0] / self.batch_size)
119 | logits_list = []
120 | with torch.no_grad(): # otherwise consumes too much memory and leads to a slowdown
121 | for i in range(n_batches):
122 | x_batch = x[i*self.batch_size:(i+1)*self.batch_size]
123 | x_batch_torch = torch.as_tensor(x_batch, device=torch.device('cuda'))
124 | logits = self.model(x_batch_torch).cpu().numpy()
125 | logits_list.append(logits)
126 | logits = np.vstack(logits_list)
127 | return logits
128 |
129 |
130 | model_path_dict = {'madry_mnist_robust': 'madry_mnist/models/robust',
131 | 'madry_cifar10_robust': 'madry_cifar10/models/robust',
132 | 'clp_mnist': 'logit_pairing/models/clp_mnist',
133 | 'lsq_mnist': 'logit_pairing/models/lsq_mnist',
134 | 'clp_cifar10': 'logit_pairing/models/clp_cifar10',
135 | 'lsq_cifar10': 'logit_pairing/models/lsq_cifar10',
136 | 'pt_post_avg_cifar10': 'post_avg/trainedModel/resnet110.th'
137 | }
138 | model_class_dict = {'pt_vgg': torch_models.vgg16_bn,
139 | 'pt_resnet': torch_models.resnet50,
140 | 'pt_inception': torch_models.inception_v3,
141 | 'pt_densenet': torch_models.densenet121,
142 | 'madry_mnist_robust': madry_model_mnist,
143 | 'madry_cifar10_robust': madry_model_cifar10,
144 | 'clp_mnist': lp_model_mnist,
145 | 'lsq_mnist': lp_model_mnist,
146 | 'clp_cifar10': lp_model_cifar10,
147 | 'lsq_cifar10': lp_model_cifar10,
148 | 'pt_post_avg_cifar10': post_avg_cifar10_resnet,
149 | 'pt_post_avg_imagenet': post_avg_imagenet_resnet,
150 | }
151 | all_model_names = list(model_class_dict.keys())
152 |
153 |
--------------------------------------------------------------------------------
/post_avg/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) [2019] [Yuping Lin]
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/post_avg/PADefense.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import time
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.utils.data as data
8 | import torch.cuda as cuda
9 | import torchvision.transforms as transforms
10 | import torchvision.utils as utl
11 | import torch.backends.cudnn as cudnn
12 |
13 | import torchvision.datasets as datasets
14 | import torchvision.models as mdl
15 |
16 | def checkEntropy(scores):
17 | scores = scores.squeeze()
18 | scr = scores.clone()
19 | scr[scr <= 0] = 1.0
20 | return - torch.sum(scores * torch.log(scr))
21 |
22 |
23 | def checkConfidence(scores, K=10):
24 | scores = scores.squeeze()
25 | hScores, _ = torch.sort(scores, dim=0, descending=True)
26 |
27 | return hScores[0] / torch.sum(hScores[:K])
28 |
29 |
30 | def integratedForward(model, sps, batchSize, nClasses, device='cpu', voteMethod='avg_softmax'):
31 | N = sps.size(0)
32 | feats = torch.empty(N, nClasses)
33 | model = model.to(device)
34 |
35 | with torch.no_grad():
36 | baseInx = 0
37 | while baseInx < N:
38 | cuda.empty_cache()
39 | endInx = min(baseInx + batchSize, N)
40 | y = model(sps[baseInx:endInx, :].to(device)).detach().to('cpu')
41 | feats[baseInx:endInx, :] = y
42 | baseInx = endInx
43 |
44 | if voteMethod == 'avg_feat':
45 | feat = torch.mean(feats, dim=0, keepdim=True)
46 | elif voteMethod == 'most_vote':
47 | maxV, _ = torch.max(feats, dim=1, keepdim=True)
48 | feat = torch.sum(feats == maxV, dim=0, keepdim=True)
49 | elif voteMethod == 'weighted_feat':
50 | feat = torch.mean(feats, dim=0, keepdim=True)
51 | maxV, _ = torch.max(feats, dim=1, keepdim=True)
52 | feat = feat * torch.sum(feats == maxV, dim=0, keepdim=True).float()
53 | elif voteMethod == 'avg_softmax':
54 | feats = nn.functional.softmax(feats, dim=1)
55 | feat = torch.mean(feats, dim=0, keepdim=True)
56 | else:
57 | # default method: avg_softmax
58 | feats = nn.functional.softmax(feats, dim=1)
59 | feat = torch.mean(feats, dim=0, keepdim=True)
60 |
61 | return feat, feats
62 |
63 | # not updated, deprecated
64 | def integratedForward_cls(model, sps, batchSize, nClasses, device='cpu', count_votes=False):
65 | N = sps.size(0)
66 | feats = torch.empty(N, nClasses)
67 | model = model.to(device)
68 |
69 | with torch.no_grad():
70 | baseInx = 0
71 | while baseInx < N:
72 | cuda.empty_cache()
73 | endInx = min(baseInx + batchSize, N)
74 | y = model.classifier(sps[baseInx:endInx, :].to(device)).detach().to('cpu')
75 | feats[baseInx:endInx, :] = y
76 | baseInx = endInx
77 |
78 | if count_votes:
79 | maxV, _ = torch.max(feats, dim=1, keepdim=True)
80 | feat = torch.sum(feats == maxV, dim=0, keepdim=True)
81 | else:
82 | feat = torch.mean(feats, dim=0, keepdim=True)
83 |
84 | return feat, feats
85 |
86 |
87 | def findNeighbors_random(sp, K, r=[2], direction='both'):
88 | # only accept single sample
89 | if sp.size(0) != 1:
90 | return None
91 |
92 | if isinstance(K, list):
93 | K = sum(K)
94 |
95 | # randomly select directions
96 | shifts = torch.randn(K, sp.size(1) * sp.size(2) * sp.size(3)).to('cuda')
97 | shifts = nn.functional.normalize(shifts, p=2, dim=1)
98 | shifts = shifts.view(K, sp.size(1), sp.size(2), sp.size(3)).contiguous()
99 |
100 | if direction == 'both':
101 | shifts = torch.cat([shifts, -shifts], dim=0)
102 |
103 | nbs = []
104 | for rInx in range(len(r)):
105 | nbs.append(sp.to('cuda') + r[rInx] * shifts)
106 |
107 | return torch.cat(nbs, dim=0)
108 |
109 |
110 | def findNeighbors_plain_vgg(model, sp, K, r=[2], direction='both', device='cpu'):
111 | # only accept single sample
112 | if sp.size(0) != 1:
113 | return None
114 |
115 | # storages for K selected distances and linear mapping
116 | selected_list = []
117 |
118 | # set model to evaluation mode
119 | model = model.to(device)
120 | model = model.eval()
121 |
122 | # place holder for input, and set to require gradient
123 | x = sp.clone().to(device)
124 | x.requires_grad = True
125 |
126 | # forward through the feature part
127 | y = model.features(x)
128 | y = model.avgpool(y)
129 | y = y.view(y.size(0), -1)
130 |
131 | # forward through classifier layer by layer
132 | for lyInx, module in model.classifier.named_children():
133 | # forward
134 | y = module(y)
135 |
136 | # at each layer activation
137 | if isinstance(module, nn.Linear):
138 | # for each neuron
139 | for i in range(y.size(1)):
140 | # clear previous gradients
141 | x.grad = None
142 |
143 | # compute gradients
144 | goal = torch.abs(y[0, i])
145 | goal.backward(retain_graph=True) # retain graph for future computation
146 |
147 | # compute distance
148 | d = torch.abs(y[0, i]) / torch.norm(x.grad)
149 |
150 | # keep K shortest distances
151 | selected_list.append((d.clone().detach().to('cpu'), x.grad.clone().detach().to('cpu')))
152 | selected_list = sorted(selected_list, key=lambda x:x[0], reverse=False)
153 | selected_list = selected_list[0:K]
154 |
155 | # generate neighboring samples
156 | grad_list = [e[1] / torch.norm(e[1]) for e in selected_list]
157 | unit_shifts = torch.cat(grad_list, dim=0)
158 | nbs = []
159 | for rInx in range(len(r)):
160 | if direction == 'inc':
161 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts)
162 | elif direction == 'dec':
163 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts)
164 | else:
165 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts)
166 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts)
167 | nbs = torch.cat(nbs, dim=0)
168 | nbs = nbs.detach()
169 | nbs.requires_grad = False
170 |
171 | return nbs
172 |
173 |
174 | def findNeighbors_lastLy_vgg(model, sp, K, r=[2], direction='both', device='cpu'):
175 | # only accept single sample
176 | if sp.size(0) != 1:
177 | return None
178 |
179 | # storages for K selected distances and linear mapping
180 | selected_list = []
181 |
182 | # set model to evaluation mode
183 | model = model.to(device)
184 | model = model.eval()
185 |
186 | # place holder for input, and set to require gradient
187 | x = sp.clone().to(device)
188 | x.requires_grad = True
189 |
190 | # forward through the feature part
191 | y = model(x)
192 | y = y.view(y.size(0), -1)
193 |
194 | for i in range(y.size(1)):
195 | # clear previous gradients
196 | x.grad = None
197 |
198 | # compute gradients
199 | goal = torch.abs(y[0, i])
200 | if i < y.size(1) - 1:
201 | goal.backward(retain_graph=True) # retain graph for future computation
202 | else:
203 | goal.backward(retain_graph=False)
204 |
205 | # compute distance
206 | d = torch.abs(y[0, i]) / torch.norm(x.grad)
207 |
208 | # keep K shortest distances
209 | selected_list.append((d.clone().detach().to('cpu'), x.grad.clone().detach().to('cpu')))
210 | selected_list = sorted(selected_list, key=lambda x:x[0], reverse=False)
211 | selected_list = selected_list[0:K]
212 |
213 | # generate neighboring samples
214 | grad_list = [e[1] / torch.norm(e[1]) for e in selected_list]
215 | unit_shifts = torch.cat(grad_list, dim=0)
216 | nbs = []
217 | for rInx in range(len(r)):
218 | if direction == 'inc':
219 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts)
220 | elif direction == 'dec':
221 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts)
222 | else:
223 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts)
224 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts)
225 | nbs = torch.cat(nbs, dim=0)
226 | nbs = nbs.detach()
227 | nbs.requires_grad = False
228 |
229 | return nbs
230 |
231 |
232 | def findNeighbors_approx_vgg(model, sp, K, r=[2], direction='both', device='cpu'):
233 | # only accept single sample
234 | if sp.size(0) != 1:
235 | return None
236 |
237 | # storages for K selected distances and linear mapping
238 | selected_list = []
239 |
240 | # set model to evaluation mode
241 | model = model.to(device)
242 | model = model.eval()
243 |
244 | # place holder for input, and set to require gradient
245 | x = sp.clone().to(device)
246 | x.requires_grad = True
247 |
248 | # forward through the feature part
249 | y = model.features(x)
250 | y = model.avgpool(y)
251 | y = y.view(y.size(0), -1)
252 |
253 | # forward through classifier layer by layer
254 | lnLy_inx = 0
255 | for lyInx, module in model.classifier.named_children():
256 | # forward
257 | y = module(y)
258 |
259 | # at each layer activation
260 | if isinstance(module, nn.Linear):
261 | KInx = min(lnLy_inx, len(K)-1)
262 | if K[KInx] > 0:
263 | with torch.no_grad():
264 | # compute weight norm
265 | w_norm = torch.norm(module.weight, dim=1, keepdim=True)
266 |
267 | # compute distance
268 | d = torch.abs(y) / w_norm.t()
269 | _, sortedInx = torch.sort(d, dim=1, descending=False)
270 |
271 | # for each selected neuron
272 | for i in range(K[KInx]):
273 |
274 | # clear previous gradients
275 | x.grad = None
276 |
277 | # compute gradients
278 | goal = torch.abs(y[0, sortedInx[0, i]])
279 | goal.backward(retain_graph=True) # retain graph for future computation
280 |
281 | # record gradients
282 | selected_list.append(x.grad.clone().detach().to('cpu') / torch.norm(x.grad).detach().to('cpu'))
283 |
284 | # update number of linear layer sampled
285 | lnLy_inx = lnLy_inx + 1
286 |
287 | # generate neighboring samples
288 | unit_shifts = torch.cat(selected_list, dim=0)
289 | nbs = []
290 | for rInx in range(len(r)):
291 | if direction == 'inc':
292 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts)
293 | elif direction == 'dec':
294 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts)
295 | else:
296 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts)
297 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts)
298 | nbs = torch.cat(nbs, dim=0)
299 | nbs = nbs.detach()
300 | nbs.requires_grad = False
301 |
302 | return nbs
303 |
304 |
305 | def findNeighbors_randPick_vgg(model, sp, K, r=[2], direction='both', device='cpu'):
306 | # only accept single sample
307 | if sp.size(0) != 1:
308 | return None
309 |
310 | # storages for K selected distances and linear mapping
311 | selected_list = []
312 |
313 | # set model to evaluation mode
314 | model = model.to(device)
315 | model = model.eval()
316 |
317 | # place holder for input, and set to require gradient
318 | x = sp.clone().to(device)
319 | x.requires_grad = True
320 |
321 | # forward through the feature part
322 | y = model.features(x)
323 | y = model.avgpool(y)
324 | y = y.view(y.size(0), -1)
325 |
326 | # forward through classifier layer by layer
327 | lnLy_inx = 0
328 | for lyInx, module in model.classifier.named_children():
329 | # forward
330 | y = module(y)
331 |
332 | # at each layer activation
333 | if isinstance(module, nn.Linear):
334 | KInx = min(lnLy_inx, len(K)-1)
335 | if K[KInx] > 0:
336 | # randomly permute indices
337 | pickInx = torch.randperm(y.size(1))
338 |
339 | # for each selected neuron
340 | for i in range(K[KInx]):
341 |
342 | # clear previous gradients
343 | x.grad = None
344 |
345 | # compute gradients
346 | goal = torch.abs(y[0, pickInx[i]])
347 | goal.backward(retain_graph=True) # retain graph for future computation
348 |
349 | # record gradients
350 | selected_list.append(x.grad.clone().detach().to('cpu') / torch.norm(x.grad).detach().to('cpu'))
351 |
352 | # update number of linear layer sampled
353 | lnLy_inx = lnLy_inx + 1
354 |
355 | # generate neighboring samples
356 | unit_shifts = torch.cat(selected_list, dim=0)
357 | nbs = []
358 | for rInx in range(len(r)):
359 | if direction == 'inc':
360 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts)
361 | elif direction == 'dec':
362 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts)
363 | else:
364 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts)
365 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts)
366 | nbs = torch.cat(nbs, dim=0)
367 | nbs = nbs.detach()
368 | nbs.requires_grad = False
369 |
370 | return nbs
371 |
372 | # not updated, deprecated
373 | def findNeighbors_feats_lastLy_vgg(model, sp, K, r=[2], direction='both', device='cpu', includeOriginal=True):
374 | # only accept single sample
375 | if sp.size(0) != 1:
376 | return None
377 |
378 | # storages for K selected distances and linear mapping
379 | selected_list = []
380 |
381 | # set model to evaluation mode
382 | model = model.to(device)
383 | model = model.eval()
384 |
385 | # forward through the feature part
386 | with torch.no_grad():
387 | feat = model.features(sp.to(device))
388 | feat = feat.view(feat.size(0), -1).contiguous().detach()
389 |
390 | # place holder for feature, and set to require gradient
391 | x = feat.clone().detach()
392 | x.requires_grad = True
393 |
394 | # forward through the classifier part
395 | y = model.classifier(x)
396 | y = y.view(y.size(0), -1)
397 |
398 | for i in range(y.size(1)):
399 | # clear previous gradients
400 | x.grad = None
401 |
402 | # compute gradients
403 | goal = torch.abs(y[0, i])
404 | if i < y.size(1) - 1:
405 | goal.backward(retain_graph=True) # retain graph for future computation
406 | else:
407 | goal.backward(retain_graph=False)
408 |
409 | # compute distance
410 | d = torch.abs(y[0, i]) / torch.norm(x.grad)
411 |
412 | # keep K shortest distances
413 | selected_list.append((d.clone().detach().to('cpu'), x.grad.clone().detach().to('cpu')))
414 | selected_list = sorted(selected_list, key=lambda x:x[0], reverse=False)
415 | selected_list = selected_list[0:K]
416 |
417 | # generate neighboring samples
418 | grad_list = [e[1] / torch.norm(e[1]) for e in selected_list]
419 | unit_shifts = torch.cat(grad_list, dim=0)
420 | if includeOriginal:
421 | nbs = [feat.to('cpu')]
422 | else:
423 | nbs = []
424 | for rInx in range(len(r)):
425 | if direction == 'inc':
426 | nbs.append(feat.to('cpu') + r[rInx] * unit_shifts)
427 | elif direction == 'dec':
428 | nbs.append(feat.to('cpu') - r[rInx] * unit_shifts)
429 | else:
430 | nbs.append(feat.to('cpu') + r[rInx] * unit_shifts)
431 | nbs.append(feat.to('cpu') - r[rInx] * unit_shifts)
432 | nbs = torch.cat(nbs, dim=0)
433 | nbs = nbs.detach()
434 | nbs.requires_grad = False
435 |
436 | return nbs
437 |
438 | # not updated, deprecated
439 | def findNeighbors_feats_approx_vgg(model, sp, K, r=[2], direction='both', device='cpu', includeOriginal=True):
440 | # only accept single sample
441 | if sp.size(0) != 1:
442 | return None
443 |
444 | # storages for K selected distances and linear mapping
445 | selected_list = []
446 |
447 | # set model to evaluation mode
448 | model = model.to(device)
449 | model = model.eval()
450 |
451 | # forward through the feature part
452 | with torch.no_grad():
453 | feat = model.features(sp.to(device))
454 | feat = feat.view(feat.size(0), -1).contiguous().detach()
455 |
456 | # place holder for feature, and set to require gradient
457 | x = feat.clone().detach()
458 | x.requires_grad = True
459 | y = x
460 |
461 | # forward through classifier layer by layer
462 | lnLy_inx = 0
463 | for lyInx, module in model.classifier.named_children():
464 | # forward
465 | y = module(y)
466 |
467 | # at each layer activation
468 | if isinstance(module, nn.Linear):
469 | KInx = min(lnLy_inx, len(K)-1)
470 | if K[KInx] > 0:
471 | with torch.no_grad():
472 | # compute weight norm
473 | w_norm = torch.norm(module.weight, dim=1, keepdim=True)
474 |
475 | # compute distance
476 | d = torch.abs(y) / w_norm.t()
477 | _, sortedInx = torch.sort(d, dim=1, descending=False)
478 |
479 | # for each selected neuron
480 | for i in range(K[KInx]):
481 |
482 | # clear previous gradients
483 | x.grad = None
484 |
485 | # compute gradients
486 | goal = torch.abs(y[0, sortedInx[0, i]])
487 | goal.backward(retain_graph=True) # retain graph for future computation
488 |
489 | # record gradients
490 | selected_list.append(x.grad.clone().detach().to('cpu') / torch.norm(x.grad).detach().to('cpu'))
491 |
492 | # update number of linear layer sampled
493 | lnLy_inx = lnLy_inx + 1
494 |
495 | # generate neighboring samples
496 | unit_shifts = torch.cat(selected_list, dim=0)
497 | if includeOriginal:
498 | nbs = [feat.to('cpu')]
499 | else:
500 | nbs = []
501 | for rInx in range(len(r)):
502 | if direction == 'inc':
503 | nbs.append(feat.to('cpu') + r[rInx] * unit_shifts)
504 | elif direction == 'dec':
505 | nbs.append(feat.to('cpu') - r[rInx] * unit_shifts)
506 | else:
507 | nbs.append(feat.to('cpu') + r[rInx] * unit_shifts)
508 | nbs.append(feat.to('cpu') - r[rInx] * unit_shifts)
509 | nbs = torch.cat(nbs, dim=0)
510 | nbs = nbs.detach()
511 | nbs.requires_grad = False
512 |
513 | return nbs
514 |
515 |
516 | def formSquad_vgg(method, model, sp, K, r=[2], direction='both', device='cpu', includeOriginal=True):
517 | if method == 'random':
518 | nbs = findNeighbors_random(sp, K, r, direction=direction)
519 | if includeOriginal:
520 | nbs = torch.cat([sp, nbs], dim=0)
521 | elif method == 'plain':
522 | nbs = findNeighbors_plain_vgg(model, sp, K, r, direction=direction, device=device)
523 | if includeOriginal:
524 | nbs = torch.cat([sp, nbs], dim=0)
525 | elif method == 'lastLy':
526 | nbs = findNeighbors_lastLy_vgg(model, sp, K, r, direction=direction, device=device)
527 | if includeOriginal:
528 | nbs = torch.cat([sp, nbs], dim=0)
529 | elif method == 'approx':
530 | nbs = findNeighbors_approx_vgg(model, sp, K, r, direction=direction, device=device)
531 | if includeOriginal:
532 | nbs = torch.cat([sp, nbs], dim=0)
533 | elif method == 'randPick':
534 | nbs = findNeighbors_randPick_vgg(model, sp, K, r, direction=direction, device=device)
535 | if includeOriginal:
536 | nbs = torch.cat([sp, nbs], dim=0)
537 | elif method == 'feats_lastLy':
538 | nbs = findNeighbors_feats_lastLy_vgg(model, sp, K, r, direction=direction, device=device, includeOriginal=includeOriginal)
539 | elif method == 'feats_approx':
540 | nbs = findNeighbors_feats_approx_vgg(model, sp, K, r, direction=direction, device=device, includeOriginal=includeOriginal)
541 | else:
542 | # if invalid method, use default setting. (actually should raise error here)
543 | nbs = findNeighbors_random(sp, K, r, direction=direction)
544 | if includeOriginal:
545 | nbs = torch.cat([sp, nbs], dim=0)
546 |
547 | return nbs
548 |
549 |
550 | def findNeighbors_approx_resnet(model, sp, K, r=[2], direction='both', device='cpu'):
551 | # only accept single sample
552 | if sp.size(0) != 1:
553 | return None
554 |
555 | # storages for K selected distances and linear mapping
556 | selected_list = []
557 |
558 | # set model to evaluation mode
559 | model = model.to(device)
560 | model = model.eval()
561 |
562 | # place holder for input, and set to require gradient
563 | x = sp.clone().to(device)
564 | x.requires_grad = True
565 |
566 | # forward through the model
567 | y = model(x)
568 | y = y.view(y.size(0), -1)
569 |
570 | if K > 0:
571 | with torch.no_grad():
572 | # compute weight norm
573 | w_norm = torch.norm(model.fc.weight, dim=1, keepdim=True)
574 |
575 | # compute distance
576 | d = torch.abs(y) / w_norm.t()
577 | _, sortedInx = torch.sort(d, dim=1, descending=False)
578 |
579 | # for each selected neuron
580 | for i in range(K):
581 |
582 | # clear previous gradients
583 | x.grad = None
584 |
585 | # compute gradients
586 | goal = torch.abs(y[0, sortedInx[0, i]])
587 | goal.backward(retain_graph=True) # retain graph for future computation
588 |
589 | # record gradients
590 | selected_list.append(x.grad.clone().detach().to('cpu') / torch.norm(x.grad).detach().to('cpu'))
591 |
592 | # generate neighboring samples
593 | unit_shifts = torch.cat(selected_list, dim=0)
594 | nbs = []
595 | for rInx in range(len(r)):
596 | if direction == 'inc':
597 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts)
598 | elif direction == 'dec':
599 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts)
600 | else:
601 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts)
602 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts)
603 | nbs = torch.cat(nbs, dim=0)
604 | nbs = nbs.detach()
605 | nbs.requires_grad = False
606 |
607 | return nbs
608 |
609 |
610 | def findNeighbors_approx_resnet_small(model, sp, K, r=[2], direction='both', device='cpu'):
611 | # only accept single sample
612 | if sp.size(0) != 1:
613 | return None
614 |
615 | # storages for K selected distances and linear mapping
616 | selected_list = []
617 |
618 | # set model to evaluation mode
619 | model = model.to(device)
620 | model = model.eval()
621 |
622 | # place holder for input, and set to require gradient
623 | x = sp.clone().to(device)
624 | x.requires_grad = True
625 |
626 | # forward through the model
627 | y = model(x)
628 | y = y.view(y.size(0), -1)
629 |
630 | if K > 0:
631 | with torch.no_grad():
632 | # compute weight norm
633 | w_norm = torch.norm(model.linear.weight, dim=1, keepdim=True)
634 |
635 | # compute distance
636 | d = torch.abs(y) / w_norm.t()
637 | _, sortedInx = torch.sort(d, dim=1, descending=False)
638 |
639 | # for each selected neuron
640 | for i in range(K):
641 |
642 | # clear previous gradients
643 | x.grad = None
644 |
645 | # compute gradients
646 | goal = torch.abs(y[0, sortedInx[0, i]])
647 | goal.backward(retain_graph=True) # retain graph for future computation
648 |
649 | # record gradients
650 | selected_list.append(x.grad.clone().detach().to('cpu') / torch.norm(x.grad).detach().to('cpu'))
651 |
652 | # generate neighboring samples
653 | unit_shifts = torch.cat(selected_list, dim=0)
654 | nbs = []
655 | for rInx in range(len(r)):
656 | if direction == 'inc':
657 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts)
658 | elif direction == 'dec':
659 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts)
660 | else:
661 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts)
662 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts)
663 | nbs = torch.cat(nbs, dim=0)
664 | nbs = nbs.detach()
665 | nbs.requires_grad = False
666 |
667 | return nbs
668 |
669 |
670 | def formSquad_resnet(method, model, sp, K, r=[2], direction='both', device='cpu', includeOriginal=True):
671 | if method == 'random':
672 | nbs = findNeighbors_random(sp, K, r, direction=direction)
673 | if includeOriginal:
674 | nbs = torch.cat([sp, nbs], dim=0)
675 | elif method == 'approx':
676 | nbs = findNeighbors_approx_resnet(model, sp, K, r, direction=direction, device=device)
677 | if includeOriginal:
678 | nbs = torch.cat([sp, nbs], dim=0)
679 | elif method == 'approx_cifar10':
680 | nbs = findNeighbors_approx_resnet_small(model, sp, K, r, direction=direction, device=device)
681 | if includeOriginal:
682 | nbs = torch.cat([sp, nbs], dim=0)
683 | else:
684 | # if invalid method, use default setting. (actually should raise error here)
685 | nbs = findNeighbors_random(sp, K, r, direction=direction)
686 | if includeOriginal:
687 | nbs = torch.cat([sp, nbs], dim=0)
688 |
689 | return nbs
690 |
--------------------------------------------------------------------------------
/post_avg/README.md:
--------------------------------------------------------------------------------
1 | # Post-Average Adversarial Defense
2 | Implementation of the Post-Average adversarial defense method as described in [Bandlimiting Neural Networks Against Adversarial Attacks](https://arxiv.org/abs/1905.12797).
3 |
4 | This implementation is based on PyTorch and uses the [Foolbox](https://github.com/bethgelab/foolbox) toolbox to provide attacking methods.
5 |
6 | ## [robustml](https://github.com/robust-ml/robustml) evaluation
7 | This implementation supports the robustml API for evaluation.
8 |
9 | To evaluate on CIFAR-10:
10 | ```
11 | python robustml_test_cifar10.py
12 | ```
13 |
14 | To evaluate on ImageNet:
15 | ```
16 | python robustml_test_imagenet.py
17 | ```
18 |
--------------------------------------------------------------------------------
/post_avg/attacks.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import robustml
4 | import numpy as np
5 |
6 | import foolbox.criteria as crt
7 | import foolbox.attacks as attacks
8 | import foolbox.distances as distances
9 | import foolbox.adversarial as adversarial
10 |
11 | class NullAttack(robustml.attack.Attack):
12 | def run(self, x, y, target):
13 | return x
14 |
15 | class FoolboxAttackWrapper(robustml.attack.Attack):
16 | def __init__(self, attack):
17 | self._attacker = attack
18 |
19 | def run(self, x, y, target):
20 | # model requires image in (C, H, W), but robustml provides (H, W, C)
21 | # transpose x to accommodate pytorch's axis arrangement convention
22 | x = np.transpose(x, (2, 0, 1))
23 | if target is not None:
24 | adv_criterion = crt.TargetClass(target)
25 | adv_obj = adversarial.Adversarial(self._attacker._default_model, adv_criterion, x, y, distance=self._attacker._default_distance)
26 | adv_x = self._attacker(adv_obj)
27 | else:
28 | adv_x = self._attacker(x, y)
29 |
30 | # transpose back to data provider's convention
31 | return np.transpose(adv_x, (1, 2, 0))
32 |
33 | def fgsmAttack(victim_model): # victim_model should be model wrapped with foolbox model
34 | attacker = attacks.GradientSignAttack(victim_model, crt.Misclassification())
35 | return FoolboxAttackWrapper(attacker)
36 |
37 | def pgdAttack(victim_model): # victim_model should be model wrapped with foolbox model
38 | attacker = attacks.RandomStartProjectedGradientDescentAttack(victim_model, crt.Misclassification(), distance=distances.Linfinity)
39 | return FoolboxAttackWrapper(attacker)
40 |
41 | def dfAttack(victim_model): # victim_model should be model wrapped with foolbox model
42 | attacker = attacks.DeepFoolAttack(victim_model, crt.Misclassification())
43 | return FoolboxAttackWrapper(attacker)
44 |
45 | def cwAttack(victim_model): # victim_model should be model wrapped with foolbox model
46 | attacker = attacks.CarliniWagnerL2Attack(victim_model, crt.Misclassification())
47 | return FoolboxAttackWrapper(attacker)
48 |
--------------------------------------------------------------------------------
/post_avg/postAveragedModels.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import robustml
4 | import numpy as np
5 | from collections import OrderedDict
6 | import post_avg.PADefense as padef
7 | import post_avg.resnetSmall as rnsmall
8 |
9 | import torch
10 | import torchvision.models as mdl
11 | import torchvision.transforms as transforms
12 |
13 | class PostAveragedResNet152(robustml.model.Model):
14 | def __init__(self, K, R, eps, device='cuda'):
15 | self._model = mdl.resnet152(pretrained=True).to(device)
16 | self._dataset = robustml.dataset.ImageNet((224, 224, 3))
17 | self._threat_model = robustml.threat_model.Linf(epsilon=eps)
18 | self._K = K
19 | self._r = [R/3, 2*R/3, R]
20 | self._sample_method = 'random'
21 | self._vote_method = 'avg_softmax'
22 | self._device = device
23 |
24 | @property
25 | def model(self):
26 | return self._model
27 |
28 | @property
29 | def dataset(self):
30 | return self._dataset
31 |
32 | @property
33 | def threat_model(self):
34 | return self._threat_model
35 |
36 | def classify(self, x):
37 | x = x.unsqueeze(0)
38 |
39 | # gather neighbor samples
40 | x_squad = padef.formSquad_resnet(self._sample_method, self._model, x, self._K, self._r, device=self._device)
41 |
42 | # forward with a batch of neighbors
43 | logits, _ = padef.integratedForward(self._model, x_squad, batchSize=100, nClasses=1000, device=self._device, voteMethod=self._vote_method)
44 |
45 | return torch.as_tensor(logits)
46 |
47 | def __call__(self, x):
48 | logits_list = []
49 | for img in x:
50 | logits = self.classify(img)
51 | logits_list.append(logits)
52 | return torch.cat(logits_list, dim=0)
53 |
54 | def _preprocess(self, image):
55 | # normalization used by pre-trained model
56 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
57 |
58 | return normalize(image)
59 |
60 | def to(self, device):
61 | self._model = self._model.to(device)
62 | self._device = device
63 |
64 | def eval(self):
65 | self._model = self._model.eval()
66 |
67 |
68 | def pa_resnet152_config1():
69 | return PostAveragedResNet152(K=15, R=30, eps=8/255)
70 |
71 |
72 | class PostAveragedResNet110(robustml.model.Model):
73 | def __init__(self, K, R, eps, device='cuda'):
74 | # load model state dict
75 | checkpoint = torch.load('post_avg/trainedModel/resnet110.th')
76 | paramDict = OrderedDict()
77 | for k, v in checkpoint['state_dict'].items():
78 | # remove 'module.' prefix introduced by DataParallel, if any
79 | if k.startswith('module.'):
80 | paramDict[k[7:]] = v
81 | self._model = rnsmall.resnet110()
82 | self._model.load_state_dict(paramDict)
83 | self._model = self._model.to(device)
84 |
85 | self._dataset = robustml.dataset.CIFAR10()
86 | self._threat_model = robustml.threat_model.Linf(epsilon=eps)
87 | self._K = K
88 | self._r = [R/3, 2*R/3, R]
89 | self._sample_method = 'random'
90 | self._vote_method = 'avg_softmax'
91 | self._device = device
92 |
93 | @property
94 | def model(self):
95 | return self._model
96 |
97 | @property
98 | def dataset(self):
99 | return self._dataset
100 |
101 | @property
102 | def threat_model(self):
103 | return self._threat_model
104 |
105 | def classify(self, x):
106 | x = x.unsqueeze(0)
107 |
108 | # gather neighbor samples
109 | x_squad = padef.formSquad_resnet(self._sample_method, self._model, x, self._K, self._r, device=self._device)
110 |
111 | # forward with a batch of neighbors
112 | logits, _ = padef.integratedForward(self._model, x_squad, batchSize=1000, nClasses=10, device=self._device, voteMethod=self._vote_method)
113 |
114 | return torch.as_tensor(logits)
115 |
116 | def __call__(self, x):
117 | logits_list = []
118 | for img in x:
119 | logits = self.classify(img)
120 | logits_list.append(logits)
121 | return torch.cat(logits_list, dim=0)
122 |
123 | def _preprocess(self, image):
124 | # normalization used by pre-trained model
125 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
126 |
127 | return normalize(image)
128 |
129 | def to(self, device):
130 | self._model = self._model.to(device)
131 | self._device = device
132 |
133 | def eval(self):
134 | self._model = self._model.eval()
135 |
136 |
137 | def pa_resnet110_config1():
138 | return PostAveragedResNet110(K=15, R=6, eps=8/255)
139 |
--------------------------------------------------------------------------------
/post_avg/resnetSmall.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | '''
4 | Properly implemented ResNet-s for CIFAR10 as described in paper [1].
5 |
6 | The implementation and structure of this file is hugely influenced by [2]
7 | which is implemented for ImageNet and doesn't have option A for identity.
8 | Moreover, most of the implementations on the web is copy-paste from
9 | torchvision's resnet and has wrong number of params.
10 |
11 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following
12 | number of layers and parameters:
13 |
14 | name | layers | params
15 | ResNet20 | 20 | 0.27M
16 | ResNet32 | 32 | 0.46M
17 | ResNet44 | 44 | 0.66M
18 | ResNet56 | 56 | 0.85M
19 | ResNet110 | 110 | 1.7M
20 | ResNet1202| 1202 | 19.4m
21 |
22 | which this implementation indeed has.
23 |
24 | Reference:
25 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
26 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
27 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
28 |
29 | If you use this implementation in you work, please don't forget to mention the
30 | author, Yerlan Idelbayev.
31 | '''
32 | import torch
33 | import torch.nn as nn
34 | import torch.nn.functional as F
35 | import torch.nn.init as init
36 |
37 | from torch.autograd import Variable
38 |
39 | __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']
40 |
41 | def _weights_init(m):
42 | classname = m.__class__.__name__
43 | # print(classname)
44 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
45 | init.kaiming_normal(m.weight)
46 |
47 | class LambdaLayer(nn.Module):
48 | def __init__(self, lambd):
49 | super(LambdaLayer, self).__init__()
50 | self.lambd = lambd
51 |
52 | def forward(self, x):
53 | return self.lambd(x)
54 |
55 |
56 | class BasicBlock(nn.Module):
57 | expansion = 1
58 |
59 | def __init__(self, in_planes, planes, stride=1, option='A'):
60 | super(BasicBlock, self).__init__()
61 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
62 | self.bn1 = nn.BatchNorm2d(planes)
63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
64 | self.bn2 = nn.BatchNorm2d(planes)
65 |
66 | self.shortcut = nn.Sequential()
67 | if stride != 1 or in_planes != planes:
68 | if option == 'A':
69 | """
70 | For CIFAR10 ResNet paper uses option A.
71 | """
72 | self.shortcut = LambdaLayer(lambda x:
73 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
74 | elif option == 'B':
75 | self.shortcut = nn.Sequential(
76 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
77 | nn.BatchNorm2d(self.expansion * planes)
78 | )
79 |
80 | def forward(self, x):
81 | out = F.relu(self.bn1(self.conv1(x)))
82 | out = self.bn2(self.conv2(out))
83 | out += self.shortcut(x)
84 | out = F.relu(out)
85 | return out
86 |
87 |
88 | class ResNet(nn.Module):
89 | def __init__(self, block, num_blocks, num_classes=10):
90 | super(ResNet, self).__init__()
91 | self.in_planes = 16
92 |
93 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
94 | self.bn1 = nn.BatchNorm2d(16)
95 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
96 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
97 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
98 | self.linear = nn.Linear(64, num_classes)
99 |
100 | self.apply(_weights_init)
101 |
102 | def _make_layer(self, block, planes, num_blocks, stride):
103 | strides = [stride] + [1]*(num_blocks-1)
104 | layers = []
105 | for stride in strides:
106 | layers.append(block(self.in_planes, planes, stride))
107 | self.in_planes = planes * block.expansion
108 |
109 | return nn.Sequential(*layers)
110 |
111 | def forward(self, x):
112 | out = F.relu(self.bn1(self.conv1(x)))
113 | out = self.layer1(out)
114 | out = self.layer2(out)
115 | out = self.layer3(out)
116 | out = F.avg_pool2d(out, out.size()[3])
117 | out = out.view(out.size(0), -1)
118 | out = self.linear(out)
119 | return out
120 |
121 |
122 | def resnet20():
123 | return ResNet(BasicBlock, [3, 3, 3])
124 |
125 |
126 | def resnet32():
127 | return ResNet(BasicBlock, [5, 5, 5])
128 |
129 |
130 | def resnet44():
131 | return ResNet(BasicBlock, [7, 7, 7])
132 |
133 |
134 | def resnet56():
135 | return ResNet(BasicBlock, [9, 9, 9])
136 |
137 |
138 | def resnet110():
139 | return ResNet(BasicBlock, [18, 18, 18])
140 |
141 |
142 | def resnet1202():
143 | return ResNet(BasicBlock, [200, 200, 200])
144 |
145 |
146 | def test(net):
147 | import numpy as np
148 | total_params = 0
149 |
150 | for x in filter(lambda p: p.requires_grad, net.parameters()):
151 | total_params += np.prod(x.data.numpy().shape)
152 | print("Total number of params", total_params)
153 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters()))))
154 |
155 |
156 | if __name__ == "__main__":
157 | for net_name in __all__:
158 | if net_name.startswith('resnet'):
159 | print(net_name)
160 | test(globals()[net_name]())
161 | print()
--------------------------------------------------------------------------------
/post_avg/robustml_test_cifar10.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import argparse
5 | import robustml
6 | import numpy as np
7 | from foolbox.models import PyTorchModel
8 | from robustml_portal import attacks as atk
9 | from robustml_portal import postAveragedModels as pamdl
10 |
11 |
12 | # argument parsing
13 | parser = argparse.ArgumentParser(description="robustml evaluation on CIFAR-10")
14 | parser.add_argument("datasetPath", help="path to the 'test_batch' file")
15 | parser.add_argument("--start", type=int, default=0, help="inclusive starting index for data. default: 0")
16 | parser.add_argument("--end", type=int, help="exclusive ending index for data. default: dataset size")
17 | parser.add_argument("--attack", choices=["pgd", "fgsm", "df", "cw", "none"], default="pgd", help="attack method to be used. default: pgd")
18 | parser.add_argument("--device", help="compuation device to be used. 'cpu' or 'cuda:'")
19 | args = parser.parse_args()
20 |
21 | if args.device is None:
22 | if torch.cuda.is_available():
23 | device = torch.device("cuda")
24 | else:
25 | device = torch.device("cpu")
26 | else:
27 | device = torch.device(args.device)
28 |
29 | # setup test model
30 | model = pamdl.pa_resnet110_config1()
31 | model.to(device)
32 | model.eval()
33 |
34 | # setup attacker
35 | nClasses = 10
36 | victim_model = PyTorchModel(model.model, (0,1), nClasses, device=device, preprocessing=(np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)), np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))))
37 | if args.attack == "pgd":
38 | attack = atk.pgdAttack(victim_model)
39 | elif args.attack == "fgsm":
40 | attack = atk.fgsmAttack(victim_model)
41 | elif args.attack == "df":
42 | attack = atk.dfAttack(victim_model)
43 | elif args.attack == "cw":
44 | attack = atk.cwAttack(victim_model)
45 | else:
46 | attack = atk.NullAttack()
47 |
48 | # setup data provider
49 | prov = robustml.provider.CIFAR10(args.datasetPath)
50 |
51 | # evaluate performance
52 | if args.end is None:
53 | args.end = len(prov)
54 | atk_success_rate = robustml.evaluate.evaluate(model, attack, prov, start=args.start, end=args.end)
55 | print('Overall attack success rate: %.4f' % atk_success_rate)
--------------------------------------------------------------------------------
/post_avg/robustml_test_imagenet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import argparse
5 | import robustml
6 | import numpy as np
7 | from foolbox.models import PyTorchModel
8 | from robustml_portal import attacks as atk
9 | from robustml_portal import postAveragedModels as pamdl
10 |
11 |
12 | # argument parsing
13 | parser = argparse.ArgumentParser(description="robustml evaluation on ImageNet")
14 | parser.add_argument("datasetPath", help="directory containing 'val.txt' and 'val/' folder")
15 | parser.add_argument("--start", type=int, default=0, help="inclusive starting index for data. default: 0")
16 | parser.add_argument("--end", type=int, help="exclusive ending index for data. default: dataset size")
17 | parser.add_argument("--attack", choices=["pgd", "fgsm", "df", "cw", "none"], default="pgd", help="attack method to be used. default: pgd")
18 | parser.add_argument("--device", help="compuation device to be used. 'cpu' or 'cuda:'")
19 | args = parser.parse_args()
20 |
21 | if args.device is None:
22 | if torch.cuda.is_available():
23 | device = torch.device("cuda")
24 | else:
25 | device = torch.device("cpu")
26 | else:
27 | device = torch.device(args.device)
28 |
29 | # setup test model
30 | model = pamdl.pa_resnet152_config1()
31 | model.to(device)
32 | model.eval()
33 |
34 | # setup attacker
35 | nClasses = 1000
36 | victim_model = PyTorchModel(model.model, (0,1), nClasses, device=device, preprocessing=(np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)), np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))))
37 | if args.attack == "pgd":
38 | attack = atk.pgdAttack(victim_model)
39 | elif args.attack == "fgsm":
40 | attack = atk.fgsmAttack(victim_model)
41 | elif args.attack == "df":
42 | attack = atk.dfAttack(victim_model)
43 | elif args.attack == "cw":
44 | attack = atk.cwAttack(victim_model)
45 | else:
46 | attack = atk.NullAttack()
47 |
48 | # setup data provider
49 | prov = robustml.provider.ImageNet(args.datasetPath, (224, 224, 3))
50 |
51 | # evaluate performance
52 | if args.end is None:
53 | args.end = len(prov)
54 | atk_success_rate = robustml.evaluate.evaluate(model, attack, prov, start=args.start, end=args.end)
55 | print('Overall attack success rate: %.4f' % atk_success_rate)
--------------------------------------------------------------------------------
/post_avg/visualHelper.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 |
7 | import matplotlib; matplotlib.use('agg')
8 | import matplotlib.pyplot as plt
9 |
10 | def plotPredStats(feats, lb, K=10, image=None, noiseImage=None, savePath=None):
11 |
12 | # score by averaging
13 | scores = torch.mean(feats, dim=0)
14 |
15 | # sort and select the top K scores
16 | hScores, hCates = torch.sort(scores, dim=0, descending=True)
17 | hScores = hScores[:K].numpy()
18 | hCates = hCates[:K].numpy()
19 |
20 | # get individual preditions
21 | _, preds = torch.max(feats, dim=1)
22 |
23 | # count votes
24 | preds_count = {lb: 0}
25 | for i in range(feats.size(0)):
26 | if preds[i].item() in preds_count:
27 | preds_count[preds[i].item()] = preds_count[preds[i].item()] + 1
28 | else:
29 | preds_count[preds[i].item()] = 1
30 |
31 | candidates = sorted(preds_count.keys())
32 | votes = [preds_count[x] for x in candidates]
33 |
34 | # generate figure
35 | fig = plt.figure()
36 | if image is None and noiseImage is None:
37 | ax1, ax2, ax3 = fig.subplots(3, 1)
38 | else:
39 | axes = fig.subplots(2, 2)
40 | ax1 = axes[0, 0]
41 | ax2 = axes[1, 0]
42 | ax3 = axes[0, 1]
43 | ax4 = axes[1, 1]
44 |
45 | # chart 1, votes distribution
46 | inx1 = list(range(len(candidates)))
47 | clr1 = []
48 | for i in inx1:
49 | if candidates[i] == lb:
50 | clr1.append('Red')
51 | else:
52 | clr1.append('SkyBlue')
53 | rects1 = ax1.bar(inx1, votes, color=clr1)
54 | for rect in rects1:
55 | h = rect.get_height()
56 | ax1.text(rect.get_x() + 0.5 * rect.get_width(), 1.01 * h, '{}'.format(h), ha='center', va='bottom')
57 | ax1.set_ylim(top=1.1 * ax1.get_ylim()[1])
58 | ax1.set_xticks(inx1)
59 | ax1.set_xticklabels([str(x) for x in candidates], rotation=30)
60 | ax1.set_ylabel('votes')
61 | ax1.set_title('Votes Distribution')
62 |
63 | # chart 2, top prediction scores
64 | inx2 = list(range(len(hCates)))
65 | clr2 = []
66 | for i in inx2:
67 | if hCates[i] == lb:
68 | clr2.append('Red')
69 | else:
70 | clr2.append('SkyBlue')
71 | rects2 = ax2.bar(inx2, hScores, color=clr2)
72 | for rect in rects2:
73 | h = rect.get_height()
74 | ax2.text(rect.get_x() + 0.5 * rect.get_width(), 1.01 * h, '{:.2f}'.format(h), ha='center', va='bottom')
75 | ax2.set_ylim(top=1.1 * ax2.get_ylim()[1])
76 | ax2.set_xticks(inx2)
77 | ax2.set_xticklabels([str(x) for x in hCates], rotation=30)
78 | ax2.set_ylabel('score')
79 | ax2.set_xlabel('Top Predictions')
80 |
81 | # axis 3, the noise image
82 | if noiseImage is not None:
83 | ax3.imshow(noiseImage)
84 | ax3.set_xlabel('Noise Image')
85 | ax3.set_axis_off()
86 | else:
87 | # if noise image is not given, show prediction event plot
88 | clr3 = []
89 | for i in range(preds.size(0)):
90 | if preds[i] == lb:
91 | clr3.append('Red')
92 | else:
93 | clr3.append('Green')
94 | ax3.eventplot(preds.unsqueeze(1).numpy(), orientation='vertical', colors=clr3)
95 | ax3.set_yticks(candidates)
96 | ax3.set_yticklabels([str(x) for x in candidates])
97 | ax3.set_xlabel('sample index')
98 | ax3.set_ylabel('class')
99 |
100 | # axis 4, the input image
101 | if image is not None:
102 | ax4.imshow(image)
103 | ax4.set_title('Input Image')
104 | ax4.set_axis_off()
105 |
106 | # save figure and close
107 | if savePath is not None:
108 | fig.savefig(savePath)
109 |
110 | plt.close(fig)
111 |
112 |
113 | def plotPerturbationDistribution(perturbations, savePath=None):
114 |
115 | # generate figure
116 | fig = plt.figure()
117 | ax1, ax2, ax3 = fig.subplots(3, 1)
118 |
119 | # plot scatter chart
120 | perts = np.asarray(perturbations)
121 | ax1.scatter(perts[:, 0], perts[:, 1], c='SkyBlue')
122 | ax1.autoscale(axis='x')
123 | ax1.set_ylim((-1, 2))
124 | ax1.set_yticks([0, 1])
125 | ax1.set_yticklabels(['missed', 'defensed'])
126 | ax1.set_xlabel('Perturbation distance')
127 | ax1.set_title('Perturbations Distribution')
128 |
129 | # plot bin chart for defensed adversarial samples
130 | x = [e[0] for e in perturbations if e[1] == 1]
131 | ax2.hist(x, bins=20, color='SkyBlue')
132 | ax2.set_xlabel('Perturbation distance')
133 | ax2.set_ylabel('Denfensed')
134 |
135 | # plot bin chart for missed adversarial samples
136 | x = [e[0] for e in perturbations if e[1] == 0]
137 | ax3.hist(x, bins=20, color='Red')
138 | ax3.set_xlabel('Perturbation distance')
139 | ax3.set_ylabel('Missed')
140 |
141 | # save figure and close
142 | if savePath is not None:
143 | fig.savefig(savePath)
144 |
145 | plt.close(fig)
146 |
147 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 |
4 |
5 | class Logger:
6 | def __init__(self, path):
7 | self.path = path
8 | if path != '':
9 | folder = '/'.join(path.split('/')[:-1])
10 | if not os.path.exists(folder):
11 | os.makedirs(folder)
12 |
13 | def print(self, message):
14 | print(message)
15 | if self.path != '':
16 | with open(self.path, 'a') as f:
17 | f.write(message + '\n')
18 | f.flush()
19 |
20 |
21 | def dense_to_onehot(y_test, n_cls):
22 | y_test_onehot = np.zeros([len(y_test), n_cls], dtype=bool)
23 | y_test_onehot[np.arange(len(y_test)), y_test] = True
24 | return y_test_onehot
25 |
26 |
27 | def random_classes_except_current(y_test, n_cls):
28 | y_test_new = np.zeros_like(y_test)
29 | for i_img in range(y_test.shape[0]):
30 | lst_classes = list(range(n_cls))
31 | lst_classes.remove(y_test[i_img])
32 | y_test_new[i_img] = np.random.choice(lst_classes)
33 | return y_test_new
34 |
35 |
36 | def softmax(x):
37 | e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
38 | return e_x / e_x.sum(axis=1, keepdims=True)
39 |
--------------------------------------------------------------------------------