├── LICENSE ├── README.md ├── adv_reg_experiment.py ├── dataset.py ├── dataset_paths.ini ├── diff_ops.py ├── fwt.py ├── imagenet_data.py ├── imagenet_example.ini ├── imagenet_labels.csv ├── mnist_data.py ├── mnist_example.ini ├── preprocessing.py ├── resnet50.py ├── robust_model.py ├── smallnet.py ├── summary_utils.py ├── training.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 cetmann 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Description 2 | Code for the Paper 'On the Connection Between Adversarial Robustness and Saliency Map Interpretability' by C. Etmann, S. Lunz, P. Maass, C.-B. Schönlieb, accepted at ICML 2019. 3 | 4 | More in-depth documentation to follow. 5 | 6 | # Execution 7 | 8 | It can be run by executing e.g. 9 | ``` 10 | python3 adv_reg_experiment.py imagenet_example.ini 11 | ``` 12 | This code was tested using the following libraries in Python 3.6: 13 | 14 | Tensorflow 1.11 15 | 16 | Foolbox 1.8.0 17 | 18 | Keras 2.2.2 19 | 20 | Conda yml file to come. 21 | -------------------------------------------------------------------------------- /adv_reg_experiment.py: -------------------------------------------------------------------------------- 1 | """ 2 | All models are trained using this base script. It can be run via 3 | python adv_reg_experiment.py experiment.ini 4 | """ 5 | 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import configparser 13 | import argparse 14 | import pathlib 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | import tensorflow.keras as keras 19 | import tensorflow.keras.backend as K 20 | 21 | import utils 22 | import dataset 23 | import robust_model 24 | import training 25 | import summary_utils 26 | 27 | 28 | parser = argparse.ArgumentParser( 29 | description='Train an adversarially robust network.') 30 | parser.add_argument('ini_file', type=str) 31 | args = parser.parse_args() 32 | 33 | ini_file = args.ini_file 34 | 35 | 36 | # Load the desired specifications from the provided 37 | # INI-file. 38 | config = configparser.ConfigParser() 39 | config.read(ini_file) 40 | paths = config['PATHS'] 41 | hyperparameters = config['HYPERPARAMETERS'] 42 | logging = config['LOGGING'] 43 | architecture = config['ARCHITECTURE'] 44 | 45 | # PATHS 46 | base_name = os.path.splitext( 47 | os.path.basename(ini_file))[0] 48 | 49 | tensorboard_logdir = paths.get( 50 | 'tensorboard_logdir') + \ 51 | base_name + '/' 52 | dataset_name = paths.get( 53 | 'dataset_name') 54 | saved_model_folder = paths.get( 55 | 'saved_model_folder') 56 | saved_model_path = saved_model_folder + \ 57 | base_name + '/model.ckpt' 58 | 59 | # HYPERPARAMETERS 60 | lr_decrease_interval = int(hyperparameters.get( 61 | 'lr_decrease_interval')) 62 | lr_decrease_factor = float(hyperparameters.get( 63 | 'lr_decrease_factor')) 64 | batch_size_per_gpu = int(hyperparameters.get( 65 | 'batch_size_per_gpu')) 66 | robust_regularization = hyperparameters.getboolean( 67 | 'robust_regularization', 68 | True) 69 | use_wavelet_decomposition = hyperparameters.getboolean( 70 | 'use_wavelet_decomposition', 71 | True) 72 | sensitivity_mode = hyperparameters.get( 73 | 'sensitivity_mode', 74 | 'NLL') 75 | 76 | wavelet_weights = [float(i) 77 | for i in eval( 78 | hyperparameters.get( 79 | 'wavelet_weights', 80 | '[4,2,1,0]' 81 | ) 82 | ) 83 | ] 84 | decomp_type = hyperparameters.get( 85 | 'decomp_type') 86 | lp_wavelet_parameter= float(hyperparameters.get( 87 | 'lp_wavelet_parameter')) 88 | p_norm = float(hyperparameters.get( 89 | 'p_norm')) 90 | learning_rate_at_start= float(hyperparameters.get( 91 | 'learning_rate_at_start')) 92 | weight_decay_parameter = float(hyperparameters.get( 93 | 'weight_decay_parameter')) 94 | bn_momentum_value = float(hyperparameters.get( 95 | 'bn_momentum_value')) 96 | num_epochs = int(hyperparameters.get( 97 | 'num_epochs')) 98 | learning_phase = int(hyperparameters.get( 99 | 'learning_phase', 100 | 1)) 101 | 102 | # LOGGING 103 | train_summary_period = int(logging.get( 104 | 'train_summary_period')) 105 | val_summary_period = int(logging.get( 106 | 'val_summary_period')) 107 | adversarial_test_period = int(logging.get( 108 | 'adversarial_test_period')) 109 | num_adversarial_batches = int(logging.get( 110 | 'num_adversarial_batches')) 111 | 112 | # ARCHITECTURE 113 | Model = architecture['model'] 114 | pretrained = architecture.getboolean( 115 | 'pretrained', 116 | False) 117 | 118 | tf.logging.set_verbosity(tf.logging.INFO) 119 | 120 | # Create the saved_model_folder if it does not 121 | # exist yet. Otherwise, tf.Saver throws an error. 122 | pathlib.Path(saved_model_folder).mkdir(parents=True, 123 | exist_ok=True) 124 | summary_writer = tf.summary.FileWriter( 125 | tensorboard_logdir, 126 | None, 127 | flush_secs=30) 128 | 129 | 130 | d_y, d_x = dataset.image_resolution[dataset_name] 131 | num_classes = dataset.num_classes[dataset_name] 132 | num_parallel_calls = np.int32(512) 133 | 134 | 135 | # Create the necessary placeholders, which represent 136 | # the inputs and hyperparameters of the network. 137 | if dataset_name == 'MNIST': 138 | files = tf.placeholder(tf.float32,name='files') 139 | else: 140 | files = tf.placeholder(tf.string,name='files') 141 | labels = tf.placeholder(tf.int32,name='labels') 142 | 143 | if dataset_name == 'MNIST': 144 | x = tf.placeholder(tf.float32,[None,d_y,d_x,1],name='x') 145 | d_c = 1 146 | else: 147 | x = tf.placeholder(tf.float32,[None,d_y,d_x,3],name='x') 148 | d_c = 3 149 | l1_p = tf.placeholder(tf.float32,(),name='l1_parameter') 150 | l2_p = tf.placeholder(tf.float32,(),name='l2_parameter') 151 | lp_wavelet_p = tf.placeholder(tf.float32,(),name='lp_wavelet_parameter') 152 | weight_decay_p = tf.placeholder(tf.float32,(),name='weight_decay_parameter') 153 | starter_learning_rate = tf.placeholder(tf.float32,(),name='starter_learning_rate') 154 | bn_momentum = tf.placeholder(tf.float32,(),name='bn_momentum') 155 | batch_size = tf.placeholder(tf.int64,(),name='batch_size') 156 | epoch_step = tf.Variable(0, trainable=False) 157 | batch_step = tf.Variable(0, trainable=False) 158 | learning_rate = tf.train.exponential_decay(starter_learning_rate, 159 | epoch_step, 160 | lr_decrease_interval, 161 | lr_decrease_factor, 162 | staircase=True) 163 | 164 | config=tf.ConfigProto( 165 | allow_soft_placement = True, 166 | log_device_placement = True) 167 | graph = tf.get_default_graph() 168 | session = tf.Session(graph=graph, 169 | config=config) 170 | K.set_session(session) 171 | K.set_image_data_format('channels_last') 172 | session.run(K.learning_phase(), feed_dict={K.learning_phase(): 0}) 173 | session.as_default() 174 | 175 | 176 | # Depending on the model specified in the INI-file, 177 | # load the preferred model. The create_model method 178 | # has to output a tf.keras.models.Model object. If 179 | # the model has BatchNorm layers, the output of this 180 | # model should be a list where the the first entry is 181 | # the logit output of the network and the remaining entries 182 | # should be the Keras tensors which are the incoming nodes 183 | # to these BatchNorm layers. When using BatchNorm with 184 | # the proposed robust regularization, the used_fused 185 | # parameter needs to be 'False'. If no BatchNorm is used, 186 | # the output of the model should just be the logit output 187 | # of the network. 188 | if Model == 'VGG16': 189 | from vgg16 import create_model 190 | if Model == 'ResNet6': 191 | from resnet6 import create_model 192 | if Model == 'ResNet18': 193 | from resnet18 import create_model 194 | if Model == 'ResNet50': 195 | from resnet50 import create_model 196 | if Model == 'SmallNet': 197 | from smallnet import create_model 198 | model = create_model(input_tensor=x, 199 | input_shape = [d_y,d_x,d_c], 200 | num_classes = num_classes, 201 | pretrained = pretrained) 202 | 203 | 204 | # Create the optimizer. 205 | optimizer = tf.train.MomentumOptimizer( 206 | learning_rate=learning_rate, 207 | momentum=.9) 208 | 209 | # Load the dataset object which automatically 210 | # handles the data loading and preprocessing 211 | # and create iterator objects. 212 | dataset = dataset.dataset( 213 | files_ph = files, 214 | labels_ph = labels, 215 | batch_size_ph = batch_size, 216 | dataset_name = dataset_name) 217 | 218 | 219 | # We use a string training handle to allow for quick switching 220 | # between the different batch iterators. 221 | handle = tf.placeholder(tf.string) 222 | iterator = tf.data.Iterator.from_string_handle( 223 | handle, 224 | dataset.train_batch_iterator.output_types, 225 | dataset.train_batch_iterator.output_shapes) 226 | 227 | 228 | # Create a robust model according to our specifications. 229 | robust_model = robust_model.robust_model(iterator = iterator, 230 | session = session, 231 | model = model, 232 | num_classes = num_classes, 233 | optimizer = optimizer, 234 | dataset = dataset, 235 | p_norm = p_norm, 236 | decomp_type = decomp_type, 237 | learning_rate = learning_rate, 238 | weight_decay_p = weight_decay_p, 239 | lp_wavelet_p = lp_wavelet_p, 240 | batch_size = batch_size, 241 | bn_momentum = bn_momentum, 242 | robust_regularization = robust_regularization, 243 | use_wavelet_decomposition = use_wavelet_decomposition, 244 | wavelet_weights = wavelet_weights, 245 | sensitivity_mode = sensitivity_mode) 246 | GPU_collections = robust_model.GPU_collections 247 | 248 | 249 | # Apply these attacks every now and then in order to get an impression 250 | # of the adversarial robustness during training. 251 | attack_types = ['GradientAttack'] 252 | 253 | # Create the training and logging loop. 254 | training_procedure = training.training( 255 | handle = handle, 256 | dataset = dataset, 257 | batch_size_placeholder = batch_size, 258 | train_op = robust_model.train_op, 259 | session = session, 260 | pretrained = pretrained, 261 | epoch_step = epoch_step, 262 | batch_step = batch_step, 263 | summary_writer = summary_writer, 264 | train_summary_op = robust_model.summary_op, 265 | img_summary_op = robust_model.img_summary_op, 266 | optimizer = optimizer, 267 | GPU_collections = robust_model.GPU_collections, 268 | adversarial_model = robust_model.adversarial_model, 269 | adversarial_attacks = attack_types, 270 | saver_path = saved_model_path, 271 | num_adversarial_batches = num_adversarial_batches, 272 | batch_size = batch_size_per_gpu, 273 | num_epochs = num_epochs, 274 | train_summary_period = train_summary_period, 275 | val_summary_period = val_summary_period, 276 | adv_summary_period = adversarial_test_period) 277 | 278 | train_feed_dict = {starter_learning_rate: learning_rate_at_start, 279 | K.learning_phase(): learning_phase, 280 | weight_decay_p: weight_decay_parameter, 281 | batch_size: batch_size_per_gpu, 282 | lp_wavelet_p: lp_wavelet_parameter, 283 | bn_momentum: bn_momentum_value} 284 | val_feed_dict = {K.learning_phase(): 0, 285 | batch_size: batch_size_per_gpu} 286 | training_procedure.train(train_feed_dict, 287 | val_feed_dict) 288 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | This base class defines a dataset API which is used for both ImageNet and MNIST. 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import numpy as np 10 | 11 | class dataset: 12 | 13 | # Method for train data shuffling 14 | 15 | def __init__(self, 16 | files_ph, 17 | labels_ph, 18 | batch_size_ph, 19 | dataset_name, 20 | img_indices = None, 21 | num_parallel_calls = np.int32(512)): 22 | # tf.placeholders go here 23 | self.files_ph = files_ph 24 | self.labels_ph = labels_ph 25 | self.batch_size_ph = batch_size_ph 26 | 27 | # This is a string 28 | self.dataset_name = dataset_name 29 | 30 | self.img_indices = img_indices 31 | 32 | self.num_parallel_calls = num_parallel_calls 33 | 34 | 35 | if dataset_name == 'ImageNet': 36 | import imagenet_data as data 37 | if dataset_name == 'ImageNetSingleFiles': 38 | import imagenet_data_from_single_files as data 39 | if dataset_name == 'TinyImageNet': 40 | import tiny_imagenet_data as data 41 | if dataset_name == 'MNIST': 42 | import mnist_data as data 43 | 44 | 45 | # These are either file names or tensors. 46 | # (Tiny)ImageNet has file names, MNIST 47 | # has tensors. 48 | self.train_data, self.train_labels = \ 49 | data.collect_train_data() 50 | self.val_data, self.val_labels = \ 51 | data.collect_val_data() 52 | 53 | self.bounds = data.bounds() 54 | self.image_range = data.image_range 55 | 56 | # If necessary, turn lists into arrays to allow 57 | # for more elaborate subindexing 58 | if type(self.train_data) == list : 59 | self.train_data = np.array(self.train_data) 60 | if type(self.train_labels) == list : 61 | self.train_labels = np.array(self.train_labels) 62 | if type(self.val_data) == list : 63 | self.val_data = np.array(self.val_data) 64 | if type(self.val_labels) == list : 65 | self.val_labels = np.array(self.val_labels) 66 | 67 | self.num_train_samples = data.num_train_samples() 68 | self.num_val_samples = data.num_val_samples() 69 | 70 | if img_indices is None: 71 | self.img_indices = img_indices 72 | self.img_data = self.val_data 73 | self.img_labels = self.val_labels 74 | else: 75 | self.img_data = self.val_data[self.img_indices] 76 | self.img_labels = self.val_labels[self.img_indices] 77 | 78 | 79 | 80 | 81 | 82 | # batch iterators 83 | self.train_batch_iterator = data.train_BI( 84 | self.files_ph, 85 | self.labels_ph, 86 | self.batch_size_ph, 87 | num_parallel_calls) 88 | self.val_batch_iterator = data.val_BI( 89 | self.files_ph, 90 | self.labels_ph, 91 | self.batch_size_ph, 92 | num_parallel_calls) 93 | self.img_batch_iterator = data.img_BI( 94 | self.files_ph, 95 | self.labels_ph, 96 | self.batch_size_ph, 97 | num_parallel_calls) 98 | 99 | self.train_handle = None 100 | self.val_handle = None 101 | self.img_handle = None 102 | 103 | self.interpret_as_image = data.interpret_as_image 104 | 105 | self.n_classes = num_classes[dataset_name] 106 | 107 | # Shuffles (training) data on numpy-level 108 | def shuffle_input(self, 109 | data, 110 | labels): 111 | num_samples = len(labels) 112 | shuffle_indices = np.arange(num_samples) 113 | np.random.shuffle(shuffle_indices) 114 | labels = labels[shuffle_indices] 115 | data = data[shuffle_indices] 116 | return data, labels 117 | 118 | def initialize_train_batch_iterator(self, 119 | session, 120 | batch_size = 16, 121 | shuffle = True): 122 | self.train_data, self.train_labels = self.shuffle_input( 123 | self.train_data, self.train_labels) 124 | session.run(self.train_batch_iterator.initializer, 125 | feed_dict={self.files_ph: self.train_data, 126 | self.labels_ph: self.train_labels, 127 | self.batch_size_ph: batch_size}) 128 | if not self.train_handle: 129 | self.get_train_handle(session) 130 | 131 | 132 | def initialize_val_batch_iterator(self, 133 | session, 134 | batch_size = 16): 135 | session.run(self.val_batch_iterator.initializer, 136 | feed_dict={self.files_ph: self.val_data, 137 | self.labels_ph: self.val_labels, 138 | self.batch_size_ph: batch_size}) 139 | if not self.val_handle: 140 | self.get_val_handle(session) 141 | 142 | def initialize_img_batch_iterator(self, 143 | session, 144 | batch_size = 16): 145 | session.run(self.img_batch_iterator.initializer, 146 | feed_dict={self.files_ph: self.img_data, 147 | self.labels_ph: self.img_labels, 148 | self.batch_size_ph: batch_size}) 149 | if not self.img_handle: 150 | self.get_img_handle(session) 151 | 152 | def get_train_handle(self,session): 153 | self.train_handle = session.run( 154 | self.train_batch_iterator.string_handle()) 155 | 156 | def get_val_handle(self,session): 157 | self.val_handle = session.run( 158 | self.val_batch_iterator.string_handle()) 159 | 160 | def get_img_handle(self,session): 161 | self.img_handle = session.run( 162 | self.img_batch_iterator.string_handle()) 163 | 164 | 165 | 166 | 167 | image_resolution = { 168 | 'ImageNet' : [256,256], 169 | 'TinyImageNet' : [64,64], 170 | 'MNIST' : [28,28], 171 | 'CIFAR-10' : [32,32], 172 | 'CIFAR-100' : [32,32] 173 | } 174 | 175 | num_classes = { 176 | 'ImageNet' : 1000, 177 | 'TinyImageNet' : 200, 178 | 'MNIST' : 10, 179 | 'CIFAR-10' : 10, 180 | 'CIFAR-100' : 100 181 | } -------------------------------------------------------------------------------- /dataset_paths.ini: -------------------------------------------------------------------------------- 1 | [PATHS] 2 | TinyImageNet = /localdata/TinyImageNet/tiny-imagenet-200/ 3 | ImageNet = /localdata/ImageNet/ 4 | MNIST = /localdata/MNIST/ -------------------------------------------------------------------------------- /diff_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines the L-Op and R-Op, based on https://j-towns.github.io/2017/06/12/A-new-trick.html 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import tensorflow as tf 10 | 11 | def Lop(nodes, x, v): 12 | lop_out = tf.gradients(nodes, x, grad_ys=v) 13 | return lop_out 14 | 15 | 16 | def Rop(nodes, x, v): 17 | if isinstance(nodes, list): 18 | u = [tf.ones_like(node) for node in nodes] 19 | else: 20 | u = tf.ones_like(nodes) 21 | rop_out = tf.gradients( 22 | Lop(nodes, x, u),u,grad_ys=v) 23 | return rop_out 24 | -------------------------------------------------------------------------------- /fwt.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines the fast wavelet transform, which did not end up being used in the paper. 3 | """ 4 | 5 | 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import tensorflow as tf 12 | import numpy as np 13 | import pywt as pywt 14 | 15 | def create_filter_bank(wavelet_type='bior2.2'): 16 | """ 17 | This functions implements a 2D filter bank. 18 | wavelet_type -- str. A pywt-compatible wavelet type. 19 | """ 20 | 21 | # Load the wavelet from pywt 22 | w = pywt.Wavelet(wavelet_type) 23 | 24 | # Tensorflow implements convolutional operators 25 | # as cross-correlations, which is equivalent to 26 | # convolutions with a flipped kernel. Flip the 27 | # wavelets in order to be compatible with 28 | # Tensorflow-style convolution. 29 | dec_hi = w.dec_hi[::-1] 30 | dec_lo = w.dec_lo[::-1] 31 | rec_hi = w.rec_hi 32 | rec_lo = w.rec_lo 33 | 34 | # The filter banks need to be at least of size 4 35 | # in order to work with the below-implemented 36 | # padding. Fill up with zeros. 37 | for l in [dec_hi, dec_lo, rec_hi, rec_lo]: 38 | while len(l)<4: 39 | l = [0.] + l 40 | 41 | # Turn the lists of numbers into Tensorflow constants 42 | dec_hi = tf.constant(dec_hi) 43 | dec_lo = tf.constant(dec_lo) 44 | rec_hi = tf.constant(rec_hi) 45 | rec_lo = tf.constant(rec_lo) 46 | 47 | # Separable 2D scaling functions and wavelets are realized through 48 | # tensor products of the 1D scaling functions and wavelets. 49 | lo_lo_dec = tf.expand_dims(dec_lo,0)*tf.expand_dims(dec_lo,1) 50 | lo_hi_dec = tf.expand_dims(dec_lo,0)*tf.expand_dims(dec_hi,1) 51 | hi_lo_dec = tf.expand_dims(dec_hi,0)*tf.expand_dims(dec_lo,1) 52 | hi_hi_dec = tf.expand_dims(dec_hi,0)*tf.expand_dims(dec_hi,1) 53 | 54 | lo_lo_rec = tf.expand_dims(rec_lo,0)*tf.expand_dims(rec_lo,1) 55 | lo_hi_rec = tf.expand_dims(rec_lo,0)*tf.expand_dims(rec_hi,1) 56 | hi_lo_rec = tf.expand_dims(rec_hi,0)*tf.expand_dims(rec_lo,1) 57 | hi_hi_rec = tf.expand_dims(rec_hi,0)*tf.expand_dims(rec_hi,1) 58 | 59 | # Turn this filter bank into a Tensorflow-compatible 60 | # convolutional kernel. The convention is 61 | # conv2d shape = [H,W,I,O] and 62 | # conv2d_transpose shape = [H,W,O,I] and 63 | # This means that 64 | # filter_bank_dec.shape = [H,W,1,4] and also 65 | # filter_bank_rec.shape = [H,W,1,4] 66 | 67 | filter_bank_dec = tf.stack([lo_lo_dec,lo_hi_dec,hi_lo_dec,hi_hi_dec], 68 | axis=2) 69 | filter_bank_dec = tf.expand_dims(filter_bank_dec,2) 70 | 71 | filter_bank_rec = tf.stack([lo_lo_rec,lo_hi_rec,hi_lo_rec,hi_hi_rec], 72 | axis=2) 73 | filter_bank_rec = tf.expand_dims(filter_bank_rec,2) 74 | 75 | return filter_bank_dec, filter_bank_rec 76 | 77 | def fwt(f_img, 78 | filter_bank, 79 | scale = 1, 80 | pad_mode = 'REFLECT', 81 | output_type = 'image', 82 | method = "strided"): 83 | """ 84 | Implements the fast wavelet transform for orthogonal or 85 | biorthogonal wavelets. 86 | f_img -- Tensorflow Tensor. The input image. 87 | filter_bank -- Tensorflow Tensor. The filter_bank. 88 | scale -- int. The scale up to which the FWT is computed. 89 | pad_mode -- 'REFLECT', 'SYMMETRIC' or 'CONSTANT' for boundary 90 | handling. 91 | output_type -- 'image' or 'list'. 'list' in incompatible with 92 | multi_channel_FWT so far. 93 | method -- 'strided' or 'downsampling'. Only for troubleshooting. 94 | """ 95 | 96 | # The convolution and subsequent subsampling is realized 97 | # through a strided convolution with padding. 98 | filter_shape = filter_bank.get_shape().as_list() 99 | h = filter_shape[0]//2-1 100 | w = filter_shape[1]//2-1 101 | 102 | 103 | if method == "strided": 104 | f_img = tf.pad(f_img,[[0,0],[h,h],[w,w],[0,0]],pad_mode) 105 | filtered = tf.nn.conv2d(f_img, 106 | filter_bank, 107 | strides=[1,2,2,1], 108 | padding='VALID') 109 | if method == "downsampling": 110 | # gives EXACTLY the same result as "strided" with pad_mode="CONSTANT". 111 | # to be deleted. 112 | filtered = tf.nn.conv2d(f_img, 113 | filter_bank, 114 | strides=[1,1,1,1], 115 | padding='SAME') 116 | downsampling_kernel = np.zeros((2,2,4,4),dtype=np.float32) 117 | for i in range(4): 118 | downsampling_kernel[:,:,i,i] = np.array([[1,0],[0,0]]) 119 | filtered = tf.nn.conv2d(filtered, 120 | downsampling_kernel, 121 | strides=[1,2,2,1], 122 | padding='SAME') 123 | 124 | 125 | 126 | # filtered is a tensor of dimension [N,H,W,4]. 127 | # Turn this into 4 tensors of dimension [N,H,W,1] in 128 | # order to be processable by the recursive function. 129 | filtered = tf.unstack(filtered, axis=-1) 130 | a = tf.expand_dims(filtered[0],axis=-1) 131 | d_1 = tf.expand_dims(filtered[1],axis=-1) 132 | d_2 = tf.expand_dims(filtered[2],axis=-1) 133 | d_3 = tf.expand_dims(filtered[3],axis=-1) 134 | 135 | # FWT recursion of a. 136 | if scale>1: 137 | a = fwt(a, 138 | filter_bank, 139 | scale-1, 140 | pad_mode, 141 | output_type) 142 | 143 | if output_type == 'image': 144 | filtered_upper = tf.concat([a,d_2],1) 145 | filtered_lower = tf.concat([d_1,d_3],1) 146 | filtered = tf.concat([filtered_upper, filtered_lower],2) 147 | return filtered 148 | elif output_type == 'list': 149 | return [a,d_1,d_2,d_3] 150 | 151 | def multi_channel_fwt(f_img, 152 | filter_bank, 153 | scale = 1, 154 | pad_mode = 'REFLECT', 155 | output_type = 'image'): 156 | 157 | """ 158 | This function implements the multi-channel FWT. 159 | Currently only works with 'image' 160 | """ 161 | 162 | img_channels = tf.unstack(f_img, axis=-1) 163 | transformed_channels = [] 164 | for channel in img_channels: 165 | transformed_channels.append( 166 | fwt(tf.expand_dims(channel,axis=-1), 167 | filter_bank, 168 | scale, 169 | pad_mode, 170 | output_type = output_type)) 171 | # only works with 'image' so far 172 | if output_type == 'image': 173 | combined_img = tf.stack(transformed_channels, 174 | axis = -1) 175 | return tf.squeeze(combined_img, [3]) 176 | if output_type == 'list': 177 | return transformed_channels -------------------------------------------------------------------------------- /imagenet_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines how to handle the ImageNet dataset from tfrecord files. The tfrecord files used in this work were 3 | created using the code from 4 | https://github.com/tensorflow/models/blob/master/research/inception/inception/data/build_imagenet_data.py 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import tensorflow as tf 12 | import numpy as np 13 | import os 14 | import csv 15 | import preprocessing 16 | from preprocessing import image_preprocessing 17 | import configparser 18 | 19 | 20 | this_folder = os.path.dirname(os.path.abspath(__file__)) + '/' 21 | 22 | config = configparser.ConfigParser() 23 | config.read(this_folder + 'dataset_paths.ini') 24 | base_folder = config['PATHS'].get('ImageNet') 25 | 26 | # Load a class dictionary that matches the pre-trained 27 | # encoding. 28 | labels_dict = {} 29 | with open(this_folder + 'imagenet_labels.csv', 'rt') as csvfile: 30 | file_contents = csv.reader(csvfile, delimiter=',') 31 | for row in file_contents: 32 | labels_dict[row[0]] = row[1] 33 | 34 | ##### Training ########### 35 | def collect_train_data(num_samples_per_class=0): 36 | print("Collecting training data...") 37 | tfrecord_folder = base_folder + 'tfrecord/' 38 | file_list = os.listdir(tfrecord_folder) 39 | train_data = [(tfrecord_folder + f) for f in file_list if 'train-' in f] 40 | # Create dummy labels, because the label information is contained 41 | # in the train_data files. 42 | train_labels = np.zeros_like(train_data,dtype=np.int32) 43 | return train_data, train_labels 44 | 45 | def collect_val_data(): 46 | print("Collecting validation data...") 47 | tfrecord_folder = base_folder + 'tfrecord/' 48 | file_list = os.listdir(tfrecord_folder) 49 | val_data = [(tfrecord_folder + f) for f in file_list if 'validation-' in f] 50 | # Create dummy labels, because the label information is contained 51 | # in the train_data files. 52 | val_labels = np.zeros_like(val_data,dtype=np.int32) 53 | return val_data, val_labels 54 | 55 | # tf.data batch iterator for the training data 56 | def train_BI(filenames, 57 | labels, 58 | batch_size, 59 | num_parallel_calls): 60 | dataset = tf.data.TFRecordDataset(filenames) 61 | batch_prepare = lambda image: image_preprocessing(image, 62 | None, 63 | file_type = 'tfrecord', 64 | shape = [256,256], 65 | random_events = True, 66 | data_augmentation = True, 67 | additive_noise = False, 68 | subtract_by = 'ImageNet', 69 | divide_by = 1., 70 | colorspace = 'BGR', 71 | min_rescale = 258, 72 | rescale = True, 73 | noise_level = 10., 74 | clip_values = bounds()) 75 | dataset = dataset.map(batch_prepare,num_parallel_calls=num_parallel_calls) 76 | batched_dataset = dataset.batch(batch_size, 77 | drop_remainder = False) 78 | train_batch_iterator = batched_dataset.make_initializable_iterator() 79 | return train_batch_iterator 80 | 81 | # tf.data batch iterator for the validation data 82 | def val_BI(filenames, 83 | labels, 84 | batch_size, 85 | num_parallel_calls): 86 | dataset = tf.data.TFRecordDataset(filenames) 87 | batch_prepare = lambda image: image_preprocessing(image, 88 | None, 89 | file_type = 'tfrecord', 90 | shape = [256,256], 91 | random_events = False, 92 | data_augmentation = False, 93 | additive_noise = False, 94 | subtract_by = 'ImageNet', 95 | divide_by = 1., 96 | colorspace = 'BGR', 97 | min_rescale = 258, 98 | rescale = True) 99 | dataset = dataset.map(batch_prepare, 100 | num_parallel_calls=num_parallel_calls) 101 | batched_dataset = dataset.batch(batch_size, 102 | drop_remainder = False) 103 | batch_iterator = batched_dataset.make_initializable_iterator() 104 | return batch_iterator 105 | 106 | # Additional tf.data batch iterator for the data that is used just for the propagation 107 | # of a few images for visualization. 108 | def img_BI(filenames, 109 | labels, 110 | batch_size, 111 | num_parallel_calls): 112 | dataset = tf.data.TFRecordDataset(filenames) 113 | batch_prepare = lambda image: image_preprocessing(image, 114 | None, 115 | file_type = 'tfrecord', 116 | shape = [256,256], 117 | random_events = False, 118 | data_augmentation = False, 119 | additive_noise = False, 120 | subtract_by = 'ImageNet', 121 | divide_by = 1., 122 | colorspace = 'BGR', 123 | min_rescale = 258, 124 | rescale = True) 125 | dataset = dataset.map(batch_prepare, 126 | num_parallel_calls=num_parallel_calls) 127 | batched_dataset = dataset.batch(batch_size, 128 | drop_remainder = False) 129 | batch_iterator = batched_dataset.make_initializable_iterator() 130 | return batch_iterator 131 | 132 | def interpret_as_image(image): 133 | return preprocessing.interpret_as_image(image, 134 | add_by='ImageNet', 135 | colorspace='BGR') 136 | 137 | def num_train_samples(): 138 | return 1281167 139 | 140 | def num_val_samples(): 141 | return 50000 142 | 143 | def bounds(): 144 | # This is a little problematic here. Foolbox only allows 145 | # for scalar bounds, not bounds per channel. For this reason, 146 | # we use the worst-case bounds here. 147 | return (-130., 255.-100.) 148 | 149 | min_values = np.array([0.,0.,0.],np.float32) 150 | max_values = np.array([1.,1.,1.],np.float32) 151 | def image_range(): 152 | return [0.,255.] -------------------------------------------------------------------------------- /imagenet_example.ini: -------------------------------------------------------------------------------- 1 | [PATHS] 2 | # The used dataset. Currently only supports 'TinyImageNet', 'MNIST' 3 | # and 'ImageNet'. 4 | dataset_name = ImageNet 5 | # The folder in which to save the different tensorboard summaries. 6 | # For a cohort of experiments, this should ideally be the same. 7 | tensorboard_logdir = /localdata/logs/ 8 | # The folder in which to save the different tensorboard summaries. 9 | # For a cohort of experiments, this should ideally be the same. 10 | saved_model_folder = /localdata/models/ 11 | 12 | [ARCHITECTURE] 13 | # Possible values: VGG16, ResNet18 or ResNet50 14 | model = ResNet50 15 | # Do we want to use a pretrained model? ONLY works 16 | # if we use ResNet50 AND ImageNet data. 17 | pretrained = True 18 | 19 | [HYPERPARAMETERS] 20 | # The total number of epochs. 21 | num_epochs = 20 22 | # The learning rate at the beginning of training. 23 | learning_rate_at_start = 0.00001 24 | # Every lr_decrease_interval epochs, multiply the 25 | # current learning rate by lr_decrease_factor. 26 | # If you don't want this, simply set lr_decrease_interval 27 | # higher than num_epochs. 28 | lr_decrease_interval = 15 29 | lr_decrease_factor = .1 30 | # Batch size per GPU. 31 | batch_size_per_gpu = 16 32 | # Whether to penalize a weighted sum of 1-norms of 33 | # wavelet coefficients of the saliency. 'False' is faster 34 | # than just setting the multipliers to zero. 35 | robust_regularization = True 36 | # Whether to decompose \nabla_x L into its wavelet coefficients. 37 | # If 'False', the regularization is applied to just the image. 38 | use_wavelet_decomposition = False 39 | # Define with respect to which output the gradients are 40 | # calculated. 'logits' or 'NLL'. 41 | sensitivity_mode = NLL 42 | # Multiplier in front of the penalty term. The name is currently a misnomer. 43 | lp_wavelet_parameter = 10000.0 44 | # Which p-norm to use 45 | p_norm = 2 46 | # Squared 2-norm weight penalty parameter. 47 | weight_decay_parameter = 0.0001 48 | # Exponential moving average multiplier for the batch 49 | # normalization layers, if present. 50 | bn_momentum_value = .999 51 | # Whether to set K.learning_phase() to 0 or 1 when 52 | # training. If we train batch normalized networks, 53 | # setting this to 0 has the effect of using the 54 | # running mean statistics instead of the batch 55 | # statistics, which is more stable but slower in the 56 | # beginning. 57 | learning_phase = 1 58 | 59 | [LOGGING] 60 | # Execute the training summary operator every 61 | # train_summary_period batches. 62 | train_summary_period = 500 63 | # Execute the training summary operator every 64 | # val_summary_period batches. 65 | val_summary_period = 20000 66 | # When checking the adversarial vulnerability, 67 | # try the different attacks on num_adversarial_batches 68 | # batches of the img_data set. This may take a considerable 69 | # amount of time so don't set this too high. 70 | num_adversarial_batches = 8 71 | # Execute the training summary operator every 72 | # train_summary_period batches. 73 | adversarial_test_period = 50000 74 | -------------------------------------------------------------------------------- /imagenet_labels.csv: -------------------------------------------------------------------------------- 1 | n04081281,762 2 | n02747177,412 3 | n02114855,272 4 | n02097474,200 5 | n04447861,861 6 | n04127249,772 7 | n03770679,656 8 | n01704323,51 9 | n01984695,123 10 | n02879718,456 11 | n03903868,708 12 | n02074367,149 13 | n01748264,63 14 | n03690938,631 15 | n03255030,543 16 | n01558993,15 17 | n02113978,268 18 | n04023962,747 19 | n02790996,422 20 | n04125021,771 21 | n04606251,913 22 | n07714571,936 23 | n03782006,664 24 | n07880968,965 25 | n02093256,179 26 | n01531178,11 27 | n01582220,18 28 | n02105251,226 29 | n01742172,61 30 | n03786901,666 31 | n02268853,320 32 | n04317175,823 33 | n02100236,210 34 | n02094258,186 35 | n02099712,208 36 | n03110669,513 37 | n02174001,306 38 | n03929855,715 39 | n01695060,48 40 | n02965783,475 41 | n03961711,729 42 | n02129604,292 43 | n02807133,433 44 | n03271574,545 45 | n06596364,917 46 | n03594734,608 47 | n03843555,686 48 | n03160309,525 49 | n04141076,776 50 | n03187595,528 51 | n11879895,984 52 | n02095889,190 53 | n04330267,827 54 | n02117135,276 55 | n04252077,802 56 | n01768244,69 57 | n03530642,599 58 | n02226429,311 59 | n02493509,380 60 | n02480855,366 61 | n04070727,760 62 | n03452741,579 63 | n04486054,873 64 | n02808440,435 65 | n03630383,617 66 | n12998815,992 67 | n01751748,65 68 | n03134739,522 69 | n03854065,687 70 | n07742313,948 71 | n04141327,777 72 | n07730033,946 73 | n02033041,142 74 | n02641379,395 75 | n02119022,277 76 | n02328150,332 77 | n04548362,893 78 | n03393912,565 79 | n02123045,281 80 | n03000134,489 81 | n02002556,127 82 | n02119789,278 83 | n02099849,209 84 | n01806143,84 85 | n07932039,969 86 | n01697457,49 87 | n03498962,596 88 | n01616318,23 89 | n03492542,592 90 | n02172182,305 91 | n02105412,227 92 | n03792782,671 93 | n03938244,721 94 | n03482405,588 95 | n04264628,810 96 | n04179913,786 97 | n01843383,96 98 | n04461696,864 99 | n02795169,427 100 | n03763968,652 101 | n01873310,103 102 | n03967562,730 103 | n04399382,850 104 | n04507155,879 105 | n02097298,199 106 | n03100240,511 107 | n01496331,5 108 | n02454379,363 109 | n04120489,770 110 | n04275548,815 111 | n04592741,908 112 | n01560419,16 113 | n03447447,576 114 | n03534580,601 115 | n07753113,952 116 | n09332890,975 117 | n02783161,418 118 | n02108422,243 119 | n02098105,202 120 | n04465501,866 121 | n02526121,390 122 | n04310018,820 123 | n02978881,481 124 | n02027492,140 125 | n04548280,892 126 | n04597913,910 127 | n01774384,75 128 | n03877472,697 129 | n02930766,468 130 | n10565667,983 131 | n03494278,593 132 | n07697313,933 133 | n02097658,201 134 | n02093428,180 135 | n02102040,217 136 | n02727426,410 137 | n02077923,150 138 | n07802026,958 139 | n01860187,100 140 | n02123394,283 141 | n04370456,841 142 | n04044716,755 143 | n12057211,986 144 | n02037110,143 145 | n02113799,267 146 | n03876231,696 147 | n02101556,216 148 | n04033995,750 149 | n07693725,931 150 | n03888257,701 151 | n02865351,451 152 | n07747607,950 153 | n04423845,855 154 | n02112350,261 155 | n02667093,399 156 | n02927161,467 157 | n02108089,242 158 | n04613696,915 159 | n04254680,805 160 | n04483307,871 161 | n02279972,323 162 | n02088238,161 163 | n04116512,767 164 | n02869837,452 165 | n02091032,171 166 | n02093991,184 167 | n02133161,295 168 | n04131690,773 169 | n13040303,994 170 | n02966193,476 171 | n03485407,590 172 | n07684084,930 173 | n02165105,300 174 | n03450230,578 175 | n03417042,569 176 | n03724870,643 177 | n09835506,981 178 | n04204238,790 179 | n04037443,751 180 | n03127925,519 181 | n03670208,627 182 | n03980874,735 183 | n04041544,754 184 | n01817953,87 185 | n02111277,256 186 | n03930313,716 187 | n02028035,141 188 | n03216828,536 189 | n02111500,257 190 | n03447721,577 191 | n01917289,109 192 | n03692522,633 193 | n03976467,732 194 | n04162706,785 195 | n04326547,825 196 | n07583066,924 197 | n07613480,927 198 | n03891251,703 199 | n03584829,606 200 | n02107312,237 201 | n02098286,203 202 | n03924679,713 203 | n04086273,763 204 | n02536864,391 205 | n03045698,501 206 | n04584207,903 207 | n03532672,600 208 | n03877845,698 209 | n02089078,165 210 | n03871628,692 211 | n01807496,86 212 | n13052670,996 213 | n03832673,681 214 | n04517823,882 215 | n02786058,419 216 | n02134418,297 217 | n03197337,531 218 | n03095699,510 219 | n01440764,0 220 | n02883205,457 221 | n04265275,811 222 | n02704792,408 223 | n01614925,22 224 | n02877765,455 225 | n03929660,714 226 | n03868242,690 227 | n07871810,962 228 | n01855032,98 229 | n02417914,350 230 | n04111531,766 231 | n04251144,801 232 | n02280649,324 233 | n07920052,967 234 | n09193705,970 235 | n03075370,507 236 | n02112018,259 237 | n02127052,287 238 | n01667778,36 239 | n01498041,6 240 | n02791270,424 241 | n04192698,787 242 | n01629819,25 243 | n02090379,168 244 | n04008634,744 245 | n04589890,904 246 | n02190166,308 247 | n01514668,7 248 | n03803284,676 249 | n01688243,43 250 | n03788365,669 251 | n03495258,594 252 | n04254777,806 253 | n02481823,367 254 | n04509417,880 255 | n01983481,122 256 | n03272562,547 257 | n03017168,494 258 | n01833805,94 259 | n09468604,979 260 | n04355933,836 261 | n03995372,740 262 | n02408429,346 263 | n02094114,185 264 | n02102177,218 265 | n03792972,672 266 | n02834397,443 267 | n04270147,813 268 | n02640242,394 269 | n02011460,133 270 | n03888605,702 271 | n03933933,718 272 | n02950826,471 273 | n03249569,541 274 | n02111889,258 275 | n03018349,495 276 | n01955084,116 277 | n03443371,572 278 | n04026417,748 279 | n03676483,629 280 | n04435653,858 281 | n02361337,336 282 | n02690373,404 283 | n01534433,13 284 | n02281406,325 285 | n03781244,663 286 | n04153751,783 287 | n02804414,431 288 | n03180011,527 289 | n03874599,695 290 | n02492660,379 291 | n01737021,58 292 | n03047690,502 293 | n03259280,544 294 | n02669723,400 295 | n03935335,719 296 | n03146219,524 297 | n01943899,112 298 | n01770081,70 299 | n02410509,347 300 | n02422106,351 301 | n01796340,81 302 | n04380533,846 303 | n02841315,447 304 | n02096177,192 305 | n04392985,848 306 | n02971356,478 307 | n02437616,355 308 | n02106550,234 309 | n04418357,854 310 | n02643566,396 311 | n07754684,955 312 | n02977058,480 313 | n01847000,97 314 | n03062245,503 315 | n07875152,964 316 | n03814639,678 317 | n01944390,113 318 | n03250847,542 319 | n02981792,484 320 | n02445715,361 321 | n03220513,538 322 | n02788148,421 323 | n02840245,446 324 | n02105162,225 325 | n03372029,558 326 | n01667114,35 327 | n07930864,968 328 | n02112706,262 329 | n02782093,417 330 | n03297495,550 331 | n02966687,477 332 | n01755581,67 333 | n03388043,562 334 | n02510455,388 335 | n04371774,843 336 | n04254120,804 337 | n03658185,623 338 | n03874293,694 339 | n04019541,746 340 | n01775062,77 341 | n04118538,768 342 | n07579787,923 343 | n01978455,119 344 | n03179701,526 345 | n02490219,377 346 | n03976657,733 347 | n03837869,682 348 | n07860988,961 349 | n04141975,778 350 | n04442312,859 351 | n04417672,853 352 | n04040759,753 353 | n03461385,582 354 | n09229709,971 355 | n02979186,482 356 | n03950228,725 357 | n02835271,444 358 | n09472597,980 359 | n02100735,212 360 | n03794056,674 361 | n03445924,575 362 | n04590129,905 363 | n03769881,654 364 | n07614500,928 365 | n03584254,605 366 | n07760859,956 367 | n03908618,709 368 | n02708093,409 369 | n01729322,54 370 | n02321529,329 371 | n03838899,683 372 | n02797295,428 373 | n04515003,881 374 | n03729826,644 375 | n02066245,147 376 | n03796401,675 377 | n02124075,285 378 | n02749479,413 379 | n03535780,602 380 | n04118776,769 381 | n02939185,469 382 | n01985128,124 383 | n04462240,865 384 | n03857828,688 385 | n02100877,213 386 | n02403003,345 387 | n04554684,897 388 | n03657121,622 389 | n02138441,299 390 | n02089973,167 391 | n03124043,514 392 | n03196217,530 393 | n03425413,571 394 | n06785654,918 395 | n01630670,26 396 | n04252225,803 397 | n02871525,454 398 | n07753592,954 399 | n07615774,929 400 | n02012849,134 401 | n01824575,91 402 | n01945685,114 403 | n02443114,358 404 | n03201208,532 405 | n01806567,85 406 | n02480495,365 407 | n02108551,244 408 | n02098413,204 409 | n01871265,101 410 | n02058221,146 411 | n02115641,273 412 | n03538406,603 413 | n04033901,749 414 | n02112137,260 415 | n04229816,796 416 | n03983396,737 417 | n02825657,442 418 | n03602883,613 419 | n02106166,232 420 | n01877812,104 421 | n03902125,707 422 | n04389033,847 423 | n03445777,574 424 | n02085936,153 425 | n02948072,470 426 | n03764736,653 427 | n04356056,837 428 | n02113712,266 429 | n04152593,782 430 | n03743016,649 431 | n02134084,296 432 | n02229544,312 433 | n02177972,307 434 | n01580077,17 435 | n02097047,196 436 | n01784675,79 437 | n04522168,883 438 | n02093754,182 439 | n01689811,44 440 | n01729977,55 441 | n02504013,385 442 | n03089624,509 443 | n02504458,386 444 | n01532829,12 445 | n03394916,566 446 | n04557648,898 447 | n04344873,831 448 | n02107574,238 449 | n12267677,988 450 | n02483708,369 451 | n01530575,10 452 | n02412080,348 453 | n03529860,598 454 | n09421951,977 455 | n01855672,99 456 | n02132136,294 457 | n02676566,402 458 | n02397096,343 459 | n01608432,21 460 | n09246464,972 461 | n02099601,207 462 | n03424325,570 463 | n02051845,144 464 | n01818515,88 465 | n03720891,641 466 | n03344393,554 467 | n04065272,757 468 | n03444034,573 469 | n02992529,487 470 | n02701002,407 471 | n15075141,999 472 | n01843065,95 473 | n01749939,64 474 | n02096294,193 475 | n02086079,154 476 | n04147183,780 477 | n03873416,693 478 | n02802426,430 479 | n04479046,869 480 | n02415577,349 481 | n04335435,829 482 | n02091244,173 483 | n03884397,699 484 | n01914609,108 485 | n01981276,121 486 | n04049303,756 487 | n04074963,761 488 | n01773797,74 489 | n07836838,960 490 | n13044778,995 491 | n02168699,303 492 | n03527444,597 493 | n01664065,33 494 | n01698640,50 495 | n03127747,518 496 | n03773504,657 497 | n04456115,862 498 | n03775071,658 499 | n01910747,107 500 | n02108000,241 501 | n02109961,248 502 | n03125729,516 503 | n04542943,891 504 | n03032252,498 505 | n02992211,486 506 | n01829413,93 507 | n03920288,712 508 | n02120079,279 509 | n02396427,342 510 | n06794110,919 511 | n02892201,458 512 | n01744401,62 513 | n02892767,459 514 | n04266014,812 515 | n03345487,555 516 | n03476991,585 517 | n02980441,483 518 | n01677366,39 519 | n02088094,160 520 | n03956157,727 521 | n02492035,378 522 | n02099429,206 523 | n04485082,872 524 | n03788195,668 525 | n04612504,914 526 | n03977966,734 527 | n02106382,233 528 | n09288635,974 529 | n04505470,878 530 | n02999410,488 531 | n04467665,867 532 | n04067472,758 533 | n02114548,270 534 | n04146614,779 535 | n03697007,634 536 | n04532670,888 537 | n02018795,138 538 | n03291819,549 539 | n02115913,274 540 | n01797886,82 541 | n03662601,625 542 | n04311174,822 543 | n04286575,818 544 | n04560804,899 545 | n01644373,31 546 | n03476684,584 547 | n03868863,691 548 | n04069434,759 549 | n02128385,288 550 | n02894605,460 551 | n01632458,28 552 | n02486410,372 553 | n02233338,314 554 | n01924916,110 555 | n03891332,704 556 | n03887697,700 557 | n07716906,940 558 | n13054560,997 559 | n02917067,466 560 | n01687978,42 561 | n02104029,222 562 | n02356798,335 563 | n03290653,548 564 | n03627232,616 565 | n01770393,71 566 | n03942813,722 567 | n01774750,76 568 | n02236044,315 569 | n07720875,945 570 | n02106662,235 571 | n07718472,943 572 | n02206856,309 573 | n02097209,198 574 | n07749582,951 575 | n02281787,326 576 | n04540053,890 577 | n01773157,72 578 | n02056570,145 579 | n12768682,990 580 | n02086910,157 581 | n02087046,158 582 | n03481172,587 583 | n01601694,20 584 | n02817516,439 585 | n04228054,795 586 | n02129165,291 587 | n04136333,775 588 | n03000684,491 589 | n02108915,245 590 | n01644900,32 591 | n03131574,520 592 | n03000247,490 593 | n07734744,947 594 | n03899768,706 595 | n02804610,432 596 | n04532106,887 597 | n03642806,620 598 | n02096437,194 599 | n02097130,197 600 | n04552348,895 601 | n04428191,856 602 | n04404412,851 603 | n03991062,738 604 | n02444819,360 605 | n07718747,944 606 | n02096051,191 607 | n03028079,497 608 | n03637318,619 609 | n04536866,889 610 | n09256479,973 611 | n01753488,66 612 | n04099969,765 613 | n02268443,319 614 | n02018207,137 615 | n02110185,250 616 | n02423022,353 617 | n02655020,397 618 | n01819313,89 619 | n04553703,896 620 | n02096585,195 621 | n02107142,236 622 | n02231487,313 623 | n01675722,38 624 | n07697537,934 625 | n07892512,966 626 | n02102480,220 627 | n01694178,47 628 | n03388183,563 629 | n04372370,844 630 | n02799071,429 631 | n02110806,253 632 | n02002724,128 633 | n03916031,711 634 | n04398044,849 635 | n02514041,389 636 | n03710637,638 637 | n01883070,106 638 | n03400231,567 639 | n02113624,265 640 | n04357314,838 641 | n03937543,720 642 | n02489166,376 643 | n02494079,382 644 | n01494475,4 645 | n02089867,166 646 | n03804744,677 647 | n03063689,505 648 | n02488702,375 649 | n02443484,359 650 | n03590841,607 651 | n04523525,884 652 | n07717556,942 653 | n02607072,393 654 | n13037406,993 655 | n04482393,870 656 | n04259630,808 657 | n02666196,398 658 | n04562935,900 659 | n03272010,546 660 | n01986214,125 661 | n01740131,60 662 | n02099267,205 663 | n04487394,875 664 | n02325366,330 665 | n02895154,461 666 | n01930112,111 667 | n02787622,420 668 | n01443537,1 669 | n04239074,799 670 | n01820546,90 671 | n03992509,739 672 | n03223299,539 673 | n02276258,321 674 | n02110958,254 675 | n03188531,529 676 | n03759954,650 677 | n03733131,645 678 | n02319095,328 679 | n07714990,937 680 | n07590611,926 681 | n04579432,902 682 | n02457408,364 683 | n07745940,949 684 | n03709823,636 685 | n04493381,876 686 | n01665541,34 687 | n06359193,916 688 | n02776631,415 689 | n02963159,474 690 | n02815834,438 691 | n04525305,886 692 | n02095314,188 693 | n02107908,240 694 | n03691459,632 695 | n03207941,534 696 | n06874185,920 697 | n04149813,781 698 | n03944341,723 699 | n01990800,126 700 | n02085620,151 701 | n07753275,953 702 | n03085013,508 703 | n02169497,304 704 | n03717622,640 705 | n02442845,357 706 | n02088632,164 707 | n02092002,177 708 | n02364673,338 709 | n03841143,685 710 | n07716358,939 711 | n03733281,646 712 | n02487347,373 713 | n03775546,659 714 | n04005630,743 715 | n02101388,215 716 | n07873807,963 717 | n03042490,500 718 | n12985857,991 719 | n03065424,506 720 | n04550184,894 721 | n02025239,139 722 | n02088364,162 723 | n04591713,907 724 | n02859443,449 725 | n02488291,374 726 | n02692877,405 727 | n04296562,819 728 | n01693334,46 729 | n02687172,403 730 | n02256656,316 731 | n04346328,832 732 | n04350905,834 733 | n01756291,68 734 | n02130308,293 735 | n01734418,56 736 | n01882714,105 737 | n09428293,978 738 | n03483316,589 739 | n03599486,612 740 | n04458633,863 741 | n13133613,998 742 | n02277742,322 743 | n04487081,874 744 | n03776460,660 745 | n07768694,957 746 | n03710193,637 747 | n02086240,155 748 | n03908714,710 749 | n01491361,3 750 | n02092339,178 751 | n02109047,246 752 | n01776313,78 753 | n02606052,392 754 | n04238763,798 755 | n02493793,381 756 | n02422699,352 757 | n03496892,595 758 | n04263257,809 759 | n02017213,136 760 | n02837789,445 761 | n02497673,383 762 | n03982430,736 763 | n02116738,275 764 | n04200800,788 765 | n02814860,437 766 | n02109525,247 767 | n02091635,175 768 | n12144580,987 769 | n02167151,302 770 | n02013706,135 771 | n03825788,680 772 | n02102973,221 773 | n02860847,450 774 | n04204347,791 775 | n02326432,331 776 | n04599235,911 777 | n04209239,794 778 | n03379051,560 779 | n04371430,842 780 | n02085782,152 781 | n02114712,271 782 | n02137549,298 783 | n03325584,552 784 | n04325704,824 785 | n02870880,453 786 | n12620546,989 787 | n02114367,269 788 | n03337140,553 789 | n02110341,251 790 | n07695742,932 791 | n02090622,169 792 | n02101006,214 793 | n02437312,354 794 | n01641577,30 795 | n01728920,53 796 | n04332243,828 797 | n02093647,181 798 | n02091134,172 799 | n02009912,132 800 | n03840681,684 801 | n02988304,485 802 | n07717410,941 803 | n04355338,835 804 | n07711569,935 805 | n02113023,263 806 | n01692333,45 807 | n02486261,371 808 | n02095570,189 809 | n03208938,535 810 | n02793495,425 811 | n02843684,448 812 | n04596742,909 813 | n02441942,356 814 | n03710721,639 815 | n02094433,187 816 | n07565083,922 817 | n04409515,852 818 | n02102318,219 819 | n03388549,564 820 | n01950731,115 821 | n04476259,868 822 | n03649909,621 823 | n03478589,586 824 | n04004767,742 825 | n02910353,464 826 | n02974003,479 827 | n02395406,341 828 | n04285008,817 829 | n07715103,938 830 | n02342885,333 831 | n02093859,183 832 | n03314780,551 833 | n02951358,472 834 | n03895866,705 835 | n04367480,840 836 | n02007558,130 837 | n02128757,289 838 | n03947888,724 839 | n03384352,561 840 | n02123159,282 841 | n03793489,673 842 | n04277352,816 843 | n03673027,628 844 | n02088466,163 845 | n03706229,635 846 | n07584110,925 847 | n02120505,280 848 | n04090263,764 849 | n02814533,436 850 | n09399592,976 851 | n07248320,921 852 | n02105855,230 853 | n03761084,651 854 | n02777292,416 855 | n02105641,229 856 | n03109150,512 857 | n02006656,129 858 | n03124170,515 859 | n04347754,833 860 | n04376876,845 861 | n03785016,665 862 | n01537544,14 863 | n03467068,583 864 | n01798484,83 865 | n07831146,959 866 | n03126707,517 867 | n04154565,784 868 | n04201297,789 869 | n02264363,318 870 | n02259212,317 871 | n03240683,540 872 | n03770439,655 873 | n02699494,406 874 | n03376595,559 875 | n02165456,301 876 | n03459775,581 877 | n04579145,901 878 | n04208210,792 879 | n02794156,426 880 | n03721384,642 881 | n01682714,40 882 | n02730930,411 883 | n03595614,610 884 | n03814906,679 885 | n04243546,800 886 | n02113186,264 887 | n02009229,131 888 | n02090721,170 889 | n03594945,609 890 | n02106030,231 891 | n03954731,726 892 | n04258138,807 893 | n03014705,492 894 | n03958227,728 895 | n02091467,174 896 | n02219486,310 897 | n02086646,156 898 | n03063599,504 899 | n02823428,440 900 | n03218198,537 901 | n03016953,493 902 | n01631663,27 903 | n02125311,286 904 | n04525038,885 905 | n03791053,670 906 | n03207743,533 907 | n04311004,821 908 | n01518878,9 909 | n01828970,92 910 | n02909870,463 911 | n02105505,228 912 | n02363005,337 913 | n11939491,985 914 | n04039381,752 915 | n02123597,284 916 | n03970156,731 917 | n03404251,568 918 | n04133789,774 919 | n03998194,741 920 | n03661043,624 921 | n02105056,224 922 | n02484975,370 923 | n01484850,2 924 | n02110063,249 925 | n03041632,499 926 | n01622779,24 927 | n10148035,982 928 | n02104365,223 929 | n03866082,689 930 | n01980166,120 931 | n04366367,839 932 | n03133878,521 933 | n03623198,615 934 | n02391049,340 935 | n01773549,73 936 | n03457902,580 937 | n04591157,906 938 | n02346627,334 939 | n02483362,368 940 | n04336792,830 941 | n02398521,344 942 | n02672831,401 943 | n03355925,557 944 | n03777568,661 945 | n03598930,611 946 | n04328186,826 947 | n03026506,496 948 | n01669191,37 949 | n01978287,118 950 | n03141823,523 951 | n02808304,434 952 | n04209133,793 953 | n03633091,618 954 | n02906734,462 955 | n04604644,912 956 | n01735189,57 957 | n04235860,797 958 | n03742115,648 959 | n04273569,814 960 | n02107683,239 961 | n03544143,604 962 | n01968897,117 963 | n03680355,630 964 | n02951585,473 965 | n01632777,29 966 | n04501370,877 967 | n04009552,745 968 | n03347037,556 969 | n03777754,662 970 | n02916936,465 971 | n02091831,176 972 | n01685808,41 973 | n03617480,614 974 | n03666591,626 975 | n01728572,52 976 | n02071294,148 977 | n03733805,647 978 | n02317335,327 979 | n03930630,717 980 | n01795545,80 981 | n02823750,441 982 | n03485794,591 983 | n02100583,211 984 | n01739381,59 985 | n02447366,362 986 | n02110627,252 987 | n01872401,102 988 | n02500267,384 989 | n04429376,857 990 | n04443257,860 991 | n01514859,8 992 | n02791124,423 993 | n01592084,19 994 | n02128925,290 995 | n03787032,667 996 | n02509815,387 997 | n02087394,159 998 | n02111129,255 999 | n02389026,339 1000 | n02769748,414 1001 | -------------------------------------------------------------------------------- /mnist_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines how to handle the MNIST dataset. 3 | """ 4 | 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import tensorflow as tf 11 | import numpy as np 12 | import os 13 | import csv 14 | import preprocessing 15 | from preprocessing import image_preprocessing 16 | import configparser 17 | import gzip 18 | import pickle 19 | 20 | this_folder = os.path.dirname(os.path.abspath(__file__)) + '/' 21 | 22 | config = configparser.ConfigParser() 23 | config.read(this_folder + 'dataset_paths.ini') 24 | base_folder = config['PATHS'].get('MNIST') 25 | 26 | 27 | 28 | 29 | ##### Training ########### 30 | def collect_train_data(num_samples_per_class=0): 31 | f = gzip.open(base_folder + 'mnist.pkl.gz', 'rb') 32 | train_set, valid_set, test_set = pickle.load(f, encoding='latin1') 33 | f.close() 34 | 35 | train_data = np.array(train_set[0], dtype='float32').reshape((-1,28,28,1)) 36 | train_labels = np.array(train_set[1], dtype='int32') 37 | return train_data, train_labels 38 | 39 | def collect_val_data(): 40 | print("Collecting validation data...") 41 | f = gzip.open(base_folder + 'mnist.pkl.gz', 'rb') 42 | train_set, valid_set, test_set = pickle.load(f, encoding='latin1') 43 | f.close() 44 | val_data = np.array(valid_set[0], dtype='float32').reshape((-1,28,28,1)) 45 | val_labels = np.array(valid_set[1], dtype='int32') 46 | return val_data, val_labels 47 | 48 | # tf.data batch iterator for the training data 49 | def train_BI(images, 50 | labels, 51 | batch_size, 52 | num_parallel_calls=100): 53 | dataset = tf.data.Dataset.from_tensor_slices( 54 | (images, labels)) 55 | batched_dataset = dataset.batch(batch_size, 56 | drop_remainder = False) 57 | train_batch_iterator = batched_dataset.make_initializable_iterator() 58 | return train_batch_iterator 59 | 60 | # tf.data batch iterator for the validation data 61 | def val_BI(images, 62 | labels, 63 | batch_size, 64 | num_parallel_calls=100): 65 | dataset = tf.data.Dataset.from_tensor_slices( 66 | (images, labels)) 67 | batched_dataset = dataset.batch(batch_size, 68 | drop_remainder = False) 69 | val_batch_iterator = batched_dataset.make_initializable_iterator() 70 | return val_batch_iterator 71 | 72 | # Additional tf.data batch iterator for the data that is used just for the propagation 73 | # of a few images for visualization. 74 | def img_BI(images, 75 | labels, 76 | batch_size, 77 | num_parallel_calls=100): 78 | dataset = tf.data.Dataset.from_tensor_slices( 79 | (images, labels)) 80 | batched_dataset = dataset.batch(batch_size, 81 | drop_remainder = False) 82 | img_batch_iterator = batched_dataset.make_initializable_iterator() 83 | return img_batch_iterator 84 | 85 | def interpret_as_image(image): 86 | return image 87 | 88 | def num_train_samples(): 89 | return 50000 90 | 91 | def num_val_samples(): 92 | return 1000 93 | 94 | def bounds(): 95 | # This is a little problematic here. Foolbox only allows 96 | # for scalar bounds, not bounds per channel. For this reason, 97 | # we use the worst-case bounds here. 98 | return (0,1) 99 | 100 | min_values = np.array([0.,0.,0.],np.float32) 101 | max_values = np.array([1.,1.,1.],np.float32) 102 | def image_range(): 103 | return [0.,1.] -------------------------------------------------------------------------------- /mnist_example.ini: -------------------------------------------------------------------------------- 1 | [PATHS] 2 | # The used dataset. Currently only supports 'TinyImageNet', 'MNIST' 3 | # and 'ImageNet'. 4 | dataset_name = MNIST 5 | # The folder in which to save the different tensorboard summaries. 6 | # For a cohort of experiments, this should ideally be the same. 7 | tensorboard_logdir = /localdata/logs/ 8 | # The folder in which to save the different tensorboard summaries. 9 | # For a cohort of experiments, this should ideally be the same. 10 | saved_model_folder = /localdata/models/ 11 | 12 | [ARCHITECTURE] 13 | # Possible values: VGG16, ResNet18 or ResNet50 14 | model = SmallNet 15 | # Do we want to use a pretrained model? ONLY works 16 | # if we use ResNet50 AND ImageNet data. 17 | pretrained = False 18 | 19 | [HYPERPARAMETERS] 20 | # The total number of epochs. 21 | num_epochs = 200 22 | # The learning rate at the beginning of training. 23 | learning_rate_at_start = 0.001 24 | # Every lr_decrease_interval epochs, multiply the 25 | # current learning rate by lr_decrease_factor. 26 | # If you don't want this, simply set lr_decrease_interval 27 | # higher than num_epochs. 28 | lr_decrease_interval = 80 29 | lr_decrease_factor = .1 30 | # Batch size per GPU. 31 | batch_size_per_gpu = 100 32 | # Whether to penalize a weighted sum of 1-norms of 33 | # wavelet coefficients of the saliency. 'False' is faster 34 | # than just setting the multipliers to zero. 35 | robust_regularization = True 36 | # Whether to decompose \nabla_x L into its wavelet coefficients. 37 | # If 'False', the regularization is applied to just the image. 38 | use_wavelet_decomposition = False 39 | # Define with respect to which output the gradients are 40 | # calculated. 'logits' or 'NLL'. 41 | sensitivity_mode = NLL 42 | # Multiplier in front of the penalty term. Model starts to degenerate at 1000. Name is a misnomer currently. 43 | lp_wavelet_parameter = 10. 44 | # Which p-norm to use 45 | p_norm = 2 46 | # Squared 2-norm weight penalty parameter. 47 | weight_decay_parameter = 0.000001 48 | # Exponential moving average multiplier for the batch 49 | # normalization layers, if present. 50 | bn_momentum_value = .99 51 | # Whether to set K.learning_phase() to 0 or 1 when 52 | # training. If we train batch normalized networks, 53 | # setting this to 0 has the effect of using the 54 | # running mean statistics instead of the batch 55 | # statistics, which is more stable but slower in the 56 | # beginning. 57 | learning_phase = 1 58 | 59 | [LOGGING] 60 | # Execute the training summary operator every 61 | # train_summary_period batches. 62 | train_summary_period = 50 63 | # Execute the training summary operator every 64 | # val_summary_period batches. 65 | val_summary_period = 300 66 | # When checking the adversarial vulnerability, 67 | # try the different attacks on num_adversarial_batches 68 | # batches of the img_data set. This may take a considerable 69 | # amount of time so don't set this too high. 70 | num_adversarial_batches = 1 71 | # Execute the training summary operator every 72 | # train_summary_period batches. 73 | adversarial_test_period = 5000000 74 | -------------------------------------------------------------------------------- /preprocessing.py: -------------------------------------------------------------------------------- 1 | """" 2 | Preprocessing pipeline for all datasets. 3 | """ 4 | 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import tensorflow as tf 11 | import numpy as np 12 | 13 | 14 | 15 | def image_preprocessing(file, 16 | label, 17 | file_type = 'filename', 18 | shape = [224,224], 19 | random_events = True, 20 | data_augmentation = True, 21 | additive_noise = False, 22 | subtract_by = 0., 23 | divide_by = 1., 24 | colorspace = 'BGR', 25 | rescale=True, 26 | min_rescale=256, 27 | max_rescale=512, 28 | test_rescale=256, 29 | noise_level=1, 30 | clip_values = None): 31 | """ 32 | Implementation of a flexible image preprocessing pipeline. 33 | """ 34 | if file_type == 'filename': 35 | image = tf.read_file(file) 36 | image = tf.image.decode_jpeg(image,channels=3) 37 | 38 | elif file_type == 'tfrecord': 39 | features = tf.parse_single_example( 40 | file, 41 | features={ 42 | 'image/height': tf.FixedLenFeature([],tf.int64), 43 | 'image/width': tf.FixedLenFeature([],tf.int64), 44 | 'image/colorspace': tf.FixedLenFeature([],tf.string), 45 | 'image/channels': tf.FixedLenFeature([],tf.int64), 46 | 'image/class/label': tf.FixedLenFeature([],tf.int64), 47 | 'image/class/text': tf.FixedLenFeature([],tf.string), 48 | 'image/format': tf.FixedLenFeature([],tf.string), 49 | 'image/filename': tf.FixedLenFeature([],tf.string), 50 | 'image/encoded': tf.FixedLenFeature([],tf.string) 51 | }) 52 | height = tf.cast(features['image/height'],tf.float32) 53 | width = tf.cast(features['image/width'],tf.float32) 54 | channels = tf.cast(features['image/channels'],tf.float32) 55 | img_size = tf.stack([height,width,channels],axis=0) 56 | 57 | # In the TFRecord file, the labels are encoded from 1 to 1000. 58 | # Here, we convert to 0,...,999 59 | label = features['image/class/label'] - 1 60 | 61 | # Convert the serialized byte image files to tensors. 62 | image = features['image/encoded'] 63 | image = tf.image.decode_jpeg(image,channels=3) 64 | 65 | elif file_type == 'tensor': 66 | image = file 67 | 68 | else: 69 | raise ValueError('file_type must be "filename" or "tensor"') 70 | 71 | image = tf.cast(image,tf.float32) 72 | 73 | 74 | [d_y,d_x] = list(shape) 75 | 76 | 77 | 78 | if rescale: 79 | if random_events: 80 | rescale_size = tf.random_uniform((1,), 81 | min_rescale, 82 | max_rescale, 83 | dtype=tf.float32)[0] 84 | else: 85 | rescale_size = tf.constant(test_rescale, 86 | dtype=tf.float32) 87 | 88 | if file_type != 'tfrecord': 89 | 90 | image_size = tf.shape(image) 91 | height = tf.cast(image_size[0], tf.float32) 92 | width = tf.cast(image_size[1], tf.float32) 93 | 94 | 95 | h_l_w = [tf.cast(rescale_size, tf.int32), 96 | tf.cast(rescale_size/height*width+1, tf.int32)] 97 | w_l_h = [tf.cast(rescale_size/width*height+1, tf.int32), 98 | tf.cast(rescale_size, tf.int32)] 99 | 100 | new_image_size = tf.cond(tf.less(height, width), 101 | lambda: h_l_w, 102 | lambda: w_l_h) 103 | new_height = new_image_size[0] 104 | new_width = new_image_size[1] 105 | 106 | image = tf.image.resize_images(image,new_image_size) 107 | 108 | if random_events: 109 | crop_y = tf.random_uniform((1,),0,new_height-d_y,dtype=tf.int32)[0] 110 | crop_x = tf.random_uniform((1,),0,new_width-d_x,dtype=tf.int32)[0] 111 | crop_location = tf.cast([crop_y,crop_x,0],tf.int32) 112 | crop_size = tf.cast([d_y,d_x,3],tf.int32) 113 | 114 | else: 115 | crop_y = (new_height-d_y)/2 116 | crop_x = (new_width-d_x)/2 117 | crop_location = tf.cast([crop_y,crop_x,0],tf.int32) 118 | crop_size = tf.cast([d_y,d_x,3],tf.int32) 119 | 120 | image = tf.slice(image,crop_location,crop_size) 121 | 122 | 123 | if data_augmentation: 124 | image = tf.image.random_hue(image,.06) 125 | image = tf.image.random_flip_left_right(image) 126 | # Behaviour not tested yet: 127 | #image = tf.image.random_brightness(image,.8,1.2) 128 | #image = tf.image.random_contrast(image,.2) 129 | 130 | 131 | # The original implementations of VGG-nets and 132 | # ResNets use BGR colorspace. If we used pre-trained 133 | # weights, we should switch to the same colorspace. 134 | if colorspace == 'BGR': 135 | image = image[:,:,::-1] 136 | image = tf.cast(image, tf.float32) 137 | 138 | if subtract_by != 0: 139 | if subtract_by == 'ImageNet': 140 | # VGG preprocessing uses this 141 | subtrahend = np.array([103.939,116.779,123.68]) 142 | subtrahend = subtrahend * tf.ones_like(image) 143 | else: 144 | subtrahend = np.array(subtract_by, dtype='float32') 145 | image-= subtrahend 146 | 147 | if divide_by != 1: 148 | image/= divide_by 149 | 150 | if additive_noise: 151 | # We don't add noise of a constant std 152 | # to the images, because that would mean that 153 | # (for sufficiently large images) all images 154 | # that the neural network sees are equally 155 | # noisy. Instead, we randomly vary the noise 156 | # level in every batch. 157 | # This was NOT USED in the paper. 158 | noise_level = tf.random_uniform((1,), 159 | minval = 0, 160 | maxval = tf.abs(noise_level)+1e-6) 161 | image+= tf.random_normal((d_y,d_x,3), 162 | stddev = noise_level) 163 | 164 | # When using additive noise, the tensor values 165 | # can theoretically escape the original image 166 | # value range. This can be clipped by setting 167 | # clip_values = [min_value, max_value] 168 | if clip_values: 169 | image = tf.clip_by_value(image, 170 | clip_values[0], 171 | clip_values[1]) 172 | 173 | return image, label 174 | 175 | 176 | def interpret_as_image(image, 177 | add_by=0., 178 | multiply_by=1., 179 | colorspace = 'RGB'): 180 | """ 181 | Here, image is NHWC, not HWC like above. 182 | """ 183 | if multiply_by !=1.: 184 | image*= multiply_by 185 | if add_by != 0.: 186 | if add_by == "ImageNet": 187 | summand = np.array([103.939,116.779,123.68]) 188 | summand = summand * np.ones_like(image) 189 | else: 190 | summand = np.array(subtract_by, dtype='float32') 191 | image+= summand 192 | if colorspace == 'BGR': 193 | image = image[:,:,:,::-1] 194 | return image 195 | -------------------------------------------------------------------------------- /resnet50.py: -------------------------------------------------------------------------------- 1 | """ResNet50 model for Keras. 2 | This is an adapted version of the official ResNet50 3 | implementation for Keras, which in turn was was adapted 4 | from a contribution by 'BigMoyan'. 5 | """ 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import os 11 | 12 | #from keras import get_submodules_from_kwargs 13 | 14 | import tensorflow 15 | #import keras 16 | import keras_applications 17 | from tensorflow.keras.layers import (Conv2D, BatchNormalization, Activation, 18 | GlobalAveragePooling2D, Dense, add, ZeroPadding2D, Input, MaxPooling2D) 19 | from keras_applications import imagenet_utils 20 | import tensorflow.keras.backend as K 21 | from tensorflow.keras.models import Model 22 | 23 | 24 | this_folder = os.path.dirname(os.path.abspath(__file__)) + '/' 25 | preprocess_input = imagenet_utils.preprocess_input 26 | 27 | WEIGHTS_PATH = ('https://github.com/fchollet/deep-learning-models/' 28 | 'releases/download/v0.2/' 29 | 'resnet50_weights_tf_dim_ordering_tf_kernels.h5') 30 | 31 | use_fused = False 32 | def identity_block(input_tensor, kernel_size, filters, stage, block): 33 | """The identity block is the block that has no conv layer at shortcut. 34 | 35 | # Arguments 36 | input_tensor: input tensor 37 | kernel_size: default 3, the kernel size of 38 | middle conv layer at main path 39 | filters: list of integers, the filters of 3 conv layer at main path 40 | stage: integer, current stage label, used for generating layer names 41 | block: 'a','b'..., current block label, used for generating layer names 42 | 43 | # Returns 44 | Output tensor for the block. 45 | """ 46 | bn_inputs = [] 47 | filters1, filters2, filters3 = filters 48 | if K.image_data_format() == 'channels_last': 49 | bn_axis = 3 50 | else: 51 | bn_axis = 1 52 | conv_name_base = 'res' + str(stage) + block + '_branch' 53 | bn_name_base = 'bn' + str(stage) + block + '_branch' 54 | 55 | x = Conv2D(filters1, 56 | (1, 1), 57 | kernel_initializer='he_normal', 58 | name=conv_name_base + '2a')(input_tensor) 59 | bn_inputs.append(x) 60 | x = BatchNormalization(axis=bn_axis, 61 | name=bn_name_base + '2a', 62 | fused=use_fused)(x) 63 | x = Activation('relu')(x) 64 | 65 | x = Conv2D(filters2, 66 | kernel_size, 67 | padding='same', 68 | kernel_initializer='he_normal', 69 | name=conv_name_base + '2b')(x) 70 | bn_inputs.append(x) 71 | x = BatchNormalization(axis=bn_axis, 72 | name=bn_name_base + '2b', 73 | fused=use_fused)(x) 74 | x = Activation('relu')(x) 75 | 76 | x = Conv2D(filters3, 77 | (1, 1), 78 | kernel_initializer='he_normal', 79 | name=conv_name_base + '2c')(x) 80 | bn_inputs.append(x) 81 | x = BatchNormalization(axis=bn_axis, 82 | name=bn_name_base + '2c', 83 | fused=use_fused)(x) 84 | x = add([x, input_tensor]) 85 | x = Activation('relu')(x) 86 | return x, bn_inputs 87 | 88 | 89 | def conv_block(input_tensor, 90 | kernel_size, 91 | filters, 92 | stage, 93 | block, 94 | strides=(2, 2)): 95 | """A block that has a conv layer at shortcut. 96 | 97 | # Arguments 98 | input_tensor: input tensor 99 | kernel_size: default 3, the kernel size of 100 | middle conv layer at main path 101 | filters: list of integers, the filters of 3 conv layer at main path 102 | stage: integer, current stage label, used for generating layer names 103 | block: 'a','b'..., current block label, used for generating layer names 104 | strides: Strides for the first conv layer in the block. 105 | 106 | # Returns 107 | Output tensor for the block. 108 | 109 | Note that from stage 3, 110 | the first conv layer at main path is with strides=(2, 2) 111 | And the shortcut should have strides=(2, 2) as well 112 | """ 113 | bn_inputs = [] 114 | filters1, filters2, filters3 = filters 115 | if K.image_data_format() == 'channels_last': 116 | bn_axis = 3 117 | else: 118 | bn_axis = 1 119 | conv_name_base = 'res' + str(stage) + block + '_branch' 120 | bn_name_base = 'bn' + str(stage) + block + '_branch' 121 | 122 | x = Conv2D(filters1, 123 | (1, 1), 124 | strides=strides, 125 | kernel_initializer='he_normal', 126 | name=conv_name_base + '2a')(input_tensor) 127 | bn_inputs.append(x) 128 | x = BatchNormalization(axis=bn_axis, 129 | name=bn_name_base + '2a', 130 | fused=use_fused)(x) 131 | x = Activation('relu')(x) 132 | 133 | x = Conv2D(filters2, 134 | kernel_size, 135 | padding='same', 136 | kernel_initializer='he_normal', 137 | name=conv_name_base + '2b')(x) 138 | bn_inputs.append(x) 139 | x = BatchNormalization(axis=bn_axis, 140 | name=bn_name_base + '2b', 141 | fused=use_fused)(x) 142 | x = Activation('relu')(x) 143 | 144 | x = Conv2D(filters3, 145 | (1, 1), 146 | kernel_initializer='he_normal', 147 | name=conv_name_base + '2c')(x) 148 | bn_inputs.append(x) 149 | x = BatchNormalization(axis=bn_axis, 150 | name=bn_name_base + '2c', 151 | fused=use_fused)(x) 152 | 153 | shortcut = Conv2D(filters3, 154 | (1, 1), 155 | strides=strides, 156 | kernel_initializer='he_normal', 157 | name=conv_name_base + '1')(input_tensor) 158 | bn_inputs.append(shortcut) 159 | shortcut = BatchNormalization(axis=bn_axis, 160 | name=bn_name_base + '1', 161 | fused=use_fused)(shortcut) 162 | 163 | x = add([x, shortcut]) 164 | x = Activation('relu')(x) 165 | return x, bn_inputs 166 | 167 | 168 | def create_model(input_tensor, 169 | input_shape, 170 | num_classes, 171 | pretrained = False, 172 | **kwargs): 173 | """Instantiates the ResNet50 architecture. 174 | 175 | Optionally loads weights pre-trained on ImageNet. 176 | Note that the data format convention used by the model is 177 | the one specified in your Keras config at `~/.keras/keras.json`. 178 | 179 | # Arguments 180 | include_top: whether to include the fully-connected 181 | layer at the top of the network. 182 | weights: one of `None` (random initialization), 183 | 'imagenet' (pre-training on ImageNet), 184 | or the path to the weights file to be loaded. 185 | input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) 186 | to use as image input for the model. 187 | input_shape: optional shape tuple, only to be specified 188 | if `include_top` is False (otherwise the input shape 189 | has to be `(224, 224, 3)` (with `channels_last` data format) 190 | or `(3, 224, 224)` (with `channels_first` data format). 191 | It should have exactly 3 inputs channels, 192 | and width and height should be no smaller than 197. 193 | E.g. `(200, 200, 3)` would be one valid value. 194 | pooling: Optional pooling mode for feature extraction 195 | when `include_top` is `False`. 196 | - `None` means that the output of the model will be 197 | the 4D tensor output of the 198 | last convolutional layer. 199 | - `avg` means that global average pooling 200 | will be applied to the output of the 201 | last convolutional layer, and thus 202 | the output of the model will be a 2D tensor. 203 | - `max` means that global max pooling will 204 | be applied. 205 | classes: optional number of classes to classify images 206 | into, only to be specified if `include_top` is True, and 207 | if no `weights` argument is specified. 208 | 209 | # Returns 210 | A Keras model instance. 211 | 212 | # Raises 213 | ValueError: in case of invalid argument for `weights`, 214 | or invalid input shape. 215 | """ 216 | 217 | keras_utils = keras_applications._KERAS_UTILS 218 | 219 | 220 | 221 | if K.image_data_format() == 'channels_last': 222 | bn_axis = 3 223 | else: 224 | bn_axis = 1 225 | 226 | bn_inputs = [] 227 | img_input = Input(tensor=input_tensor, shape=input_shape) 228 | x = ZeroPadding2D(padding=(3, 3), name='conv1_pad')(img_input) 229 | x = Conv2D(64, (7, 7), 230 | strides=(2, 2), 231 | padding='valid', 232 | kernel_initializer='he_normal', 233 | name='conv1')(x) 234 | bn_inputs.append(x) 235 | x = BatchNormalization(axis=bn_axis, name='bn_conv1', fused=use_fused)(x) 236 | x = Activation('relu')(x) 237 | x = ZeroPadding2D(padding=(1, 1), name='pool1_pad')(x) 238 | x = MaxPooling2D((3, 3), strides=(2, 2))(x) 239 | 240 | x, bn_inputs_ = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) 241 | bn_inputs = bn_inputs + bn_inputs_ 242 | x, bn_inputs_ = identity_block(x, 3, [64, 64, 256], stage=2, block='b') 243 | bn_inputs = bn_inputs + bn_inputs_ 244 | x, bn_inputs_ = identity_block(x, 3, [64, 64, 256], stage=2, block='c') 245 | bn_inputs = bn_inputs + bn_inputs_ 246 | 247 | x, bn_inputs_ = conv_block(x, 3, [128, 128, 512], stage=3, block='a') 248 | bn_inputs = bn_inputs + bn_inputs_ 249 | x, bn_inputs_ = identity_block(x, 3, [128, 128, 512], stage=3, block='b') 250 | bn_inputs = bn_inputs + bn_inputs_ 251 | x, bn_inputs_ = identity_block(x, 3, [128, 128, 512], stage=3, block='c') 252 | bn_inputs = bn_inputs + bn_inputs_ 253 | x, bn_inputs_ = identity_block(x, 3, [128, 128, 512], stage=3, block='d') 254 | bn_inputs = bn_inputs + bn_inputs_ 255 | 256 | x, bn_inputs_ = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') 257 | bn_inputs = bn_inputs + bn_inputs_ 258 | x, bn_inputs_ = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') 259 | bn_inputs = bn_inputs + bn_inputs_ 260 | x, bn_inputs_ = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') 261 | bn_inputs = bn_inputs + bn_inputs_ 262 | x, bn_inputs_ = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') 263 | bn_inputs = bn_inputs + bn_inputs_ 264 | x, bn_inputs_ = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') 265 | bn_inputs = bn_inputs + bn_inputs_ 266 | x, bn_inputs_ = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') 267 | bn_inputs = bn_inputs + bn_inputs_ 268 | 269 | x, bn_inputs_ = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') 270 | bn_inputs = bn_inputs + bn_inputs_ 271 | x, bn_inputs_ = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') 272 | bn_inputs = bn_inputs + bn_inputs_ 273 | x, bn_inputs_ = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') 274 | bn_inputs = bn_inputs + bn_inputs_ 275 | 276 | x = GlobalAveragePooling2D(name='avg_pool')(x) 277 | x = Dense(num_classes, 278 | activation='linear', 279 | name='fc1000')(x) 280 | 281 | # Create model. 282 | model = Model(img_input, [x] + bn_inputs, name='resnet50') 283 | 284 | # Load weights. 285 | if pretrained == True: 286 | print("Loading pretrained model..") 287 | weights_path = tensorflow.keras.utils.get_file( 288 | 'resnet50_weights_tf_dim_ordering_tf_kernels.h5', 289 | WEIGHTS_PATH, 290 | cache_subdir='models', 291 | md5_hash='a7b3fe01876f51b976af0dea6bc144eb') 292 | model.load_weights(weights_path) 293 | 294 | return model 295 | -------------------------------------------------------------------------------- /robust_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | from tensorflow.python.client import device_lib 7 | import utils 8 | import summary_utils 9 | 10 | from foolbox.models import TensorFlowModel 11 | from diff_ops import Rop 12 | 13 | 14 | def bias_shifted_input(x, b, direction): 15 | direction_norm_squared = tf.reduce_sum( 16 | direction ** 2, 17 | axis=[1, 2, 3], 18 | keepdims=True) 19 | return x + tf.reshape(b,[-1,1,1,1]) * direction / tf.reshape(direction_norm_squared,[-1,1,1,1]) 20 | 21 | class robust_model: 22 | def __init__(self, 23 | iterator, 24 | session, 25 | model, 26 | num_classes, 27 | optimizer, 28 | dataset, 29 | p_norm = 2., 30 | alpha = None, 31 | decomp_type = 'bior2.2', 32 | NUMPY_images = None, 33 | NUMPY_labels = None, 34 | learning_rate = .001, 35 | weight_decay_p = .0001, 36 | lp_wavelet_p = .0001, 37 | batch_size = 32, 38 | bn_momentum = .99, 39 | robust_regularization = True, 40 | use_wavelet_decomposition = True, 41 | wavelet_weights = [0,1], 42 | sensitivity_mode = 'logits', 43 | graph = tf.get_default_graph()): 44 | 45 | self.iterator = iterator 46 | self.session = session 47 | self.model = model 48 | self.num_classes = num_classes 49 | self.optimizer = optimizer 50 | self.dataset = dataset 51 | self.robust_regularization = robust_regularization 52 | self.wavelet_weights = wavelet_weights 53 | self.nested_wavelet_weights = utils.nested_weight_list( 54 | wavelet_weights) 55 | self.sensitivity_mode = sensitivity_mode 56 | self.graph = graph 57 | self.decomp_type = decomp_type 58 | 59 | self.decomp_depth = len(wavelet_weights)-1 60 | self.learning_rate = learning_rate 61 | self.weight_decay_p = weight_decay_p 62 | self.lp_wavelet_p = lp_wavelet_p 63 | self.batch_size = batch_size 64 | self.bn_momentum = bn_momentum 65 | self.graph = tf.get_default_graph() 66 | self.p_norm = p_norm 67 | 68 | 69 | self.alpha = alpha 70 | self.NUMPY_images = NUMPY_images 71 | self.NUMPY_labels = NUMPY_labels 72 | 73 | if use_wavelet_decomposition: 74 | from fwt import multi_channel_fwt, create_filter_bank 75 | self.decomp_filters, self.reconst_filters = create_filter_bank( 76 | decomp_type) 77 | 78 | devices = device_lib.list_local_devices() 79 | GPU_devices = [dev.name for dev in devices 80 | if dev.device_type=='GPU'] 81 | self.num_GPUs = len(GPU_devices) 82 | 83 | tensors = [] 84 | scalars = [] 85 | gradients = [] 86 | summaries = [] 87 | with tf.variable_scope(tf.get_variable_scope()): 88 | with session.as_default(): 89 | for dev in range(self.num_GPUs): 90 | with tf.device('/device:GPU:%d' % dev): 91 | with tf.name_scope('GPU_%d' % dev) as scope: 92 | print("Compiling on GPU %d ..." %dev) 93 | 94 | tensors.append(dict()) 95 | scalars.append(dict()) 96 | 97 | # scalars finished converting to dict: 98 | # mean_NLL, sum_of_true_logits, mean_correlations 99 | 100 | # Get the inputs from the iterators 101 | next_element = iterator.get_next() 102 | tensors[-1]['images'] = next_element[0] 103 | tensors[-1]['targets'] = next_element[1] 104 | tensors[-1]['one_hot_targets'] = tf.one_hot( 105 | tensors[-1]['targets'], 106 | self.num_classes) 107 | 108 | # Get the forward propagated output 109 | # for the current batch of this GPU. 110 | network_output = model(tensors[-1]['images']) 111 | tensors[-1]['logits'] = network_output 112 | 113 | 114 | 115 | 116 | # For neural networks that use batch 117 | # normalization, network_output is actually 118 | # a list of tensors, where logits[1:] 119 | # represent the inputs to the BatchNorm 120 | # layers. Here, we handle this situation 121 | # if it arises. 122 | if type(network_output) == list: 123 | tensors[-1]['logits'] = network_output[0] 124 | bn_inputs = network_output[1:] 125 | utils.add_bn_ops(model, 126 | bn_inputs, 127 | bn_momentum=bn_momentum) 128 | 129 | 130 | tensors[-1]['predictions'] = tf.argmax( 131 | tensors[-1]['logits'], 132 | axis=1) 133 | tensors[-1]['predicted_one_hot_targets'] = tf.one_hot( 134 | tensors[-1]['predictions'], 135 | self.num_classes) 136 | tensors[-1]['predicted_logits'] = tf.reduce_max( 137 | tensors[-1]['logits'], 138 | axis=1) 139 | tensors[-1]['probabilities'] = tf.nn.softmax( 140 | tensors[-1]['logits']) 141 | 142 | 143 | 144 | #### x-terms, b-terms #################### 145 | 146 | tensors[-1]['x_terms'] = Rop(tensors[-1]['logits'], 147 | tensors[-1]['images'], 148 | tensors[-1]['images']) 149 | tensors[-1]['b_terms'] = tensors[-1]['logits'] - tensors[-1]['x_terms'] 150 | tensors[-1]['predicted_b_terms'] = utils.select(tensors[-1]['b_terms'], 151 | tensors[-1]['predictions'], 152 | self.num_classes) 153 | 154 | if self.alpha is not None: 155 | tensors[-1]['taus'] = tensors[-1]['logits'] - self.alpha * tensors[-1]['x_terms'] 156 | 157 | 158 | #NUMPY SECTION 159 | if NUMPY_images is not None and NUMPY_labels is not None: 160 | NUMPY_network_output = model(NUMPY_images) 161 | tensors[-1]['NUMPY_logits'] = NUMPY_network_output 162 | if type(NUMPY_network_output) == list: 163 | tensors[-1]['NUMPY_logits'] = NUMPY_network_output[0] 164 | tensors[-1]['NUMPY_predictions'] = tf.argmax( 165 | tensors[-1]['NUMPY_logits'], 166 | axis=1) 167 | 168 | tensors[-1]['NUMPY_x_terms'] = Rop(tensors[-1]['NUMPY_logits'], 169 | NUMPY_images, 170 | NUMPY_images) 171 | tensors[-1]['NUMPY_b_terms'] = tensors[-1]['NUMPY_logits'] - tensors[-1]['NUMPY_x_terms'] 172 | 173 | 174 | tensors[-1]['NUMPY_selected_x_terms'] = utils.select( 175 | tensors[-1]['NUMPY_x_terms'], 176 | NUMPY_labels, 177 | self.num_classes) 178 | tensors[-1]['NUMPY_selected_b_terms'] = utils.select( 179 | tensors[-1]['NUMPY_b_terms'], 180 | NUMPY_labels, 181 | self.num_classes) 182 | 183 | if self.alpha is not None: 184 | NUMPY_taus = tensors[-1]['NUMPY_logits'] - self.alpha * tensors[-1]['NUMPY_x_terms'] 185 | 186 | tensors[-1]['NUMPY_selected_logits'] = utils.select( 187 | tensors[-1]['NUMPY_logits'], 188 | NUMPY_labels, 189 | self.num_classes) 190 | 191 | tensors[-1]['NUMPY_logit_sensitivities'] = tf.gradients( 192 | tf.reduce_sum(tensors[-1]['NUMPY_selected_logits']), 193 | NUMPY_images)[0] 194 | tensors[-1]['NUMPY_bias_shifted_images'] = bias_shifted_input( 195 | NUMPY_images, 196 | tensors[-1]['NUMPY_selected_b_terms'], 197 | tensors[-1]['NUMPY_logit_sensitivities']) 198 | 199 | 200 | 201 | ########################################## 202 | 203 | 204 | # Classification loss 205 | tensors[-1]['NLLs'] = tf.nn.softmax_cross_entropy_with_logits_v2( 206 | labels = tensors[-1]['one_hot_targets'], 207 | logits = tensors[-1]['logits'] 208 | ) 209 | scalars[-1]['mean_NLL'] = tf.reduce_mean(tensors[-1]['NLLs']) 210 | 211 | # Setting up the sensitivity penalty. 212 | if sensitivity_mode == 'logits': 213 | scalars[-1]['sum_of_true_logits'] = tf.reduce_sum( 214 | tensors[-1]['logits'] * tensors[-1]['one_hot_targets']) 215 | tensors[-1]['sensitivities'] = tf.gradients( 216 | scalars[-1]['sum_of_true_logits'], 217 | tensors[-1]['images'], 218 | name='input_gradients')[0] 219 | elif sensitivity_mode == 'NLL': 220 | tensors[-1]['sensitivities'] = tf.gradients( 221 | scalars[-1]['mean_NLL'], 222 | tensors[-1]['images'], 223 | name='input_gradients')[0] 224 | 225 | 226 | if use_wavelet_decomposition: 227 | sensitivity_w_decomp = multi_channel_fwt( 228 | tensors[-1]['sensitivities'], 229 | self.decomp_filters, 230 | self.decomp_depth, 231 | output_type = 'list') 232 | 233 | 234 | tensors[-1]['inner_products'] = tf.reduce_sum( 235 | tensors[-1]['images'] * tensors[-1]['sensitivities'], 236 | axis = [1,2,3]) 237 | 238 | tensors[-1]['sensitivity_norms'] = tf.sqrt(tf.reduce_sum( 239 | tensors[-1]['sensitivities']**2, 240 | axis=[1,2,3], 241 | name='sens_norm')) 242 | tensors[-1]['image_norms'] = tf.sqrt(tf.reduce_sum( 243 | tensors[-1]['images']**2, 244 | axis=[1,2,3], 245 | name='im_norm')) 246 | 247 | tensors[-1]['norm_products'] = tensors[-1]['sensitivity_norms'] * tensors[-1]['image_norms'] 248 | 249 | epsilon = 0.0 250 | tensors[-1]['correlations'] = tensors[-1]['inner_products'] / ( 251 | tensors[-1]['norm_products'] + epsilon) 252 | 253 | scalars[-1]['mean_correlation'] = tf.reduce_mean(tensors[-1]['correlations']) 254 | scalars[-1]['mean_inner_product'] = tf.reduce_mean(tensors[-1]['inner_products']) 255 | scalars[-1]['mean_norm_product'] = tf.reduce_mean(tensors[-1]['norm_products']) 256 | 257 | 258 | tensors[-1]['true_logits'] = tf.reduce_sum( 259 | tensors[-1]['logits'] * tensors[-1]['one_hot_targets'],axis=1) 260 | scalars[-1]['sum_of_true_logits'] = tf.reduce_sum( 261 | tensors[-1]['true_logits']) 262 | tensors[-1]['logit_sensitivities'] = tf.gradients( 263 | scalars[-1]['sum_of_true_logits'], 264 | tensors[-1]['images'], 265 | name='logit_input_gradients')[0] 266 | 267 | tensors[-1]['logit_inner_products'] = tf.reduce_sum( 268 | tensors[-1]['images'] * tensors[-1]['logit_sensitivities'], 269 | axis = [1,2,3]) 270 | 271 | tensors[-1]['logit_sensitivity_norms'] = tf.sqrt(tf.reduce_sum( 272 | tensors[-1]['logit_sensitivities']**2, 273 | axis=[1,2,3], 274 | name='sens_norm')) 275 | 276 | tensors[-1]['logit_norm_products'] = tensors[-1]['logit_sensitivity_norms'] * tensors[-1]['image_norms'] 277 | 278 | tensors[-1]['logit_correlations'] = tensors[-1]['logit_inner_products'] / \ 279 | (tensors[-1]['logit_norm_products'] + epsilon) 280 | 281 | scalars[-1]['mean_logit_correlation'] = tf.reduce_mean(tensors[-1]['logit_correlations']) 282 | scalars[-1]['mean_logit_inner_product'] = tf.reduce_mean(tensors[-1]['logit_inner_products']) 283 | scalars[-1]['mean_logit_norm_product'] = tf.reduce_mean(tensors[-1]['logit_norm_products']) 284 | 285 | 286 | 287 | 288 | # Again as a tiled image, for visualization. 289 | # Only do this if the dimensions work out. 290 | tiled_image_works = False 291 | if use_wavelet_decomposition: 292 | try: 293 | tensors[-1]['sensitivity_w_decomp_imgs'] = multi_channel_fwt( 294 | tensors[-1]['sensitivities'], 295 | self.decomp_filters, 296 | self.decomp_depth, 297 | output_type = 'image') 298 | tiled_image_works = True 299 | except tf.errors.OpError: 300 | print("Creating a tiled wavelet image failed.") 301 | 302 | 303 | # sum up all the p-norms of the FWTs of 304 | # all channels. 305 | if use_wavelet_decomposition: 306 | sensitivity_w_mean_lp = 0 307 | for decomp in sensitivity_w_decomp: 308 | sensitivity_w_mean_lp+= utils.lp_norm_weighted( 309 | decomp, 310 | self.nested_wavelet_weights, 311 | p_norm = self.p_norm) 312 | else: 313 | # Otherwise, just calculate the p-norm of the 314 | # sensitivity. 315 | sensitivity_w_mean_lp = utils.lp_norm(tensors[-1]['sensitivities'], 316 | p_norm = self.p_norm) 317 | 318 | scalars[-1]['sensitivity_w_mean_lp'] = sensitivity_w_mean_lp 319 | 320 | 321 | ############ ONLY FOR LOGGING PURPOSES ################### 322 | tensors[-1]['random_targets'] = tf.random_uniform(tf.shape(tensors[-1]['targets']), 323 | maxval = self.num_classes-1, 324 | dtype=tf.int32) 325 | 326 | tensors[-1]['random_one_hot_targets'] = tf.one_hot( 327 | tensors[-1]['random_targets'], 328 | self.num_classes) 329 | tensors[-1]['random_logits'] = tf.reduce_sum( 330 | tensors[-1]['logits'] * tensors[-1]['random_one_hot_targets'], 331 | axis=1) 332 | scalars[-1]['sum_of_random_logits'] = tf.reduce_sum( 333 | tensors[-1]['random_logits']) 334 | 335 | tensors[-1]['random_logit_sensitivities'] = tf.gradients( 336 | scalars[-1]['sum_of_random_logits'], 337 | tensors[-1]['images'], 338 | name='random_logit_sensitivities')[0] 339 | tensors[-1]['random_logit_inner_products'] = tf.reduce_sum( 340 | tensors[-1]['images']*tensors[-1]['random_logit_sensitivities'], 341 | axis=[1,2,3]) 342 | tensors[-1]['random_logit_sensitivity_norms'] = tf.sqrt(tf.reduce_sum( 343 | tensors[-1]['random_logit_sensitivities']**2, 344 | axis=[1,2,3])) 345 | 346 | 347 | scalars[-1]['sum_of_predicted_logits'] = tf.reduce_sum( 348 | tensors[-1]['predicted_logits']) 349 | tensors[-1]['predicted_logit_sensitivities'] = tf.gradients( 350 | scalars[-1]['sum_of_predicted_logits'], 351 | tensors[-1]['images'], 352 | name='predicted_logit_sensitivities')[0] 353 | tensors[-1]['predicted_logit_inner_products'] = tf.reduce_sum( 354 | tensors[-1]['images']*tensors[-1]['predicted_logit_sensitivities'], 355 | axis=[1,2,3]) 356 | tensors[-1]['predicted_logit_sensitivity_norms'] = tf.sqrt(tf.reduce_sum( 357 | tensors[-1]['predicted_logit_sensitivities']**2, 358 | axis=[1,2,3])) 359 | 360 | tensors[-1]['true_logit_sensitivities'] = tensors[-1]['logit_sensitivities'] 361 | tensors[-1]['true_logit_inner_products'] = tf.reduce_sum( 362 | tensors[-1]['images'] * tensors[-1]['true_logit_sensitivities'], 363 | axis = [1,2,3]) 364 | tensors[-1]['true_logit_sensitivity_norms'] = tf.sqrt(tf.reduce_sum( 365 | tensors[-1]['true_logit_sensitivities']**2, 366 | axis=[1,2,3])) 367 | 368 | 369 | 370 | # Calculate the bias gradients 371 | flatten = lambda a : tf.reshape(a,(-1,)) 372 | IP = lambda a,b : tf.reduce_sum(a*b) 373 | 374 | biases = [b for b in model.trainable_weights if 'bias' in b.name] 375 | biases+= tf.get_collection('bn_betas') 376 | biases+= tf.get_collection('bn_means') 377 | 378 | random_bias_gradients = tf.gradients( 379 | scalars[-1]['sum_of_random_logits'], 380 | biases, 381 | name='random_bias_gradients') 382 | 383 | 384 | random_bg = [IP(flatten(b),flatten(g)) for (b,g) in zip(biases, random_bias_gradients)] 385 | random_bias_inner_products = tf.accumulate_n(random_bg) 386 | 387 | predicted_bias_gradients = tf.gradients( 388 | scalars[-1]['sum_of_predicted_logits'], 389 | biases, 390 | name='predicted_bias_gradients') 391 | predicted_bg = [IP(flatten(b),flatten(g)) for (b,g) in zip(biases, predicted_bias_gradients)] 392 | predicted_bias_inner_products = tf.accumulate_n(predicted_bg) 393 | 394 | true_bias_gradients = tf.gradients( 395 | scalars[-1]['sum_of_true_logits'], 396 | biases, 397 | name='true_bias_gradients') 398 | 399 | 400 | true_bg = [IP(flatten(b),flatten(g)) for (b,g) in zip(biases, true_bias_gradients)] 401 | true_bias_inner_products = tf.add_n(true_bg) 402 | 403 | zero_image = tf.zeros_like(tensors[-1]['images']) 404 | tensors[-1]['zero_output'] = model(zero_image)[0] 405 | 406 | tensors[-1]['random_zero_logits'] = tf.reduce_sum( 407 | tensors[-1]['zero_output'] * tensors[-1]['random_one_hot_targets'], 408 | axis=1) 409 | tensors[-1]['predicted_zero_logits'] = tf.reduce_sum( 410 | tensors[-1]['zero_output'] * tensors[-1]['predicted_one_hot_targets'], 411 | axis=1) 412 | tensors[-1]['true_zero_logits'] = tf.reduce_sum( 413 | tensors[-1]['zero_output'] * tensors[-1]['one_hot_targets'], 414 | axis=1) 415 | 416 | 417 | 418 | # Calculate the approximate random robustness 419 | 420 | tensors[-1]['inner_product_differences'] = (tensors[-1]['predicted_logit_inner_products'] - 421 | tensors[-1]['random_logit_inner_products']) 422 | 423 | tensors[-1]['bias_differences'] = predicted_bias_inner_products - random_bias_inner_products 424 | 425 | numerator = tensors[-1]['inner_product_differences'] - tensors[-1]['bias_differences'] 426 | 427 | tensors[-1]['logit_sensitivity_differences'] = ( 428 | tensors[-1]['predicted_logit_sensitivities'] - 429 | tensors[-1]['random_logit_sensitivities']) 430 | denominator = tf.sqrt(tf.reduce_sum(tensors[-1]['logit_sensitivity_differences']**2)) 431 | 432 | tensors[-1]['approximate_random_robustness'] = numerator/denominator 433 | tensors[-1]['inner_product_differences_normalized'] = ( 434 | tensors[-1]['inner_product_differences'] / denominator) 435 | tensors[-1]['bias_differences_normalized'] = tensors[-1]['bias_differences'] / denominator 436 | 437 | tensors[-1]['bias_difference_shifted_images'] = bias_shifted_input( 438 | tensors[-1]['images'], 439 | tensors[-1]['bias_differences'], 440 | tensors[-1]['logit_sensitivity_differences']) 441 | 442 | 443 | #print(tensors[-1]['bias_differences_normalized']) 444 | #crash() 445 | ####################################################### 446 | 447 | 448 | 449 | # Collect the network's weights and set up 450 | # the weight decay penalty 451 | trainable_weights = model.trainable_weights 452 | scalars[-1]['weight_norm'] = tf.add_n( 453 | [tf.reduce_sum(w**2) for w in trainable_weights]) 454 | 455 | # Assemble the total loss for this GPU 456 | scalars[-1]['total_loss'] = scalars[-1]['mean_NLL'] 457 | scalars[-1]['total_loss']+= weight_decay_p * scalars[-1]['weight_norm'] 458 | if robust_regularization: 459 | scalars[-1]['sensitivity_penalty'] = lp_wavelet_p * scalars[-1]['sensitivity_w_mean_lp'] 460 | scalars[-1]['total_loss']+= scalars[-1]['sensitivity_penalty'] 461 | 462 | # Everything that is tracked during training 463 | # goes here. Top-5 and top-1 accuracies are 464 | # automatically added. 465 | summary_dict={ 466 | 'total_loss': scalars[-1]['total_loss'], 467 | 'mean_NLL': scalars[-1]['mean_NLL'], 468 | 'weight_2_norm_squared': scalars[-1]['weight_norm'], 469 | 'mean_sensitivity_wavelet_coeffs_lp': scalars[-1]['sensitivity_w_mean_lp']} 470 | 471 | # Add some hyperparameters, too. 472 | # Some redundant calculations through averaging 473 | # later, but the computational overhead is negligible. 474 | summary_dict['learning_rate_'] = learning_rate 475 | summary_dict['correlation_'] = scalars[-1]['mean_correlation'] 476 | summary_dict['inner_product_'] = scalars[-1]['mean_inner_product'] 477 | summary_dict['norm_product_'] = scalars[-1]['mean_norm_product'] 478 | summary_dict['logit_correlation_'] = scalars[-1]['mean_logit_correlation'] 479 | summary_dict['logit_inner_product_'] = scalars[-1]['mean_logit_inner_product'] 480 | summary_dict['logit_norm_product_'] = scalars[-1]['mean_logit_norm_product'] 481 | summary_dict['weight_decay_parameter_'] = weight_decay_p 482 | summary_dict['lp_Wavelet_parameter_'] = lp_wavelet_p 483 | summary_dict['total_batch_size'] = batch_size * self.num_GPUs 484 | summary_dict['bn_momentum_'] = bn_momentum 485 | summary_dict['p_norm'] = p_norm 486 | 487 | if robust_regularization: 488 | summary_dict['sensitivity_penalty'] = scalars[-1]['sensitivity_penalty'] 489 | 490 | 491 | summary_dict = summary_utils.prepare_summaries( 492 | summary_dict = summary_dict, 493 | predictions = tensors[-1]['probabilities'], 494 | labels = tensors[-1]['targets']) 495 | summaries.append(summary_dict) 496 | 497 | # Collect the gradients for every GPU 498 | gradients.append( 499 | optimizer.compute_gradients( 500 | scalars[-1]['total_loss'], 501 | var_list=trainable_weights, 502 | colocate_gradients_with_ops=True)) 503 | 504 | # So far, the adversarial attack model is only 505 | # created on one GPU. Different parallelized versions 506 | # always lead to errors. 507 | if dev == 0: 508 | self.adversarial_model = TensorFlowModel( 509 | tensors[-1]['images'], 510 | tensors[-1]['logits'], 511 | bounds=self.dataset.bounds) 512 | 513 | 514 | 515 | print("Done.") 516 | 517 | # Copy the lists 'tensors' and 'scalars' and replace these with an aggregated version: 518 | # Concatenate the tensors and average the scalars. 519 | self.tensors = dict() 520 | self.scalars = dict() 521 | for key in tensors[0].keys(): 522 | print(key) 523 | self.tensors[key] = tf.concat( 524 | [tensors_item[key] for tensors_item in tensors], 525 | axis=0) 526 | for key in scalars[0].keys(): 527 | self.scalars[key] = tf.reduce_mean( 528 | [scalars_item[key] for scalars_item in scalars]) 529 | 530 | # Create self.GPU_collections for backwards compatibility 531 | self.GPU_collections = {**self.tensors, **self.scalars} 532 | self.GPU_collections['top_1'] = tf.concat( 533 | tf.get_collection('top_1'),0) 534 | self.GPU_collections['top_5'] = tf.concat( 535 | tf.get_collection('top_5'),0) 536 | 537 | 538 | # Collection and apply the gradients over all used 539 | # GPUs for synchronous parallel training. 540 | avg_grads = utils.average_gradients(gradients) 541 | gradient_application = optimizer.apply_gradients(avg_grads) 542 | # We combine the gradient update and possibly the 543 | # batch normalization update operators into one. 544 | self.train_op = tf.group(gradient_application, 545 | *(tf.get_collection('bn_update_ops'))) 546 | 547 | summary_dict = summary_utils.collect_summaries( 548 | summaries) 549 | self.summary_op = summary_utils.create_summary_op( 550 | summary_dict) 551 | 552 | if use_wavelet_decomposition: 553 | wavelet_summary = tf.summary.tensor_summary('wavelet_weights', 554 | self.wavelet_weights) 555 | self.summary_op = tf.summary.merge([self.summary_op, 556 | wavelet_summary]) 557 | 558 | # Here, we create a tiled image summary for Tensorboard. 559 | # We hereby shift the range of the sensitivity and 560 | # possibly its decomposition to the range of the image. 561 | image_range = self.dataset.image_range() 562 | image_max = image_range[1] 563 | image_min = image_range[0] 564 | image_span = image_max - image_min 565 | image_mid = image_span / 2. 566 | 567 | 568 | self.images = self.dataset.interpret_as_image( 569 | self.GPU_collections['images']) 570 | self.saliencies = self.GPU_collections['sensitivities'] 571 | saliencies_max = tf.reduce_max(tf.abs(self.saliencies), 572 | [1,2], 573 | keepdims=True) 574 | normalized_saliencies = image_span * self.saliencies / \ 575 | (2*saliencies_max + 1e-9) + image_mid 576 | 577 | if use_wavelet_decomposition: 578 | self.saliency_decomps = self.GPU_collections[ 579 | 'sensitivity_w_decomp_imgs'] 580 | saliency_decomps_max = tf.reduce_max( 581 | tf.abs(self.saliency_decomps), 582 | [1,2], 583 | keepdims=True) 584 | normalized_decomps = image_span * self.saliency_decomps / \ 585 | (2*saliency_decomps_max + 1e-9) + image_mid 586 | 587 | 588 | composite_image = [self.images, 589 | normalized_saliencies] 590 | 591 | if tiled_image_works: 592 | composite_image.append(normalized_decomps) 593 | 594 | 595 | img_saliency_decomp = tf.concat( 596 | composite_image, 597 | 2) 598 | 599 | self.img_summary_op = tf.summary.image( 600 | 'img_saliency_decomp', 601 | img_saliency_decomp, 602 | max_outputs = 10) 603 | 604 | 605 | -------------------------------------------------------------------------------- /smallnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the small architecture used for the MNIST experiments. 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import tensorflow as tf 10 | import tensorflow.keras as keras 11 | from tensorflow.keras.layers import Conv2D, MaxPool2D, Input, Dense, Flatten, Dropout 12 | from tensorflow.keras.models import Model 13 | 14 | 15 | def create_model(input_tensor, 16 | input_shape, 17 | num_classes, 18 | pretrained=False): 19 | 20 | input_tensor = Input(tensor = input_tensor, 21 | shape = input_shape) 22 | x = Conv2D(32,3,activation='relu',padding='same')(input_tensor) 23 | x = MaxPool2D(2,strides=2)(x) 24 | 25 | x = Conv2D(64,3,activation='relu',padding='same')(x) 26 | x = MaxPool2D(2,strides=2)(x) 27 | 28 | x = Conv2D(128,3,activation='relu',padding='same')(x) 29 | x = MaxPool2D(2,strides=2)(x) 30 | 31 | x = Flatten()(x) 32 | x = Dense(128, activation='relu')(x) 33 | x = Dropout(.5)(x) 34 | x = Dense(num_classes, activation='linear')(x) 35 | 36 | return Model(inputs = input_tensor, 37 | outputs = x) -------------------------------------------------------------------------------- /summary_utils.py: -------------------------------------------------------------------------------- 1 | """" 2 | Some utility functions for creating the summaries. 3 | """ 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | 11 | def prepare_summaries(summary_dict=dict(), 12 | predictions=None, 13 | labels=None): 14 | if (predictions is not None) and (labels is not None): 15 | top_1 = tf.nn.in_top_k(predictions,labels, 1) 16 | top_1 = tf.cast(top_1, tf.float32) 17 | tf.add_to_collection('top_1',top_1) 18 | top_1_acc = tf.reduce_mean(top_1) 19 | summary_dict['top_1_accuracy'] = top_1_acc 20 | 21 | top_5 = tf.nn.in_top_k(predictions,labels, 5) 22 | top_5 = tf.cast(top_5, tf.float32) 23 | tf.add_to_collection('top_5',top_5) 24 | top_5_acc = tf.reduce_mean(top_5) 25 | summary_dict['top_5_accuracy'] = top_5_acc 26 | return summary_dict 27 | 28 | def collect_summaries(summary_dict_list): 29 | summary_dict = dict() 30 | n_dicts = np.float32(len(summary_dict_list)) 31 | for d in summary_dict_list: 32 | for key in d.keys(): 33 | # If the desired key is in the report_dict, append 34 | # the corresponding item to a list. Otherwise create 35 | # this list. 36 | if key in summary_dict.keys(): 37 | summary_dict[key].append(d[key]) 38 | else: 39 | summary_dict[key] = [d[key]] 40 | 41 | for key in summary_dict.keys(): 42 | summary_dict[key] = tf.add_n(summary_dict[key])/n_dicts 43 | return summary_dict 44 | 45 | def create_summary_op(summary_dict): 46 | summaries = [] 47 | for key in summary_dict.keys(): 48 | summary = tf.summary.scalar(key, summary_dict[key]) 49 | summaries.append(summary) 50 | return tf.summary.merge(summaries) -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | """ 2 | This class defines the training and validation pipeline. 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | from tensorflow.python.client import device_lib 12 | import os 13 | 14 | import foolbox as fb 15 | from foolbox.criteria import Misclassification 16 | from foolbox.adversarial import Adversarial 17 | from foolbox.distances import Linfinity, MSE 18 | 19 | class training: 20 | def __init__(self, 21 | handle, 22 | dataset, 23 | train_op, 24 | session, 25 | epoch_step, 26 | batch_step, 27 | summary_writer, 28 | train_summary_op, 29 | img_summary_op, 30 | optimizer, 31 | GPU_collections, 32 | batch_size_placeholder, 33 | pretrained = False, 34 | adversarial_model = None, 35 | adversarial_attacks = None, 36 | adversarial_criterion = Misclassification(), 37 | saver_path = "model.ckpt", 38 | num_adversarial_batches = 4, 39 | batch_size = 32, 40 | num_epochs = 1000, 41 | train_summary_period = 1000, 42 | val_summary_period = 1000, 43 | adv_summary_period = 1000): 44 | 45 | self.session = session 46 | 47 | self.saver_path = saver_path 48 | 49 | self.epoch = 0 50 | self.batch_i = 0 51 | self.handle = handle 52 | self.dataset = dataset 53 | self.train_op = train_op 54 | self.epoch_step = epoch_step 55 | self.epoch_step_increment = self.epoch_step.assign_add(1) 56 | self.batch_step = batch_step 57 | self.batch_placeholder = tf.placeholder(tf.int32,(), 58 | 'b_ph') 59 | self.batch_step_assign = tf.assign(self.batch_step, 60 | self.batch_placeholder) 61 | self.num_epochs = num_epochs 62 | self.batch_size = batch_size 63 | 64 | self.optimizer = optimizer 65 | self.GPU_collections = GPU_collections 66 | self.batch_size_placeholder = batch_size_placeholder 67 | 68 | # summary ops 69 | self.train_summary_op = train_summary_op 70 | self.img_summary_op = img_summary_op 71 | self.train_summary_period = train_summary_period 72 | self.val_summary_period = val_summary_period 73 | self.adv_summary_period = adv_summary_period 74 | self.summary_writer = summary_writer 75 | 76 | # validation 77 | self.val_top_one_mean = tf.placeholder( 78 | tf.float32, name='val_top_one_mean') 79 | self.val_top_five_mean = tf.placeholder( 80 | tf.float32, name='val_top_five_mean') 81 | val_summaries = [] 82 | val_summaries.append(tf.summary.scalar( 83 | 'top_1_accuracy_validation', 84 | self.val_top_one_mean)) 85 | val_summaries.append(tf.summary.scalar( 86 | 'top_5_accuracy_validation', 87 | self.val_top_five_mean)) 88 | self.val_summary_op = tf.summary.merge( 89 | val_summaries, 90 | name = 'val_summaries_op') 91 | 92 | # Adversarial attacks 93 | self.num_adversarial_batches = num_adversarial_batches 94 | self.adversarial_criterion = adversarial_criterion 95 | 96 | self.adv_result = tf.placeholder( 97 | tf.float32, name='adv_results') 98 | self.adversarial_attacks = adversarial_attacks 99 | self.adversarial_model = adversarial_model 100 | 101 | default_distances = { 102 | 'GradientAttack' : MSE, 103 | 'FGSM' : MSE, 104 | 'LinfinityBasicIterativeAttack' : Linfinity, 105 | 'L2BasicIterativeAttack' : MSE, 106 | 'LinfinityBasicIterativeAttack' : Linfinity, 107 | 'ProjectedGradientDescentAttack' : Linfinity, 108 | 'DeepFoolAttack' : MSE, 109 | 'DeepFoolLinfinityAttack' : Linfinity} 110 | 111 | self.attacks = dict() 112 | self.distances = dict() # add support for custom distances 113 | self.adv_summaries = dict() 114 | 115 | 116 | for attack in self.adversarial_attacks: 117 | self.attacks[attack] = getattr(fb.attacks, 118 | attack)() 119 | if attack in default_distances.keys(): 120 | self.distances[attack] = default_distances[attack] 121 | else: 122 | self.distances[attack] = MSE 123 | 124 | key = attack + '_median_dist' 125 | 126 | self.adv_summaries[attack] = tf.summary.scalar( 127 | attack + '_median_dist', 128 | self.adv_result) 129 | 130 | devices = device_lib.list_local_devices() 131 | GPU_devices = [dev.name for dev in devices 132 | if dev.device_type=='GPU'] 133 | self.num_GPUs = len(GPU_devices) 134 | 135 | self.pretrained = pretrained 136 | 137 | 138 | if self.dataset.train_handle is None: 139 | self.dataset.get_train_handle(self.session) 140 | 141 | self.saver = tf.train.Saver(tf.global_variables()) 142 | 143 | def train(self, 144 | training_feed_dict, 145 | val_feed_dict = {}, 146 | do_not_reload_checkpoint = False, 147 | do_not_save = False): 148 | if (os.path.isfile(self.saver_path + '.index') and not 149 | do_not_reload_checkpoint): 150 | self.restore_model() 151 | self.epoch = self.session.run(self.epoch_step) 152 | self.batch_i = self.session.run(self.batch_step)+1 153 | elif not self.pretrained: 154 | print("Initializing variables...") 155 | self.session.run(tf.global_variables_initializer()) 156 | print("Done.") 157 | else: 158 | # When starting from a pretrained network, 159 | # only initialize the variables that we potentially added 160 | # when we implemented the parallelized batch normalization 161 | # update operators. 162 | print("Initializing batch normalization variables...") 163 | for var in tf.global_variables(): 164 | if "biased" in var.name or "local_step" in var.name: 165 | tf.add_to_collection('uninitialized_variables',var) 166 | tf.add_to_collection('uninitialized_variables',self.epoch_step) 167 | tf.add_to_collection('uninitialized_variables',self.batch_step) 168 | uninitialized_vars = tf.get_collection('uninitialized_variables') 169 | self.session.run(tf.variables_initializer( 170 | uninitialized_vars)) 171 | print("Done.") 172 | self.session.run(tf.variables_initializer( 173 | self.optimizer.variables())) 174 | 175 | if val_feed_dict: 176 | self.validate(val_feed_dict) 177 | self.adversarial(self.adversarial_attacks, 178 | val_feed_dict, 179 | num_batches = self.num_adversarial_batches) 180 | 181 | if self.dataset.train_handle is None: 182 | self.dataset.get_train_handle(self.session) 183 | training_feed_dict[self.handle] = self.dataset.train_handle 184 | 185 | 186 | path = self.saver_path 187 | base, ext = os.path.splitext(path) 188 | new_path = base + '_' + str(self.batch_i) + ext 189 | self.saver.save(self.session, 190 | new_path) 191 | 192 | print("Beginning training...") 193 | if self.epoch >= self.num_epochs: 194 | print("End of training reached.") 195 | while self.epoch < self.num_epochs: 196 | try: 197 | self.dataset.initialize_train_batch_iterator( 198 | self.session, 199 | self.batch_size) 200 | while True: 201 | try: 202 | train_output = self.session.run( 203 | self.train_op, 204 | training_feed_dict) 205 | 206 | self.batch_i+= 1 207 | if self.batch_i % self.train_summary_period == 0: 208 | train_output, summary_str = self.session.run( 209 | [self.train_op,self.train_summary_op], 210 | training_feed_dict) 211 | self.summary_writer.add_summary(summary_str, 212 | self.batch_i) 213 | self.summary_writer.flush() 214 | self.update_batch_step() 215 | if not do_not_save: 216 | self.save_model() 217 | if (self.batch_i % self.val_summary_period == 0 218 | and val_feed_dict): 219 | self.validate(val_feed_dict, 220 | do_not_save = do_not_save) 221 | 222 | if (self.batch_i % self.adv_summary_period == 0 223 | and val_feed_dict): 224 | self.adversarial(self.adversarial_attacks, 225 | val_feed_dict, 226 | num_batches = self.num_adversarial_batches, 227 | do_not_save = do_not_save) 228 | except tf.errors.OutOfRangeError: 229 | self.epoch+= 1 230 | self.session.run( 231 | self.epoch_step_increment) 232 | break 233 | except KeyboardInterrupt: 234 | print('\nCancelled') 235 | break 236 | 237 | print("Training finished.") 238 | if val_feed_dict: 239 | print("Performing final validation...") 240 | self.validate(val_feed_dict) 241 | print("Performing final adversarial tests...") 242 | self.adversarial(self.adversarial_attacks, 243 | val_feed_dict, 244 | num_batches = self.num_adversarial_batches) 245 | print("Done.") 246 | 247 | if not do_not_save: 248 | print("Saving completed model...") 249 | self.saver.save(self.session, self.saver_path) 250 | print("Successfully saved completed model.") 251 | 252 | 253 | 254 | 255 | def validate(self, val_feed_dict, do_not_save = False): 256 | 257 | self.dataset.initialize_val_batch_iterator(self.session, 258 | self.batch_size) 259 | 260 | val_feed_dict[self.handle] = self.dataset.val_handle 261 | 262 | 263 | # Calculate validation error # 264 | val_in_top_five = np.zeros(self.dataset.num_val_samples, 265 | np.float32) 266 | val_in_top_one = np.zeros(self.dataset.num_val_samples, 267 | np.float32) 268 | step_size = self.batch_size * self.num_GPUs 269 | 270 | l_index = u_index = 0 271 | while u_index < self.dataset.num_val_samples: 272 | try: 273 | top_fives, top_ones = self.session.run( 274 | [self.GPU_collections['top_5'], 275 | self.GPU_collections['top_1']], 276 | feed_dict=val_feed_dict) 277 | n_images = len(top_fives) 278 | u_index = l_index + n_images 279 | val_in_top_five[l_index:u_index] = top_fives 280 | val_in_top_one[l_index:u_index] = top_ones 281 | l_index = u_index 282 | except tf.errors.OutOfRangeError: 283 | # This handles an error that only appears when 284 | # there are 2 GPUs with batch size 16 each... 285 | # More documentation to follow. 286 | break 287 | 288 | val_in_top_five = val_in_top_five[:u_index] 289 | val_in_top_one = val_in_top_one[:u_index] 290 | 291 | val_top_five_accuracy = sum( 292 | val_in_top_five)/np.float32( 293 | self.dataset.num_val_samples) 294 | val_top_one_accuracy = sum( 295 | val_in_top_one)/np.float32( 296 | self.dataset.num_val_samples) 297 | 298 | val_summary = self.session.run(self.val_summary_op, 299 | feed_dict={self.val_top_one_mean : val_top_one_accuracy, 300 | self.val_top_five_mean : val_top_five_accuracy}) 301 | self.summary_writer.add_summary(val_summary, 302 | self.batch_i) 303 | 304 | 305 | 306 | # IMG SUMMARIES 307 | self.dataset.initialize_img_batch_iterator(self.session, 308 | self.batch_size) 309 | val_feed_dict[self.handle] = self.dataset.img_handle 310 | img_summary_str = self.session.run(self.img_summary_op, 311 | feed_dict=val_feed_dict) 312 | self.summary_writer.add_summary(img_summary_str, 313 | self.batch_i) 314 | self.summary_writer.flush() 315 | self.update_batch_step() 316 | if not do_not_save: 317 | self.save_model() 318 | 319 | 320 | def adversarial(self, 321 | adversarial_attacks, 322 | adv_feed_dict, 323 | distances_dict = {}, 324 | num_batches = 4, 325 | do_not_save = False): 326 | results = dict() 327 | for attack in adversarial_attacks: 328 | results[attack] = [] 329 | 330 | self.dataset.initialize_img_batch_iterator(self.session, 331 | self.batch_size) 332 | adv_feed_dict[self.handle] = self.dataset.img_handle 333 | 334 | for run in range(num_batches): 335 | [images, labels] = self.session.run( 336 | [self.GPU_collections['images'], 337 | self.GPU_collections['predictions']], 338 | feed_dict=adv_feed_dict) 339 | 340 | for attack_name in adversarial_attacks: 341 | attack = self.attacks[attack_name] 342 | for i in range(len(images)): 343 | adversarial = Adversarial( 344 | self.adversarial_model, 345 | self.adversarial_criterion, 346 | images[i], 347 | labels[i], 348 | distance = self.distances[attack_name]) 349 | att = attack(adversarial) 350 | dist = adversarial.distance.value 351 | if dist > 0: 352 | results[attack_name].append(dist) 353 | 354 | for attack_name in self.adversarial_attacks: 355 | median_dist = np.median(results[attack_name]) 356 | adv_summary_str = self.session.run(self.adv_summaries[attack_name], 357 | feed_dict = {self.adv_result : median_dist}) 358 | 359 | self.summary_writer.add_summary(adv_summary_str, 360 | self.batch_i) 361 | 362 | self.summary_writer.flush() 363 | self.update_batch_step() 364 | if not do_not_save: 365 | self.save_model() 366 | 367 | def update_batch_step(self): 368 | self.session.run( 369 | self.batch_step_assign, 370 | feed_dict = {self.batch_placeholder : self.batch_i}) 371 | 372 | def restore_model(self): 373 | print("Trying to load old model checkpoint...") 374 | self.saver.restore(self.session, 375 | self.saver_path) 376 | print("Successfully loaded old model checkpoint.") 377 | 378 | def save_model(self): 379 | self.saver.save(self.session, 380 | self.saver_path) 381 | print("Model saved.") 382 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some utility functions. Most are not used in the paper, however. 3 | """ 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | import tensorflow as tf 14 | import numpy as np 15 | import tensorflow.keras.backend as K 16 | from tensorflow.python.training import moving_averages 17 | 18 | # The following function is inspired by the Tensorflow Multi-GPU 19 | # implementation example: 20 | # https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py 21 | 22 | def average_gradients(GPU_grads): 23 | resulting_averages = [] 24 | for grad_vars in zip(*GPU_grads): 25 | gradients = [] 26 | for gradient, _ in grad_vars: 27 | GPUs_and_gradients = tf.expand_dims( 28 | gradient, 0) 29 | gradients.append(GPUs_and_gradients) 30 | 31 | gradients = tf.concat( 32 | axis=0, 33 | values=gradients) 34 | gradients = tf.reduce_mean( 35 | gradients, 36 | 0) 37 | variables = grad_vars[0][1] 38 | gradients_and_variables = ( 39 | gradients, 40 | variables) 41 | resulting_averages.append(gradients_and_variables) 42 | return resulting_averages 43 | 44 | 45 | def edge_filter(filter_type="simple", 46 | edges_per_channel=False, 47 | greyscale=False): 48 | ''' 49 | Helper function if we want to calculate the T(G)V 50 | of the images in a batch. 51 | ''' 52 | if filter_type == "simple": 53 | # Only needs to be square because of the 'valid' convolution 54 | # One-sided difference operator 55 | edge_filter_x_atom = np.array([[1,-1], 56 | [0,0]], 57 | dtype=np.float32) 58 | if filter_type == "sobel": 59 | # Sobel-Operator 60 | edge_filter_x_atom = np.array([[1,0,-1], 61 | [2,0,-2], 62 | [1,0,-1]], 63 | dtype=np.float32) 64 | if filter_type == "scharr": 65 | # Scharr-Operator 66 | edge_filter_x_atom = np.array([[3, 0,-3], 67 | [10,0,-10], 68 | [3, 0,-3]], 69 | dtype=np.float32) 70 | if filter_type == "laplace": 71 | # Laplace-Operator 72 | edge_filter_x_atom = np.array([[0,-1,0], 73 | [-1,4,-1], 74 | [0,-1,0]], 75 | dtype=np.float32) 76 | edge_filter_y_atom = edge_filter_x_atom.T 77 | 78 | filter_size_x = edge_filter_x_atom.shape[1] 79 | filter_size_y = edge_filter_x_atom.shape[0] 80 | 81 | # Right format: NHWC images => HWIO filters 82 | if edges_per_channel: 83 | # Version 1: Edges on each color channel 84 | edge_filter_x = np.zeros((filter_size_y,filter_size_x,3,3), 85 | dtype=np.float32) 86 | edge_filter_x[:,:,0,0] = edge_filter_x_atom 87 | edge_filter_x[:,:,1,1] = edge_filter_x_atom 88 | edge_filter_x[:,:,2,2] = edge_filter_x_atom 89 | 90 | edge_filter_y = np.zeros((filter_size_x,filter_size_y,3,3), 91 | dtype=np.float32) 92 | edge_filter_y[:,:,0,0] = edge_filter_y_atom 93 | edge_filter_y[:,:,1,1] = edge_filter_y_atom 94 | edge_filter_y[:,:,2,2] = edge_filter_y_atom 95 | 96 | else: 97 | # Version 2: Edges on grayscale images 98 | # Might lead to differently colored edges equalizing 99 | edge_filter_x = np.zeros((filter_size_y,filter_size_x,3,1), 100 | dtype=np.float32) 101 | edge_filter_x[:,:,0,0] = edge_filter_x_atom 102 | edge_filter_x[:,:,1,0] = edge_filter_x_atom 103 | edge_filter_x[:,:,2,0] = edge_filter_x_atom 104 | 105 | edge_filter_y = np.zeros((filter_size_x,filter_size_y,3,1), 106 | dtype=np.float32) 107 | edge_filter_y[:,:,0,0] = edge_filter_y_atom 108 | edge_filter_y[:,:,1,0] = edge_filter_y_atom 109 | edge_filter_y[:,:,2,0] = edge_filter_y_atom 110 | if greyscale: 111 | # Version 2: Edges on grayscale images 112 | # Might lead to differently colored edges equalizing 113 | edge_filter_x = np.zeros((filter_size_y,filter_size_x,1,1), 114 | dtype=np.float32) 115 | edge_filter_x[:,:,0,0] = edge_filter_x_atom 116 | 117 | edge_filter_y = np.zeros((filter_size_x,filter_size_y,1,1), 118 | dtype=np.float32) 119 | edge_filter_y[:,:,0,0] = edge_filter_y_atom 120 | 121 | edge_filter_x = tf.constant(edge_filter_x, 122 | dtype=tf.float32) 123 | edge_filter_y = tf.constant(edge_filter_y, 124 | dtype=tf.float32) 125 | 126 | return edge_filter_x, edge_filter_y 127 | 128 | def isotropic_TV(tensor, 129 | normalize=True, 130 | filter_type="simple", 131 | edges_per_channel=False, 132 | eps=1e-6): 133 | ''' 134 | Isotropic mean total variation of a batch. 135 | ''' 136 | edge_filter_x, edge_filter_y = edge_filter(filter_type, 137 | edges_per_channel) 138 | if tensor.shape[-1] == 1: 139 | edge_filter_x, edge_filter_y = edge_filter(filter_type, 140 | greyscale=True) 141 | if normalize: 142 | tensor_norm = tf.sqrt(tf.reduce_sum( 143 | tensor**2,axis=[1,2,3],keepdims=True) + eps ) 144 | tensor = tensor / tensor_norm 145 | edges_x_of_grads = tf.nn.conv2d((tensor), 146 | edge_filter_x, 147 | strides=[1,1,1,1], 148 | padding='VALID') 149 | edges_y_of_grads = tf.nn.conv2d((tensor), 150 | edge_filter_y, 151 | strides=[1,1,1,1], 152 | padding='VALID') 153 | edge_image = tf.sqrt(edges_x_of_grads**2 + edges_y_of_grads**2 + eps) 154 | isotropic_TV = tf.reduce_sum(edge_image,axis=[1,2,3]) 155 | iso_TV = tf.reduce_mean(isotropic_TV, name='iso_TV') 156 | tf.add_to_collection('TV_losses', iso_TV) 157 | return iso_TV 158 | 159 | def anisotropic_TV(tensor, 160 | normalize=True, 161 | filter_type="simple", 162 | edges_per_channel=False, 163 | eps=1e-6): 164 | ''' 165 | Anisotropic mean total variation of a batch. 166 | ''' 167 | edge_filter_x, edge_filter_y = edge_filter(filter_type, 168 | edges_per_channel) 169 | if tensor.shape[-1] == 1: 170 | edge_filter_x, edge_filter_y = edge_filter(filter_type, 171 | greyscale=True) 172 | if normalize: 173 | tensor_norm = tf.sqrt(tf.reduce_sum( 174 | tensor**2,axis=[1,2,3],keepdims=True) + eps ) 175 | tensor = tensor / tensor_norm 176 | edges_x_of_grads = tf.nn.conv2d((tensor), 177 | edge_filter_x, 178 | strides=[1,1,1,1], 179 | padding='VALID') 180 | edges_y_of_grads = tf.nn.conv2d((tensor), 181 | edge_filter_y, 182 | strides=[1,1,1,1], 183 | padding='VALID') 184 | edge_image = tf.abs(edges_x_of_grads) + tf.abs(edges_y_of_grads) 185 | anisotropic_TV = tf.reduce_sum(edge_image,axis=[1,2,3]) 186 | aniso_TV = tf.reduce_mean(anisotropic_TV, name='aniso_TV') 187 | tf.add_to_collection('TV_losses', aniso_TV) 188 | return aniso_TV 189 | 190 | 191 | def lp_norm(tensor, 192 | p_norm=2, 193 | normalize=False, 194 | eps=1e-6): 195 | ''' 196 | The p-norm of a tensor. Can be l2-normalized 197 | optionally. 198 | ''' 199 | if tensor is list: 200 | a_1 = lp_norm(tensor[0], p_norm = p_norm) 201 | d_1 = lp_norm(tensor[1], p_norm = p_norm) 202 | d_2 = lp_norm(tensor[2], p_norm = p_norm) 203 | d_3 = lp_norm(tensor[3], p_norm = p_norm) 204 | else: 205 | if normalize: 206 | tensor_norm = tf.sqrt(tf.reduce_sum( 207 | tensor**2,axis=[1,2,3],keepdims=True) + eps ) 208 | tensor = tensor / tensor_norm 209 | norm = tf.reduce_sum(tf.abs(tensor)**p_norm, 210 | axis=[1,2,3], 211 | name='regularization_lp') 212 | avg_norm = tf.reduce_mean(norm) 213 | return avg_norm 214 | 215 | def lp_norm_weighted(tensor, 216 | weights, 217 | p_norm = 2): 218 | ''' 219 | A nested weight list and a wavelet decomposition 220 | are combined to form a weighted p-norm of 221 | the wavelet decomposition. 222 | ''' 223 | if type(tensor) == list: 224 | [c_0,c_1,c_2,c_3] = weights 225 | 226 | 227 | # c_0 and a_0 are nested lists 228 | a_0 = lp_norm_weighted(tensor[0], c_0, p_norm) 229 | d_1 = lp_norm_weighted(tensor[1], c_1, p_norm) 230 | d_2 = lp_norm_weighted(tensor[2], c_2, p_norm) 231 | d_3 = lp_norm_weighted(tensor[3], c_3, p_norm) 232 | 233 | return a_0 + d_1 + d_2 + d_3 234 | else: 235 | norm = tf.reduce_sum(tf.abs(tensor)**p_norm, 236 | axis=[1,2,3], 237 | name='regularization_lp') 238 | avg_norm = tf.reduce_mean(norm) 239 | if type(weights) != list: 240 | avg_norm*= weights 241 | return avg_norm 242 | 243 | def nested_weight_list(weight_list): 244 | ''' 245 | This function creates a nested list of weights that 246 | follows the same list structure as the wavelet 247 | decomposition. 248 | ''' 249 | if len(weight_list)>1: 250 | w = weight_list[0] 251 | return [nested_weight_list(weight_list[1:]), 252 | w,w,w] 253 | else: 254 | return weight_list[0] 255 | 256 | 257 | def add_bn_ops(model,bn_inputs,bn_momentum=.999): 258 | ''' 259 | If we propagate a tf.tensor through the keras 260 | model for different GPUs, the update operators 261 | do not get updated correctly. This is why we 262 | create the correct update operators ourselves 263 | manually here. 264 | ''' 265 | for l in model.layers: 266 | # Check for BatchNorm layers 267 | if hasattr(l,'gamma'): 268 | tf.add_to_collection('bn_gammas', 269 | l.gamma) 270 | tf.add_to_collection('bn_betas', 271 | l.beta) 272 | tf.add_to_collection('bn_means', 273 | l.moving_mean) 274 | tf.add_to_collection('bn_vars', 275 | l.moving_variance) 276 | tf.add_to_collection('bn_initializers', 277 | l.moving_mean.initializer) 278 | tf.add_to_collection('bn_initializers', 279 | l.moving_variance.initializer) 280 | bn_update_ops = [] 281 | bn_means = tf.get_collection('bn_means') 282 | bn_vars = tf.get_collection('bn_vars') 283 | for act, m, v in zip(bn_inputs, bn_means, bn_vars): 284 | input_shape = K.int_shape(act) 285 | reduction_axes = list(range(len(input_shape))) 286 | del reduction_axes[-1] 287 | mean, variance = tf.nn.moments(act, reduction_axes) 288 | m_mean_op = moving_averages.assign_moving_average(m, 289 | mean, 290 | bn_momentum, 291 | zero_debias = False) 292 | tf.add_to_collection('bn_update_ops',m_mean_op) 293 | m_var_op = moving_averages.assign_moving_average(v, 294 | variance, 295 | bn_momentum, 296 | zero_debias = False) 297 | tf.add_to_collection('bn_update_ops',m_var_op) 298 | 299 | def select(tensor, indices, max_value): 300 | sparse_tensor = tensor*tf.one_hot(indices, max_value) 301 | selected_values = tf.reduce_sum(sparse_tensor, 302 | axis = 1) 303 | return selected_values --------------------------------------------------------------------------------