├── .gitignore ├── LICENSE ├── README.md ├── datasets ├── __init__.py └── imagenet_input.py ├── examples.png ├── model_lib.py ├── robustml_attack.py ├── robustml_eval.py └── robustml_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.tar.gz 2 | *.ckpt.* 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2016 The TensorFlow Authors. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright 2016, The Authors. 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Evaluating and Understanding the Robustness of Adversarial Logit Pairing 2 | 3 | The code in this repository, forked from the [official 4 | implementation](https://github.com/tensorflow/models/tree/master/research/adversarial_logit_pairing), 5 | evaluates the robustness of [Adversarial Logit 6 | Pairing](https://arxiv.org/abs/1803.06373), a proposed defense against 7 | adversarial examples. 8 | 9 | On the ImageNet 64x64 dataset, with an L-infinity perturbation of 16/255 (the 10 | threat model considered in the original paper), we can make the classifier 11 | accuracy 0.1% and generate targeted adversarial examples (with randomly chosen 12 | target labels) with 98.6% success rate using the provided code and models. 13 | 14 | See our writeup [here](https://arxiv.org/abs/1807.10272) for our analysis, including visualizations of the loss landscape induced by Adversarial Logit Pairing. 15 | 16 | ## Pictures 17 | 18 | Obligatory pictures of adversarial examples (with randomly chosen target 19 | classes). 20 | 21 | ![](examples.png) 22 | 23 | ## Setup 24 | 25 | Download and untar the [ALP-trained 26 | ResNet-v2-50](http://download.tensorflow.org/models/adversarial_logit_pairing/imagenet64_alp025_2018_06_26.ckpt.tar.gz) 27 | model into the root of the repository. 28 | 29 | ## [RobustML](https://github.com/robust-ml/robustml) evaluation 30 | 31 | Run with: 32 | 33 | ``` 34 | python robustml_eval.py --imagenet-path 35 | ``` 36 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/labsix/adversarial-logit-pairing-analysis/185b2d2288f163c2d05beef6d99edf4798d92227/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/imagenet_input.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Imagenet input.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | from absl import flags 24 | import tensorflow as tf 25 | 26 | FLAGS = flags.FLAGS 27 | 28 | 29 | flags.DEFINE_string('imagenet_data_dir', None, 30 | 'Directory with Imagenet dataset in TFRecord format.') 31 | 32 | 33 | def _decode_and_random_crop(image_buffer, bbox, image_size): 34 | """Randomly crops image and then scales to target size.""" 35 | with tf.name_scope('distorted_bounding_box_crop', 36 | values=[image_buffer, bbox]): 37 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 38 | tf.image.extract_jpeg_shape(image_buffer), 39 | bounding_boxes=bbox, 40 | min_object_covered=0.1, 41 | aspect_ratio_range=[0.75, 1.33], 42 | area_range=[0.08, 1.0], 43 | max_attempts=10, 44 | use_image_if_no_bounding_boxes=True) 45 | bbox_begin, bbox_size, _ = sample_distorted_bounding_box 46 | 47 | # Crop the image to the specified bounding box. 48 | offset_y, offset_x, _ = tf.unstack(bbox_begin) 49 | target_height, target_width, _ = tf.unstack(bbox_size) 50 | crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) 51 | image = tf.image.decode_and_crop_jpeg(image_buffer, crop_window, channels=3) 52 | image = tf.image.convert_image_dtype( 53 | image, dtype=tf.float32) 54 | 55 | image = tf.image.resize_bicubic([image], 56 | [image_size, image_size])[0] 57 | 58 | return image 59 | 60 | 61 | def _decode_and_center_crop(image_buffer, image_size): 62 | """Crops to center of image with padding then scales to target size.""" 63 | shape = tf.image.extract_jpeg_shape(image_buffer) 64 | image_height = shape[0] 65 | image_width = shape[1] 66 | 67 | padded_center_crop_size = tf.cast( 68 | 0.875 * tf.cast(tf.minimum(image_height, image_width), tf.float32), 69 | tf.int32) 70 | 71 | offset_height = ((image_height - padded_center_crop_size) + 1) // 2 72 | offset_width = ((image_width - padded_center_crop_size) + 1) // 2 73 | crop_window = tf.stack([offset_height, offset_width, 74 | padded_center_crop_size, padded_center_crop_size]) 75 | image = tf.image.decode_and_crop_jpeg(image_buffer, crop_window, channels=3) 76 | image = tf.image.convert_image_dtype( 77 | image, dtype=tf.float32) 78 | 79 | image = tf.image.resize_bicubic([image], 80 | [image_size, image_size])[0] 81 | 82 | return image 83 | 84 | 85 | def _normalize(image): 86 | """Rescale image to [-1, 1] range.""" 87 | return tf.multiply(tf.subtract(image, 0.5), 2.0) 88 | 89 | 90 | def image_preprocessing(image_buffer, bbox, image_size, is_training): 91 | """Does image decoding and preprocessing. 92 | 93 | Args: 94 | image_buffer: string tensor with encoded image. 95 | bbox: bounding box of the object at the image. 96 | image_size: image size. 97 | is_training: whether to do training or eval preprocessing. 98 | 99 | Returns: 100 | Tensor with the image. 101 | """ 102 | if is_training: 103 | image = _decode_and_random_crop(image_buffer, bbox, image_size) 104 | image = _normalize(image) 105 | image = tf.image.random_flip_left_right(image) 106 | else: 107 | image = _decode_and_center_crop(image_buffer, image_size) 108 | image = _normalize(image) 109 | image = tf.reshape(image, [image_size, image_size, 3]) 110 | return image 111 | 112 | 113 | def imagenet_parser(value, image_size, is_training): 114 | """Parse an ImageNet record from a serialized string Tensor. 115 | 116 | Args: 117 | value: encoded example. 118 | image_size: size of the output image. 119 | is_training: if True then do training preprocessing, 120 | otherwise do eval preprocessing. 121 | 122 | Returns: 123 | image: tensor with the image. 124 | label: true label of the image. 125 | """ 126 | keys_to_features = { 127 | 'image/encoded': 128 | tf.FixedLenFeature((), tf.string, ''), 129 | 'image/format': 130 | tf.FixedLenFeature((), tf.string, 'jpeg'), 131 | 'image/class/label': 132 | tf.FixedLenFeature([], tf.int64, -1), 133 | 'image/class/text': 134 | tf.FixedLenFeature([], tf.string, ''), 135 | 'image/object/bbox/xmin': 136 | tf.VarLenFeature(dtype=tf.float32), 137 | 'image/object/bbox/ymin': 138 | tf.VarLenFeature(dtype=tf.float32), 139 | 'image/object/bbox/xmax': 140 | tf.VarLenFeature(dtype=tf.float32), 141 | 'image/object/bbox/ymax': 142 | tf.VarLenFeature(dtype=tf.float32), 143 | 'image/object/class/label': 144 | tf.VarLenFeature(dtype=tf.int64), 145 | } 146 | 147 | parsed = tf.parse_single_example(value, keys_to_features) 148 | 149 | image_buffer = tf.reshape(parsed['image/encoded'], shape=[]) 150 | 151 | xmin = tf.expand_dims(parsed['image/object/bbox/xmin'].values, 0) 152 | ymin = tf.expand_dims(parsed['image/object/bbox/ymin'].values, 0) 153 | xmax = tf.expand_dims(parsed['image/object/bbox/xmax'].values, 0) 154 | ymax = tf.expand_dims(parsed['image/object/bbox/ymax'].values, 0) 155 | # Note that ordering is (y, x) 156 | bbox = tf.concat([ymin, xmin, ymax, xmax], 0) 157 | # Force the variable number of bounding boxes into the shape 158 | # [1, num_boxes, coords]. 159 | bbox = tf.expand_dims(bbox, 0) 160 | bbox = tf.transpose(bbox, [0, 2, 1]) 161 | 162 | image = image_preprocessing( 163 | image_buffer=image_buffer, 164 | bbox=bbox, 165 | image_size=image_size, 166 | is_training=is_training 167 | ) 168 | 169 | # Labels are in [1, 1000] range 170 | label = tf.cast( 171 | tf.reshape(parsed['image/class/label'], shape=[]), dtype=tf.int32) 172 | 173 | return image, label 174 | 175 | 176 | def imagenet_input(split, batch_size, image_size, is_training): 177 | """Returns ImageNet dataset. 178 | 179 | Args: 180 | split: name of the split, "train" or "validation". 181 | batch_size: size of the minibatch. 182 | image_size: size of the one side of the image. Output images will be 183 | resized to square shape image_size*image_size. 184 | is_training: if True then training preprocessing is done, otherwise eval 185 | preprocessing is done. 186 | 187 | Raises: 188 | ValueError: if name of the split is incorrect. 189 | 190 | Returns: 191 | Instance of tf.data.Dataset with the dataset. 192 | """ 193 | if split.lower().startswith('train'): 194 | file_pattern = os.path.join(FLAGS.imagenet_data_dir, 'train-*') 195 | elif split.lower().startswith('validation'): 196 | file_pattern = os.path.join(FLAGS.imagenet_data_dir, 'validation-*') 197 | else: 198 | raise ValueError('Invalid split: %s' % split) 199 | 200 | dataset = tf.data.Dataset.list_files(file_pattern, shuffle=is_training) 201 | 202 | if is_training: 203 | dataset = dataset.repeat() 204 | 205 | def fetch_dataset(filename): 206 | return tf.data.TFRecordDataset(filename, buffer_size=8*1024*1024) 207 | 208 | # Read the data from disk in parallel 209 | dataset = dataset.apply( 210 | tf.contrib.data.parallel_interleave( 211 | fetch_dataset, cycle_length=4, sloppy=True)) 212 | dataset = dataset.shuffle(1024) 213 | 214 | # Parse, preprocess, and batch the data in parallel 215 | dataset = dataset.apply( 216 | tf.contrib.data.map_and_batch( 217 | lambda value: imagenet_parser(value, image_size, is_training), 218 | batch_size=batch_size, 219 | num_parallel_batches=4, 220 | drop_remainder=True)) 221 | 222 | def set_shapes(images, labels): 223 | """Statically set the batch_size dimension.""" 224 | images.set_shape(images.get_shape().merge_with( 225 | tf.TensorShape([batch_size, None, None, None]))) 226 | labels.set_shape(labels.get_shape().merge_with( 227 | tf.TensorShape([batch_size]))) 228 | return images, labels 229 | 230 | # Assign static batch size dimension 231 | dataset = dataset.map(set_shapes) 232 | 233 | # Prefetch overlaps in-feed with training 234 | dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) 235 | return dataset 236 | 237 | 238 | def num_examples_per_epoch(split): 239 | """Returns the number of examples in the data set. 240 | 241 | Args: 242 | split: name of the split, "train" or "validation". 243 | 244 | Raises: 245 | ValueError: if split name is incorrect. 246 | 247 | Returns: 248 | Number of example in the split. 249 | """ 250 | if split.lower().startswith('train'): 251 | return 1281167 252 | elif split.lower().startswith('validation'): 253 | return 50000 254 | else: 255 | raise ValueError('Invalid split: %s' % split) 256 | -------------------------------------------------------------------------------- /examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/labsix/adversarial-logit-pairing-analysis/185b2d2288f163c2d05beef6d99edf4798d92227/examples.png -------------------------------------------------------------------------------- /model_lib.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Library with common functions for training and eval.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import six 23 | 24 | import tensorflow as tf 25 | 26 | from tensorflow.contrib.slim.nets import resnet_v2 27 | 28 | 29 | def default_hparams(): 30 | """Returns default hyperparameters.""" 31 | return tf.contrib.training.HParams( 32 | # Batch size for training and evaluation. 33 | batch_size=32, 34 | eval_batch_size=50, 35 | 36 | # General training parameters. 37 | weight_decay=0.0001, 38 | label_smoothing=0.1, 39 | 40 | # Parameters of the adversarial training. 41 | train_adv_method='clean', # adversarial training method 42 | train_lp_weight=0.0, # Weight of adversarial logit pairing loss 43 | 44 | # Parameters of the optimizer. 45 | optimizer='rms', # possible values are: 'rms', 'momentum', 'adam' 46 | momentum=0.9, # momentum 47 | rmsprop_decay=0.9, # Decay term for RMSProp 48 | rmsprop_epsilon=1.0, # Epsilon term for RMSProp 49 | 50 | # Parameters of learning rate schedule. 51 | lr_schedule='exp_decay', # Possible values: 'exp_decay', 'step', 'fixed' 52 | learning_rate=0.045, 53 | lr_decay_factor=0.94, # Learning exponential decay 54 | lr_num_epochs_per_decay=2.0, # Number of epochs per lr decay 55 | lr_list=[1.0 / 6, 2.0 / 6, 3.0 / 6, 56 | 4.0 / 6, 5.0 / 6, 1.0, 0.1, 0.01, 57 | 0.001, 0.0001], 58 | lr_decay_epochs=[1, 2, 3, 4, 5, 30, 60, 80, 59 | 90]) 60 | 61 | 62 | def get_lr_schedule(hparams, examples_per_epoch, replicas_to_aggregate=1): 63 | """Returns TensorFlow op which compute learning rate. 64 | 65 | Args: 66 | hparams: hyper parameters. 67 | examples_per_epoch: number of training examples per epoch. 68 | replicas_to_aggregate: number of training replicas running in parallel. 69 | 70 | Raises: 71 | ValueError: if learning rate schedule specified in hparams is incorrect. 72 | 73 | Returns: 74 | learning_rate: tensor with learning rate. 75 | steps_per_epoch: number of training steps per epoch. 76 | """ 77 | global_step = tf.train.get_or_create_global_step() 78 | steps_per_epoch = float(examples_per_epoch) / float(hparams.batch_size) 79 | if replicas_to_aggregate > 0: 80 | steps_per_epoch /= replicas_to_aggregate 81 | 82 | if hparams.lr_schedule == 'exp_decay': 83 | decay_steps = long(steps_per_epoch * hparams.lr_num_epochs_per_decay) 84 | learning_rate = tf.train.exponential_decay( 85 | hparams.learning_rate, 86 | global_step, 87 | decay_steps, 88 | hparams.lr_decay_factor, 89 | staircase=True) 90 | elif hparams.lr_schedule == 'step': 91 | lr_decay_steps = [long(epoch * steps_per_epoch) 92 | for epoch in hparams.lr_decay_epochs] 93 | learning_rate = tf.train.piecewise_constant( 94 | global_step, lr_decay_steps, hparams.lr_list) 95 | elif hparams.lr_schedule == 'fixed': 96 | learning_rate = hparams.learning_rate 97 | else: 98 | raise ValueError('Invalid value of lr_schedule: %s' % hparams.lr_schedule) 99 | 100 | if replicas_to_aggregate > 0: 101 | learning_rate *= replicas_to_aggregate 102 | 103 | return learning_rate, steps_per_epoch 104 | 105 | 106 | def get_optimizer(hparams, learning_rate): 107 | """Returns optimizer. 108 | 109 | Args: 110 | hparams: hyper parameters. 111 | learning_rate: learning rate tensor. 112 | 113 | Raises: 114 | ValueError: if type of optimizer specified in hparams is incorrect. 115 | 116 | Returns: 117 | Instance of optimizer class. 118 | """ 119 | if hparams.optimizer == 'rms': 120 | optimizer = tf.train.RMSPropOptimizer(learning_rate, 121 | hparams.rmsprop_decay, 122 | hparams.momentum, 123 | hparams.rmsprop_epsilon) 124 | elif hparams.optimizer == 'momentum': 125 | optimizer = tf.train.MomentumOptimizer(learning_rate, 126 | hparams.momentum) 127 | elif hparams.optimizer == 'adam': 128 | optimizer = tf.train.AdamOptimizer(learning_rate) 129 | else: 130 | raise ValueError('Invalid value of optimizer: %s' % hparams.optimizer) 131 | return optimizer 132 | 133 | 134 | RESNET_MODELS = {'resnet_v2_50': resnet_v2.resnet_v2_50} 135 | 136 | 137 | def get_model(model_name, num_classes): 138 | """Returns function which creates model. 139 | 140 | Args: 141 | model_name: Name of the model. 142 | num_classes: Number of classes. 143 | 144 | Raises: 145 | ValueError: If model_name is invalid. 146 | 147 | Returns: 148 | Function, which creates model when called. 149 | """ 150 | if model_name.startswith('resnet'): 151 | def resnet_model(images, is_training, reuse=tf.AUTO_REUSE): 152 | with tf.contrib.framework.arg_scope(resnet_v2.resnet_arg_scope()): 153 | resnet_fn = RESNET_MODELS[model_name] 154 | logits, _ = resnet_fn(images, num_classes, is_training=is_training, 155 | reuse=reuse) 156 | logits = tf.reshape(logits, [-1, num_classes]) 157 | return logits 158 | return resnet_model 159 | else: 160 | raise ValueError('Invalid model: %s' % model_name) 161 | 162 | 163 | def filter_trainable_variables(trainable_scopes): 164 | """Keep only trainable variables which are prefixed with given scopes. 165 | 166 | Args: 167 | trainable_scopes: either list of trainable scopes or string with comma 168 | separated list of trainable scopes. 169 | 170 | This function removes all variables which are not prefixed with given 171 | trainable_scopes from collection of trainable variables. 172 | Useful during network fine tuning, when you only need to train subset of 173 | variables. 174 | """ 175 | if not trainable_scopes: 176 | return 177 | if isinstance(trainable_scopes, six.string_types): 178 | trainable_scopes = [scope.strip() for scope in trainable_scopes.split(',')] 179 | trainable_scopes = {scope for scope in trainable_scopes if scope} 180 | if not trainable_scopes: 181 | return 182 | trainable_collection = tf.get_collection_ref( 183 | tf.GraphKeys.TRAINABLE_VARIABLES) 184 | non_trainable_vars = [ 185 | v for v in trainable_collection 186 | if not any([v.op.name.startswith(s) for s in trainable_scopes]) 187 | ] 188 | for v in non_trainable_vars: 189 | trainable_collection.remove(v) 190 | -------------------------------------------------------------------------------- /robustml_attack.py: -------------------------------------------------------------------------------- 1 | import robustml 2 | import tensorflow as tf 3 | import numpy as np 4 | import sys 5 | 6 | class NullAttack(robustml.attack.Attack): 7 | def run(self, x, y, target): 8 | return x 9 | 10 | class PGDAttack(robustml.attack.Attack): 11 | def __init__(self, sess, model, epsilon, max_steps=100, step_size=0.01, quantize=False, debug=False): 12 | self._sess = sess 13 | self._model = model 14 | self._epsilon = epsilon 15 | self._max_steps = max_steps 16 | self._step_size = step_size 17 | self._quantize = quantize 18 | self._debug = debug 19 | 20 | self._label = tf.placeholder(tf.int32, ()) 21 | one_hot = tf.expand_dims(tf.one_hot(self._label, 1000), axis=0) 22 | self._loss = tf.nn.softmax_cross_entropy_with_logits_v2(logits=model.logits, labels=one_hot) 23 | self._grad, = tf.gradients(self._loss, model.input) 24 | 25 | def run(self, x, y, target): 26 | mult = -1 27 | untargeted = not target 28 | if target is None: 29 | target = y 30 | mult = 1 31 | lower = np.clip(x - self._epsilon, 0, 1) 32 | upper = np.clip(x + self._epsilon, 0, 1) 33 | adv = x + np.random.uniform(low=-self._epsilon, high=self._epsilon, size=x.shape) 34 | adv = np.clip(adv, lower, upper) 35 | for i in range(self._max_steps): 36 | if self._quantize: 37 | adv_eval = (adv*255).astype(np.uint8).astype(np.float32)/255.0 38 | else: 39 | adv_eval = adv 40 | p, l, g = self._sess.run( 41 | [self._model.predictions, self._loss, self._grad], 42 | {self._model.input: [adv_eval], self._label: target} 43 | ) 44 | if self._debug: 45 | print( 46 | 'attack: step %d/%d, loss = %g (true %d, predicted %d, target %d)' % ( 47 | i+1, 48 | self._max_steps, 49 | l, 50 | y, 51 | p, 52 | target 53 | ), 54 | file=sys.stderr 55 | ) 56 | if untargeted and p != y or not untargeted and p == target: 57 | # we're done 58 | if self._debug: 59 | print('returning early', file=sys.stderr) 60 | break 61 | adv += mult * self._step_size * np.sign(g[0]) 62 | adv = np.clip(adv, lower, upper) 63 | return adv 64 | -------------------------------------------------------------------------------- /robustml_eval.py: -------------------------------------------------------------------------------- 1 | from robustml_model import * 2 | from robustml_attack import * 3 | import tensorflow as tf 4 | import argparse 5 | 6 | def main(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--imagenet-path', type=str, required=True, 9 | help='directory containing `val.txt` and `val/` folder') 10 | parser.add_argument('--checkpoint-path', type=str, default='imagenet64_alp025_2018_06_26.ckpt', 11 | help='path to imagenet64 checkpoint') 12 | parser.add_argument('--start', type=int, default=0) 13 | parser.add_argument('--end', type=int, default=100) 14 | parser.add_argument('--attack', type=str, default='pgd', help='none | pgd') 15 | parser.add_argument('--attack-iterations', type=int, default=1000) 16 | parser.add_argument('--attack-step-size', type=float, default=0.005) 17 | parser.add_argument('--debug', action='store_true') 18 | parser.add_argument('--quantize', action='store_true') 19 | args = parser.parse_args() 20 | 21 | sess = tf.Session() 22 | 23 | model = ALP(sess, args.checkpoint_path, quantize=args.quantize) 24 | 25 | if args.attack == 'none': 26 | attack = NullAttack() 27 | elif args.attack == 'pgd': 28 | attack = PGDAttack( 29 | sess, 30 | model, 31 | model.threat_model.epsilon, 32 | debug=args.debug, 33 | max_steps=args.attack_iterations, 34 | step_size=args.attack_step_size, 35 | quantize=args.quantize 36 | ) 37 | else: 38 | raise ValueError('invalid attack: %s' % args.attack) 39 | 40 | provider = robustml.provider.ImageNet(args.imagenet_path, shape=(64, 64, 3)) 41 | 42 | success_rate = robustml.evaluate.evaluate( 43 | model, 44 | attack, 45 | provider, 46 | start=args.start, 47 | end=args.end, 48 | deterministic=True, 49 | debug=True 50 | ) 51 | 52 | print('attack success rate: %.2f%% (over %d data points)' % (success_rate*100, args.end - args.start)) 53 | 54 | if __name__ == '__main__': 55 | main() 56 | -------------------------------------------------------------------------------- /robustml_model.py: -------------------------------------------------------------------------------- 1 | import robustml 2 | import tensorflow as tf 3 | import numpy as np 4 | import model_lib 5 | from datasets import imagenet_input 6 | 7 | class ALP(robustml.model.Model): 8 | ''' 9 | ALP for ImageNet 64x64 10 | ''' 11 | def __init__(self, sess, checkpoint_path, quantize=False): 12 | self._sess = sess 13 | self._input = tf.placeholder(tf.float32, (None, 64, 64, 3)) 14 | self._logits = _model(sess, self._input, checkpoint_path) 15 | self._logits = self._logits[:, 1:] # ignore background class 16 | self._predictions = tf.argmax(self._logits, 1) 17 | self._dataset = robustml.dataset.ImageNet((64, 64, 3)) 18 | self._threat_model = robustml.threat_model.Linf(epsilon=16.0/255.0, targeted=True) 19 | self._quantize = quantize 20 | 21 | @property 22 | def dataset(self): 23 | return self._dataset 24 | 25 | @property 26 | def threat_model(self): 27 | return self._threat_model 28 | 29 | def classify(self, x): 30 | if self._quantize: 31 | x = (x*255).astype(np.uint8).astype(np.float32)/255.0 32 | return self._sess.run(self._predictions, {self._input: [x]})[0] 33 | 34 | # exposing some internals to make it less annoying for attackers to do a 35 | # white-box attack 36 | 37 | @property 38 | def input(self): 39 | return self._input 40 | 41 | @property 42 | def logits(self): 43 | return self._logits 44 | 45 | @property 46 | def predictions(self): 47 | return self._predictions 48 | 49 | def _model(sess, input_, checkpoint_path): 50 | model_fn_two_args = model_lib.get_model('resnet_v2_50', 1001) 51 | model_fn = lambda x: model_fn_two_args(x, is_training=False) 52 | preprocessed = imagenet_input._normalize(input_) 53 | logits = model_fn(preprocessed) 54 | variables_to_restore = tf.contrib.framework.get_variables_to_restore() 55 | saver = tf.train.Saver(variables_to_restore) 56 | saver.restore(sess, checkpoint_path) 57 | return logits 58 | --------------------------------------------------------------------------------