├── .DS_Store ├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── _init_paths.py ├── blackbox.py ├── datasets ├── README.md ├── __init__.py ├── celeba.py ├── dataset.py ├── fmnist.py ├── mnist.py └── utils.py ├── download_dataset.py ├── experiments └── cfgs │ ├── gans │ ├── celeba.yml │ ├── default.yml │ ├── fmnist.yml │ └── mnist.yml │ └── key_doc.yml ├── figures ├── .DS_Store ├── defensegan.png └── defensegan_gd.png ├── models ├── __init__.py ├── base_model.py ├── dataset_models.py └── gan.py ├── requirements.txt ├── tflib ├── __init__.py ├── checkpoint.py ├── cifar10.py ├── inception_score.py ├── mnist.py ├── ops │ ├── __init__.py │ ├── batchnorm.py │ ├── cond_batchnorm.py │ ├── conv1d.py │ ├── conv2d.py │ ├── deconv2d.py │ ├── layernorm.py │ └── linear.py ├── plot.py ├── save_images.py └── small_imagenet.py ├── train.py ├── utils ├── __init__.py ├── config.py ├── dummy.py ├── gan_defense.py ├── misc.py ├── network_builder.py └── visualize.py └── whitebox.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kabkabm/defensegan/7e3feaebf7b9bbf08b1364e400119ef596cd78fd/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.so 3 | .ipynb_checkpoints 4 | .idea 5 | cache 6 | output 7 | **/logs 8 | /output/* 9 | /external 10 | /external/* 11 | data 12 | debug 13 | lib/build 14 | lib/build/* 15 | matplotlibrc 16 | *.out 17 | *.pdf 18 | *.esp 19 | *.mat 20 | *.pckl 21 | *.pkl 22 | *.swp -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "cleverhans"] 2 | path = cleverhans 3 | url = https://github.com/tensorflow/cleverhans 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | Copyright (c) 2018 Pouya Samangouei and Maya Kabkab 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | MIT License 181 | 182 | Copyright (c) 2017 Ashish Bora 183 | 184 | Permission is hereby granted, free of charge, to any person obtaining a copy 185 | of this software and associated documentation files (the "Software"), to deal 186 | in the Software without restriction, including without limitation the rights 187 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 188 | copies of the Software, and to permit persons to whom the Software is 189 | furnished to do so, subject to the following conditions: 190 | 191 | The above copyright notice and this permission notice shall be included in all 192 | copies or substantial portions of the Software. 193 | 194 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 195 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 196 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 197 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 198 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 199 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 200 | SOFTWARE. 201 | 202 | The MIT License (MIT) 203 | 204 | Copyright (c) 2016 Taehoon Kim 205 | 206 | Permission is hereby granted, free of charge, to any person obtaining a copy 207 | of this software and associated documentation files (the "Software"), to deal 208 | in the Software without restriction, including without limitation the rights 209 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 210 | copies of the Software, and to permit persons to whom the Software is 211 | furnished to do so, subject to the following conditions: 212 | 213 | The above copyright notice and this permission notice shall be included in all 214 | copies or substantial portions of the Software. 215 | 216 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 217 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 218 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 219 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 220 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 221 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 222 | SOFTWARE. 223 | 224 | MIT License 225 | 226 | Copyright (c) 2017 Google Inc., OpenAI and Pennsylvania State University 227 | 228 | Permission is hereby granted, free of charge, to any person obtaining a copy 229 | of this software and associated documentation files (the "Software"), to deal 230 | in the Software without restriction, including without limitation the rights 231 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 232 | copies of the Software, and to permit persons to whom the Software is 233 | furnished to do so, subject to the following conditions: 234 | 235 | The above copyright notice and this permission notice shall be included in all 236 | copies or substantial portions of the Software. 237 | 238 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 239 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 240 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 241 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 242 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 243 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 244 | SOFTWARE. 245 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Defense-GAN: Protecting Classifiers Against Adversarial Attacks Using Generative Models 2 | 3 | Pouya Samangouei*, Maya Kabkab*, Rama Chellappa 4 | 5 | [*: authors contributed equally] 6 | 7 | This repository contains the implementation of our ICLR-18 paper: 8 | [**Defense-GAN: Protecting Classifiers Against Adversarial Attacks Using Generative Models**](https://openreview.net/pdf?id=BkJ3ibb0-) 9 | 10 | If you find this code or the paper useful, please consider citing: 11 | 12 | ``` 13 | @inproceedings{defensegan, 14 | title={Defense-GAN: Protecting classifiers against adversarial attacks using generative models}, 15 | author={Samangouei, Pouya and Kabkab, Maya and Chellappa, Rama}, 16 | booktitle={International Conference on Learning Representations}, 17 | year={2018} 18 | } 19 | ``` 20 | 21 | ![alt text](figures/defensegan.png "The Overview of the Defense-GAN Algorithm") 22 | ![alt text](figures/defensegan_gd.png "The gradient descent steps at inferece time.") 23 | 24 | 25 | ## Contents 26 | 27 | 1. [Installation](#installation) 28 | 2. [Usage](#usage) 29 | - [Train a GAN model](#train-a-gan-model) 30 | - [Black-box attacks](#black-box-attacks) 31 | - [White-box attacks](#white-box-attacks) 32 | 33 | 34 | ## Installation 35 | 1. Clone this repository: 36 | ``` 37 | git clone --recursive https://github.com/kabkabm/defensegan 38 | cd defensegan 39 | git submodule update --init --recursive 40 | ``` 41 | 42 | 2. Install requirements: 43 | ``` 44 | pip install -r requirements.txt 45 | ``` 46 | Note: if you don't have a GPU install the cpu version of TensorFlow 1.7. 47 | 48 | 3. Download the dataset and prepare `data` directory: 49 | ``` 50 | python download_dataset.py [mnist|f-mnist|celeba] 51 | ``` 52 | 53 | 4. Create or link `output` and `debug` directories: 54 | ``` 55 | mkdir output 56 | mkdir debug 57 | ``` 58 | or 59 | ``` 60 | ln -s output 61 | ln -s debug 62 | ``` 63 | 64 | 65 | ## Usage 66 | 67 | ### Train a GAN model 68 | ``` 69 | python train.py --cfg --is_train 70 | ``` 71 | - `--cfg` This can be set to either a `.yml` configuration file like the ones in 72 | `experiments/cfgs`, or an output directory path. 73 | - `` can be any parameter that is defined in the config file. 74 | 75 | The training will create a directory in the `output` directory per experiment 76 | with the same name as to save the model checkpoints. If 77 | `` are different from the ones that are defined in ``, 78 | the output directory name will reflect the difference. 79 | 80 | A config file is saved into each experiment directory so that they can be 81 | loaded if `` is the address to that directory. 82 | 83 | #### Example 84 | 85 | After running 86 | ``` 87 | python train.py --cfg experiments/cfgs/gans/mnist.yml --is_train 88 | ``` 89 | `output/gans/mnist` will be created. 90 | 91 | #### [optional] Save reconstructions and datasets into cache: 92 | ``` 93 | python train.py --cfg experiments/cfgs/ --save_recs 94 | python train.py --cfg experiments/cfgs/ --save_ds 95 | ``` 96 | 97 | #### Example 98 | After running the training code for `mnist`, the reconstructions and the 99 | dataset can be saved with: 100 | ``` 101 | python train.py --cfg output/gans/mnist --save_recs 102 | python train.py --cfg output/gans/mnist --save_ds 103 | ``` 104 | 105 | As training goes on, sample outputs of the generator are written to `debug/gans/`. 106 | 107 | ### Black-box attacks 108 | 109 | To perform black-box experiments run `blackbox.py` [Table 1 and 2 of the 110 | paper]: 111 | ``` 112 | python blackbox.py --cfg \ 113 | --results_dir \ 114 | --bb_model {A, B, C, D, E} \ 115 | --sub_model {A, B, C, D, E} \ 116 | --fgsm_eps \ 117 | --defense_type {none|defense_gan|adv_tr} 118 | [--train_on_recs or --online_training] 119 | 120 | ``` 121 | - `--cfg` is the path to the config file for training the iWGAN. This can 122 | also be the path to the output directory of the model. 123 | - `--results_dir` The path where the final results are saved in text files. 124 | - `--bb_model` The black-box model architectures that are used in Table 1 and 125 | Table 2. 126 | - `--sub_model` The substitute model architectures that are used in Table 1 and 127 | Table 2. 128 | - `--defense_type` specifies the type of defense to protect the classifier. 129 | - `--train_on_recs or --online_training` These parameters are optional. If they 130 | are set, the classifier will be trained on the reconstructions of 131 | Defense-GAN (e.g. in column `Defense-GAN-Rec` of Table 1 and 2). Otherwise, the 132 | results are for `Defense-GAN-Orig`. Note `--online_training` will take 133 | a while if `--rec_iters`, or L in the paper, is set to a large value. 134 | - `` A list of `-- ` that are the same 135 | as the hyperparemeters that are defined in config files (all lower case), and 136 | also a list of flags in `blackbox.py`. The most important ones are: 137 | - `--rec_iters` The number of GD reconstruction iterations for Defense-GAN, or L in 138 | the paper. 139 | - `--rec_lr` The learning rate of the reconstruction step. 140 | - `--rec_rr` The number of random restarts for the reconstruction step, or 141 | R in the paper. 142 | - `--num_train` The number of images to train the black-box model on. For debugging 143 | purposes set this to a small value. 144 | - `--num_test` The number of images to test on. For debugging purposes set this 145 | to a small value. 146 | - `--debug` This will save qualitative attack and reconstruction results in 147 | `debug` directory and will not run the adversarial attack part of the code. 148 | 149 | - Refer to `blackbox.py` for more flag descriptions. 150 | 151 | #### Example 152 | 153 | - Row 1 of Table 1 `Defense-GAN-Orig`: 154 | ``` 155 | python blackbox.py --cfg output/gans/mnist \ 156 | --results_dir defensegan \ 157 | --bb_model A \ 158 | --sub_model B \ 159 | --fgsm_eps 0.3 \ 160 | --defense_type defense_gan 161 | ``` 162 | - If you set `--nb_epochs 1 --nb_epochs_s 1 --data_aug 1` you will get a quick glance of how the script works. 163 | 164 | ### White-box attacks 165 | 166 | To test Defense-GAN for white-box attacks run `whitebox.py` [Tables 4, 5, 12 167 | of the paper]: 168 | ``` 169 | python whitebox.py --cfg \ 170 | --results_dir \ 171 | --attack_type {fgsm, rand_fgsm, cw} \ 172 | --defense_type {none|defense_gan|adv_tr} \ 173 | --model {A, B, C, D} \ 174 | [--train_on_recs or --online_training] 175 | 176 | ``` 177 | - `--cfg` is the path to the config file for training the iWGAN. This can 178 | also be the path to the output directory of the model. 179 | - `--results_dir` The path where the final results are saved in text files. 180 | - `--defense_type` specifies the type of defense to protect the classifier. 181 | - `--train_on_recs or --online_training` These parameters are optional. If they 182 | are set, the classifier will be trained on the reconstructions of 183 | Defense-GAN (e.g. in column `Defense-GAN-Rec` of Table 1 and 2). Otherwise, the 184 | results are for `Defense-GAN-Orig`. Note `--online_training` will take 185 | a while if `--rec_iters`, or L in the paper, is set to a large value. 186 | - `` A list of `-- ` that are the same 187 | as the hyperparemeters that are defined in config files (all lower case), and 188 | also a list of flags in `whitebox.py`. The most important ones are: 189 | - `--rec_iters` The number of GD reconstruction iterations for Defense-GAN, or L in 190 | the paper. 191 | - `--rec_lr` The learning rate of the reconstruction step. 192 | - `--rec_rr` The number of random restarts for the reconstruction step, or 193 | R in the paper. 194 | - `--num_test` The number of images to test on. For debugging purposes set this 195 | to a small value. 196 | - Refer to `whitebox.py` for more flag descriptions. 197 | 198 | #### Example 199 | 200 | First row of Table 4: 201 | ``` 202 | python whitebox.py --cfg \ 203 | --results_dir whitebox \ 204 | --attack_type fgsm \ 205 | --defense_type defense_gan \ 206 | --model A 207 | ``` 208 | - If you want to quickly see how the scripts work, add the following flags: 209 | ``` 210 | --nb_epochs 1 --num_tests 400 211 | ``` 212 | -------------------------------------------------------------------------------- /_init_paths.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | 4 | def add_path(path): 5 | if path not in sys.path: 6 | sys.path.insert(0, path) 7 | 8 | this_dir = osp.dirname(__file__) 9 | cleverhans_path = osp.join(this_dir,'cleverhans') 10 | add_path(cleverhans_path) 11 | 12 | 13 | -------------------------------------------------------------------------------- /blackbox.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Defense-GAN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Testing blackbox Defense-GAN models. This module is based on MNIST tutorial 17 | of cleverhans.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | from __future__ import unicode_literals 23 | 24 | import argparse 25 | import cPickle 26 | import logging 27 | import os 28 | import re 29 | import sys 30 | 31 | import keras.backend as K 32 | import numpy as np 33 | import tensorflow as tf 34 | from six.moves import xrange 35 | from tensorflow.python.platform import flags 36 | 37 | from cleverhans.attacks import FastGradientMethod 38 | from cleverhans.attacks_tf import jacobian_graph, jacobian_augmentation 39 | from cleverhans.utils import set_log_level, to_categorical 40 | from cleverhans.utils_tf import model_train, model_eval, batch_eval 41 | from datasets.celeba import CelebA 42 | from datasets.dataset import PickleLazyDataset 43 | from models.gan import MnistDefenseGAN, FmnistDefenseDefenseGAN, \ 44 | CelebADefenseGAN 45 | from utils.config import load_config 46 | from utils.gan_defense import model_eval_gan 47 | from utils.misc import ensure_dir 48 | from utils.network_builder import model_a, model_b, model_c, model_d, \ 49 | model_e, model_f, model_z, model_q 50 | from utils.visualize import save_images_files 51 | 52 | FLAGS = flags.FLAGS 53 | dataset_gan_dict = { 54 | 'mnist': MnistDefenseGAN, 55 | 'f-mnist': FmnistDefenseDefenseGAN, 56 | 'celeba': CelebADefenseGAN, 57 | } 58 | 59 | # orig_ refers to original images and not reconstructed ones. 60 | # To prepare these cache files run "python main.py --save_ds". 61 | orig_data_path = {k: 'data/cache/{}_pkl'.format(k) for k in 62 | dataset_gan_dict.keys()} 63 | 64 | 65 | def prep_bbox(sess, images, labels, images_train, labels_train, images_test, 66 | labels_test, nb_epochs, batch_size, learning_rate, rng, gan=None, 67 | adv_training=False, cnn_arch=None): 68 | """Defines and trains a model that simulates the "remote" 69 | black-box oracle described in https://arxiv.org/abs/1602.02697. 70 | 71 | Args: 72 | sess: the TF session 73 | images: the input placeholder 74 | labels: the ouput placeholder 75 | images_train: the training data for the oracle 76 | labels_train: the training labels for the oracle 77 | images_test: the testing data for the oracle 78 | labels_test: the testing labels for the oracle 79 | nb_epochs: number of epochs to train model 80 | batch_size: size of training batches 81 | learning_rate: learning rate for training 82 | rng: numpy.random.RandomState 83 | 84 | Returns: 85 | model: The blackbox model function. 86 | predictions: The predictions tensor. 87 | accuracy: Accuracy of the model. 88 | """ 89 | 90 | # Define TF model graph (for the black-box model). 91 | model = cnn_arch 92 | if gan: 93 | x_rec = tf.stop_gradient( 94 | gan.reconstruct(images, batch_size=batch_size)) 95 | predictions = model(x_rec) 96 | else: 97 | predictions = model(images) 98 | print("Defined TensorFlow model graph.") 99 | 100 | train_params = { 101 | 'nb_epochs': nb_epochs, 102 | 'batch_size': batch_size, 103 | 'learning_rate': learning_rate, 104 | } 105 | preds_adv = None 106 | 107 | if adv_training: 108 | 109 | fgsm_par = {'eps': FLAGS.fgsm_eps_tr, 'ord': np.inf, 'clip_min': 0., 110 | 'clip_max': 1.} 111 | if gan: 112 | if any([xx in gan.dataset_name for xx in ['celeba']]): 113 | fgsm_par['clip_min'] = -1.0 114 | fgsm_params = fgsm_par 115 | 116 | fgsm = FastGradientMethod(model, sess=sess) 117 | adv_x = fgsm.generate(images, **fgsm_params) 118 | adv_x = tf.stop_gradient(adv_x) 119 | preds_adv = model(adv_x) 120 | 121 | model_train( 122 | sess, images, labels, predictions, images_train, labels_train, 123 | args=train_params, rng=rng, predictions_adv=preds_adv, 124 | init_all=False, feed={K.learning_phase(): 1} 125 | ) 126 | 127 | # Print out the accuracy on legitimate test data. 128 | eval_params = {'batch_size': batch_size} 129 | 130 | accuracy = model_eval( 131 | sess, images, labels, predictions, images_test, 132 | labels_test, args=eval_params, feed={K.learning_phase(): 0}, 133 | ) 134 | 135 | print( 136 | 'Test accuracy of black-box on legitimate test examples: ' + 137 | str(accuracy) 138 | ) 139 | 140 | return model, predictions, accuracy 141 | 142 | 143 | def train_sub(sess, x, y, bbox_preds, X_sub, Y_sub, nb_classes, 144 | nb_epochs_s, batch_size, learning_rate, data_aug, lmbda, 145 | rng, substitute_model=None): 146 | """This function trains the substitute model as described in 147 | arxiv.org/abs/1602.02697 148 | 149 | Args: 150 | sess: TF session 151 | x: input TF placeholder 152 | y: output TF placeholder 153 | bbox_preds: output of black-box model predictions 154 | X_sub: initial substitute training data 155 | Y_sub: initial substitute training labels 156 | nb_classes: number of output classes 157 | nb_epochs_s: number of epochs to train substitute model 158 | batch_size: size of training batches 159 | learning_rate: learning rate for training 160 | data_aug: number of times substitute training data is augmented 161 | lmbda: lambda from arxiv.org/abs/1602.02697 162 | rng: numpy.random.RandomState instance 163 | 164 | Returns: 165 | model_sub: The substitute model function. 166 | preds_sub: The substitute prediction tensor. 167 | """ 168 | # Define TF model graph (for the black-box model). 169 | model_sub = substitute_model 170 | preds_sub = model_sub(x) 171 | print("Defined TensorFlow model graph for the substitute.") 172 | 173 | # Define the Jacobian symbolically using TensorFlow. 174 | grads = jacobian_graph(preds_sub, x, nb_classes) 175 | 176 | # Train the substitute and augment dataset alternatively. 177 | for rho in xrange(data_aug): 178 | print("Substitute training epoch #" + str(rho)) 179 | train_params = { 180 | 'nb_epochs': nb_epochs_s, 181 | 'batch_size': batch_size, 182 | 'learning_rate': learning_rate 183 | } 184 | model_train(sess, x, y, preds_sub, X_sub, to_categorical(Y_sub), 185 | init_all=False, args=train_params, 186 | rng=rng, feed={K.learning_phase(): 1}) 187 | 188 | # If we are not at last substitute training iteration, augment dataset. 189 | if rho < data_aug - 1: 190 | 191 | print("Augmenting substitute training data.") 192 | # Perform the Jacobian augmentation. 193 | X_sub = jacobian_augmentation(sess, x, X_sub, Y_sub, grads, lmbda, 194 | feed={K.learning_phase(): 0}) 195 | 196 | print("Labeling substitute training data.") 197 | # Label the newly generated synthetic points using the black-box. 198 | Y_sub = np.hstack([Y_sub, Y_sub]) 199 | X_sub_prev = X_sub[int(len(X_sub) / 2):] 200 | eval_params = {'batch_size': batch_size} 201 | 202 | # To initialize the local variables of Defense-GAN. 203 | sess.run(tf.local_variables_initializer()) 204 | 205 | bbox_val = batch_eval(sess, [x], [bbox_preds], [X_sub_prev], 206 | args=eval_params, 207 | feed={K.learning_phase(): 0})[0] 208 | # Note here that we take the argmax because the adversary 209 | # only has access to the label (not the probabilities) output 210 | # by the black-box model. 211 | Y_sub[int(len(X_sub) / 2):] = np.argmax(bbox_val, axis=1) 212 | 213 | return model_sub, preds_sub 214 | 215 | 216 | def convert_to_onehot(ys): 217 | """Converts the labels to one-hot vectors.""" 218 | max_y = int(np.max(ys)) 219 | y_one_hat = np.zeros([len(ys), max_y + 1], np.float32) 220 | for (i, y) in enumerate(ys): 221 | y_one_hat[i, int(y)] = 1.0 222 | return y_one_hat 223 | 224 | 225 | def get_celeba(data_path, test_on_dev=True, orig_data=False): 226 | """Generates the CelebA dataset from Pickle files. 227 | 228 | Args: 229 | data_path: The path to where pickles are saved. 230 | //pickles/ 231 | test_on_dev: Test on the development set. 232 | orig_data: Original data flag. `True` for returning the original 233 | dataset. 234 | 235 | Returns: 236 | images: Images of the dataset. 237 | labels: Labels of the loaded images. 238 | """ 239 | dev_name = 'val' 240 | if not test_on_dev: 241 | dev_name = 'test' 242 | ds = CelebA(attribute=FLAGS.attribute) 243 | ds.load() 244 | ds_test = CelebA(attribute=FLAGS.attribute) 245 | ds_test.load(split=dev_name) 246 | train_labels = ds.labels 247 | test_labels = ds_test.labels 248 | 249 | def get_pickeldb(split): 250 | train_data_path = os.path.join(data_path, split, 'pickles') 251 | assert os.path.exists(train_data_path) 252 | pkl_files = os.listdir(train_data_path) 253 | pkl_labels = np.array( 254 | [int(re.findall('.*_l(\d+).pkl', pf)[0]) for pf in pkl_files], 255 | np.int32) 256 | pkl_paths = [os.path.join(train_data_path, pf) for pf in 257 | sorted(pkl_files)] 258 | pkl_ds = PickleLazyDataset(pkl_paths, [64, 64, 3]) 259 | return pkl_ds, pkl_labels 260 | 261 | if orig_data: 262 | train_images = ds.images 263 | test_images = ds_test.images 264 | else: 265 | train_images, train_labels = get_pickeldb('train') 266 | test_images, test_labels = get_pickeldb(dev_name) 267 | 268 | return train_images, convert_to_onehot(train_labels), test_images, \ 269 | convert_to_onehot(test_labels) 270 | 271 | 272 | def get_train_test(data_path, test_on_dev=True, model=None, 273 | orig_data=False, max_num=-1): 274 | """Loads the datasets. 275 | 276 | Args: 277 | data_path: The path that contains train,dev,[test] directories 278 | test_on_dev: Test on the development set 279 | model: An instance of `GAN`. 280 | orig_data: `True` for loading original data, `False` to load the 281 | reconstructed images. 282 | 283 | Returns: 284 | train_images: Training images. 285 | train_labels: Training labels. 286 | test_images: Testing images. 287 | test_labels: Testing labels. 288 | """ 289 | 290 | data_dict = None 291 | if model and not orig_data: 292 | data_dict = model.reconstruct_dataset(max_num_load=max_num) 293 | 294 | def get_images_labels_from_pickle(data_path, split): 295 | data_path = os.path.join(data_path, split, 'feats.pkl') 296 | could_load = False 297 | try: 298 | if os.path.exists(data_path): 299 | with open(data_path) as f: 300 | train_images_gan = cPickle.load(f) 301 | train_labels_gan = cPickle.load(f) 302 | could_load = True 303 | else: 304 | print( 305 | '[!] Run python train.py --cfg --save_ds ' 306 | 'to prepare the dataset cache files.' 307 | ) 308 | exit(1) 309 | 310 | except Exception as e: 311 | print( 312 | '[!] Found feats.pkl but could not load it because {}'.format( 313 | str(e))) 314 | 315 | if not could_load and not data_dict is None: 316 | train_images_gan, train_labels_gan, train_images_orig = data_dict[ 317 | split] 318 | if orig_data: 319 | train_images_gan = train_images_orig 320 | 321 | return train_images_gan, convert_to_onehot(train_labels_gan) 322 | 323 | train_images, train_lables = \ 324 | get_images_labels_from_pickle(data_path, 'train') 325 | test_split = 'test' if test_on_dev else 'dev' 326 | test_images, test_labels = \ 327 | get_images_labels_from_pickle(data_path, test_split) 328 | 329 | return train_images, train_lables, test_images, test_labels 330 | 331 | 332 | def get_cached_gan_data(gan, test_on_dev, orig_data_flag=None): 333 | """Fetches the dataset of a GAN model. 334 | 335 | Args: 336 | gan: The GAN model. 337 | test_on_dev: `True` for loading the dev set instead of the test set. 338 | orig_data_flag: `True` for loading the original images not the 339 | reconstructions. 340 | 341 | Returns: 342 | train_images: Training images. 343 | train_labels: Training labels. 344 | test_images: Testing images. 345 | test_labels: Testing labels. 346 | """ 347 | FLAGS = flags.FLAGS 348 | if orig_data_flag is None: 349 | if not FLAGS.train_on_recs or FLAGS.defense_type != 'defense_gan': 350 | orig_data_flag = True 351 | else: 352 | orig_data_flag = False 353 | 354 | if 'celeba' in gan.dataset_name: 355 | train_images, train_labels, test_images, test_labels = get_celeba( 356 | FLAGS.rec_path, 357 | orig_data=orig_data_flag, 358 | ) 359 | if FLAGS.num_train > 0: 360 | train_images = train_images[:FLAGS.num_train] 361 | train_labels = train_labels[:FLAGS.num_train] 362 | else: 363 | train_images, train_labels, test_images, test_labels = \ 364 | get_train_test( 365 | orig_data_path[gan.dataset_name], test_on_dev=test_on_dev, 366 | model=gan, orig_data=orig_data_flag, max_num=FLAGS.num_train) 367 | return train_images, train_labels, test_images, test_labels 368 | 369 | 370 | def blackbox(gan, rec_data_path=None, batch_size=128, 371 | learning_rate=0.001, nb_epochs=10, holdout=150, data_aug=6, 372 | nb_epochs_s=10, lmbda=0.1, online_training=False, 373 | train_on_recs=False, test_on_dev=True, 374 | defense_type='none'): 375 | """MNIST tutorial for the black-box attack from arxiv.org/abs/1602.02697 376 | 377 | Args: 378 | train_start: index of first training set example 379 | train_end: index of last training set example 380 | test_start: index of first test set example 381 | test_end: index of last test set example 382 | defense_type: Type of defense against blackbox attacks 383 | 384 | Returns: 385 | a dictionary with: 386 | * black-box model accuracy on test set 387 | * substitute model accuracy on test set 388 | * black-box model accuracy on adversarial examples transferred 389 | from the substitute model 390 | """ 391 | FLAGS = flags.FLAGS 392 | 393 | # Set logging level to see debug information. 394 | set_log_level(logging.WARNING) 395 | 396 | # Dictionary used to keep track and return key accuracies. 397 | accuracies = {} 398 | 399 | # Create TF session. 400 | adv_training = False 401 | if defense_type: 402 | if defense_type == 'defense_gan' and gan: 403 | sess = gan.sess 404 | gan_defense_flag = True 405 | else: 406 | gan_defense_flag = False 407 | config = tf.ConfigProto() 408 | config.gpu_options.allow_growth = True 409 | sess = tf.Session(config=config) 410 | if 'adv_tr' in defense_type: 411 | adv_training = True 412 | else: 413 | gan_defense_flag = False 414 | config = tf.ConfigProto() 415 | config.gpu_options.allow_growth = True 416 | sess = tf.Session(config=config) 417 | 418 | train_images, train_labels, test_images, test_labels = \ 419 | get_cached_gan_data(gan, test_on_dev, orig_data_flag=True) 420 | 421 | x_shape, classes = list(train_images.shape[1:]), train_labels.shape[1] 422 | nb_classes = classes 423 | 424 | type_to_models = { 425 | 'A': model_a, 'B': model_b, 'C': model_c, 'D': model_d, 'E': model_e, 426 | 'F': model_f, 'Q': model_q, 'Z': model_z 427 | } 428 | 429 | bb_model = type_to_models[FLAGS.bb_model]( 430 | input_shape=[None] + x_shape, nb_classes=train_labels.shape[1], 431 | ) 432 | sub_model = type_to_models[FLAGS.sub_model]( 433 | input_shape=[None] + x_shape, nb_classes=train_labels.shape[1], 434 | ) 435 | 436 | if FLAGS.debug: 437 | train_images = train_images[:20 * batch_size] 438 | train_labels = train_labels[:20 * batch_size] 439 | debug_dir = os.path.join('debug', 'blackbox', FLAGS.debug_dir) 440 | ensure_dir(debug_dir) 441 | x_debug_test = test_images[:batch_size] 442 | 443 | # Initialize substitute training set reserved for adversary 444 | images_sub = test_images[:holdout] 445 | labels_sub = np.argmax(test_labels[:holdout], axis=1) 446 | 447 | # Redefine test set as remaining samples unavailable to adversaries 448 | if FLAGS.num_tests > 0: 449 | test_images = test_images[:FLAGS.num_tests] 450 | test_labels = test_labels[:FLAGS.num_tests] 451 | 452 | test_images = test_images[holdout:] 453 | test_labels = test_labels[holdout:] 454 | 455 | # Define input and output TF placeholders 456 | 457 | if FLAGS.image_dim[0] == 3: 458 | FLAGS.image_dim = [FLAGS.image_dim[1], FLAGS.image_dim[2], 459 | FLAGS.image_dim[0]] 460 | 461 | images_tensor = tf.placeholder(tf.float32, shape=[None] + x_shape) 462 | labels_tensor = tf.placeholder(tf.float32, shape=(None, classes)) 463 | 464 | rng = np.random.RandomState([11, 24, 1990]) 465 | tf.set_random_seed(11241990) 466 | 467 | train_images_bb, train_labels_bb, test_images_bb, test_labels_bb = \ 468 | train_images, train_labels, test_images, \ 469 | test_labels 470 | 471 | cur_gan = None 472 | 473 | if defense_type: 474 | if 'gan' in defense_type: 475 | # Load cached dataset reconstructions. 476 | if online_training and not train_on_recs: 477 | cur_gan = gan 478 | elif not online_training and rec_data_path: 479 | train_images_bb, train_labels_bb, test_images_bb, \ 480 | test_labels_bb = get_cached_gan_data( 481 | gan, test_on_dev, orig_data_flag=False) 482 | else: 483 | assert not train_on_recs 484 | 485 | if FLAGS.debug: 486 | train_images_bb = train_images_bb[:20 * batch_size] 487 | train_labels_bb = train_labels_bb[:20 * batch_size] 488 | 489 | # Prepare the black_box model. 490 | prep_bbox_out = prep_bbox( 491 | sess, images_tensor, labels_tensor, train_images_bb, 492 | train_labels_bb, test_images_bb, test_labels_bb, nb_epochs, 493 | batch_size, learning_rate, rng=rng, gan=cur_gan, 494 | adv_training=adv_training, 495 | cnn_arch=bb_model) 496 | else: 497 | prep_bbox_out = prep_bbox(sess, images_tensor, labels_tensor, 498 | train_images_bb, train_labels_bb, 499 | test_images_bb, test_labels_bb, 500 | nb_epochs, batch_size, learning_rate, 501 | rng=rng, gan=cur_gan, 502 | adv_training=adv_training, 503 | cnn_arch=bb_model) 504 | 505 | model, bbox_preds, accuracies['bbox'] = prep_bbox_out 506 | 507 | # Train substitute using method from https://arxiv.org/abs/1602.02697 508 | print("Training the substitute model.") 509 | reconstructed_tensors = tf.stop_gradient( 510 | gan.reconstruct(images_tensor, batch_size=batch_size, 511 | reconstructor_id=1)) 512 | model_sub, preds_sub = train_sub( 513 | sess, images_tensor, labels_tensor, 514 | model(reconstructed_tensors), images_sub, 515 | labels_sub, 516 | nb_classes, nb_epochs_s, batch_size, 517 | learning_rate, data_aug, lmbda, rng=rng, 518 | substitute_model=sub_model, 519 | ) 520 | 521 | accuracies['sub'] = 0 522 | # Initialize the Fast Gradient Sign Method (FGSM) attack object. 523 | fgsm_par = { 524 | 'eps': FLAGS.fgsm_eps, 'ord': np.inf, 'clip_min': 0., 'clip_max': 1. 525 | } 526 | if gan: 527 | if gan.dataset_name == 'celeba': 528 | fgsm_par['clip_min'] = -1.0 529 | 530 | fgsm = FastGradientMethod(model_sub, sess=sess) 531 | 532 | # Craft adversarial examples using the substitute. 533 | eval_params = {'batch_size': batch_size} 534 | x_adv_sub = fgsm.generate(images_tensor, **fgsm_par) 535 | 536 | if FLAGS.debug and gan is not None: # To see some qualitative results. 537 | reconstructed_tensors = gan.reconstruct(x_adv_sub, batch_size=batch_size, 538 | reconstructor_id=2) 539 | 540 | x_rec_orig = gan.reconstruct(images_tensor, batch_size=batch_size, 541 | reconstructor_id=3) 542 | x_adv_sub_val = sess.run(x_adv_sub, 543 | feed_dict={images_tensor: x_debug_test, 544 | K.learning_phase(): 0}) 545 | sess.run(tf.local_variables_initializer()) 546 | x_rec_debug_val, x_rec_orig_val = sess.run( 547 | [reconstructed_tensors, x_rec_orig], 548 | feed_dict={ 549 | images_tensor: x_debug_test, 550 | K.learning_phase(): 0}) 551 | 552 | save_images_files(x_adv_sub_val, output_dir=debug_dir, 553 | postfix='adv') 554 | 555 | postfix = 'gen_rec' 556 | save_images_files(x_rec_debug_val, output_dir=debug_dir, 557 | postfix=postfix) 558 | save_images_files(x_debug_test, output_dir=debug_dir, 559 | postfix='orig') 560 | save_images_files(x_rec_orig_val, output_dir=debug_dir, 561 | postfix='orig_rec') 562 | return 563 | 564 | if gan_defense_flag: 565 | reconstructed_tensors = gan.reconstruct( 566 | x_adv_sub, batch_size=batch_size, reconstructor_id=4, 567 | ) 568 | 569 | num_dims = len(images_tensor.get_shape()) 570 | avg_inds = list(range(1, num_dims)) 571 | diff_op = tf.reduce_mean(tf.square(x_adv_sub - reconstructed_tensors), 572 | axis=avg_inds) 573 | 574 | outs = model_eval_gan(sess, images_tensor, labels_tensor, 575 | predictions=model(reconstructed_tensors), 576 | test_images=test_images, test_labels=test_labels, 577 | args=eval_params, diff_op=diff_op, 578 | feed={K.learning_phase(): 0}) 579 | 580 | accuracies['bbox_on_sub_adv_ex'] = outs[0] 581 | accuracies['roc_info'] = outs[1] 582 | print('Test accuracy of oracle on adversarial examples generated ' 583 | 'using the substitute: ' + str(outs[0])) 584 | else: 585 | accuracy = model_eval(sess, images_tensor, labels_tensor, 586 | model(x_adv_sub), test_images, 587 | test_labels, 588 | args=eval_params, feed={K.learning_phase(): 0}) 589 | print('Test accuracy of oracle on adversarial examples generated ' 590 | 'using the substitute: ' + str(accuracy)) 591 | accuracies['bbox_on_sub_adv_ex'] = accuracy 592 | 593 | return accuracies 594 | 595 | 596 | def _get_results_dir_filename(gan): 597 | result_file_name = 'sub={:d}_eps={:.2f}.txt'.format(FLAGS.data_aug, 598 | FLAGS.fgsm_eps) 599 | 600 | results_dir = os.path.join('results', '{}_{}'.format( 601 | FLAGS.defense_type, FLAGS.dataset_name)) 602 | 603 | if FLAGS.rec_path and FLAGS.defense_type == 'defense_gan': 604 | results_dir = gan.checkpoint_dir.replace('output', 'results') 605 | result_file_name = \ 606 | 'teRR={:d}_teLR={:.4f}_teIter={:d}_sub={:d}_eps={:.2f}.txt'.format( 607 | gan.rec_rr, 608 | gan.rec_lr, 609 | gan.rec_iters, 610 | FLAGS.data_aug, 611 | FLAGS.fgsm_eps) 612 | 613 | if not FLAGS.train_on_recs: 614 | result_file_name = 'orig_' + result_file_name 615 | elif FLAGS.defense_type == 'adv_tr': 616 | result_file_name = 'sub={:d}_trEps={:.2f}_eps={:.2f}.txt'.format( 617 | FLAGS.data_aug, FLAGS.fgsm_eps_tr, 618 | FLAGS.fgsm_eps) 619 | if FLAGS.num_tests > -1: 620 | result_file_name = 'numtest={}_'.format( 621 | FLAGS.num_tests) + result_file_name 622 | 623 | if FLAGS.num_train > -1: 624 | result_file_name = 'numtrain={}_'.format( 625 | FLAGS.num_train) + result_file_name 626 | 627 | result_file_name = 'bbModel={}_subModel={}_'.format(FLAGS.bb_model, 628 | FLAGS.sub_model) \ 629 | + result_file_name 630 | return results_dir, result_file_name 631 | 632 | 633 | def main(cfg, argv=None): 634 | FLAGS = tf.app.flags.FLAGS 635 | GAN = dataset_gan_dict[FLAGS.dataset_name] 636 | 637 | gan = GAN(cfg=cfg, test_mode=True) 638 | gan.load_generator() 639 | # Setting test time reconstruction hyper parameters. 640 | [tr_rr, tr_lr, tr_iters] = [FLAGS.rec_rr, FLAGS.rec_lr, FLAGS.rec_iters] 641 | if FLAGS.defense_type.lower() != 'none': 642 | if FLAGS.rec_path and FLAGS.defense_type == 'defense_gan': 643 | 644 | # extract hyper parameters from reconstruction path. 645 | if FLAGS.rec_path: 646 | train_param_re = re.compile('recs_rr(.*)_lr(.*)_iters(.*)') 647 | [tr_rr, tr_lr, tr_iters] = \ 648 | train_param_re.findall(FLAGS.rec_path)[0] 649 | gan.rec_rr = int(tr_rr) 650 | gan.rec_lr = float(tr_lr) 651 | gan.rec_iters = int(tr_iters) 652 | elif FLAGS.defense_type == 'defense_gan': 653 | assert FLAGS.online_training or not FLAGS.train_on_recs 654 | 655 | if FLAGS.override: 656 | gan.rec_rr = int(tr_rr) 657 | gan.rec_lr = float(tr_lr) 658 | gan.rec_iters = int(tr_iters) 659 | 660 | # Setting the reuslts directory 661 | results_dir, result_file_name = _get_results_dir_filename(gan) 662 | 663 | # Result file name. The counter makes sure we are not overwriting the 664 | # results. 665 | counter = 0 666 | temp_fp = str(counter) + '_' + result_file_name 667 | results_dir = os.path.join(results_dir, FLAGS.results_dir) 668 | temp_final_fp = os.path.join(results_dir, temp_fp) 669 | while os.path.exists(temp_final_fp): 670 | counter += 1 671 | temp_fp = str(counter) + '_' + result_file_name 672 | temp_final_fp = os.path.join(results_dir, temp_fp) 673 | result_file_name = temp_fp 674 | sub_result_path = os.path.join(results_dir, result_file_name) 675 | 676 | accuracies = blackbox(gan, rec_data_path=FLAGS.rec_path, 677 | batch_size=FLAGS.batch_size, 678 | learning_rate=FLAGS.learning_rate, 679 | nb_epochs=FLAGS.nb_epochs, holdout=FLAGS.holdout, 680 | data_aug=FLAGS.data_aug, 681 | nb_epochs_s=FLAGS.nb_epochs_s, 682 | lmbda=FLAGS.lmbda, 683 | online_training=FLAGS.online_training, 684 | train_on_recs=FLAGS.train_on_recs, 685 | defense_type=FLAGS.defense_type) 686 | 687 | ensure_dir(results_dir) 688 | 689 | with open(sub_result_path, 'a') as f: 690 | f.writelines([str(accuracies[x]) + ' ' for x in 691 | ['bbox', 'sub', 'bbox_on_sub_adv_ex']]) 692 | f.write('\n') 693 | print('[*] saved accuracy in {}'.format(sub_result_path)) 694 | 695 | if 'roc_info' in accuracies.keys(): # For attack detection. 696 | pkl_result_path = sub_result_path.replace('.txt', '_roc.pkl') 697 | with open(pkl_result_path, 'w') as f: 698 | cPickle.dump(accuracies['roc_info'], f, cPickle.HIGHEST_PROTOCOL) 699 | print('[*] saved roc_info in {}'.format(sub_result_path)) 700 | 701 | 702 | def parse_args(): 703 | parser = argparse.ArgumentParser() 704 | 705 | parser.add_argument('--cfg', required=True, help='Config file') 706 | 707 | if len(sys.argv) == 1: 708 | parser.print_help() 709 | sys.exit(1) 710 | args, _ = parser.parse_known_args() 711 | return args 712 | 713 | 714 | if __name__ == '__main__': 715 | args = parse_args() 716 | 717 | # Note: The load_config() call will convert all the parameters that are defined in 718 | # experiments/config files into FLAGS.param_name and can be passed in from command line. 719 | # arguments : python blackbox.py --cfg -- 720 | cfg = load_config(args.cfg) 721 | flags = tf.app.flags 722 | 723 | flags.DEFINE_integer('nb_classes', 10, 'Number of classes.') 724 | flags.DEFINE_float('learning_rate', 0.001, 'Learning rate for training ' 725 | 'the black-box model.') 726 | flags.DEFINE_integer('nb_epochs', 10, 'Number of epochs to train the ' 727 | 'blackbox model.') 728 | flags.DEFINE_integer('holdout', 150, 'Test set holdout for adversary.') 729 | flags.DEFINE_integer('data_aug', 6, 'Number of substitute data augmentations.') 730 | flags.DEFINE_integer('nb_epochs_s', 10, 'Training epochs for substitute.') 731 | flags.DEFINE_float('lmbda', 0.1, 'Lambda from arxiv.org/abs/1602.02697') 732 | flags.DEFINE_float('fgsm_eps', 0.3, 'FGSM epsilon.') 733 | flags.DEFINE_float('fgsm_eps_tr', 0.15, 'FGSM epsilon for adversarial ' 734 | 'training.') 735 | flags.DEFINE_string('rec_path', None, 'Path to Defense-GAN ' 736 | 'reconstructions.') 737 | flags.DEFINE_integer('num_tests', 2000, 'Number of test samples.') 738 | flags.DEFINE_integer('random_test_iter', -1, 739 | 'Number of random sampling for testing the ' 740 | 'classifier.') 741 | flags.DEFINE_boolean("online_training", False, 742 | 'Train the base classifier based on online ' 743 | 'reconstructions from Defense-GAN, as opposed to ' 744 | 'using the cached reconstructions.') 745 | flags.DEFINE_string("defense_type", "none", "Type of defense " 746 | "[defense_gan|adv_tr|none]") 747 | flags.DEFINE_string("results_dir", None, "The path to results.") 748 | flags.DEFINE_boolean("train_on_recs", False, 749 | "Train the black-box model on Defense-GAN " 750 | "reconstructions.") 751 | flags.DEFINE_integer('num_train', -1, 'Number of training samples for ' 752 | 'the black-box model.') 753 | flags.DEFINE_string("bb_model", 'F', 754 | "The architecture of the classifier model.") 755 | flags.DEFINE_string("sub_model", 'E', "The architecture of the " 756 | "substitute model.") 757 | flags.DEFINE_string("debug_dir", None, "Directory for debug outputs.") 758 | flags.DEFINE_boolean("debug", None, "Directory for debug outputs.") 759 | flags.DEFINE_boolean("override", None, "Overrides the test hyperparams.") 760 | 761 | main_cfg = lambda x: main(cfg, x) 762 | tf.app.run(main=main_cfg) 763 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | ## Datasets 2 | 3 | The datasets are handled in this module. 4 | 5 | There is a base Dataset class that is implemented for each dataset: 6 | - MNIST 7 | - Fashion-MNIST 8 | - CelebA 9 | 10 | The factory.py is responsible for generating the appropriate class based 11 | on the set flags. -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kabkabm/defensegan/7e3feaebf7b9bbf08b1364e400119ef596cd78fd/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/celeba.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Defense-GAN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Contains the class for handling the CelebA dataset.""" 17 | 18 | import os 19 | 20 | import numpy as np 21 | 22 | from datasets.dataset import Dataset, LazyDataset 23 | 24 | 25 | class CelebA(Dataset): 26 | """CelebA class implementing Dataset.""" 27 | 28 | def __init__(self, center_crop_size=108, resize_size=64, attribute=None): 29 | """CelebA constructor. 30 | 31 | Args: 32 | center_crop_size: An integer defining the center crop square 33 | dimensions. 34 | resize_size: An integer for the final size of the cropped image. 35 | attribute: A string which is the attribute name according to the 36 | CelebA's label file header. 37 | """ 38 | 39 | super(CelebA, self).__init__('celebA') 40 | self.y_dim = 0 41 | self.split_data = {} 42 | self.image_size = center_crop_size 43 | self.resize_size = resize_size 44 | # The attribute represents which attribute to use in case of 45 | # classification. 46 | self.attribute = attribute 47 | # Only gender classification is supported. 48 | self.attr_dict = {'gender': ['male']} 49 | 50 | def load(self, split='train', lazy=True, randomize=False): 51 | """Loads the dataset according to split. 52 | 53 | Args: 54 | split: A string [train|val|test] referring to the dataset split. 55 | lazy (optional): If True, only loads the file paths and creates a 56 | LazyDataset object (default True). 57 | randomize (optional): `True` will randomize the data. 58 | 59 | Returns: 60 | A LazyDataset (if lazy is True) or a numpy array containing all 61 | the images, labels, and image ids. 62 | 63 | Raises: 64 | ValueError: If split is not one of [train|val|test]. 65 | """ 66 | 67 | attribute = self.attribute 68 | 69 | # If split data has already been loaded, return it. 70 | if split in self.split_data.keys(): 71 | return self.split_data[split] 72 | 73 | # Start and end indices of different CelebA splits. 74 | if split == 'train': 75 | start = 1 76 | end = 162770 77 | elif split == 'val': 78 | start = 162771 79 | end = 182637 80 | elif split == 'test': 81 | start = 182638 82 | end = 202599 83 | else: 84 | raise ValueError('[!] Invalid split {}.'.format(split)) 85 | 86 | data_dir = self.data_dir 87 | 88 | # Lazy dataset loading. 89 | fps = [os.path.join(data_dir, '{:06d}.jpg'.format(i)) for i in 90 | range(start, end + 1)] 91 | 92 | if randomize: 93 | rng_state = np.random.get_state() 94 | np.random.shuffle(fps) 95 | np.random.set_state(rng_state) 96 | 97 | images = LazyDataset(fps, self.image_size, self.resize_size) 98 | # Access images if not lazy. 99 | if not lazy: 100 | images = images[:len(images)] 101 | 102 | if attribute is None: # No class information needed. 103 | labels = None # Labels set to None. 104 | ids = np.array(range(0, end - start + 1), 105 | dtype=int) # All indices are valid. 106 | else: 107 | # If attribute is valid. 108 | if self.attr_dict.has_key(attribute): 109 | # Get list of classes to consider. 110 | attr_list = self.attr_dict[attribute] 111 | with open(os.path.join(self.data_dir, 'list_attr_celeba.txt'), 112 | 'r') as f: 113 | flines = f.readlines() 114 | class_names = [s.lower().replace(' ', '_') for s in 115 | flines[1].strip().split()] 116 | # Get indices of relevant columns. 117 | cols = [i for i, x in enumerate(class_names) if 118 | x in attr_list] 119 | cols = np.asarray(cols, dtype=int) 120 | face_attributes = [[int(x) for x in l.split()[1:]] for l in 121 | [ll.strip() for ll in flines[2:]]] 122 | face_attributes = (np.asarray(face_attributes, 123 | dtype=int) + 1) // 2 124 | face_attributes = face_attributes[start - 1:end, cols] 125 | labels = face_attributes.reshape(-1) 126 | else: 127 | raise ValueError( 128 | '[!] Invalid attribute {} for CelebA dataset.'.format( 129 | attribute)) 130 | 131 | self.split_data[split] = [images, labels] 132 | 133 | self.images = images 134 | self.labels = labels 135 | 136 | return images, labels 137 | -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Defense-GAN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | 16 | """Contains the classes: 17 | Dataset: All datasets used in the project implement this class. 18 | LazyDataset: A class for loading data in a lazy manner from file paths. 19 | LazyPickledDataset: A class for loading pickled data from filepaths. 20 | 21 | defined here.""" 22 | 23 | import cPickle 24 | import os 25 | 26 | import numpy as np 27 | import scipy 28 | import scipy.misc 29 | 30 | 31 | class Dataset(object): 32 | """The abstract class for handling datasets. 33 | 34 | Attributes: 35 | name: Name of the dataset. 36 | data_dir: The directory where the dataset resides. 37 | """ 38 | 39 | def __init__(self, name, data_dir='./data'): 40 | """The datasaet default constructor. 41 | 42 | Args: 43 | name: A string, name of the dataset. 44 | data_dir (optional): The path of the datasets on disk. 45 | """ 46 | 47 | self.data_dir = os.path.join(data_dir, name) 48 | self.name = name 49 | self.images = None 50 | self.labels = None 51 | 52 | def __len__(self): 53 | """Gives the number of images in the dataset. 54 | 55 | Returns: 56 | Number of images in the dataset. 57 | """ 58 | 59 | return len(self.images) 60 | 61 | def load(self, split): 62 | """ Abstract function specific to each dataset.""" 63 | pass 64 | 65 | 66 | class LazyDataset(object): 67 | """The Lazy Dataset class. 68 | Instead of loading the whole dataset into memory, this class loads 69 | images only when their index is accessed. 70 | 71 | Attributes: 72 | fps: String list of file paths. 73 | center_crop_dim: An integer for the size of center crop (after 74 | loading the images). 75 | resize_size: The final resize size (after loading the images). 76 | """ 77 | 78 | def __init__(self, filepaths, center_crop_dim, resize_size, 79 | transform_type=None): 80 | """LazyDataset constructor. 81 | 82 | Args: 83 | filepaths: File paths. 84 | center_crop_dim: The dimension of the center cropped square. 85 | resize_size: Final size to resize the center crop of the images. 86 | """ 87 | 88 | self.filepaths = filepaths 89 | self.center_crop_dim = center_crop_dim 90 | self.resize_size = resize_size 91 | self.transform_type = transform_type 92 | 93 | def _get_image(self, image_path): 94 | """Retrieves an image at a given path and resizes it to the 95 | specified size. 96 | 97 | Args: 98 | image_path: Path to image. 99 | 100 | Returns: 101 | Loaded and transformed image. 102 | """ 103 | 104 | # Read image at image_path. 105 | image = scipy.misc.imread(image_path).astype(np.float) 106 | 107 | # Return transformed image. 108 | return _prepare_image(image, self.center_crop_dim, 109 | self.center_crop_dim, 110 | resize_height=self.resize_size, 111 | resize_width=self.resize_size, 112 | is_crop=True) 113 | 114 | def __len__(self): 115 | """Gives the number of images in the dataset. 116 | 117 | Returns: 118 | Number of images in the dataset. 119 | """ 120 | 121 | return len(self.filepaths) 122 | 123 | def __getitem__(self, index): 124 | """Loads and returns images specified by index. 125 | 126 | Args: 127 | index: Indices of images to load. 128 | 129 | Returns: 130 | Loaded images. 131 | 132 | Raises: 133 | TypeError: If index is neither of: int, slice, np.ndarray. 134 | """ 135 | 136 | # Case of a single integer index. 137 | if isinstance(index, int): 138 | return self._get_image(self.filepaths[index]) 139 | # Case of a slice or array of indices. 140 | elif isinstance(index, slice): 141 | if isinstance(index, slice): 142 | if index.start is None: 143 | index = range(index.stop) 144 | elif index.step is None: 145 | index = range(index.start, index.stop) 146 | else: 147 | index = range(index.start, index.stop, index.step) 148 | return np.array( 149 | [self._get_image(self.filepaths[i]) for i in index] 150 | ) 151 | else: 152 | try: 153 | inds = [int(i) for i in index] 154 | return np.array( 155 | [self._get_image(self.filepaths[i]) for i in inds] 156 | ) 157 | except TypeError: 158 | raise TypeError("Index must be an integer, a slice, a container or an integer generator.") 159 | 160 | def get_subset(self, indices): 161 | """Gets a subset of the images 162 | 163 | Args: 164 | indices: The indices of the images that are needed. It's like 165 | lazy indexing without loading. 166 | 167 | Raises: 168 | TypeError if index is not a slice. 169 | """ 170 | if isinstance(indices, int): 171 | self.filepaths = self.filepaths[indices] 172 | elif isinstance(indices, slice) or isinstance(indices, np.ndarray): 173 | self.filepaths = [self.filepaths[i] for i in indices] 174 | else: 175 | raise TypeError("Index must be an integer or a slice.") 176 | 177 | @property 178 | def shape(self): 179 | return tuple([None] + list(self._get_image(self.filepaths[0]).shape)) 180 | 181 | @property 182 | def dtype(self): 183 | return self._get_image(self.filepaths[0]).dtype 184 | 185 | 186 | class PickleLazyDataset(LazyDataset): 187 | """This dataset is a lazy dataset for working with saved pickle files 188 | (of typically generated images) on disk without loading them. 189 | """ 190 | 191 | def __init__(self, filepaths, shape=None): 192 | """The constructor for instances of this class. 193 | 194 | Args: 195 | filepaths: List of strings. The list of file paths. 196 | shape (optional): Shape of the loaded images in case the images 197 | are saved as a vector. 198 | """ 199 | self.filepaths = filepaths 200 | self.image_shape = shape 201 | 202 | def __len__(self): 203 | return len(self.filepaths) 204 | 205 | def _get_image(self, filepath): 206 | with open(filepath) as f: 207 | return cPickle.load(f).reshape(self.image_shape) 208 | 209 | @property 210 | def shape(self): 211 | im = self.__getitem__(0) 212 | return [len(self.filepaths)] + list(im.shape) 213 | 214 | 215 | def _prepare_image(image, crop_height, crop_width, resize_height=64, 216 | resize_width=64, is_crop=True): 217 | """Prepares an image by first applying an optional center 218 | crop, then resizing it. 219 | 220 | Args: 221 | image: Input image. 222 | crop_height: The height of the crop. 223 | crop_width: The width of the crop. 224 | resize_height: The resize height after cropping. 225 | resize_width: The resize width after cropping. 226 | is_crop: If True, first apply a center crop. 227 | 228 | Returns: 229 | The cropped and resized image. 230 | """ 231 | 232 | def center_crop(image, crop_h, crop_w, resize_h=64, resize_w=64): 233 | """Performs a center crop followed by a resize. 234 | 235 | Args: 236 | image: Image of type np.ndarray 237 | crop_h: The height of the crop. 238 | crop_w: The width of the crop. 239 | resize_h: The resize height after cropping. 240 | resize_w: The resize width after cropping. 241 | 242 | Returns: 243 | The cropped and resized image of type np.ndarray. 244 | """ 245 | if crop_w is None: 246 | crop_w = crop_h 247 | h, w = image.shape[:2] 248 | j = int(round((h - crop_h) / 2.)) 249 | i = int(round((w - crop_w) / 2.)) 250 | # Crop then resize. 251 | return scipy.misc.imresize(image[j:j + crop_h, i:i + crop_w], 252 | [resize_h, resize_w]) 253 | 254 | # Optionally crop the image. Then resize it. 255 | if is_crop: 256 | cropped_image = center_crop(image, crop_height, crop_width, 257 | resize_height, resize_width) 258 | else: 259 | cropped_image = scipy.misc.imresize(image, [resize_height, 260 | resize_width]) 261 | return cropped_image 262 | -------------------------------------------------------------------------------- /datasets/fmnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Defense-GAN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | 16 | """Contains the class for handling the F-MNIST dataset.""" 17 | 18 | import os 19 | 20 | import numpy as np 21 | 22 | from datasets.dataset import Dataset 23 | 24 | 25 | class FMnist(Dataset): 26 | """Implements the Dataset class to handle F-MNIST. 27 | 28 | Attributes: 29 | y_dim: The dimension of label vectors (number of classes). 30 | split_data: A dictionary of 31 | { 32 | 'train': Images of np.ndarray, Int array of labels, and int 33 | array of ids. 34 | 'val': Images of np.ndarray, Int array of labels, and int 35 | array of ids. 36 | 'test': Images of np.ndarray, Int array of labels, and int 37 | array of ids. 38 | } 39 | """ 40 | 41 | def __init__(self): 42 | """F-MNIST Constructor.""" 43 | 44 | super(FMnist, self).__init__('f-mnist') 45 | self.y_dim = 10 46 | self.split_data = {} 47 | 48 | def load(self, split='train', lazy=False, randomize=True): 49 | """Implements the load function. 50 | 51 | Args: 52 | split: Dataset split, can be [train|dev|test], default: train. 53 | lazy: Not used for F-MNIST. 54 | 55 | Returns: 56 | Images of np.ndarray, Int array of labels, and int array of ids. 57 | 58 | Raises: 59 | ValueError: If split is not one of [train|val|test]. 60 | """ 61 | 62 | if split in self.split_data.keys(): 63 | return self.split_data[split] 64 | 65 | data_dir = self.data_dir 66 | 67 | fd = open(os.path.join(data_dir, 'train-images-idx3-ubyte')) 68 | loaded = np.fromfile(file=fd, dtype=np.uint8) 69 | train_images = loaded[16:].reshape((60000, 28, 28, 1)).astype(np.float) 70 | 71 | fd = open(os.path.join(data_dir, 'train-labels-idx1-ubyte')) 72 | loaded = np.fromfile(file=fd, dtype=np.uint8) 73 | train_labels = loaded[8:].reshape((60000)).astype(np.float) 74 | 75 | fd = open(os.path.join(data_dir, 't10k-images-idx3-ubyte')) 76 | loaded = np.fromfile(file=fd, dtype=np.uint8) 77 | test_images = loaded[16:].reshape((10000, 28, 28, 1)).astype(np.float) 78 | 79 | fd = open(os.path.join(data_dir, 't10k-labels-idx1-ubyte')) 80 | loaded = np.fromfile(file=fd, dtype=np.uint8) 81 | test_labels = loaded[8:].reshape((10000)).astype(np.float) 82 | 83 | train_labels = np.asarray(train_labels) 84 | test_labels = np.asarray(test_labels) 85 | if split == 'train': 86 | images = train_images[:50000] 87 | labels = train_labels[:50000] 88 | elif split == 'val': 89 | images = train_images[50000:60000] 90 | labels = train_labels[50000:60000] 91 | elif split == 'test': 92 | images = test_images 93 | labels = test_labels 94 | else: 95 | raise ValueError('[!] Invalid split {}.'.format(split)) 96 | 97 | if randomize: 98 | rng_state = np.random.get_state() 99 | np.random.shuffle(images) 100 | np.random.set_state(rng_state) 101 | np.random.shuffle(labels) 102 | images = np.reshape(images, [-1, 28, 28, 1]) 103 | self.split_data[split] = [images, labels] 104 | self.images = images 105 | self.labels = labels 106 | return images, labels 107 | -------------------------------------------------------------------------------- /datasets/mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Defense-GAN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | 16 | """Contains the class for handling the MNIST dataset.""" 17 | 18 | import os 19 | 20 | import numpy as np 21 | 22 | from datasets.dataset import Dataset 23 | 24 | 25 | class Mnist(Dataset): 26 | """Implements the Dataset class to handle MNIST. 27 | 28 | Attributes: 29 | y_dim: The dimension of label vectors (number of classes). 30 | split_data: A dictionary of 31 | { 32 | 'train': Images of np.ndarray, Int array of labels, and int 33 | array of ids. 34 | 'val': Images of np.ndarray, Int array of labels, and int 35 | array of ids. 36 | 'test': Images of np.ndarray, Int array of labels, and int 37 | array of ids. 38 | } 39 | """ 40 | 41 | def __init__(self): 42 | super(Mnist, self).__init__('mnist') 43 | self.y_dim = 10 44 | self.split_data = {} 45 | 46 | def load(self, split='train', lazy=True, randomize=True): 47 | """Implements the load function. 48 | 49 | Args: 50 | split: Dataset split, can be [train|dev|test], default: train. 51 | lazy: Not used for MNIST. 52 | 53 | Returns: 54 | Images of np.ndarray, Int array of labels, and int array of ids. 55 | 56 | Raises: 57 | ValueError: If split is not one of [train|val|test]. 58 | """ 59 | 60 | if split in self.split_data.keys(): 61 | return self.split_data[split] 62 | 63 | data_dir = self.data_dir 64 | 65 | fd = open(os.path.join(data_dir, 'train-images-idx3-ubyte')) 66 | loaded = np.fromfile(file=fd, dtype=np.uint8) 67 | train_images = loaded[16:].reshape((60000, 28, 28, 1)).astype(np.float) 68 | 69 | fd = open(os.path.join(data_dir, 'train-labels-idx1-ubyte')) 70 | loaded = np.fromfile(file=fd, dtype=np.uint8) 71 | train_labels = loaded[8:].reshape((60000)).astype(np.float) 72 | 73 | fd = open(os.path.join(data_dir, 't10k-images-idx3-ubyte')) 74 | loaded = np.fromfile(file=fd, dtype=np.uint8) 75 | test_images = loaded[16:].reshape((10000, 28, 28, 1)).astype(np.float) 76 | 77 | fd = open(os.path.join(data_dir, 't10k-labels-idx1-ubyte')) 78 | loaded = np.fromfile(file=fd, dtype=np.uint8) 79 | test_labels = loaded[8:].reshape((10000)).astype(np.float) 80 | 81 | train_labels = np.asarray(train_labels) 82 | test_labels = np.asarray(test_labels) 83 | if split == 'train': 84 | images = train_images[:50000] 85 | labels = train_labels[:50000] 86 | elif split == 'val': 87 | images = train_images[50000:60000] 88 | labels = train_labels[50000:60000] 89 | elif split == 'test': 90 | images = test_images 91 | labels = test_labels 92 | 93 | if randomize: 94 | rng_state = np.random.get_state() 95 | np.random.shuffle(images) 96 | np.random.set_state(rng_state) 97 | np.random.shuffle(labels) 98 | images = np.reshape(images, [-1, 28, 28, 1]) 99 | self.split_data[split] = [images, labels] 100 | self.images = images 101 | self.labels = labels 102 | 103 | return images, labels 104 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Defense-GAN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | 16 | """Data handling related utilities live in this module.""" 17 | 18 | import tensorflow as tf 19 | 20 | from datasets.celeba import CelebA 21 | from datasets.fmnist import FMnist 22 | from datasets.mnist import Mnist 23 | 24 | 25 | def create_generator(dataset_name, split, batch_size, randomize, 26 | attribute=None): 27 | """Creates a batch generator for the dataset. 28 | 29 | Args: 30 | dataset_name: `str`. The name of the dataset. 31 | split: `str`. The split of data. It can be `train`, `val`, or `test`. 32 | batch_size: An integer. The batch size. 33 | randomize: `bool`. Whether to randomize the order of images before 34 | batching. 35 | attribute (optional): For cele 36 | 37 | Returns: 38 | image_batch: A Python generator for the images. 39 | label_batch: A Python generator for the labels. 40 | """ 41 | flags = tf.app.flags.FLAGS 42 | 43 | if dataset_name.lower() == 'mnist': 44 | ds = Mnist() 45 | elif dataset_name.lower() == 'f-mnist': 46 | ds = FMnist() 47 | elif dataset_name.lower() == 'celeba': 48 | ds = CelebA(attribute=attribute) 49 | else: 50 | raise ValueError("Dataset {} is not supported.".format(dataset_name)) 51 | 52 | ds.load(split=split, randomize=randomize) 53 | 54 | def get_gen(): 55 | for i in range(0, len(ds) - batch_size, batch_size): 56 | image_batch, label_batch = ds.images[ 57 | i:i + batch_size], \ 58 | ds.labels[i:i + batch_size] 59 | yield image_batch, label_batch 60 | 61 | return get_gen 62 | 63 | 64 | def get_generators(dataset_name, batch_size, randomize=True, attribute='gender'): 65 | """Creates batch generators for datasets. 66 | 67 | Args: 68 | dataset_name: A `string`. Name of the dataset. 69 | batch_size: An `integer`. The size of each batch. 70 | randomize: A `boolean`. 71 | attribute: A `string`. If the dataset name is `celeba`, this will 72 | indicate the attribute name that labels should be returned for. 73 | 74 | Returns: 75 | Training, validation, and test dataset generators which are the 76 | return values of `create_generator`. 77 | """ 78 | splits = ['train', 'val', 'test'] 79 | gens = [] 80 | for i in range(3): 81 | if i > 0: 82 | randomize = False 83 | gens.append( 84 | create_generator(dataset_name, splits[i], batch_size, randomize, 85 | attribute=attribute)) 86 | 87 | return gens -------------------------------------------------------------------------------- /download_dataset.py: -------------------------------------------------------------------------------- 1 | """Modification of https://github.com/stanfordnlp/treelstm/blob/master/scripts/download.py 2 | 3 | Downloads the following: 4 | - Celeb-A dataset 5 | - FMNIST dataset 6 | - MNIST dataset 7 | """ 8 | 9 | from __future__ import print_function 10 | import os 11 | import sys 12 | import json 13 | import zipfile 14 | import argparse 15 | import requests 16 | import subprocess 17 | from tqdm import tqdm 18 | from six.moves import urllib 19 | 20 | parser = argparse.ArgumentParser(description='Download dataset for DCGAN.') 21 | parser.add_argument('datasets', metavar='N', type=str, nargs='+', choices=['celebA', 'lsun', 'mnist','f-mnist'], 22 | help='name of dataset to download [celebA, lsun, mnist, fmnist]') 23 | 24 | def download(url, dirpath): 25 | filename = url.split('/')[-1] 26 | filepath = os.path.join(dirpath, filename) 27 | u = urllib.request.urlopen(url) 28 | f = open(filepath, 'wb') 29 | filesize = int(u.headers["Content-Length"]) 30 | print("Downloading: %s Bytes: %s" % (filename, filesize)) 31 | 32 | downloaded = 0 33 | block_sz = 8192 34 | status_width = 70 35 | while True: 36 | buf = u.read(block_sz) 37 | if not buf: 38 | print('') 39 | break 40 | else: 41 | print('', end='\r') 42 | downloaded += len(buf) 43 | f.write(buf) 44 | status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") % 45 | ('=' * int(float(downloaded) / filesize * status_width) + '>', downloaded * 100. / filesize)) 46 | print(status, end='') 47 | sys.stdout.flush() 48 | f.close() 49 | return filepath 50 | 51 | def download_file_from_google_drive(id, destination): 52 | URL = "https://docs.google.com/uc?export=download" 53 | session = requests.Session() 54 | 55 | response = session.get(URL, params={ 'id': id }, stream=True,verify=False) 56 | token = get_confirm_token(response) 57 | 58 | if token: 59 | params = { 'id' : id, 'confirm' : token } 60 | response = session.get(URL, params=params, stream=True,verify=False) 61 | 62 | save_response_content(response, destination) 63 | 64 | def get_confirm_token(response): 65 | for key, value in response.cookies.items(): 66 | if key.startswith('download_warning'): 67 | return value 68 | return None 69 | 70 | def save_response_content(response, destination, chunk_size=32*1024): 71 | total_size = int(response.headers.get('content-length', 0)) 72 | with open(destination, "wb") as f: 73 | for chunk in tqdm(response.iter_content(chunk_size), total=total_size, 74 | unit='B', unit_scale=True, desc=destination): 75 | if chunk: # filter out keep-alive new chunks 76 | f.write(chunk) 77 | 78 | def unzip(filepath): 79 | print("Extracting: " + filepath) 80 | dirpath = os.path.dirname(filepath) 81 | with zipfile.ZipFile(filepath) as zf: 82 | zf.extractall(dirpath) 83 | os.remove(filepath) 84 | 85 | def download_celeb_a(dirpath): 86 | data_dir = 'celebA' 87 | if os.path.exists(os.path.join(dirpath, data_dir)): 88 | print('Found Celeb-A - skip') 89 | return 90 | 91 | filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM" 92 | save_path = os.path.join(dirpath, filename) 93 | 94 | if os.path.exists(save_path): 95 | print('[*] {} already exists'.format(save_path)) 96 | else: 97 | download_file_from_google_drive(drive_id, save_path) 98 | 99 | zip_dir = '' 100 | with zipfile.ZipFile(save_path) as zf: 101 | zip_dir = zf.namelist()[0] 102 | zf.extractall(dirpath) 103 | os.remove(save_path) 104 | os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir)) 105 | 106 | def _list_categories(tag): 107 | url = 'http://lsun.cs.princeton.edu/htbin/list.cgi?tag=' + tag 108 | f = urllib.request.urlopen(url) 109 | return json.loads(f.read()) 110 | 111 | def _download_lsun(out_dir, category, set_name, tag): 112 | url = 'http://lsun.cs.princeton.edu/htbin/download.cgi?tag={tag}' \ 113 | '&category={category}&set={set_name}'.format(**locals()) 114 | print(url) 115 | if set_name == 'test': 116 | out_name = 'test_lmdb.zip' 117 | else: 118 | out_name = '{category}_{set_name}_lmdb.zip'.format(**locals()) 119 | out_path = os.path.join(out_dir, out_name) 120 | cmd = ['curl', url, '-o', out_path] 121 | print('Downloading', category, set_name, 'set') 122 | subprocess.call(cmd) 123 | 124 | def download_lsun(dirpath): 125 | data_dir = os.path.join(dirpath, 'lsun') 126 | if os.path.exists(data_dir): 127 | print('Found LSUN - skip') 128 | return 129 | else: 130 | os.mkdir(data_dir) 131 | 132 | tag = 'latest' 133 | categories = ['bedroom'] 134 | 135 | for category in categories: 136 | _download_lsun(data_dir, category, 'train', tag) 137 | _download_lsun(data_dir, category, 'val', tag) 138 | _download_lsun(data_dir, '', 'test', tag) 139 | 140 | def download_mnist(dirpath): 141 | data_dir = os.path.join(dirpath, 'mnist') 142 | if os.path.exists(data_dir): 143 | print('Found MNIST - skip') 144 | return 145 | else: 146 | os.mkdir(data_dir) 147 | url_base = 'http://yann.lecun.com/exdb/mnist/' 148 | file_names = ['train-images-idx3-ubyte.gz', 149 | 'train-labels-idx1-ubyte.gz', 150 | 't10k-images-idx3-ubyte.gz', 151 | 't10k-labels-idx1-ubyte.gz'] 152 | for file_name in file_names: 153 | url = (url_base+file_name).format(**locals()) 154 | print(url) 155 | out_path = os.path.join(data_dir,file_name) 156 | cmd = ['curl', url, '-o', out_path] 157 | print('Downloading ', file_name) 158 | subprocess.call(cmd) 159 | cmd = ['gzip', '-d', out_path] 160 | print('Decompressing ', file_name) 161 | subprocess.call(cmd) 162 | 163 | 164 | def download_fmnist(dirpath): 165 | data_dir = os.path.join(dirpath, 'f-mnist') 166 | if os.path.exists(data_dir): 167 | print('Found F-MNIST - skip') 168 | return 169 | else: 170 | os.mkdir(data_dir) 171 | url_base = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' 172 | file_names = ['train-images-idx3-ubyte.gz', 173 | 'train-labels-idx1-ubyte.gz', 174 | 't10k-images-idx3-ubyte.gz', 175 | 't10k-labels-idx1-ubyte.gz'] 176 | for file_name in file_names: 177 | url = (url_base+file_name).format(**locals()) 178 | print(url) 179 | out_path = os.path.join(data_dir,file_name) 180 | cmd = ['curl', url, '-o', out_path] 181 | print('Downloading ', file_name) 182 | subprocess.call(cmd) 183 | cmd = ['gzip', '-d', out_path] 184 | print('Decompressing ', file_name) 185 | subprocess.call(cmd) 186 | 187 | def prepare_data_dir(path = './data'): 188 | if not os.path.exists(path): 189 | os.mkdir(path) 190 | 191 | if __name__ == '__main__': 192 | args = parser.parse_args() 193 | prepare_data_dir() 194 | 195 | if any(name in args.datasets for name in ['CelebA', 'celebA', 'celebA']): 196 | download_celeb_a('./data') 197 | if 'mnist' in args.datasets: 198 | download_mnist('./data') 199 | if 'fmnist' in args.datasets: 200 | download_fmnist('./data') 201 | -------------------------------------------------------------------------------- /experiments/cfgs/gans/celeba.yml: -------------------------------------------------------------------------------- 1 | DATASET_NAME: 'celeba' 2 | ARCH_TYPE: 'celeba' 3 | MODE: 'wgan-gp' 4 | CRITIC_ITERS: 5 5 | REC_ITERS: 200 6 | REC_LR: 10.0 7 | REC_RR: 2 8 | IMAGE_DIM: [64,64,3] 9 | ATTRIBUTE: gender 10 | NET_DIM: 64 -------------------------------------------------------------------------------- /experiments/cfgs/gans/default.yml: -------------------------------------------------------------------------------- 1 | MODE: 'wgan-gp' 2 | BATCH_SIZE: 50 3 | USE_BN: False 4 | LATENT_DIM: 128 5 | GRADIENT_PENALTY_LAMBDA: 10.0 6 | OUTPUT_DIR: output 7 | NET_DIM: 64 8 | TRAIN_ITERS: 200000 9 | DISC_LAMBDA: 0.0 10 | TV_LAMBDA: 0.0 11 | ATTRIBUTE: 12 | TEST_BATCH_SIZE: 20 13 | NUM_GPUS: 1 14 | ENC_TRAIN_ITERS: 100000 15 | INPUT_TRANSFORM_TYPE: 0 16 | -------------------------------------------------------------------------------- /experiments/cfgs/gans/fmnist.yml: -------------------------------------------------------------------------------- 1 | DATASET_NAME: 'f-mnist' 2 | ARCH_TYPE: 'f-mnist' 3 | MODE: 'wgan-gp' 4 | CRITIC_ITERS: 5 5 | REC_ITERS: 200 6 | REC_LR: 10.0 7 | REC_RR: 10 8 | IMAGE_DIM: [28,28,1] 9 | INPUR_TRANSFORM_TYPE: 1 -------------------------------------------------------------------------------- /experiments/cfgs/gans/mnist.yml: -------------------------------------------------------------------------------- 1 | DATASET_NAME: 'mnist' 2 | ARCH_TYPE: 'mnist' 3 | MODE: 'wgan-gp' 4 | CRITIC_ITERS: 5 5 | REC_ITERS: 200 6 | REC_LR: 10.0 7 | REC_RR: 10 8 | IMAGE_DIM: [28,28,1] 9 | INPUR_TRANSFORM_TYPE: 1 -------------------------------------------------------------------------------- /experiments/cfgs/key_doc.yml: -------------------------------------------------------------------------------- 1 | BATCH_SIZE: 'batch size [64]' 2 | MODEL_DIM: 'The width of the generator and discriminator will be proportional to this' 3 | CRITIC_ITERS: 'The number of critic iterations per generator iterations' 4 | ITERS: 'The number of iterations' 5 | OUTPUT_DIM: 'The dimension of the output of the generator' 6 | DATASET_NAME: 'Name of the dataset' 7 | ARCH_TYPE: 'Architecture type' 8 | GRADIENT_PENALTY_LAMBDA: 'Gradient penalty weight' 9 | LATENT_DIM: 'Dimension of latent variables' 10 | INPUT_TRANSFORM_TYPE: 'The way to normalize the input 0:[-1,1] 1:[0,1] 2:mean subtraction' 11 | REC_ITERS: 'Number of iterations for reconstruction step' 12 | IMAGE_DIM: 'The image dimension [width,height,channels]' 13 | DISC_LAMBDA: 'The lambda for discriminator loss in reconstruction' 14 | TV_LAMBDA: 'TV denoising lambda for reconstruction' 15 | ATTRIBUTE: 'For celeba dataset' 16 | NUM_GPUS: 'number of gpus' -------------------------------------------------------------------------------- /figures/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kabkabm/defensegan/7e3feaebf7b9bbf08b1364e400119ef596cd78fd/figures/.DS_Store -------------------------------------------------------------------------------- /figures/defensegan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kabkabm/defensegan/7e3feaebf7b9bbf08b1364e400119ef596cd78fd/figures/defensegan.png -------------------------------------------------------------------------------- /figures/defensegan_gd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kabkabm/defensegan/7e3feaebf7b9bbf08b1364e400119ef596cd78fd/figures/defensegan_gd.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kabkabm/defensegan/7e3feaebf7b9bbf08b1364e400119ef596cd78fd/models/__init__.py -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Defense-GAN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | 16 | """Contains the abstract class for models.""" 17 | 18 | import os 19 | 20 | import tensorflow as tf 21 | import yaml 22 | from utils.misc import ensure_dir 23 | from tensorflow.contrib import slim 24 | 25 | from utils.dummy import DummySummaryWriter 26 | 27 | 28 | class AbstractModel(object): 29 | def __init__(self, default_properties, test_mode=False, verbose=True, 30 | cfg=None, **args): 31 | """The abstract model that the other models extend. 32 | 33 | Args: 34 | default_properties: The attributes of an experiment, read from a 35 | config file 36 | test_mode: If in the test mode, computation graph for loss will 37 | not be constructed, config will be saved in the output directory 38 | verbose: If true, prints debug information 39 | cfg: Config dictionary 40 | args: The rest of the arguments which can become object attributes 41 | """ 42 | 43 | # Set attributes either from FLAGS or **args. 44 | self.cfg = cfg 45 | 46 | # Active session parameter. 47 | self.active_sess = None 48 | 49 | # Object attributes. 50 | default_properties.extend( 51 | ['tensorboard_log', 'output_dir', 'num_gpus']) 52 | self.default_properties = default_properties 53 | self.initialized = False 54 | self.verbose = verbose 55 | self.output_dir = 'output' 56 | 57 | local_vals = locals() 58 | args.update(local_vals) 59 | for attr in default_properties: 60 | if attr in args.keys(): 61 | self._set_attr(attr, args[attr]) 62 | else: 63 | self._set_attr(attr, None) 64 | 65 | # Runtime attributes. 66 | self.saver = None 67 | self.global_step = tf.train.get_or_create_global_step() 68 | self.global_step_inc = \ 69 | tf.assign(self.global_step, tf.add(self.global_step, 1)) 70 | 71 | # Phase: 1 train 0 test. 72 | self.is_training = tf.placeholder(dtype=tf.bool) 73 | self.save_vars = {} 74 | self.save_var_prefixes = [] 75 | self.dataset = None 76 | self.test_mode = test_mode 77 | 78 | self._set_checkpoint_dir() 79 | self._build() 80 | if not test_mode: 81 | self._save_cfg_in_ckpt() 82 | self._loss() 83 | 84 | self._initialize_summary_writer() 85 | 86 | 87 | def _load_dataset(self): 88 | pass 89 | 90 | def _build(self): 91 | pass 92 | 93 | def _loss(self): 94 | pass 95 | 96 | def test(self, input): 97 | pass 98 | 99 | def train(self): 100 | pass 101 | 102 | def _verbose_print(self, message): 103 | """Handy verbose print function""" 104 | if self.verbose: 105 | print(message) 106 | 107 | def _save_cfg_in_ckpt(self): 108 | """Saves the configuration in the experiment's output directory.""" 109 | final_cfg = {} 110 | if hasattr(self, 'cfg'): 111 | for k in self.cfg.keys(): 112 | if hasattr(self, k.lower()): 113 | if getattr(self, k.lower()) is not None: 114 | final_cfg[k] = getattr(self, k.lower()) 115 | if not self.test_mode: 116 | with open(os.path.join(self.checkpoint_dir, 'cfg.yml'), 117 | 'w') as f: 118 | yaml.dump(final_cfg, f) 119 | 120 | def _set_attr(self, attr_name, val): 121 | """Sets an object attribute from FLAGS if it exists, if not it 122 | prints out an error. Note that FLAGS is set from config and command 123 | line inputs. 124 | 125 | 126 | Args: 127 | attr_name: The name of the field. 128 | val: The value, if None it will set it from tf.apps.flags.FLAGS 129 | """ 130 | 131 | FLAGS = tf.app.flags.FLAGS 132 | 133 | if val is None: 134 | if hasattr(FLAGS, attr_name): 135 | val = getattr(FLAGS, attr_name) 136 | elif hasattr(self, 'cfg'): 137 | if attr_name.upper() in self.cfg.keys(): 138 | val = self.cfg[attr_name.upper()] 139 | elif attr_name.lower() in self.cfg.keys(): 140 | val = self.cfg[attr_name.lower()] 141 | if val is None and self.verbose: 142 | print( 143 | '[-] {}.{} is not set.'.format(type(self).__name__, attr_name)) 144 | 145 | setattr(self, attr_name, val) 146 | if self.verbose: 147 | print('[#] {}.{} is set to {}.'.format(type(self).__name__, 148 | attr_name, val)) 149 | 150 | def imsave_transform(self, imgs): 151 | return imgs 152 | 153 | def get_learning_rate(self, init_lr=None, decay_epoch=None, 154 | decay_mult=None, iters_per_epoch=None, 155 | decay_iter=None, 156 | global_step=None, decay_lr=True): 157 | """Prepares the learning rate. 158 | 159 | Args: 160 | init_lr: The initial learning rate 161 | decay_epoch: The epoch of decay 162 | decay_mult: The decay factor 163 | iters_per_epoch: Number of iterations per epoch 164 | decay_iter: The iteration of decay [either this or decay_epoch 165 | should be set] 166 | global_step: 167 | decay_lr: 168 | 169 | Returns: 170 | `tf.Tensor` of the learning rate. 171 | """ 172 | if init_lr is None: 173 | init_lr = self.learning_rate 174 | if global_step is None: 175 | global_step = self.global_step 176 | 177 | if decay_epoch: 178 | assert iters_per_epoch 179 | 180 | if iters_per_epoch is None: 181 | iters_per_epoch = self.iters_per_epoch 182 | else: 183 | assert decay_iter 184 | 185 | if decay_lr: 186 | if decay_epoch: 187 | decay_iter = decay_epoch * iters_per_epoch 188 | return tf.train.exponential_decay(init_lr, 189 | global_step, 190 | decay_iter, 191 | decay_mult, 192 | staircase=True) 193 | else: 194 | return tf.constant(self.learning_rate) 195 | 196 | 197 | def _set_checkpoint_dir(self): 198 | """Sets the directory containing snapshots of the model.""" 199 | 200 | self.cfg_file = self.cfg['cfg_path'] 201 | if 'cfg.yml' in self.cfg_file: 202 | ckpt_dir = os.path.dirname(self.cfg_file) 203 | 204 | else: 205 | ckpt_dir = os.path.join(self.output_dir, 206 | self.cfg_file.replace('experiments/cfgs/', 207 | '').replace( 208 | 'cfg.yml', '').replace( 209 | '.yml', '')) 210 | if not self.test_mode: 211 | postfix = '' 212 | ignore_list = ['dataset', 'cfg_file', 'batch_size'] 213 | if hasattr(self, 'cfg'): 214 | if self.cfg is not None: 215 | for prop in self.default_properties: 216 | if prop in ignore_list: 217 | continue 218 | 219 | if prop.upper() in self.cfg.keys(): 220 | self_val = getattr(self, prop) 221 | if self_val is not None: 222 | if getattr(self, prop) != self.cfg[ 223 | prop.upper()]: 224 | postfix += '-{}={}'.format( 225 | prop, self_val).replace('.', '_') 226 | 227 | ckpt_dir += postfix 228 | ensure_dir(ckpt_dir) 229 | 230 | self.checkpoint_dir = ckpt_dir 231 | self.debug_dir = self.checkpoint_dir.replace('output', 'debug') 232 | ensure_dir(self.debug_dir) 233 | 234 | def _initialize_summary_writer(self): 235 | # Setup the summary writer. 236 | if not self.tensorboard_log: 237 | self.summary_writer = DummySummaryWriter() 238 | else: 239 | sum_dir = os.path.join(self.checkpoint_dir, 'tb_logs') 240 | if not os.path.exists(sum_dir): 241 | os.makedirs(sum_dir) 242 | 243 | self.summary_writer = tf.summary.FileWriter(sum_dir) 244 | 245 | def _initialize_saver(self, prefixes=None, force=False, max_to_keep=5): 246 | """Initializes the saver object. 247 | 248 | Args: 249 | prefixes: The prefixes that the saver should take care of. 250 | force (optional): Even if saver is set, reconstruct the saver 251 | object. 252 | max_to_keep (optional): 253 | """ 254 | if self.saver is not None and not force: 255 | return 256 | else: 257 | if prefixes is None or not ( 258 | type(prefixes) != list or type(prefixes) != tuple): 259 | raise ValueError( 260 | 'Prefix of variables that needs saving are not defined') 261 | 262 | prefixes_str = '' 263 | for pref in prefixes: 264 | prefixes_str = prefixes_str + pref + ' ' 265 | 266 | print('[#] Initializing it with variable prefixes: {}'.format( 267 | prefixes_str)) 268 | saved_vars = [] 269 | for pref in prefixes: 270 | saved_vars.extend(slim.get_variables(pref)) 271 | 272 | self.saver = tf.train.Saver(saved_vars, max_to_keep=max_to_keep) 273 | 274 | def set_session(self, sess): 275 | """""" 276 | if self.active_sess is None: 277 | self.active_sess = sess 278 | else: 279 | raise EnvironmentError("Session is already set.") 280 | 281 | @property 282 | def sess(self): 283 | if self.active_sess is None: 284 | config = tf.ConfigProto() 285 | config.gpu_options.allow_growth = True 286 | self.active_sess = tf.Session(config=config) 287 | 288 | return self.active_sess 289 | 290 | def close_session(self): 291 | if self.active_sess: 292 | self.active_sess.close() 293 | 294 | def load(self, checkpoint_dir=None, prefixes=None, saver=None): 295 | """Loads the saved weights to the model from the checkpoint directory 296 | 297 | Args: 298 | checkpoint_dir: The path to saved models 299 | """ 300 | if prefixes is None: 301 | prefixes = self.save_var_prefixes 302 | if self.saver is None: 303 | print('[!] Saver is not initialized') 304 | self._initialize_saver(prefixes=prefixes) 305 | 306 | if saver is None: 307 | saver = self.saver 308 | 309 | if checkpoint_dir is None: 310 | checkpoint_dir = self.checkpoint_dir 311 | 312 | if not os.path.isdir(checkpoint_dir): 313 | try: 314 | saver.restore(self.sess, checkpoint_dir) 315 | except: 316 | print(" [!] Failed to find a checkpoint at {}".format( 317 | checkpoint_dir)) 318 | else: 319 | print(" [-] Reading checkpoints... {} ".format(checkpoint_dir)) 320 | 321 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 322 | if ckpt and ckpt.model_checkpoint_path: 323 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 324 | saver.restore(self.sess, 325 | os.path.join(checkpoint_dir, ckpt_name)) 326 | else: 327 | print( 328 | " [!] Failed to find a checkpoint " 329 | "within directory {}".format(checkpoint_dir)) 330 | return False 331 | 332 | print(" [*] Checkpoint is read successfully from {}".format( 333 | checkpoint_dir)) 334 | 335 | return True 336 | 337 | def add_save_vars(self, prefixes): 338 | """Prepares the list of variables that should be saved based on 339 | their name prefix. 340 | 341 | Args: 342 | prefixes: Variable name prefixes to find and save. 343 | """ 344 | 345 | for pre in prefixes: 346 | pre_vars = slim.get_variables(pre) 347 | self.save_vars.update(pre_vars) 348 | 349 | var_list = '' 350 | for var in self.save_vars: 351 | var_list = var_list + var.name + ' ' 352 | 353 | print ('Saving these variables: {}'.format(var_list)) 354 | 355 | def input_transform(self, images): 356 | pass 357 | 358 | def input_pl_transform(self): 359 | self.real_data = self.input_transform(self.real_data_pl) 360 | self.real_data_test = self.input_transform(self.real_data_test_pl) 361 | 362 | def initialize_uninitialized(self, ): 363 | """Only initializes the variables of a TensorFlow session that were not 364 | already initialized. 365 | """ 366 | # List all global variables. 367 | sess = self.sess 368 | global_vars = tf.global_variables() 369 | 370 | # Find initialized status for all variables. 371 | is_var_init = [tf.is_variable_initialized(var) for var in global_vars] 372 | is_initialized = sess.run(is_var_init) 373 | 374 | # List all variables that were not previously initialized. 375 | not_initialized_vars = [var for (var, init) in 376 | zip(global_vars, is_initialized) if not init] 377 | for v in not_initialized_vars: 378 | print('[!] not init: {}'.format(v.name)) 379 | # Initialize all uninitialized variables found, if any. 380 | if len(not_initialized_vars): 381 | sess.run(tf.variables_initializer(not_initialized_vars)) 382 | 383 | def save(self, prefixes=None, global_step=None, checkpoint_dir=None): 384 | if global_step is None: 385 | global_step = self.global_step 386 | if checkpoint_dir is None: 387 | checkpoint_dir = self._set_checkpoint_dir 388 | 389 | ensure_dir(checkpoint_dir) 390 | self._initialize_saver(prefixes) 391 | self.saver.save(self.sess, 392 | os.path.join(checkpoint_dir, self.model_save_name), 393 | global_step=global_step) 394 | print('Saved at iter {} to {}'.format(self.sess.run(global_step), 395 | checkpoint_dir)) 396 | 397 | def initialize(self, dir): 398 | self.load(dir) 399 | self.initialized = True 400 | -------------------------------------------------------------------------------- /models/dataset_models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import tflib as lib 4 | import tflib.ops.batchnorm 5 | import tflib.ops.conv2d 6 | import tflib.ops.deconv2d 7 | import tflib.ops.linear 8 | 9 | 10 | def LeakyReLU(x, alpha=0.2): 11 | return tf.maximum(alpha * x, x) 12 | 13 | 14 | def ReLULayer(name, n_in, n_out, inputs): 15 | output = lib.ops.linear.Linear( 16 | name + '.Linear', 17 | n_in, 18 | n_out, 19 | inputs, 20 | initialization='he' 21 | ) 22 | return tf.nn.relu(output) 23 | 24 | 25 | def LeakyReLULayer(name, n_in, n_out, inputs): 26 | output = lib.ops.linear.Linear( 27 | name + '.Linear', 28 | n_in, 29 | n_out, 30 | inputs, 31 | initialization='he' 32 | ) 33 | return LeakyReLU(output) 34 | 35 | 36 | def mnist_generator(n_samples, noise=None, use_bn=False, 37 | net_dim=64, output_dim=64, is_training=False, 38 | latent_dim=128): 39 | if noise is None: 40 | noise = tf.random_normal([n_samples, latent_dim]) 41 | 42 | output = lib.ops.linear.Linear('Generator.Input', latent_dim, 43 | 4 * 4 * 4 * net_dim, noise) 44 | if use_bn: 45 | output = lib.ops.batchnorm.Batchnorm('Generator.BN1', [0], output, 46 | is_training=is_training) 47 | 48 | output = tf.nn.relu(output) 49 | 50 | output = tf.reshape(output, [-1, 4, 4, 4 * net_dim, ]) 51 | 52 | output = lib.ops.deconv2d.Deconv2D('Generator.2', 4 * net_dim, 2 * net_dim, 53 | 5, output) 54 | if use_bn: 55 | output = lib.ops.batchnorm.Batchnorm('Generator.BN2', [0, 1, 2], 56 | output, is_training=is_training) 57 | output = tf.nn.relu(output) 58 | 59 | output = output[:, :7, :7, :] 60 | 61 | output = lib.ops.deconv2d.Deconv2D('Generator.3', 2 * net_dim, net_dim, 5, 62 | output) 63 | if use_bn: 64 | output = lib.ops.batchnorm.Batchnorm('Generator.BN3', [0, 1, 2], 65 | output, is_training=is_training) 66 | output = tf.nn.relu(output) 67 | 68 | output = lib.ops.deconv2d.Deconv2D('Generator.5', net_dim, 1, 5, output) 69 | output = tf.nn.sigmoid(output) 70 | 71 | return output 72 | 73 | 74 | def mnist_discriminator(inputs, use_bn=False, net_dim=128, is_training=False): 75 | output = lib.ops.conv2d.Conv2D('Discriminator.1', 1, net_dim, 5, inputs, 76 | stride=2) 77 | output = LeakyReLU(output) 78 | 79 | output = lib.ops.conv2d.Conv2D('Discriminator.2', net_dim, 2 * net_dim, 5, 80 | output, stride=2) 81 | if use_bn: 82 | output = lib.ops.batchnorm.Batchnorm('Discriminator.BN2', [0, 1, 2], 83 | output, is_training=is_training) 84 | output = LeakyReLU(output) 85 | 86 | output = lib.ops.conv2d.Conv2D('Discriminator.3', 2 * net_dim, 4 * net_dim, 87 | 5, output, stride=2) 88 | if use_bn: 89 | output = lib.ops.batchnorm.Batchnorm('Discriminator.BN3', [0, 1, 2], 90 | output, is_training=is_training) 91 | output = LeakyReLU(output) 92 | 93 | output = tf.reshape(output, [-1, 4 * 4 * 4 * net_dim]) 94 | output = lib.ops.linear.Linear('Discriminator.Output', 4 * 4 * 4 * net_dim, 95 | 1, output) 96 | 97 | return tf.reshape(output, [-1]) 98 | 99 | 100 | def MnistEncoder(inputs, use_bn=False, net_dim=128, is_training=False, 101 | latent_dim=None): 102 | output = lib.ops.conv2d.Conv2D('Encoder.1', 1, net_dim, 5, inputs, 103 | stride=2) 104 | output = LeakyReLU(output) 105 | 106 | output = lib.ops.conv2d.Conv2D('Encoder.2', net_dim, 2 * net_dim, 5, 107 | output, stride=2) 108 | if use_bn: 109 | output = lib.ops.batchnorm.Batchnorm('Encoder.BN2', [0, 1, 2], output, 110 | is_training=is_training) 111 | output = LeakyReLU(output) 112 | 113 | output = lib.ops.conv2d.Conv2D('Encoder.3', 2 * net_dim, 4 * net_dim, 5, 114 | output, stride=2) 115 | if use_bn: 116 | output = lib.ops.batchnorm.Batchnorm('Encoder.BN3', [0, 1, 2], output, 117 | is_training=is_training) 118 | output = LeakyReLU(output) 119 | 120 | output = tf.reshape(output, [-1, 4 * 4 * 4 * net_dim]) 121 | output = lib.ops.linear.Linear('Encoder.Output', 4 * 4 * 4 * net_dim, 122 | latent_dim, output) 123 | 124 | return tf.tanh(output) 125 | 126 | 127 | def celeba_generator(n_samples, noise=None, use_bn=False, 128 | net_dim=64, output_dim=64, is_training=False, 129 | latent_dim=128, stats_iter=None): 130 | if noise is None: 131 | noise = tf.random_normal([n_samples, latent_dim]) 132 | 133 | output = lib.ops.linear.Linear('Generator.Input', latent_dim, 134 | 4 * 4 * 4 * net_dim, noise) 135 | if use_bn: 136 | output = lib.ops.batchnorm.Batchnorm('Generator.BN1', [0], output, 137 | is_training=is_training, 138 | stats_iter=stats_iter) 139 | output = tf.nn.relu(output) 140 | output = tf.reshape(output, [-1, 4, 4, 4 * net_dim]) 141 | 142 | output = lib.ops.deconv2d.Deconv2D('Generator.2', 4 * net_dim, 2 * net_dim, 143 | 5, output) 144 | if use_bn: 145 | output = lib.ops.batchnorm.Batchnorm('Generator.BN2', [0, 1, 2], 146 | output, is_training=is_training, 147 | stats_iter=stats_iter) 148 | output = tf.nn.relu(output) 149 | 150 | output = lib.ops.deconv2d.Deconv2D('Generator.3', 2 * net_dim, net_dim, 5, 151 | output) 152 | if use_bn: 153 | output = lib.ops.batchnorm.Batchnorm('Generator.BN3', [0, 1, 2], 154 | output, is_training=is_training, 155 | stats_iter=stats_iter) 156 | output = tf.nn.relu(output) 157 | 158 | output = lib.ops.deconv2d.Deconv2D('Generator.5', net_dim, net_dim, 5, 159 | output) 160 | 161 | output = lib.ops.deconv2d.Deconv2D('Generator.6', net_dim, 3, 5, output) 162 | 163 | output = tf.tanh(output) 164 | 165 | return output 166 | 167 | 168 | def celeba_discriminator(inputs, use_bn=False, net_dim=128, is_training=False, 169 | stats_iter=None, data_format='NCHW'): 170 | output = lib.ops.conv2d.Conv2D('Discriminator.1', 3, net_dim, 5, inputs, 171 | stride=2, data_format=data_format) 172 | output = LeakyReLU(output) 173 | 174 | output = lib.ops.conv2d.Conv2D('Discriminator.2', net_dim, 2 * net_dim, 5, 175 | output, stride=2, 176 | data_format=data_format) 177 | if use_bn: 178 | output = lib.ops.batchnorm.Batchnorm('Discriminator.BN2', [0, 2, 3], 179 | output, is_training=is_training, 180 | stats_iter=stats_iter, 181 | data_format=data_format) 182 | output = LeakyReLU(output) 183 | 184 | output = lib.ops.conv2d.Conv2D('Discriminator.3', 2 * net_dim, 4 * net_dim, 185 | 5, output, stride=2, 186 | data_format=data_format) 187 | if use_bn: 188 | output = lib.ops.batchnorm.Batchnorm('Discriminator.BN3', [0, 2, 3], 189 | output, is_training=is_training, 190 | stats_iter=stats_iter) 191 | output = LeakyReLU(output) 192 | 193 | output = tf.reshape(output, [-1, 4 * 4 * 4 * net_dim]) 194 | output = lib.ops.linear.Linear('Discriminator.Output', 4 * 4 * 4 * net_dim, 195 | 1, output) 196 | 197 | return tf.reshape(output, [-1]) 198 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.14.2 2 | scipy==1.0.1 3 | requests==2.20.0 4 | keras==2.1.5 5 | opencv-python==3.4.0.12 6 | scikit-image==0.13.1 7 | matplotlib==2.1.2 8 | tqdm=4.28.1 9 | -------------------------------------------------------------------------------- /tflib/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import locale 4 | 5 | locale.setlocale(locale.LC_ALL, '') 6 | 7 | _params = {} 8 | _param_aliases = {} 9 | def param(name, *args, **kwargs): 10 | """ 11 | A wrapper for `tf.Variable` which enables parameter sharing in models. 12 | 13 | Creates and returns theano shared variables similarly to `tf.Variable`, 14 | except if you try to create a param with the same name as a 15 | previously-created one, `param(...)` will just return the old one instead of 16 | making a new one. 17 | 18 | This constructor also adds a `param` attribute to the shared variables it 19 | creates, so that you can easily search a graph for all params. 20 | """ 21 | 22 | if name not in _params: 23 | kwargs['name'] = name 24 | param = tf.Variable(*args, **kwargs) 25 | param.param = True 26 | _params[name] = param 27 | result = _params[name] 28 | i = 0 29 | while result in _param_aliases: 30 | # print 'following alias {}: {} to {}'.format(i, result, _param_aliases[result]) 31 | i += 1 32 | result = _param_aliases[result] 33 | return result 34 | 35 | def params_with_name(name): 36 | return [p for n,p in _params.items() if name in n] 37 | 38 | def delete_all_params(): 39 | _params.clear() 40 | 41 | def alias_params(replace_dict): 42 | for old,new in replace_dict.items(): 43 | # print "aliasing {} to {}".format(old,new) 44 | _param_aliases[old] = new 45 | 46 | def delete_param_aliases(): 47 | _param_aliases.clear() 48 | 49 | # def search(node, critereon): 50 | # """ 51 | # Traverse the Theano graph starting at `node` and return a list of all nodes 52 | # which match the `critereon` function. When optimizing a cost function, you 53 | # can use this to get a list of all of the trainable params in the graph, like 54 | # so: 55 | 56 | # `lib.search(cost, lambda x: hasattr(x, "param"))` 57 | # """ 58 | 59 | # def _search(node, critereon, visited): 60 | # if node in visited: 61 | # return [] 62 | # visited.add(node) 63 | 64 | # results = [] 65 | # if isinstance(node, T.Apply): 66 | # for inp in node.inputs: 67 | # results += _search(inp, critereon, visited) 68 | # else: # Variable node 69 | # if critereon(node): 70 | # results.append(node) 71 | # if node.owner is not None: 72 | # results += _search(node.owner, critereon, visited) 73 | # return results 74 | 75 | # return _search(node, critereon, set()) 76 | 77 | # def print_params_info(params): 78 | # """Print information about the parameters in the given param set.""" 79 | 80 | # params = sorted(params, key=lambda p: p.name) 81 | # values = [p.get_value(borrow=True) for p in params] 82 | # shapes = [p.shape for p in values] 83 | # print "Params for cost:" 84 | # for param, value, shape in zip(params, values, shapes): 85 | # print "\t{0} ({1})".format( 86 | # param.name, 87 | # ",".join([str(x) for x in shape]) 88 | # ) 89 | 90 | # total_param_count = 0 91 | # for shape in shapes: 92 | # param_count = 1 93 | # for dim in shape: 94 | # param_count *= dim 95 | # total_param_count += param_count 96 | # print "Total parameter count: {0}".format( 97 | # locale.format("%d", total_param_count, grouping=True) 98 | # ) 99 | 100 | def print_model_settings(locals_): 101 | print "Uppercase local vars:" 102 | all_vars = [(k,v) for (k,v) in locals_.items() if (k.isupper() and k!='T' and k!='SETTINGS' and k!='ALL_SETTINGS')] 103 | all_vars = sorted(all_vars, key=lambda x: x[0]) 104 | for var_name, var_value in all_vars: 105 | print "\t{}: {}".format(var_name, var_value) 106 | 107 | 108 | def print_model_settings_dict(settings): 109 | print "Settings dict:" 110 | all_vars = [(k,v) for (k,v) in settings.items()] 111 | all_vars = sorted(all_vars, key=lambda x: x[0]) 112 | for var_name, var_value in all_vars: 113 | print "\t{}: {}".format(var_name, var_value) -------------------------------------------------------------------------------- /tflib/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | def save_model(saver,sess,checkpoint_dir, step,model_name="GAN.model"): 5 | ''' 6 | Saves to output 7 | :param checkpoint_dir: 8 | :param step: 9 | :return: 10 | ''' 11 | saver.save(sess, 12 | os.path.join(checkpoint_dir, model_name), 13 | global_step=step) 14 | 15 | import re 16 | 17 | def load_model(saver,sess,checkpoint_dir=None): 18 | ''' 19 | Loads the saved model 20 | :param checkpoint_dir: root of all the checkpoints 21 | :return: 22 | ''' 23 | FLAGS = tf.app.flags.FLAGS 24 | 25 | def load_from_path(ckpt_path): 26 | ckpt_name = os.path.basename(ckpt_path) 27 | saver.restore(sess, ckpt_path) 28 | counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0)) 29 | print(" [*] Success to read {}".format(ckpt_name)) 30 | 31 | return True, counter 32 | 33 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 34 | try: 35 | if ckpt and ckpt.model_checkpoint_path: 36 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 37 | saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name)) 38 | return load_from_path(os.path.join(checkpoint_dir, ckpt_name)) 39 | else: 40 | print(" [*] Failed to find a checkpoint within directory {}".format(FLAGS.ckpt_path)) 41 | return False, 0 42 | except Exception as e: 43 | print(e) 44 | print(" [*] Failed to find a checkpoint, Exception!") 45 | return False, 0 46 | -------------------------------------------------------------------------------- /tflib/cifar10.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import os 4 | import urllib 5 | import gzip 6 | import cPickle as pickle 7 | 8 | def unpickle(file): 9 | fo = open(file, 'rb') 10 | dict = pickle.load(fo) 11 | fo.close() 12 | return dict['data'], dict['labels'] 13 | 14 | def cifar_generator(filenames, batch_size, data_dir,randomize=True): 15 | all_data = [] 16 | all_labels = [] 17 | for filename in filenames: 18 | data, labels = unpickle(data_dir + '/' + filename) 19 | all_data.append(data) 20 | all_labels.append(labels) 21 | 22 | images = np.concatenate(all_data, axis=0) 23 | labels = np.concatenate(all_labels, axis=0) 24 | 25 | def get_epoch(): 26 | if randomize: 27 | rng_state = np.random.get_state() 28 | np.random.shuffle(images) 29 | np.random.set_state(rng_state) 30 | np.random.shuffle(labels) 31 | 32 | for i in xrange(len(images) / batch_size): 33 | yield (images[i*batch_size:(i+1)*batch_size], labels[i*batch_size:(i+1)*batch_size]) 34 | 35 | return get_epoch 36 | 37 | 38 | def load(batch_size, data_dir,randomize=True): 39 | return ( 40 | cifar_generator(['data_batch_1','data_batch_2','data_batch_3','data_batch_4','data_batch_5'], batch_size, data_dir,randomize=randomize), 41 | cifar_generator(['test_batch'], batch_size, data_dir,randomize=randomize) 42 | ) -------------------------------------------------------------------------------- /tflib/inception_score.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/openai/improved-gan/blob/master/inception_score/model.py 2 | # Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os.path 8 | import sys 9 | import tarfile 10 | 11 | import numpy as np 12 | from six.moves import urllib 13 | import tensorflow as tf 14 | import glob 15 | import scipy.misc 16 | import math 17 | import sys 18 | 19 | MODEL_DIR = '/tmp/imagenet' 20 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 21 | softmax = None 22 | 23 | # Call this function with list of images. Each of elements should be a 24 | # numpy array with values ranging from 0 to 255. 25 | def get_inception_score(images, splits=10): 26 | assert(type(images) == list) 27 | assert(type(images[0]) == np.ndarray) 28 | assert(len(images[0].shape) == 3) 29 | assert(np.max(images[0]) > 10) 30 | assert(np.min(images[0]) >= 0.0) 31 | inps = [] 32 | for img in images: 33 | img = img.astype(np.float32) 34 | inps.append(np.expand_dims(img, 0)) 35 | bs = 100 36 | 37 | config = tf.ConfigProto(allow_soft_placement=True) 38 | config.gpu_options.allow_growth = True 39 | with tf.Session(config=config) as sess: 40 | preds = [] 41 | n_batches = int(math.ceil(float(len(inps)) / float(bs))) 42 | for i in range(n_batches): 43 | # sys.stdout.write(".") 44 | # sys.stdout.flush() 45 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))] 46 | inp = np.concatenate(inp, 0) 47 | pred = sess.run(softmax, {'ExpandDims:0': inp}) 48 | preds.append(pred) 49 | preds = np.concatenate(preds, 0) 50 | scores = [] 51 | for i in range(splits): 52 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] 53 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 54 | kl = np.mean(np.sum(kl, 1)) 55 | scores.append(np.exp(kl)) 56 | return np.mean(scores), np.std(scores) 57 | 58 | # This function is called automatically. 59 | def _init_inception(): 60 | global softmax 61 | if not os.path.exists(MODEL_DIR): 62 | os.makedirs(MODEL_DIR) 63 | filename = DATA_URL.split('/')[-1] 64 | filepath = os.path.join(MODEL_DIR, filename) 65 | if not os.path.exists(filepath): 66 | def _progress(count, block_size, total_size): 67 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 68 | filename, float(count * block_size) / float(total_size) * 100.0)) 69 | sys.stdout.flush() 70 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 71 | print() 72 | statinfo = os.stat(filepath) 73 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 74 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) 75 | with tf.gfile.FastGFile(os.path.join( 76 | MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: 77 | graph_def = tf.GraphDef() 78 | graph_def.ParseFromString(f.read()) 79 | _ = tf.import_graph_def(graph_def, name='') 80 | # Works with an arbitrary minibatch size. 81 | config = tf.ConfigProto(allow_soft_placement=True) 82 | config.gpu_options.allow_growth = True 83 | with tf.Session(config=config) as sess: 84 | pool3 = sess.graph.get_tensor_by_name('pool_3:0') 85 | ops = pool3.graph.get_operations() 86 | for op_idx, op in enumerate(ops): 87 | for o in op.outputs: 88 | shape = o.get_shape() 89 | shape = [s.value for s in shape] 90 | new_shape = [] 91 | for j, s in enumerate(shape): 92 | if s == 1 and j == 0: 93 | new_shape.append(None) 94 | else: 95 | new_shape.append(s) 96 | o._shape = tf.TensorShape(new_shape) 97 | w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1] 98 | logits = tf.matmul(tf.squeeze(pool3), w) 99 | softmax = tf.nn.softmax(logits) 100 | 101 | if softmax is None: 102 | _init_inception() 103 | -------------------------------------------------------------------------------- /tflib/mnist.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | import os 4 | import urllib 5 | import gzip 6 | import cPickle as pickle 7 | 8 | def mnist_generator(data, batch_size, n_labelled, limit=None,randomize=True): 9 | images, targets = data 10 | 11 | rng_state = numpy.random.get_state() 12 | numpy.random.shuffle(images) 13 | numpy.random.set_state(rng_state) 14 | numpy.random.shuffle(targets) 15 | if limit is not None: 16 | print "WARNING ONLY FIRST {} MNIST DIGITS".format(limit) 17 | images = images.astype('float32')[:limit] 18 | targets = targets.astype('int32')[:limit] 19 | if n_labelled is not None: 20 | labelled = numpy.zeros(len(images), dtype='int32') 21 | labelled[:n_labelled] = 1 22 | 23 | def get_epoch(): 24 | if randomize: 25 | rng_state = numpy.random.get_state() 26 | numpy.random.shuffle(images) 27 | numpy.random.set_state(rng_state) 28 | numpy.random.shuffle(targets) 29 | 30 | if n_labelled is not None: 31 | numpy.random.set_state(rng_state) 32 | numpy.random.shuffle(labelled) 33 | 34 | image_batches = images.reshape(-1, batch_size, 784) 35 | target_batches = targets.reshape(-1, batch_size) 36 | 37 | if n_labelled is not None: 38 | labelled_batches = labelled.reshape(-1, batch_size) 39 | 40 | for i in xrange(len(image_batches)): 41 | yield (numpy.copy(image_batches[i]), numpy.copy(target_batches[i]), numpy.copy(labelled)) 42 | 43 | else: 44 | 45 | for i in xrange(len(image_batches)): 46 | yield (numpy.copy(image_batches[i]), numpy.copy(target_batches[i])) 47 | 48 | return get_epoch 49 | 50 | def load(batch_size, test_batch_size, n_labelled=None,randomize=True): 51 | filepath = '/tmp/mnist.pkl.gz' 52 | url = 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz' 53 | 54 | if not os.path.isfile(filepath): 55 | print "Couldn't find MNIST dataset in /tmp, downloading..." 56 | urllib.urlretrieve(url, filepath) 57 | 58 | with gzip.open('/tmp/mnist.pkl.gz', 'rb') as f: 59 | train_data, dev_data, test_data = pickle.load(f) 60 | 61 | return ( 62 | mnist_generator(train_data, batch_size, n_labelled,randomize=randomize), 63 | mnist_generator(dev_data, test_batch_size, n_labelled,randomize=randomize), 64 | mnist_generator(test_data, test_batch_size, n_labelled,randomize=randomize) 65 | ) -------------------------------------------------------------------------------- /tflib/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kabkabm/defensegan/7e3feaebf7b9bbf08b1364e400119ef596cd78fd/tflib/ops/__init__.py -------------------------------------------------------------------------------- /tflib/ops/batchnorm.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | def Batchnorm(name, axes, inputs, is_training=None, stats_iter=None, 7 | update_moving_stats=True, fused=True, 8 | data_format='NHWC'): 9 | if ((axes == [0,2,3]) or (axes == [0,2])) and fused==True: 10 | if axes==[0,2]: 11 | inputs = tf.expand_dims(inputs, 3) 12 | # Old (working but pretty slow) implementation: 13 | ########## 14 | 15 | # inputs = tf.transpose(inputs, [0,2,3,1]) 16 | 17 | # mean, var = tf.nn.moments(inputs, [0,1,2], keep_dims=False) 18 | # offset = lib.param(name+'.offset', np.zeros(mean.get_shape()[-1], dtype='float32')) 19 | # scale = lib.param(name+'.scale', np.ones(var.get_shape()[-1], dtype='float32')) 20 | # result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-4) 21 | 22 | # return tf.transpose(result, [0,3,1,2]) 23 | 24 | # New (super fast but untested) implementation: 25 | offset = lib.param(name+'.offset', np.zeros(inputs.get_shape()[1], dtype='float32')) 26 | scale = lib.param(name+'.scale', np.ones(inputs.get_shape()[1], dtype='float32')) 27 | 28 | moving_mean = lib.param(name+'.moving_mean', np.zeros(inputs.get_shape()[1], dtype='float32'), trainable=False) 29 | moving_variance = lib.param(name+'.moving_variance', np.ones(inputs.get_shape()[1], dtype='float32'), trainable=False) 30 | 31 | def _fused_batch_norm_training(): 32 | return tf.nn.fused_batch_norm(inputs, scale, offset, 33 | epsilon=1e-5, 34 | data_format=data_format, 35 | is_training=True) 36 | 37 | def _fused_batch_norm_inference(): 38 | # Version which blends in the current item's statistics 39 | # batch_size = tf.cast(tf.shape(inputs)[0], 'float32') 40 | # mean, var = tf.nn.moments(inputs, [2,3], keep_dims=True) 41 | # mean = ((1./batch_size)*mean) + (((batch_size-1.)/batch_size)*moving_mean)[None,:,None,None] 42 | # var = ((1./batch_size)*var) + (((batch_size-1.)/batch_size)*moving_variance)[None,:,None,None] 43 | # return tf.nn.batch_normalization(inputs, mean, var, offset[None,:,None,None], scale[None,:,None,None], 1e-5), mean, var 44 | 45 | #Standard version 46 | return tf.nn.fused_batch_norm( 47 | inputs, 48 | scale, 49 | offset, 50 | epsilon=1e-5, 51 | mean=moving_mean, 52 | variance=moving_variance, 53 | is_training=False, 54 | data_format=data_format 55 | ) 56 | 57 | if is_training is None: 58 | outputs, batch_mean, batch_var = _fused_batch_norm_training() 59 | else: 60 | outputs, batch_mean, batch_var = tf.cond(is_training, 61 | _fused_batch_norm_training, 62 | _fused_batch_norm_inference) 63 | if update_moving_stats: 64 | no_updates = lambda: outputs 65 | def _force_updates(): 66 | """Internal function forces updates moving_vars if is_training.""" 67 | float_stats_iter = tf.cast(stats_iter, tf.float32) 68 | 69 | update_moving_mean = tf.assign(moving_mean, ((float_stats_iter/(float_stats_iter+1))*moving_mean) + ((1/(float_stats_iter+1))*batch_mean)) 70 | update_moving_variance = tf.assign(moving_variance, ((float_stats_iter/(float_stats_iter+1))*moving_variance) + ((1/(float_stats_iter+1))*batch_var)) 71 | 72 | with tf.control_dependencies([update_moving_mean, update_moving_variance]): 73 | return tf.identity(outputs) 74 | outputs = tf.cond(is_training, _force_updates, no_updates) 75 | 76 | if axes == [0,2]: 77 | return outputs[:,:,:,0] # collapse last dim 78 | else: 79 | return outputs 80 | else: 81 | # raise Exception('old BN') 82 | # TODO we can probably use nn.fused_batch_norm here too for speedup 83 | mean, var = tf.nn.moments(inputs, axes, keep_dims=True) 84 | shape = mean.get_shape().as_list() 85 | if 0 not in axes: 86 | print "WARNING ({}): didn't find 0 in axes, but not using separate BN params for each item in batch".format(name) 87 | shape[0] = 1 88 | offset = lib.param(name+'.offset', np.zeros(shape, dtype='float32')) 89 | scale = lib.param(name+'.scale', np.ones(shape, dtype='float32')) 90 | result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-5) 91 | 92 | 93 | return result 94 | -------------------------------------------------------------------------------- /tflib/ops/cond_batchnorm.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | def Batchnorm(name, axes, inputs, is_training=None, stats_iter=None, update_moving_stats=True, fused=True, labels=None, n_labels=None): 7 | """conditional batchnorm (dumoulin et al 2016) for BCHW conv filtermaps""" 8 | if axes != [0,2,3]: 9 | raise Exception('unsupported') 10 | mean, var = tf.nn.moments(inputs, axes, keep_dims=True) 11 | shape = mean.get_shape().as_list() # shape is [1,n,1,1] 12 | offset_m = lib.param(name+'.offset', np.zeros([n_labels,shape[1]], dtype='float32')) 13 | scale_m = lib.param(name+'.scale', np.ones([n_labels,shape[1]], dtype='float32')) 14 | offset = tf.nn.embedding_lookup(offset_m, labels) 15 | scale = tf.nn.embedding_lookup(scale_m, labels) 16 | result = tf.nn.batch_normalization(inputs, mean, var, offset[:,:,None,None], scale[:,:,None,None], 1e-5) 17 | return result -------------------------------------------------------------------------------- /tflib/ops/conv1d.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | _default_weightnorm = False 7 | def enable_default_weightnorm(): 8 | global _default_weightnorm 9 | _default_weightnorm = True 10 | 11 | def Conv1D(name, input_dim, output_dim, filter_size, inputs, he_init=True, 12 | mask_type=None, stride=1, weightnorm=None, biases=True, gain=1., 13 | data_format='NHWC'): 14 | """ 15 | inputs: tensor of shape (batch size, num channels, width) 16 | mask_type: one of None, 'a', 'b' 17 | 18 | returns: tensor of shape (batch size, num channels, width) 19 | """ 20 | with tf.name_scope(name) as scope: 21 | 22 | if mask_type is not None: 23 | mask_type, mask_n_channels = mask_type 24 | 25 | mask = np.ones( 26 | (filter_size, input_dim, output_dim), 27 | dtype='float32' 28 | ) 29 | center = filter_size // 2 30 | 31 | # Mask out future locations 32 | # filter shape is (width, input channels, output channels) 33 | mask[center+1:, :, :] = 0. 34 | 35 | # Mask out future channels 36 | for i in xrange(mask_n_channels): 37 | for j in xrange(mask_n_channels): 38 | if (mask_type=='a' and i >= j) or (mask_type=='b' and i > j): 39 | mask[ 40 | center, 41 | i::mask_n_channels, 42 | j::mask_n_channels 43 | ] = 0. 44 | 45 | 46 | def uniform(stdev, size): 47 | return np.random.uniform( 48 | low=-stdev * np.sqrt(3), 49 | high=stdev * np.sqrt(3), 50 | size=size 51 | ).astype('float32') 52 | 53 | fan_in = input_dim * filter_size 54 | fan_out = output_dim * filter_size / stride 55 | 56 | if mask_type is not None: # only approximately correct 57 | fan_in /= 2. 58 | fan_out /= 2. 59 | 60 | if he_init: 61 | filters_stdev = np.sqrt(4./(fan_in+fan_out)) 62 | else: # Normalized init (Glorot & Bengio) 63 | filters_stdev = np.sqrt(2./(fan_in+fan_out)) 64 | 65 | filter_values = uniform( 66 | filters_stdev, 67 | (filter_size, input_dim, output_dim) 68 | ) 69 | # print "WARNING IGNORING GAIN" 70 | filter_values *= gain 71 | 72 | filters = lib.param(name+'.Filters', filter_values) 73 | 74 | if weightnorm==None: 75 | weightnorm = _default_weightnorm 76 | if weightnorm: 77 | norm_values = np.sqrt(np.sum(np.square(filter_values), axis=(0,1))) 78 | target_norms = lib.param( 79 | name + '.g', 80 | norm_values 81 | ) 82 | with tf.name_scope('weightnorm') as scope: 83 | norms = tf.sqrt(tf.reduce_sum(tf.square(filters), reduction_indices=[0,1])) 84 | filters = filters * (target_norms / norms) 85 | 86 | if mask_type is not None: 87 | with tf.name_scope('filter_mask'): 88 | filters = filters * mask 89 | 90 | result = tf.nn.conv1d( 91 | value=inputs, 92 | filters=filters, 93 | stride=stride, 94 | padding='SAME', 95 | data_format=data_format 96 | ) 97 | 98 | if biases: 99 | _biases = lib.param( 100 | name+'.Biases', 101 | np.zeros([output_dim], dtype='float32') 102 | ) 103 | 104 | # result = result + _biases 105 | 106 | result = tf.expand_dims(result, 3) 107 | result = tf.nn.bias_add(result, _biases, data_format=data_format) 108 | result = tf.squeeze(result) 109 | 110 | return result 111 | -------------------------------------------------------------------------------- /tflib/ops/conv2d.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | _default_weightnorm = False 7 | 8 | 9 | def enable_default_weightnorm(): 10 | global _default_weightnorm 11 | _default_weightnorm = True 12 | 13 | 14 | _weights_stdev = None 15 | 16 | 17 | def set_weights_stdev(weights_stdev): 18 | global _weights_stdev 19 | _weights_stdev = weights_stdev 20 | 21 | 22 | def unset_weights_stdev(): 23 | global _weights_stdev 24 | _weights_stdev = None 25 | 26 | 27 | def Conv2D(name, input_dim, output_dim, filter_size, inputs, he_init=True, mask_type=None, stride=1, weightnorm=None, 28 | biases=True, gain=1., data_format='NHWC'): 29 | """ 30 | inputs: tensor of shape (batch size, num channels, height, width) 31 | mask_type: one of None, 'a', 'b' 32 | 33 | returns: tensor of shape (batch size, num channels, height, width) 34 | """ 35 | with tf.name_scope(name) as scope: 36 | 37 | if mask_type is not None: 38 | mask_type, mask_n_channels = mask_type 39 | 40 | mask = np.ones( 41 | (filter_size, filter_size, input_dim, output_dim), 42 | dtype='float32' 43 | ) 44 | center = filter_size // 2 45 | 46 | # Mask out future locations 47 | # filter shape is (height, width, input channels, output channels) 48 | mask[center + 1:, :, :, :] = 0. 49 | mask[center, center + 1:, :, :] = 0. 50 | 51 | # Mask out future channels 52 | for i in xrange(mask_n_channels): 53 | for j in xrange(mask_n_channels): 54 | if (mask_type == 'a' and i >= j) or (mask_type == 'b' and i > j): 55 | mask[ 56 | center, 57 | center, 58 | i::mask_n_channels, 59 | j::mask_n_channels 60 | ] = 0. 61 | 62 | def uniform(stdev, size): 63 | return np.random.uniform( 64 | low=-stdev * np.sqrt(3), 65 | high=stdev * np.sqrt(3), 66 | size=size 67 | ).astype('float32') 68 | 69 | fan_in = input_dim * filter_size ** 2 70 | fan_out = output_dim * filter_size ** 2 / (stride ** 2) 71 | 72 | if mask_type is not None: # only approximately correct 73 | fan_in /= 2. 74 | fan_out /= 2. 75 | 76 | if he_init: 77 | filters_stdev = np.sqrt(4. / (fan_in + fan_out)) 78 | else: # Normalized init (Glorot & Bengio) 79 | filters_stdev = np.sqrt(2. / (fan_in + fan_out)) 80 | 81 | if _weights_stdev is not None: 82 | filter_values = uniform( 83 | _weights_stdev, 84 | (filter_size, filter_size, input_dim, output_dim) 85 | ) 86 | else: 87 | filter_values = uniform( 88 | filters_stdev, 89 | (filter_size, filter_size, input_dim, output_dim) 90 | ) 91 | 92 | # print "WARNING IGNORING GAIN" 93 | filter_values *= gain 94 | 95 | filters = lib.param(name + '.Filters', filter_values) 96 | 97 | if weightnorm == None: 98 | weightnorm = _default_weightnorm 99 | if weightnorm: 100 | norm_values = np.sqrt(np.sum(np.square(filter_values), axis=(0, 1, 2))) 101 | target_norms = lib.param( 102 | name + '.g', 103 | norm_values 104 | ) 105 | with tf.name_scope('weightnorm') as scope: 106 | norms = tf.sqrt(tf.reduce_sum(tf.square(filters), reduction_indices=[0, 1, 2])) 107 | filters = filters * (target_norms / norms) 108 | 109 | if mask_type is not None: 110 | with tf.name_scope('filter_mask'): 111 | filters = filters * mask 112 | strides = [1,1,stride,stride] 113 | if data_format == 'NHWC': 114 | strides = [1,stride, stride,1] 115 | result = tf.nn.conv2d( 116 | input=inputs, 117 | filter=filters, 118 | strides=strides, 119 | padding='SAME', 120 | data_format=data_format 121 | ) 122 | 123 | if biases: 124 | _biases = lib.param( 125 | name + '.Biases', 126 | np.zeros(output_dim, dtype='float32') 127 | ) 128 | 129 | result = tf.nn.bias_add(result, _biases, data_format=data_format) 130 | 131 | return result 132 | -------------------------------------------------------------------------------- /tflib/ops/deconv2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | import tflib as lib 5 | 6 | _default_weightnorm = False 7 | 8 | 9 | def enable_default_weightnorm(): 10 | global _default_weightnorm 11 | _default_weightnorm = True 12 | 13 | 14 | _weights_stdev = None 15 | 16 | 17 | def set_weights_stdev(weights_stdev): 18 | global _weights_stdev 19 | _weights_stdev = weights_stdev 20 | 21 | 22 | def unset_weights_stdev(): 23 | global _weights_stdev 24 | _weights_stdev = None 25 | 26 | 27 | def Deconv2D( 28 | name, 29 | input_dim, 30 | output_dim, 31 | filter_size, 32 | inputs, 33 | he_init=True, 34 | weightnorm=None, 35 | biases=True, 36 | gain=1., 37 | mask_type=None, 38 | data_format='NHWC', 39 | ): 40 | """ 41 | inputs: tensor of shape (batch size, height, width, input_dim) 42 | returns: tensor of shape (batch size, 2*height, 2*width, output_dim) 43 | """ 44 | with tf.name_scope(name) as scope: 45 | 46 | if mask_type != None: 47 | raise Exception('Unsupported configuration') 48 | 49 | def uniform(stdev, size): 50 | return np.random.uniform( 51 | low=-stdev * np.sqrt(3), 52 | high=stdev * np.sqrt(3), 53 | size=size 54 | ).astype('float32') 55 | 56 | stride = 2 57 | fan_in = input_dim * filter_size ** 2 / (stride ** 2) 58 | fan_out = output_dim * filter_size ** 2 59 | 60 | if he_init: 61 | filters_stdev = np.sqrt(4. / (fan_in + fan_out)) 62 | else: # Normalized init (Glorot & Bengio) 63 | filters_stdev = np.sqrt(2. / (fan_in + fan_out)) 64 | 65 | if _weights_stdev is not None: 66 | filter_values = uniform( 67 | _weights_stdev, 68 | (filter_size, filter_size, output_dim, input_dim) 69 | ) 70 | else: 71 | filter_values = uniform( 72 | filters_stdev, 73 | (filter_size, filter_size, output_dim, input_dim) 74 | ) 75 | 76 | filter_values *= gain 77 | 78 | filters = lib.param( 79 | name + '.Filters', 80 | filter_values 81 | ) 82 | 83 | if weightnorm == None: 84 | weightnorm = _default_weightnorm 85 | if weightnorm: 86 | norm_values = np.sqrt( 87 | np.sum(np.square(filter_values), axis=(0, 1, 3))) 88 | target_norms = lib.param( 89 | name + '.g', 90 | norm_values 91 | ) 92 | with tf.name_scope('weightnorm') as scope: 93 | norms = tf.sqrt(tf.reduce_sum(tf.square(filters), 94 | reduction_indices=[0, 1, 3])) 95 | filters = filters * tf.expand_dims(target_norms / norms, 1) 96 | 97 | if data_format == 'NCHW': 98 | inputs = tf.transpose(inputs, [0, 2, 3, 1], name='NCHW_to_NHWC') 99 | 100 | input_shape = inputs.get_shape().as_list() 101 | output_shape = [input_shape[0], 2 * input_shape[1], 2 * input_shape[2], 102 | output_dim] 103 | 104 | result = tf.nn.conv2d_transpose( 105 | value=inputs, 106 | filter=filters, 107 | output_shape=output_shape, 108 | strides=[1, 2, 2, 1], 109 | padding='SAME' 110 | ) 111 | 112 | if biases: 113 | _biases = lib.param( 114 | name + '.Biases', 115 | np.zeros(output_dim, dtype='float32') 116 | ) 117 | result = tf.nn.bias_add(result, _biases) 118 | 119 | if data_format == 'NCHW': 120 | result = tf.transpose(result, [0, 3, 1, 2], name='NHWC_to_NCHW') 121 | 122 | return result 123 | -------------------------------------------------------------------------------- /tflib/ops/layernorm.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | def Layernorm(name, norm_axes, inputs): 7 | mean, var = tf.nn.moments(inputs, norm_axes, keep_dims=True) 8 | 9 | # Assume the 'neurons' axis is the first of norm_axes. This is the case for fully-connected and BCHW conv layers. 10 | n_neurons = inputs.get_shape().as_list()[norm_axes[0]] 11 | 12 | offset = lib.param(name+'.offset', np.zeros(n_neurons, dtype='float32')) 13 | scale = lib.param(name+'.scale', np.ones(n_neurons, dtype='float32')) 14 | 15 | # Add broadcasting dims to offset and scale (e.g. BCHW conv data) 16 | offset = tf.reshape(offset, [-1] + [1 for i in xrange(len(norm_axes)-1)]) 17 | scale = tf.reshape(scale, [-1] + [1 for i in xrange(len(norm_axes)-1)]) 18 | 19 | result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-5) 20 | 21 | return result -------------------------------------------------------------------------------- /tflib/ops/linear.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | _default_weightnorm = False 7 | def enable_default_weightnorm(): 8 | global _default_weightnorm 9 | _default_weightnorm = True 10 | 11 | def disable_default_weightnorm(): 12 | global _default_weightnorm 13 | _default_weightnorm = False 14 | 15 | _weights_stdev = None 16 | def set_weights_stdev(weights_stdev): 17 | global _weights_stdev 18 | _weights_stdev = weights_stdev 19 | 20 | def unset_weights_stdev(): 21 | global _weights_stdev 22 | _weights_stdev = None 23 | 24 | def Linear( 25 | name, 26 | input_dim, 27 | output_dim, 28 | inputs, 29 | biases=True, 30 | initialization=None, 31 | weightnorm=None, 32 | gain=1. 33 | ): 34 | """ 35 | initialization: None, `lecun`, 'glorot', `he`, 'glorot_he', `orthogonal`, `("uniform", range)` 36 | """ 37 | with tf.name_scope(name) as scope: 38 | 39 | def uniform(stdev, size): 40 | if _weights_stdev is not None: 41 | stdev = _weights_stdev 42 | return np.random.uniform( 43 | low=-stdev * np.sqrt(3), 44 | high=stdev * np.sqrt(3), 45 | size=size 46 | ).astype('float32') 47 | 48 | if initialization == 'lecun':# and input_dim != output_dim): 49 | # disabling orth. init for now because it's too slow 50 | weight_values = uniform( 51 | np.sqrt(1./input_dim), 52 | (input_dim, output_dim) 53 | ) 54 | 55 | elif initialization == 'glorot' or (initialization == None): 56 | 57 | weight_values = uniform( 58 | np.sqrt(2./(input_dim+output_dim)), 59 | (input_dim, output_dim) 60 | ) 61 | 62 | elif initialization == 'he': 63 | 64 | weight_values = uniform( 65 | np.sqrt(2./input_dim), 66 | (input_dim, output_dim) 67 | ) 68 | 69 | elif initialization == 'glorot_he': 70 | 71 | weight_values = uniform( 72 | np.sqrt(4./(input_dim+output_dim)), 73 | (input_dim, output_dim) 74 | ) 75 | 76 | elif initialization == 'orthogonal' or \ 77 | (initialization == None and input_dim == output_dim): 78 | 79 | # From lasagne 80 | def sample(shape): 81 | if len(shape) < 2: 82 | raise RuntimeError("Only shapes of length 2 or more are " 83 | "supported.") 84 | flat_shape = (shape[0], np.prod(shape[1:])) 85 | # TODO: why normal and not uniform? 86 | a = np.random.normal(0.0, 1.0, flat_shape) 87 | u, _, v = np.linalg.svd(a, full_matrices=False) 88 | # pick the one with the correct shape 89 | q = u if u.shape == flat_shape else v 90 | q = q.reshape(shape) 91 | return q.astype('float32') 92 | weight_values = sample((input_dim, output_dim)) 93 | 94 | elif initialization[0] == 'uniform': 95 | 96 | weight_values = np.random.uniform( 97 | low=-initialization[1], 98 | high=initialization[1], 99 | size=(input_dim, output_dim) 100 | ).astype('float32') 101 | 102 | else: 103 | 104 | raise Exception('Invalid initialization!') 105 | 106 | weight_values *= gain 107 | 108 | weight = lib.param( 109 | name + '.W', 110 | weight_values 111 | ) 112 | 113 | if weightnorm==None: 114 | weightnorm = _default_weightnorm 115 | if weightnorm: 116 | norm_values = np.sqrt(np.sum(np.square(weight_values), axis=0)) 117 | # norm_values = np.linalg.norm(weight_values, axis=0) 118 | 119 | target_norms = lib.param( 120 | name + '.g', 121 | norm_values 122 | ) 123 | 124 | with tf.name_scope('weightnorm') as scope: 125 | norms = tf.sqrt(tf.reduce_sum(tf.square(weight), reduction_indices=[0])) 126 | weight = weight * (target_norms / norms) 127 | 128 | 129 | if inputs.get_shape().ndims == 2: 130 | result = tf.matmul(inputs, weight) 131 | else: 132 | reshaped_inputs = tf.reshape(inputs, [-1, input_dim]) 133 | result = tf.matmul(reshaped_inputs, weight) 134 | 135 | if biases: 136 | result = tf.nn.bias_add( 137 | result, 138 | lib.param( 139 | name + '.b', 140 | np.zeros((output_dim,), dtype='float32') 141 | ) 142 | ) 143 | 144 | return result -------------------------------------------------------------------------------- /tflib/plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import matplotlib 4 | from utils.misc import ensure_dir 5 | 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | 9 | import collections 10 | import cPickle as pickle 11 | 12 | _since_beginning = collections.defaultdict(lambda: {}) 13 | _since_last_flush = collections.defaultdict(lambda: {}) 14 | 15 | _iter = [0] 16 | 17 | import os 18 | 19 | def tick(): 20 | _iter[0] += 1 21 | 22 | 23 | def plot(name, value): 24 | _since_last_flush[name][_iter[0]] = value 25 | 26 | 27 | def flush(): 28 | prints = [] 29 | 30 | for name, vals in _since_last_flush.items(): 31 | prints.append("{}\t{}".format(name, np.mean(vals.values()))) 32 | _since_beginning[name].update(vals) 33 | 34 | x_vals = np.sort(_since_beginning[name].keys()) 35 | y_vals = [_since_beginning[name][x] for x in x_vals] 36 | 37 | plt.clf() 38 | plt.plot(x_vals, y_vals) 39 | plt.xlabel('iteration') 40 | plt.ylabel(name) 41 | fpath=os.path.join('debug',name.replace(' ', '_') + '.jpg') 42 | base_dir = os.path.dirname(fpath) 43 | ensure_dir(base_dir) 44 | plt.savefig(fpath) 45 | 46 | print "iter {}\t{}".format(_iter[0], "\t".join(prints)) 47 | _since_last_flush.clear() 48 | 49 | with open('debug/log.pkl', 'wb') as f: 50 | pickle.dump(dict(_since_beginning), f, pickle.HIGHEST_PROTOCOL) 51 | -------------------------------------------------------------------------------- /tflib/save_images.py: -------------------------------------------------------------------------------- 1 | """ 2 | Image grid saver, based on color_grid_vis from github.com/Newmu 3 | """ 4 | 5 | import numpy as np 6 | from scipy.misc import imsave 7 | 8 | def save_images(X, save_path): 9 | # [0, 1] -> [0,255] 10 | if isinstance(X.flatten()[0], np.floating): 11 | X = (255.99*X).astype('uint8') 12 | 13 | n_samples = X.shape[0] 14 | rows = int(np.sqrt(n_samples)) 15 | while n_samples % rows != 0: 16 | rows -= 1 17 | 18 | nh, nw = rows, n_samples/rows 19 | 20 | if X.ndim == 2: 21 | X = np.reshape(X, (X.shape[0], int(np.sqrt(X.shape[1])), int(np.sqrt(X.shape[1])))) 22 | 23 | if X.ndim == 4: 24 | # BCHW -> BHWC 25 | if X.shape[1] == 3: 26 | X = X.transpose(0,2,3,1) 27 | 28 | h, w = X[0].shape[:2] 29 | img = np.zeros((h*nh, w*nw, 3)) 30 | elif X.ndim == 3: 31 | h, w = X[0].shape[:2] 32 | img = np.zeros((h*nh, w*nw)) 33 | 34 | for n, x in enumerate(X): 35 | j = n/nw 36 | i = n%nw 37 | img[j*h:j*h+h, i*w:i*w+w] = x 38 | 39 | imsave(save_path, img) -------------------------------------------------------------------------------- /tflib/small_imagenet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.misc 3 | import time 4 | 5 | def make_generator(path, n_files, batch_size): 6 | epoch_count = [1] 7 | def get_epoch(): 8 | images = np.zeros((batch_size, 3, 64, 64), dtype='int32') 9 | files = range(n_files) 10 | random_state = np.random.RandomState(epoch_count[0]) 11 | random_state.shuffle(files) 12 | epoch_count[0] += 1 13 | for n, i in enumerate(files): 14 | image = scipy.misc.imread("{}/{}.png".format(path, str(i+1).zfill(len(str(n_files))))) 15 | images[n % batch_size] = image.transpose(2,0,1) 16 | if n > 0 and n % batch_size == 0: 17 | yield (images,) 18 | return get_epoch 19 | 20 | def load(batch_size, data_dir='/home/ishaan/data/imagenet64'): 21 | return ( 22 | make_generator(data_dir+'/train_64x64', 1281149, batch_size), 23 | make_generator(data_dir+'/valid_64x64', 49999, batch_size) 24 | ) 25 | 26 | if __name__ == '__main__': 27 | train_gen, valid_gen = load(64) 28 | t0 = time.time() 29 | for i, batch in enumerate(train_gen(), start=1): 30 | print "{}\t{}".format(str(time.time() - t0), batch[0][0,0,0,0]) 31 | if i == 1000: 32 | break 33 | t0 = time.time() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Defense-GAN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | 16 | """The main class for training GANs.""" 17 | 18 | import argparse 19 | import sys 20 | 21 | import tensorflow as tf 22 | 23 | from models.gan import MnistDefenseGAN, FmnistDefenseDefenseGAN, \ 24 | CelebADefenseGAN 25 | from utils.config import load_config 26 | 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser() 30 | 31 | parser.add_argument('--cfg', required=True, help='Config file') 32 | 33 | if len(sys.argv) == 1: 34 | parser.print_help() 35 | sys.exit(1) 36 | args, _ = parser.parse_known_args() 37 | return args 38 | 39 | 40 | def main(cfg, *args): 41 | FLAGS = tf.app.flags.FLAGS 42 | ds_gan = { 43 | 'mnist': MnistDefenseGAN, 'f-mnist': FmnistDefenseDefenseGAN, 44 | 'celeba': CelebADefenseGAN, 45 | } 46 | GAN = ds_gan[FLAGS.dataset_name] 47 | 48 | gan = GAN(cfg=cfg, test_mode=not FLAGS.is_train) 49 | 50 | if FLAGS.is_train: 51 | gan.train() 52 | 53 | if FLAGS.train_encoder: 54 | gan.load(checkpoint_dir=FLAGS.init_path) 55 | gan.train(phase='just_enc') 56 | 57 | if FLAGS.save_recs: 58 | gan.reconstruct_dataset(ckpt_path=FLAGS.init_path, 59 | max_num=FLAGS.max_num) 60 | 61 | if FLAGS.test_generator: 62 | gan.load_generator(ckpt_path=FLAGS.init_path) 63 | gan.sess.run(gan.global_step.initializer) 64 | gan.generate_image(iteration=0) 65 | 66 | if FLAGS.test_batch: 67 | gan.test_batch() 68 | 69 | if FLAGS.save_ds: 70 | gan.save_ds() 71 | 72 | 73 | if __name__ == '__main__': 74 | args = parse_args() 75 | 76 | # Note: The load_config() call will convert all the parameters that are defined in 77 | # experiments/config files into FLAGS.param_name and can be passed in from command line. 78 | # arguments : python train.py --cfg -- 79 | cfg = load_config(args.cfg) 80 | flags = tf.app.flags 81 | 82 | flags.DEFINE_boolean("is_train", False, 83 | "True for training, False for testing. [False]") 84 | flags.DEFINE_boolean("save_recs", False, 85 | "True for saving reconstructions. [False]") 86 | flags.DEFINE_boolean("debug", False, 87 | "True for debug. [False]") 88 | flags.DEFINE_boolean("test_generator", False, 89 | "True for generator samples. [False]") 90 | flags.DEFINE_boolean("test_decoder", False, 91 | "True for decoder samples. [False]") 92 | flags.DEFINE_boolean("test_again", False, 93 | "True for not using cache. [False]") 94 | flags.DEFINE_boolean("test_batch", False, 95 | "True for visualizing the batches and labels. [False]") 96 | flags.DEFINE_boolean("save_ds", False, 97 | "True for saving the dataset in a pickle file. [" 98 | "False]") 99 | flags.DEFINE_boolean("tensorboard_log", True, "True for saving " 100 | "tensorboard logs. [True]") 101 | flags.DEFINE_boolean("train_encoder", False, 102 | "Add an encoder to a pretrained model. [" 103 | "False]") 104 | flags.DEFINE_boolean("init_with_enc", False, 105 | "Initializes the z with an encoder, must run " 106 | "--train_encoder first. [False]") 107 | flags.DEFINE_integer("max_num", -1, 108 | "True for saving the dataset in a pickle file [" 109 | "False]") 110 | flags.DEFINE_string("init_path", None, "Checkpoint path. [None]") 111 | 112 | main_cfg = lambda x: main(cfg, x) 113 | tf.app.run(main=main_cfg) 114 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kabkabm/defensegan/7e3feaebf7b9bbf08b1364e400119ef596cd78fd/utils/__init__.py -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Defense-GAN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Contains the configuration handling code and default experiment 17 | parameters.""" 18 | 19 | import os 20 | 21 | import tensorflow as tf 22 | import yaml 23 | 24 | FLAGS = tf.app.flags.FLAGS 25 | 26 | type_to_define_fn = {int: tf.app.flags.DEFINE_integer, 27 | float: tf.app.flags.DEFINE_float, 28 | bool: tf.app.flags.DEFINE_boolean, 29 | basestring: tf.app.flags.DEFINE_string, 30 | str: tf.app.flags.DEFINE_string, 31 | type(None): tf.app.flags.DEFINE_integer, 32 | tuple: tf.app.flags.DEFINE_list, 33 | list: tf.app.flags.DEFINE_list} 34 | 35 | 36 | def load_config(cfg_path, set_flag=False, verbose=False): 37 | """Loads the configuration files into the global flags. 38 | 39 | Args: 40 | cfg_path: The path to the config yaml file. 41 | set_flag: If True, does not create new flag attributes, only sets 42 | existing ones. 43 | verbose: Verbose mode. 44 | 45 | Returns: 46 | The loaded configuration dictionary. 47 | 48 | Raises: 49 | RuntimeError: If the configuration path does not exist. 50 | """ 51 | flags = tf.app.flags.FLAGS 52 | 53 | if not os.path.exists(cfg_path): 54 | raise RuntimeError( 55 | "[!] Configuration path {} does not exist.".format(cfg_path)) 56 | if os.path.isdir(cfg_path): 57 | cfg_path = os.path.join(cfg_path, 'cfg.yml') 58 | with open(cfg_path, 'r') as f: 59 | cfg = yaml.load(f) 60 | else: 61 | with open(cfg_path, 'r') as f: 62 | loaded_cfg = yaml.load(f) 63 | base_dir = os.path.dirname(cfg_path) 64 | with open(os.path.join(base_dir, 'default.yml'), 'r') as f: 65 | cfg = yaml.load(f) 66 | 67 | cfg.update(loaded_cfg) 68 | 69 | with open(os.path.join('experiments/cfgs', 'key_doc.yml')) as f: 70 | docs = yaml.load(f) 71 | 72 | tf.app.flags.DEFINE_string('cfg_path', cfg_path, 'config path.') 73 | for (k, v) in cfg.items(): 74 | if set_flag: 75 | setattr(flags, k.lower(), v) 76 | else: 77 | if hasattr(flags, k.lower()): 78 | setattr(flags, k.lower(), v) 79 | else: 80 | def_func = type_to_define_fn[type(v)] 81 | 82 | try: 83 | def_func(k.lower(), v, docs[k]) 84 | except KeyError: 85 | 'Doc for the key {} is not found in the ' \ 86 | 'experimets/cfgs/key_doc.yml'.format( 87 | k) 88 | def_func(k.lower(), v, 'No doc') 89 | if verbose: 90 | print('[#] set {} to {} type: {}'.format(k.lower(), v['val'], 91 | str(type( 92 | v['val'])))) 93 | cfg['cfg_path'] = cfg_path 94 | return cfg 95 | -------------------------------------------------------------------------------- /utils/dummy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Defense-GAN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Dummy summary writer for not saving the tensorboard log.""" 17 | 18 | class DummySummaryWriter(object): 19 | def write(self, *args, **arg_dicts): 20 | pass 21 | 22 | def add_summary(self, summary_str, counter): 23 | pass -------------------------------------------------------------------------------- /utils/gan_defense.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Defense-GAN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Defense-GAN model evaluation function.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | from __future__ import unicode_literals 22 | 23 | import math 24 | import tensorflow as tf 25 | import warnings 26 | import numpy as np 27 | 28 | from cleverhans.utils import _ArgsWrapper, create_logger 29 | 30 | _logger = create_logger("cleverhans.utils.tf") 31 | 32 | def model_eval_gan( 33 | sess, 34 | images, 35 | labels, 36 | predictions=None, 37 | predictions_rec=None, 38 | test_images=None, 39 | test_labels=None, 40 | feed=None, 41 | args=None, 42 | model=None, 43 | diff_op=None, 44 | ): 45 | """Computes the accuracy of a model on test data as well as the 46 | reconstruction errors for attack detection. 47 | 48 | Args: 49 | sess: TF session to use when training the graph. 50 | images: input placeholder. 51 | labels: output placeholder (for labels). 52 | predictions: model output predictions. 53 | predictions_rec: model output prediction for reconstructions. 54 | test_images: numpy array with training inputs 55 | test_labels: numpy array with training outputs 56 | feed: An optional dictionary that is appended to the feeding 57 | dictionary before the session runs. Can be used to feed 58 | the learning phase of a Keras model for instance. 59 | args: dict or argparse `Namespace` object. 60 | Should contain `batch_size` 61 | model: (deprecated) if not None, holds model output predictions. 62 | diff_op: The operation that calculates the difference between input 63 | and attack. 64 | 65 | Returns: 66 | accuracy: The accuracy on the test data. 67 | accuracy_rec: The accuracy on the reconstructed test data (if 68 | predictions_rec is provided) 69 | roc_info: The differences between input and reconstruction for 70 | attack detection. 71 | """ 72 | args = _ArgsWrapper(args or {}) 73 | 74 | assert args.batch_size, "Batch size was not given in args dict" 75 | if test_images is None or test_labels is None: 76 | raise ValueError("X_test argument and Y_test argument " 77 | "must be supplied.") 78 | if model is None and predictions is None: 79 | raise ValueError("One of model argument " 80 | "or predictions argument must be supplied.") 81 | if model is not None: 82 | warnings.warn("model argument is deprecated. " 83 | "Switch to predictions argument. " 84 | "model argument will be removed after 2018-01-05.") 85 | if predictions is None: 86 | predictions = model 87 | else: 88 | raise ValueError("Exactly one of model argument" 89 | " and predictions argument should be specified.") 90 | 91 | # Define accuracy symbolically. 92 | correct_preds = tf.equal(tf.argmax(labels, axis=-1), 93 | tf.argmax(predictions, axis=-1)) 94 | 95 | if predictions_rec is not None: 96 | correct_preds_rec = tf.equal(tf.argmax(labels, axis=-1), 97 | tf.argmax(predictions_rec, axis=-1)) 98 | acc_value_rec = tf.reduce_sum(tf.to_float(correct_preds_rec)) 99 | 100 | accuracy_rec = 0.0 101 | cur_labels = tf.argmax(labels, axis=-1), 102 | cur_preds = tf.argmax(predictions, axis=-1) 103 | 104 | acc_value = tf.reduce_sum(tf.to_float(correct_preds)) 105 | 106 | 107 | diffs = [] 108 | all_labels = [] 109 | preds = [] 110 | 111 | accuracy = 0.0 112 | 113 | # Compute number of batches. 114 | nb_batches = int(math.ceil(float(len(test_images)) / args.batch_size)) 115 | assert nb_batches * args.batch_size >= len(test_images) 116 | 117 | for batch in range(nb_batches): 118 | # To initialize the variables of Defense-GAN at test time. 119 | sess.run(tf.local_variables_initializer()) 120 | print("[#] Eval batch {}/{}".format(batch, nb_batches)) 121 | 122 | # Must not use the `batch_indices` function here, because it 123 | # repeats some examples. 124 | # It's acceptable to repeat during training, but not eval. 125 | start = batch * args.batch_size 126 | end = min(len(test_images), start + args.batch_size) 127 | cur_batch_size = end - start 128 | 129 | # The last batch may be smaller than all others, so we need to 130 | # account for variable batch size here. 131 | feed_dict = {images: test_images[start:end], labels: test_labels[start:end]} 132 | if feed is not None: 133 | feed_dict.update(feed) 134 | 135 | 136 | 137 | run_list = [acc_value,cur_labels,cur_preds] 138 | 139 | if diff_op is not None: 140 | run_list += [diff_op] 141 | 142 | if predictions_rec is not None: 143 | run_list += [acc_value_rec] 144 | acc_val_ind = len(run_list)-1; 145 | 146 | outs = sess.run(run_list,feed_dict=feed_dict) 147 | cur_acc = outs[0] 148 | 149 | if diff_op is not None: 150 | cur_diffs_val = outs[3] 151 | diffs.append(cur_diffs_val) 152 | 153 | if predictions_rec is not None: 154 | cur_acc_rec = outs[acc_val_ind] 155 | accuracy_rec += cur_acc_rec 156 | 157 | cur_labels_val = outs[1][0] 158 | cur_preds_val = outs[2] 159 | all_labels.append(cur_labels_val) 160 | preds.append(cur_preds_val) 161 | 162 | accuracy += cur_acc 163 | 164 | assert end >= len(test_images) 165 | 166 | # Divide by number of examples to get final value. 167 | accuracy /= len(test_images) 168 | accuracy_rec /= len(test_images) 169 | preds = np.concatenate(preds) 170 | all_labels = np.concatenate(all_labels) 171 | 172 | if diff_op is not None: 173 | diffs = np.concatenate(diffs) 174 | 175 | roc_info = [all_labels,preds,diffs] 176 | if predictions_rec is not None: 177 | return accuracy,accuracy_rec,roc_info 178 | else: 179 | return accuracy, roc_info 180 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Defense-GAN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Defense-GAN utility functions.""" 17 | 18 | import os 19 | 20 | 21 | def static_vars(**kwargs): 22 | def decorate(func): 23 | for k in kwargs: 24 | setattr(func, k, kwargs[k]) 25 | return func 26 | 27 | return decorate 28 | 29 | 30 | def make_dir(dir_path): 31 | if not os.path.exists(dir_path): 32 | os.makedirs(dir_path) 33 | print('[+] Created the directory: {}'.format(dir_path)) 34 | 35 | 36 | ensure_dir = make_dir 37 | -------------------------------------------------------------------------------- /utils/network_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Defense-GAN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Modified for Defense-GAN: 17 | - Added the ReconstructionLayer class for cleverhans. 18 | - The different model architectures that are tested in the paper. 19 | 20 | Modified version of cleverhans/model.py 21 | 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | from __future__ import unicode_literals 27 | 28 | from abc import ABCMeta 29 | 30 | import keras.backend as K 31 | import numpy as np 32 | import tensorflow as tf 33 | 34 | 35 | class Model(object): 36 | """ 37 | An abstract interface for model wrappers that exposes model symbols 38 | needed for making an attack. This abstraction removes the dependency on 39 | any specific neural network package (e.g. Keras) from the core 40 | code of CleverHans. It can also simplify exposing the hidden features of a 41 | model when a specific package does not directly expose them. 42 | """ 43 | __metaclass__ = ABCMeta 44 | 45 | def __init__(self): 46 | pass 47 | 48 | def __call__(self, *args, **kwargs): 49 | """ 50 | For compatibility with functions used as model definitions (taking 51 | an input tensor and returning the tensor giving the output 52 | of the model on that input). 53 | """ 54 | return self.get_probs(*args, **kwargs) 55 | 56 | def get_layer(self, x, layer): 57 | """ 58 | Expose the hidden features of a model given a layer name. 59 | :param x: A symbolic representation of the network input 60 | :param layer: The name of the hidden layer to return features at. 61 | :return: A symbolic representation of the hidden features 62 | :raise: NoSuchLayerError if `layer` is not in the model. 63 | """ 64 | # Return the symbolic representation for this layer. 65 | output = self.fprop(x) 66 | try: 67 | requested = output[layer] 68 | except KeyError: 69 | raise NoSuchLayerError() 70 | return requested 71 | 72 | def get_logits(self, x): 73 | """ 74 | :param x: A symbolic representation of the network input 75 | :return: A symbolic representation of the output logits (i.e., the 76 | values fed as inputs to the softmax layer). 77 | """ 78 | return self.get_layer(x, 'logits') 79 | 80 | def get_probs(self, x): 81 | """ 82 | :param x: A symbolic representation of the network input 83 | :return: A symbolic representation of the output probabilities (i.e., 84 | the output values produced by the softmax layer). 85 | """ 86 | try: 87 | return self.get_layer(x, 'probs') 88 | except NoSuchLayerError: 89 | import tensorflow as tf 90 | return tf.nn.softmax(self.get_logits(x)) 91 | 92 | def get_layer_names(self): 93 | """ 94 | :return: a list of names for the layers that can be exposed by this 95 | model abstraction. 96 | """ 97 | 98 | if hasattr(self, 'layer_names'): 99 | return self.layer_names 100 | 101 | raise NotImplementedError('`get_layer_names` not implemented.') 102 | 103 | def fprop(self, x): 104 | """ 105 | Exposes all the layers of the model returned by get_layer_names. 106 | :param x: A symbolic representation of the network input 107 | :return: A dictionary mapping layer names to the symbolic 108 | representation of their output. 109 | """ 110 | raise NotImplementedError('`fprop` not implemented.') 111 | 112 | 113 | class CallableModelWrapper(Model): 114 | 115 | def __init__(self, callable_fn, output_layer): 116 | """ 117 | Wrap a callable function that takes a tensor as input and returns 118 | a tensor as output with the given layer name. 119 | :param callable_fn: The callable function taking a tensor and 120 | returning a given layer as output. 121 | :param output_layer: A string of the output layer returned by the 122 | function. (Usually either "probs" or "logits".) 123 | """ 124 | 125 | self.output_layer = output_layer 126 | self.callable_fn = callable_fn 127 | 128 | def get_layer_names(self): 129 | return [self.output_layer] 130 | 131 | def fprop(self, x): 132 | return {self.output_layer: self.callable_fn(x)} 133 | 134 | 135 | class NoSuchLayerError(ValueError): 136 | """Raised when a layer that does not exist is requested.""" 137 | 138 | 139 | class MLP(Model): 140 | """ 141 | An example of a bare bones multilayer perceptron (MLP) class. 142 | """ 143 | 144 | def __init__(self, layers, input_shape, rec_model=None): 145 | super(MLP, self).__init__() 146 | self.layer_names = [] 147 | self.layers = layers 148 | self.input_shape = input_shape 149 | if isinstance(layers[-1], Softmax): 150 | layers[-1].name = 'probs' 151 | layers[-2].name = 'logits' 152 | else: 153 | layers[-1].name = 'logits' 154 | for i, layer in enumerate(self.layers): 155 | if hasattr(layer, 'name'): 156 | name = layer.name 157 | else: 158 | name = layer.__class__.__name__ + str(i) 159 | self.layer_names.append(name) 160 | 161 | layer.set_input_shape(input_shape) 162 | input_shape = layer.get_output_shape() 163 | 164 | def fprop(self, x, set_ref=False, no_rec=False): 165 | states = [] 166 | start = 0 167 | if no_rec: 168 | start = 1 169 | 170 | for layer in self.layers[start:]: 171 | if set_ref: 172 | layer.ref = x 173 | x = layer.fprop(x) 174 | assert x is not None 175 | states.append(x) 176 | states = dict(zip(self.get_layer_names(), states)) 177 | return states 178 | 179 | def add_rec_model(self, model, z_init, batch_size): 180 | rec_layer = ReconstructionLayer(model, z_init, self.input_shape, batch_size) 181 | rec_layer.set_input_shape(self.input_shape) 182 | self.layers = [rec_layer] + self.layers 183 | self.layer_names = ['reconstruction'] + self.layer_names 184 | 185 | 186 | class Layer(object): 187 | def get_output_shape(self): 188 | return self.output_shape 189 | 190 | 191 | class Linear(Layer): 192 | def __init__(self, num_hid): 193 | self.num_hid = num_hid 194 | 195 | def set_input_shape(self, input_shape): 196 | batch_size, dim = input_shape 197 | self.input_shape = [batch_size, dim] 198 | self.output_shape = [batch_size, self.num_hid] 199 | init = tf.random_normal([dim, self.num_hid], dtype=tf.float32) 200 | init = init / tf.sqrt(1e-7 + tf.reduce_sum(tf.square(init), axis=0, 201 | keep_dims=True)) 202 | self.W = tf.Variable(init) 203 | self.b = tf.Variable(np.zeros((self.num_hid,)).astype('float32')) 204 | 205 | def fprop(self, x): 206 | return tf.matmul(x, self.W) + self.b 207 | 208 | 209 | class Conv2D(Layer): 210 | def __init__(self, output_channels, kernel_shape, strides, padding): 211 | self.__dict__.update(locals()) 212 | del self.self 213 | 214 | def set_input_shape(self, input_shape): 215 | batch_size, rows, cols, input_channels = input_shape 216 | kernel_shape = tuple(self.kernel_shape) + (input_channels, 217 | self.output_channels) 218 | assert len(kernel_shape) == 4 219 | assert all(isinstance(e, int) for e in kernel_shape), kernel_shape 220 | init = tf.random_normal(kernel_shape, dtype=tf.float32) 221 | init = init / tf.sqrt(1e-7 + tf.reduce_sum(tf.square(init), 222 | axis=(0, 1, 2))) 223 | self.kernels = tf.Variable(init) 224 | self.b = tf.Variable( 225 | np.zeros((self.output_channels,)).astype('float32')) 226 | input_shape = list(input_shape) 227 | input_shape[0] = 1 228 | dummy_batch = tf.zeros(input_shape) 229 | dummy_output = self.fprop(dummy_batch) 230 | output_shape = [int(e) for e in dummy_output.get_shape()] 231 | output_shape[0] = 1 232 | self.output_shape = tuple(output_shape) 233 | 234 | def fprop(self, x): 235 | return tf.nn.conv2d(x, self.kernels, (1,) + tuple(self.strides) + (1,), 236 | self.padding) + self.b 237 | 238 | 239 | class ReconstructionLayer(Layer): 240 | """This layer is used as a wrapper for Defense-GAN's reconstruction 241 | part. 242 | """ 243 | 244 | def __init__(self, model, z_init, input_shape, batch_size): 245 | """Constructor of the layer. 246 | 247 | Args: 248 | model: `Callable`. The generator model that gets an input and 249 | reconstructs it. `def gen(Tensor) -> Tensor.` 250 | z_init: `tf.Tensor'. 251 | input_shape: `List[int]`. 252 | batch_size: int. 253 | """ 254 | self.z_init = z_init 255 | self.rec_model = model 256 | self.input_shape = input_shape 257 | self.batch_size = batch_size 258 | 259 | def set_input_shape(self, shape): 260 | self.input_shape = shape 261 | self.output_shape = shape 262 | 263 | def get_output_shape(self): 264 | return self.output_shape 265 | 266 | def fprop(self, x): 267 | x.set_shape(self.input_shape) 268 | self.rec = self.rec_model.reconstruct( 269 | x, batch_size=self.batch_size, back_prop=True, z_init_val=self.z_init, 270 | reconstructor_id=123) 271 | return self.rec 272 | 273 | 274 | class ReLU(Layer): 275 | def __init__(self): 276 | pass 277 | 278 | def set_input_shape(self, shape): 279 | self.input_shape = shape 280 | self.output_shape = shape 281 | 282 | def get_output_shape(self): 283 | return self.output_shape 284 | 285 | def fprop(self, x): 286 | return tf.nn.relu(x) 287 | 288 | 289 | class Dropout(Layer): 290 | def __init__(self, prob): 291 | self.prob = prob 292 | pass 293 | 294 | def set_input_shape(self, shape): 295 | self.input_shape = shape 296 | self.output_shape = shape 297 | 298 | def get_output_shape(self): 299 | return self.output_shape 300 | 301 | def fprop(self, x): 302 | return tf.cond(K.learning_phase(), lambda: tf.nn.dropout(x, self.prob), lambda: x) 303 | 304 | 305 | class Softmax(Layer): 306 | def __init__(self): 307 | pass 308 | 309 | def set_input_shape(self, shape): 310 | self.input_shape = shape 311 | self.output_shape = shape 312 | 313 | def fprop(self, x): 314 | return tf.nn.softmax(x) 315 | 316 | 317 | class Flatten(Layer): 318 | def __init__(self): 319 | pass 320 | 321 | def set_input_shape(self, shape): 322 | self.input_shape = shape 323 | output_width = 1 324 | for factor in shape[1:]: 325 | output_width *= factor 326 | self.output_width = output_width 327 | self.output_shape = [None, output_width] 328 | 329 | def fprop(self, x): 330 | return tf.reshape(x, [-1, self.output_width]) 331 | 332 | 333 | def model_f(nb_filters=64, nb_classes=10, 334 | input_shape=(None, 28, 28, 1), rec_model=None): 335 | layers = [Conv2D(nb_filters, (8, 8), (2, 2), "SAME"), 336 | ReLU(), 337 | Conv2D(nb_filters * 2, (6, 6), (2, 2), "VALID"), 338 | ReLU(), 339 | Conv2D(nb_filters * 2, (5, 5), (1, 1), "VALID"), 340 | ReLU(), 341 | Flatten(), 342 | Linear(nb_classes), 343 | Softmax()] 344 | 345 | model = MLP(layers, input_shape, rec_model=rec_model) 346 | return model 347 | 348 | 349 | def model_e(input_shape=(None, 28, 28, 1), nb_classes=10): 350 | """ 351 | Defines the model architecture to be used by the substitute. Use 352 | the example model interface. 353 | :param img_rows: number of rows in input 354 | :param img_cols: number of columns in input 355 | :param nb_classes: number of classes in output 356 | :return: tensorflow model 357 | """ 358 | 359 | # Define a fully connected model (it's different than the black-box). 360 | layers = [Flatten(), 361 | Linear(200), 362 | ReLU(), 363 | Linear(200), 364 | ReLU(), 365 | Linear(nb_classes), 366 | Softmax()] 367 | 368 | return MLP(layers, input_shape) 369 | 370 | 371 | def model_d(input_shape=(None, 28, 28, 1), nb_classes=10): 372 | """ 373 | Defines the model architecture to be used by the substitute. Use 374 | the example model interface. 375 | :param img_rows: number of rows in input 376 | :param img_cols: number of columns in input 377 | :param nb_classes: number of classes in output 378 | :return: tensorflow model 379 | """ 380 | 381 | # Define a fully connected model (it's different than the black-box) 382 | layers = [Flatten(), 383 | Linear(200), 384 | ReLU(), 385 | Dropout(0.5), 386 | Linear(200), 387 | ReLU(), 388 | Linear(nb_classes), 389 | Softmax()] 390 | 391 | return MLP(layers, input_shape) 392 | 393 | 394 | def model_b(nb_filters=64, nb_classes=10, 395 | input_shape=(None, 28, 28, 1), rec_model=None): 396 | layers = [Dropout(0.2), 397 | Conv2D(nb_filters, (8, 8), (2, 2), "SAME"), 398 | ReLU(), 399 | Conv2D(nb_filters * 2, (6, 6), (2, 2), "VALID"), 400 | ReLU(), 401 | Conv2D(nb_filters * 2, (5, 5), (1, 1), "VALID"), 402 | ReLU(), 403 | Dropout(0.5), 404 | Flatten(), 405 | Linear(nb_classes), 406 | Softmax()] 407 | 408 | model = MLP(layers, input_shape, rec_model=rec_model) 409 | return model 410 | 411 | 412 | def model_a(nb_filters=64, nb_classes=10, 413 | input_shape=(None, 28, 28, 1), rec_model=None): 414 | layers = [Conv2D(nb_filters, (5, 5), (1, 1), "SAME"), 415 | ReLU(), 416 | Conv2D(nb_filters, (5, 5), (2, 2), "VALID"), 417 | ReLU(), 418 | Flatten(), 419 | Dropout(0.25), 420 | Linear(128), 421 | ReLU(), 422 | Dropout(0.5), 423 | Linear(nb_classes), 424 | Softmax()] 425 | 426 | model = MLP(layers, input_shape, rec_model=rec_model) 427 | return model 428 | 429 | 430 | def model_c(nb_filters=64, nb_classes=10, 431 | input_shape=(None, 28, 28, 1), rec_model=None): 432 | layers = [Conv2D(nb_filters * 2, (3, 3), (1, 1), "SAME"), 433 | ReLU(), 434 | Conv2D(nb_filters, (5, 5), (2, 2), "VALID"), 435 | ReLU(), 436 | Flatten(), 437 | Dropout(0.25), 438 | Linear(128), 439 | ReLU(), 440 | Dropout(0.5), 441 | Linear(nb_classes), 442 | Softmax()] 443 | 444 | model = MLP(layers, input_shape, rec_model=rec_model) 445 | return model 446 | 447 | 448 | def model_y(nb_filters=64, nb_classes=10, 449 | input_shape=(None, 28, 28, 1), rec_model=None): 450 | layers = [Conv2D(nb_filters, (3, 3), (1, 1), "SAME"), 451 | ReLU(), 452 | Conv2D(nb_filters, (3, 3), (2, 2), "VALID"), 453 | ReLU(), 454 | Conv2D(2 * nb_filters, (3, 3), (2, 2), "VALID"), 455 | ReLU(), 456 | Conv2D(2 * nb_filters, (3, 3), (2, 2), "VALID"), 457 | ReLU(), 458 | Flatten(), 459 | Linear(256), 460 | ReLU(), 461 | Dropout(0.5), 462 | Linear(256), 463 | ReLU(), 464 | Dropout(0.5), 465 | Linear(nb_classes), 466 | Softmax()] 467 | 468 | model = MLP(layers, input_shape, rec_model=rec_model) 469 | return model 470 | 471 | 472 | def model_q(nb_filters=32, nb_classes=10, 473 | input_shape=(None, 28, 28, 1), rec_model=None): 474 | layers = [Conv2D(nb_filters, (3, 3), (1, 1), "SAME"), 475 | ReLU(), 476 | Conv2D(nb_filters, (3, 3), (2, 2), "VALID"), 477 | ReLU(), 478 | Conv2D(2 * nb_filters, (3, 3), (1, 1), "VALID"), 479 | ReLU(), 480 | Conv2D(2 * nb_filters, (3, 3), (2, 2), "VALID"), 481 | ReLU(), 482 | Flatten(), 483 | Linear(256), 484 | ReLU(), 485 | Dropout(0.5), 486 | Linear(256), 487 | ReLU(), 488 | Dropout(0.5), 489 | Linear(nb_classes), 490 | Softmax()] 491 | 492 | model = MLP(layers, input_shape, rec_model=rec_model) 493 | return model 494 | 495 | 496 | def model_z(nb_filters=32, nb_classes=10, 497 | input_shape=(None, 28, 28, 1), rec_model=None): 498 | layers = [Conv2D(nb_filters, (3, 3), (1, 1), "SAME"), 499 | ReLU(), 500 | Conv2D(nb_filters, (3, 3), (2, 2), "VALID"), 501 | ReLU(), 502 | Conv2D(2 * nb_filters, (3, 3), (1, 1), "VALID"), 503 | ReLU(), 504 | Conv2D(2 * nb_filters, (3, 3), (2, 2), "VALID"), 505 | ReLU(), 506 | Conv2D(4 * nb_filters, (3, 3), (1, 1), "VALID"), 507 | ReLU(), 508 | Conv2D(4 * nb_filters, (3, 3), (2, 2), "VALID"), 509 | ReLU(), 510 | Flatten(), 511 | Linear(600), 512 | ReLU(), 513 | Dropout(0.5), 514 | Linear(600), 515 | ReLU(), 516 | Dropout(0.5), 517 | Linear(nb_classes), 518 | Softmax()] 519 | 520 | model = MLP(layers, input_shape, rec_model=rec_model) 521 | return model 522 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Defense-GAN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Visualization utilities.""" 17 | 18 | import gc 19 | import os 20 | 21 | import numpy as np 22 | import scipy.misc 23 | 24 | from utils.misc import static_vars, make_dir 25 | 26 | 27 | @static_vars(plt_counter=0) 28 | def save_plot(plt, fname=None, save_dir='debug/plots/'): 29 | plt.tight_layout() 30 | plt.draw() 31 | if fname is None: 32 | fname = 'plot_{}.png'.format(save_plot.plt_counter) 33 | save_plot.plt_counter = save_plot.plt_counter + 1 34 | 35 | make_dir(save_dir) 36 | if not 'png' in fname and not 'pdf' in fname: 37 | fname = fname + '.png' 38 | 39 | save_path = os.path.join(save_dir, fname) 40 | plt.savefig(save_path) 41 | print('[-] Saved plot to {}'.format(save_path)) 42 | plt.clf() 43 | plt.close() 44 | gc.collect() 45 | 46 | 47 | def save_images_files(images, prefix='im', labels=None, output_dir=None, 48 | postfix=''): 49 | if prefix is None and labels is None: 50 | prefix = '{}_image.png' 51 | else: 52 | prefix = prefix + '_{:03d}' 53 | if labels is not None: 54 | prefix = prefix + '_{:03d}' 55 | 56 | prefix = prefix + postfix + '.png' 57 | 58 | assert len(images.shape) == 4, 'images should be a 4D np array uint8' 59 | for i in range(images.shape[0]): 60 | image = images[i] 61 | if labels is None: 62 | save_image(image, fname=prefix.format(i), dir_path=output_dir) 63 | else: 64 | save_image(image, fname=prefix.format(i, int(labels[i])), 65 | dir_path=output_dir) 66 | 67 | 68 | @static_vars(image_counter=0) 69 | def save_image(image, fname=None, dir_path='debug/images/'): 70 | if fname is None: 71 | fname = 'image_{}.png'.format(save_image.image_counter) 72 | save_image.image_counter = save_image.image_counter + 1 73 | make_dir(dir_path) 74 | fpath = os.path.join(dir_path, fname) 75 | save_image_core(image, fpath) 76 | 77 | 78 | def save_image_core(image, path): 79 | """Save an image as a png file""" 80 | if image.shape[0] == 3 or image.shape[0] == 1: 81 | image = image.transpose([1, 2, 0]) 82 | image = ((image.squeeze() * 1.0 - image.min()) / ( 83 | image.max() - image.min() + 1e-7)) * 255 84 | image = image.astype(np.uint8) 85 | scipy.misc.imsave(path, image) 86 | 87 | print('[#] saved image to: {}'.format(path)) 88 | -------------------------------------------------------------------------------- /whitebox.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Defense-GAN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Testing white-box attacks Defense-GAN models. This module is based on MNIST 17 | tutorial of cleverhans.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | from __future__ import unicode_literals 23 | 24 | import _init_paths 25 | 26 | import argparse 27 | import cPickle 28 | import logging 29 | import os 30 | import sys 31 | 32 | import keras.backend as K 33 | import numpy as np 34 | import tensorflow as tf 35 | 36 | from blackbox import dataset_gan_dict, get_cached_gan_data 37 | from cleverhans.attacks import CarliniWagnerL2 38 | from cleverhans.attacks import FastGradientMethod 39 | from cleverhans.utils import AccuracyReport 40 | from cleverhans.utils import set_log_level 41 | from cleverhans.utils_tf import model_train, model_eval 42 | from models.gan import MnistDefenseGAN, FmnistDefenseDefenseGAN, CelebADefenseGAN 43 | from utils.config import load_config 44 | from utils.gan_defense import model_eval_gan 45 | from utils.misc import ensure_dir 46 | from utils.network_builder import model_a, model_b, model_c, model_d, model_e, model_f 47 | 48 | ds_gan = { 49 | 'mnist': MnistDefenseGAN, 50 | 'f-mnist': FmnistDefenseDefenseGAN, 51 | 'celeba': CelebADefenseGAN, 52 | } 53 | orig_data_paths = {k: 'data/cache/{}_pkl'.format(k) for k in ds_gan.keys()} 54 | 55 | 56 | def whitebox(gan, rec_data_path=None, batch_size=128, learning_rate=0.001, 57 | nb_epochs=10, eps=0.3, online_training=False, 58 | test_on_dev=True, attack_type='fgsm', defense_type='gan', 59 | num_tests=-1, num_train=-1): 60 | """Based on MNIST tutorial from cleverhans. 61 | 62 | Args: 63 | gan: A `GAN` model. 64 | rec_data_path: A string to the directory. 65 | batch_size: The size of the batch. 66 | learning_rate: The learning rate for training the target models. 67 | nb_epochs: Number of epochs for training the target model. 68 | eps: The epsilon of FGSM. 69 | online_training: Training Defense-GAN with online reconstruction. The 70 | faster but less accurate way is to reconstruct the dataset once and use 71 | it to train the target models with: 72 | `python train.py --cfg --save_recs` 73 | attack_type: Type of the white-box attack. It can be `fgsm`, 74 | `rand+fgsm`, or `cw`. 75 | defense_type: String representing the type of attack. Can be `none`, 76 | `defense_gan`, or `adv_tr`. 77 | """ 78 | 79 | FLAGS = tf.flags.FLAGS 80 | 81 | # Set logging level to see debug information. 82 | set_log_level(logging.WARNING) 83 | 84 | if defense_type == 'defense_gan': 85 | assert gan is not None 86 | 87 | # Create TF session. 88 | if defense_type == 'defense_gan': 89 | sess = gan.sess 90 | if FLAGS.train_on_recs: 91 | assert rec_data_path is not None or online_training 92 | else: 93 | config = tf.ConfigProto() 94 | config.gpu_options.allow_growth = True 95 | sess = tf.Session(config=config) 96 | 97 | train_images, train_labels, test_images, test_labels = \ 98 | get_cached_gan_data(gan, test_on_dev) 99 | 100 | rec_test_images = test_images 101 | rec_test_labels = test_labels 102 | 103 | _, _, test_images, test_labels = \ 104 | get_cached_gan_data(gan, test_on_dev, orig_data_flag=True) 105 | 106 | x_shape = [None] + list(train_images.shape[1:]) 107 | images_pl = tf.placeholder(tf.float32, shape=[None] + list(train_images.shape[1:])) 108 | labels_pl = tf.placeholder(tf.float32, shape=[None] + [train_labels.shape[1]]) 109 | 110 | if num_tests > 0: 111 | test_images = test_images[:num_tests] 112 | rec_test_images = rec_test_images[:num_tests] 113 | test_labels = test_labels[:num_tests] 114 | 115 | if num_train > 0: 116 | train_images = train_images[:num_train] 117 | train_labels = train_labels[:num_train] 118 | 119 | # GAN defense flag. 120 | models = {'A': model_a, 'B': model_b, 'C': model_c, 'D': model_d, 'E': model_e, 'F': model_f} 121 | model = models[FLAGS.model](input_shape=x_shape, nb_classes=train_labels.shape[1]) 122 | 123 | preds = model.get_probs(images_pl) 124 | report = AccuracyReport() 125 | 126 | def evaluate(): 127 | # Evaluate the accuracy of the MNIST model on legitimate test 128 | # examples. 129 | eval_params = {'batch_size': batch_size} 130 | acc = model_eval( 131 | sess, images_pl, labels_pl, preds, rec_test_images, 132 | rec_test_labels, args=eval_params, 133 | feed={K.learning_phase(): 0}) 134 | report.clean_train_clean_eval = acc 135 | print('Test accuracy on legitimate examples: %0.4f' % acc) 136 | 137 | train_params = { 138 | 'nb_epochs': nb_epochs, 139 | 'batch_size': batch_size, 140 | 'learning_rate': learning_rate, 141 | } 142 | 143 | rng = np.random.RandomState([11, 24, 1990]) 144 | tf.set_random_seed(11241990) 145 | 146 | preds_adv = None 147 | if FLAGS.defense_type == 'adv_tr': 148 | attack_params = {'eps': FLAGS.fgsm_eps_tr, 149 | 'clip_min': 0., 150 | 'clip_max': 1.} 151 | if gan: 152 | if gan.dataset_name == 'celeba': 153 | attack_params['clip_min'] = -1.0 154 | 155 | attack_obj = FastGradientMethod(model, sess=sess) 156 | adv_x_tr = attack_obj.generate(images_pl, **attack_params) 157 | adv_x_tr = tf.stop_gradient(adv_x_tr) 158 | preds_adv = model(adv_x_tr) 159 | 160 | model_train(sess, images_pl, labels_pl, preds, train_images, train_labels, 161 | args=train_params, rng=rng, predictions_adv=preds_adv, 162 | init_all=False, feed={K.learning_phase(): 1}, 163 | evaluate=evaluate) 164 | 165 | # Calculate training error. 166 | eval_params = {'batch_size': batch_size} 167 | acc = model_eval( 168 | sess, images_pl, labels_pl, preds, train_images, train_labels, 169 | args=eval_params, feed={K.learning_phase(): 0}, 170 | ) 171 | print('[#] Accuracy on clean examples {}'.format(acc)) 172 | if attack_type is None: 173 | return acc, 0, None 174 | 175 | # Initialize the Fast Gradient Sign Method (FGSM) attack object and 176 | # graph. 177 | 178 | if FLAGS.defense_type == 'defense_gan': 179 | z_init_val = None 180 | 181 | if FLAGS.same_init: 182 | z_init_val = tf.constant( 183 | np.random.randn(batch_size * gan.rec_rr, gan.latent_dim).astype(np.float32)) 184 | 185 | model.add_rec_model(gan, z_init_val, batch_size) 186 | 187 | min_val = 0.0 188 | if gan: 189 | if gan.dataset_name == 'celeba': 190 | min_val = -1.0 191 | 192 | if 'rand' in FLAGS.attack_type: 193 | test_images = np.clip( 194 | test_images + args.alpha * np.sign(np.random.randn(*test_images.shape)), 195 | min_val, 1.0) 196 | eps -= args.alpha 197 | 198 | if 'fgsm' in FLAGS.attack_type: 199 | attack_params = {'eps': eps, 'ord': np.inf, 'clip_min': min_val, 'clip_max': 1.} 200 | attack_obj = FastGradientMethod(model, sess=sess) 201 | elif FLAGS.attack_type == 'cw': 202 | attack_obj = CarliniWagnerL2(model, back='tf', sess=sess) 203 | attack_iterations = 100 204 | attack_params = {'binary_search_steps': 1, 205 | 'max_iterations': attack_iterations, 206 | 'learning_rate': 10.0, 207 | 'batch_size': batch_size, 208 | 'initial_const': 100, 209 | 'feed': {K.learning_phase(): 0}} 210 | adv_x = attack_obj.generate(images_pl, **attack_params) 211 | 212 | eval_par = {'batch_size': batch_size} 213 | if FLAGS.defense_type == 'defense_gan': 214 | preds_adv = model.get_probs(adv_x) 215 | 216 | num_dims = len(images_pl.get_shape()) 217 | avg_inds = list(range(1, num_dims)) 218 | diff_op = tf.reduce_mean(tf.square(adv_x - images_pl), axis=avg_inds) 219 | acc_adv, roc_info = model_eval_gan( 220 | sess, images_pl, labels_pl, preds_adv, None, 221 | test_images=test_images, test_labels=test_labels, args=eval_par, 222 | feed={K.learning_phase(): 0}, diff_op=diff_op, 223 | ) 224 | print('Test accuracy on adversarial examples: %0.4f\n' % acc_adv) 225 | return acc_adv, 0, roc_info 226 | else: 227 | preds_adv = model(adv_x) 228 | acc_adv = model_eval(sess, images_pl, labels_pl, preds_adv, test_images, test_labels, 229 | args=eval_par, 230 | feed={K.learning_phase(): 0}) 231 | print('Test accuracy on adversarial examples: %0.4f\n' % acc_adv) 232 | 233 | return acc_adv, 0, None 234 | 235 | 236 | import re 237 | 238 | 239 | def main(cfg, argv=None): 240 | FLAGS = tf.app.flags.FLAGS 241 | GAN = dataset_gan_dict[FLAGS.dataset_name] 242 | 243 | gan = GAN(cfg=cfg, test_mode=True) 244 | gan.load_generator() 245 | # Setting test time reconstruction hyper parameters. 246 | [tr_rr, tr_lr, tr_iters] = [FLAGS.rec_rr, FLAGS.rec_lr, FLAGS.rec_iters] 247 | if FLAGS.defense_type.lower() != 'none': 248 | if FLAGS.rec_path and FLAGS.defense_type == 'defense_gan': 249 | 250 | # Extract hyperparameters from reconstruction path. 251 | if FLAGS.rec_path: 252 | train_param_re = re.compile('recs_rr(.*)_lr(.*)_iters(.*)') 253 | [tr_rr, tr_lr, tr_iters] = \ 254 | train_param_re.findall(FLAGS.rec_path)[0] 255 | gan.rec_rr = int(tr_rr) 256 | gan.rec_lr = float(tr_lr) 257 | gan.rec_iters = int(tr_iters) 258 | elif FLAGS.defense_type == 'defense_gan': 259 | assert FLAGS.online_training or not FLAGS.train_on_recs 260 | 261 | if FLAGS.override: 262 | gan.rec_rr = int(tr_rr) 263 | gan.rec_lr = float(tr_lr) 264 | gan.rec_iters = int(tr_iters) 265 | 266 | # Setting the results directory. 267 | results_dir, result_file_name = _get_results_dir_filename(gan) 268 | 269 | # Result file name. The counter ensures we are not overwriting the 270 | # results. 271 | counter = 0 272 | temp_fp = str(counter) + '_' + result_file_name 273 | results_dir = os.path.join(results_dir, FLAGS.results_dir) 274 | temp_final_fp = os.path.join(results_dir, temp_fp) 275 | while os.path.exists(temp_final_fp): 276 | counter += 1 277 | temp_fp = str(counter) + '_' + result_file_name 278 | temp_final_fp = os.path.join(results_dir, temp_fp) 279 | result_file_name = temp_fp 280 | sub_result_path = os.path.join(results_dir, result_file_name) 281 | 282 | accuracies = whitebox( 283 | gan, rec_data_path=FLAGS.rec_path, 284 | batch_size=FLAGS.batch_size, 285 | learning_rate=FLAGS.learning_rate, 286 | nb_epochs=FLAGS.nb_epochs, 287 | eps=FLAGS.fgsm_eps, 288 | online_training=FLAGS.online_training, 289 | defense_type=FLAGS.defense_type, 290 | num_tests=FLAGS.num_tests, 291 | attack_type=FLAGS.attack_type, 292 | num_train=FLAGS.num_train, 293 | ) 294 | 295 | ensure_dir(results_dir) 296 | 297 | with open(sub_result_path, 'a') as f: 298 | f.writelines([str(accuracies[i]) + ' ' for i in range(2)]) 299 | f.write('\n') 300 | print('[*] saved accuracy in {}'.format(sub_result_path)) 301 | 302 | if accuracies[2]: # For attack detection. 303 | pkl_result_path = sub_result_path.replace('.txt', '_roc.pkl') 304 | with open(pkl_result_path, 'w') as f: 305 | cPickle.dump(accuracies[2], f, cPickle.HIGHEST_PROTOCOL) 306 | print('[*] saved roc_info in {}'.format(pkl_result_path)) 307 | 308 | 309 | def _get_results_dir_filename(gan): 310 | FLAGS = tf.flags.FLAGS 311 | 312 | results_dir = os.path.join('results', 'whitebox_{}_{}'.format( 313 | FLAGS.defense_type, FLAGS.dataset_name)) 314 | 315 | if FLAGS.rec_path and FLAGS.defense_type == 'defense_gan': 316 | results_dir = gan.checkpoint_dir.replace('output', 'results') 317 | result_file_name = \ 318 | 'Iter={}_RR={:d}_LR={:.4f}_defense=gan'.format( 319 | gan.rec_rr, 320 | gan.rec_lr, 321 | gan.rec_iters, 322 | FLAGS.attack_type, 323 | ) 324 | 325 | if not FLAGS.train_on_recs: 326 | result_file_name = 'orig_' + result_file_name 327 | elif FLAGS.defense_type == 'adv_tr': 328 | result_file_name = 'advTrEps={:.2f}'.format(FLAGS.fgsm_eps_tr) 329 | else: 330 | result_file_name = 'nodefense_' 331 | if FLAGS.num_tests > -1: 332 | result_file_name = 'numtest={}_'.format( 333 | FLAGS.num_tests) + result_file_name 334 | 335 | if FLAGS.num_train > -1: 336 | result_file_name = 'numtrain={}_'.format( 337 | FLAGS.num_train) + result_file_name 338 | 339 | result_file_name = 'model={}_'.format(FLAGS.model) + result_file_name 340 | result_file_name += 'attack={}.txt'.format(FLAGS.attack_type) 341 | return results_dir, result_file_name 342 | 343 | 344 | def parse_args(): 345 | parser = argparse.ArgumentParser() 346 | 347 | parser.add_argument('--cfg', required=True, help='Config file') 348 | parser.add_argument("--alpha", type=float, default=0.05, 349 | help="RAND+FGSM random perturbation scale") 350 | 351 | if len(sys.argv) == 1: 352 | parser.print_help() 353 | sys.exit(1) 354 | args, _ = parser.parse_known_args() 355 | return args 356 | 357 | 358 | if __name__ == '__main__': 359 | args = parse_args() 360 | 361 | # Note: The load_config() call will convert all the parameters that are defined in 362 | # experiments/config files into FLAGS.param_name and can be passed in from command line. 363 | # arguments : python whitebox.py --cfg -- 364 | cfg = load_config(args.cfg) 365 | flags = tf.app.flags 366 | 367 | flags.DEFINE_integer('nb_classes', 10, 'Number of classes.') 368 | flags.DEFINE_float('learning_rate', 0.001, 'Learning rate for training.') 369 | flags.DEFINE_integer('nb_epochs', 10, 'Number of epochs to train model.') 370 | flags.DEFINE_float('lmbda', 0.1, 'Lambda from arxiv.org/abs/1602.02697.') 371 | flags.DEFINE_float('fgsm_eps', 0.3, 'FGSM epsilon.') 372 | flags.DEFINE_string('rec_path', None, 'Path to reconstructions.') 373 | flags.DEFINE_integer('num_tests', -1, 'Number of test samples.') 374 | flags.DEFINE_integer('random_test_iter', -1, 375 | 'Number of random sampling for testing the classifier.') 376 | flags.DEFINE_boolean("online_training", False, 377 | "Train the base classifier on reconstructions.") 378 | flags.DEFINE_string("defense_type", "none", "Type of defense [none|defense_gan|adv_tr]") 379 | flags.DEFINE_string("attack_type", "none", "Type of attack [fgsm|cw|rand_fgsm]") 380 | flags.DEFINE_string("results_dir", None, "The final subdirectory of the results.") 381 | flags.DEFINE_boolean("same_init", False, "Same initialization for z_hats.") 382 | flags.DEFINE_string("model", "F", "The classifier model.") 383 | flags.DEFINE_string("debug_dir", "temp", "The debug directory.") 384 | flags.DEFINE_integer("num_train", -1, 'Number of training data to load.') 385 | flags.DEFINE_boolean("debug", False, "True for saving reconstructions [False]") 386 | flags.DEFINE_boolean("override", False, "Overriding the config values of reconstruction " 387 | "hyperparameters. It has to be true if either " 388 | "`--rec_rr`, `--rec_lr`, or `--rec_iters` is passed " 389 | "from command line.") 390 | flags.DEFINE_boolean("train_on_recs", False, 391 | "Train the classifier on the reconstructed samples " 392 | "using Defense-GAN.") 393 | 394 | main_cfg = lambda x: main(cfg, x) 395 | tf.app.run(main=main_cfg) 396 | --------------------------------------------------------------------------------