├── .gitignore ├── LICENSE ├── README.md ├── attack.py ├── data.py ├── images ├── ablation_study.png ├── adv_examples.png ├── algorithm_rs.png ├── ezgif.com-gif-maker-50-conf-small.gif ├── ezgif.com-gif-maker-img-53-l2-2.gif ├── main_results_imagenet.png ├── main_results_imagenet_l2_commonly_successful.png ├── repository_picture.png ├── sensitivity_wrt_p.png ├── success_rate_curves_full.png ├── table_clp_lsq.png ├── table_madry_mnist_l2.png ├── table_madry_trades_mnist_linf.png └── table_post_avg.png ├── logit_pairing └── models.py ├── madry_cifar10 ├── LICENSE ├── README.md ├── cifar10_input.py ├── config.json ├── eval.py ├── fetch_model.py ├── model.py ├── model_robustml.py ├── pgd_attack.py ├── run_attack.py └── train.py ├── madry_mnist ├── LICENSE ├── config.json ├── eval.py ├── fetch_model.py ├── model.py ├── run_attack.py └── train.py ├── metrics ├── 2019-11-10 15:57:14 model=pt_inception dataset=imagenet n_ex=1000 eps=12.75 p=0.05 n_iter=10000.metrics.npy ├── 2019-11-10 15:57:14 model=pt_resnet dataset=imagenet n_ex=1000 eps=12.75 p=0.05 n_iter=10000.metrics.npy ├── 2019-11-10 15:57:14 model=pt_vgg dataset=imagenet n_ex=1000 eps=12.75 p=0.05 n_iter=10000.metrics.npy ├── square_l2_inceptionv3_queries.npy ├── square_l2_resnet50_queries.npy └── square_l2_vgg16_queries.npy ├── models.py ├── post_avg ├── LICENSE.txt ├── PADefense.py ├── README.md ├── attacks.py ├── postAveragedModels.py ├── resnetSmall.py ├── robustml_test_cifar10.py ├── robustml_test_imagenet.py └── visualHelper.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | # data files 4 | MNIST_DATA 5 | 6 | fast_mnist/ 7 | 8 | # model files 9 | models 10 | secret.zip 11 | 12 | # compiled python files 13 | *.pyc 14 | 15 | .idea/ 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, Maksym Andriushchenko, Francesco Croce, Nicolas Flammarion, Matthias Hein 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | * Redistributions of source code must retain the above copyright 7 | notice, this list of conditions and the following disclaimer. 8 | * Redistributions in binary form must reproduce the above copyright 9 | notice, this list of conditions and the following disclaimer in the 10 | documentation and/or other materials provided with the distribution. 11 | * Neither the name of the copyright holder nor the 12 | names of its contributors may be used to endorse or promote products 13 | derived from this software without specific prior written permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER BE LIABLE FOR ANY 19 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Square Attack: a query-efficient black-box adversarial attack via random search 2 | **ECCV 2020** 3 | 4 | **Maksym Andriushchenko\*, Francesco Croce\*, Nicolas Flammarion, Matthias Hein** 5 | 6 | **EPFL, University of Tübingen** 7 | 8 | **Paper:** [https://arxiv.org/abs/1912.00049](https://arxiv.org/abs/1912.00049) 9 | 10 | \* denotes equal contribution 11 | 12 | 13 | ## News 14 | + [Jul 2020] The paper is accepted at **ECCV 2020**! Please stop by our virtual poster for the latest insights in black-box adversarial attacks (also check out our recent preprint [Sparse-RS paper](https://arxiv.org/abs/2006.12834) where we use random search for sparse attacks). 15 | + [Mar 2020] Our attack is now part of [AutoAttack](https://github.com/fra31/auto-attack), an ensemble of attacks used 16 | for automatic (i.e., no hyperparameter tuning needed) robustness evaluation. Table 2 in the [AutoAttack paper](https://arxiv.org/abs/2003.01690) 17 | shows that at least on 6 models our **black-box** attack outperforms gradient-based methods. Always useful to have a black-box attack to prevent inaccurate robustness claims! 18 | + [Mar 2020] We also achieve the best results on [TRADES MNIST benchmark](https://github.com/yaodongyu/TRADES)! 19 | + [Jan 2020] The Square Attack achieves the best results on [MadryLab's MNIST challenge](https://github.com/MadryLab/mnist_challenge), 20 | outperforming all white-box attacks! In this case we used 50 random restarts of our attack, each with a query limit of 20000. 21 | + [Nov 2019] The Square Attack breaks the recently proposed defense from "Bandlimiting Neural Networks Against Adversarial Attacks" 22 | ([https://github.com/robust-ml/robust-ml.github.io/issues/15](https://github.com/robust-ml/robust-ml.github.io/issues/15)). 23 | 24 | 25 | ## Abstract 26 | We propose the *Square Attack*, a score-based black-box L2- and Linf-adversarial attack that does not 27 | rely on local gradient information and thus is not affected by gradient masking. Square Attack is based on a randomized 28 | search scheme which selects localized square-shaped updates at random positions so that at each iteration the perturbation 29 | is situated approximately at the boundary of the feasible set. Our method is significantly more query efficient and achieves a higher success rate compared to the state-of-the-art 30 | methods, especially in the untargeted setting. In particular, on ImageNet we improve the average query efficiency in the untargeted setting for various deep networks 31 | by a factor of at least 1.8 and up to 3 compared to the recent state-of-the-art Linf-attack of Al-Dujaili & O’Reilly. 32 | Moreover, although our attack is *black-box*, it can also outperform gradient-based *white-box* attacks 33 | on the standard benchmarks achieving a new state-of-the-art in terms of the success rate. 34 | 35 | ----- 36 | 37 | The code of the Square Attack can be found in `square_attack_linf(...)` and `square_attack_l2(...)` in `attack.py`.\ 38 | Below we show adversarial examples generated by our method for Linf and L2 perturbations: 39 |

40 | 41 | 45 | 46 | 47 | ## About the paper 48 | The general algorithm of the attack is extremely simple and relies on the random search algorithm: we try some update and 49 | accept it only if it helps to improve the loss: 50 |

51 | 52 | The only thing we customize is the sampling distribution P (see the paper for details). The main idea behind the choice 53 | of the sampling distributions is that: 54 | - We start at the boundary of the feasible set with a good initialization that helps to improve the query efficiency (particularly for the Linf-attack). 55 | - Every iteration we stay at the boundary of the feasible set by changing squared-shaped regions of the image. 56 | 57 | In the paper we also provide convergence analysis of a variant of our attack in the non-convex setting, and justify 58 | the main algorithmic choices such as modifying squares and using the same sign of the update. 59 | 60 | This simple algorithm is sufficient to significantly outperform much more complex approaches in terms of the success rate 61 | and query efficiency: 62 |

63 |

64 | 65 | Here are the complete success rate curves with respect to different number of queries. We note that the Square Attack 66 | also outperforms the competing approaches in the low-query regime. 67 |

68 | 69 | The Square Attack also performs very well on adversarially trained models on MNIST achieving results competitive or 70 | better than *white-box* attacks despite the fact our attack is *black-box*: 71 |

72 | 73 | Interestingly, the L2 perturbations for the Linf adversarially trained model are challenging for many attacks, including 74 | white-box PGD, and also other black-box attacks. However, the Square Attack is able to much more accurately assess the 75 | robustness in this setting: 76 |

77 | 78 | 93 | 94 | 95 | 96 | ## Running the code 97 | `attack.py` is the main module that implements the Square Attack, see the command line arguments there. 98 | The main functions which implement the attack are `square_attack_linf()` and `square_attack_l2()`. 99 | 100 | In order to run the untargeted Linf Square Attack on ImageNet models from the PyTorch repository you need to specify a correct path 101 | to the validation set (see `IMAGENET_PATH` in `data.py`) and then run: 102 | - ``` python attack.py --attack=square_linf --model=pt_vgg --n_ex=1000 --eps=12.75 --p=0.05 --n_iter=10000 ``` 103 | - ``` python attack.py --attack=square_linf --model=pt_resnet --n_ex=1000 --eps=12.75 --p=0.05 --n_iter=10000 ``` 104 | - ``` python attack.py --attack=square_linf --model=pt_inception --n_ex=1000 --eps=12.75 --p=0.05 --n_iter=10000 ``` 105 | 106 | Note that eps=12.75 is then divided by 255, so in the end it is equal to 0.05. 107 | 108 | For performing targeted attacks, one should use additionally the flag `--targeted`, use a lower `p`, and specify more 109 | iterations `--n_iter=100000` since it usually takes more iteration to achieve a misclassification to some particular, 110 | randomly chosen class. 111 | 112 | The rest of the models have to downloaded first (see the instructions below), and then can be evaluated in the following way: 113 | 114 | Post-averaging models: 115 | - ``` python attack.py --attack=square_linf --model=pt_post_avg_cifar10 --n_ex=1000 --eps=8.0 --p=0.3 --n_iter=20000 ``` 116 | - ``` python attack.py --attack=square_linf --model=pt_post_avg_imagenet --n_ex=1000 --eps=8.0 --p=0.3 --n_iter=20000 ``` 117 | 118 | Clean logit pairing and logit squeezing models: 119 | - ``` python attack.py --attack=square_linf --model=clp_mnist --n_ex=1000 --eps=0.3 --p=0.3 --n_iter=20000 ``` 120 | - ``` python attack.py --attack=square_linf --model=lsq_mnist --n_ex=1000 --eps=0.3 --p=0.3 --n_iter=20000 ``` 121 | - ``` python attack.py --attack=square_linf --model=clp_cifar10 --n_ex=1000 --eps=16.0 --p=0.3 --n_iter=20000 ``` 122 | - ``` python attack.py --attack=square_linf --model=lsq_cifar10 --n_ex=1000 --eps=16.0 --p=0.3 --n_iter=20000 ``` 123 | 124 | Adversarially trained model (with only 1 restart; note that the results in the paper are based on 50 restarts): 125 | - ``` python attack.py --attack=square_linf --model=madry_mnist_robust --n_ex=10000 --eps=0.3 --p=0.8 --n_iter=20000 ``` 126 | 127 | The L2 Square Attack can be run similarly, but please check the recommended hyperparameters in the paper (Section B of the supplement) 128 | and make sure that you specify the right value `eps` taking into account whether the pixels are in [0, 1] or in [0, 255] 129 | for a particular dataset dataset and model. 130 | For example, for the standard ImageNet models, the correct L2 eps to specify is 1275 since after division by 255 it will become 5.0. 131 | 132 | 133 | 134 | ## Saved statistics 135 | In the folder `metrics`, we provide saved statistics of the attack on 4 models: Inception-v3, ResNet-50, VGG-16-BN.\ 136 | Here are simple examples how to load the metrics file. 137 | 138 | ### Linf attack 139 | To print the statistics from the last iteration: 140 | ``` 141 | metrics = np.load('metrics/2019-11-10 15:57:14 model=pt_resnet dataset=imagenet n_ex=1000 eps=12.75 p=0.05 n_iter=10000.metrics.npy') 142 | iteration = np.argmax(metrics[:, -1]) # max time is the last available iteration 143 | acc, acc_corr, mean_nq, mean_nq_ae, median_nq, avg_loss, time_total = metrics[iteration] 144 | print('[iter {}] acc={:.2%} acc_corr={:.2%} avg#q={:.2f} avg#q_ae={:.2f} med#q_ae={:.2f} (p={}, n_ex={}, eps={}, {:.2f}min)'. 145 | format(n_iters+1, acc, acc_corr, mean_nq, mean_nq_ae, median_nq_ae, p, n_ex, eps, time_total/60)) 146 | ``` 147 | 148 | Then one can also create different plots based on the data contained in `metrics`. For example, one can use `1 - acc_corr` 149 | to plot the success rate of the Square Attack at different number of queries. 150 | 151 | ### L2 attack 152 | In this case we provide the number of queries necessary to achieve misclassification (`n_queries[i] = 0` means that the image `i` was initially misclassified, `n_queries[i] = 10001` indicates that the attack could not find an adversarial example for the image `i`). 153 | To load the metrics and compute the success rate of the Square Attack after `k` queries, you can run: 154 | ``` 155 | n_queries = np.load('metrics/square_l2_resnet50_queries.npy')['n_queries'] 156 | success_rate = float(((n_queries > 0) * (n_queries <= k)).sum()) / (n_queries > 0).sum() 157 | ``` 158 | 159 | 160 | ## Models 161 | Note that in order to evaluate other models, one has to first download them and move them to the folders specified in 162 | `model_path_dict` from `models.py`: 163 | - [Clean Logit Pairing on MNIST](https://oc.cs.uni-saarland.de/owncloud/index.php/s/w2yegcfx8mc8kNa) 164 | - [Logit Squeezing on MNIST](https://oc.cs.uni-saarland.de/owncloud/index.php/s/a5ZY72BDCPEtb2S) 165 | - [Clean Logit Pairing on CIFAR-10](https://oc.cs.uni-saarland.de/owncloud/index.php/s/odcd7FgFdbqq6zL) 166 | - [Logit Squeezing on CIFAR-10](https://oc.cs.uni-saarland.de/owncloud/index.php/s/EYnbHDeMbe4mq5M) 167 | - MNIST, Madry adversarial training: run `python madry_mnist/fetch_model.py secret` 168 | - MNIST, TRADES: download the [models](https://drive.google.com/file/d/1scTd9-YO3-5Ul3q5SJuRrTNX__LYLD_M) and see their [repository](https://github.com/yaodongyu/TRADES) 169 | - [Post-averaging defense](https://github.com/YupingLin171/PostAvgDefense/blob/master/trainedModel/resnet110.th): the model can be downloaded directly from the repository 170 | 171 | For the first 4 models, one has to additionally update the paths in the `checkpoint` file in the following way: 172 | ``` 173 | model_checkpoint_path: "model.ckpt" 174 | all_model_checkpoint_paths: "model.ckpt" 175 | ``` 176 | 177 | 178 | 179 | ## Requirements 180 | - PyTorch 1.0.0 181 | - Tensorflow 1.12.0 182 | 183 | 184 | 185 | ## Contact 186 | Do you have a problem or question regarding the code? 187 | Please don't hesitate to open an issue or contact [Maksym Andriushchenko](https://github.com/max-andr) or 188 | [Francesco Croce](https://github.com/fra31) directly. 189 | 190 | 191 | ## Citation 192 | ``` 193 | @article{ACFH2020square, 194 | title={Square Attack: a query-efficient black-box adversarial attack via random search}, 195 | author={Andriushchenko, Maksym and Croce, Francesco and Flammarion, Nicolas and Hein, Matthias}, 196 | conference={ECCV}, 197 | year={2020} 198 | } 199 | ``` 200 | -------------------------------------------------------------------------------- /attack.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import numpy as np 4 | import data 5 | import models 6 | import os 7 | import utils 8 | from datetime import datetime 9 | np.set_printoptions(precision=5, suppress=True) 10 | 11 | 12 | def p_selection(p_init, it, n_iters): 13 | """ Piece-wise constant schedule for p (the fraction of pixels changed on every iteration). """ 14 | it = int(it / n_iters * 10000) 15 | 16 | if 10 < it <= 50: 17 | p = p_init / 2 18 | elif 50 < it <= 200: 19 | p = p_init / 4 20 | elif 200 < it <= 500: 21 | p = p_init / 8 22 | elif 500 < it <= 1000: 23 | p = p_init / 16 24 | elif 1000 < it <= 2000: 25 | p = p_init / 32 26 | elif 2000 < it <= 4000: 27 | p = p_init / 64 28 | elif 4000 < it <= 6000: 29 | p = p_init / 128 30 | elif 6000 < it <= 8000: 31 | p = p_init / 256 32 | elif 8000 < it <= 10000: 33 | p = p_init / 512 34 | else: 35 | p = p_init 36 | 37 | return p 38 | 39 | 40 | def pseudo_gaussian_pert_rectangles(x, y): 41 | delta = np.zeros([x, y]) 42 | x_c, y_c = x // 2 + 1, y // 2 + 1 43 | 44 | counter2 = [x_c - 1, y_c - 1] 45 | for counter in range(0, max(x_c, y_c)): 46 | delta[max(counter2[0], 0):min(counter2[0] + (2 * counter + 1), x), 47 | max(0, counter2[1]):min(counter2[1] + (2 * counter + 1), y)] += 1.0 / (counter + 1) ** 2 48 | 49 | counter2[0] -= 1 50 | counter2[1] -= 1 51 | 52 | delta /= np.sqrt(np.sum(delta ** 2, keepdims=True)) 53 | 54 | return delta 55 | 56 | 57 | def meta_pseudo_gaussian_pert(s): 58 | delta = np.zeros([s, s]) 59 | n_subsquares = 2 60 | if n_subsquares == 2: 61 | delta[:s // 2] = pseudo_gaussian_pert_rectangles(s // 2, s) 62 | delta[s // 2:] = pseudo_gaussian_pert_rectangles(s - s // 2, s) * (-1) 63 | delta /= np.sqrt(np.sum(delta ** 2, keepdims=True)) 64 | if np.random.rand(1) > 0.5: delta = np.transpose(delta) 65 | 66 | elif n_subsquares == 4: 67 | delta[:s // 2, :s // 2] = pseudo_gaussian_pert_rectangles(s // 2, s // 2) * np.random.choice([-1, 1]) 68 | delta[s // 2:, :s // 2] = pseudo_gaussian_pert_rectangles(s - s // 2, s // 2) * np.random.choice([-1, 1]) 69 | delta[:s // 2, s // 2:] = pseudo_gaussian_pert_rectangles(s // 2, s - s // 2) * np.random.choice([-1, 1]) 70 | delta[s // 2:, s // 2:] = pseudo_gaussian_pert_rectangles(s - s // 2, s - s // 2) * np.random.choice([-1, 1]) 71 | delta /= np.sqrt(np.sum(delta ** 2, keepdims=True)) 72 | 73 | return delta 74 | 75 | 76 | def square_attack_l2(model, x, y, corr_classified, eps, n_iters, p_init, metrics_path, targeted, loss_type): 77 | """ The L2 square attack """ 78 | np.random.seed(0) 79 | 80 | min_val, max_val = 0, 1 81 | c, h, w = x.shape[1:] 82 | n_features = c * h * w 83 | n_ex_total = x.shape[0] 84 | x, y = x[corr_classified], y[corr_classified] 85 | 86 | ### initialization 87 | delta_init = np.zeros(x.shape) 88 | s = h // 5 89 | log.print('Initial square side={} for bumps'.format(s)) 90 | sp_init = (h - s * 5) // 2 91 | center_h = sp_init + 0 92 | for counter in range(h // s): 93 | center_w = sp_init + 0 94 | for counter2 in range(w // s): 95 | delta_init[:, :, center_h:center_h + s, center_w:center_w + s] += meta_pseudo_gaussian_pert(s).reshape( 96 | [1, 1, s, s]) * np.random.choice([-1, 1], size=[x.shape[0], c, 1, 1]) 97 | center_w += s 98 | center_h += s 99 | 100 | x_best = np.clip(x + delta_init / np.sqrt(np.sum(delta_init ** 2, axis=(1, 2, 3), keepdims=True)) * eps, 0, 1) 101 | 102 | logits = model.predict(x_best) 103 | loss_min = model.loss(y, logits, targeted, loss_type=loss_type) 104 | margin_min = model.loss(y, logits, targeted, loss_type='margin_loss') 105 | n_queries = np.ones(x.shape[0]) # ones because we have already used 1 query 106 | 107 | time_start = time.time() 108 | s_init = int(np.sqrt(p_init * n_features / c)) 109 | metrics = np.zeros([n_iters, 7]) 110 | for i_iter in range(n_iters): 111 | idx_to_fool = (margin_min > 0.0) 112 | 113 | x_curr, x_best_curr = x[idx_to_fool], x_best[idx_to_fool] 114 | y_curr, margin_min_curr = y[idx_to_fool], margin_min[idx_to_fool] 115 | loss_min_curr = loss_min[idx_to_fool] 116 | delta_curr = x_best_curr - x_curr 117 | 118 | p = p_selection(p_init, i_iter, n_iters) 119 | s = max(int(round(np.sqrt(p * n_features / c))), 3) 120 | 121 | if s % 2 == 0: 122 | s += 1 123 | 124 | s2 = s + 0 125 | ### window_1 126 | center_h = np.random.randint(0, h - s) 127 | center_w = np.random.randint(0, w - s) 128 | new_deltas_mask = np.zeros(x_curr.shape) 129 | new_deltas_mask[:, :, center_h:center_h + s, center_w:center_w + s] = 1.0 130 | 131 | ### window_2 132 | center_h_2 = np.random.randint(0, h - s2) 133 | center_w_2 = np.random.randint(0, w - s2) 134 | new_deltas_mask_2 = np.zeros(x_curr.shape) 135 | new_deltas_mask_2[:, :, center_h_2:center_h_2 + s2, center_w_2:center_w_2 + s2] = 1.0 136 | norms_window_2 = np.sqrt( 137 | np.sum(delta_curr[:, :, center_h_2:center_h_2 + s2, center_w_2:center_w_2 + s2] ** 2, axis=(-2, -1), 138 | keepdims=True)) 139 | 140 | ### compute total norm available 141 | curr_norms_window = np.sqrt( 142 | np.sum(((x_best_curr - x_curr) * new_deltas_mask) ** 2, axis=(2, 3), keepdims=True)) 143 | curr_norms_image = np.sqrt(np.sum((x_best_curr - x_curr) ** 2, axis=(1, 2, 3), keepdims=True)) 144 | mask_2 = np.maximum(new_deltas_mask, new_deltas_mask_2) 145 | norms_windows = np.sqrt(np.sum((delta_curr * mask_2) ** 2, axis=(2, 3), keepdims=True)) 146 | 147 | ### create the updates 148 | new_deltas = np.ones([x_curr.shape[0], c, s, s]) 149 | new_deltas = new_deltas * meta_pseudo_gaussian_pert(s).reshape([1, 1, s, s]) 150 | new_deltas *= np.random.choice([-1, 1], size=[x_curr.shape[0], c, 1, 1]) 151 | old_deltas = delta_curr[:, :, center_h:center_h + s, center_w:center_w + s] / (1e-10 + curr_norms_window) 152 | new_deltas += old_deltas 153 | new_deltas = new_deltas / np.sqrt(np.sum(new_deltas ** 2, axis=(2, 3), keepdims=True)) * ( 154 | np.maximum(eps ** 2 - curr_norms_image ** 2, 0) / c + norms_windows ** 2) ** 0.5 155 | delta_curr[:, :, center_h_2:center_h_2 + s2, center_w_2:center_w_2 + s2] = 0.0 # set window_2 to 0 156 | delta_curr[:, :, center_h:center_h + s, center_w:center_w + s] = new_deltas + 0 # update window_1 157 | 158 | hps_str = 's={}->{}'.format(s_init, s) 159 | x_new = x_curr + delta_curr / np.sqrt(np.sum(delta_curr ** 2, axis=(1, 2, 3), keepdims=True)) * eps 160 | x_new = np.clip(x_new, min_val, max_val) 161 | curr_norms_image = np.sqrt(np.sum((x_new - x_curr) ** 2, axis=(1, 2, 3), keepdims=True)) 162 | 163 | logits = model.predict(x_new) 164 | loss = model.loss(y_curr, logits, targeted, loss_type=loss_type) 165 | margin = model.loss(y_curr, logits, targeted, loss_type='margin_loss') 166 | 167 | idx_improved = loss < loss_min_curr 168 | loss_min[idx_to_fool] = idx_improved * loss + ~idx_improved * loss_min_curr 169 | margin_min[idx_to_fool] = idx_improved * margin + ~idx_improved * margin_min_curr 170 | 171 | idx_improved = np.reshape(idx_improved, [-1, *[1] * len(x.shape[:-1])]) 172 | x_best[idx_to_fool] = idx_improved * x_new + ~idx_improved * x_best_curr 173 | n_queries[idx_to_fool] += 1 174 | 175 | acc = (margin_min > 0.0).sum() / n_ex_total 176 | acc_corr = (margin_min > 0.0).mean() 177 | mean_nq, mean_nq_ae, median_nq, median_nq_ae = np.mean(n_queries), np.mean( 178 | n_queries[margin_min <= 0]), np.median(n_queries), np.median(n_queries[margin_min <= 0]) 179 | 180 | time_total = time.time() - time_start 181 | log.print( 182 | '{}: acc={:.2%} acc_corr={:.2%} avg#q_ae={:.1f} med#q_ae={:.1f} {}, n_ex={}, {:.0f}s, loss={:.3f}, max_pert={:.1f}, impr={:.0f}'. 183 | format(i_iter + 1, acc, acc_corr, mean_nq_ae, median_nq_ae, hps_str, x.shape[0], time_total, 184 | np.mean(margin_min), np.amax(curr_norms_image), np.sum(idx_improved))) 185 | metrics[i_iter] = [acc, acc_corr, mean_nq, mean_nq_ae, median_nq, margin_min.mean(), time_total] 186 | if (i_iter <= 500 and i_iter % 500) or (i_iter > 100 and i_iter % 500) or i_iter + 1 == n_iters or acc == 0: 187 | np.save(metrics_path, metrics) 188 | if acc == 0: 189 | curr_norms_image = np.sqrt(np.sum((x_best - x) ** 2, axis=(1, 2, 3), keepdims=True)) 190 | print('Maximal norm of the perturbations: {:.5f}'.format(np.amax(curr_norms_image))) 191 | break 192 | 193 | curr_norms_image = np.sqrt(np.sum((x_best - x) ** 2, axis=(1, 2, 3), keepdims=True)) 194 | print('Maximal norm of the perturbations: {:.5f}'.format(np.amax(curr_norms_image))) 195 | 196 | return n_queries, x_best 197 | 198 | 199 | def square_attack_linf(model, x, y, corr_classified, eps, n_iters, p_init, metrics_path, targeted, loss_type): 200 | """ The Linf square attack """ 201 | np.random.seed(0) # important to leave it here as well 202 | min_val, max_val = 0, 1 if x.max() <= 1 else 255 203 | c, h, w = x.shape[1:] 204 | n_features = c*h*w 205 | n_ex_total = x.shape[0] 206 | x, y = x[corr_classified], y[corr_classified] 207 | 208 | # [c, 1, w], i.e. vertical stripes work best for untargeted attacks 209 | init_delta = np.random.choice([-eps, eps], size=[x.shape[0], c, 1, w]) 210 | x_best = np.clip(x + init_delta, min_val, max_val) 211 | 212 | logits = model.predict(x_best) 213 | loss_min = model.loss(y, logits, targeted, loss_type=loss_type) 214 | margin_min = model.loss(y, logits, targeted, loss_type='margin_loss') 215 | n_queries = np.ones(x.shape[0]) # ones because we have already used 1 query 216 | 217 | time_start = time.time() 218 | metrics = np.zeros([n_iters, 7]) 219 | for i_iter in range(n_iters - 1): 220 | idx_to_fool = margin_min > 0 221 | x_curr, x_best_curr, y_curr = x[idx_to_fool], x_best[idx_to_fool], y[idx_to_fool] 222 | loss_min_curr, margin_min_curr = loss_min[idx_to_fool], margin_min[idx_to_fool] 223 | deltas = x_best_curr - x_curr 224 | 225 | p = p_selection(p_init, i_iter, n_iters) 226 | for i_img in range(x_best_curr.shape[0]): 227 | s = int(round(np.sqrt(p * n_features / c))) 228 | s = min(max(s, 1), h-1) # at least c x 1 x 1 window is taken and at most c x h-1 x h-1 229 | center_h = np.random.randint(0, h - s) 230 | center_w = np.random.randint(0, w - s) 231 | 232 | x_curr_window = x_curr[i_img, :, center_h:center_h+s, center_w:center_w+s] 233 | x_best_curr_window = x_best_curr[i_img, :, center_h:center_h+s, center_w:center_w+s] 234 | # prevent trying out a delta if it doesn't change x_curr (e.g. an overlapping patch) 235 | while np.sum(np.abs(np.clip(x_curr_window + deltas[i_img, :, center_h:center_h+s, center_w:center_w+s], min_val, max_val) - x_best_curr_window) < 10**-7) == c*s*s: 236 | deltas[i_img, :, center_h:center_h+s, center_w:center_w+s] = np.random.choice([-eps, eps], size=[c, 1, 1]) 237 | 238 | x_new = np.clip(x_curr + deltas, min_val, max_val) 239 | 240 | logits = model.predict(x_new) 241 | loss = model.loss(y_curr, logits, targeted, loss_type=loss_type) 242 | margin = model.loss(y_curr, logits, targeted, loss_type='margin_loss') 243 | 244 | idx_improved = loss < loss_min_curr 245 | loss_min[idx_to_fool] = idx_improved * loss + ~idx_improved * loss_min_curr 246 | margin_min[idx_to_fool] = idx_improved * margin + ~idx_improved * margin_min_curr 247 | idx_improved = np.reshape(idx_improved, [-1, *[1]*len(x.shape[:-1])]) 248 | x_best[idx_to_fool] = idx_improved * x_new + ~idx_improved * x_best_curr 249 | n_queries[idx_to_fool] += 1 250 | 251 | acc = (margin_min > 0.0).sum() / n_ex_total 252 | acc_corr = (margin_min > 0.0).mean() 253 | mean_nq, mean_nq_ae, median_nq_ae = np.mean(n_queries), np.mean(n_queries[margin_min <= 0]), np.median(n_queries[margin_min <= 0]) 254 | avg_margin_min = np.mean(margin_min) 255 | time_total = time.time() - time_start 256 | log.print('{}: acc={:.2%} acc_corr={:.2%} avg#q_ae={:.2f} med#q={:.1f}, avg_margin={:.2f} (n_ex={}, eps={:.3f}, {:.2f}s)'. 257 | format(i_iter+1, acc, acc_corr, mean_nq_ae, median_nq_ae, avg_margin_min, x.shape[0], eps, time_total)) 258 | 259 | metrics[i_iter] = [acc, acc_corr, mean_nq, mean_nq_ae, median_nq_ae, margin_min.mean(), time_total] 260 | if (i_iter <= 500 and i_iter % 20 == 0) or (i_iter > 100 and i_iter % 50 == 0) or i_iter + 1 == n_iters or acc == 0: 261 | np.save(metrics_path, metrics) 262 | if acc == 0: 263 | break 264 | 265 | return n_queries, x_best 266 | 267 | 268 | if __name__ == '__main__': 269 | parser = argparse.ArgumentParser(description='Define hyperparameters.') 270 | parser.add_argument('--model', type=str, default='pt_resnet', choices=models.all_model_names, help='Model name.') 271 | parser.add_argument('--attack', type=str, default='square_linf', choices=['square_linf', 'square_l2'], help='Attack.') 272 | parser.add_argument('--exp_folder', type=str, default='exps', help='Experiment folder to store all output.') 273 | parser.add_argument('--gpu', type=str, default='7', help='GPU number. Multiple GPUs are possible for PT models.') 274 | parser.add_argument('--n_ex', type=int, default=10000, help='Number of test ex to test on.') 275 | parser.add_argument('--p', type=float, default=0.05, 276 | help='Probability of changing a coordinate. Note: check the paper for the best values. ' 277 | 'Linf standard: 0.05, L2 standard: 0.1. But robust models require higher p.') 278 | parser.add_argument('--eps', type=float, default=0.05, help='Radius of the Lp ball.') 279 | parser.add_argument('--n_iter', type=int, default=10000) 280 | parser.add_argument('--targeted', action='store_true', help='Targeted or untargeted attack.') 281 | args = parser.parse_args() 282 | args.loss = 'margin_loss' if not args.targeted else 'cross_entropy' 283 | 284 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 285 | dataset = 'mnist' if 'mnist' in args.model else 'cifar10' if 'cifar10' in args.model else 'imagenet' 286 | timestamp = str(datetime.now())[:-7] 287 | hps_str = '{} model={} dataset={} attack={} n_ex={} eps={} p={} n_iter={}'.format( 288 | timestamp, args.model, dataset, args.attack, args.n_ex, args.eps, args.p, args.n_iter) 289 | args.eps = args.eps / 255.0 if dataset == 'imagenet' else args.eps # for mnist and cifar10 we leave as it is 290 | batch_size = data.bs_dict[dataset] 291 | model_type = 'pt' if 'pt_' in args.model else 'tf' 292 | n_cls = 1000 if dataset == 'imagenet' else 10 293 | gpu_memory = 0.5 if dataset == 'mnist' and args.n_ex > 1000 else 0.15 if dataset == 'mnist' else 0.99 294 | 295 | log_path = '{}/{}.log'.format(args.exp_folder, hps_str) 296 | metrics_path = '{}/{}.metrics'.format(args.exp_folder, hps_str) 297 | 298 | log = utils.Logger(log_path) 299 | log.print('All hps: {}'.format(hps_str)) 300 | 301 | if args.model != 'pt_inception': 302 | x_test, y_test = data.datasets_dict[dataset](args.n_ex) 303 | else: # exception for inception net on imagenet -- 299x299 images instead of 224x224 304 | x_test, y_test = data.datasets_dict[dataset](args.n_ex, size=299) 305 | x_test, y_test = x_test[:args.n_ex], y_test[:args.n_ex] 306 | 307 | if args.model == 'pt_post_avg_cifar10': 308 | x_test /= 255.0 309 | args.eps = args.eps / 255.0 310 | 311 | models_class_dict = {'tf': models.ModelTF, 'pt': models.ModelPT} 312 | model = models_class_dict[model_type](args.model, batch_size, gpu_memory) 313 | 314 | logits_clean = model.predict(x_test) 315 | corr_classified = logits_clean.argmax(1) == y_test 316 | # important to check that the model was restored correctly and the clean accuracy is high 317 | log.print('Clean accuracy: {:.2%}'.format(np.mean(corr_classified))) 318 | 319 | square_attack = square_attack_linf if args.attack == 'square_linf' else square_attack_l2 320 | y_target = utils.random_classes_except_current(y_test, n_cls) if args.targeted else y_test 321 | y_target_onehot = utils.dense_to_onehot(y_target, n_cls=n_cls) 322 | # Note: we count the queries only across correctly classified images 323 | n_queries, x_adv = square_attack(model, x_test, y_target_onehot, corr_classified, args.eps, args.n_iter, 324 | args.p, metrics_path, args.targeted, args.loss) 325 | 326 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torchvision import transforms 4 | from torchvision.datasets import ImageFolder 5 | from torch.utils.data import DataLoader 6 | 7 | 8 | def load_mnist(n_ex): 9 | from tensorflow.keras.datasets import mnist as mnist_keras 10 | 11 | x_test, y_test = mnist_keras.load_data()[1] 12 | x_test = x_test.astype(np.float64) / 255.0 13 | x_test = x_test[:, None, :, :] 14 | 15 | return x_test[:n_ex], y_test[:n_ex] 16 | 17 | 18 | def load_cifar10(n_ex): 19 | from madry_cifar10.cifar10_input import CIFAR10Data 20 | cifar = CIFAR10Data('madry_cifar10/cifar10_data') 21 | x_test, y_test = cifar.eval_data.xs.astype(np.float32), cifar.eval_data.ys 22 | x_test = np.transpose(x_test, axes=[0, 3, 1, 2]) 23 | return x_test[:n_ex], y_test[:n_ex] 24 | 25 | 26 | def load_imagenet(n_ex, size=224): 27 | IMAGENET_SL = size 28 | IMAGENET_PATH = "/scratch/maksym/imagenet/val_orig" 29 | imagenet = ImageFolder(IMAGENET_PATH, 30 | transforms.Compose([ 31 | transforms.Resize(IMAGENET_SL), 32 | transforms.CenterCrop(IMAGENET_SL), 33 | transforms.ToTensor() 34 | ])) 35 | torch.manual_seed(0) 36 | 37 | imagenet_loader = DataLoader(imagenet, batch_size=n_ex, shuffle=True, num_workers=1) 38 | x_test, y_test = next(iter(imagenet_loader)) 39 | 40 | return np.array(x_test, dtype=np.float32), np.array(y_test) 41 | 42 | 43 | datasets_dict = {'mnist': load_mnist, 44 | 'cifar10': load_cifar10, 45 | 'imagenet': load_imagenet, 46 | } 47 | bs_dict = {'mnist': 10000, 48 | 'cifar10': 4096, # 4096 is the maximum that fits 49 | 'imagenet': 100, 50 | } 51 | -------------------------------------------------------------------------------- /images/ablation_study.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/ablation_study.png -------------------------------------------------------------------------------- /images/adv_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/adv_examples.png -------------------------------------------------------------------------------- /images/algorithm_rs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/algorithm_rs.png -------------------------------------------------------------------------------- /images/ezgif.com-gif-maker-50-conf-small.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/ezgif.com-gif-maker-50-conf-small.gif -------------------------------------------------------------------------------- /images/ezgif.com-gif-maker-img-53-l2-2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/ezgif.com-gif-maker-img-53-l2-2.gif -------------------------------------------------------------------------------- /images/main_results_imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/main_results_imagenet.png -------------------------------------------------------------------------------- /images/main_results_imagenet_l2_commonly_successful.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/main_results_imagenet_l2_commonly_successful.png -------------------------------------------------------------------------------- /images/repository_picture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/repository_picture.png -------------------------------------------------------------------------------- /images/sensitivity_wrt_p.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/sensitivity_wrt_p.png -------------------------------------------------------------------------------- /images/success_rate_curves_full.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/success_rate_curves_full.png -------------------------------------------------------------------------------- /images/table_clp_lsq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/table_clp_lsq.png -------------------------------------------------------------------------------- /images/table_madry_mnist_l2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/table_madry_mnist_l2.png -------------------------------------------------------------------------------- /images/table_madry_trades_mnist_linf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/table_madry_trades_mnist_linf.png -------------------------------------------------------------------------------- /images/table_post_avg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/images/table_post_avg.png -------------------------------------------------------------------------------- /logit_pairing/models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from collections import OrderedDict 3 | 4 | 5 | # ------------------------------------------------------------- 6 | # Models 7 | # ------------------------------------------------------------- 8 | 9 | class LeNet: 10 | def __init__(self): 11 | super().__init__() 12 | self.nb_classes = 10 13 | self.input_shape = [28, 28, 3] 14 | self.weights_init = 'He' 15 | self.filters = 32 # 32 is the default here for all our pre-trained models 16 | self.is_training = False 17 | self.bn = False 18 | self.bn_scale = False 19 | self.bn_bias = False 20 | self.parameters = 0 21 | 22 | # Create variables 23 | with tf.variable_scope('conv1_vars'): 24 | self.W_conv1 = create_conv2d_weights(kernel_size=3, filter_in=1, filter_out=self.filters, 25 | init=self.weights_init) 26 | self.parameters += 3 * 3 * self.input_shape[-1] * self.filters 27 | 28 | self.b_conv1 = create_biases(size=self.filters) 29 | self.parameters += self.filters 30 | 31 | with tf.variable_scope('conv2_vars'): 32 | self.W_conv2 = create_conv2d_weights(kernel_size=3, filter_in=self.filters, filter_out=self.filters * 2, 33 | init=self.weights_init) 34 | self.parameters += 3 * 3 * self.filters * (self.filters * 2) 35 | 36 | self.b_conv2 = create_biases(size=self.filters * 2) 37 | self.parameters += self.filters * 2 38 | 39 | with tf.variable_scope('fc1_vars'): 40 | self.W_fc1 = create_weights(units_in=7 * 7 * self.filters * 2, units_out=1024, init=self.weights_init) 41 | self.parameters += (7 * 7 * self.filters * 2) * 1024 42 | 43 | self.b_fc1 = create_biases(size=1024) 44 | self.parameters += 1024 45 | 46 | with tf.variable_scope('fc2_vars'): 47 | self.W_fc2 = create_weights(units_in=1024, units_out=self.nb_classes, init=self.weights_init) 48 | self.parameters += 1024 * self.nb_classes 49 | 50 | self.b_fc2 = create_biases(size=self.nb_classes) 51 | self.parameters += self.nb_classes 52 | 53 | self.x_input = tf.placeholder(tf.float32, shape=[None, 784]) 54 | self.y_input = tf.placeholder(tf.int64, shape=[None]) 55 | 56 | x = tf.reshape(self.x_input, [-1, 28, 28, 1]) 57 | 58 | with tf.name_scope('conv-block-1'): 59 | conv1 = conv_layer(x, self.is_training, self.W_conv1, stride=1, padding='SAME', bn=self.bn, 60 | bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv1', bias=self.b_conv1) 61 | 62 | with tf.name_scope('max-pool-1'): 63 | conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID') 64 | 65 | with tf.name_scope('conv-block-2'): 66 | conv2 = conv_layer(conv1, self.is_training, self.W_conv2, stride=1, padding='SAME', bn=self.bn, 67 | bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv2', bias=self.b_conv2) 68 | 69 | with tf.name_scope('max-pool-2'): 70 | conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID') 71 | 72 | with tf.name_scope('fc-block'): 73 | conv2 = tf.layers.flatten(conv2) 74 | fc1 = fc_layer(conv2, self.is_training, self.W_fc1, bn=self.bn, bn_scale=self.bn_scale, 75 | bn_bias=self.bn_bias, name='fc1', non_linearity='relu', bias=self.b_fc1) 76 | 77 | logits = fc_layer(fc1, self.is_training, self.W_fc2, bn=self.bn, bn_scale=self.bn_scale, 78 | bn_bias=self.bn_bias, name='fc2', non_linearity='linear', bias=self.b_fc2) 79 | 80 | self.summaries = False 81 | self.logits = logits 82 | 83 | 84 | class ResNet20_v2: 85 | def __init__(self): 86 | super().__init__() 87 | self.nb_classes = 10 88 | self.input_shape = [32, 32, 3] 89 | self.weights_init = 'He' 90 | self.filters = 64 # 64 is the default here for all our pre-trained models 91 | self.is_training = False 92 | self.bn = True 93 | self.bn_scale = True 94 | self.bn_bias = True 95 | self.parameters = 0 96 | 97 | # Create variables 98 | with tf.variable_scope('conv1_vars'): 99 | self.W_conv1 = create_conv2d_weights(kernel_size=3, filter_in=self.input_shape[-1], filter_out=self.filters, 100 | init=self.weights_init) 101 | self.parameters += 3 * 3 * self.input_shape[-1] * self.filters 102 | 103 | with tf.variable_scope('conv2_vars'): 104 | self.W_conv2 = create_conv2d_weights(kernel_size=3, filter_in=self.filters, filter_out=self.filters, 105 | init=self.weights_init) 106 | self.parameters += 3 * 3 * self.filters * self.filters 107 | 108 | with tf.variable_scope('conv3_vars'): 109 | self.W_conv3 = create_conv2d_weights(kernel_size=3, filter_in=self.filters, filter_out=self.filters, 110 | init=self.weights_init) 111 | self.parameters += 3 * 3 * self.filters * self.filters 112 | 113 | with tf.variable_scope('conv4_vars'): 114 | self.W_conv4 = create_conv2d_weights(kernel_size=3, filter_in=self.filters, filter_out=self.filters, 115 | init=self.weights_init) 116 | self.parameters += 3 * 3 * self.filters * self.filters 117 | 118 | with tf.variable_scope('conv5_vars'): 119 | self.W_conv5 = create_conv2d_weights(kernel_size=3, filter_in=self.filters, filter_out=self.filters, 120 | init=self.weights_init) 121 | self.parameters += 3 * 3 * self.filters * self.filters 122 | 123 | with tf.variable_scope('conv6_vars'): 124 | self.W_conv6 = create_conv2d_weights(kernel_size=3, filter_in=self.filters, filter_out=self.filters, 125 | init=self.weights_init) 126 | self.parameters += 3 * 3 * self.filters * self.filters 127 | 128 | with tf.variable_scope('conv7_vars'): 129 | self.W_conv7 = create_conv2d_weights(kernel_size=3, filter_in=self.filters, filter_out=self.filters, 130 | init=self.weights_init) 131 | self.parameters += 3 * 3 * self.filters * self.filters 132 | 133 | with tf.variable_scope('conv8_vars'): 134 | self.W_conv8 = create_conv2d_weights(kernel_size=3, filter_in=self.filters, filter_out=self.filters * 2, 135 | init=self.weights_init) 136 | self.parameters += 3 * 3 * self.filters * (self.filters * 2) 137 | 138 | with tf.variable_scope('conv9_vars'): 139 | self.W_conv9 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 2, filter_out=self.filters * 2, 140 | init=self.weights_init) 141 | self.parameters += 3 * 3 * (self.filters * 2) * (self.filters * 2) 142 | 143 | with tf.variable_scope('conv10_vars'): 144 | self.W_conv10 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 2, 145 | filter_out=self.filters * 2, init=self.weights_init) 146 | self.parameters += 3 * 3 * (self.filters * 2) * (self.filters * 2) 147 | 148 | with tf.variable_scope('conv11_vars'): 149 | self.W_conv11 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 2, 150 | filter_out=self.filters * 2, init=self.weights_init) 151 | self.parameters += 3 * 3 * (self.filters * 2) * (self.filters * 2) 152 | 153 | with tf.variable_scope('conv12_vars'): 154 | self.W_conv12 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 2, 155 | filter_out=self.filters * 2, init=self.weights_init) 156 | self.parameters += 3 * 3 * (self.filters * 2) * (self.filters * 2) 157 | 158 | with tf.variable_scope('conv13_vars'): 159 | self.W_conv13 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 2, 160 | filter_out=self.filters * 2, init=self.weights_init) 161 | self.parameters += 3 * 3 * (self.filters * 2) * (self.filters * 2) 162 | 163 | with tf.variable_scope('conv14_vars'): 164 | self.W_conv14 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 2, 165 | filter_out=self.filters * 4, init=self.weights_init) 166 | self.parameters += 3 * 3 * (self.filters * 2) * (self.filters * 4) 167 | 168 | with tf.variable_scope('conv15_vars'): 169 | self.W_conv15 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 4, 170 | filter_out=self.filters * 4, init=self.weights_init) 171 | self.parameters += 3 * 3 * (self.filters * 4) * (self.filters * 4) 172 | 173 | with tf.variable_scope('conv16_vars'): 174 | self.W_conv16 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 4, 175 | filter_out=self.filters * 4, init=self.weights_init) 176 | self.parameters += 3 * 3 * (self.filters * 4) * (self.filters * 4) 177 | 178 | with tf.variable_scope('conv17_vars'): 179 | self.W_conv17 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 4, 180 | filter_out=self.filters * 4, init=self.weights_init) 181 | self.parameters += 3 * 3 * (self.filters * 4) * (self.filters * 4) 182 | 183 | with tf.variable_scope('conv18_vars'): 184 | self.W_conv18 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 4, 185 | filter_out=self.filters * 4, init=self.weights_init) 186 | self.parameters += 3 * 3 * (self.filters * 4) * (self.filters * 4) 187 | 188 | with tf.variable_scope('conv19_vars'): 189 | self.W_conv19 = create_conv2d_weights(kernel_size=3, filter_in=self.filters * 4, 190 | filter_out=self.filters * 4, init=self.weights_init) 191 | self.parameters += 3 * 3 * (self.filters * 4) * (self.filters * 4) 192 | 193 | with tf.variable_scope('fc1_vars'): 194 | self.W_fc1 = create_weights(units_in=self.filters * 4, units_out=self.nb_classes, init=self.weights_init) 195 | self.parameters += (self.filters * 4) * self.nb_classes 196 | 197 | self.b_fc1 = create_biases(size=self.nb_classes) 198 | self.parameters += self.nb_classes 199 | 200 | with tf.variable_scope('scip1_vars'): 201 | self.W_scip1 = create_conv2d_weights(kernel_size=1, filter_in=self.filters, filter_out=self.filters, 202 | init=self.weights_init) 203 | self.parameters += 1 * 1 * self.filters * self.filters 204 | 205 | with tf.variable_scope('scip2_vars'): 206 | self.W_scip2 = create_conv2d_weights(kernel_size=1, filter_in=self.filters, filter_out=self.filters * 2, 207 | init=self.weights_init) 208 | self.parameters += 1 * 1 * self.filters * (self.filters * 2) 209 | 210 | with tf.variable_scope('scip3_vars'): 211 | self.W_scip3 = create_conv2d_weights(kernel_size=1, filter_in=self.filters * 2, filter_out=self.filters * 4, 212 | init=self.weights_init) 213 | self.parameters += 1 * 1 * (self.filters * 2) * (self.filters * 4) 214 | 215 | self.x_input = tf.placeholder(tf.float32, shape=[None, 32, 32, 3]) 216 | self.y_input = tf.placeholder(tf.int64, shape=None) 217 | x = self.x_input / 255.0 218 | 219 | # Specify forward pass 220 | with tf.name_scope('input-block'): 221 | conv1 = conv_layer(x, self.is_training, self.W_conv1, stride=1, padding='SAME', 222 | bn=False, bn_scale=self.bn_scale, bn_bias=self.bn_bias, 223 | name='conv1', 224 | non_linearity='linear') 225 | 226 | with tf.name_scope('conv-block-1'): 227 | conv2 = pre_act_conv_layer(conv1, self.is_training, self.W_conv2, stride=1, padding='SAME', 228 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv2') 229 | 230 | conv3 = pre_act_conv_layer(conv2, self.is_training, self.W_conv3, stride=1, padding='SAME', 231 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv3') 232 | 233 | # skip connection 234 | conv3 += tf.nn.conv2d(conv1, self.W_scip1, strides=[1, 1, 1, 1], padding='SAME', name='conv-skip1') 235 | 236 | conv4 = pre_act_conv_layer(conv3, self.is_training, self.W_conv4, stride=1, padding='SAME', 237 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv4') 238 | 239 | conv5 = pre_act_conv_layer(conv4, self.is_training, self.W_conv5, stride=1, padding='SAME', 240 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv5') 241 | 242 | # skip connection 243 | conv5 += conv3 244 | 245 | conv6 = pre_act_conv_layer(conv5, self.is_training, self.W_conv6, stride=1, padding='SAME', 246 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv6') 247 | 248 | conv7 = pre_act_conv_layer(conv6, self.is_training, self.W_conv7, stride=1, padding='SAME', 249 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv7') 250 | 251 | # skip connection 252 | conv7 += conv5 253 | 254 | with tf.name_scope('conv-block-2'): 255 | conv8 = pre_act_conv_layer(conv7, self.is_training, self.W_conv8, stride=2, padding='SAME', 256 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv8') 257 | 258 | conv9 = pre_act_conv_layer(conv8, self.is_training, self.W_conv9, stride=1, padding='SAME', 259 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv9') 260 | 261 | # skip connection 262 | conv9 += tf.nn.conv2d(conv7, self.W_scip2, strides=[1, 2, 2, 1], padding='SAME', name='conv-skip2') 263 | 264 | conv10 = pre_act_conv_layer(conv9, self.is_training, self.W_conv10, stride=1, padding='SAME', 265 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv10') 266 | 267 | conv11 = pre_act_conv_layer(conv10, self.is_training, self.W_conv11, stride=1, padding='SAME', 268 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv11') 269 | 270 | # skip connection 271 | conv11 += conv9 272 | 273 | conv12 = pre_act_conv_layer(conv11, self.is_training, self.W_conv12, stride=1, padding='SAME', 274 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv12') 275 | 276 | conv13 = pre_act_conv_layer(conv12, self.is_training, self.W_conv13, stride=1, padding='SAME', 277 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv13') 278 | 279 | # skip connection 280 | conv13 += conv11 281 | 282 | with tf.name_scope('conv-block-3'): 283 | conv14 = pre_act_conv_layer(conv13, self.is_training, self.W_conv14, stride=2, padding='SAME', 284 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv14') 285 | 286 | conv15 = pre_act_conv_layer(conv14, self.is_training, self.W_conv15, stride=1, padding='SAME', 287 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv15') 288 | 289 | # skip connection 290 | conv15 += tf.nn.conv2d(conv13, self.W_scip3, strides=[1, 2, 2, 1], padding='SAME', name='conv-skip3') 291 | 292 | conv16 = pre_act_conv_layer(conv15, self.is_training, self.W_conv16, stride=1, padding='SAME', 293 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv16') 294 | 295 | conv17 = pre_act_conv_layer(conv16, self.is_training, self.W_conv17, stride=1, padding='SAME', 296 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv17') 297 | 298 | # skip connection 299 | conv17 += conv15 300 | 301 | conv18 = pre_act_conv_layer(conv17, self.is_training, self.W_conv18, stride=1, padding='SAME', 302 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv18') 303 | 304 | conv19 = pre_act_conv_layer(conv18, self.is_training, self.W_conv19, stride=1, padding='SAME', 305 | bn=self.bn, bn_scale=self.bn_scale, bn_bias=self.bn_bias, name='conv19') 306 | 307 | # skip connection 308 | conv19 += conv17 309 | conv19 = nonlinearity(conv19) 310 | 311 | with tf.name_scope('output-block'): 312 | with tf.name_scope('global-average-pooling'): 313 | fc1 = tf.reduce_mean(conv19, axis=[1, 2]) 314 | 315 | logits = fc_layer(fc1, self.is_training, self.W_fc1, bn=False, bn_scale=self.bn_scale, bn_bias=self.bn_bias, 316 | name='fc1', 317 | non_linearity='linear', bias=self.b_fc1) 318 | 319 | self.summaries = False 320 | self.logits = logits 321 | 322 | 323 | # ------------------------------------------------------------- 324 | # Helpers 325 | # ------------------------------------------------------------- 326 | 327 | def create_weights(units_in, units_out, init='Xavier', seed=None): 328 | if init == 'Xavier': 329 | initializer = tf.variance_scaling_initializer(scale=1.0, 330 | mode='fan_in', 331 | distribution='normal', 332 | seed=None, 333 | dtype=tf.float32) 334 | elif init == 'He': 335 | initializer = tf.variance_scaling_initializer(scale=2.0, 336 | mode='fan_in', 337 | distribution='normal', 338 | seed=None, 339 | dtype=tf.float32) 340 | else: 341 | initializer = tf.truncated_normal_initializer(mean=0.0, stddev=0.01, seed=seed, dtype=tf.float32) 342 | 343 | weights = tf.get_variable(name='weights', 344 | shape=[units_in, units_out], 345 | dtype=tf.float32, 346 | initializer=initializer) 347 | return weights 348 | 349 | 350 | def create_conv2d_weights(kernel_size, filter_in, filter_out, init='Xavier', seed=None): 351 | if init == 'Xavier': 352 | initializer = tf.variance_scaling_initializer(scale=1.0, 353 | mode='fan_in', 354 | distribution='normal', 355 | seed=None, 356 | dtype=tf.float32) 357 | elif init == 'He': 358 | initializer = tf.variance_scaling_initializer(scale=2.0, 359 | mode='fan_in', 360 | distribution='normal', 361 | seed=None, 362 | dtype=tf.float32) 363 | else: 364 | initializer = tf.truncated_normal_initializer(mean=0.0, stddev=0.01, seed=seed, dtype=tf.float32) 365 | 366 | weights = tf.get_variable(name='weights', 367 | shape=[kernel_size, kernel_size, filter_in, filter_out], 368 | dtype=tf.float32, 369 | initializer=initializer) 370 | return weights 371 | 372 | 373 | def create_biases(size): 374 | return tf.get_variable(name='biases', shape=[size], dtype=tf.float32, initializer=tf.zeros_initializer()) 375 | 376 | 377 | def batch_norm(x, is_training, scale, bias, name, reuse): 378 | return tf.contrib.layers.batch_norm( 379 | x, 380 | decay=0.999, 381 | center=bias, 382 | scale=scale, 383 | epsilon=0.001, 384 | param_initializers=None, 385 | updates_collections=tf.GraphKeys.UPDATE_OPS, 386 | is_training=is_training, 387 | reuse=reuse, 388 | variables_collections=['batch-norm'], 389 | outputs_collections=None, 390 | trainable=True, 391 | batch_weights=None, 392 | fused=False, 393 | zero_debias_moving_mean=False, 394 | scope=name, 395 | renorm=False, 396 | renorm_clipping=None, 397 | renorm_decay=0.99 398 | ) 399 | 400 | 401 | def nonlinearity(x, non_linearity='relu'): 402 | if non_linearity == 'linear': 403 | return tf.identity(x) 404 | if non_linearity == 'sigmoid': 405 | return tf.nn.sigmoid(x) 406 | if non_linearity == 'tanh': 407 | return tf.nn.tanh(x) 408 | if non_linearity == 'relu': 409 | return tf.nn.relu(x) 410 | if non_linearity == 'elu': 411 | return tf.nn.elu(x) 412 | if non_linearity == 'selu': 413 | return tf.nn.selu(x) 414 | 415 | 416 | def conv_layer(inputs, is_training, weights, stride, padding, bn, bn_scale, bn_bias, name, 417 | non_linearity='relu', bias=None): 418 | if bias is not None: 419 | inputs = tf.nn.conv2d(inputs, weights, strides=[1, stride, stride, 1], padding=padding) + bias 420 | else: 421 | inputs = tf.nn.conv2d(inputs, weights, strides=[1, stride, stride, 1], padding=padding) 422 | 423 | if bn: 424 | inputs = batch_norm(inputs, is_training=is_training, scale=bn_scale, bias=bn_bias, 425 | name='batch-norm-{:s}'.format(name), 426 | reuse=tf.AUTO_REUSE) 427 | 428 | activations = nonlinearity(inputs, non_linearity=non_linearity) 429 | 430 | return activations 431 | 432 | 433 | def pre_act_conv_layer(inputs, is_training, weights, stride, padding, bn, bn_scale, bn_bias, name, 434 | non_linearity='relu'): 435 | if bn: 436 | inputs = batch_norm(inputs, is_training=is_training, scale=bn_scale, bias=bn_bias, 437 | name='batch-norm-{:s}'.format(name), 438 | reuse=tf.AUTO_REUSE) 439 | 440 | activations = nonlinearity(inputs, non_linearity=non_linearity) 441 | 442 | outputs = tf.nn.conv2d(activations, weights, strides=[1, stride, stride, 1], padding=padding) 443 | 444 | return outputs 445 | 446 | 447 | def fc_layer(inputs, is_training, weights, bn, bn_scale, bn_bias, name, non_linearity='relu', bias=None): 448 | if bias is not None: 449 | inputs = tf.matmul(inputs, weights) + bias 450 | else: 451 | inputs = tf.matmul(inputs, weights) 452 | 453 | if bn: 454 | inputs = batch_norm(inputs, is_training=is_training, scale=bn_scale, bias=bn_bias, 455 | name='batch-norm-{:s}'.format(name), 456 | reuse=tf.AUTO_REUSE) 457 | 458 | activations = nonlinearity(inputs, non_linearity) 459 | 460 | return activations 461 | -------------------------------------------------------------------------------- /madry_cifar10/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Aleksander Madry, Aleksandar Makelov, Ludwig Schmidt, Dimitris Tsipras, and Adrian Vladu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /madry_cifar10/README.md: -------------------------------------------------------------------------------- 1 | # CIFAR10 Adversarial Examples Challenge 2 | 3 | Recently, there has been much progress on adversarial *attacks* against neural networks, such as the [cleverhans](https://github.com/tensorflow/cleverhans) library and the code by [Carlini and Wagner](https://github.com/carlini/nn_robust_attacks). 4 | We now complement these advances by proposing an *attack challenge* for the 5 | [CIFAR10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html) which follows the 6 | format of [our earlier MNIST challenge](https://github.com/MadryLab/mnist_challenge). 7 | We have trained a robust network, and the objective is to find a set of adversarial examples on which this network achieves only a low accuracy. 8 | To train an adversarially-robust network, we followed the approach from our recent paper: 9 | 10 | **Towards Deep Learning Models Resistant to Adversarial Attacks**
11 | *Aleksander Madry, Aleksandar Makelov, Ludwig Schmidt, Dimitris Tsipras, Adrian Vladu*
12 | https://arxiv.org/abs/1706.06083. 13 | 14 | As part of the challenge, we release both the training code and the network architecture, but keep the network weights secret. 15 | We invite any researcher to submit attacks against our model (see the detailed instructions below). 16 | We will maintain a leaderboard of the best attacks for the next two months and then publish our secret network weights. 17 | 18 | Analogously to our MNIST challenge, the goal of this challenge is to clarify the state-of-the-art for adversarial robustness on CIFAR10. Moreover, we hope that future work on defense mechanisms will adopt a similar challenge format in order to improve reproducibility and empirical comparisons. 19 | 20 | **Update 2017-12-10**: We released our secret model. You can download it by running `python fetch_model.py secret`. As of Dec 10 we are no longer accepting black-box challenge submissions. We have set up a leaderboard for white-box attacks on the (now released) secret model. The submission format is the same as before. We plan to continue evaluating submissions and maintaining the leaderboard for the foreseeable future. 21 | 22 | ## Black-Box Leaderboard (Original Challenge) 23 | 24 | | Attack | Submitted by | Accuracy | Submission Date | 25 | | -------------------------------------- | ------------- | -------- | ---- | 26 | | PGD on the cross-entropy loss for the
adversarially trained public network | (initial entry) | **63.39%** | Jul 12, 2017 | 27 | | PGD on the [CW](https://github.com/carlini/nn_robust_attacks) loss for the
adversarially trained public network | (initial entry) | 64.38% | Jul 12, 2017 | 28 | | FGSM on the [CW](https://github.com/carlini/nn_robust_attacks) loss for the
adversarially trained public network | (initial entry) | 67.25% | Jul 12, 2017 | 29 | | FGSM on the [CW](https://github.com/carlini/nn_robust_attacks) loss for the
naturally trained public network | (initial entry) | 85.23% | Jul 12, 2017 | 30 | 31 | ## White-Box Leaderboard 32 | 33 | | Attack | Submitted by | Accuracy | Submission Date | 34 | | -------------------------------------- | ------------- | -------- | ---- | 35 | | [FAB: Fast Adaptive Boundary Attack](https://github.com/fra31/fab-attack) | Francesco Croce | **44.51%** | Jun 7, 2019 | 36 | | [Distributionally Adversarial Attack](https://github.com/tianzheng4/Distributionally-Adversarial-Attack) | Tianhang Zheng | 44.71% | Aug 21, 2018 | 37 | | 20-step PGD on the cross-entropy loss
with 10 random restarts | Tianhang Zheng | 45.21% | Aug 24, 2018 | 38 | | 20-step PGD on the cross-entropy loss | (initial entry) | 47.04% | Dec 10, 2017 | 39 | | 20-step PGD on the [CW](https://github.com/carlini/nn_robust_attacks) loss | (initial entry) | 47.76% | Dec 10, 2017 | 40 | | FGSM on the [CW](https://github.com/carlini/nn_robust_attacks) loss | (initial entry) | 54.92% | Dec 10, 2017 | 41 | | FGSM on the cross-entropy loss | (initial entry) | 55.55% | Dec 10, 2017 | 42 | 43 | 44 | 45 | 46 | 47 | ## Format and Rules 48 | 49 | The objective of the challenge is to find black-box (transfer) attacks that are effective against our CIFAR10 model. 50 | Attacks are allowed to perturb each pixel of the input image by at most `epsilon=8.0` on a `0-255` pixel scale. 51 | To ensure that the attacks are indeed black-box, we release our training code and model architecture, but keep the actual network weights secret. 52 | 53 | We invite any interested researchers to submit attacks against our model. 54 | The most successful attacks will be listed in the leaderboard above. 55 | As a reference point, we have seeded the leaderboard with the results of some standard attacks. 56 | 57 | ### The CIFAR10 Model 58 | 59 | We used the code published in this repository to produce an adversarially robust model for CIFAR10 classification. The model is a residual convolutional neural network consisting of five residual units and a fully connected layer. This architecture is derived from the "w32-10 wide" variant of the [Tensorflow model repository](https://github.com/tensorflow/models/blob/master/resnet/resnet_model.py). 60 | The network was trained against an iterative adversary that is allowed to perturb each pixel by at most `epsilon=8.0`. 61 | 62 | The random seed used for training and the trained network weights will be kept secret. 63 | 64 | The `sha256()` digest of our model file is: 65 | ``` 66 | 555be6e892372599380c9da5d5f9802f9cbd098be8a47d24d96937a002305fd4 67 | ``` 68 | We will release the corresponding model file on September 15 2017, which is roughly two months after the start of this competition. **Edit: We are extending the deadline for submitting attacks to October 15th due to requests.** 69 | 70 | ### The Attack Model 71 | 72 | We are interested in adversarial inputs that are derived from the CIFAR10 test set. 73 | Each pixel can be perturbed by at most `epsilon=8.0` from its initial value on the `0-255` pixel scale. 74 | All pixels can be perturbed independently, so this is an l_infinity attack. 75 | 76 | ### Submitting an Attack 77 | 78 | Each attack should consist of a perturbed version of the CIFAR10 test set. 79 | Each perturbed image in this test set should follow the above attack model. 80 | 81 | The adversarial test set should be formated as a numpy array with one row per example and each row containing a 32x32x3 82 | array of pixels. 83 | Hence the overall dimensions are 10,000x32x32x3. 84 | Each pixel must be in the [0, 255] range. 85 | See the script `pgd_attack.py` for an attack that generates an adversarial test set in this format. 86 | 87 | In order to submit your attack, save the matrix containing your adversarial examples with `numpy.save` and email the resulting file to cifar10.challenge@gmail.com. 88 | We will then run the `run_attack.py` script on your file to verify that the attack is valid and to evaluate the accuracy of our secret model on your examples. 89 | After that, we will reply with the predictions of our model on each of your examples and the overall accuracy of our model on your evaluation set. 90 | 91 | If the attack is valid and outperforms all current attacks in the leaderboard, it will appear at the top of the leaderboard. 92 | Novel types of attacks might be included in the leaderboard even if they do not perform best. 93 | 94 | We strongly encourage you to disclose your attack method. 95 | We would be happy to add a link to your code in our leaderboard. 96 | 97 | ## Overview of the Code 98 | The code consists of seven Python scripts and the file `config.json` that contains various parameter settings. 99 | 100 | ### Running the code 101 | - `python train.py`: trains the network, storing checkpoints along 102 | the way. 103 | - `python eval.py`: an infinite evaluation loop, processing each new 104 | checkpoint as it is created while logging summaries. It is intended 105 | to be run in parallel with the `train.py` script. 106 | - `python pgd_attack.py`: applies the attack to the CIFAR10 eval set and 107 | stores the resulting adversarial eval set in a `.npy` file. This file is 108 | in a valid attack format for our challenge. 109 | - `python run_attack.py`: evaluates the model on the examples in 110 | the `.npy` file specified in config, while ensuring that the adversarial examples 111 | are indeed a valid attack. The script also saves the network predictions in `pred.npy`. 112 | - `python fetch_model.py name`: downloads the pre-trained model with the 113 | specified name (at the moment `adv_trained` or `natural`), prints the sha256 114 | hash, and places it in the models directory. 115 | - `cifar10_input.py` provides utility functions and classes for loading the CIFAR10 dataset. 116 | 117 | ### Parameters in `config.json` 118 | 119 | Model configuration: 120 | - `model_dir`: contains the path to the directory of the currently 121 | trained/evaluated model. 122 | 123 | Training configuration: 124 | - `tf_random_seed`: the seed for the RNG used to initialize the network 125 | weights. 126 | - `numpy_random_seed`: the seed for the RNG used to pass over the dataset in random order 127 | - `max_num_training_steps`: the number of training steps. 128 | - `num_output_steps`: the number of training steps between printing 129 | progress in standard output. 130 | - `num_summary_steps`: the number of training steps between storing 131 | tensorboard summaries. 132 | - `num_checkpoint_steps`: the number of training steps between storing 133 | model checkpoints. 134 | - `training_batch_size`: the size of the training batch. 135 | 136 | Evaluation configuration: 137 | - `num_eval_examples`: the number of CIFAR10 examples to evaluate the 138 | model on. 139 | - `eval_batch_size`: the size of the evaluation batches. 140 | - `eval_on_cpu`: forces the `eval.py` script to run on the CPU so it does not compete with `train.py` for GPU resources. 141 | 142 | Adversarial examples configuration: 143 | - `epsilon`: the maximum allowed perturbation per pixel. 144 | - `k`: the number of PGD iterations used by the adversary. 145 | - `a`: the size of the PGD adversary steps. 146 | - `random_start`: specifies whether the adversary will start iterating 147 | from the natural example or a random perturbation of it. 148 | - `loss_func`: the loss function used to run pgd on. `xent` corresponds to the 149 | standard cross-entropy loss, `cw` corresponds to the loss function 150 | of [Carlini and Wagner](https://arxiv.org/abs/1608.04644). 151 | - `store_adv_path`: the file in which adversarial examples are stored. 152 | Relevant for the `pgd_attack.py` and `run_attack.py` scripts. 153 | 154 | ## Example usage 155 | After cloning the repository you can either train a new network or evaluate/attack one of our pre-trained networks. 156 | #### Training a new network 157 | * Start training by running: 158 | ``` 159 | python train.py 160 | ``` 161 | * (Optional) Evaluation summaries can be logged by simultaneously 162 | running: 163 | ``` 164 | python eval.py 165 | ``` 166 | #### Download a pre-trained network 167 | * For an adversarially trained network, run 168 | ``` 169 | python fetch_model.py adv_trained 170 | ``` 171 | and use the `config.json` file to set `"model_dir": "models/adv_trained"`. 172 | * For a naturally trained network, run 173 | ``` 174 | python fetch_model.py natural 175 | ``` 176 | and use the `config.json` file to set `"model_dir": "models/naturally_trained"`. 177 | #### Test the network 178 | * Create an attack file by running 179 | ``` 180 | python pgd_attack.py 181 | ``` 182 | * Evaluate the network with 183 | ``` 184 | python run_attack.py 185 | ``` 186 | -------------------------------------------------------------------------------- /madry_cifar10/cifar10_input.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for importing the CIFAR10 dataset. 3 | 4 | Each image in the dataset is a numpy array of shape (32, 32, 3), with the values 5 | being unsigned integers (i.e., in the range 0,1,...,255). 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import os 13 | import pickle 14 | import sys 15 | import tensorflow as tf 16 | from tensorflow.examples.tutorials.mnist import input_data 17 | version = sys.version_info 18 | 19 | import numpy as np 20 | 21 | class CIFAR10Data(object): 22 | """ 23 | Unpickles the CIFAR10 dataset from a specified folder containing a pickled 24 | version following the format of Krizhevsky which can be found 25 | [here](https://www.cs.toronto.edu/~kriz/cifar.html). 26 | 27 | Inputs to constructor 28 | ===================== 29 | 30 | - path: path to the pickled dataset. The training data must be pickled 31 | into five files named data_batch_i for i = 1, ..., 5, containing 10,000 32 | examples each, the test data 33 | must be pickled into a single file called test_batch containing 10,000 34 | examples, and the 10 class names must be 35 | pickled into a file called batches.meta. The pickled examples should 36 | be stored as a tuple of two objects: an array of 10,000 32x32x3-shaped 37 | arrays, and an array of their 10,000 true labels. 38 | 39 | """ 40 | def __init__(self, path): 41 | train_filenames = ['data_batch_{}'.format(ii + 1) for ii in range(5)] 42 | eval_filename = 'test_batch' 43 | metadata_filename = 'batches.meta' 44 | 45 | train_images = np.zeros((50000, 32, 32, 3), dtype='uint8') 46 | train_labels = np.zeros(50000, dtype='int32') 47 | for ii, fname in enumerate(train_filenames): 48 | cur_images, cur_labels = self._load_datafile(os.path.join(path, fname)) 49 | train_images[ii * 10000 : (ii+1) * 10000, ...] = cur_images 50 | train_labels[ii * 10000 : (ii+1) * 10000, ...] = cur_labels 51 | eval_images, eval_labels = self._load_datafile( 52 | os.path.join(path, eval_filename)) 53 | 54 | with open(os.path.join(path, metadata_filename), 'rb') as fo: 55 | if version.major == 3: 56 | data_dict = pickle.load(fo, encoding='bytes') 57 | else: 58 | data_dict = pickle.load(fo) 59 | 60 | self.label_names = data_dict[b'label_names'] 61 | for ii in range(len(self.label_names)): 62 | self.label_names[ii] = self.label_names[ii].decode('utf-8') 63 | 64 | self.train_data = DataSubset(train_images, train_labels) 65 | self.eval_data = DataSubset(eval_images, eval_labels) 66 | 67 | @staticmethod 68 | def _load_datafile(filename): 69 | with open(filename, 'rb') as fo: 70 | if version.major == 3: 71 | data_dict = pickle.load(fo, encoding='bytes') 72 | else: 73 | data_dict = pickle.load(fo) 74 | 75 | assert data_dict[b'data'].dtype == np.uint8 76 | image_data = data_dict[b'data'] 77 | image_data = image_data.reshape((10000, 3, 32, 32)).transpose(0, 2, 3, 1) 78 | return image_data, np.array(data_dict[b'labels']) 79 | 80 | class AugmentedCIFAR10Data(object): 81 | """ 82 | Data augmentation wrapper over a loaded dataset. 83 | 84 | Inputs to constructor 85 | ===================== 86 | - raw_cifar10data: the loaded CIFAR10 dataset, via the CIFAR10Data class 87 | - sess: current tensorflow session 88 | - model: current model (needed for input tensor) 89 | """ 90 | def __init__(self, raw_cifar10data, sess, model): 91 | assert isinstance(raw_cifar10data, CIFAR10Data) 92 | self.image_size = 32 93 | 94 | # create augmentation computational graph 95 | self.x_input_placeholder = tf.placeholder(tf.float32, shape=[None, 32, 32, 3]) 96 | padded = tf.map_fn(lambda img: tf.image.resize_image_with_crop_or_pad( 97 | img, self.image_size + 4, self.image_size + 4), 98 | self.x_input_placeholder) 99 | cropped = tf.map_fn(lambda img: tf.random_crop(img, [self.image_size, 100 | self.image_size, 101 | 3]), padded) 102 | flipped = tf.map_fn(lambda img: tf.image.random_flip_left_right(img), cropped) 103 | self.augmented = flipped 104 | 105 | self.train_data = AugmentedDataSubset(raw_cifar10data.train_data, sess, 106 | self.x_input_placeholder, 107 | self.augmented) 108 | self.eval_data = AugmentedDataSubset(raw_cifar10data.eval_data, sess, 109 | self.x_input_placeholder, 110 | self.augmented) 111 | self.label_names = raw_cifar10data.label_names 112 | 113 | 114 | class DataSubset(object): 115 | def __init__(self, xs, ys): 116 | self.xs = xs 117 | self.n = xs.shape[0] 118 | self.ys = ys 119 | self.batch_start = 0 120 | self.cur_order = np.random.permutation(self.n) 121 | 122 | def get_next_batch(self, batch_size, multiple_passes=False, reshuffle_after_pass=True): 123 | if self.n < batch_size: 124 | raise ValueError('Batch size can be at most the dataset size') 125 | if not multiple_passes: 126 | actual_batch_size = min(batch_size, self.n - self.batch_start) 127 | if actual_batch_size <= 0: 128 | raise ValueError('Pass through the dataset is complete.') 129 | batch_end = self.batch_start + actual_batch_size 130 | batch_xs = self.xs[self.cur_order[self.batch_start : batch_end], ...] 131 | batch_ys = self.ys[self.cur_order[self.batch_start : batch_end], ...] 132 | self.batch_start += actual_batch_size 133 | return batch_xs, batch_ys 134 | actual_batch_size = min(batch_size, self.n - self.batch_start) 135 | if actual_batch_size < batch_size: 136 | if reshuffle_after_pass: 137 | self.cur_order = np.random.permutation(self.n) 138 | self.batch_start = 0 139 | batch_end = self.batch_start + batch_size 140 | batch_xs = self.xs[self.cur_order[self.batch_start : batch_end], ...] 141 | batch_ys = self.ys[self.cur_order[self.batch_start : batch_end], ...] 142 | self.batch_start += batch_size 143 | return batch_xs, batch_ys 144 | 145 | 146 | class AugmentedDataSubset(object): 147 | def __init__(self, raw_datasubset, sess, x_input_placeholder, 148 | augmented): 149 | self.sess = sess 150 | self.raw_datasubset = raw_datasubset 151 | self.x_input_placeholder = x_input_placeholder 152 | self.augmented = augmented 153 | 154 | def get_next_batch(self, batch_size, multiple_passes=False, reshuffle_after_pass=True): 155 | raw_batch = self.raw_datasubset.get_next_batch(batch_size, multiple_passes, 156 | reshuffle_after_pass) 157 | images = raw_batch[0].astype(np.float32) 158 | return self.sess.run(self.augmented, feed_dict={self.x_input_placeholder: 159 | raw_batch[0]}), raw_batch[1] 160 | 161 | -------------------------------------------------------------------------------- /madry_cifar10/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_comment": "===== MODEL CONFIGURATION =====", 3 | "model_dir": "models/secret", 4 | 5 | "_comment": "===== DATASET CONFIGURATION =====", 6 | "data_path": "cifar10_data", 7 | 8 | "_comment": "===== TRAINING CONFIGURATION =====", 9 | "tf_random_seed": 451760341, 10 | "np_random_seed": 216105420, 11 | "max_num_training_steps": 80000, 12 | "num_output_steps": 100, 13 | "num_summary_steps": 100, 14 | "num_checkpoint_steps": 1000, 15 | "training_batch_size": 128, 16 | "step_size_schedule": [[0, 0.1], [40000, 0.01], [60000, 0.001]], 17 | "weight_decay": 0.0002, 18 | "momentum": 0.9, 19 | 20 | "_comment": "===== EVAL CONFIGURATION =====", 21 | "num_eval_examples": 100, 22 | "eval_batch_size": 100, 23 | "eval_on_cpu": false, 24 | 25 | "_comment": "=====ADVERSARIAL EXAMPLES CONFIGURATION=====", 26 | "epsilon": 8.0, 27 | "num_steps": 10, 28 | "step_size": 2.0, 29 | "random_start": true, 30 | "loss_func": "xent", 31 | "store_adv_path": "attack.npy" 32 | } 33 | -------------------------------------------------------------------------------- /madry_cifar10/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Infinite evaluation loop going through the checkpoints in the model directory 3 | as they appear and evaluating them. Accuracy and average loss are printed and 4 | added as tensorboard summaries. 5 | """ 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | from datetime import datetime 11 | import json 12 | import math 13 | import os 14 | import sys 15 | import time 16 | 17 | import tensorflow as tf 18 | 19 | import cifar10_input 20 | from model import Model 21 | from pgd_attack import LinfPGDAttack 22 | 23 | # Global constants 24 | with open('config.json') as config_file: 25 | config = json.load(config_file) 26 | num_eval_examples = config['num_eval_examples'] 27 | eval_batch_size = config['eval_batch_size'] 28 | eval_on_cpu = config['eval_on_cpu'] 29 | data_path = config['data_path'] 30 | 31 | model_dir = config['model_dir'] 32 | 33 | # Set upd the data, hyperparameters, and the model 34 | cifar = cifar10_input.CIFAR10Data(data_path) 35 | 36 | if eval_on_cpu: 37 | with tf.device("/cpu:0"): 38 | model = Model(mode='eval') 39 | attack = LinfPGDAttack(model, 40 | config['epsilon'], 41 | config['num_steps'], 42 | config['step_size'], 43 | config['random_start'], 44 | config['loss_func']) 45 | else: 46 | model = Model(mode='eval') 47 | attack = LinfPGDAttack(model, 48 | config['epsilon'], 49 | config['num_steps'], 50 | config['step_size'], 51 | config['random_start'], 52 | config['loss_func']) 53 | 54 | global_step = tf.contrib.framework.get_or_create_global_step() 55 | 56 | # Setting up the Tensorboard and checkpoint outputs 57 | if not os.path.exists(model_dir): 58 | os.makedirs(model_dir) 59 | eval_dir = os.path.join(model_dir, 'eval') 60 | if not os.path.exists(eval_dir): 61 | os.makedirs(eval_dir) 62 | 63 | last_checkpoint_filename = '' 64 | already_seen_state = False 65 | 66 | saver = tf.train.Saver() 67 | summary_writer = tf.summary.FileWriter(eval_dir) 68 | 69 | # A function for evaluating a single checkpoint 70 | def evaluate_checkpoint(filename): 71 | with tf.Session() as sess: 72 | # Restore the checkpoint 73 | saver.restore(sess, filename) 74 | 75 | # Iterate over the samples batch-by-batch 76 | num_batches = int(math.ceil(num_eval_examples / eval_batch_size)) 77 | total_xent_nat = 0. 78 | total_xent_adv = 0. 79 | total_corr_nat = 0 80 | total_corr_adv = 0 81 | 82 | for ibatch in range(num_batches): 83 | bstart = ibatch * eval_batch_size 84 | bend = min(bstart + eval_batch_size, num_eval_examples) 85 | 86 | x_batch = cifar.eval_data.xs[bstart:bend, :] 87 | y_batch = cifar.eval_data.ys[bstart:bend] 88 | 89 | dict_nat = {model.x_input: x_batch, 90 | model.y_input: y_batch} 91 | 92 | x_batch_adv = attack.perturb(x_batch, y_batch, sess) 93 | 94 | dict_adv = {model.x_input: x_batch_adv, 95 | model.y_input: y_batch} 96 | 97 | cur_corr_nat, cur_xent_nat = sess.run( 98 | [model.num_correct,model.xent], 99 | feed_dict = dict_nat) 100 | cur_corr_adv, cur_xent_adv = sess.run( 101 | [model.num_correct,model.xent], 102 | feed_dict = dict_adv) 103 | 104 | print(eval_batch_size) 105 | print("Correctly classified natural examples: {}".format(cur_corr_nat)) 106 | print("Correctly classified adversarial examples: {}".format(cur_corr_adv)) 107 | total_xent_nat += cur_xent_nat 108 | total_xent_adv += cur_xent_adv 109 | total_corr_nat += cur_corr_nat 110 | total_corr_adv += cur_corr_adv 111 | 112 | avg_xent_nat = total_xent_nat / num_eval_examples 113 | avg_xent_adv = total_xent_adv / num_eval_examples 114 | acc_nat = total_corr_nat / num_eval_examples 115 | acc_adv = total_corr_adv / num_eval_examples 116 | 117 | summary = tf.Summary(value=[ 118 | tf.Summary.Value(tag='xent adv eval', simple_value= avg_xent_adv), 119 | tf.Summary.Value(tag='xent adv', simple_value= avg_xent_adv), 120 | tf.Summary.Value(tag='xent nat', simple_value= avg_xent_nat), 121 | tf.Summary.Value(tag='accuracy adv eval', simple_value= acc_adv), 122 | tf.Summary.Value(tag='accuracy adv', simple_value= acc_adv), 123 | tf.Summary.Value(tag='accuracy nat', simple_value= acc_nat)]) 124 | summary_writer.add_summary(summary, global_step.eval(sess)) 125 | 126 | print('natural: {:.2f}%'.format(100 * acc_nat)) 127 | print('adversarial: {:.2f}%'.format(100 * acc_adv)) 128 | print('avg nat loss: {:.4f}'.format(avg_xent_nat)) 129 | print('avg adv loss: {:.4f}'.format(avg_xent_adv)) 130 | 131 | # Infinite eval loop 132 | while True: 133 | cur_checkpoint = tf.train.latest_checkpoint(model_dir) 134 | 135 | # Case 1: No checkpoint yet 136 | if cur_checkpoint is None: 137 | if not already_seen_state: 138 | print('No checkpoint yet, waiting ...', end='') 139 | already_seen_state = True 140 | else: 141 | print('.', end='') 142 | sys.stdout.flush() 143 | time.sleep(10) 144 | # Case 2: Previously unseen checkpoint 145 | elif cur_checkpoint != last_checkpoint_filename: 146 | print('\nCheckpoint {}, evaluating ... ({})'.format(cur_checkpoint, 147 | datetime.now())) 148 | sys.stdout.flush() 149 | last_checkpoint_filename = cur_checkpoint 150 | already_seen_state = False 151 | evaluate_checkpoint(cur_checkpoint) 152 | # Case 3: Previously evaluated checkpoint 153 | else: 154 | if not already_seen_state: 155 | print('Waiting for the next checkpoint ... ({}) '.format( 156 | datetime.now()), 157 | end='') 158 | already_seen_state = True 159 | else: 160 | print('.', end='') 161 | sys.stdout.flush() 162 | time.sleep(10) 163 | -------------------------------------------------------------------------------- /madry_cifar10/fetch_model.py: -------------------------------------------------------------------------------- 1 | """Downloads a model, computes its SHA256 hash and unzips it 2 | at the proper location.""" 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import sys 8 | import zipfile 9 | import hashlib 10 | 11 | if len(sys.argv) == 1 or sys.argv[1] not in ['natural', 12 | 'adv_trained', 13 | 'secret']: 14 | print('Usage: python fetch_model.py [natural, adv_trained]') 15 | sys.exit(1) 16 | 17 | if sys.argv[1] == 'natural': 18 | url = 'https://www.dropbox.com/s/cgzd5odqoojvxzk/natural.zip?dl=1' 19 | elif sys.argv[1] == 'adv_trained': 20 | url = 'https://www.dropbox.com/s/g4b6ntrp8zrudbz/adv_trained.zip?dl=1' 21 | else: # fetch secret model 22 | url = 'https://www.dropbox.com/s/ywc0hg8lr5ba8zd/secret.zip?dl=1' 23 | 24 | fname = url.split('/')[-1].split('?')[0] # get the name of the file 25 | 26 | # model download 27 | print('Downloading models') 28 | if sys.version_info >= (3,): 29 | import urllib.request 30 | urllib.request.urlretrieve(url, fname) 31 | else: 32 | import urllib 33 | urllib.urlretrieve(url, fname) 34 | 35 | # computing model hash 36 | sha256 = hashlib.sha256() 37 | with open(fname, 'rb') as f: 38 | data = f.read() 39 | sha256.update(data) 40 | print('SHA256 hash: {}'.format(sha256.hexdigest())) 41 | 42 | # extracting model 43 | print('Extracting model') 44 | with zipfile.ZipFile(fname, 'r') as model_zip: 45 | model_zip.extractall() 46 | print('Extracted model in {}'.format(model_zip.namelist()[0])) 47 | -------------------------------------------------------------------------------- /madry_cifar10/model.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/tensorflow/models/tree/master/resnet 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | class Model(object): 10 | """ResNet model.""" 11 | 12 | def __init__(self, mode='eval'): 13 | """ResNet constructor. 14 | 15 | Args: 16 | mode: One of 'train' and 'eval'. 17 | """ 18 | self.mode = mode 19 | self._build_model() 20 | 21 | def add_internal_summaries(self): 22 | pass 23 | 24 | def _stride_arr(self, stride): 25 | """Map a stride scalar to the stride array for tf.nn.conv2d.""" 26 | return [1, stride, stride, 1] 27 | 28 | def _build_model(self): 29 | assert self.mode == 'train' or self.mode == 'eval' 30 | """Build the core model within the graph.""" 31 | with tf.variable_scope('input'): 32 | 33 | self.x_input = tf.placeholder( 34 | tf.float32, 35 | shape=[None, 32, 32, 3]) 36 | 37 | self.y_input = tf.placeholder(tf.int64, shape=None) 38 | 39 | 40 | input_standardized = tf.map_fn(lambda img: tf.image.per_image_standardization(img), 41 | self.x_input) 42 | x = self._conv('init_conv', input_standardized, 3, 3, 16, self._stride_arr(1)) 43 | 44 | 45 | 46 | strides = [1, 2, 2] 47 | activate_before_residual = [True, False, False] 48 | res_func = self._residual 49 | 50 | # Uncomment the following codes to use w28-10 wide residual network. 51 | # It is more memory efficient than very deep residual network and has 52 | # comparably good performance. 53 | # https://arxiv.org/pdf/1605.07146v1.pdf 54 | filters = [16, 160, 320, 640] 55 | 56 | 57 | # Update hps.num_residual_units to 9 58 | 59 | with tf.variable_scope('unit_1_0'): 60 | x = res_func(x, filters[0], filters[1], self._stride_arr(strides[0]), 61 | activate_before_residual[0]) 62 | for i in range(1, 5): 63 | with tf.variable_scope('unit_1_%d' % i): 64 | x = res_func(x, filters[1], filters[1], self._stride_arr(1), False) 65 | 66 | with tf.variable_scope('unit_2_0'): 67 | x = res_func(x, filters[1], filters[2], self._stride_arr(strides[1]), 68 | activate_before_residual[1]) 69 | for i in range(1, 5): 70 | with tf.variable_scope('unit_2_%d' % i): 71 | x = res_func(x, filters[2], filters[2], self._stride_arr(1), False) 72 | 73 | with tf.variable_scope('unit_3_0'): 74 | x = res_func(x, filters[2], filters[3], self._stride_arr(strides[2]), 75 | activate_before_residual[2]) 76 | for i in range(1, 5): 77 | with tf.variable_scope('unit_3_%d' % i): 78 | x = res_func(x, filters[3], filters[3], self._stride_arr(1), False) 79 | 80 | with tf.variable_scope('unit_last'): 81 | x = self._batch_norm('final_bn', x) 82 | x = self._relu(x, 0.1) 83 | x = self._global_avg_pool(x) 84 | 85 | with tf.variable_scope('logit'): 86 | self.pre_softmax = self._fully_connected(x, 10) 87 | 88 | self.predictions = tf.argmax(self.pre_softmax, 1) 89 | self.correct_prediction = tf.equal(self.predictions, self.y_input) 90 | self.num_correct = tf.reduce_sum( 91 | tf.cast(self.correct_prediction, tf.int64)) 92 | self.accuracy = tf.reduce_mean( 93 | tf.cast(self.correct_prediction, tf.float32)) 94 | 95 | with tf.variable_scope('costs'): 96 | self.y_xent = tf.nn.sparse_softmax_cross_entropy_with_logits( 97 | logits=self.pre_softmax, labels=self.y_input) 98 | self.xent_per_point = self.y_xent 99 | self.xent = tf.reduce_sum(self.y_xent, name='y_xent') 100 | self.mean_xent = tf.reduce_mean(self.y_xent) 101 | self.weight_decay_loss = self._decay() 102 | 103 | def _batch_norm(self, name, x): 104 | """Batch normalization.""" 105 | with tf.name_scope(name): 106 | return tf.contrib.layers.batch_norm( 107 | inputs=x, 108 | decay=.9, 109 | center=True, 110 | scale=True, 111 | activation_fn=None, 112 | updates_collections=None, 113 | is_training=(self.mode == 'train')) 114 | 115 | def _residual(self, x, in_filter, out_filter, stride, 116 | activate_before_residual=False): 117 | """Residual unit with 2 sub layers.""" 118 | if activate_before_residual: 119 | with tf.variable_scope('shared_activation'): 120 | x = self._batch_norm('init_bn', x) 121 | x = self._relu(x, 0.1) 122 | orig_x = x 123 | else: 124 | with tf.variable_scope('residual_only_activation'): 125 | orig_x = x 126 | x = self._batch_norm('init_bn', x) 127 | x = self._relu(x, 0.1) 128 | 129 | with tf.variable_scope('sub1'): 130 | x = self._conv('conv1', x, 3, in_filter, out_filter, stride) 131 | 132 | with tf.variable_scope('sub2'): 133 | x = self._batch_norm('bn2', x) 134 | x = self._relu(x, 0.1) 135 | x = self._conv('conv2', x, 3, out_filter, out_filter, [1, 1, 1, 1]) 136 | 137 | with tf.variable_scope('sub_add'): 138 | if in_filter != out_filter: 139 | orig_x = tf.nn.avg_pool(orig_x, stride, stride, 'VALID') 140 | orig_x = tf.pad( 141 | orig_x, [[0, 0], [0, 0], [0, 0], 142 | [(out_filter-in_filter)//2, (out_filter-in_filter)//2]]) 143 | x += orig_x 144 | 145 | tf.logging.debug('image after unit %s', x.get_shape()) 146 | return x 147 | 148 | def _decay(self): 149 | """L2 weight decay loss.""" 150 | costs = [] 151 | for var in tf.trainable_variables(): 152 | if var.op.name.find('DW') > 0: 153 | costs.append(tf.nn.l2_loss(var)) 154 | return tf.add_n(costs) 155 | 156 | def _conv(self, name, x, filter_size, in_filters, out_filters, strides): 157 | """Convolution.""" 158 | with tf.variable_scope(name): 159 | n = filter_size * filter_size * out_filters 160 | kernel = tf.get_variable( 161 | 'DW', [filter_size, filter_size, in_filters, out_filters], 162 | tf.float32, initializer=tf.random_normal_initializer( 163 | stddev=np.sqrt(2.0/n))) 164 | return tf.nn.conv2d(x, kernel, strides, padding='SAME') 165 | 166 | def _relu(self, x, leakiness=0.0): 167 | """Relu, with optional leaky support.""" 168 | return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu') 169 | 170 | def _fully_connected(self, x, out_dim): 171 | """FullyConnected layer for final output.""" 172 | num_non_batch_dimensions = len(x.shape) 173 | prod_non_batch_dimensions = 1 174 | for ii in range(num_non_batch_dimensions - 1): 175 | prod_non_batch_dimensions *= int(x.shape[ii + 1]) 176 | x = tf.reshape(x, [tf.shape(x)[0], -1]) 177 | w = tf.get_variable( 178 | 'DW', [prod_non_batch_dimensions, out_dim], 179 | initializer=tf.uniform_unit_scaling_initializer(factor=1.0)) 180 | b = tf.get_variable('biases', [out_dim], 181 | initializer=tf.constant_initializer()) 182 | return tf.nn.xw_plus_b(x, w, b) 183 | 184 | def _global_avg_pool(self, x): 185 | assert x.get_shape().ndims == 4 186 | return tf.reduce_mean(x, [1, 2]) 187 | 188 | 189 | 190 | -------------------------------------------------------------------------------- /madry_cifar10/model_robustml.py: -------------------------------------------------------------------------------- 1 | import robustml 2 | import tensorflow as tf 3 | 4 | import model 5 | 6 | class Model(robustml.model.Model): 7 | def __init__(self, sess): 8 | self._model = model.Model('eval') 9 | 10 | saver = tf.train.Saver() 11 | checkpoint = tf.train.latest_checkpoint('models/secret') 12 | saver.restore(sess, checkpoint) 13 | 14 | self._sess = sess 15 | self._input = self._model.x_input 16 | self._logits = self._model.pre_softmax 17 | self._predictions = self._model.predictions 18 | self._dataset = robustml.dataset.CIFAR10() 19 | self._threat_model = robustml.threat_model.Linf(epsilon=0.03) 20 | 21 | @property 22 | def dataset(self): 23 | return self._dataset 24 | 25 | @property 26 | def threat_model(self): 27 | return self._threat_model 28 | 29 | def classify(self, x): 30 | return self._sess.run(self._predictions, 31 | {self._input: x})[0] 32 | 33 | # expose attack interface 34 | 35 | @property 36 | def input(self): 37 | return self._input 38 | 39 | @property 40 | def logits(self): 41 | return self._logits 42 | 43 | @property 44 | def predictions(self): 45 | return self._predictions 46 | -------------------------------------------------------------------------------- /madry_cifar10/pgd_attack.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of attack methods. Running this file as a program will 3 | apply the attack to the model specified by the config file and store 4 | the examples in an .npy file. 5 | """ 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import tensorflow as tf 11 | import numpy as np 12 | 13 | import cifar10_input 14 | 15 | 16 | class LinfPGDAttack: 17 | def __init__(self, model, epsilon, num_steps, step_size, random_start, loss_func): 18 | """Attack parameter initialization. The attack performs k steps of 19 | size a, while always staying within epsilon from the initial 20 | point.""" 21 | self.model = model 22 | self.epsilon = epsilon 23 | self.num_steps = num_steps 24 | self.step_size = step_size 25 | self.rand = random_start 26 | 27 | if loss_func == 'xent': 28 | loss = model.xent 29 | elif loss_func == 'cw': 30 | label_mask = tf.one_hot(model.y_input, 31 | 10, 32 | on_value=1.0, 33 | off_value=0.0, 34 | dtype=tf.float32) 35 | correct_logit = tf.reduce_sum(label_mask * model.pre_softmax, axis=1) 36 | wrong_logit = tf.reduce_max((1 - label_mask) * model.pre_softmax - 1e4 * label_mask, axis=1) 37 | loss = -tf.nn.relu(correct_logit - wrong_logit + 50) 38 | else: 39 | print('Unknown loss function. Defaulting to cross-entropy') 40 | loss = model.xent 41 | 42 | self.grad = tf.gradients(loss, model.x_input)[0] 43 | 44 | def perturb(self, x_nat, y, sess): 45 | """Given a set of examples (x_nat, y), returns a set of adversarial 46 | examples within epsilon of x_nat in l_infinity norm.""" 47 | if self.rand: 48 | x = x_nat + np.random.uniform(-self.epsilon, self.epsilon, x_nat.shape) 49 | x = np.clip(x, 0, 255) # ensure valid pixel range 50 | else: 51 | x = np.copy(x_nat) 52 | 53 | for i in range(self.num_steps): 54 | grad = sess.run(self.grad, feed_dict={self.model.x_input: x, 55 | self.model.y_input: y}) 56 | 57 | x = np.add(x, self.step_size * np.sign(grad), out=x, casting='unsafe') 58 | 59 | x = np.clip(x, x_nat - self.epsilon, x_nat + self.epsilon) 60 | x = np.clip(x, 0, 255) # ensure valid pixel range 61 | 62 | return x 63 | 64 | 65 | if __name__ == '__main__': 66 | import json 67 | import sys 68 | import math 69 | 70 | from model import Model 71 | 72 | with open('config.json') as config_file: 73 | config = json.load(config_file) 74 | 75 | model_file = tf.train.latest_checkpoint(config['model_dir']) 76 | if model_file is None: 77 | print('No model found') 78 | sys.exit() 79 | 80 | model = Model(mode='eval') 81 | attack = LinfPGDAttack(model, 82 | config['epsilon'], 83 | config['num_steps'], 84 | config['step_size'], 85 | config['random_start'], 86 | config['loss_func']) 87 | saver = tf.train.Saver() 88 | 89 | data_path = config['data_path'] 90 | cifar = cifar10_input.CIFAR10Data(data_path) 91 | 92 | gpu_options = tf.GPUOptions(visible_device_list='7', per_process_gpu_memory_fraction=0.5) 93 | tf_config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options) 94 | with tf.Session(config=tf_config) as sess: 95 | # Restore the checkpoint 96 | saver.restore(sess, model_file) 97 | 98 | # Iterate over the samples batch-by-batch 99 | num_eval_examples = config['num_eval_examples'] 100 | eval_batch_size = config['eval_batch_size'] 101 | num_batches = int(math.ceil(num_eval_examples / eval_batch_size)) 102 | 103 | x_adv = [] # adv accumulator 104 | 105 | print('Iterating over {} batches'.format(num_batches)) 106 | 107 | for ibatch in range(num_batches): 108 | bstart = ibatch * eval_batch_size 109 | bend = min(bstart + eval_batch_size, num_eval_examples) 110 | print('batch size: {}'.format(bend - bstart)) 111 | 112 | x_batch = cifar.eval_data.xs[bstart:bend, :] 113 | y_batch = cifar.eval_data.ys[bstart:bend] 114 | 115 | x_batch_adv = attack.perturb(x_batch, y_batch, sess) 116 | 117 | x_adv.append(x_batch_adv) 118 | 119 | print('Storing examples') 120 | path = config['store_adv_path'] 121 | x_adv = np.concatenate(x_adv, axis=0) 122 | np.save(path, x_adv) 123 | print('Examples stored in {}'.format(path)) 124 | -------------------------------------------------------------------------------- /madry_cifar10/run_attack.py: -------------------------------------------------------------------------------- 1 | """Evaluates a model against examples from a .npy file as specified 2 | in config.json""" 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from datetime import datetime 8 | import json 9 | import math 10 | import os 11 | import sys 12 | import time 13 | 14 | import tensorflow as tf 15 | import numpy as np 16 | 17 | from model import Model 18 | import cifar10_input 19 | 20 | with open('config.json') as config_file: 21 | config = json.load(config_file) 22 | 23 | data_path = config['data_path'] 24 | 25 | def run_attack(checkpoint, x_adv, epsilon): 26 | cifar = cifar10_input.CIFAR10Data(data_path) 27 | 28 | model = Model(mode='eval') 29 | 30 | saver = tf.train.Saver() 31 | 32 | num_eval_examples = 10000 33 | eval_batch_size = 100 34 | 35 | num_batches = int(math.ceil(num_eval_examples / eval_batch_size)) 36 | total_corr = 0 37 | 38 | x_nat = cifar.eval_data.xs 39 | l_inf = np.amax(np.abs(x_nat - x_adv)) 40 | 41 | if l_inf > epsilon + 0.0001: 42 | print('maximum perturbation found: {}'.format(l_inf)) 43 | print('maximum perturbation allowed: {}'.format(epsilon)) 44 | return 45 | 46 | y_pred = [] # label accumulator 47 | 48 | with tf.Session() as sess: 49 | # Restore the checkpoint 50 | saver.restore(sess, checkpoint) 51 | 52 | # Iterate over the samples batch-by-batch 53 | for ibatch in range(num_batches): 54 | bstart = ibatch * eval_batch_size 55 | bend = min(bstart + eval_batch_size, num_eval_examples) 56 | 57 | x_batch = x_adv[bstart:bend, :] 58 | y_batch = cifar.eval_data.ys[bstart:bend] 59 | 60 | dict_adv = {model.x_input: x_batch, 61 | model.y_input: y_batch} 62 | cur_corr, y_pred_batch = sess.run([model.num_correct, model.predictions], 63 | feed_dict=dict_adv) 64 | 65 | total_corr += cur_corr 66 | y_pred.append(y_pred_batch) 67 | 68 | accuracy = total_corr / num_eval_examples 69 | 70 | print('Accuracy: {:.2f}%'.format(100.0 * accuracy)) 71 | y_pred = np.concatenate(y_pred, axis=0) 72 | np.save('pred.npy', y_pred) 73 | print('Output saved at pred.npy') 74 | 75 | if __name__ == '__main__': 76 | import json 77 | 78 | with open('config.json') as config_file: 79 | config = json.load(config_file) 80 | 81 | model_dir = config['model_dir'] 82 | 83 | checkpoint = tf.train.latest_checkpoint(model_dir) 84 | x_adv = np.load(config['store_adv_path']) 85 | 86 | if checkpoint is None: 87 | print('No checkpoint found') 88 | elif x_adv.shape != (10000, 32, 32, 3): 89 | print('Invalid shape: expected (10000, 32, 32, 3), found {}'.format(x_adv.shape)) 90 | elif np.amax(x_adv) > 255.0001 or np.amin(x_adv) < -0.0001: 91 | print('Invalid pixel range. Expected [0, 255], found [{}, {}]'.format( 92 | np.amin(x_adv), 93 | np.amax(x_adv))) 94 | else: 95 | run_attack(checkpoint, x_adv, config['epsilon']) 96 | -------------------------------------------------------------------------------- /madry_cifar10/train.py: -------------------------------------------------------------------------------- 1 | """Trains a model, saving checkpoints and tensorboard summaries along 2 | the way.""" 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from datetime import datetime 8 | import json 9 | import os 10 | import shutil 11 | from timeit import default_timer as timer 12 | 13 | import tensorflow as tf 14 | import numpy as np 15 | 16 | from model import Model 17 | import cifar10_input 18 | from pgd_attack import LinfPGDAttack 19 | 20 | with open('config.json') as config_file: 21 | config = json.load(config_file) 22 | 23 | # seeding randomness 24 | tf.set_random_seed(config['tf_random_seed']) 25 | np.random.seed(config['np_random_seed']) 26 | 27 | # Setting up training parameters 28 | max_num_training_steps = config['max_num_training_steps'] 29 | num_output_steps = config['num_output_steps'] 30 | num_summary_steps = config['num_summary_steps'] 31 | num_checkpoint_steps = config['num_checkpoint_steps'] 32 | step_size_schedule = config['step_size_schedule'] 33 | weight_decay = config['weight_decay'] 34 | data_path = config['data_path'] 35 | momentum = config['momentum'] 36 | batch_size = config['training_batch_size'] 37 | 38 | # Setting up the data and the model 39 | raw_cifar = cifar10_input.CIFAR10Data(data_path) 40 | global_step = tf.contrib.framework.get_or_create_global_step() 41 | model = Model(mode='train') 42 | 43 | # Setting up the optimizer 44 | boundaries = [int(sss[0]) for sss in step_size_schedule] 45 | boundaries = boundaries[1:] 46 | values = [sss[1] for sss in step_size_schedule] 47 | learning_rate = tf.train.piecewise_constant( 48 | tf.cast(global_step, tf.int32), 49 | boundaries, 50 | values) 51 | total_loss = model.mean_xent + weight_decay * model.weight_decay_loss 52 | train_step = tf.train.MomentumOptimizer(learning_rate, momentum).minimize( 53 | total_loss, 54 | global_step=global_step) 55 | 56 | # Set up adversary 57 | attack = LinfPGDAttack(model, 58 | config['epsilon'], 59 | config['num_steps'], 60 | config['step_size'], 61 | config['random_start'], 62 | config['loss_func']) 63 | 64 | # Setting up the Tensorboard and checkpoint outputs 65 | model_dir = config['model_dir'] 66 | if not os.path.exists(model_dir): 67 | os.makedirs(model_dir) 68 | 69 | # We add accuracy and xent twice so we can easily make three types of 70 | # comparisons in Tensorboard: 71 | # - train vs eval (for a single run) 72 | # - train of different runs 73 | # - eval of different runs 74 | 75 | saver = tf.train.Saver(max_to_keep=3) 76 | tf.summary.scalar('accuracy adv train', model.accuracy) 77 | tf.summary.scalar('accuracy adv', model.accuracy) 78 | tf.summary.scalar('xent adv train', model.xent / batch_size) 79 | tf.summary.scalar('xent adv', model.xent / batch_size) 80 | tf.summary.image('images adv train', model.x_input) 81 | merged_summaries = tf.summary.merge_all() 82 | 83 | # keep the configuration file with the model for reproducibility 84 | shutil.copy('config.json', model_dir) 85 | 86 | with tf.Session() as sess: 87 | 88 | # initialize data augmentation 89 | cifar = cifar10_input.AugmentedCIFAR10Data(raw_cifar, sess, model) 90 | 91 | # Initialize the summary writer, global variables, and our time counter. 92 | summary_writer = tf.summary.FileWriter(model_dir, sess.graph) 93 | sess.run(tf.global_variables_initializer()) 94 | training_time = 0.0 95 | 96 | # Main training loop 97 | for ii in range(max_num_training_steps): 98 | x_batch, y_batch = cifar.train_data.get_next_batch(batch_size, 99 | multiple_passes=True) 100 | 101 | # Compute Adversarial Perturbations 102 | start = timer() 103 | x_batch_adv = attack.perturb(x_batch, y_batch, sess) 104 | end = timer() 105 | training_time += end - start 106 | 107 | nat_dict = {model.x_input: x_batch, 108 | model.y_input: y_batch} 109 | 110 | adv_dict = {model.x_input: x_batch_adv, 111 | model.y_input: y_batch} 112 | 113 | # Output to stdout 114 | if ii % num_output_steps == 0: 115 | nat_acc = sess.run(model.accuracy, feed_dict=nat_dict) 116 | adv_acc = sess.run(model.accuracy, feed_dict=adv_dict) 117 | print('Step {}: ({})'.format(ii, datetime.now())) 118 | print(' training nat accuracy {:.4}%'.format(nat_acc * 100)) 119 | print(' training adv accuracy {:.4}%'.format(adv_acc * 100)) 120 | if ii != 0: 121 | print(' {} examples per second'.format( 122 | num_output_steps * batch_size / training_time)) 123 | training_time = 0.0 124 | # Tensorboard summaries 125 | if ii % num_summary_steps == 0: 126 | summary = sess.run(merged_summaries, feed_dict=adv_dict) 127 | summary_writer.add_summary(summary, global_step.eval(sess)) 128 | 129 | # Write a checkpoint 130 | if ii % num_checkpoint_steps == 0: 131 | saver.save(sess, 132 | os.path.join(model_dir, 'checkpoint'), 133 | global_step=global_step) 134 | 135 | # Actual training step 136 | start = timer() 137 | sess.run(train_step, feed_dict=adv_dict) 138 | end = timer() 139 | training_time += end - start 140 | -------------------------------------------------------------------------------- /madry_mnist/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Aleksander Madry, Aleksandar Makelov, Ludwig Schmidt, Dimitris Tsipras, and Adrian Vladu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /madry_mnist/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_comment": "===== MODEL CONFIGURATION =====", 3 | "model_dir": "models/secret", 4 | 5 | "_comment": "===== TRAINING CONFIGURATION =====", 6 | "random_seed": 4557077, 7 | "max_num_training_steps": 100000, 8 | "num_output_steps": 100, 9 | "num_summary_steps": 100, 10 | "num_checkpoint_steps": 300, 11 | "training_batch_size": 50, 12 | 13 | "_comment": "===== EVAL CONFIGURATION =====", 14 | "num_eval_examples": 10000, 15 | "eval_on_cpu": false, 16 | 17 | "_comment": "=====ADVERSARIAL EXAMPLES CONFIGURATION=====", 18 | "epsilon": 0.3, 19 | "k": 100, 20 | "a": 0.01, 21 | "random_start": true, 22 | "loss_func": "xent", 23 | "store_adv_path": "attack.npy" 24 | } 25 | -------------------------------------------------------------------------------- /madry_mnist/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Infinite evaluation loop going through the checkpoints in the model directory 3 | as they appear and evaluating them. Accuracy and average loss are printed and 4 | added as tensorboard summaries. 5 | """ 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | from datetime import datetime 11 | import json 12 | import math 13 | import os 14 | import sys 15 | import time 16 | 17 | import tensorflow as tf 18 | from tensorflow.examples.tutorials.mnist import input_data 19 | 20 | from model import Model 21 | from attack import LinfPGDAttack 22 | 23 | # Global constants 24 | with open('config.json') as config_file: 25 | config = json.load(config_file) 26 | num_eval_examples = config['num_eval_examples'] 27 | eval_batch_size = config['eval_batch_size'] 28 | eval_on_cpu = config['eval_on_cpu'] 29 | 30 | model_dir = config['model_dir'] 31 | 32 | # Set upd the data, hyperparameters, and the model 33 | mnist = input_data.read_data_sets('MNIST_data', one_hot=False) 34 | 35 | if eval_on_cpu: 36 | with tf.device("/cpu:0"): 37 | model = Model() 38 | attack = LinfPGDAttack(model, 39 | config['epsilon'], 40 | config['k'], 41 | config['a'], 42 | config['random_start'], 43 | config['loss_func']) 44 | else: 45 | model = Model() 46 | attack = LinfPGDAttack(model, 47 | config['epsilon'], 48 | config['k'], 49 | config['a'], 50 | config['random_start'], 51 | config['loss_func']) 52 | 53 | global_step = tf.contrib.framework.get_or_create_global_step() 54 | 55 | # Setting up the Tensorboard and checkpoint outputs 56 | if not os.path.exists(model_dir): 57 | os.makedirs(model_dir) 58 | eval_dir = os.path.join(model_dir, 'eval') 59 | if not os.path.exists(eval_dir): 60 | os.makedirs(eval_dir) 61 | 62 | last_checkpoint_filename = '' 63 | already_seen_state = False 64 | 65 | saver = tf.train.Saver() 66 | summary_writer = tf.summary.FileWriter(eval_dir) 67 | 68 | # A function for evaluating a single checkpoint 69 | def evaluate_checkpoint(filename): 70 | with tf.Session() as sess: 71 | # Restore the checkpoint 72 | saver.restore(sess, filename) 73 | 74 | # Iterate over the samples batch-by-batch 75 | num_batches = int(math.ceil(num_eval_examples / eval_batch_size)) 76 | total_xent_nat = 0. 77 | total_xent_adv = 0. 78 | total_corr_nat = 0 79 | total_corr_adv = 0 80 | 81 | for ibatch in range(num_batches): 82 | bstart = ibatch * eval_batch_size 83 | bend = min(bstart + eval_batch_size, num_eval_examples) 84 | 85 | x_batch = mnist.test.images[bstart:bend, :] 86 | y_batch = mnist.test.labels[bstart:bend] 87 | 88 | dict_nat = {model.x_input: x_batch, 89 | model.y_input: y_batch} 90 | 91 | x_batch_adv = attack.perturb(x_batch, y_batch, sess) 92 | 93 | dict_adv = {model.x_input: x_batch_adv, 94 | model.y_input: y_batch} 95 | 96 | cur_corr_nat, cur_xent_nat = sess.run( 97 | [model.num_correct,model.xent], 98 | feed_dict = dict_nat) 99 | cur_corr_adv, cur_xent_adv = sess.run( 100 | [model.num_correct,model.xent], 101 | feed_dict = dict_adv) 102 | 103 | total_xent_nat += cur_xent_nat 104 | total_xent_adv += cur_xent_adv 105 | total_corr_nat += cur_corr_nat 106 | total_corr_adv += cur_corr_adv 107 | 108 | avg_xent_nat = total_xent_nat / num_eval_examples 109 | avg_xent_adv = total_xent_adv / num_eval_examples 110 | acc_nat = total_corr_nat / num_eval_examples 111 | acc_adv = total_corr_adv / num_eval_examples 112 | 113 | summary = tf.Summary(value=[ 114 | tf.Summary.Value(tag='xent adv eval', simple_value= avg_xent_adv), 115 | tf.Summary.Value(tag='xent adv', simple_value= avg_xent_adv), 116 | tf.Summary.Value(tag='xent nat', simple_value= avg_xent_nat), 117 | tf.Summary.Value(tag='accuracy adv eval', simple_value= acc_adv), 118 | tf.Summary.Value(tag='accuracy adv', simple_value= acc_adv), 119 | tf.Summary.Value(tag='accuracy nat', simple_value= acc_nat)]) 120 | summary_writer.add_summary(summary, global_step.eval(sess)) 121 | 122 | print('natural: {:.2f}%'.format(100 * acc_nat)) 123 | print('adversarial: {:.2f}%'.format(100 * acc_adv)) 124 | print('avg nat loss: {:.4f}'.format(avg_xent_nat)) 125 | print('avg adv loss: {:.4f}'.format(avg_xent_adv)) 126 | 127 | # Infinite eval loop 128 | while True: 129 | cur_checkpoint = tf.train.latest_checkpoint(model_dir) 130 | 131 | # Case 1: No checkpoint yet 132 | if cur_checkpoint is None: 133 | if not already_seen_state: 134 | print('No checkpoint yet, waiting ...', end='') 135 | already_seen_state = True 136 | else: 137 | print('.', end='') 138 | sys.stdout.flush() 139 | time.sleep(10) 140 | # Case 2: Previously unseen checkpoint 141 | elif cur_checkpoint != last_checkpoint_filename: 142 | print('\nCheckpoint {}, evaluating ... ({})'.format(cur_checkpoint, 143 | datetime.now())) 144 | sys.stdout.flush() 145 | last_checkpoint_filename = cur_checkpoint 146 | already_seen_state = False 147 | evaluate_checkpoint(cur_checkpoint) 148 | # Case 3: Previously evaluated checkpoint 149 | else: 150 | if not already_seen_state: 151 | print('Waiting for the next checkpoint ... ({}) '.format( 152 | datetime.now()), 153 | end='') 154 | already_seen_state = True 155 | else: 156 | print('.', end='') 157 | sys.stdout.flush() 158 | time.sleep(10) 159 | -------------------------------------------------------------------------------- /madry_mnist/fetch_model.py: -------------------------------------------------------------------------------- 1 | """Downloads a model, computes its SHA256 hash and unzips it 2 | at the proper location.""" 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import sys 8 | import zipfile 9 | import hashlib 10 | 11 | if len(sys.argv) != 2 or sys.argv[1] not in ['natural', 12 | 'adv_trained', 13 | 'secret']: 14 | print('Usage: python fetch_model.py [natural, adv_trained, secret]') 15 | sys.exit(1) 16 | 17 | if sys.argv[1] == 'natural': 18 | url = 'https://github.com/MadryLab/mnist_challenge_models/raw/master/natural.zip' 19 | elif sys.argv[1] == 'secret': 20 | url = 'https://github.com/MadryLab/mnist_challenge_models/raw/master/secret.zip' 21 | else: # fetch adv_trained model 22 | url = 'https://github.com/MadryLab/mnist_challenge_models/raw/master/adv_trained.zip' 23 | 24 | fname = url.split('/')[-1] # get the name of the file 25 | 26 | # model download 27 | print('Downloading models') 28 | if sys.version_info >= (3,): 29 | import urllib.request 30 | urllib.request.urlretrieve(url, fname) 31 | else: 32 | import urllib 33 | urllib.urlretrieve(url, fname) 34 | 35 | # computing model hash 36 | sha256 = hashlib.sha256() 37 | with open(fname, 'rb') as f: 38 | data = f.read() 39 | sha256.update(data) 40 | print('SHA256 hash: {}'.format(sha256.hexdigest())) 41 | 42 | # extracting model 43 | print('Extracting model') 44 | with zipfile.ZipFile(fname, 'r') as model_zip: 45 | model_zip.extractall() 46 | print('Extracted model in {}'.format(model_zip.namelist()[0])) 47 | -------------------------------------------------------------------------------- /madry_mnist/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | The model is adapted from the tensorflow tutorial: 3 | https://www.tensorflow.org/get_started/mnist/pros 4 | """ 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import tensorflow as tf 10 | 11 | class Model(object): 12 | def __init__(self): 13 | self.x_input = tf.placeholder(tf.float32, shape = [None, 784]) 14 | self.y_input = tf.placeholder(tf.int64, shape = [None]) 15 | 16 | self.x_image = tf.reshape(self.x_input, [-1, 28, 28, 1]) 17 | 18 | # first convolutional layer 19 | W_conv1 = self._weight_variable([5,5,1,32]) 20 | b_conv1 = self._bias_variable([32]) 21 | 22 | h_conv1 = tf.nn.relu(self._conv2d(self.x_image, W_conv1) + b_conv1) 23 | h_pool1 = self._max_pool_2x2(h_conv1) 24 | 25 | # second convolutional layer 26 | W_conv2 = self._weight_variable([5,5,32,64]) 27 | b_conv2 = self._bias_variable([64]) 28 | 29 | h_conv2 = tf.nn.relu(self._conv2d(h_pool1, W_conv2) + b_conv2) 30 | h_pool2 = self._max_pool_2x2(h_conv2) 31 | 32 | # first fully connected layer 33 | W_fc1 = self._weight_variable([7 * 7 * 64, 1024]) 34 | b_fc1 = self._bias_variable([1024]) 35 | 36 | h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) 37 | h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 38 | 39 | # output layer 40 | W_fc2 = self._weight_variable([1024,10]) 41 | b_fc2 = self._bias_variable([10]) 42 | 43 | self.pre_softmax = tf.matmul(h_fc1, W_fc2) + b_fc2 44 | 45 | y_xent = tf.nn.sparse_softmax_cross_entropy_with_logits( 46 | labels=self.y_input, logits=self.pre_softmax) 47 | 48 | self.xent_per_point = y_xent 49 | self.xent = tf.reduce_sum(y_xent) 50 | 51 | self.y_pred = tf.argmax(self.pre_softmax, 1) 52 | 53 | correct_prediction = tf.equal(self.y_pred, self.y_input) 54 | 55 | self.num_correct = tf.reduce_sum(tf.cast(correct_prediction, tf.int64)) 56 | self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 57 | 58 | @staticmethod 59 | def _weight_variable(shape): 60 | initial = tf.truncated_normal(shape, stddev=0.1) 61 | return tf.Variable(initial) 62 | 63 | @staticmethod 64 | def _bias_variable(shape): 65 | initial = tf.constant(0.1, shape = shape) 66 | return tf.Variable(initial) 67 | 68 | @staticmethod 69 | def _conv2d(x, W): 70 | return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME') 71 | 72 | @staticmethod 73 | def _max_pool_2x2( x): 74 | return tf.nn.max_pool(x, 75 | ksize = [1,2,2,1], 76 | strides=[1,2,2,1], 77 | padding='SAME') 78 | -------------------------------------------------------------------------------- /madry_mnist/run_attack.py: -------------------------------------------------------------------------------- 1 | """Evaluates a model against examples from a .npy file as specified 2 | in config.json""" 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from datetime import datetime 8 | import json 9 | import math 10 | import os 11 | import sys 12 | import time 13 | 14 | import tensorflow as tf 15 | from tensorflow.examples.tutorials.mnist import input_data 16 | 17 | import numpy as np 18 | 19 | from model import Model 20 | 21 | 22 | def run_attack(checkpoint, x_adv, epsilon): 23 | mnist = input_data.read_data_sets('MNIST_data', one_hot=False) 24 | 25 | model = Model() 26 | 27 | saver = tf.train.Saver() 28 | 29 | num_eval_examples = 10000 30 | eval_batch_size = 64 31 | 32 | num_batches = int(math.ceil(num_eval_examples / eval_batch_size)) 33 | total_corr = 0 34 | 35 | x_nat = mnist.test.images 36 | l_inf = np.amax(np.abs(x_nat - x_adv)) 37 | 38 | if l_inf > epsilon + 0.0001: 39 | print('maximum perturbation found: {}'.format(l_inf)) 40 | print('maximum perturbation allowed: {}'.format(epsilon)) 41 | return 42 | 43 | y_pred = [] # label accumulator 44 | 45 | with tf.Session() as sess: 46 | # Restore the checkpoint 47 | saver.restore(sess, checkpoint) 48 | 49 | # Iterate over the samples batch-by-batch 50 | for ibatch in range(num_batches): 51 | bstart = ibatch * eval_batch_size 52 | bend = min(bstart + eval_batch_size, num_eval_examples) 53 | 54 | x_batch = x_adv[bstart:bend, :] 55 | y_batch = mnist.test.labels[bstart:bend] 56 | 57 | dict_adv = {model.x_input: x_batch, 58 | model.y_input: y_batch} 59 | cur_corr, y_pred_batch = sess.run([model.num_correct, model.y_pred], 60 | feed_dict=dict_adv) 61 | 62 | total_corr += cur_corr 63 | y_pred.append(y_pred_batch) 64 | 65 | accuracy = total_corr / num_eval_examples 66 | 67 | print('Accuracy: {:.2f}%'.format(100.0 * accuracy)) 68 | y_pred = np.concatenate(y_pred, axis=0) 69 | np.save('pred.npy', y_pred) 70 | print('Output saved at pred.npy') 71 | 72 | 73 | if __name__ == '__main__': 74 | import json 75 | 76 | with open('config.json') as config_file: 77 | config = json.load(config_file) 78 | 79 | model_dir = config['model_dir'] 80 | 81 | checkpoint = tf.train.latest_checkpoint(model_dir) 82 | x_adv = np.load(config['store_adv_path']) 83 | 84 | if checkpoint is None: 85 | print('No checkpoint found') 86 | elif x_adv.shape != (10000, 784): 87 | print('Invalid shape: expected (10000,784), found {}'.format(x_adv.shape)) 88 | elif np.amax(x_adv) > 1.0001 or \ 89 | np.amin(x_adv) < -0.0001 or \ 90 | np.isnan(np.amax(x_adv)): 91 | print('Invalid pixel range. Expected [0, 1], found [{}, {}]'.format( 92 | np.amin(x_adv), 93 | np.amax(x_adv))) 94 | else: 95 | run_attack(checkpoint, x_adv, config['epsilon']) 96 | -------------------------------------------------------------------------------- /madry_mnist/train.py: -------------------------------------------------------------------------------- 1 | """Trains a model, saving checkpoints and tensorboard summaries along 2 | the way.""" 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from datetime import datetime 8 | import json 9 | import os 10 | import shutil 11 | from timeit import default_timer as timer 12 | 13 | import tensorflow as tf 14 | import numpy as np 15 | from tensorflow.examples.tutorials.mnist import input_data 16 | 17 | from model import Model 18 | from attack import LinfPGDAttack 19 | 20 | with open('config.json') as config_file: 21 | config = json.load(config_file) 22 | 23 | # Setting up training parameters 24 | tf.set_random_seed(config['random_seed']) 25 | 26 | max_num_training_steps = config['max_num_training_steps'] 27 | num_output_steps = config['num_output_steps'] 28 | num_summary_steps = config['num_summary_steps'] 29 | num_checkpoint_steps = config['num_checkpoint_steps'] 30 | 31 | batch_size = config['training_batch_size'] 32 | 33 | # Setting up the data and the model 34 | mnist = input_data.read_data_sets('MNIST_data', one_hot=False) 35 | global_step = tf.contrib.framework.get_or_create_global_step() 36 | model = Model() 37 | 38 | # Setting up the optimizer 39 | train_step = tf.train.AdamOptimizer(1e-4).minimize(model.xent, 40 | global_step=global_step) 41 | 42 | # Set up adversary 43 | attack = LinfPGDAttack(model, 44 | config['epsilon'], 45 | config['k'], 46 | config['a'], 47 | config['random_start'], 48 | config['loss_func']) 49 | 50 | # Setting up the Tensorboard and checkpoint outputs 51 | model_dir = config['model_dir'] 52 | if not os.path.exists(model_dir): 53 | os.makedirs(model_dir) 54 | 55 | # We add accuracy and xent twice so we can easily make three types of 56 | # comparisons in Tensorboard: 57 | # - train vs eval (for a single run) 58 | # - train of different runs 59 | # - eval of different runs 60 | 61 | saver = tf.train.Saver(max_to_keep=3) 62 | tf.summary.scalar('accuracy adv train', model.accuracy) 63 | tf.summary.scalar('accuracy adv', model.accuracy) 64 | tf.summary.scalar('xent adv train', model.xent / batch_size) 65 | tf.summary.scalar('xent adv', model.xent / batch_size) 66 | tf.summary.image('images adv train', model.x_image) 67 | merged_summaries = tf.summary.merge_all() 68 | 69 | shutil.copy('config.json', model_dir) 70 | 71 | with tf.Session() as sess: 72 | # Initialize the summary writer, global variables, and our time counter. 73 | summary_writer = tf.summary.FileWriter(model_dir, sess.graph) 74 | sess.run(tf.global_variables_initializer()) 75 | training_time = 0.0 76 | 77 | # Main training loop 78 | for ii in range(max_num_training_steps): 79 | x_batch, y_batch = mnist.train.next_batch(batch_size) 80 | 81 | # Compute Adversarial Perturbations 82 | start = timer() 83 | x_batch_adv = attack.perturb(x_batch, y_batch, sess) 84 | end = timer() 85 | training_time += end - start 86 | 87 | nat_dict = {model.x_input: x_batch, 88 | model.y_input: y_batch} 89 | 90 | adv_dict = {model.x_input: x_batch_adv, 91 | model.y_input: y_batch} 92 | 93 | # Output to stdout 94 | if ii % num_output_steps == 0: 95 | nat_acc = sess.run(model.accuracy, feed_dict=nat_dict) 96 | adv_acc = sess.run(model.accuracy, feed_dict=adv_dict) 97 | print('Step {}: ({})'.format(ii, datetime.now())) 98 | print(' training nat accuracy {:.4}%'.format(nat_acc * 100)) 99 | print(' training adv accuracy {:.4}%'.format(adv_acc * 100)) 100 | if ii != 0: 101 | print(' {} examples per second'.format( 102 | num_output_steps * batch_size / training_time)) 103 | training_time = 0.0 104 | # Tensorboard summaries 105 | if ii % num_summary_steps == 0: 106 | summary = sess.run(merged_summaries, feed_dict=adv_dict) 107 | summary_writer.add_summary(summary, global_step.eval(sess)) 108 | 109 | # Write a checkpoint 110 | if ii % num_checkpoint_steps == 0: 111 | saver.save(sess, 112 | os.path.join(model_dir, 'checkpoint'), 113 | global_step=global_step) 114 | 115 | # Actual training step 116 | start = timer() 117 | sess.run(train_step, feed_dict=adv_dict) 118 | end = timer() 119 | training_time += end - start 120 | -------------------------------------------------------------------------------- /metrics/2019-11-10 15:57:14 model=pt_inception dataset=imagenet n_ex=1000 eps=12.75 p=0.05 n_iter=10000.metrics.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/metrics/2019-11-10 15:57:14 model=pt_inception dataset=imagenet n_ex=1000 eps=12.75 p=0.05 n_iter=10000.metrics.npy -------------------------------------------------------------------------------- /metrics/2019-11-10 15:57:14 model=pt_resnet dataset=imagenet n_ex=1000 eps=12.75 p=0.05 n_iter=10000.metrics.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/metrics/2019-11-10 15:57:14 model=pt_resnet dataset=imagenet n_ex=1000 eps=12.75 p=0.05 n_iter=10000.metrics.npy -------------------------------------------------------------------------------- /metrics/2019-11-10 15:57:14 model=pt_vgg dataset=imagenet n_ex=1000 eps=12.75 p=0.05 n_iter=10000.metrics.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/metrics/2019-11-10 15:57:14 model=pt_vgg dataset=imagenet n_ex=1000 eps=12.75 p=0.05 n_iter=10000.metrics.npy -------------------------------------------------------------------------------- /metrics/square_l2_inceptionv3_queries.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/metrics/square_l2_inceptionv3_queries.npy -------------------------------------------------------------------------------- /metrics/square_l2_resnet50_queries.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/metrics/square_l2_resnet50_queries.npy -------------------------------------------------------------------------------- /metrics/square_l2_vgg16_queries.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/max-andr/square-attack/ea95eebb5aca62ec790a927b5aa985ba4e87245c/metrics/square_l2_vgg16_queries.npy -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tensorflow as tf 3 | import numpy as np 4 | import math 5 | import utils 6 | from torchvision import models as torch_models 7 | from torch.nn import DataParallel 8 | from madry_mnist.model import Model as madry_model_mnist 9 | from madry_cifar10.model import Model as madry_model_cifar10 10 | from logit_pairing.models import LeNet as lp_model_mnist, ResNet20_v2 as lp_model_cifar10 11 | from post_avg.postAveragedModels import pa_resnet110_config1 as post_avg_cifar10_resnet 12 | from post_avg.postAveragedModels import pa_resnet152_config1 as post_avg_imagenet_resnet 13 | 14 | 15 | class Model: 16 | def __init__(self, batch_size, gpu_memory): 17 | self.batch_size = batch_size 18 | self.gpu_memory = gpu_memory 19 | 20 | def predict(self, x): 21 | raise NotImplementedError('use ModelTF or ModelPT') 22 | 23 | def loss(self, y, logits, targeted=False, loss_type='margin_loss'): 24 | """ Implements the margin loss (difference between the correct and 2nd best class). """ 25 | if loss_type == 'margin_loss': 26 | preds_correct_class = (logits * y).sum(1, keepdims=True) 27 | diff = preds_correct_class - logits # difference between the correct class and all other classes 28 | diff[y] = np.inf # to exclude zeros coming from f_correct - f_correct 29 | margin = diff.min(1, keepdims=True) 30 | loss = margin * -1 if targeted else margin 31 | elif loss_type == 'cross_entropy': 32 | probs = utils.softmax(logits) 33 | loss = -np.log(probs[y]) 34 | loss = loss * -1 if not targeted else loss 35 | else: 36 | raise ValueError('Wrong loss.') 37 | return loss.flatten() 38 | 39 | 40 | class ModelTF(Model): 41 | """ 42 | Wrapper class around TensorFlow models. 43 | 44 | In order to incorporate a new model, one has to ensure that self.model has a TF variable `logits`, 45 | and that the preprocessing of the inputs is done correctly (e.g. subtracting the mean and dividing over the 46 | standard deviation). 47 | """ 48 | def __init__(self, model_name, batch_size, gpu_memory): 49 | super().__init__(batch_size, gpu_memory) 50 | model_folder = model_path_dict[model_name] 51 | model_file = tf.train.latest_checkpoint(model_folder) 52 | self.model = model_class_dict[model_name]() 53 | self.batch_size = batch_size 54 | self.model_name = model_name 55 | self.model_file = model_file 56 | if 'logits' not in self.model.__dict__: 57 | self.model.logits = self.model.pre_softmax 58 | 59 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_memory) 60 | config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options) 61 | self.sess = tf.Session(config=config) 62 | tf.train.Saver().restore(self.sess, model_file) 63 | 64 | def predict(self, x): 65 | if 'mnist' in self.model_name: 66 | shape = self.model.x_input.shape[1:].as_list() 67 | x = np.reshape(x, [-1, *shape]) 68 | elif 'cifar10' in self.model_name: 69 | x = np.transpose(x, axes=[0, 2, 3, 1]) 70 | 71 | n_batches = math.ceil(x.shape[0] / self.batch_size) 72 | logits_list = [] 73 | for i in range(n_batches): 74 | x_batch = x[i*self.batch_size:(i+1)*self.batch_size] 75 | logits = self.sess.run(self.model.logits, feed_dict={self.model.x_input: x_batch}) 76 | logits_list.append(logits) 77 | logits = np.vstack(logits_list) 78 | return logits 79 | 80 | 81 | class ModelPT(Model): 82 | """ 83 | Wrapper class around PyTorch models. 84 | 85 | In order to incorporate a new model, one has to ensure that self.model is a callable object that returns logits, 86 | and that the preprocessing of the inputs is done correctly (e.g. subtracting the mean and dividing over the 87 | standard deviation). 88 | """ 89 | def __init__(self, model_name, batch_size, gpu_memory): 90 | super().__init__(batch_size, gpu_memory) 91 | if model_name in ['pt_vgg', 'pt_resnet', 'pt_inception', 'pt_densenet']: 92 | model = model_class_dict[model_name](pretrained=True) 93 | self.mean = np.reshape([0.485, 0.456, 0.406], [1, 3, 1, 1]) 94 | self.std = np.reshape([0.229, 0.224, 0.225], [1, 3, 1, 1]) 95 | model = DataParallel(model.cuda()) 96 | else: 97 | model = model_class_dict[model_name]() 98 | if model_name in ['pt_post_avg_cifar10', 'pt_post_avg_imagenet']: 99 | # checkpoint = torch.load(model_path_dict[model_name]) 100 | self.mean = np.reshape([0.485, 0.456, 0.406], [1, 3, 1, 1]) 101 | self.std = np.reshape([0.229, 0.224, 0.225], [1, 3, 1, 1]) 102 | else: 103 | model = DataParallel(model).cuda() 104 | checkpoint = torch.load(model_path_dict[model_name] + '.pth') 105 | self.mean = np.reshape([0.485, 0.456, 0.406], [1, 3, 1, 1]) 106 | self.std = np.reshape([0.225, 0.225, 0.225], [1, 3, 1, 1]) 107 | model.load_state_dict(checkpoint) 108 | model.float() 109 | self.mean, self.std = self.mean.astype(np.float32), self.std.astype(np.float32) 110 | 111 | model.eval() 112 | self.model = model 113 | 114 | def predict(self, x): 115 | x = (x - self.mean) / self.std 116 | x = x.astype(np.float32) 117 | 118 | n_batches = math.ceil(x.shape[0] / self.batch_size) 119 | logits_list = [] 120 | with torch.no_grad(): # otherwise consumes too much memory and leads to a slowdown 121 | for i in range(n_batches): 122 | x_batch = x[i*self.batch_size:(i+1)*self.batch_size] 123 | x_batch_torch = torch.as_tensor(x_batch, device=torch.device('cuda')) 124 | logits = self.model(x_batch_torch).cpu().numpy() 125 | logits_list.append(logits) 126 | logits = np.vstack(logits_list) 127 | return logits 128 | 129 | 130 | model_path_dict = {'madry_mnist_robust': 'madry_mnist/models/robust', 131 | 'madry_cifar10_robust': 'madry_cifar10/models/robust', 132 | 'clp_mnist': 'logit_pairing/models/clp_mnist', 133 | 'lsq_mnist': 'logit_pairing/models/lsq_mnist', 134 | 'clp_cifar10': 'logit_pairing/models/clp_cifar10', 135 | 'lsq_cifar10': 'logit_pairing/models/lsq_cifar10', 136 | 'pt_post_avg_cifar10': 'post_avg/trainedModel/resnet110.th' 137 | } 138 | model_class_dict = {'pt_vgg': torch_models.vgg16_bn, 139 | 'pt_resnet': torch_models.resnet50, 140 | 'pt_inception': torch_models.inception_v3, 141 | 'pt_densenet': torch_models.densenet121, 142 | 'madry_mnist_robust': madry_model_mnist, 143 | 'madry_cifar10_robust': madry_model_cifar10, 144 | 'clp_mnist': lp_model_mnist, 145 | 'lsq_mnist': lp_model_mnist, 146 | 'clp_cifar10': lp_model_cifar10, 147 | 'lsq_cifar10': lp_model_cifar10, 148 | 'pt_post_avg_cifar10': post_avg_cifar10_resnet, 149 | 'pt_post_avg_imagenet': post_avg_imagenet_resnet, 150 | } 151 | all_model_names = list(model_class_dict.keys()) 152 | 153 | -------------------------------------------------------------------------------- /post_avg/LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2019] [Yuping Lin] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /post_avg/PADefense.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import time 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.utils.data as data 8 | import torch.cuda as cuda 9 | import torchvision.transforms as transforms 10 | import torchvision.utils as utl 11 | import torch.backends.cudnn as cudnn 12 | 13 | import torchvision.datasets as datasets 14 | import torchvision.models as mdl 15 | 16 | def checkEntropy(scores): 17 | scores = scores.squeeze() 18 | scr = scores.clone() 19 | scr[scr <= 0] = 1.0 20 | return - torch.sum(scores * torch.log(scr)) 21 | 22 | 23 | def checkConfidence(scores, K=10): 24 | scores = scores.squeeze() 25 | hScores, _ = torch.sort(scores, dim=0, descending=True) 26 | 27 | return hScores[0] / torch.sum(hScores[:K]) 28 | 29 | 30 | def integratedForward(model, sps, batchSize, nClasses, device='cpu', voteMethod='avg_softmax'): 31 | N = sps.size(0) 32 | feats = torch.empty(N, nClasses) 33 | model = model.to(device) 34 | 35 | with torch.no_grad(): 36 | baseInx = 0 37 | while baseInx < N: 38 | cuda.empty_cache() 39 | endInx = min(baseInx + batchSize, N) 40 | y = model(sps[baseInx:endInx, :].to(device)).detach().to('cpu') 41 | feats[baseInx:endInx, :] = y 42 | baseInx = endInx 43 | 44 | if voteMethod == 'avg_feat': 45 | feat = torch.mean(feats, dim=0, keepdim=True) 46 | elif voteMethod == 'most_vote': 47 | maxV, _ = torch.max(feats, dim=1, keepdim=True) 48 | feat = torch.sum(feats == maxV, dim=0, keepdim=True) 49 | elif voteMethod == 'weighted_feat': 50 | feat = torch.mean(feats, dim=0, keepdim=True) 51 | maxV, _ = torch.max(feats, dim=1, keepdim=True) 52 | feat = feat * torch.sum(feats == maxV, dim=0, keepdim=True).float() 53 | elif voteMethod == 'avg_softmax': 54 | feats = nn.functional.softmax(feats, dim=1) 55 | feat = torch.mean(feats, dim=0, keepdim=True) 56 | else: 57 | # default method: avg_softmax 58 | feats = nn.functional.softmax(feats, dim=1) 59 | feat = torch.mean(feats, dim=0, keepdim=True) 60 | 61 | return feat, feats 62 | 63 | # not updated, deprecated 64 | def integratedForward_cls(model, sps, batchSize, nClasses, device='cpu', count_votes=False): 65 | N = sps.size(0) 66 | feats = torch.empty(N, nClasses) 67 | model = model.to(device) 68 | 69 | with torch.no_grad(): 70 | baseInx = 0 71 | while baseInx < N: 72 | cuda.empty_cache() 73 | endInx = min(baseInx + batchSize, N) 74 | y = model.classifier(sps[baseInx:endInx, :].to(device)).detach().to('cpu') 75 | feats[baseInx:endInx, :] = y 76 | baseInx = endInx 77 | 78 | if count_votes: 79 | maxV, _ = torch.max(feats, dim=1, keepdim=True) 80 | feat = torch.sum(feats == maxV, dim=0, keepdim=True) 81 | else: 82 | feat = torch.mean(feats, dim=0, keepdim=True) 83 | 84 | return feat, feats 85 | 86 | 87 | def findNeighbors_random(sp, K, r=[2], direction='both'): 88 | # only accept single sample 89 | if sp.size(0) != 1: 90 | return None 91 | 92 | if isinstance(K, list): 93 | K = sum(K) 94 | 95 | # randomly select directions 96 | shifts = torch.randn(K, sp.size(1) * sp.size(2) * sp.size(3)).to('cuda') 97 | shifts = nn.functional.normalize(shifts, p=2, dim=1) 98 | shifts = shifts.view(K, sp.size(1), sp.size(2), sp.size(3)).contiguous() 99 | 100 | if direction == 'both': 101 | shifts = torch.cat([shifts, -shifts], dim=0) 102 | 103 | nbs = [] 104 | for rInx in range(len(r)): 105 | nbs.append(sp.to('cuda') + r[rInx] * shifts) 106 | 107 | return torch.cat(nbs, dim=0) 108 | 109 | 110 | def findNeighbors_plain_vgg(model, sp, K, r=[2], direction='both', device='cpu'): 111 | # only accept single sample 112 | if sp.size(0) != 1: 113 | return None 114 | 115 | # storages for K selected distances and linear mapping 116 | selected_list = [] 117 | 118 | # set model to evaluation mode 119 | model = model.to(device) 120 | model = model.eval() 121 | 122 | # place holder for input, and set to require gradient 123 | x = sp.clone().to(device) 124 | x.requires_grad = True 125 | 126 | # forward through the feature part 127 | y = model.features(x) 128 | y = model.avgpool(y) 129 | y = y.view(y.size(0), -1) 130 | 131 | # forward through classifier layer by layer 132 | for lyInx, module in model.classifier.named_children(): 133 | # forward 134 | y = module(y) 135 | 136 | # at each layer activation 137 | if isinstance(module, nn.Linear): 138 | # for each neuron 139 | for i in range(y.size(1)): 140 | # clear previous gradients 141 | x.grad = None 142 | 143 | # compute gradients 144 | goal = torch.abs(y[0, i]) 145 | goal.backward(retain_graph=True) # retain graph for future computation 146 | 147 | # compute distance 148 | d = torch.abs(y[0, i]) / torch.norm(x.grad) 149 | 150 | # keep K shortest distances 151 | selected_list.append((d.clone().detach().to('cpu'), x.grad.clone().detach().to('cpu'))) 152 | selected_list = sorted(selected_list, key=lambda x:x[0], reverse=False) 153 | selected_list = selected_list[0:K] 154 | 155 | # generate neighboring samples 156 | grad_list = [e[1] / torch.norm(e[1]) for e in selected_list] 157 | unit_shifts = torch.cat(grad_list, dim=0) 158 | nbs = [] 159 | for rInx in range(len(r)): 160 | if direction == 'inc': 161 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts) 162 | elif direction == 'dec': 163 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts) 164 | else: 165 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts) 166 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts) 167 | nbs = torch.cat(nbs, dim=0) 168 | nbs = nbs.detach() 169 | nbs.requires_grad = False 170 | 171 | return nbs 172 | 173 | 174 | def findNeighbors_lastLy_vgg(model, sp, K, r=[2], direction='both', device='cpu'): 175 | # only accept single sample 176 | if sp.size(0) != 1: 177 | return None 178 | 179 | # storages for K selected distances and linear mapping 180 | selected_list = [] 181 | 182 | # set model to evaluation mode 183 | model = model.to(device) 184 | model = model.eval() 185 | 186 | # place holder for input, and set to require gradient 187 | x = sp.clone().to(device) 188 | x.requires_grad = True 189 | 190 | # forward through the feature part 191 | y = model(x) 192 | y = y.view(y.size(0), -1) 193 | 194 | for i in range(y.size(1)): 195 | # clear previous gradients 196 | x.grad = None 197 | 198 | # compute gradients 199 | goal = torch.abs(y[0, i]) 200 | if i < y.size(1) - 1: 201 | goal.backward(retain_graph=True) # retain graph for future computation 202 | else: 203 | goal.backward(retain_graph=False) 204 | 205 | # compute distance 206 | d = torch.abs(y[0, i]) / torch.norm(x.grad) 207 | 208 | # keep K shortest distances 209 | selected_list.append((d.clone().detach().to('cpu'), x.grad.clone().detach().to('cpu'))) 210 | selected_list = sorted(selected_list, key=lambda x:x[0], reverse=False) 211 | selected_list = selected_list[0:K] 212 | 213 | # generate neighboring samples 214 | grad_list = [e[1] / torch.norm(e[1]) for e in selected_list] 215 | unit_shifts = torch.cat(grad_list, dim=0) 216 | nbs = [] 217 | for rInx in range(len(r)): 218 | if direction == 'inc': 219 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts) 220 | elif direction == 'dec': 221 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts) 222 | else: 223 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts) 224 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts) 225 | nbs = torch.cat(nbs, dim=0) 226 | nbs = nbs.detach() 227 | nbs.requires_grad = False 228 | 229 | return nbs 230 | 231 | 232 | def findNeighbors_approx_vgg(model, sp, K, r=[2], direction='both', device='cpu'): 233 | # only accept single sample 234 | if sp.size(0) != 1: 235 | return None 236 | 237 | # storages for K selected distances and linear mapping 238 | selected_list = [] 239 | 240 | # set model to evaluation mode 241 | model = model.to(device) 242 | model = model.eval() 243 | 244 | # place holder for input, and set to require gradient 245 | x = sp.clone().to(device) 246 | x.requires_grad = True 247 | 248 | # forward through the feature part 249 | y = model.features(x) 250 | y = model.avgpool(y) 251 | y = y.view(y.size(0), -1) 252 | 253 | # forward through classifier layer by layer 254 | lnLy_inx = 0 255 | for lyInx, module in model.classifier.named_children(): 256 | # forward 257 | y = module(y) 258 | 259 | # at each layer activation 260 | if isinstance(module, nn.Linear): 261 | KInx = min(lnLy_inx, len(K)-1) 262 | if K[KInx] > 0: 263 | with torch.no_grad(): 264 | # compute weight norm 265 | w_norm = torch.norm(module.weight, dim=1, keepdim=True) 266 | 267 | # compute distance 268 | d = torch.abs(y) / w_norm.t() 269 | _, sortedInx = torch.sort(d, dim=1, descending=False) 270 | 271 | # for each selected neuron 272 | for i in range(K[KInx]): 273 | 274 | # clear previous gradients 275 | x.grad = None 276 | 277 | # compute gradients 278 | goal = torch.abs(y[0, sortedInx[0, i]]) 279 | goal.backward(retain_graph=True) # retain graph for future computation 280 | 281 | # record gradients 282 | selected_list.append(x.grad.clone().detach().to('cpu') / torch.norm(x.grad).detach().to('cpu')) 283 | 284 | # update number of linear layer sampled 285 | lnLy_inx = lnLy_inx + 1 286 | 287 | # generate neighboring samples 288 | unit_shifts = torch.cat(selected_list, dim=0) 289 | nbs = [] 290 | for rInx in range(len(r)): 291 | if direction == 'inc': 292 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts) 293 | elif direction == 'dec': 294 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts) 295 | else: 296 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts) 297 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts) 298 | nbs = torch.cat(nbs, dim=0) 299 | nbs = nbs.detach() 300 | nbs.requires_grad = False 301 | 302 | return nbs 303 | 304 | 305 | def findNeighbors_randPick_vgg(model, sp, K, r=[2], direction='both', device='cpu'): 306 | # only accept single sample 307 | if sp.size(0) != 1: 308 | return None 309 | 310 | # storages for K selected distances and linear mapping 311 | selected_list = [] 312 | 313 | # set model to evaluation mode 314 | model = model.to(device) 315 | model = model.eval() 316 | 317 | # place holder for input, and set to require gradient 318 | x = sp.clone().to(device) 319 | x.requires_grad = True 320 | 321 | # forward through the feature part 322 | y = model.features(x) 323 | y = model.avgpool(y) 324 | y = y.view(y.size(0), -1) 325 | 326 | # forward through classifier layer by layer 327 | lnLy_inx = 0 328 | for lyInx, module in model.classifier.named_children(): 329 | # forward 330 | y = module(y) 331 | 332 | # at each layer activation 333 | if isinstance(module, nn.Linear): 334 | KInx = min(lnLy_inx, len(K)-1) 335 | if K[KInx] > 0: 336 | # randomly permute indices 337 | pickInx = torch.randperm(y.size(1)) 338 | 339 | # for each selected neuron 340 | for i in range(K[KInx]): 341 | 342 | # clear previous gradients 343 | x.grad = None 344 | 345 | # compute gradients 346 | goal = torch.abs(y[0, pickInx[i]]) 347 | goal.backward(retain_graph=True) # retain graph for future computation 348 | 349 | # record gradients 350 | selected_list.append(x.grad.clone().detach().to('cpu') / torch.norm(x.grad).detach().to('cpu')) 351 | 352 | # update number of linear layer sampled 353 | lnLy_inx = lnLy_inx + 1 354 | 355 | # generate neighboring samples 356 | unit_shifts = torch.cat(selected_list, dim=0) 357 | nbs = [] 358 | for rInx in range(len(r)): 359 | if direction == 'inc': 360 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts) 361 | elif direction == 'dec': 362 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts) 363 | else: 364 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts) 365 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts) 366 | nbs = torch.cat(nbs, dim=0) 367 | nbs = nbs.detach() 368 | nbs.requires_grad = False 369 | 370 | return nbs 371 | 372 | # not updated, deprecated 373 | def findNeighbors_feats_lastLy_vgg(model, sp, K, r=[2], direction='both', device='cpu', includeOriginal=True): 374 | # only accept single sample 375 | if sp.size(0) != 1: 376 | return None 377 | 378 | # storages for K selected distances and linear mapping 379 | selected_list = [] 380 | 381 | # set model to evaluation mode 382 | model = model.to(device) 383 | model = model.eval() 384 | 385 | # forward through the feature part 386 | with torch.no_grad(): 387 | feat = model.features(sp.to(device)) 388 | feat = feat.view(feat.size(0), -1).contiguous().detach() 389 | 390 | # place holder for feature, and set to require gradient 391 | x = feat.clone().detach() 392 | x.requires_grad = True 393 | 394 | # forward through the classifier part 395 | y = model.classifier(x) 396 | y = y.view(y.size(0), -1) 397 | 398 | for i in range(y.size(1)): 399 | # clear previous gradients 400 | x.grad = None 401 | 402 | # compute gradients 403 | goal = torch.abs(y[0, i]) 404 | if i < y.size(1) - 1: 405 | goal.backward(retain_graph=True) # retain graph for future computation 406 | else: 407 | goal.backward(retain_graph=False) 408 | 409 | # compute distance 410 | d = torch.abs(y[0, i]) / torch.norm(x.grad) 411 | 412 | # keep K shortest distances 413 | selected_list.append((d.clone().detach().to('cpu'), x.grad.clone().detach().to('cpu'))) 414 | selected_list = sorted(selected_list, key=lambda x:x[0], reverse=False) 415 | selected_list = selected_list[0:K] 416 | 417 | # generate neighboring samples 418 | grad_list = [e[1] / torch.norm(e[1]) for e in selected_list] 419 | unit_shifts = torch.cat(grad_list, dim=0) 420 | if includeOriginal: 421 | nbs = [feat.to('cpu')] 422 | else: 423 | nbs = [] 424 | for rInx in range(len(r)): 425 | if direction == 'inc': 426 | nbs.append(feat.to('cpu') + r[rInx] * unit_shifts) 427 | elif direction == 'dec': 428 | nbs.append(feat.to('cpu') - r[rInx] * unit_shifts) 429 | else: 430 | nbs.append(feat.to('cpu') + r[rInx] * unit_shifts) 431 | nbs.append(feat.to('cpu') - r[rInx] * unit_shifts) 432 | nbs = torch.cat(nbs, dim=0) 433 | nbs = nbs.detach() 434 | nbs.requires_grad = False 435 | 436 | return nbs 437 | 438 | # not updated, deprecated 439 | def findNeighbors_feats_approx_vgg(model, sp, K, r=[2], direction='both', device='cpu', includeOriginal=True): 440 | # only accept single sample 441 | if sp.size(0) != 1: 442 | return None 443 | 444 | # storages for K selected distances and linear mapping 445 | selected_list = [] 446 | 447 | # set model to evaluation mode 448 | model = model.to(device) 449 | model = model.eval() 450 | 451 | # forward through the feature part 452 | with torch.no_grad(): 453 | feat = model.features(sp.to(device)) 454 | feat = feat.view(feat.size(0), -1).contiguous().detach() 455 | 456 | # place holder for feature, and set to require gradient 457 | x = feat.clone().detach() 458 | x.requires_grad = True 459 | y = x 460 | 461 | # forward through classifier layer by layer 462 | lnLy_inx = 0 463 | for lyInx, module in model.classifier.named_children(): 464 | # forward 465 | y = module(y) 466 | 467 | # at each layer activation 468 | if isinstance(module, nn.Linear): 469 | KInx = min(lnLy_inx, len(K)-1) 470 | if K[KInx] > 0: 471 | with torch.no_grad(): 472 | # compute weight norm 473 | w_norm = torch.norm(module.weight, dim=1, keepdim=True) 474 | 475 | # compute distance 476 | d = torch.abs(y) / w_norm.t() 477 | _, sortedInx = torch.sort(d, dim=1, descending=False) 478 | 479 | # for each selected neuron 480 | for i in range(K[KInx]): 481 | 482 | # clear previous gradients 483 | x.grad = None 484 | 485 | # compute gradients 486 | goal = torch.abs(y[0, sortedInx[0, i]]) 487 | goal.backward(retain_graph=True) # retain graph for future computation 488 | 489 | # record gradients 490 | selected_list.append(x.grad.clone().detach().to('cpu') / torch.norm(x.grad).detach().to('cpu')) 491 | 492 | # update number of linear layer sampled 493 | lnLy_inx = lnLy_inx + 1 494 | 495 | # generate neighboring samples 496 | unit_shifts = torch.cat(selected_list, dim=0) 497 | if includeOriginal: 498 | nbs = [feat.to('cpu')] 499 | else: 500 | nbs = [] 501 | for rInx in range(len(r)): 502 | if direction == 'inc': 503 | nbs.append(feat.to('cpu') + r[rInx] * unit_shifts) 504 | elif direction == 'dec': 505 | nbs.append(feat.to('cpu') - r[rInx] * unit_shifts) 506 | else: 507 | nbs.append(feat.to('cpu') + r[rInx] * unit_shifts) 508 | nbs.append(feat.to('cpu') - r[rInx] * unit_shifts) 509 | nbs = torch.cat(nbs, dim=0) 510 | nbs = nbs.detach() 511 | nbs.requires_grad = False 512 | 513 | return nbs 514 | 515 | 516 | def formSquad_vgg(method, model, sp, K, r=[2], direction='both', device='cpu', includeOriginal=True): 517 | if method == 'random': 518 | nbs = findNeighbors_random(sp, K, r, direction=direction) 519 | if includeOriginal: 520 | nbs = torch.cat([sp, nbs], dim=0) 521 | elif method == 'plain': 522 | nbs = findNeighbors_plain_vgg(model, sp, K, r, direction=direction, device=device) 523 | if includeOriginal: 524 | nbs = torch.cat([sp, nbs], dim=0) 525 | elif method == 'lastLy': 526 | nbs = findNeighbors_lastLy_vgg(model, sp, K, r, direction=direction, device=device) 527 | if includeOriginal: 528 | nbs = torch.cat([sp, nbs], dim=0) 529 | elif method == 'approx': 530 | nbs = findNeighbors_approx_vgg(model, sp, K, r, direction=direction, device=device) 531 | if includeOriginal: 532 | nbs = torch.cat([sp, nbs], dim=0) 533 | elif method == 'randPick': 534 | nbs = findNeighbors_randPick_vgg(model, sp, K, r, direction=direction, device=device) 535 | if includeOriginal: 536 | nbs = torch.cat([sp, nbs], dim=0) 537 | elif method == 'feats_lastLy': 538 | nbs = findNeighbors_feats_lastLy_vgg(model, sp, K, r, direction=direction, device=device, includeOriginal=includeOriginal) 539 | elif method == 'feats_approx': 540 | nbs = findNeighbors_feats_approx_vgg(model, sp, K, r, direction=direction, device=device, includeOriginal=includeOriginal) 541 | else: 542 | # if invalid method, use default setting. (actually should raise error here) 543 | nbs = findNeighbors_random(sp, K, r, direction=direction) 544 | if includeOriginal: 545 | nbs = torch.cat([sp, nbs], dim=0) 546 | 547 | return nbs 548 | 549 | 550 | def findNeighbors_approx_resnet(model, sp, K, r=[2], direction='both', device='cpu'): 551 | # only accept single sample 552 | if sp.size(0) != 1: 553 | return None 554 | 555 | # storages for K selected distances and linear mapping 556 | selected_list = [] 557 | 558 | # set model to evaluation mode 559 | model = model.to(device) 560 | model = model.eval() 561 | 562 | # place holder for input, and set to require gradient 563 | x = sp.clone().to(device) 564 | x.requires_grad = True 565 | 566 | # forward through the model 567 | y = model(x) 568 | y = y.view(y.size(0), -1) 569 | 570 | if K > 0: 571 | with torch.no_grad(): 572 | # compute weight norm 573 | w_norm = torch.norm(model.fc.weight, dim=1, keepdim=True) 574 | 575 | # compute distance 576 | d = torch.abs(y) / w_norm.t() 577 | _, sortedInx = torch.sort(d, dim=1, descending=False) 578 | 579 | # for each selected neuron 580 | for i in range(K): 581 | 582 | # clear previous gradients 583 | x.grad = None 584 | 585 | # compute gradients 586 | goal = torch.abs(y[0, sortedInx[0, i]]) 587 | goal.backward(retain_graph=True) # retain graph for future computation 588 | 589 | # record gradients 590 | selected_list.append(x.grad.clone().detach().to('cpu') / torch.norm(x.grad).detach().to('cpu')) 591 | 592 | # generate neighboring samples 593 | unit_shifts = torch.cat(selected_list, dim=0) 594 | nbs = [] 595 | for rInx in range(len(r)): 596 | if direction == 'inc': 597 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts) 598 | elif direction == 'dec': 599 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts) 600 | else: 601 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts) 602 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts) 603 | nbs = torch.cat(nbs, dim=0) 604 | nbs = nbs.detach() 605 | nbs.requires_grad = False 606 | 607 | return nbs 608 | 609 | 610 | def findNeighbors_approx_resnet_small(model, sp, K, r=[2], direction='both', device='cpu'): 611 | # only accept single sample 612 | if sp.size(0) != 1: 613 | return None 614 | 615 | # storages for K selected distances and linear mapping 616 | selected_list = [] 617 | 618 | # set model to evaluation mode 619 | model = model.to(device) 620 | model = model.eval() 621 | 622 | # place holder for input, and set to require gradient 623 | x = sp.clone().to(device) 624 | x.requires_grad = True 625 | 626 | # forward through the model 627 | y = model(x) 628 | y = y.view(y.size(0), -1) 629 | 630 | if K > 0: 631 | with torch.no_grad(): 632 | # compute weight norm 633 | w_norm = torch.norm(model.linear.weight, dim=1, keepdim=True) 634 | 635 | # compute distance 636 | d = torch.abs(y) / w_norm.t() 637 | _, sortedInx = torch.sort(d, dim=1, descending=False) 638 | 639 | # for each selected neuron 640 | for i in range(K): 641 | 642 | # clear previous gradients 643 | x.grad = None 644 | 645 | # compute gradients 646 | goal = torch.abs(y[0, sortedInx[0, i]]) 647 | goal.backward(retain_graph=True) # retain graph for future computation 648 | 649 | # record gradients 650 | selected_list.append(x.grad.clone().detach().to('cpu') / torch.norm(x.grad).detach().to('cpu')) 651 | 652 | # generate neighboring samples 653 | unit_shifts = torch.cat(selected_list, dim=0) 654 | nbs = [] 655 | for rInx in range(len(r)): 656 | if direction == 'inc': 657 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts) 658 | elif direction == 'dec': 659 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts) 660 | else: 661 | nbs.append(sp.to('cpu') + r[rInx] * unit_shifts) 662 | nbs.append(sp.to('cpu') - r[rInx] * unit_shifts) 663 | nbs = torch.cat(nbs, dim=0) 664 | nbs = nbs.detach() 665 | nbs.requires_grad = False 666 | 667 | return nbs 668 | 669 | 670 | def formSquad_resnet(method, model, sp, K, r=[2], direction='both', device='cpu', includeOriginal=True): 671 | if method == 'random': 672 | nbs = findNeighbors_random(sp, K, r, direction=direction) 673 | if includeOriginal: 674 | nbs = torch.cat([sp, nbs], dim=0) 675 | elif method == 'approx': 676 | nbs = findNeighbors_approx_resnet(model, sp, K, r, direction=direction, device=device) 677 | if includeOriginal: 678 | nbs = torch.cat([sp, nbs], dim=0) 679 | elif method == 'approx_cifar10': 680 | nbs = findNeighbors_approx_resnet_small(model, sp, K, r, direction=direction, device=device) 681 | if includeOriginal: 682 | nbs = torch.cat([sp, nbs], dim=0) 683 | else: 684 | # if invalid method, use default setting. (actually should raise error here) 685 | nbs = findNeighbors_random(sp, K, r, direction=direction) 686 | if includeOriginal: 687 | nbs = torch.cat([sp, nbs], dim=0) 688 | 689 | return nbs 690 | -------------------------------------------------------------------------------- /post_avg/README.md: -------------------------------------------------------------------------------- 1 | # Post-Average Adversarial Defense 2 | Implementation of the Post-Average adversarial defense method as described in [Bandlimiting Neural Networks Against Adversarial Attacks](https://arxiv.org/abs/1905.12797). 3 | 4 | This implementation is based on PyTorch and uses the [Foolbox](https://github.com/bethgelab/foolbox) toolbox to provide attacking methods. 5 | 6 | ## [robustml](https://github.com/robust-ml/robustml) evaluation 7 | This implementation supports the robustml API for evaluation. 8 | 9 | To evaluate on CIFAR-10: 10 | ``` 11 | python robustml_test_cifar10.py 12 | ``` 13 | 14 | To evaluate on ImageNet: 15 | ``` 16 | python robustml_test_imagenet.py 17 | ``` 18 | -------------------------------------------------------------------------------- /post_avg/attacks.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import robustml 4 | import numpy as np 5 | 6 | import foolbox.criteria as crt 7 | import foolbox.attacks as attacks 8 | import foolbox.distances as distances 9 | import foolbox.adversarial as adversarial 10 | 11 | class NullAttack(robustml.attack.Attack): 12 | def run(self, x, y, target): 13 | return x 14 | 15 | class FoolboxAttackWrapper(robustml.attack.Attack): 16 | def __init__(self, attack): 17 | self._attacker = attack 18 | 19 | def run(self, x, y, target): 20 | # model requires image in (C, H, W), but robustml provides (H, W, C) 21 | # transpose x to accommodate pytorch's axis arrangement convention 22 | x = np.transpose(x, (2, 0, 1)) 23 | if target is not None: 24 | adv_criterion = crt.TargetClass(target) 25 | adv_obj = adversarial.Adversarial(self._attacker._default_model, adv_criterion, x, y, distance=self._attacker._default_distance) 26 | adv_x = self._attacker(adv_obj) 27 | else: 28 | adv_x = self._attacker(x, y) 29 | 30 | # transpose back to data provider's convention 31 | return np.transpose(adv_x, (1, 2, 0)) 32 | 33 | def fgsmAttack(victim_model): # victim_model should be model wrapped with foolbox model 34 | attacker = attacks.GradientSignAttack(victim_model, crt.Misclassification()) 35 | return FoolboxAttackWrapper(attacker) 36 | 37 | def pgdAttack(victim_model): # victim_model should be model wrapped with foolbox model 38 | attacker = attacks.RandomStartProjectedGradientDescentAttack(victim_model, crt.Misclassification(), distance=distances.Linfinity) 39 | return FoolboxAttackWrapper(attacker) 40 | 41 | def dfAttack(victim_model): # victim_model should be model wrapped with foolbox model 42 | attacker = attacks.DeepFoolAttack(victim_model, crt.Misclassification()) 43 | return FoolboxAttackWrapper(attacker) 44 | 45 | def cwAttack(victim_model): # victim_model should be model wrapped with foolbox model 46 | attacker = attacks.CarliniWagnerL2Attack(victim_model, crt.Misclassification()) 47 | return FoolboxAttackWrapper(attacker) 48 | -------------------------------------------------------------------------------- /post_avg/postAveragedModels.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import robustml 4 | import numpy as np 5 | from collections import OrderedDict 6 | import post_avg.PADefense as padef 7 | import post_avg.resnetSmall as rnsmall 8 | 9 | import torch 10 | import torchvision.models as mdl 11 | import torchvision.transforms as transforms 12 | 13 | class PostAveragedResNet152(robustml.model.Model): 14 | def __init__(self, K, R, eps, device='cuda'): 15 | self._model = mdl.resnet152(pretrained=True).to(device) 16 | self._dataset = robustml.dataset.ImageNet((224, 224, 3)) 17 | self._threat_model = robustml.threat_model.Linf(epsilon=eps) 18 | self._K = K 19 | self._r = [R/3, 2*R/3, R] 20 | self._sample_method = 'random' 21 | self._vote_method = 'avg_softmax' 22 | self._device = device 23 | 24 | @property 25 | def model(self): 26 | return self._model 27 | 28 | @property 29 | def dataset(self): 30 | return self._dataset 31 | 32 | @property 33 | def threat_model(self): 34 | return self._threat_model 35 | 36 | def classify(self, x): 37 | x = x.unsqueeze(0) 38 | 39 | # gather neighbor samples 40 | x_squad = padef.formSquad_resnet(self._sample_method, self._model, x, self._K, self._r, device=self._device) 41 | 42 | # forward with a batch of neighbors 43 | logits, _ = padef.integratedForward(self._model, x_squad, batchSize=100, nClasses=1000, device=self._device, voteMethod=self._vote_method) 44 | 45 | return torch.as_tensor(logits) 46 | 47 | def __call__(self, x): 48 | logits_list = [] 49 | for img in x: 50 | logits = self.classify(img) 51 | logits_list.append(logits) 52 | return torch.cat(logits_list, dim=0) 53 | 54 | def _preprocess(self, image): 55 | # normalization used by pre-trained model 56 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 57 | 58 | return normalize(image) 59 | 60 | def to(self, device): 61 | self._model = self._model.to(device) 62 | self._device = device 63 | 64 | def eval(self): 65 | self._model = self._model.eval() 66 | 67 | 68 | def pa_resnet152_config1(): 69 | return PostAveragedResNet152(K=15, R=30, eps=8/255) 70 | 71 | 72 | class PostAveragedResNet110(robustml.model.Model): 73 | def __init__(self, K, R, eps, device='cuda'): 74 | # load model state dict 75 | checkpoint = torch.load('post_avg/trainedModel/resnet110.th') 76 | paramDict = OrderedDict() 77 | for k, v in checkpoint['state_dict'].items(): 78 | # remove 'module.' prefix introduced by DataParallel, if any 79 | if k.startswith('module.'): 80 | paramDict[k[7:]] = v 81 | self._model = rnsmall.resnet110() 82 | self._model.load_state_dict(paramDict) 83 | self._model = self._model.to(device) 84 | 85 | self._dataset = robustml.dataset.CIFAR10() 86 | self._threat_model = robustml.threat_model.Linf(epsilon=eps) 87 | self._K = K 88 | self._r = [R/3, 2*R/3, R] 89 | self._sample_method = 'random' 90 | self._vote_method = 'avg_softmax' 91 | self._device = device 92 | 93 | @property 94 | def model(self): 95 | return self._model 96 | 97 | @property 98 | def dataset(self): 99 | return self._dataset 100 | 101 | @property 102 | def threat_model(self): 103 | return self._threat_model 104 | 105 | def classify(self, x): 106 | x = x.unsqueeze(0) 107 | 108 | # gather neighbor samples 109 | x_squad = padef.formSquad_resnet(self._sample_method, self._model, x, self._K, self._r, device=self._device) 110 | 111 | # forward with a batch of neighbors 112 | logits, _ = padef.integratedForward(self._model, x_squad, batchSize=1000, nClasses=10, device=self._device, voteMethod=self._vote_method) 113 | 114 | return torch.as_tensor(logits) 115 | 116 | def __call__(self, x): 117 | logits_list = [] 118 | for img in x: 119 | logits = self.classify(img) 120 | logits_list.append(logits) 121 | return torch.cat(logits_list, dim=0) 122 | 123 | def _preprocess(self, image): 124 | # normalization used by pre-trained model 125 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 126 | 127 | return normalize(image) 128 | 129 | def to(self, device): 130 | self._model = self._model.to(device) 131 | self._device = device 132 | 133 | def eval(self): 134 | self._model = self._model.eval() 135 | 136 | 137 | def pa_resnet110_config1(): 138 | return PostAveragedResNet110(K=15, R=6, eps=8/255) 139 | -------------------------------------------------------------------------------- /post_avg/resnetSmall.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ''' 4 | Properly implemented ResNet-s for CIFAR10 as described in paper [1]. 5 | 6 | The implementation and structure of this file is hugely influenced by [2] 7 | which is implemented for ImageNet and doesn't have option A for identity. 8 | Moreover, most of the implementations on the web is copy-paste from 9 | torchvision's resnet and has wrong number of params. 10 | 11 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following 12 | number of layers and parameters: 13 | 14 | name | layers | params 15 | ResNet20 | 20 | 0.27M 16 | ResNet32 | 32 | 0.46M 17 | ResNet44 | 44 | 0.66M 18 | ResNet56 | 56 | 0.85M 19 | ResNet110 | 110 | 1.7M 20 | ResNet1202| 1202 | 19.4m 21 | 22 | which this implementation indeed has. 23 | 24 | Reference: 25 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 26 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 27 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 28 | 29 | If you use this implementation in you work, please don't forget to mention the 30 | author, Yerlan Idelbayev. 31 | ''' 32 | import torch 33 | import torch.nn as nn 34 | import torch.nn.functional as F 35 | import torch.nn.init as init 36 | 37 | from torch.autograd import Variable 38 | 39 | __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 40 | 41 | def _weights_init(m): 42 | classname = m.__class__.__name__ 43 | # print(classname) 44 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 45 | init.kaiming_normal(m.weight) 46 | 47 | class LambdaLayer(nn.Module): 48 | def __init__(self, lambd): 49 | super(LambdaLayer, self).__init__() 50 | self.lambd = lambd 51 | 52 | def forward(self, x): 53 | return self.lambd(x) 54 | 55 | 56 | class BasicBlock(nn.Module): 57 | expansion = 1 58 | 59 | def __init__(self, in_planes, planes, stride=1, option='A'): 60 | super(BasicBlock, self).__init__() 61 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 64 | self.bn2 = nn.BatchNorm2d(planes) 65 | 66 | self.shortcut = nn.Sequential() 67 | if stride != 1 or in_planes != planes: 68 | if option == 'A': 69 | """ 70 | For CIFAR10 ResNet paper uses option A. 71 | """ 72 | self.shortcut = LambdaLayer(lambda x: 73 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 74 | elif option == 'B': 75 | self.shortcut = nn.Sequential( 76 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 77 | nn.BatchNorm2d(self.expansion * planes) 78 | ) 79 | 80 | def forward(self, x): 81 | out = F.relu(self.bn1(self.conv1(x))) 82 | out = self.bn2(self.conv2(out)) 83 | out += self.shortcut(x) 84 | out = F.relu(out) 85 | return out 86 | 87 | 88 | class ResNet(nn.Module): 89 | def __init__(self, block, num_blocks, num_classes=10): 90 | super(ResNet, self).__init__() 91 | self.in_planes = 16 92 | 93 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 94 | self.bn1 = nn.BatchNorm2d(16) 95 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 96 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 97 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 98 | self.linear = nn.Linear(64, num_classes) 99 | 100 | self.apply(_weights_init) 101 | 102 | def _make_layer(self, block, planes, num_blocks, stride): 103 | strides = [stride] + [1]*(num_blocks-1) 104 | layers = [] 105 | for stride in strides: 106 | layers.append(block(self.in_planes, planes, stride)) 107 | self.in_planes = planes * block.expansion 108 | 109 | return nn.Sequential(*layers) 110 | 111 | def forward(self, x): 112 | out = F.relu(self.bn1(self.conv1(x))) 113 | out = self.layer1(out) 114 | out = self.layer2(out) 115 | out = self.layer3(out) 116 | out = F.avg_pool2d(out, out.size()[3]) 117 | out = out.view(out.size(0), -1) 118 | out = self.linear(out) 119 | return out 120 | 121 | 122 | def resnet20(): 123 | return ResNet(BasicBlock, [3, 3, 3]) 124 | 125 | 126 | def resnet32(): 127 | return ResNet(BasicBlock, [5, 5, 5]) 128 | 129 | 130 | def resnet44(): 131 | return ResNet(BasicBlock, [7, 7, 7]) 132 | 133 | 134 | def resnet56(): 135 | return ResNet(BasicBlock, [9, 9, 9]) 136 | 137 | 138 | def resnet110(): 139 | return ResNet(BasicBlock, [18, 18, 18]) 140 | 141 | 142 | def resnet1202(): 143 | return ResNet(BasicBlock, [200, 200, 200]) 144 | 145 | 146 | def test(net): 147 | import numpy as np 148 | total_params = 0 149 | 150 | for x in filter(lambda p: p.requires_grad, net.parameters()): 151 | total_params += np.prod(x.data.numpy().shape) 152 | print("Total number of params", total_params) 153 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) 154 | 155 | 156 | if __name__ == "__main__": 157 | for net_name in __all__: 158 | if net_name.startswith('resnet'): 159 | print(net_name) 160 | test(globals()[net_name]()) 161 | print() -------------------------------------------------------------------------------- /post_avg/robustml_test_cifar10.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import argparse 5 | import robustml 6 | import numpy as np 7 | from foolbox.models import PyTorchModel 8 | from robustml_portal import attacks as atk 9 | from robustml_portal import postAveragedModels as pamdl 10 | 11 | 12 | # argument parsing 13 | parser = argparse.ArgumentParser(description="robustml evaluation on CIFAR-10") 14 | parser.add_argument("datasetPath", help="path to the 'test_batch' file") 15 | parser.add_argument("--start", type=int, default=0, help="inclusive starting index for data. default: 0") 16 | parser.add_argument("--end", type=int, help="exclusive ending index for data. default: dataset size") 17 | parser.add_argument("--attack", choices=["pgd", "fgsm", "df", "cw", "none"], default="pgd", help="attack method to be used. default: pgd") 18 | parser.add_argument("--device", help="compuation device to be used. 'cpu' or 'cuda:'") 19 | args = parser.parse_args() 20 | 21 | if args.device is None: 22 | if torch.cuda.is_available(): 23 | device = torch.device("cuda") 24 | else: 25 | device = torch.device("cpu") 26 | else: 27 | device = torch.device(args.device) 28 | 29 | # setup test model 30 | model = pamdl.pa_resnet110_config1() 31 | model.to(device) 32 | model.eval() 33 | 34 | # setup attacker 35 | nClasses = 10 36 | victim_model = PyTorchModel(model.model, (0,1), nClasses, device=device, preprocessing=(np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)), np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)))) 37 | if args.attack == "pgd": 38 | attack = atk.pgdAttack(victim_model) 39 | elif args.attack == "fgsm": 40 | attack = atk.fgsmAttack(victim_model) 41 | elif args.attack == "df": 42 | attack = atk.dfAttack(victim_model) 43 | elif args.attack == "cw": 44 | attack = atk.cwAttack(victim_model) 45 | else: 46 | attack = atk.NullAttack() 47 | 48 | # setup data provider 49 | prov = robustml.provider.CIFAR10(args.datasetPath) 50 | 51 | # evaluate performance 52 | if args.end is None: 53 | args.end = len(prov) 54 | atk_success_rate = robustml.evaluate.evaluate(model, attack, prov, start=args.start, end=args.end) 55 | print('Overall attack success rate: %.4f' % atk_success_rate) -------------------------------------------------------------------------------- /post_avg/robustml_test_imagenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import argparse 5 | import robustml 6 | import numpy as np 7 | from foolbox.models import PyTorchModel 8 | from robustml_portal import attacks as atk 9 | from robustml_portal import postAveragedModels as pamdl 10 | 11 | 12 | # argument parsing 13 | parser = argparse.ArgumentParser(description="robustml evaluation on ImageNet") 14 | parser.add_argument("datasetPath", help="directory containing 'val.txt' and 'val/' folder") 15 | parser.add_argument("--start", type=int, default=0, help="inclusive starting index for data. default: 0") 16 | parser.add_argument("--end", type=int, help="exclusive ending index for data. default: dataset size") 17 | parser.add_argument("--attack", choices=["pgd", "fgsm", "df", "cw", "none"], default="pgd", help="attack method to be used. default: pgd") 18 | parser.add_argument("--device", help="compuation device to be used. 'cpu' or 'cuda:'") 19 | args = parser.parse_args() 20 | 21 | if args.device is None: 22 | if torch.cuda.is_available(): 23 | device = torch.device("cuda") 24 | else: 25 | device = torch.device("cpu") 26 | else: 27 | device = torch.device(args.device) 28 | 29 | # setup test model 30 | model = pamdl.pa_resnet152_config1() 31 | model.to(device) 32 | model.eval() 33 | 34 | # setup attacker 35 | nClasses = 1000 36 | victim_model = PyTorchModel(model.model, (0,1), nClasses, device=device, preprocessing=(np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)), np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)))) 37 | if args.attack == "pgd": 38 | attack = atk.pgdAttack(victim_model) 39 | elif args.attack == "fgsm": 40 | attack = atk.fgsmAttack(victim_model) 41 | elif args.attack == "df": 42 | attack = atk.dfAttack(victim_model) 43 | elif args.attack == "cw": 44 | attack = atk.cwAttack(victim_model) 45 | else: 46 | attack = atk.NullAttack() 47 | 48 | # setup data provider 49 | prov = robustml.provider.ImageNet(args.datasetPath, (224, 224, 3)) 50 | 51 | # evaluate performance 52 | if args.end is None: 53 | args.end = len(prov) 54 | atk_success_rate = robustml.evaluate.evaluate(model, attack, prov, start=args.start, end=args.end) 55 | print('Overall attack success rate: %.4f' % atk_success_rate) -------------------------------------------------------------------------------- /post_avg/visualHelper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | import matplotlib; matplotlib.use('agg') 8 | import matplotlib.pyplot as plt 9 | 10 | def plotPredStats(feats, lb, K=10, image=None, noiseImage=None, savePath=None): 11 | 12 | # score by averaging 13 | scores = torch.mean(feats, dim=0) 14 | 15 | # sort and select the top K scores 16 | hScores, hCates = torch.sort(scores, dim=0, descending=True) 17 | hScores = hScores[:K].numpy() 18 | hCates = hCates[:K].numpy() 19 | 20 | # get individual preditions 21 | _, preds = torch.max(feats, dim=1) 22 | 23 | # count votes 24 | preds_count = {lb: 0} 25 | for i in range(feats.size(0)): 26 | if preds[i].item() in preds_count: 27 | preds_count[preds[i].item()] = preds_count[preds[i].item()] + 1 28 | else: 29 | preds_count[preds[i].item()] = 1 30 | 31 | candidates = sorted(preds_count.keys()) 32 | votes = [preds_count[x] for x in candidates] 33 | 34 | # generate figure 35 | fig = plt.figure() 36 | if image is None and noiseImage is None: 37 | ax1, ax2, ax3 = fig.subplots(3, 1) 38 | else: 39 | axes = fig.subplots(2, 2) 40 | ax1 = axes[0, 0] 41 | ax2 = axes[1, 0] 42 | ax3 = axes[0, 1] 43 | ax4 = axes[1, 1] 44 | 45 | # chart 1, votes distribution 46 | inx1 = list(range(len(candidates))) 47 | clr1 = [] 48 | for i in inx1: 49 | if candidates[i] == lb: 50 | clr1.append('Red') 51 | else: 52 | clr1.append('SkyBlue') 53 | rects1 = ax1.bar(inx1, votes, color=clr1) 54 | for rect in rects1: 55 | h = rect.get_height() 56 | ax1.text(rect.get_x() + 0.5 * rect.get_width(), 1.01 * h, '{}'.format(h), ha='center', va='bottom') 57 | ax1.set_ylim(top=1.1 * ax1.get_ylim()[1]) 58 | ax1.set_xticks(inx1) 59 | ax1.set_xticklabels([str(x) for x in candidates], rotation=30) 60 | ax1.set_ylabel('votes') 61 | ax1.set_title('Votes Distribution') 62 | 63 | # chart 2, top prediction scores 64 | inx2 = list(range(len(hCates))) 65 | clr2 = [] 66 | for i in inx2: 67 | if hCates[i] == lb: 68 | clr2.append('Red') 69 | else: 70 | clr2.append('SkyBlue') 71 | rects2 = ax2.bar(inx2, hScores, color=clr2) 72 | for rect in rects2: 73 | h = rect.get_height() 74 | ax2.text(rect.get_x() + 0.5 * rect.get_width(), 1.01 * h, '{:.2f}'.format(h), ha='center', va='bottom') 75 | ax2.set_ylim(top=1.1 * ax2.get_ylim()[1]) 76 | ax2.set_xticks(inx2) 77 | ax2.set_xticklabels([str(x) for x in hCates], rotation=30) 78 | ax2.set_ylabel('score') 79 | ax2.set_xlabel('Top Predictions') 80 | 81 | # axis 3, the noise image 82 | if noiseImage is not None: 83 | ax3.imshow(noiseImage) 84 | ax3.set_xlabel('Noise Image') 85 | ax3.set_axis_off() 86 | else: 87 | # if noise image is not given, show prediction event plot 88 | clr3 = [] 89 | for i in range(preds.size(0)): 90 | if preds[i] == lb: 91 | clr3.append('Red') 92 | else: 93 | clr3.append('Green') 94 | ax3.eventplot(preds.unsqueeze(1).numpy(), orientation='vertical', colors=clr3) 95 | ax3.set_yticks(candidates) 96 | ax3.set_yticklabels([str(x) for x in candidates]) 97 | ax3.set_xlabel('sample index') 98 | ax3.set_ylabel('class') 99 | 100 | # axis 4, the input image 101 | if image is not None: 102 | ax4.imshow(image) 103 | ax4.set_title('Input Image') 104 | ax4.set_axis_off() 105 | 106 | # save figure and close 107 | if savePath is not None: 108 | fig.savefig(savePath) 109 | 110 | plt.close(fig) 111 | 112 | 113 | def plotPerturbationDistribution(perturbations, savePath=None): 114 | 115 | # generate figure 116 | fig = plt.figure() 117 | ax1, ax2, ax3 = fig.subplots(3, 1) 118 | 119 | # plot scatter chart 120 | perts = np.asarray(perturbations) 121 | ax1.scatter(perts[:, 0], perts[:, 1], c='SkyBlue') 122 | ax1.autoscale(axis='x') 123 | ax1.set_ylim((-1, 2)) 124 | ax1.set_yticks([0, 1]) 125 | ax1.set_yticklabels(['missed', 'defensed']) 126 | ax1.set_xlabel('Perturbation distance') 127 | ax1.set_title('Perturbations Distribution') 128 | 129 | # plot bin chart for defensed adversarial samples 130 | x = [e[0] for e in perturbations if e[1] == 1] 131 | ax2.hist(x, bins=20, color='SkyBlue') 132 | ax2.set_xlabel('Perturbation distance') 133 | ax2.set_ylabel('Denfensed') 134 | 135 | # plot bin chart for missed adversarial samples 136 | x = [e[0] for e in perturbations if e[1] == 0] 137 | ax3.hist(x, bins=20, color='Red') 138 | ax3.set_xlabel('Perturbation distance') 139 | ax3.set_ylabel('Missed') 140 | 141 | # save figure and close 142 | if savePath is not None: 143 | fig.savefig(savePath) 144 | 145 | plt.close(fig) 146 | 147 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | 5 | class Logger: 6 | def __init__(self, path): 7 | self.path = path 8 | if path != '': 9 | folder = '/'.join(path.split('/')[:-1]) 10 | if not os.path.exists(folder): 11 | os.makedirs(folder) 12 | 13 | def print(self, message): 14 | print(message) 15 | if self.path != '': 16 | with open(self.path, 'a') as f: 17 | f.write(message + '\n') 18 | f.flush() 19 | 20 | 21 | def dense_to_onehot(y_test, n_cls): 22 | y_test_onehot = np.zeros([len(y_test), n_cls], dtype=bool) 23 | y_test_onehot[np.arange(len(y_test)), y_test] = True 24 | return y_test_onehot 25 | 26 | 27 | def random_classes_except_current(y_test, n_cls): 28 | y_test_new = np.zeros_like(y_test) 29 | for i_img in range(y_test.shape[0]): 30 | lst_classes = list(range(n_cls)) 31 | lst_classes.remove(y_test[i_img]) 32 | y_test_new[i_img] = np.random.choice(lst_classes) 33 | return y_test_new 34 | 35 | 36 | def softmax(x): 37 | e_x = np.exp(x - np.max(x, axis=1, keepdims=True)) 38 | return e_x / e_x.sum(axis=1, keepdims=True) 39 | --------------------------------------------------------------------------------