├── .gitignore ├── LICENSE.txt ├── ReadMe.md ├── adaptive_instance_norm.py ├── attack.py ├── cifar10 ├── __init__.py └── cifar10_input.py ├── dataprep.py ├── decoder.py ├── encoder.py ├── imagenetmod ├── __init__.py ├── adv_model.py ├── interface.py ├── main.py ├── nets.py ├── resnet_model.py └── third_party │ ├── README.md │ ├── __init__.py │ ├── imagenet_utils.py │ ├── serve-data.py │ └── utils.py ├── modelprep.py ├── models ├── __init__.py ├── cifar10_class.py ├── pretrained │ ├── __init__.py │ ├── interface.py │ ├── resnet_slim.py │ └── resnet_utils.py └── trade_interface.py ├── samples ├── adv_training.jpg ├── attack1.jpg ├── attack2.jpg ├── attacking_phase.PNG ├── electric_guitar.jpg ├── espresso.jpg ├── human_preference.png ├── llama.jpg ├── printer.jpg ├── samples_diff_attack.jpg └── training_phase.PNG ├── settings.py ├── style_transfer_net.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | ./store 3 | *.swp 4 | *.zip 5 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) <2020> 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /ReadMe.md: -------------------------------------------------------------------------------- 1 | # Feature Space Attack 2 | > This is the official coder for paper Xu et al. [Towards Feature Space Adversarial Attack](https://arxiv.org/abs/2004.12385). The code is developed on the basis from [Ye](https://github.com/elleryqueenhomels/arbitrary_style_transfer). 3 | 4 | ## Table of contents 5 | * [Description](#description) 6 | * [Adversarial Samples](#Adversarial-Samples) 7 | * [Prerequisites](#Prerequisites) 8 | * [Pretrained Model](#Pretrained-Model) 9 | * [Tutorial](#Tutorial) 10 | * [Result](#Result) 11 | * [Features](#features) 12 | * [Environment](#Environment) 13 | * [Contact](#contact) 14 | * [Citation](#Citation) 15 | 16 | ## Description 17 | This project provides a general way to construct adversarial attack in feature space! Different from common works on pixel-space, in this project, we aimed to find a suitable feature space for adversarial attack. With this spirit, we leverage style transfer and propose a two-phase feature-space attack. The first phase (a) is to ensure that feature-space perturbation can be restored back into pixel-space. The second phase (b) is to find such adversarial perturbation within a proper bound. 18 | 19 | ![Training Phase](./samples/training_phase.PNG) 20 | 21 | ![Training Phase](./samples/attacking_phase.PNG) 22 | 23 | ## Adversarial Samples 24 | 25 | The first row is the benign image and the second row is the our adversarial samples. The last row visualizes their differences. Notice the difference can be categorized as color, texture and implicit style changes. 26 | 27 | | Espresso | Llama | Printer | Guitar | 28 | | :-------------------------: | :----------------------: | -------------------------- | ---------------------------------- | 29 | | ![](./samples/espresso.jpg) | ![](./samples/llama.jpg) | ![](./samples/printer.jpg) | ![](./samples/electric_guitar.jpg) | 30 | 31 | ## Prerequisites 32 | [ImageNet](http://www.image-net.org/) : Create subdirectories "imagenet", extract "train", "val" from ImageNet there. 33 | 34 | [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) : Create subdirectories "cifar10_data", extract downloaded files there. 35 | 36 | [VGG19]( https://qiulingxu-public.s3.us-east-2.amazonaws.com/FSA/vgg19_normalised.zip) : Extract it in the root 37 | 38 | ## Pre-trained Model 39 | 40 | - [Pretrained Decoder for ImageNet](https://qiulingxu-public.s3.us-east-2.amazonaws.com/FSA/Imagenet_Decoder.zip) 41 | - [Pretrained Decoder for CIFAR10](https://qiulingxu-public.s3.us-east-2.amazonaws.com/FSA/CIFAR10_Decoder.zip) 42 | - [Default Classifiers for Attack](https://qiulingxu-public.s3.us-east-2.amazonaws.com/FSA/Classifiers.zip) 43 | 44 | Download and extract them in the root directories or you can instead use your models. 45 | 46 | ## Tutorial 47 | 48 | Two files are needed for performing the attack. If you already downloaded the pre-trained decoder. You can skip the step 1. 49 | 50 | ### Step 1: 51 | 52 | Run`python train.py --dataset="[dataset]" --decoder=[depth of decoder]`. 53 | 54 | A deeper decoder injects more harmful perturbation but is less nature-looking. The command looks like the follows: 55 | 56 | - `python train.py --dataset="imagenet" --decoder=3` 57 | - `python train.py --dataset="cifar10" --decoder=3 --scale` 58 | 59 | ### Step 2: 60 | 61 | Run `python attack.py --dataset="[dataset]" --decoder=[depth of decoder] --model="[name of classifier]"` e.g. 62 | 63 | - `python attack.py --dataset="imagenet" --decoder=3 --model="imagenet_denoise"` 64 | - `python attack.py --dataset="cifar10" --decoder=3 --scale --model="cifar10_adv"` 65 | 66 | Note that for CIFAR10 dataset, you need to choose whether to scale up the image size to match for VGG19's input. Scaling up will increase the quality of attacks while consumes more memory. If you don't want to scale it up, remove the scale option and set decoder to 1 in the command. 67 | 68 | The generated image can be found at "store" subdirectories. 69 | 70 | ## Result 71 | 72 | ### Accuracy under Attack 73 | 74 | | | | 75 | | -------------------------- | -------------------------- | 76 | | ![](./samples/attack1.jpg) | ![](./samples/attack2.jpg) | 77 | 78 | *The result shows that defense on pixel-space can hardly ensure robustness on feature space. We set decoder=1 for the smaller dataset in the first table, and set decoder=3 for Imagenet. We set the bound=1.5 for untargeted attack and 2 for targeted attack.* 79 | 80 | ### Human Preference Rate 81 | 82 | ![Human_Preference](./samples/human_preference.png) 83 | 84 | *We employ targeted attack on Imagenet Resnet50 v1 Model. We report the successful rate and corresponding human preference rate under different bound. We choose Imagenet and decoder=1 for this experiment.* 85 | 86 | ### Adversarial Training 87 | 88 | ![](./samples/adv_training.jpg) 89 | 90 | *The result shows that adversarial training on feature space or pixel-space is useful to related attacks, but not each other. Thus people need to consider both cases for well-round defenses.* 91 | 92 | ### Different Attacks 93 | 94 | ![](./samples/samples_diff_attack.jpg) 95 | 96 | ## Features 97 | 98 | I reorganize the code for better structure. Let me know if you run into errors. Some of the function is not polished and not public yet. 99 | 100 | * Implemented two phase algorithm 101 | * Support for Feature Argumentation Attack 102 | 103 | To-do list: 104 | * Play and plug on any model 105 | * Feature Interpolation Attack 106 | 107 | ## Environment 108 | 109 | - The code is tested on Python 3.6 + Tensorflow 1.15 + Tensorpack + Ubuntu 18.04 110 | - We test the program on GTX 2080TI. If you have a card with small memory, please consider decrease the "BATCH_SIZE" in "settings.py" . 111 | - To setup the environment, please download the code and model here. 112 | 113 | ## Contact 114 | Created by [@Qiuling Xu](https://www.cs.purdue.edu/homes/xu1230/) - feel free to contact me! 115 | 116 | ## Citation 117 | 118 | >@misc{xu2020feature, 119 | > 120 | >title={Towards Feature Space Adversarial Attack}, 121 | > 122 | >author={Qiuling Xu and Guanhong Tao and Siyuan Cheng and Lin Tan and Xiangyu Zhang}, 123 | > 124 | >year={2020}, 125 | > 126 | >eprint={2004.12385}, 127 | > 128 | >archivePrefix={arXiv}, 129 | > 130 | >primaryClass={cs.LG} 131 | >} -------------------------------------------------------------------------------- /adaptive_instance_norm.py: -------------------------------------------------------------------------------- 1 | # Adaptive Instance Normalization 2 | 3 | import tensorflow as tf 4 | import settings 5 | 6 | def AdaIN(content, style, epsilon=1e-5): 7 | meanC, varC = tf.nn.moments(content, [1, 2], keep_dims=True) 8 | meanS, varS = tf.nn.moments(style, [1, 2], keep_dims=True) 9 | 10 | sigmaC = tf.sqrt(tf.add(varC, epsilon)) 11 | sigmaS = tf.sqrt(tf.add(varS, epsilon)) 12 | 13 | return (content - meanC) * sigmaS / sigmaC + meanS, meanS, sigmaS 14 | 15 | def AdaIN_adv_tanh(content, epsilon=1e-5): 16 | meanC, varC = tf.nn.moments(content, [1, 2], keep_dims=True) 17 | bs = settings.config["BATCH_SIZE"] 18 | content_shape = content.shape.as_list() 19 | new_shape = [bs, 1, 1, content_shape[3]] 20 | with tf.variable_scope("scale"): 21 | sigmaS = tf.get_variable("sigma_S", shape=new_shape, 22 | initializer=tf.zeros_initializer()) 23 | meanS = tf.get_variable("mean_S", shape=new_shape, 24 | initializer=tf.zeros_initializer()) 25 | 26 | 27 | sigmaC = tf.sqrt(tf.add(varC, epsilon)) 28 | 29 | 30 | p=tf.sqrt(1.5) 31 | 32 | def get_mid_range(l,r): 33 | _mid=(l+r)/2.0 34 | _range=(r-l)/2.0 35 | return _mid,_range 36 | 37 | sign=tf.sign(meanC) 38 | abs_meanC=tf.abs(meanC) 39 | 40 | _sigma_mid, _sigma_range = get_mid_range(sigmaC/p, sigmaC*p) 41 | _mean_mid, _mean_range = get_mid_range(abs_meanC/p, abs_meanC*p) 42 | 43 | sigmaSp = _sigma_range*tf.nn.tanh(sigmaS)+_sigma_mid 44 | meanSp = sign * (_mean_range*tf.nn.tanh(meanS)+_mean_mid) 45 | 46 | ops_bound = [] 47 | 48 | ops_asgn = [tf.assign(sigmaS, tf.atanh((sigmaC-_sigma_mid)/ (_sigma_range +1e-4) )), 49 | tf.assign(meanS, tf.atanh((abs_meanC-_mean_mid)/(_mean_range + 1e-4) ))] 50 | 51 | #ops_asgn = [sigmaS.initializer, meanS.initializer]# 52 | #ops_asgn = [tf.assign(sigmaS, sigmaC-_sigma_mid), 53 | # tf.assign(meanS, meanC-_mean_mid)] 54 | 55 | return (content - meanC) * sigmaSp / sigmaC + meanSp , ops_asgn, ops_bound, sigmaSp, meanSp, meanS, sigmaS 56 | 57 | 58 | def AdaIN_adv(content, epsilon=1e-5, p=1.5): 59 | meanC, varC = tf.nn.moments(content, [1, 2], keep_dims=True) 60 | bs = settings.config["BATCH_SIZE"] 61 | content_shape = content.shape.as_list() 62 | new_shape = [bs, 1, 1, content_shape[3]] 63 | with tf.variable_scope("scale"): 64 | meanS = tf.get_variable("mean_S", shape=new_shape, 65 | initializer=tf.zeros_initializer()) 66 | sigmaS = tf.get_variable("sigma_S", shape=new_shape, 67 | initializer=tf.ones_initializer()) 68 | 69 | 70 | sigmaC = tf.sqrt(tf.add(varC, epsilon)) 71 | 72 | #p = 1.5 73 | p_sigma = p 74 | p_mean = p 75 | 76 | sign = tf.sign(meanC) 77 | abs_meanC = tf.abs(meanC) 78 | ops_bound = [tf.assign(sigmaS, tf.clip_by_value(sigmaS, sigmaC/p_sigma, sigmaC*p_sigma)), 79 | tf.assign(meanS, tf.clip_by_value(meanS, abs_meanC/p_mean, abs_meanC*p_mean))] 80 | 81 | sigmaC_rand = tf.random_uniform(tf.shape(sigmaC), sigmaC/p, sigmaC*p) 82 | meanC_rand = tf.random_uniform(tf.shape(meanC), abs_meanC/p, abs_meanC*p) 83 | #sigmaS = tf.sqrt(tf.add(varS, epsilon)) 84 | ops_asgn = [tf.assign(meanS, abs_meanC), tf.assign(sigmaS, sigmaC)] 85 | ops_asgn_rand = [tf.assign(sigmaS, sigmaC_rand), tf.assign(meanS, meanC_rand)] 86 | 87 | return (content - meanC) * sigmaS / sigmaC + sign * meanS, ops_asgn, ops_bound, sigmaS, meanS, meanC, sigmaC, ops_asgn_rand, (content - meanC) / (sigmaC) 88 | 89 | 90 | def normalize(content, epsilon=1e-5): 91 | meanC, varC = tf.nn.moments(content, [1, 2], keep_dims=True) 92 | #meanC_s, varC_s = tf.nn.moments(content, [1, 2]) 93 | bs = settings.config["BATCH_SIZE"] 94 | content_shape = content.shape.as_list() 95 | new_shape = [bs, 1, 1, content_shape[3]] 96 | 97 | sigmaC = tf.sqrt(tf.add(varC, epsilon)) 98 | #sigmaS = tf.sqrt(tf.add(varS, epsilon)) 99 | normalize_content = (content - meanC) / sigmaC 100 | 101 | 102 | 103 | return normalize_content, meanC, sigmaC 104 | -------------------------------------------------------------------------------- /attack.py: -------------------------------------------------------------------------------- 1 | # Train the Style Transfer Net 2 | from __future__ import print_function 3 | 4 | import numpy as np 5 | import sys 6 | import os 7 | import argparse 8 | from PIL import Image 9 | 10 | import tensorflow as tf 11 | 12 | import settings 13 | import dataprep 14 | import modelprep 15 | 16 | from style_transfer_net import StyleTransferNet_adv 17 | from utils import get_scope_var, save_rgb_img 18 | 19 | 20 | np.set_printoptions(threshold=sys.maxsize) 21 | 22 | parser = argparse.ArgumentParser( 23 | description='Training Auto Encoder for Feature Space Attack') 24 | parser.add_argument("--dataset", help="Dataset for training the auto encoder", 25 | choices=["imagenet", "cifar10"] , default="imagenet") 26 | parser.add_argument("--decoder", help="Depth of the decoder to use. The deeper one (e.g. 3) injects more structure change. " + 27 | "And it becomes more harmful but less nature-looking.", type=int, choices=[1, 2, 3], default=1) 28 | parser.add_argument( 29 | "--scale", help="Whether to scale up the image size of CIFAR10 to the size of Imagenet", action="store_true") 30 | parser.add_argument("--model", help="Model to attack.", default="imagenet_normal", 31 | choices=["imagenet_normal", "imagenet_denoise", "cifar10_adv", "cifar10_nat", "cifar10_trades"]) 32 | parser.add_argument("--bound", help="Bound for attack, the exponential of sigma described in the paper", type=float, default=1.5) 33 | 34 | args = parser.parse_args() 35 | 36 | data_set = args.dataset 37 | decoder = args.decoder 38 | model_name = args.model 39 | bound = args.bound 40 | 41 | if data_set == "imagenet": 42 | decoder_list = {1: "imagenet_shallowest", 43 | 2: "imagenet_shallow", 44 | 3: "imagenet"} 45 | 46 | decoder_name = decoder_list[decoder] 47 | 48 | elif data_set == "cifar10": 49 | # One can choose to not to scale CIFAR10 to Imagenet for better speed. While for best quality, one need to consider scale the image size up 50 | # The corresponding decoder name is cifar10_unscale 51 | decoder_list = {1: "cifar10_shallowest", 52 | 2: "cifar10_shallow", 53 | 3: "cifar10"} 54 | if args.scale: 55 | decoder_name = decoder_list[decoder] 56 | else: 57 | decoder_name = "cifar10_unscale" 58 | 59 | task_name = "attack" 60 | 61 | # Put all the pre-defined const into settings and fetch them as global variables 62 | settings.common_const_init(data_set, model_name, decoder_name, task_name) 63 | logger = settings.logger 64 | 65 | for k, v in settings.config.items(): 66 | globals()[k] = v 67 | 68 | dataprep.init_data("eval") 69 | get_data = dataprep.get_data 70 | 71 | 72 | # (height, width, color_channels) 73 | TRAINING_IMAGE_SHAPE = settings.config["IMAGE_SHAPE"] 74 | 75 | EPOCHS = 4 76 | EPSILON = 1e-5 77 | BATCH_SIZE = settings.config["BATCH_SIZE"] 78 | if data_set == "cifar10": 79 | LEARNING_RATE = 1e-2 80 | LR_DECAY_RATE = 1e-4 81 | DECAY_STEPS = 1.0 82 | adv_weight = 500 83 | ITER=2000 84 | CLIP_NORM_VALUE = 10.0 85 | else: 86 | if model_name .find("shallowest")>=0: 87 | LEARNING_RATE = 5e-3 88 | else: 89 | LEARNING_RATE = 1e-2 90 | LR_DECAY_RATE = 1e-3 91 | DECAY_STEPS = 1.0 92 | adv_weight = 128 93 | ITER=500 94 | CLIP_NORM_VALUE = 10.0 95 | 96 | style_weight = 1 97 | 98 | 99 | encoder_path = ENCODER_WEIGHTS_PATH 100 | debug = True 101 | if debug: 102 | from datetime import datetime 103 | start_time = datetime.now() 104 | 105 | def grad_attack(): 106 | sess.run(stn.init_style, feed_dict=fdict) 107 | sess.run(global_step.initializer) 108 | rst_img, rst_loss, nat_acc, rst_acc,rst_mean,rst_sigma = sess.run( 109 | [adv_img, content_loss_y, nat_output.acc_y_auto, adv_output.acc_y_auto, stn.meanS, stn.sigmaS], feed_dict=fdict) 110 | print("Nature Acc:", nat_acc) 111 | for i in range(ITER): 112 | # Run an optimization step 113 | _ = sess.run([train_op], feed_dict=fdict) 114 | 115 | # Clip the bound 116 | sess.run(stn.style_bound, feed_dict = fdict) 117 | 118 | # Monitor the progress 119 | _adv_img, acc, aloss, closs, _mean, _sigma = sess.run( 120 | [adv_img, adv_output.acc_y_auto, adv_loss, content_loss_y, stn.meanS, stn.sigmaS], feed_dict=fdict) 121 | for j in range(BATCH_SIZE): 122 | # Save the best samples 123 | if acc[j]=2: 72 | pairs=min(left,sz//2) 73 | else: 74 | i = (i+1) % self.class_num 75 | continue 76 | x1.extend(self.bucket[i][:pairs]) 77 | x2.extend(self.bucket[i][pairs:2*pairs]) 78 | y1.extend([i]*pairs) 79 | y2.extend([i]*pairs) 80 | self.bucket[i] = self.bucket[i][2*pairs:] 81 | self.bucket_size[i]-=2*pairs 82 | left-=pairs 83 | i= (i+1)%self.class_num 84 | #print(i) 85 | self.index = i 86 | self.tot_pair-=self.batch_size 87 | x1=np.stack(x1) 88 | x2=np.stack(x2) 89 | y1=np.stack(y1) 90 | y2=np.stack(y2) 91 | return x1,y1,x2,y2 92 | 93 | 94 | def init_data(mode): 95 | global CLASS_NUM, BATCH_SIZE, inet, cifar_data, data_set, dp, config_name, raw_cifar 96 | assert mode in ["train","eval"] 97 | CLASS_NUM = settings.config["CLASS_NUM"] 98 | BATCH_SIZE = settings.config["BATCH_SIZE"] 99 | data_set = settings.config["data_set"] 100 | config_name = settings.config["config_name"] 101 | 102 | assert data_set in ["cifar10","svhn","imagenet"] 103 | data_set = data_set 104 | 105 | if data_set == "imagenet": 106 | if mode == "train": 107 | inet = imagenet(BATCH_SIZE, dataset="train") 108 | elif mode == "eval": 109 | inet = imagenet(BATCH_SIZE, dataset="val") 110 | elif data_set == "cifar10": 111 | 112 | raw_cifar = cifar10_input.CIFAR10Data("cifar10_data") 113 | if mode == "eval": 114 | cifar_data = raw_cifar.eval_data 115 | elif mode == "train": 116 | cifar_data = raw_cifar.train_data 117 | else: 118 | assert False, "Not implemented" 119 | dp = datapair(CLASS_NUM, BATCH_SIZE) 120 | 121 | def init_polygon_data(stack_num, fetch_embed): 122 | global _mean_all, _sigma_all 123 | mean_file = "polygon_mean_%s.npy" % config_name 124 | sigma_file = "polygon_sigma_%s.npy" % config_name 125 | if os.path.exists(mean_file) and os.path.exists(sigma_file): 126 | _mean_all = np.load(mean_file) 127 | _sigma_all = np.load(sigma_file) 128 | else: 129 | ## Populate polygon point 130 | dps = datapairs(CLASS_NUM, BATCH_SIZE, stack_num) 131 | f = True 132 | while f: 133 | x_batch, y_batch = get_data() 134 | f = dp.feed_pair(x_batch, y_batch) 135 | print("datapairs loading") 136 | polygon_arr = np.concatenate(dp.bucket) 137 | len_arr = polygon_arr.shape[0] 138 | _mean = [] 139 | _sigma = [] 140 | for i in range((len_arr - 1) // BATCH_SIZE + 1): 141 | # sess.run([stn.meanC, stn.sigmaC], feed_dict={ 142 | _meanC, _sigmaC = fetch_embed( 143 | polygon_arr[i*BATCH_SIZE:(i+1)*BATCH_SIZE]) 144 | _mean.append(_meanC) 145 | _sigma.append(_sigmaC) 146 | print("datapairs loaded") 147 | _mean_all = np.concatenate(_mean, axis=0) 148 | _sigma_all = np.concatenate(_sigma, axis=0) 149 | np.save(mean_file, _mean_all) 150 | np.save(sigma_file, _sigma_all) 151 | 152 | def popoulate_data(_meanC, _sigmaC, y_batch, include_self=True): 153 | 154 | res_mean = [] 155 | res_sigma = [] 156 | 157 | if include_self: 158 | real_num = INTERPOLATE_NUM - 1 159 | for i in range(BATCH_SIZE): 160 | y = y_batch[i] 161 | meanCi = _meanC[i: i+1] 162 | meanC_pop = _mean_all[y*real_num:(y+1)*real_num] 163 | res_mean.append(np.concatenate([meanCi, meanC_pop])) 164 | sigmaCi = _sigmaC[i: i+1] 165 | sigmaC_pop = _sigma_all[y*real_num:(y+1)*real_num] 166 | res_sigma.append(np.concatenate([sigmaCi, sigmaC_pop])) 167 | else: 168 | real_num = INTERPOLATE_NUM 169 | for i in range(BATCH_SIZE): 170 | y = y_batch[i] 171 | meanC_pop = _mean_all[y*real_num:(y+1)*real_num] 172 | res_mean.append(meanC_pop) 173 | sigmaCi = _sigmaC[i: i+1] 174 | sigmaC_pop = _sigma_all[y*real_num:(y+1)*real_num] 175 | res_sigma.append(sigmaC_pop) 176 | return np.stack(res_mean), np.stack(res_sigma) 177 | 178 | def get_fetch_func(sess, content, pred): 179 | return functools.partial(_fetch_embed, sess=sess, content=content, pred=pred) 180 | 181 | def _fetch_embed(sess, content, pred ): 182 | _pred = sess.run(pred, feed_dict={content: content}) 183 | return _pred 184 | 185 | def _get_data(): 186 | 187 | if data_set =="cifar10": 188 | x_batch, y_batch = cifar_data.get_next_batch( 189 | batch_size=BATCH_SIZE, multiple_passes=True) 190 | elif data_set == "imagenet": 191 | x_batch, y_batch = inet.get_next_batch() 192 | 193 | return x_batch,y_batch 194 | 195 | def get_data(): 196 | return _get_data() 197 | 198 | def get_data_pair(): 199 | mode = settings.config["data_mode"] 200 | if mode == 1: 201 | ret_list = [] 202 | for _ in range(2): 203 | ret_list.extend(get_data()) 204 | return ret_list 205 | 206 | else: 207 | res = dp.get_pair() 208 | while res is None: 209 | x_batch, y_batch = get_data() 210 | dp.feed_pair(x_batch, y_batch) 211 | res = dp.get_pair() 212 | return res 213 | -------------------------------------------------------------------------------- /decoder.py: -------------------------------------------------------------------------------- 1 | # Decoder mostly mirrors the encoder with all pooling layers replaced by nearest 2 | # up-sampling to reduce checker-board effects. 3 | # Decoder has no BN/IN layers. 4 | 5 | import tensorflow as tf 6 | import settings 7 | 8 | class Decoder(object): 9 | 10 | def __init__(self): 11 | self.weight_vars = [] 12 | 13 | if "Decoder_Layer" in settings.config: 14 | self.decoder_layer = settings.config["Decoder_Layer"] 15 | else: 16 | self.decoder_layer = "conv" 17 | 18 | with tf.variable_scope('decoder'): 19 | self._create_variables(512, 256, 3, scope='conv4_1') 20 | 21 | self._create_variables(256, 256, 3, scope='conv3_4') 22 | self._create_variables(256, 256, 3, scope='conv3_3') 23 | self._create_variables(256, 256, 3, scope='conv3_2') 24 | self._create_variables(256, 128, 3, scope='conv3_1') 25 | 26 | self._create_variables(128, 128, 3, scope='conv2_2') 27 | self._create_variables(128, 64, 3, scope='conv2_1') 28 | 29 | self._create_variables( 64, 64, 3, scope='conv1_2') 30 | self._create_variables( 64, 3, 3, scope='conv1_1') 31 | 32 | def _create_variables(self, input_filters, output_filters, kernel_size, scope): 33 | if self.decoder_layer == "conv": 34 | self._create_variables_c( 35 | input_filters, output_filters, kernel_size, scope) 36 | elif self.decoder_layer == "deconv": 37 | self._create_variables_t( 38 | input_filters, output_filters, kernel_size, scope) 39 | else: 40 | assert False 41 | 42 | def _create_variables_c(self, input_filters, output_filters, kernel_size, scope): 43 | if scope in settings.config["DECODER_LAYERS"]: 44 | 45 | with tf.variable_scope(scope): 46 | shape = [kernel_size, kernel_size, 47 | input_filters, output_filters] 48 | kernel = tf.get_variable(initializer=tf.contrib.layers.xavier_initializer( 49 | uniform=False), shape=shape, name='kernel') 50 | bias = tf.get_variable(initializer=tf.contrib.layers.xavier_initializer( 51 | uniform=False), shape=[output_filters], name='bias') 52 | pack = (kernel, bias) 53 | self.weight_vars.append(pack) 54 | 55 | def _create_variables_t(self, input_filters, output_filters, kernel_size, scope): 56 | if scope in settings.config["DECODER_LAYERS"]: 57 | with tf.variable_scope(scope): 58 | shape = [kernel_size, kernel_size, 59 | output_filters, input_filters] 60 | kernel = tf.get_variable(initializer=tf.contrib.layers.xavier_initializer( 61 | uniform=False), shape=shape, name='kernel') 62 | bias = tf.get_variable(initializer=tf.contrib.layers.xavier_initializer( 63 | uniform=False), shape=[output_filters], name='bias') 64 | pack = (kernel, bias) 65 | self.weight_vars.append(pack) 66 | 67 | def decode(self, image): 68 | # upsampling after 'conv4_1', 'conv3_1', 'conv2_1' 69 | upsample_indices = settings.config["upsample_indices"] 70 | final_layer_idx = len(self.weight_vars) - 1 71 | 72 | if self.decoder_layer == "conv": 73 | func = conv2d 74 | else: 75 | func = transconv2d 76 | 77 | out = image 78 | for i in range(len(self.weight_vars)): 79 | #print("decoder in %d shape: " % i, out.shape.as_list()) 80 | kernel, bias = self.weight_vars[i] 81 | #if i in upsample_indices: 82 | # out=transconv2d(out,kernel,bias) 83 | #else: 84 | if i == final_layer_idx: 85 | out = func(out, kernel, bias, use_relu=False) 86 | else: 87 | out = func(out, kernel, bias) 88 | 89 | if i in upsample_indices: 90 | out = upsample(out) 91 | #print("decoder out %d shape: "%i, out.shape.as_list()) 92 | 93 | return out 94 | 95 | 96 | def conv2d(x, kernel, bias, use_relu=True): 97 | # padding image with reflection mode 98 | x_padded = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT') 99 | 100 | # conv and add bias 101 | out = tf.nn.conv2d(x_padded, kernel, strides=[1, 1, 1, 1], padding='VALID') 102 | out = tf.nn.bias_add(out, bias) 103 | 104 | if use_relu: 105 | out = tf.nn.relu(out) 106 | 107 | return out 108 | 109 | 110 | def transconv2d(x, kernel, bias, use_relu=True, stride=1): 111 | 112 | bs = tf.shape(x)[0] 113 | img_sz = x.shape.as_list()[1] 114 | #print(img_sz) 115 | filter_size = kernel.shape.as_list()[2] 116 | # conv and add bias 117 | g_deconv = tf.nn.conv2d_transpose(x, kernel, output_shape=[ 118 | bs, img_sz*stride, img_sz*stride, filter_size], strides=[1, stride, stride, 1], padding='SAME') 119 | out = g_deconv + bias 120 | 121 | if use_relu: 122 | out = tf.nn.relu(out) 123 | #print(out.shape.as_list()) 124 | return out 125 | 126 | 127 | def upsample(x, scale=2): 128 | height = x.shape.as_list()[1]*scale#tf.shape(x)[1] * scale 129 | width = x.shape.as_list()[2]*scale # tf.shape(x)[2] * scale 130 | output = tf.image.resize_images(x, [height, width], 131 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 132 | 133 | return output 134 | 135 | -------------------------------------------------------------------------------- /encoder.py: -------------------------------------------------------------------------------- 1 | # Encoder is fixed to the first few layers (up to relu4_1) 2 | # of VGG-19 (pre-trained on ImageNet) 3 | # This code is a modified version of Anish Athalye's vgg.py 4 | # https://github.com/anishathalye/neural-style/blob/master/vgg.py 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | import settings 9 | 10 | 11 | class Encoder(object): 12 | 13 | def __init__(self, weights_path): 14 | # load weights (kernel and bias) from npz file 15 | weights = np.load(weights_path) 16 | 17 | idx = 0 18 | self.weight_vars = [] 19 | ENCODER_LAYERS = settings.config["ENCODER_LAYERS"] 20 | # create the TensorFlow variables 21 | with tf.variable_scope('encoder'): 22 | for layer in ENCODER_LAYERS: 23 | kind = layer[:4] 24 | 25 | if kind == 'conv': 26 | kernel = weights['arr_%d' % idx].transpose([2, 3, 1, 0]) 27 | bias = weights['arr_%d' % (idx + 1)] 28 | kernel = kernel.astype(np.float32) 29 | bias = bias.astype(np.float32) 30 | idx += 2 31 | 32 | with tf.variable_scope(layer): 33 | W = tf.Variable(kernel, trainable=False, name='kernel') 34 | b = tf.Variable(bias, trainable=False, name='bias') 35 | 36 | self.weight_vars.append((W, b)) 37 | 38 | def encode(self, image): 39 | 40 | # create the computational graph 41 | idx = 0 42 | layers = {} 43 | current = image 44 | ENCODER_LAYERS = settings.config["ENCODER_LAYERS"] 45 | for i,layer in enumerate(ENCODER_LAYERS): 46 | kind = layer[:4] 47 | 48 | if kind == 'conv': 49 | kernel, bias = self.weight_vars[idx] 50 | idx += 1 51 | current = conv2d(current, kernel, bias) 52 | 53 | elif kind == 'relu': 54 | current = tf.nn.relu(current) 55 | 56 | elif kind == 'pool': 57 | current = pool2d(current) 58 | 59 | layers[layer] = current 60 | print("encoder %d shape: " % i, current.shape.as_list()) 61 | 62 | assert(len(layers) == len(ENCODER_LAYERS)) 63 | 64 | enc = layers[ENCODER_LAYERS[-1]] 65 | 66 | return enc, layers 67 | 68 | def preprocess(self, image, mode='RGB'): 69 | assert mode == "RGB" 70 | # preprocess 71 | if settings.config["IMAGE_SHAPE"][0] != 224 and "NO_SCALE" not in settings.config: 72 | if "pre_scale" in image.__dict__: 73 | image = image.pre_scale 74 | else: 75 | image = tf.image.resize(image, size=[224, 224]) 76 | 77 | # To BGR 78 | image = tf.reverse(image, axis=[-1]) 79 | 80 | return image - np.array([103.939, 116.779, 123.68]) 81 | 82 | def deprocess(self, image, mode='BGR'): 83 | assert mode == "BGR" 84 | image = image + np.array([103.939, 116.779, 123.68]) 85 | 86 | image = tf.reverse(image, axis=[-1]) 87 | image = tf.clip_by_value(image, 0.0, 255.0) 88 | 89 | pre_scale = image 90 | if settings.config["IMAGE_SHAPE"][0] != 224 and "NO_SCALE" not in settings.config: 91 | image = tf.image.resize( 92 | image, size=settings.config["IMAGE_SHAPE"][:2]) 93 | image.pre_scale = pre_scale 94 | return image 95 | 96 | def conv2d(x, kernel, bias): 97 | # padding image with reflection mode 98 | x_padded = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT') 99 | 100 | # conv and add bias 101 | out = tf.nn.conv2d(x_padded, kernel, strides=[1, 1, 1, 1], padding='VALID') 102 | out = tf.nn.bias_add(out, bias) 103 | 104 | return out 105 | 106 | 107 | def pool2d(x): 108 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 109 | 110 | -------------------------------------------------------------------------------- /imagenetmod/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiulingxu/FeatureSpaceAttack/121538bbc298a1ee485cf085e39472690ac9ce48/imagenetmod/__init__.py -------------------------------------------------------------------------------- /imagenetmod/adv_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import tensorflow as tf 8 | 9 | from tensorpack.models import regularize_cost, BatchNorm 10 | from tensorpack.tfutils.summary import add_moving_summary 11 | from tensorpack.tfutils import argscope 12 | from tensorpack.tfutils.tower import get_current_tower_context, TowerFuncWrapper 13 | from tensorpack.utils import logger 14 | from tensorpack.utils.argtools import log_once 15 | from tensorpack.tfutils.collection import freeze_collection 16 | from tensorpack.tfutils.varreplace import custom_getter_scope 17 | 18 | from .third_party.imagenet_utils import ImageNetModel 19 | 20 | 21 | IMAGE_SCALE = 2.0 / 255 22 | 23 | 24 | class NoOpAttacker(): 25 | """ 26 | A placeholder attacker which does nothing. 27 | """ 28 | def attack(self, image, label, model_func): 29 | return image, -tf.ones_like(label) 30 | 31 | 32 | class PGDAttacker(): 33 | """ 34 | A PGD white-box attacker with random target label. 35 | """ 36 | 37 | USE_FP16 = False 38 | """ 39 | Use FP16 to run PGD iterations. 40 | This has about 2~3x speedup for most types of models 41 | if used together with XLA on Volta GPUs. 42 | """ 43 | 44 | USE_XLA = False 45 | """ 46 | Use XLA to optimize the graph of PGD iterations. 47 | This requires CUDA>=9.2 and TF>=1.12. 48 | """ 49 | 50 | 51 | def __init__(self, num_iter, epsilon, step_size, prob_start_from_clean=0.0): 52 | """ 53 | Args: 54 | num_iter (int): 55 | epsilon (float): 56 | step_size (int): 57 | prob_start_from_clean (float): The probability to initialize with 58 | the original image, rather than a randomly perturbed one. 59 | """ 60 | step_size = max(step_size, epsilon / num_iter) 61 | """ 62 | Feature Denoising, Sec 6.1: 63 | We set its step size α = 1, except for 10-iteration attacks where α is set to α/10= 1.6 64 | """ 65 | self.num_iter = num_iter 66 | # rescale the attack epsilon and attack step size 67 | self.epsilon = epsilon * IMAGE_SCALE 68 | self.step_size = step_size * IMAGE_SCALE 69 | self.prob_start_from_clean = prob_start_from_clean 70 | 71 | def _create_random_target(self, label): 72 | """ 73 | Feature Denoising Sec 6: 74 | we consider targeted attacks when 75 | evaluating under the white-box settings, where the targeted 76 | class is selected uniformly at random 77 | """ 78 | label_offset = tf.random_uniform(tf.shape(label), minval=1, maxval=1000, dtype=tf.int32) 79 | return tf.floormod(label + label_offset, tf.constant(1000, tf.int32)) 80 | 81 | def attack(self, image_clean, label, model_func): 82 | target_label = self._create_random_target(label) 83 | 84 | def fp16_getter(getter, *args, **kwargs): 85 | name = args[0] if len(args) else kwargs['name'] 86 | if not name.endswith('/W') and not name.endswith('/b'): 87 | """ 88 | Following convention, convolution & fc are quantized. 89 | BatchNorm (gamma & beta) are not quantized. 90 | """ 91 | return getter(*args, **kwargs) 92 | else: 93 | if kwargs['dtype'] == tf.float16: 94 | kwargs['dtype'] = tf.float32 95 | ret = getter(*args, **kwargs) 96 | ret = tf.cast(ret, tf.float16) 97 | log_once("Variable {} casted to fp16 ...".format(name)) 98 | return ret 99 | else: 100 | return getter(*args, **kwargs) 101 | 102 | def one_step_attack(adv): 103 | if not self.USE_FP16: 104 | logits = model_func(adv) 105 | else: 106 | adv16 = tf.cast(adv, tf.float16) 107 | with custom_getter_scope(fp16_getter): 108 | logits = model_func(adv16) 109 | logits = tf.cast(logits, tf.float32) 110 | # Note we don't add any summaries here when creating losses, because 111 | # summaries don't work in conditionals. 112 | losses = tf.nn.sparse_softmax_cross_entropy_with_logits( 113 | logits=logits, labels=target_label) # we want to minimize it in targeted attack 114 | if not self.USE_FP16: 115 | g, = tf.gradients(losses, adv) 116 | else: 117 | """ 118 | We perform loss scaling to prevent underflow: 119 | https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html 120 | (We have not yet tried training without scaling) 121 | """ 122 | g, = tf.gradients(losses * 128., adv) 123 | g = g / 128. 124 | 125 | """ 126 | Feature Denoising, Sec 5: 127 | We use the Projected Gradient Descent (PGD) 128 | (implemented at https://github.com/MadryLab/cifar10_challenge ) 129 | as the white-box attacker for adversarial training 130 | """ 131 | adv = tf.clip_by_value(adv - tf.sign(g) * self.step_size, lower_bound, upper_bound) 132 | return adv 133 | 134 | """ 135 | Feature Denoising, Sec 6: 136 | Adversarial perturbation is considered under L∞ norm (i.e., maximum difference for each pixel). 137 | """ 138 | lower_bound = tf.clip_by_value(image_clean - self.epsilon, -1., 1.) 139 | upper_bound = tf.clip_by_value(image_clean + self.epsilon, -1., 1.) 140 | 141 | """ 142 | Feature Denoising Sec. 5: 143 | We randomly choose from both initializations in the 144 | PGD attacker during adversarial training: 20% of training 145 | batches use clean images to initialize PGD, and 80% use 146 | random points within the allowed . 147 | """ 148 | init_start = tf.random_uniform(tf.shape(image_clean), minval=-self.epsilon, maxval=self.epsilon) 149 | 150 | start_from_noise_index = tf.cast(tf.greater(tf.random_uniform(shape=[]), self.prob_start_from_clean), tf.float32) 151 | start_adv = image_clean + start_from_noise_index * init_start 152 | 153 | 154 | if self.USE_XLA: 155 | assert tuple(map(int, tf.__version__.split('.')[:2])) >= (1, 12) 156 | from tensorflow.contrib.compiler import xla 157 | with tf.name_scope('attack_loop'): 158 | adv_final = tf.while_loop( 159 | lambda _: True, 160 | one_step_attack if not self.USE_XLA else \ 161 | lambda adv: xla.compile(lambda: one_step_attack(adv))[0], 162 | [start_adv], 163 | back_prop=False, 164 | maximum_iterations=self.num_iter, 165 | parallel_iterations=1) 166 | return adv_final, target_label 167 | 168 | 169 | class AdvImageNetModel(ImageNetModel): 170 | 171 | """ 172 | Feature Denoising, Sec 5: 173 | A label smoothing of 0.1 is used. 174 | """ 175 | label_smoothing = 0.1 176 | 177 | def set_attacker(self, attacker): 178 | self.attacker = attacker 179 | 180 | def build_graph(self, image, label): 181 | """ 182 | The default tower function. 183 | """ 184 | image = self.image_preprocess(image) 185 | assert self.data_format == 'NCHW' 186 | image = tf.transpose(image, [0, 3, 1, 2]) 187 | ctx = get_current_tower_context() 188 | 189 | with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): 190 | # BatchNorm always comes with trouble. We use the testing mode of it during attack. 191 | with freeze_collection([tf.GraphKeys.UPDATE_OPS]), argscope(BatchNorm, training=False): 192 | image, target_label = self.attacker.attack(image, label, self.get_logits) 193 | image = tf.stop_gradient(image, name='adv_training_sample') 194 | 195 | logits = self.get_logits(image) 196 | 197 | loss = ImageNetModel.compute_loss_and_error( 198 | logits, label, label_smoothing=self.label_smoothing) 199 | AdvImageNetModel.compute_attack_success(logits, target_label) 200 | if not ctx.is_training: 201 | return 202 | 203 | wd_loss = regularize_cost(self.weight_decay_pattern, 204 | tf.contrib.layers.l2_regularizer(self.weight_decay), 205 | name='l2_regularize_loss') 206 | add_moving_summary(loss, wd_loss) 207 | total_cost = tf.add_n([loss, wd_loss], name='cost') 208 | 209 | if self.loss_scale != 1.: 210 | logger.info("Scaling the total loss by {} ...".format(self.loss_scale)) 211 | return total_cost * self.loss_scale 212 | else: 213 | return total_cost 214 | 215 | def get_inference_func(self, attacker): 216 | """ 217 | Returns a tower function to be used for inference. It generates adv 218 | images with the given attacker and runs classification on it. 219 | """ 220 | 221 | def tower_func(image, label): 222 | assert not get_current_tower_context().is_training 223 | image = self.image_preprocess(image) 224 | image = tf.transpose(image, [0, 3, 1, 2]) 225 | image, target_label = attacker.attack(image, label, self.get_logits) 226 | logits = self.get_logits(image) 227 | ImageNetModel.compute_loss_and_error(logits, label) # compute top-1 and top-5 228 | AdvImageNetModel.compute_attack_success(logits, target_label) 229 | 230 | return TowerFuncWrapper(tower_func, self.get_inputs_desc()) 231 | 232 | def image_preprocess(self, image): 233 | with tf.name_scope('image_preprocess'): 234 | if image.dtype.base_dtype != tf.float32: 235 | image = tf.cast(image, tf.float32) 236 | # For the purpose of adversarial training, normalize images to [-1, 1] 237 | image = image * IMAGE_SCALE - 1.0 238 | return image 239 | 240 | @staticmethod 241 | def compute_attack_success(logits, target_label): 242 | """ 243 | Compute the attack success rate. 244 | """ 245 | pred = tf.argmax(logits, axis=1, output_type=tf.int32) 246 | equal_target = tf.equal(pred, target_label) 247 | success = tf.cast(equal_target, tf.float32, name='attack_success') 248 | add_moving_summary(tf.reduce_mean(success, name='attack_success_rate')) 249 | -------------------------------------------------------------------------------- /imagenetmod/interface.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorpack.tfutils import SmartInit 3 | from .nets import ResNeXtDenoiseAllModel 4 | from .third_party.imagenet_utils import get_val_dataflow 5 | from tensorpack.tfutils.tower import TowerContext 6 | 7 | def restore_parameter(sess): 8 | file_path = "X101-DenoiseAll.npz" 9 | sessinit = SmartInit(file_path) 10 | sessinit.init(sess) 11 | 12 | class container: 13 | def __init__(self): 14 | pass 15 | 16 | 17 | def build_imagenet_model(image, label, reuse=False, conf=1): 18 | args = container() 19 | args.depth = 101 20 | with TowerContext(tower_name='', is_training=False): 21 | with tf.variable_scope("", auxiliary_name_scope=False, reuse=reuse): 22 | model = ResNeXtDenoiseAllModel(args) 23 | model.build_graph(image, label) 24 | return model.logits 25 | 26 | def build_imagenet_model_old(image,label,reuse=False,conf =1): 27 | args=container() 28 | args.depth=101 29 | with TowerContext(tower_name='', is_training=False): 30 | with tf.variable_scope("", auxiliary_name_scope=False, reuse=reuse): 31 | model=ResNeXtDenoiseAllModel(args) 32 | model.build_graph(image,label) 33 | cont = container 34 | cont.logits = model.logits 35 | cont.label = tf.argmax(cont.logits, axis=-1) 36 | cont.acc_y = 1-model.wrong_1 37 | cont.acc_y_5 = 1-model.wrong_5 38 | cont.accuracy = tf.reduce_mean(1-model.wrong_1) # wrong_5 39 | cont.rev_xent = tf.reduce_mean(tf.log( 40 | 1 - tf.reduce_sum(tf.nn.softmax(model.logits) * 41 | tf.one_hot(label, depth=1000), axis=-1) 42 | )) 43 | cont.poss_loss = 1 - tf.reduce_mean( 44 | tf.reduce_sum(tf.nn.softmax(model.logits) * 45 | tf.one_hot(label, depth=1000), axis=-1) 46 | ) 47 | 48 | label_one_hot = tf.one_hot(label, depth=1000) 49 | wrong_logit = tf.reduce_max(model.logits * (1-label_one_hot) -label_one_hot * 1e7, axis=-1) 50 | true_logit = tf.reduce_sum(model.logits * label_one_hot, axis=-1) 51 | #wrong_logit = tf.contrib.nn.nth_element(model.logits * (1-label_one_hot) - label_one_hot * 1e7, n=5, reverse=True) 52 | wrong_logit5, _idx = tf.nn.top_k( 53 | model.logits * (1-label_one_hot) - label_one_hot * 1e7, k=5, sorted=False) 54 | true_logit5 = tf.reduce_sum(model.logits * label_one_hot, axis=-1, keep_dims=True) 55 | cont.target_loss5 = - tf.reduce_sum(tf.nn.relu(true_logit5 - wrong_logit5 + conf), axis=1) 56 | cont.target_loss = - tf.nn.relu(true_logit - wrong_logit + conf) 57 | cont.xent_filter = tf.reduce_mean((1.0-model.wrong_1)* 58 | tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=model.logits), axis=-1) 59 | 60 | cont.xent = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( 61 | labels=label, logits=model.logits), axis=-1) 62 | #cont.target_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( 63 | # labels=label, logits=model.logits) * tf.nn.relu(tf.minimum(1.0, true_logit - wrong_logit + conf)) 64 | return cont 65 | 66 | class imagenet: 67 | 68 | def __init__(self, batchsize, dataset="val"): 69 | self.batchsize=batchsize 70 | self.dataset=dataset 71 | self.init() 72 | 73 | def init(self, ): 74 | self.data = get_val_dataflow( 75 | "imagenet", self.batchsize, dataname=self.dataset) 76 | self.data.reset_state() 77 | self.iter=iter(self.data) 78 | #self.data = tf.transpose(data, [0, 3, 1, 2]) 79 | 80 | 81 | def get_next_batch(self): 82 | pack=next(self.iter,None) 83 | if pack is None: 84 | self.data.reset_state() 85 | self.iter = iter(self.data) 86 | pack = next(self.iter, None) 87 | return pack 88 | -------------------------------------------------------------------------------- /imagenetmod/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | 10 | import argparse 11 | import glob 12 | import os 13 | import sys 14 | import socket 15 | import numpy as np 16 | 17 | import horovod.tensorflow as hvd 18 | 19 | from tensorpack import * 20 | from tensorpack.tfutils import get_model_loader 21 | 22 | import nets 23 | from adv_model import NoOpAttacker, PGDAttacker 24 | from third_party.imagenet_utils import get_val_dataflow, eval_on_ILSVRC12 25 | from third_party.utils import HorovodClassificationError 26 | 27 | 28 | def create_eval_callback(name, tower_func, condition): 29 | """ 30 | Create a distributed evaluation callback. 31 | 32 | Args: 33 | name (str): a prefix 34 | tower_func (TowerFuncWrapper): the inference tower function 35 | condition: a function(epoch number) that returns whether this epoch should evaluate or not 36 | """ 37 | dataflow = get_val_dataflow( 38 | args.data, args.batch, 39 | num_splits=hvd.size(), split_index=hvd.rank()) 40 | # We eval both the classification error rate (for comparison with defenders) 41 | # and the attack success rate (for comparison with attackers). 42 | infs = [HorovodClassificationError('wrong-top1', '{}-top1-error'.format(name)), 43 | HorovodClassificationError('wrong-top5', '{}-top5-error'.format(name)), 44 | HorovodClassificationError('attack_success', '{}-attack-success-rate'.format(name)) 45 | ] 46 | cb = InferenceRunner( 47 | QueueInput(dataflow), infs, 48 | tower_name=name, 49 | tower_func=tower_func).set_chief_only(False) 50 | cb = EnableCallbackIf( 51 | cb, lambda self: condition(self.epoch_num)) 52 | return cb 53 | 54 | 55 | def do_train(model): 56 | batch = args.batch 57 | total_batch = batch * hvd.size() 58 | 59 | if args.fake: 60 | data = FakeData( 61 | [[batch, 224, 224, 3], [batch]], 1000, 62 | random=False, dtype=['uint8', 'int32']) 63 | data = StagingInput(QueueInput(data)) 64 | callbacks = [] 65 | steps_per_epoch = 50 66 | else: 67 | logger.info("#Tower: {}; Batch size per tower: {}".format(hvd.size(), batch)) 68 | zmq_addr = 'ipc://@imagenet-train-b{}'.format(batch) 69 | if args.no_zmq_ops: 70 | dataflow = RemoteDataZMQ(zmq_addr, hwm=150, bind=False) 71 | data = QueueInput(dataflow) 72 | else: 73 | data = ZMQInput(zmq_addr, 30, bind=False) 74 | data = StagingInput(data) 75 | 76 | steps_per_epoch = int(np.round(1281167 / total_batch)) 77 | 78 | BASE_LR = 0.1 * (total_batch // 256) 79 | """ 80 | ImageNet in 1 Hour, Sec 2.1: 81 | Linear Scaling Rule: When the minibatch size is 82 | multiplied by k, multiply the learning rate by k. 83 | """ 84 | logger.info("Base LR: {}".format(BASE_LR)) 85 | callbacks = [ 86 | ModelSaver(max_to_keep=10), 87 | EstimatedTimeLeft(), 88 | ScheduledHyperParamSetter( 89 | 'learning_rate', [(0, BASE_LR), (35, BASE_LR * 1e-1), (70, BASE_LR * 1e-2), 90 | (95, BASE_LR * 1e-3)]) 91 | ] 92 | """ 93 | Feature Denoising, Sec 5: 94 | Our models are trained for a total of 95 | 110 epochs; we decrease the learning rate by 10× at the 35- 96 | th, 70-th, and 95-th epoch 97 | """ 98 | max_epoch = 110 99 | 100 | if BASE_LR > 0.1: 101 | callbacks.append( 102 | ScheduledHyperParamSetter( 103 | 'learning_rate', [(0, 0.1), (5 * steps_per_epoch, BASE_LR)], 104 | interp='linear', step_based=True)) 105 | """ 106 | ImageNet in 1 Hour, Sec 2.2: 107 | we start from a learning rate of η and increment it by a constant amount at 108 | each iteration such that it reaches ηˆ = kη after 5 epochs 109 | """ 110 | 111 | if not args.fake: 112 | # add distributed evaluation, for various attackers that we care. 113 | def add_eval_callback(name, attacker, condition): 114 | cb = create_eval_callback( 115 | name, 116 | model.get_inference_func(attacker), 117 | # always eval in the last 2 epochs no matter what 118 | lambda epoch_num: condition(epoch_num) or epoch_num > max_epoch - 2) 119 | callbacks.append(cb) 120 | 121 | add_eval_callback('eval-clean', NoOpAttacker(), lambda e: True) 122 | add_eval_callback('eval-10step', PGDAttacker(10, args.attack_epsilon, args.attack_step_size), 123 | lambda e: True) 124 | add_eval_callback('eval-50step', PGDAttacker(50, args.attack_epsilon, args.attack_step_size), 125 | lambda e: e % 20 == 0) 126 | add_eval_callback('eval-100step', PGDAttacker(100, args.attack_epsilon, args.attack_step_size), 127 | lambda e: e % 10 == 0) 128 | for k in [20, 30, 40, 60, 70, 80, 90]: 129 | add_eval_callback('eval-{}step'.format(k), 130 | PGDAttacker(k, args.attack_epsilon, args.attack_step_size), 131 | lambda e: False) 132 | 133 | trainer = HorovodTrainer(average=True) 134 | trainer.setup_graph(model.get_inputs_desc(), data, model.build_graph, model.get_optimizer) 135 | trainer.train_with_defaults( 136 | callbacks=callbacks, 137 | steps_per_epoch=steps_per_epoch, 138 | session_init=get_model_loader(args.load) if args.load is not None else None, 139 | max_epoch=max_epoch, 140 | starting_epoch=args.starting_epoch) 141 | 142 | 143 | if __name__ == '__main__': 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument('--load', help='Path to a model to load for evaluation or resuming training.') 146 | parser.add_argument('--starting-epoch', help='The epoch to start with. Useful when resuming training.', type=int, default=1) 147 | parser.add_argument('--logdir', help='Directory suffix for models and training stats.') 148 | parser.add_argument('--eval', action='store_true', help='Evaluate a model instead of training.') 149 | parser.add_argument('--eval-directory', help='Path to a directory of images to classify.') 150 | 151 | parser.add_argument('--data', help='ILSVRC dataset dir') 152 | parser.add_argument('--fake', help='Use fakedata to test or benchmark this model', action='store_true') 153 | parser.add_argument('--no-zmq-ops', help='Use pure python to send/receive data', 154 | action='store_true') 155 | parser.add_argument('--batch', help='Per-GPU batch size', default=32, type=int) 156 | 157 | # attacker flags: 158 | parser.add_argument('--attack-iter', help='Adversarial attack iteration', 159 | type=int, default=30) 160 | parser.add_argument('--attack-epsilon', help='Adversarial attack maximal perturbation', 161 | type=float, default=16.0) 162 | parser.add_argument('--attack-step-size', help='Adversarial attack step size', 163 | type=float, default=1.0) 164 | parser.add_argument('--use-fp16xla', 165 | help='Optimize PGD with fp16+XLA in training or evaluation. (Evaluation during training will still use FP32, for fair comparison)', 166 | action='store_true') 167 | 168 | # architecture flags: 169 | parser.add_argument('-d', '--depth', help='ResNet depth', 170 | type=int, default=50, choices=[50, 101, 152]) 171 | parser.add_argument('--arch', help='Name of architectures defined in nets.py', 172 | default='ResNet') 173 | args = parser.parse_args() 174 | 175 | # Define model 176 | model = getattr(nets, args.arch + 'Model')(args) 177 | 178 | # Define attacker 179 | if args.attack_iter == 0 or args.eval_directory: 180 | attacker = NoOpAttacker() 181 | else: 182 | attacker = PGDAttacker( 183 | args.attack_iter, args.attack_epsilon, args.attack_step_size, 184 | prob_start_from_clean=0.2 if not args.eval else 0.0) 185 | if args.use_fp16xla: 186 | attacker.USE_FP16 = True 187 | attacker.USE_XLA = True 188 | model.set_attacker(attacker) 189 | 190 | os.system("nvidia-smi") 191 | hvd.init() 192 | 193 | if args.eval: 194 | sessinit = get_model_loader(args.load) 195 | if hvd.size() == 1: 196 | # single-GPU eval, slow 197 | ds = get_val_dataflow(args.data, args.batch) 198 | eval_on_ILSVRC12(model, sessinit, ds) 199 | else: 200 | logger.info("CMD: " + " ".join(sys.argv)) 201 | cb = create_eval_callback( 202 | "eval", 203 | model.get_inference_func(attacker), 204 | lambda e: True) 205 | trainer = HorovodTrainer() 206 | trainer.setup_graph(model.get_inputs_desc(), PlaceholderInput(), model.build_graph, model.get_optimizer) 207 | # train for an empty epoch, to reuse the distributed evaluation code 208 | trainer.train_with_defaults( 209 | callbacks=[cb], 210 | monitors=[ScalarPrinter()] if hvd.rank() == 0 else [], 211 | session_init=sessinit, 212 | steps_per_epoch=0, max_epoch=1) 213 | elif args.eval_directory: 214 | assert hvd.size() == 1 215 | files = glob.glob(os.path.join(args.eval_directory, '*.*')) 216 | ds = ImageFromFile(files, resize=224) 217 | ds = BatchData(ds, 32, remainder=True) 218 | ds = MapData(ds, lambda dp: [dp[0][:, :, :, ::-1]]) 219 | # Our model expects BGR images instead of RGB 220 | 221 | pred_config = PredictConfig( 222 | model=model, 223 | session_init=get_model_loader(args.load), 224 | input_names=['input'], 225 | output_names=['linear/output'] # the logits 226 | ) 227 | predictor = SimpleDatasetPredictor(pred_config, ds) 228 | 229 | logger.info("Running inference on {} images in {}".format(len(files), args.eval_directory)) 230 | results = [] 231 | for logits, in predictor.get_result(): 232 | predictions = list(np.argmax(logits, axis=1)) 233 | results.extend(predictions) 234 | assert len(results) == len(files) 235 | output_filename = "predictions.txt" 236 | with open(output_filename, "w") as f: 237 | for filename, label in zip(files, results): 238 | f.write("{}\t{}\n".format(filename, label)) 239 | logger.info("Outputs saved to " + output_filename) 240 | else: 241 | logger.info("Training on {}".format(socket.gethostname())) 242 | logdir = os.path.join( 243 | 'train_log', 244 | 'PGD-{}{}-Batch{}-{}GPUs-iter{}-epsilon{}-step{}{}'.format( 245 | args.arch, args.depth, args.batch, hvd.size(), 246 | args.attack_iter, args.attack_epsilon, args.attack_step_size, 247 | '-' + args.logdir if args.logdir else '')) 248 | 249 | if hvd.rank() == 0: 250 | # old log directory will be automatically removed. 251 | logger.set_logger_dir(logdir, 'd') 252 | logger.info("CMD: " + " ".join(sys.argv)) 253 | logger.info("Rank={}, Local Rank={}, Size={}".format(hvd.rank(), hvd.local_rank(), hvd.size())) 254 | 255 | do_train(model) 256 | -------------------------------------------------------------------------------- /imagenetmod/nets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .adv_model import AdvImageNetModel 8 | from .third_party.imagenet_utils import ImageNetModel 9 | from .resnet_model import ( 10 | resnet_group, resnet_bottleneck, resnet_backbone) 11 | from .resnet_model import denoising 12 | 13 | 14 | NUM_BLOCKS = { 15 | 50: [3, 4, 6, 3], 16 | 101: [3, 4, 23, 3], 17 | 152: [3, 8, 36, 3] 18 | } 19 | 20 | 21 | class ResNetModel(AdvImageNetModel): 22 | def __init__(self, args): 23 | self.num_blocks = NUM_BLOCKS[args.depth] 24 | 25 | def get_logits(self, image): 26 | return resnet_backbone(image, self.num_blocks, resnet_group, resnet_bottleneck) 27 | 28 | 29 | class ResNetDenoiseModel(ImageNetModel): 30 | def __init__(self, args): 31 | self.num_blocks = NUM_BLOCKS[args.depth] 32 | 33 | def get_logits(self, image): 34 | 35 | def group_func(name, *args): 36 | """ 37 | Feature Denoising, Sec 6: 38 | we add 4 denoising blocks to a ResNet: each is added after the 39 | last residual block of res2, res3, res4, and res5, respectively. 40 | """ 41 | l = resnet_group(name, *args) 42 | l = denoising(name + '_denoise', l, embed=True, softmax=True) 43 | return l 44 | 45 | return resnet_backbone(image, self.num_blocks, group_func, resnet_bottleneck) 46 | 47 | 48 | class ResNeXtDenoiseAllModel(ImageNetModel): 49 | """ 50 | ResNeXt 32x8d that performs denoising after every residual block. 51 | """ 52 | def __init__(self, args): 53 | self.num_blocks = NUM_BLOCKS[args.depth] 54 | 55 | def get_logits(self, image): 56 | 57 | print(image.shape.as_list()) 58 | def block_func(l, ch_out, stride): 59 | """ 60 | Feature Denoising, Sec 6.2: 61 | The winning entry, shown in the blue bar, was based on our method by using 62 | a ResNeXt101-32×8 backbone 63 | with non-local denoising blocks added to all residual blocks. 64 | """ 65 | l = resnet_bottleneck(l, ch_out, stride, group=32, res2_bottleneck=8) 66 | l = denoising('non_local', l, embed=False, softmax=False) 67 | return l 68 | 69 | return resnet_backbone(image, self.num_blocks, resnet_group, block_func) 70 | -------------------------------------------------------------------------------- /imagenetmod/resnet_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import tensorflow as tf 8 | from tensorpack.tfutils.argscope import argscope 9 | from tensorpack.models import ( 10 | Conv2D, MaxPooling, AvgPooling, GlobalAvgPooling, BatchNorm, FullyConnected, BNReLU) 11 | 12 | 13 | def resnet_shortcut(l, n_out, stride, activation=tf.identity): 14 | n_in = l.get_shape().as_list()[1] 15 | if n_in != n_out: # change dimension when channel is not the same 16 | return Conv2D('convshortcut', l, n_out, 1, strides=stride, activation=activation) 17 | else: 18 | return l 19 | 20 | 21 | def get_bn(zero_init=False): 22 | if zero_init: 23 | return lambda x, name=None: BatchNorm('bn', x, gamma_initializer=tf.zeros_initializer()) 24 | else: 25 | return lambda x, name=None: BatchNorm('bn', x) 26 | 27 | 28 | def resnet_bottleneck(l, ch_out, stride, group=1, res2_bottleneck=64): 29 | """ 30 | Args: 31 | group (int): the number of groups for resnext 32 | res2_bottleneck (int): the number of channels in res2 bottleneck. 33 | The default corresponds to ResNeXt 1x64d, i.e. vanilla ResNet. 34 | """ 35 | ch_factor = res2_bottleneck * group // 64 36 | shortcut = l 37 | l = Conv2D('conv1', l, ch_out * ch_factor, 1, strides=1, activation=BNReLU) 38 | l = Conv2D('conv2', l, ch_out * ch_factor, 3, strides=stride, activation=BNReLU, split=group) 39 | """ 40 | ImageNet in 1 Hour, Sec 5.1: 41 | the stride-2 convolutions are on 3×3 layers instead of on 1×1 layers 42 | """ 43 | l = Conv2D('conv3', l, ch_out * 4, 1, activation=get_bn(zero_init=True)) 44 | """ 45 | ImageNet in 1 Hour, Sec 5.1: each residual block's last BN where γ is initialized to be 0 46 | """ 47 | ret = l + resnet_shortcut(shortcut, ch_out * 4, stride, activation=get_bn(zero_init=False)) 48 | return tf.nn.relu(ret, name='block_output') 49 | 50 | 51 | def resnet_group(name, l, block_func, features, count, stride): 52 | with tf.variable_scope(name): 53 | for i in range(0, count): 54 | with tf.variable_scope('block{}'.format(i)): 55 | current_stride = stride if i == 0 else 1 56 | l = block_func(l, features, current_stride) 57 | return l 58 | 59 | 60 | def resnet_backbone(image, num_blocks, group_func, block_func): 61 | with argscope([Conv2D, MaxPooling, AvgPooling, GlobalAvgPooling, BatchNorm], data_format='NCHW'), \ 62 | argscope(Conv2D, use_bias=False, 63 | kernel_initializer=tf.variance_scaling_initializer(scale=2.0, mode='fan_out')): 64 | l = Conv2D('conv0', image, 64, 7, strides=2, activation=BNReLU) 65 | l = MaxPooling('pool0', l, pool_size=3, strides=2, padding='SAME') 66 | l = group_func('group0', l, block_func, 64, num_blocks[0], 1) 67 | l = group_func('group1', l, block_func, 128, num_blocks[1], 2) 68 | l = group_func('group2', l, block_func, 256, num_blocks[2], 2) 69 | l = group_func('group3', l, block_func, 512, num_blocks[3], 2) 70 | l = GlobalAvgPooling('gap', l) 71 | logits = FullyConnected('linear', l, 1000, 72 | kernel_initializer=tf.random_normal_initializer(stddev=0.01)) 73 | """ 74 | ImageNet in 1 Hour, Sec 5.1: 75 | The 1000-way fully-connected layer is initialized by 76 | drawing weights from a zero-mean Gaussian with standard deviation of 0.01 77 | """ 78 | return logits 79 | 80 | 81 | def denoising(name, l, embed=True, softmax=True): 82 | """ 83 | Feature Denoising, Fig 4 & 5. 84 | """ 85 | with tf.variable_scope(name): 86 | f = non_local_op(l, embed=embed, softmax=softmax) 87 | f = Conv2D('conv', f, l.shape[1], 1, strides=1, activation=get_bn(zero_init=True)) 88 | l = l + f 89 | return l 90 | 91 | 92 | def non_local_op(l, embed, softmax): 93 | """ 94 | Feature Denoising, Sec 4.2 & Fig 5. 95 | Args: 96 | embed (bool): whether to use embedding on theta & phi 97 | softmax (bool): whether to use gaussian (softmax) version or the dot-product version. 98 | """ 99 | n_in, H, W = l.shape.as_list()[1:] 100 | if embed: 101 | theta = Conv2D('embedding_theta', l, n_in / 2, 1, 102 | strides=1, kernel_initializer=tf.random_normal_initializer(stddev=0.01)) 103 | phi = Conv2D('embedding_phi', l, n_in / 2, 1, 104 | strides=1, kernel_initializer=tf.random_normal_initializer(stddev=0.01)) 105 | g = l 106 | else: 107 | theta, phi, g = l, l, l 108 | if n_in > H * W or softmax: 109 | f = tf.einsum('niab,nicd->nabcd', theta, phi) 110 | if softmax: 111 | orig_shape = tf.shape(f) 112 | f = tf.reshape(f, [-1, H * W, H * W]) 113 | f = f / tf.sqrt(tf.cast(theta.shape[1], theta.dtype)) 114 | f = tf.nn.softmax(f) 115 | f = tf.reshape(f, orig_shape) 116 | f = tf.einsum('nabcd,nicd->niab', f, g) 117 | else: 118 | f = tf.einsum('nihw,njhw->nij', phi, g) 119 | f = tf.einsum('nij,nihw->njhw', f, theta) 120 | if not softmax: 121 | f = f / tf.cast(H * W, f.dtype) 122 | return tf.reshape(f, tf.shape(l)) 123 | -------------------------------------------------------------------------------- /imagenetmod/third_party/README.md: -------------------------------------------------------------------------------- 1 | 2 | Utilities for ImageNet training & distributed evaluation. 3 | 4 | Copied from https://github.com/tensorpack/benchmarks/tree/master/ResNet-Horovod. 5 | -------------------------------------------------------------------------------- /imagenetmod/third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiulingxu/FeatureSpaceAttack/121538bbc298a1ee485cf085e39472690ac9ce48/imagenetmod/third_party/__init__.py -------------------------------------------------------------------------------- /imagenetmod/third_party/imagenet_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: imagenet_utils.py 4 | 5 | 6 | import multiprocessing 7 | import numpy as np 8 | from abc import abstractmethod 9 | 10 | import cv2 11 | import tensorflow as tf 12 | 13 | from tensorpack import imgaug, dataset, ModelDesc 14 | from tensorpack.dataflow import ( 15 | BatchData, MultiThreadMapData, DataFromList, RepeatedData, MultiProcessMapData) 16 | from tensorpack.predict import PredictConfig, SimpleDatasetPredictor 17 | from tensorpack.utils.stats import RatioCounter 18 | from tensorpack.models import regularize_cost 19 | from tensorpack.tfutils.summary import add_moving_summary 20 | from tensorpack.utils import logger 21 | 22 | 23 | def fbresnet_augmentor(isTrain): 24 | """ 25 | Augmentor used in fb.resnet.torch, for BGR images in range [0,255]. 26 | """ 27 | if isTrain: 28 | augmentors = [ 29 | imgaug.GoogleNetRandomCropAndResize(), 30 | # It's OK to remove the following augs if your CPU is not fast enough. 31 | # Removing brightness/contrast/saturation does not have a significant effect on accuracy. 32 | # Removing lighting leads to a tiny drop in accuracy. 33 | imgaug.RandomOrderAug( 34 | [imgaug.BrightnessScale((0.6, 1.4), clip=False), 35 | imgaug.Contrast((0.6, 1.4), clip=False), 36 | imgaug.Saturation(0.4, rgb=False), 37 | # rgb-bgr conversion for the constants copied from fb.resnet.torch 38 | imgaug.Lighting(0.1, 39 | eigval=np.asarray( 40 | [0.2175, 0.0188, 0.0045][::-1]) * 255.0, 41 | eigvec=np.array( 42 | [[-0.5675, 0.7192, 0.4009], 43 | [-0.5808, -0.0045, -0.8140], 44 | [-0.5836, -0.6948, 0.4203]], 45 | dtype='float32')[::-1, ::-1] 46 | )]), 47 | imgaug.Flip(horiz=True), 48 | ] 49 | else: 50 | augmentors = [ 51 | imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC), 52 | imgaug.CenterCrop((224, 224)), 53 | ] 54 | return augmentors 55 | 56 | 57 | def get_val_dataflow( 58 | datadir, batch_size, 59 | augmentors=None, parallel=None, 60 | num_splits=None, split_index=None, dataname="val"): 61 | if augmentors is None: 62 | augmentors = fbresnet_augmentor(False) 63 | assert datadir is not None 64 | assert isinstance(augmentors, list) 65 | if parallel is None: 66 | parallel = min(40, multiprocessing.cpu_count()) 67 | 68 | if num_splits is None: 69 | ds = dataset.ILSVRC12Files(datadir, dataname, shuffle=True) 70 | else: 71 | # shard validation data 72 | assert False 73 | assert split_index < num_splits 74 | files = dataset.ILSVRC12Files(datadir, dataname, shuffle=True) 75 | files.reset_state() 76 | files = list(files.get_data()) 77 | logger.info("Number of validation data = {}".format(len(files))) 78 | split_size = len(files) // num_splits 79 | start, end = split_size * split_index, split_size * (split_index + 1) 80 | end = min(end, len(files)) 81 | logger.info("Local validation split = {} - {}".format(start, end)) 82 | files = files[start: end] 83 | ds = DataFromList(files, shuffle=True) 84 | 85 | aug = imgaug.AugmentorList(augmentors) 86 | 87 | def mapf(dp): 88 | fname, cls = dp 89 | im = cv2.imread(fname, cv2.IMREAD_COLOR) 90 | #from BGR to RGB 91 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 92 | im = aug.augment(im) 93 | return im, cls 94 | ds = MultiThreadMapData(ds, parallel, mapf, 95 | buffer_size=min(2000, ds.size()), strict=True) 96 | ds = BatchData(ds, batch_size, remainder=False) 97 | ds = RepeatedData(ds, num=-1) 98 | # do not fork() under MPI 99 | return ds 100 | 101 | 102 | def eval_on_ILSVRC12(model, sessinit, dataflow): 103 | pred_config = PredictConfig( 104 | model=model, 105 | session_init=sessinit, 106 | input_names=['input', 'label'], 107 | output_names=['wrong-top1', 'wrong-top5'] 108 | ) 109 | pred = SimpleDatasetPredictor(pred_config, dataflow) 110 | acc1, acc5 = RatioCounter(), RatioCounter() 111 | for top1, top5 in pred.get_result(): 112 | batch_size = top1.shape[0] 113 | acc1.feed(top1.sum(), batch_size) 114 | acc5.feed(top5.sum(), batch_size) 115 | print("Top1 Error: {}".format(acc1.ratio)) 116 | print("Top5 Error: {}".format(acc5.ratio)) 117 | 118 | 119 | class ImageNetModel(ModelDesc): 120 | image_shape = 224 121 | 122 | """ 123 | uint8 instead of float32 is used as input type to reduce copy overhead. 124 | It might hurt the performance a liiiitle bit. 125 | """ 126 | image_dtype = tf.uint8 127 | 128 | """ 129 | Either 'NCHW' or 'NHWC' 130 | """ 131 | data_format = 'NCHW' 132 | 133 | """ 134 | Whether the image is BGR or RGB. If using DataFlow, then it should be BGR. 135 | """ 136 | image_bgr = True 137 | 138 | weight_decay = 1e-4 139 | 140 | """ 141 | To apply on normalization parameters, use '.*/W|.*/gamma|.*/beta' 142 | """ 143 | weight_decay_pattern = '.*/W' 144 | 145 | """ 146 | Scale the loss, for whatever reasons (e.g., gradient averaging, fp16 training, etc) 147 | """ 148 | loss_scale = 1. 149 | 150 | """ 151 | Label smoothing (See tf.losses.softmax_cross_entropy) 152 | """ 153 | label_smoothing = 0. 154 | 155 | def inputs(self): 156 | return [tf.placeholder(self.image_dtype, [None, self.image_shape, self.image_shape, 3], 'input'), 157 | tf.placeholder(tf.int32, [None], 'label')] 158 | 159 | def build_graph(self, image, label): 160 | image = self.image_preprocess(image) 161 | assert self.data_format == 'NCHW' 162 | image = tf.transpose(image, [0, 3, 1, 2]) 163 | logits = self.get_logits(image) 164 | 165 | self.logits = logits 166 | loss, self.wrong_1, self.wrong_5 = ImageNetModel.compute_loss_and_error( 167 | logits, label, label_smoothing=self.label_smoothing) 168 | 169 | if self.weight_decay > 0: 170 | wd_loss = regularize_cost(self.weight_decay_pattern, 171 | tf.contrib.layers.l2_regularizer(self.weight_decay), 172 | name='l2_regularize_loss') 173 | add_moving_summary(loss, wd_loss) 174 | total_cost = tf.add_n([loss, wd_loss], name='cost') 175 | else: 176 | total_cost = tf.identity(loss, name='cost') 177 | add_moving_summary(total_cost) 178 | 179 | if self.loss_scale != 1.: 180 | logger.info("Scaling the total loss by {} ...".format(self.loss_scale)) 181 | return total_cost * self.loss_scale 182 | else: 183 | return total_cost 184 | 185 | @abstractmethod 186 | def get_logits(self, image): 187 | """ 188 | Args: 189 | image: 4D tensor of ``self.input_shape`` in ``self.data_format`` 190 | 191 | Returns: 192 | Nx#class logits 193 | """ 194 | 195 | def optimizer(self): 196 | lr = tf.get_variable('learning_rate', initializer=0.1, trainable=False) 197 | tf.summary.scalar('learning_rate-summary', lr) 198 | return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True) 199 | 200 | def image_preprocess(self, image): 201 | with tf.name_scope('image_preprocess'): 202 | if image.dtype.base_dtype != tf.float32: 203 | image = tf.cast(image, tf.float32) 204 | mean = [0.485, 0.456, 0.406] # rgb 205 | std = [0.229, 0.224, 0.225] 206 | if self.image_bgr: 207 | mean = mean[::-1] 208 | std = std[::-1] 209 | image_mean = tf.constant(mean, dtype=tf.float32) * 255. 210 | image_std = tf.constant(std, dtype=tf.float32) * 255. 211 | image = (image - image_mean) / image_std 212 | return image 213 | 214 | @staticmethod 215 | def compute_loss_and_error(logits, label, label_smoothing=0.): 216 | if label_smoothing == 0.: 217 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label) 218 | else: 219 | nclass = logits.shape[-1] 220 | loss = tf.losses.softmax_cross_entropy( 221 | tf.one_hot(label, nclass), 222 | logits, label_smoothing=label_smoothing, 223 | reduction=tf.losses.Reduction.NONE) 224 | loss = tf.reduce_mean(loss, name='xentropy-loss') 225 | 226 | def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'): 227 | with tf.name_scope('prediction_incorrect'): 228 | x = tf.logical_not(tf.nn.in_top_k(logits, label, topk)) 229 | return tf.cast(x, tf.float32, name=name) 230 | 231 | wrong_1 = prediction_incorrect(logits, label, 1, name='wrong-top1') 232 | 233 | wrong_5 = prediction_incorrect(logits, label, 5, name='wrong-top5') 234 | return loss, wrong_1, wrong_5 235 | -------------------------------------------------------------------------------- /imagenetmod/third_party/serve-data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: serve-data.py 4 | 5 | import argparse 6 | import os 7 | import multiprocessing as mp 8 | import socket 9 | 10 | from tensorpack.dataflow import ( 11 | send_dataflow_zmq, MapData, TestDataSpeed, FakeData, dataset, 12 | AugmentImageComponent, BatchData, PrefetchDataZMQ) 13 | from tensorpack.utils import logger 14 | from imagenet_utils import fbresnet_augmentor 15 | 16 | from zmq_ops import dump_arrays 17 | 18 | 19 | def get_data(batch, augmentors): 20 | """ 21 | Sec 3, Remark 4: 22 | Use a single random shuffling of the training data (per epoch) that is divided amongst all k workers. 23 | 24 | NOTE: Here we do not follow the paper, but it makes little differences. 25 | """ 26 | ds = dataset.ILSVRC12(args.data, 'train', shuffle=True) 27 | ds = AugmentImageComponent(ds, augmentors, copy=False) 28 | ds = BatchData(ds, batch, remainder=False) 29 | ds = PrefetchDataZMQ(ds, min(50, mp.cpu_count())) 30 | return ds 31 | 32 | 33 | if __name__ == '__main__': 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--data', help='ILSVRC dataset dir') 36 | parser.add_argument('--fake', action='store_true') 37 | parser.add_argument('--batch', help='per-GPU batch size', 38 | default=32, type=int) 39 | parser.add_argument('--benchmark', action='store_true') 40 | parser.add_argument('--no-zmq-ops', action='store_true') 41 | args = parser.parse_args() 42 | 43 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 44 | 45 | if args.fake: 46 | ds = FakeData( 47 | [[args.batch, 224, 224, 3], [args.batch]], 48 | 1000, random=False, dtype=['uint8', 'int32']) 49 | else: 50 | augs = fbresnet_augmentor(True) 51 | ds = get_data(args.batch, augs) 52 | 53 | logger.info("Serving data on {}".format(socket.gethostname())) 54 | 55 | if args.benchmark: 56 | ds = MapData(ds, dump_arrays) 57 | TestDataSpeed(ds, warmup=300).start() 58 | else: 59 | format = None if args.no_zmq_ops else 'zmq_ops' 60 | send_dataflow_zmq( 61 | ds, 'ipc://@imagenet-train-b{}'.format(args.batch), 62 | hwm=150, format=format, bind=True) 63 | -------------------------------------------------------------------------------- /imagenetmod/third_party/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import horovod.tensorflow as hvd 3 | import tensorflow as tf 4 | 5 | from tensorpack.callbacks import Inferencer 6 | from tensorpack.utils.stats import RatioCounter 7 | 8 | 9 | class HorovodClassificationError(Inferencer): 10 | """ 11 | Like ClassificationError, it evaluates total samples & count of incorrect or correct samples. 12 | But in the end we aggregate the total&count by horovod. 13 | """ 14 | def __init__(self, wrong_tensor_name, summary_name='validation_error'): 15 | """ 16 | Args: 17 | wrong_tensor_name(str): name of the ``wrong`` binary vector tensor. 18 | summary_name(str): the name to log the error with. 19 | """ 20 | self.wrong_tensor_name = wrong_tensor_name 21 | self.summary_name = summary_name 22 | 23 | def _setup_graph(self): 24 | self._placeholder = tf.placeholder(tf.float32, shape=[2], name='to_be_reduced') 25 | self._reduced = hvd.allreduce(self._placeholder, average=False) 26 | 27 | def _before_inference(self): 28 | self.err_stat = RatioCounter() 29 | 30 | def _get_fetches(self): 31 | return [self.wrong_tensor_name] 32 | 33 | def _on_fetches(self, outputs): 34 | vec = outputs[0] 35 | batch_size = len(vec) 36 | wrong = np.sum(vec) 37 | self.err_stat.feed(wrong, batch_size) 38 | 39 | def _after_inference(self): 40 | tot = self.err_stat.total 41 | cnt = self.err_stat.count 42 | tot, cnt = self._reduced.eval(feed_dict={self._placeholder: [tot, cnt]}) 43 | return {self.summary_name: cnt * 1. / tot} 44 | -------------------------------------------------------------------------------- /modelprep.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | import settings 5 | import utils 6 | import dataprep 7 | 8 | import imagenetmod.interface as imagenet_denoise_interface 9 | import models.pretrained.interface as imagenet_normal_interface 10 | from models import cifar10_class as resnet_cifar10 11 | from models import trade_interface as cifar_wrn_trades_interface 12 | 13 | 14 | def init_classifier(conf = 1): 15 | global build_model, restore_model 16 | model_name=settings.config["model_name"] 17 | assert model_name in ["imagenet_denoise", "imagenet_normal", "cifar10_nat", "cifar10_adv", "cifar10_trades"] 18 | if model_name in ["imagenet_denoise"]: 19 | 20 | def _build_model(input,label,reuse): 21 | input = tf.reverse(input, axis=[-1]) # rgb to bgr 22 | logits = imagenet_denoise_interface.build_imagenet_model( 23 | input, label, reuse, conf=conf) 24 | container = utils.build_logits (logits, label, conf) 25 | return container 26 | 27 | _restore_model = imagenet_denoise_interface.restore_parameter 28 | 29 | elif model_name in ["imagenet_normal"]: 30 | def _build_model(input, label, reuse): 31 | # refer to https://github.com/tensorflow/models/blob/6e63dfee4118df6e889227b1a32badf7d0a09e3b/research/slim/preprocessing/vgg_preprocessing.py 32 | _R_MEAN = 123.68 33 | _G_MEAN = 116.78 34 | _B_MEAN = 103.94 35 | _mean = np.array([_R_MEAN, _G_MEAN, _B_MEAN]).reshape([1,1,1,-1]) 36 | input = input - _mean 37 | 38 | logits = imagenet_normal_interface.build_imagenet_model( 39 | input, label, reuse, conf=conf) 40 | container = utils.build_logits(logits, label, conf) 41 | return container 42 | 43 | _restore_model = imagenet_normal_interface.restore_parameter 44 | 45 | elif model_name in ["cifar10_nat","cifar10_adv"]: 46 | def _build_model(input, label, reuse): 47 | model = resnet_cifar10.Model("eval", dataprep.raw_cifar.train_images) 48 | model._build_model(input, label, reuse, conf = conf) 49 | container = utils.build_logits(model.logits, label, conf) 50 | return container 51 | 52 | def _restore_model(sess): 53 | classifier_vars = utils.get_scope_var("model") 54 | classifier_saver = tf.train.Saver(classifier_vars, max_to_keep=1) 55 | if model_name == "cifar10_nat": 56 | classifier_saver.restore(sess, "./pretrained/pretrained.ckpt") 57 | elif model_name == "cifar10_adv": 58 | classifier_saver.restore(sess, "./pretrained/hardened.ckpt") 59 | 60 | elif model_name in ["cifar10_trades"]: 61 | def _build_model(input, label, reuse): 62 | assert settings.config["BATCH_SIZE"] == 64 , "Graph is static and the batch size must be 64" 63 | logits = cifar_wrn_trades_interface.get_model(input) 64 | container = utils.build_logits(logits, label, conf) 65 | return container 66 | 67 | def _restore_model(sess): 68 | pass 69 | 70 | restore_model = _restore_model 71 | build_model = _build_model 72 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiulingxu/FeatureSpaceAttack/121538bbc298a1ee485cf085e39472690ac9ce48/models/__init__.py -------------------------------------------------------------------------------- /models/cifar10_class.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/tensorflow/models/tree/master/resnet 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | """with tf.variable_scope('input'): 10 | 11 | self.x_input = tf.placeholder( 12 | tf.float32, 13 | shape=[None, 32, 32, 3]) 14 | 15 | self.y_input = tf.placeholder(tf.int64, shape=None) 16 | 17 | assert self.statistics_enable, "Please provide training data statistics" 18 | input_standardized=(self.x_input-self.tr_mean)/self.tr_std 19 | #input_standardized = tf.map_fn(lambda img: tf.image.per_image_standardization(img), 20 | # self.x_input) 21 | x = self._conv('init_conv', input_standardized, 3, 3, 16, self._stride_arr(1)) 22 | """ 23 | 24 | class Model(object): 25 | """ResNet model.""" 26 | 27 | def __init__(self, mode, data=None): 28 | """ResNet constructor. 29 | 30 | Args: 31 | mode: One of 'train' and 'eval'. 32 | """ 33 | self.mode = mode 34 | 35 | self.provide_statistics() 36 | #self._build_model() 37 | 38 | def add_internal_summaries(self): 39 | pass 40 | 41 | def _stride_arr(self, stride): 42 | """Map a stride scalar to the stride array for tf.nn.conv2d.""" 43 | return [1, stride, stride, 1] 44 | 45 | def provide_statistics(self):#,data,label): 46 | self.statistics_enable=True 47 | self.tr_mean = np.array([125.3, 123.0, 113.9]).reshape([1,1,1,-1]) #np.mean(data, axis=(0, 1, 2), keepdims=True) 48 | self.tr_std = np.array([63.0, 62.1, 66.7]).reshape([1,1,1,-1]) #np.std(data,axis=(0,1,2), keepdims=True) 49 | 50 | def _build_model_easy(self): 51 | self.x_input = tf.placeholder( 52 | tf.float32, 53 | shape=[None, 32, 32, 3]) 54 | 55 | self.y_input = tf.placeholder(tf.int64, shape=None) 56 | self._build_model(self.x_input,self.y_input, reuse=False) 57 | 58 | 59 | def _build_model(self,x, y, reuse, conf=1): 60 | assert self.mode == 'train' or self.mode == 'eval' 61 | """Build the core model within the graph.""" 62 | with tf.variable_scope('model', reuse=reuse): 63 | 64 | input_standardized = (x-self.tr_mean)/self.tr_std 65 | x = self._conv('init_conv', input_standardized, 66 | 3, 3, 16, self._stride_arr(1)) 67 | strides = [1, 2, 2] 68 | activate_before_residual = [True, False, False] 69 | res_func = self._residual 70 | 71 | # Uncomment the following codes to use w28-10 wide residual network. 72 | # It is more memory efficient than very deep residual network and has 73 | # comparably good performance. 74 | # https://arxiv.org/pdf/1605.07146v1.pdf 75 | filters = [16, 32, 64, 128] 76 | 77 | 78 | # Update hps.num_residual_units to 9 79 | 80 | with tf.variable_scope('unit_1_0'): 81 | x = res_func(x, filters[0], filters[1], self._stride_arr(strides[0]), 82 | activate_before_residual[0]) 83 | for i in range(1, 2): 84 | with tf.variable_scope('unit_1_%d' % i): 85 | x = res_func(x, filters[1], filters[1], self._stride_arr(1), False) 86 | 87 | with tf.variable_scope('unit_2_0'): 88 | x = res_func(x, filters[1], filters[2], self._stride_arr(strides[1]), 89 | activate_before_residual[1]) 90 | for i in range(1, 2): 91 | with tf.variable_scope('unit_2_%d' % i): 92 | x = res_func(x, filters[2], filters[2], self._stride_arr(1), False) 93 | 94 | with tf.variable_scope('unit_3_0'): 95 | x = res_func(x, filters[2], filters[3], self._stride_arr(strides[2]), 96 | activate_before_residual[2]) 97 | for i in range(1, 2): 98 | with tf.variable_scope('unit_3_%d' % i): 99 | x = res_func(x, filters[3], filters[3], self._stride_arr(1), False) 100 | 101 | with tf.variable_scope('unit_last'): 102 | x = self._batch_norm('final_bn', x) 103 | x = self._relu(x, 0.1) 104 | x = self._global_avg_pool(x) 105 | 106 | 107 | with tf.variable_scope('logit'): 108 | self.pre_softmax = self._fully_connected(x, 10) 109 | 110 | self.predictions = tf.argmax(self.pre_softmax, 1) 111 | self.correct_prediction = tf.equal(self.predictions, y) 112 | self.num_correct = tf.reduce_sum( 113 | tf.cast(self.correct_prediction, tf.int64)) 114 | self.accuracy = tf.reduce_mean( 115 | tf.cast(self.correct_prediction, tf.float32)) 116 | 117 | 118 | with tf.variable_scope('costs'): 119 | self.logits = self.pre_softmax 120 | self.y_xent = tf.nn.sparse_softmax_cross_entropy_with_logits( 121 | logits=self.pre_softmax, labels=y) 122 | self.relaxed_y_xent = tf.reduce_mean( tf.log ( \ 123 | 1- tf.reduce_sum( tf.nn.softmax(self.pre_softmax) * tf.one_hot(y,depth=10) ,axis=-1) \ 124 | ) ) 125 | 126 | label_one_hot = tf.one_hot(y, depth=10) 127 | wrong_logit = tf.reduce_max( 128 | self.pre_softmax * (1-label_one_hot) - label_one_hot * 1e7, axis=-1) 129 | true_logit = tf.reduce_sum( 130 | self.pre_softmax * label_one_hot, axis=-1) 131 | self.target_loss = - tf.reduce_sum(tf.nn.relu(true_logit - wrong_logit + conf) ) 132 | 133 | self.xent = tf.reduce_sum(self.y_xent, name='y_xent') 134 | self.mean_xent = tf.reduce_mean(self.y_xent) 135 | self.weight_decay_loss = self._decay() 136 | 137 | def _batch_norm(self, name, x): 138 | """Batch normalization.""" 139 | with tf.name_scope(name): 140 | return tf.contrib.layers.batch_norm( 141 | inputs=x, 142 | decay=.9, 143 | center=True, 144 | scale=True, 145 | activation_fn=None, 146 | updates_collections=None, 147 | is_training=(self.mode == 'train')) 148 | 149 | 150 | def _instance_norm(self,name,x): 151 | with tf.name_scope(name): 152 | return tf.contrib.layers.instance_norm( 153 | inputs=x, 154 | center=True, 155 | scale=True, 156 | )#is_training=(self.mode=="train")) 157 | 158 | def _norm(self,name,x,norm="batch"): 159 | if norm=="batch": 160 | return self._instance_norm(name,x) 161 | else: 162 | return self._batch_norm(name,x) 163 | 164 | def _residual(self, x, in_filter, out_filter, stride, 165 | activate_before_residual=False): 166 | """Residual unit with 2 sub layers.""" 167 | if activate_before_residual: 168 | with tf.variable_scope('shared_activation'): 169 | x = self._norm('init_bn', x) 170 | x = self._relu(x, 0.1) 171 | orig_x = x 172 | else: 173 | with tf.variable_scope('residual_only_activation'): 174 | orig_x = x 175 | x = self._norm('init_bn', x) 176 | x = self._relu(x, 0.1) 177 | 178 | with tf.variable_scope('sub1'): 179 | x = self._conv('conv1', x, 3, in_filter, out_filter, stride) 180 | 181 | with tf.variable_scope('sub2'): 182 | x = self._norm('bn2', x) 183 | x = self._relu(x, 0.1) 184 | x = self._conv('conv2', x, 3, out_filter, out_filter, [1, 1, 1, 1]) 185 | 186 | with tf.variable_scope('sub_add'): 187 | if in_filter != out_filter: 188 | orig_x = tf.nn.avg_pool(orig_x, stride, stride, 'VALID') 189 | orig_x = tf.pad( 190 | orig_x, [[0, 0], [0, 0], [0, 0], 191 | [(out_filter-in_filter)//2, (out_filter-in_filter)//2]]) 192 | x += orig_x 193 | 194 | tf.logging.debug('image after unit %s', x.get_shape()) 195 | return x 196 | 197 | def _decay(self): 198 | """L2 weight decay loss.""" 199 | costs = [] 200 | for var in tf.trainable_variables(): 201 | if var.op.name.find('DW') > 0: 202 | costs.append(tf.nn.l2_loss(var)) 203 | return tf.add_n(costs) 204 | 205 | def _conv(self, name, x, filter_size, in_filters, out_filters, strides): 206 | """Convolution.""" 207 | with tf.variable_scope(name): 208 | n = filter_size * filter_size * out_filters 209 | kernel = tf.get_variable( 210 | 'DW', [filter_size, filter_size, in_filters, out_filters], 211 | tf.float32, initializer=tf.random_normal_initializer( 212 | stddev=np.sqrt(2.0/n))) 213 | return tf.nn.conv2d(x, kernel, strides, padding='SAME') 214 | 215 | def _relu(self, x, leakiness=0.0): 216 | """Relu, with optional leaky support.""" 217 | return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu') 218 | 219 | def _fully_connected(self, x, out_dim): 220 | """FullyConnected layer for final output.""" 221 | num_non_batch_dimensions = len(x.shape) 222 | prod_non_batch_dimensions = 1 223 | for ii in range(num_non_batch_dimensions - 1): 224 | prod_non_batch_dimensions *= int(x.shape[ii + 1]) 225 | x = tf.reshape(x, [tf.shape(x)[0], -1]) 226 | w = tf.get_variable( 227 | 'DW', [prod_non_batch_dimensions, out_dim], 228 | initializer=tf.uniform_unit_scaling_initializer(factor=1.0)) 229 | b = tf.get_variable('biases', [out_dim], 230 | initializer=tf.constant_initializer()) 231 | return tf.nn.xw_plus_b(x, w, b) 232 | 233 | def _global_avg_pool(self, x): 234 | assert x.get_shape().ndims == 4 235 | return tf.reduce_mean(x, [1, 2]) 236 | 237 | 238 | 239 | -------------------------------------------------------------------------------- /models/pretrained/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiulingxu/FeatureSpaceAttack/121538bbc298a1ee485cf085e39472690ac9ce48/models/pretrained/__init__.py -------------------------------------------------------------------------------- /models/pretrained/interface.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from . import resnet_slim 3 | slim = tf.contrib.slim 4 | 5 | def get_scope_var(scope_name): 6 | var_list = tf.get_collection( 7 | tf.GraphKeys.GLOBAL_VARIABLES, scope=scope_name) 8 | assert (len(var_list) >= 1) 9 | return var_list 10 | 11 | def restore_parameter(sess): 12 | file_path = "imagenet_resnet_v1_50.ckpt" 13 | var_list = get_scope_var("resnet_v1") 14 | saver = tf.train.Saver(var_list) 15 | saver.restore(sess,file_path) 16 | 17 | 18 | 19 | class container: 20 | def __init__(self): 21 | pass 22 | 23 | def compute_loss_and_error(logits, label, label_smoothing=0.): 24 | if label_smoothing == 0.: 25 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits( 26 | logits=logits, labels=label) 27 | else: 28 | nclass = logits.shape[-1] 29 | loss = tf.losses.softmax_cross_entropy( 30 | tf.one_hot(label, nclass), 31 | logits, label_smoothing=label_smoothing, 32 | reduction=tf.losses.Reduction.NONE) 33 | loss = tf.reduce_mean(loss, name='xentropy-loss') 34 | 35 | def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'): 36 | with tf.name_scope('prediction_incorrect'): 37 | x = tf.logical_not(tf.nn.in_top_k(logits, label, topk)) 38 | return tf.cast(x, tf.float32, name=name) 39 | 40 | wrong_1 = prediction_incorrect(logits, label, 1, name='wrong-top1') 41 | 42 | wrong_5 = prediction_incorrect(logits, label, 5, name='wrong-top5') 43 | return loss, wrong_1, wrong_5 44 | 45 | def build_imagenet_model(image, label, reuse=False, conf=1, shrink_class = 1000): 46 | 47 | with slim.arg_scope(resnet_slim.resnet_arg_scope()): 48 | logits, desc = resnet_slim.resnet_v1_50(image, num_classes=shrink_class, is_training= False, reuse=reuse) 49 | return logits 50 | 51 | -------------------------------------------------------------------------------- /models/pretrained/resnet_slim.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains definitions for the original form of Residual Networks. 16 | The 'v1' residual networks (ResNets) implemented in this module were proposed 17 | by: 18 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 20 | Other variants were introduced in: 21 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 22 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027 23 | The networks defined in this module utilize the bottleneck building block of 24 | [1] with projection shortcuts only for increasing depths. They employ batch 25 | normalization *after* every weight layer. This is the architecture used by 26 | MSRA in the Imagenet and MSCOCO 2016 competition models ResNet-101 and 27 | ResNet-152. See [2; Fig. 1a] for a comparison between the current 'v1' 28 | architecture and the alternative 'v2' architecture of [2] which uses batch 29 | normalization *before* every weight layer in the so-called full pre-activation 30 | units. 31 | Typical use: 32 | from tensorflow.contrib.slim.nets import resnet_v1 33 | ResNet-101 for image classification into 1000 classes: 34 | # inputs has shape [batch, 224, 224, 3] 35 | with slim.arg_scope(resnet_v1.resnet_arg_scope()): 36 | net, end_points = resnet_v1.resnet_v1_101(inputs, 1000, is_training=False) 37 | ResNet-101 for semantic segmentation into 21 classes: 38 | # inputs has shape [batch, 513, 513, 3] 39 | with slim.arg_scope(resnet_v1.resnet_arg_scope()): 40 | net, end_points = resnet_v1.resnet_v1_101(inputs, 41 | 21, 42 | is_training=False, 43 | global_pool=False, 44 | output_stride=16) 45 | """ 46 | from __future__ import absolute_import 47 | from __future__ import division 48 | from __future__ import print_function 49 | 50 | import tensorflow as tf 51 | 52 | from . import resnet_utils 53 | 54 | 55 | resnet_arg_scope = resnet_utils.resnet_arg_scope 56 | slim = tf.contrib.slim 57 | 58 | 59 | class NoOpScope(object): 60 | """No-op context manager.""" 61 | 62 | def __enter__(self): 63 | return None 64 | 65 | def __exit__(self, exc_type, exc_value, traceback): 66 | return False 67 | 68 | 69 | @slim.add_arg_scope 70 | def bottleneck(inputs, 71 | depth, 72 | depth_bottleneck, 73 | stride, 74 | rate=1, 75 | outputs_collections=None, 76 | scope=None, 77 | use_bounded_activations=False): 78 | """Bottleneck residual unit variant with BN after convolutions. 79 | This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for 80 | its definition. Note that we use here the bottleneck variant which has an 81 | extra bottleneck layer. 82 | When putting together two consecutive ResNet blocks that use this unit, one 83 | should use stride = 2 in the last unit of the first block. 84 | Args: 85 | inputs: A tensor of size [batch, height, width, channels]. 86 | depth: The depth of the ResNet unit output. 87 | depth_bottleneck: The depth of the bottleneck layers. 88 | stride: The ResNet unit's stride. Determines the amount of downsampling of 89 | the units output compared to its input. 90 | rate: An integer, rate for atrous convolution. 91 | outputs_collections: Collection to add the ResNet unit output. 92 | scope: Optional variable_scope. 93 | use_bounded_activations: Whether or not to use bounded activations. Bounded 94 | activations better lend themselves to quantized inference. 95 | Returns: 96 | The ResNet unit's output. 97 | """ 98 | with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc: 99 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) 100 | if depth == depth_in: 101 | shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') 102 | else: 103 | shortcut = slim.conv2d( 104 | inputs, 105 | depth, [1, 1], 106 | stride=stride, 107 | activation_fn=tf.nn.relu6 if use_bounded_activations else None, 108 | scope='shortcut') 109 | 110 | residual = slim.conv2d(inputs, depth_bottleneck, [1, 1], stride=1, 111 | scope='conv1') 112 | residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride, 113 | rate=rate, scope='conv2') 114 | residual = slim.conv2d(residual, depth, [1, 1], stride=1, 115 | activation_fn=None, scope='conv3') 116 | 117 | if use_bounded_activations: 118 | # Use clip_by_value to simulate bandpass activation. 119 | residual = tf.clip_by_value(residual, -6.0, 6.0) 120 | output = tf.nn.relu6(shortcut + residual) 121 | else: 122 | output = tf.nn.relu(shortcut + residual) 123 | 124 | return slim.utils.collect_named_outputs(outputs_collections, 125 | sc.name, 126 | output) 127 | 128 | 129 | def resnet_v1(inputs, 130 | blocks, 131 | num_classes=None, 132 | is_training=True, 133 | global_pool=True, 134 | output_stride=None, 135 | include_root_block=True, 136 | spatial_squeeze=True, 137 | store_non_strided_activations=False, 138 | reuse=None, 139 | scope=None): 140 | """Generator for v1 ResNet models. 141 | This function generates a family of ResNet v1 models. See the resnet_v1_*() 142 | methods for specific model instantiations, obtained by selecting different 143 | block instantiations that produce ResNets of various depths. 144 | Training for image classification on Imagenet is usually done with [224, 224] 145 | inputs, resulting in [7, 7] feature maps at the output of the last ResNet 146 | block for the ResNets defined in [1] that have nominal stride equal to 32. 147 | However, for dense prediction tasks we advise that one uses inputs with 148 | spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In 149 | this case the feature maps at the ResNet output will have spatial shape 150 | [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1] 151 | and corners exactly aligned with the input image corners, which greatly 152 | facilitates alignment of the features to the image. Using as input [225, 225] 153 | images results in [8, 8] feature maps at the output of the last ResNet block. 154 | For dense prediction tasks, the ResNet needs to run in fully-convolutional 155 | (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all 156 | have nominal stride equal to 32 and a good choice in FCN mode is to use 157 | output_stride=16 in order to increase the density of the computed features at 158 | small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915. 159 | Args: 160 | inputs: A tensor of size [batch, height_in, width_in, channels]. 161 | blocks: A list of length equal to the number of ResNet blocks. Each element 162 | is a resnet_utils.Block object describing the units in the block. 163 | num_classes: Number of predicted classes for classification tasks. 164 | If 0 or None, we return the features before the logit layer. 165 | is_training: whether batch_norm layers are in training mode. If this is set 166 | to None, the callers can specify slim.batch_norm's is_training parameter 167 | from an outer slim.arg_scope. 168 | global_pool: If True, we perform global average pooling before computing the 169 | logits. Set to True for image classification, False for dense prediction. 170 | output_stride: If None, then the output will be computed at the nominal 171 | network stride. If output_stride is not None, it specifies the requested 172 | ratio of input to output spatial resolution. 173 | include_root_block: If True, include the initial convolution followed by 174 | max-pooling, if False excludes it. 175 | spatial_squeeze: if True, logits is of shape [B, C], if false logits is 176 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes. 177 | To use this parameter, the input images must be smaller than 300x300 178 | pixels, in which case the output logit layer does not contain spatial 179 | information and can be removed. 180 | store_non_strided_activations: If True, we compute non-strided (undecimated) 181 | activations at the last unit of each block and store them in the 182 | `outputs_collections` before subsampling them. This gives us access to 183 | higher resolution intermediate activations which are useful in some 184 | dense prediction problems but increases 4x the computation and memory cost 185 | at the last unit of each block. 186 | reuse: whether or not the network and its variables should be reused. To be 187 | able to reuse 'scope' must be given. 188 | scope: Optional variable_scope. 189 | Returns: 190 | net: A rank-4 tensor of size [batch, height_out, width_out, channels_out]. 191 | If global_pool is False, then height_out and width_out are reduced by a 192 | factor of output_stride compared to the respective height_in and width_in, 193 | else both height_out and width_out equal one. If num_classes is 0 or None, 194 | then net is the output of the last ResNet block, potentially after global 195 | average pooling. If num_classes a non-zero integer, net contains the 196 | pre-softmax activations. 197 | end_points: A dictionary from components of the network to the corresponding 198 | activation. 199 | Raises: 200 | ValueError: If the target output_stride is not valid. 201 | """ 202 | with tf.variable_scope(scope, 'resnet_v1', [inputs], reuse=reuse) as sc: 203 | end_points_collection = sc.original_name_scope + '_end_points' 204 | with slim.arg_scope([slim.conv2d, bottleneck, 205 | resnet_utils.stack_blocks_dense], 206 | outputs_collections=end_points_collection): 207 | with (slim.arg_scope([slim.batch_norm], is_training=is_training) 208 | if is_training is not None else NoOpScope()): 209 | net = inputs 210 | if include_root_block: 211 | if output_stride is not None: 212 | if output_stride % 4 != 0: 213 | raise ValueError( 214 | 'The output_stride needs to be a multiple of 4.') 215 | output_stride /= 4 216 | net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1') 217 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1') 218 | net = resnet_utils.stack_blocks_dense(net, blocks, output_stride, 219 | store_non_strided_activations) 220 | # Convert end_points_collection into a dictionary of end_points. 221 | end_points = slim.utils.convert_collection_to_dict( 222 | end_points_collection) 223 | 224 | if global_pool: 225 | # Global average pooling. 226 | net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True) 227 | end_points['global_pool'] = net 228 | if num_classes: 229 | 230 | net_2d = tf.squeeze(net, [1, 2]) 231 | net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, 232 | normalizer_fn=None, scope='logits') 233 | matrix, bias = slim.get_variables(scope=sc.name+'/logits') 234 | matrix = tf.squeeze(matrix,[0,1]) 235 | logits = tf.add(tf.matmul(net_2d,matrix),bias) 236 | """ 237 | end_points[sc.name + '/logits'] = net 238 | if spatial_squeeze: 239 | net = tf.squeeze(net, [1, 2], name='SpatialSqueeze') 240 | end_points[sc.name + '/spatial_squeeze'] = net 241 | end_points['predictions'] = slim.softmax(net, scope='predictions') 242 | """ 243 | return logits, end_points 244 | 245 | 246 | resnet_v1.default_image_size = 224 247 | 248 | 249 | def resnet_v1_block(scope, base_depth, num_units, stride): 250 | """Helper function for creating a resnet_v1 bottleneck block. 251 | Args: 252 | scope: The scope of the block. 253 | base_depth: The depth of the bottleneck layer for each unit. 254 | num_units: The number of units in the block. 255 | stride: The stride of the block, implemented as a stride in the last unit. 256 | All other units have stride=1. 257 | Returns: 258 | A resnet_v1 bottleneck block. 259 | """ 260 | return resnet_utils.Block(scope, bottleneck, [{ 261 | 'depth': base_depth * 4, 262 | 'depth_bottleneck': base_depth, 263 | 'stride': 1 264 | }] * (num_units - 1) + [{ 265 | 'depth': base_depth * 4, 266 | 'depth_bottleneck': base_depth, 267 | 'stride': stride 268 | }]) 269 | 270 | 271 | def resnet_v1_50(inputs, 272 | num_classes=None, 273 | is_training=True, 274 | global_pool=True, 275 | output_stride=None, 276 | spatial_squeeze=True, 277 | store_non_strided_activations=False, 278 | min_base_depth=8, 279 | depth_multiplier=1, 280 | reuse=None, 281 | scope='resnet_v1_50'): 282 | """ResNet-50 model of [1]. See resnet_v1() for arg and return description.""" 283 | def depth_func(d): return max(int(d * depth_multiplier), min_base_depth) 284 | blocks = [ 285 | resnet_v1_block('block1', base_depth=depth_func(64), num_units=3, 286 | stride=2), 287 | resnet_v1_block('block2', base_depth=depth_func(128), num_units=4, 288 | stride=2), 289 | resnet_v1_block('block3', base_depth=depth_func(256), num_units=6, 290 | stride=2), 291 | resnet_v1_block('block4', base_depth=depth_func(512), num_units=3, 292 | stride=1), 293 | ] 294 | return resnet_v1(inputs, blocks, num_classes, is_training, 295 | global_pool=global_pool, output_stride=output_stride, 296 | include_root_block=True, spatial_squeeze=spatial_squeeze, 297 | store_non_strided_activations=store_non_strided_activations, 298 | reuse=reuse, scope=scope) 299 | 300 | 301 | resnet_v1_50.default_image_size = resnet_v1.default_image_size 302 | 303 | 304 | def resnet_v1_101(inputs, 305 | num_classes=None, 306 | is_training=True, 307 | global_pool=True, 308 | output_stride=None, 309 | spatial_squeeze=True, 310 | store_non_strided_activations=False, 311 | min_base_depth=8, 312 | depth_multiplier=1, 313 | reuse=None, 314 | scope='resnet_v1_101'): 315 | """ResNet-101 model of [1]. See resnet_v1() for arg and return description.""" 316 | def depth_func(d): return max(int(d * depth_multiplier), min_base_depth) 317 | blocks = [ 318 | resnet_v1_block('block1', base_depth=depth_func(64), num_units=3, 319 | stride=2), 320 | resnet_v1_block('block2', base_depth=depth_func(128), num_units=4, 321 | stride=2), 322 | resnet_v1_block('block3', base_depth=depth_func(256), num_units=23, 323 | stride=2), 324 | resnet_v1_block('block4', base_depth=depth_func(512), num_units=3, 325 | stride=1), 326 | ] 327 | return resnet_v1(inputs, blocks, num_classes, is_training, 328 | global_pool=global_pool, output_stride=output_stride, 329 | include_root_block=True, spatial_squeeze=spatial_squeeze, 330 | store_non_strided_activations=store_non_strided_activations, 331 | reuse=reuse, scope=scope) 332 | 333 | 334 | resnet_v1_101.default_image_size = resnet_v1.default_image_size 335 | 336 | 337 | def resnet_v1_152(inputs, 338 | num_classes=None, 339 | is_training=True, 340 | global_pool=True, 341 | output_stride=None, 342 | store_non_strided_activations=False, 343 | spatial_squeeze=True, 344 | min_base_depth=8, 345 | depth_multiplier=1, 346 | reuse=None, 347 | scope='resnet_v1_152'): 348 | """ResNet-152 model of [1]. See resnet_v1() for arg and return description.""" 349 | def depth_func(d): return max(int(d * depth_multiplier), min_base_depth) 350 | blocks = [ 351 | resnet_v1_block('block1', base_depth=depth_func(64), num_units=3, 352 | stride=2), 353 | resnet_v1_block('block2', base_depth=depth_func(128), num_units=8, 354 | stride=2), 355 | resnet_v1_block('block3', base_depth=depth_func(256), num_units=36, 356 | stride=2), 357 | resnet_v1_block('block4', base_depth=depth_func(512), num_units=3, 358 | stride=1), 359 | ] 360 | return resnet_v1(inputs, blocks, num_classes, is_training, 361 | global_pool=global_pool, output_stride=output_stride, 362 | include_root_block=True, spatial_squeeze=spatial_squeeze, 363 | store_non_strided_activations=store_non_strided_activations, 364 | reuse=reuse, scope=scope) 365 | 366 | 367 | resnet_v1_152.default_image_size = resnet_v1.default_image_size 368 | 369 | 370 | def resnet_v1_200(inputs, 371 | num_classes=None, 372 | is_training=True, 373 | global_pool=True, 374 | output_stride=None, 375 | store_non_strided_activations=False, 376 | spatial_squeeze=True, 377 | min_base_depth=8, 378 | depth_multiplier=1, 379 | reuse=None, 380 | scope='resnet_v1_200'): 381 | """ResNet-200 model of [2]. See resnet_v1() for arg and return description.""" 382 | def depth_func(d): return max(int(d * depth_multiplier), min_base_depth) 383 | blocks = [ 384 | resnet_v1_block('block1', base_depth=depth_func(64), num_units=3, 385 | stride=2), 386 | resnet_v1_block('block2', base_depth=depth_func(128), num_units=24, 387 | stride=2), 388 | resnet_v1_block('block3', base_depth=depth_func(256), num_units=36, 389 | stride=2), 390 | resnet_v1_block('block4', base_depth=depth_func(512), num_units=3, 391 | stride=1), 392 | ] 393 | return resnet_v1(inputs, blocks, num_classes, is_training, 394 | global_pool=global_pool, output_stride=output_stride, 395 | include_root_block=True, spatial_squeeze=spatial_squeeze, 396 | store_non_strided_activations=store_non_strided_activations, 397 | reuse=reuse, scope=scope) 398 | 399 | 400 | resnet_v1_200.default_image_size = resnet_v1.default_image_size 401 | -------------------------------------------------------------------------------- /models/pretrained/resnet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains building blocks for various versions of Residual Networks. 16 | 17 | Residual networks (ResNets) were proposed in: 18 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015 20 | 21 | More variants were introduced in: 22 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 23 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016 24 | 25 | We can obtain different ResNet variants by changing the network depth, width, 26 | and form of residual unit. This module implements the infrastructure for 27 | building them. Concrete ResNet units and full ResNet networks are implemented in 28 | the accompanying resnet_v1.py and resnet_v2.py modules. 29 | 30 | Compared to https://github.com/KaimingHe/deep-residual-networks, in the current 31 | implementation we subsample the output activations in the last residual unit of 32 | each block, instead of subsampling the input activations in the first residual 33 | unit of each block. The two implementations give identical results but our 34 | implementation is more memory efficient. 35 | """ 36 | from __future__ import absolute_import 37 | from __future__ import division 38 | from __future__ import print_function 39 | 40 | import collections 41 | import tensorflow as tf 42 | 43 | slim = tf.contrib.slim 44 | 45 | 46 | class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])): 47 | """A named tuple describing a ResNet block. 48 | 49 | Its parts are: 50 | scope: The scope of the `Block`. 51 | unit_fn: The ResNet unit function which takes as input a `Tensor` and 52 | returns another `Tensor` with the output of the ResNet unit. 53 | args: A list of length equal to the number of units in the `Block`. The list 54 | contains one (depth, depth_bottleneck, stride) tuple for each unit in the 55 | block to serve as argument to unit_fn. 56 | """ 57 | 58 | 59 | def subsample(inputs, factor, scope=None): 60 | """Subsamples the input along the spatial dimensions. 61 | 62 | Args: 63 | inputs: A `Tensor` of size [batch, height_in, width_in, channels]. 64 | factor: The subsampling factor. 65 | scope: Optional variable_scope. 66 | 67 | Returns: 68 | output: A `Tensor` of size [batch, height_out, width_out, channels] with the 69 | input, either intact (if factor == 1) or subsampled (if factor > 1). 70 | """ 71 | if factor == 1: 72 | return inputs 73 | else: 74 | return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope) 75 | 76 | 77 | def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None): 78 | """Strided 2-D convolution with 'SAME' padding. 79 | 80 | When stride > 1, then we do explicit zero-padding, followed by conv2d with 81 | 'VALID' padding. 82 | 83 | Note that 84 | 85 | net = conv2d_same(inputs, num_outputs, 3, stride=stride) 86 | 87 | is equivalent to 88 | 89 | net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME') 90 | net = subsample(net, factor=stride) 91 | 92 | whereas 93 | 94 | net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME') 95 | 96 | is different when the input's height or width is even, which is why we add the 97 | current function. For more details, see ResnetUtilsTest.testConv2DSameEven(). 98 | 99 | Args: 100 | inputs: A 4-D tensor of size [batch, height_in, width_in, channels]. 101 | num_outputs: An integer, the number of output filters. 102 | kernel_size: An int with the kernel_size of the filters. 103 | stride: An integer, the output stride. 104 | rate: An integer, rate for atrous convolution. 105 | scope: Scope. 106 | 107 | Returns: 108 | output: A 4-D tensor of size [batch, height_out, width_out, channels] with 109 | the convolution output. 110 | """ 111 | if stride == 1: 112 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate, 113 | padding='SAME', scope=scope) 114 | else: 115 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) 116 | pad_total = kernel_size_effective - 1 117 | pad_beg = pad_total // 2 118 | pad_end = pad_total - pad_beg 119 | inputs = tf.pad(inputs, 120 | [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) 121 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride, 122 | rate=rate, padding='VALID', scope=scope) 123 | 124 | 125 | @slim.add_arg_scope 126 | def stack_blocks_dense(net, blocks, output_stride=None, 127 | store_non_strided_activations=False, 128 | outputs_collections=None): 129 | """Stacks ResNet `Blocks` and controls output feature density. 130 | 131 | First, this function creates scopes for the ResNet in the form of 132 | 'block_name/unit_1', 'block_name/unit_2', etc. 133 | 134 | Second, this function allows the user to explicitly control the ResNet 135 | output_stride, which is the ratio of the input to output spatial resolution. 136 | This is useful for dense prediction tasks such as semantic segmentation or 137 | object detection. 138 | 139 | Most ResNets consist of 4 ResNet blocks and subsample the activations by a 140 | factor of 2 when transitioning between consecutive ResNet blocks. This results 141 | to a nominal ResNet output_stride equal to 8. If we set the output_stride to 142 | half the nominal network stride (e.g., output_stride=4), then we compute 143 | responses twice. 144 | 145 | Control of the output feature density is implemented by atrous convolution. 146 | 147 | Args: 148 | net: A `Tensor` of size [batch, height, width, channels]. 149 | blocks: A list of length equal to the number of ResNet `Blocks`. Each 150 | element is a ResNet `Block` object describing the units in the `Block`. 151 | output_stride: If `None`, then the output will be computed at the nominal 152 | network stride. If output_stride is not `None`, it specifies the requested 153 | ratio of input to output spatial resolution, which needs to be equal to 154 | the product of unit strides from the start up to some level of the ResNet. 155 | For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1, 156 | then valid values for the output_stride are 1, 2, 6, 24 or None (which 157 | is equivalent to output_stride=24). 158 | store_non_strided_activations: If True, we compute non-strided (undecimated) 159 | activations at the last unit of each block and store them in the 160 | `outputs_collections` before subsampling them. This gives us access to 161 | higher resolution intermediate activations which are useful in some 162 | dense prediction problems but increases 4x the computation and memory cost 163 | at the last unit of each block. 164 | outputs_collections: Collection to add the ResNet block outputs. 165 | 166 | Returns: 167 | net: Output tensor with stride equal to the specified output_stride. 168 | 169 | Raises: 170 | ValueError: If the target output_stride is not valid. 171 | """ 172 | # The current_stride variable keeps track of the effective stride of the 173 | # activations. This allows us to invoke atrous convolution whenever applying 174 | # the next residual unit would result in the activations having stride larger 175 | # than the target output_stride. 176 | current_stride = 1 177 | 178 | # The atrous convolution rate parameter. 179 | rate = 1 180 | 181 | for block in blocks: 182 | with tf.variable_scope(block.scope, 'block', [net]) as sc: 183 | block_stride = 1 184 | for i, unit in enumerate(block.args): 185 | if store_non_strided_activations and i == len(block.args) - 1: 186 | # Move stride from the block's last unit to the end of the block. 187 | block_stride = unit.get('stride', 1) 188 | unit = dict(unit, stride=1) 189 | 190 | with tf.variable_scope('unit_%d' % (i + 1), values=[net]): 191 | # If we have reached the target output_stride, then we need to employ 192 | # atrous convolution with stride=1 and multiply the atrous rate by the 193 | # current unit's stride for use in subsequent layers. 194 | if output_stride is not None and current_stride == output_stride: 195 | net = block.unit_fn(net, rate=rate, **dict(unit, stride=1)) 196 | rate *= unit.get('stride', 1) 197 | 198 | else: 199 | net = block.unit_fn(net, rate=1, **unit) 200 | current_stride *= unit.get('stride', 1) 201 | if output_stride is not None and current_stride > output_stride: 202 | raise ValueError('The target output_stride cannot be reached.') 203 | 204 | # Collect activations at the block's end before performing subsampling. 205 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 206 | 207 | # Subsampling of the block's output activations. 208 | if output_stride is not None and current_stride == output_stride: 209 | rate *= block_stride 210 | else: 211 | net = subsample(net, block_stride) 212 | current_stride *= block_stride 213 | if output_stride is not None and current_stride > output_stride: 214 | raise ValueError('The target output_stride cannot be reached.') 215 | 216 | if output_stride is not None and current_stride != output_stride: 217 | raise ValueError('The target output_stride cannot be reached.') 218 | 219 | return net 220 | 221 | 222 | def resnet_arg_scope(weight_decay=0.0001, 223 | batch_norm_decay=0.997, 224 | batch_norm_epsilon=1e-5, 225 | batch_norm_scale=True, 226 | activation_fn=tf.nn.relu, 227 | use_batch_norm=True, 228 | batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS): 229 | """Defines the default ResNet arg scope. 230 | 231 | TODO(gpapan): The batch-normalization related default values above are 232 | appropriate for use in conjunction with the reference ResNet models 233 | released at https://github.com/KaimingHe/deep-residual-networks. When 234 | training ResNets from scratch, they might need to be tuned. 235 | 236 | Args: 237 | weight_decay: The weight decay to use for regularizing the model. 238 | batch_norm_decay: The moving average decay when estimating layer activation 239 | statistics in batch normalization. 240 | batch_norm_epsilon: Small constant to prevent division by zero when 241 | normalizing activations by their variance in batch normalization. 242 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the 243 | activations in the batch normalization layer. 244 | activation_fn: The activation function which is used in ResNet. 245 | use_batch_norm: Whether or not to use batch normalization. 246 | batch_norm_updates_collections: Collection for the update ops for 247 | batch norm. 248 | 249 | Returns: 250 | An `arg_scope` to use for the resnet models. 251 | """ 252 | batch_norm_params = { 253 | 'decay': batch_norm_decay, 254 | 'epsilon': batch_norm_epsilon, 255 | 'scale': batch_norm_scale, 256 | 'updates_collections': batch_norm_updates_collections, 257 | 'fused': None, # Use fused batch norm if possible. 258 | } 259 | 260 | with slim.arg_scope( 261 | [slim.conv2d], 262 | weights_regularizer=slim.l2_regularizer(weight_decay), 263 | weights_initializer=slim.variance_scaling_initializer(), 264 | activation_fn=activation_fn, 265 | normalizer_fn=slim.batch_norm if use_batch_norm else None, 266 | normalizer_params=batch_norm_params): 267 | with slim.arg_scope([slim.batch_norm], **batch_norm_params): 268 | # The following implies padding='SAME' for pool1, which makes feature 269 | # alignment easier for dense prediction tasks. This is also used in 270 | # https://github.com/facebook/fb.resnet.torch. However the accompanying 271 | # code of 'Deep Residual Learning for Image Recognition' uses 272 | # padding='VALID' for pool1. You can switch to that choice by setting 273 | # slim.arg_scope([slim.max_pool2d], padding='VALID'). 274 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc: 275 | return arg_sc 276 | -------------------------------------------------------------------------------- /models/trade_interface.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | cnt=0 5 | def load_pb(path_to_pb, input): 6 | input = tf.transpose(input ,[0, 3, 1, 2]) 7 | input = input /255 8 | with tf.gfile.GFile(path_to_pb, 'rb') as f: 9 | graph_def = tf.GraphDef() 10 | graph_def.ParseFromString(f.read()) 11 | #with tf.get_default_graph() as graph: 12 | global cnt 13 | cnt+=1 14 | output, = tf.import_graph_def(graph_def, name='model%d'%cnt, input_map={ 15 | "input: 0": input, }, return_elements=['add_16: 0']) 16 | #print(output.shape.as_list()) 17 | return tf.get_default_graph(), output 18 | 19 | def get_model(input, ): 20 | graph, output = load_pb("./pretrained/model_cifar_wrn.pb", input) 21 | #print([n.name for n in tf.get_default_graph().as_graph_def().node]) 22 | #output_tensor = graph.get_tensor_by_name('model/add_16: 0') 23 | #input_tensor = graph.get_tensor_by_name('model/input:0') 24 | return output#input_tensor, output_tensor 25 | 26 | 27 | class container: 28 | def __init__(self): 29 | pass 30 | 31 | def build_model(x,y, conf=1): 32 | 33 | cont = container() 34 | logits = get_model(x) 35 | cont.logits = logits 36 | 37 | predictions = tf.argmax(logits, 1) 38 | correct_prediction = tf.equal(predictions, y) 39 | 40 | cont.correct_prediction= correct_prediction 41 | cont.accuracy = tf.reduce_mean( 42 | tf.cast(correct_prediction, tf.float32)) 43 | 44 | with tf.variable_scope('costs'): 45 | cont.y_xent = tf.nn.sparse_softmax_cross_entropy_with_logits( 46 | logits=logits, labels=y) 47 | 48 | label_one_hot = tf.one_hot(y, depth=10) 49 | wrong_logit = tf.reduce_max( 50 | logits * (1-label_one_hot) - label_one_hot * 1e7, axis=-1) 51 | true_logit = tf.reduce_sum( 52 | logits * label_one_hot, axis=-1) 53 | cont.target_loss = - \ 54 | tf.reduce_sum(tf.nn.relu(true_logit - wrong_logit + conf)) 55 | 56 | cont.xent = tf.reduce_sum(logits, name='y_xent') 57 | cont.mean_xent = tf.reduce_mean(logits) 58 | return cont 59 | 60 | def gen_pb(): 61 | from trades.models.wideresnet import WideResNet 62 | import torch 63 | import onnx 64 | from onnx_tf.backend import prepare 65 | 66 | device = torch.device("cuda") 67 | model = WideResNet().to(device) 68 | model.load_state_dict(torch.load('./model_cifar_wrn.pt')) 69 | model.eval() 70 | 71 | dummy_input = torch.from_numpy( 72 | np.zeros((64, 3, 32, 32),)).float().to(device) 73 | dummy_output = model(dummy_input) 74 | 75 | torch.onnx.export(model, dummy_input, './model_cifar_wrn.onnx', 76 | input_names=['input'], output_names=['output']) 77 | 78 | model_onnx = onnx.load('./model_cifar_wrn.onnx') 79 | 80 | tf_rep = prepare(model_onnx) 81 | 82 | # Print out tensors and placeholders in model (helpful during inference in TensorFlow) 83 | print(tf_rep.tensor_dict) 84 | 85 | # Export model as .pb file 86 | tf_rep.export_graph('./model_cifar_wrn.pb') 87 | 88 | if __name__=="__main__": 89 | gen_pb() 90 | """ 91 | input_tensor, output_tensor = get_model() 92 | sess = tf.Session() 93 | init = tf.global_variables_initializer() 94 | output_val = sess.run(output_tensor, feed_dict={ 95 | input_tensor: np.zeros((64, 3, 32, 32),)}) 96 | print(output_val) 97 | """ -------------------------------------------------------------------------------- /samples/adv_training.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiulingxu/FeatureSpaceAttack/121538bbc298a1ee485cf085e39472690ac9ce48/samples/adv_training.jpg -------------------------------------------------------------------------------- /samples/attack1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiulingxu/FeatureSpaceAttack/121538bbc298a1ee485cf085e39472690ac9ce48/samples/attack1.jpg -------------------------------------------------------------------------------- /samples/attack2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiulingxu/FeatureSpaceAttack/121538bbc298a1ee485cf085e39472690ac9ce48/samples/attack2.jpg -------------------------------------------------------------------------------- /samples/attacking_phase.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiulingxu/FeatureSpaceAttack/121538bbc298a1ee485cf085e39472690ac9ce48/samples/attacking_phase.PNG -------------------------------------------------------------------------------- /samples/electric_guitar.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiulingxu/FeatureSpaceAttack/121538bbc298a1ee485cf085e39472690ac9ce48/samples/electric_guitar.jpg -------------------------------------------------------------------------------- /samples/espresso.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiulingxu/FeatureSpaceAttack/121538bbc298a1ee485cf085e39472690ac9ce48/samples/espresso.jpg -------------------------------------------------------------------------------- /samples/human_preference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiulingxu/FeatureSpaceAttack/121538bbc298a1ee485cf085e39472690ac9ce48/samples/human_preference.png -------------------------------------------------------------------------------- /samples/llama.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiulingxu/FeatureSpaceAttack/121538bbc298a1ee485cf085e39472690ac9ce48/samples/llama.jpg -------------------------------------------------------------------------------- /samples/printer.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiulingxu/FeatureSpaceAttack/121538bbc298a1ee485cf085e39472690ac9ce48/samples/printer.jpg -------------------------------------------------------------------------------- /samples/samples_diff_attack.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiulingxu/FeatureSpaceAttack/121538bbc298a1ee485cf085e39472690ac9ce48/samples/samples_diff_attack.jpg -------------------------------------------------------------------------------- /samples/training_phase.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiulingxu/FeatureSpaceAttack/121538bbc298a1ee485cf085e39472690ac9ce48/samples/training_phase.PNG -------------------------------------------------------------------------------- /settings.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import logging 3 | import sys 4 | import os 5 | 6 | def init_settings(config_name,suffix="",task_dir=""): 7 | global config 8 | config={} 9 | assert config_name in ["cifar10", "cifar10_shallow", "cifar10_shallowest", "cifar10_unscale", "imagenet", 10 | "imagenet_shallow", "imagenet_shallowest"] 11 | if config_name.find("cifar10")>=0: 12 | data_set = "cifar10" 13 | elif config_name.find("imagenet") >= 0: 14 | data_set = "imagenet" 15 | else: 16 | assert False 17 | 18 | config["config_name"] = config_name 19 | config["data_set"] = data_set 20 | config["style_weight"]=1 21 | 22 | # data mode: 23 | # 1: allowing any pairs to feed into training 24 | # 2: only allowing pairs from the same class to feed into training 25 | config["data_mode"] = 2 26 | config["INTERPOLATE_NUM"] = 50 + 1 27 | 28 | if config_name.find("_shallowest")>=0: 29 | config["BATCH_SIZE"] = 8 30 | 31 | config["ENCODER_LAYERS"] = ( 32 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 33 | 'conv2_1', 'relu2_1') 34 | config["DECODER_LAYERS"] = ('conv2_1', 'conv1_2', 'conv1_1') 35 | config["upsample_indices"] = (0, ) 36 | config["STYLE_LAYERS"] = ('relu1_1', 'relu2_1') 37 | elif config_name.find("_shallow")>=0: 38 | config["BATCH_SIZE"] = 8 39 | config["ENCODER_LAYERS"] = ( 40 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 41 | 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 42 | 'conv3_1', 'relu3_1', ) 43 | config["DECODER_LAYERS"] = ('conv3_1', 44 | 'conv2_2', 'conv2_1', 45 | 'conv1_2', 'conv1_1') 46 | config["upsample_indices"] = (1, 3) 47 | config["STYLE_LAYERS"] = ('relu1_1', 'relu2_1', 'relu3_1') 48 | else: 49 | config["BATCH_SIZE"] = 8 50 | config["ENCODER_LAYERS"] = ( 51 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 52 | 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 53 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 54 | 'conv4_1', 'relu4_1') 55 | config["DECODER_LAYERS"] = ('conv4_1', 56 | 'conv3_4', 'conv3_3', 'conv3_2', 'conv3_1', 57 | 'conv2_2', 'conv2_1', 58 | 'conv1_2', 'conv1_1') 59 | config["upsample_indices"] = (0, 4, 6) 60 | config["STYLE_LAYERS"] = ('relu1_1', 'relu2_1', 'relu3_1', 'relu4_1') 61 | 62 | if config_name == "cifar10_unscale": 63 | config["CLASS_NUM"] = 10 64 | config["IMAGE_SHAPE"] = [32,32,3] 65 | config["DECODER_DIM"] = [16, 16, 128] 66 | 67 | config["NO_SCALE"] = True 68 | config["BATCH_SIZE"] = 64 69 | config["ENCODER_LAYERS"] = ( 70 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 71 | 'conv2_1', 'relu2_1') 72 | config["DECODER_LAYERS"] = ('conv2_1', 'conv1_2', 'conv1_1') 73 | config["upsample_indices"] = (0, ) 74 | config["STYLE_LAYERS"] = ('relu1_1', 'relu2_1') 75 | 76 | config["pretrained_model"] = "pretrained.ckpt" 77 | config["hardened_model"] = "hardened.ckpt" 78 | config["model_save_path"] = "./cifar10transform%d.ckpt" % (config["style_weight"]) 79 | 80 | elif config_name == "cifar10_shallowest": 81 | config["CLASS_NUM"] = 10 82 | config["IMAGE_SHAPE"] = [32, 32, 3] 83 | config["DECODER_DIM"] = [112, 112, 128] 84 | 85 | config["pretrained_model"] = "pretrained.ckpt" 86 | config["hardened_model"] = "hardened.ckpt" 87 | config["model_save_path"] = "./cifar10shallowesttransform_scale%d.ckpt" % ( 88 | config["style_weight"]) 89 | 90 | elif config_name == "cifar10_shallow": 91 | config["CLASS_NUM"] = 10 92 | config["IMAGE_SHAPE"] = [32, 32, 3] 93 | config["DECODER_DIM"] = [112, 112, 128] 94 | 95 | config["pretrained_model"] = "pretrained.ckpt" 96 | config["hardened_model"] = "hardened.ckpt" 97 | config["model_save_path"] = "./cifar10shallowtransform_scale%d.ckpt" % ( 98 | config["style_weight"]) 99 | 100 | elif config_name == "cifar10": 101 | config["CLASS_NUM"] = 10 102 | config["IMAGE_SHAPE"] = [32, 32, 3] 103 | config["DECODER_DIM"] = [28, 28, 512] 104 | 105 | config["pretrained_model"] = "pretrained.ckpt" 106 | config["hardened_model"] = "hardened.ckpt" 107 | config["model_save_path"] = "./cifar10transform_scale%d.ckpt" % ( 108 | config["style_weight"]) 109 | 110 | elif config_name == "imagenet": 111 | config["CLASS_NUM"] = 1000 112 | 113 | config["IMAGE_SHAPE"] = [224, 224, 3] 114 | config["DECODER_DIM"] = [28, 28, 512] 115 | 116 | config["pretrained_model"] = "imagenet_pretrained.ckpt" 117 | config["hardened_model"] = "imagenet_hardened.ckpt" 118 | config["model_save_path"] = "./imagenettransform%d.ckpt.mode2" % ( 119 | config["style_weight"]) 120 | 121 | 122 | elif config_name == "imagenet_shallow": 123 | config["CLASS_NUM"] = 1000 124 | config["INTERPOLATE_NUM"] = 50 + 1 125 | config["DECODER_DIM"] = [56, 56, 256] 126 | config["IMAGE_SHAPE"] = [224, 224, 3] 127 | 128 | config["pretrained_model"] = "imagenet_pretrained.ckpt" 129 | config["hardened_model"] = "imagenet_hardened.ckpt" 130 | config["model_save_path"] = "./imagenetshallowtransform%d.ckpt.mode2" % ( 131 | config["style_weight"]) 132 | config["Decoder_Layer"] = "deconv" 133 | 134 | elif config_name == "imagenet_shallowest": 135 | config["CLASS_NUM"] = 1000 136 | config["INTERPOLATE_NUM"] = 50 + 1 137 | config["DECODER_DIM"] = [112, 112, 128] 138 | config["IMAGE_SHAPE"] = [224, 224, 3] 139 | 140 | config["pretrained_model"] = "imagenet_pretrained.ckpt" 141 | config["hardened_model"] = "imagenet_hardened.ckpt" 142 | config["model_save_path"] = "./imagenetshallowesttransform%d.ckpt.mode2" % ( 143 | config["style_weight"]) 144 | config["Decoder_Layer"] = "deconv" 145 | 146 | global logger 147 | 148 | FORMAT = '%(asctime)-15s %(message)s' 149 | logging.basicConfig(level=logging.INFO, format=FORMAT, 150 | filename=task_dir+"log.log") 151 | logger = logging.getLogger() 152 | ch = logging.StreamHandler(sys.stdout) 153 | ch.setLevel(logging.DEBUG) 154 | formatter = logging.Formatter( 155 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 156 | ch.setFormatter(formatter) 157 | logger.addHandler(ch) 158 | 159 | 160 | def common_const_init(data_set, model_name, decoder_name, task_name): 161 | 162 | global config 163 | 164 | 165 | assert data_set in ["imagenet", "cifar10"] 166 | ENCODER_WEIGHTS_PATH = 'vgg19_normalised.npz' 167 | base_dir_data = os.path.join("store", data_set) 168 | base_dir_decoder = os.path.join("store", data_set, decoder_name) 169 | base_dir_model = os.path.join("store", data_set, decoder_name, model_name) 170 | task_dir = os.path.join("store", data_set, decoder_name, model_name, task_name) 171 | os.makedirs(task_dir, exist_ok=True) 172 | 173 | if data_set == "cifar10": 174 | assert model_name in ["cifar10_nat", "cifar10_adv", "cifar10_trades"] 175 | assert decoder_name in ["cifar10", "cifar10_shallow", "cifar10_shallowest", "cifar10_unscale"] 176 | init_settings(decoder_name, task_dir=task_dir) 177 | 178 | if decoder_name == "cifar10_unscale": 179 | Decoder_Model = "./cifar10transform1.ckpt" 180 | elif decoder_name == "cifar10": 181 | Decoder_Model = "./cifar10transform_scale1.ckpt" 182 | elif decoder_name == "cifar10_shallow": 183 | Decoder_Model = "./cifar10shallowtransform_scale1.ckpt" 184 | elif decoder_name == "cifar10_shallowest": 185 | Decoder_Model = "./cifar10shallowesttransform_scale1.ckpt" 186 | 187 | elif data_set == "imagenet": 188 | assert model_name in ["imagenet_denoise", "imagenet_normal"] 189 | assert decoder_name in ["imagenet", 190 | "imagenet_shallow", "imagenet_shallowest"] 191 | init_settings(decoder_name, task_dir=task_dir) 192 | from imagenetmod.interface import imagenet 193 | 194 | if decoder_name == "imagenet_shallowest": 195 | Decoder_Model = "./imagenetshallowesttransform1.ckpt.mode2" 196 | elif decoder_name == "imagenet_shallow": 197 | # "./trans_pretrained/imagenetshallowtransform1.ckpt-104000" 198 | Decoder_Model = "./imagenetshallowtransform1.ckpt.mode2" 199 | elif decoder_name == "imagenet": 200 | Decoder_Model = "./imagenettransform1.ckpt.mode2" 201 | 202 | print(locals()) 203 | config.update(locals()) 204 | -------------------------------------------------------------------------------- /style_transfer_net.py: -------------------------------------------------------------------------------- 1 | # Style Transfer Network 2 | # Encoder -> AdaIN -> Decoder 3 | 4 | import tensorflow as tf 5 | 6 | from encoder import Encoder 7 | from decoder import Decoder 8 | from adaptive_instance_norm import AdaIN, AdaIN_adv, normalize 9 | import settings 10 | 11 | class Base_Style_Transfer(object): 12 | 13 | def __init__(self, encoder_weights_path): 14 | self.encoder = Encoder(encoder_weights_path) 15 | config_name = settings.config["config_name"] 16 | 17 | self.decode_mode = "vgg" 18 | self.decoder = Decoder() 19 | 20 | def set_stat(self, meanS, sigmaS): 21 | self.meanS = meanS 22 | self.sigmaS = sigmaS 23 | 24 | def decode(self, x): 25 | img = self.decoder.decode(x) 26 | 27 | # post processing for output of decoder 28 | img = self.encoder.deprocess(img) 29 | return img 30 | 31 | def encode(self, img): 32 | # Note that the pretrained vgg model accepts BGR format, but the function by default take RGB value 33 | img = self.encoder.preprocess(img) 34 | 35 | x = self.encoder.encode(img) 36 | return x 37 | 38 | 39 | class StyleTransferNet(Base_Style_Transfer): 40 | 41 | def transform(self, content, style): 42 | 43 | # encode image 44 | enc_c, enc_c_layers = self.encode(content) 45 | enc_s, enc_s_layers = self.encode(style) 46 | 47 | self.encoded_content_layers = enc_c_layers 48 | self.encoded_style_layers = enc_s_layers 49 | 50 | self.norm_features = enc_c 51 | # pass the encoded images to AdaIN 52 | with tf.variable_scope("transform"): 53 | target_features, meanS, sigmaS = AdaIN(enc_c, enc_s) 54 | self.set_stat(meanS,sigmaS) 55 | self.target_features = target_features 56 | 57 | 58 | # decode target features back to image 59 | generated_adv_img = self.decode(target_features) 60 | generated_img = self.decode(enc_c) 61 | 62 | return generated_img, generated_adv_img 63 | 64 | 65 | class StyleTransferNet_adv(Base_Style_Transfer): 66 | 67 | def transform(self, content, p=1.5): 68 | 69 | # encode image 70 | enc_c, enc_c_layers = self.encode(content) 71 | 72 | self.encoded_content_layers = enc_c_layers 73 | 74 | self.norm_features = enc_c 75 | 76 | with tf.variable_scope("transform"): 77 | target_features, self.init_style, self.style_bound, sigmaS, meanS, self.meanC, self.sigmaC, self.init_style_rand, self.normalized \ 78 | = AdaIN_adv(enc_c, p=p) 79 | 80 | self.set_stat(meanS, sigmaS) 81 | bs = settings.config["BATCH_SIZE"] 82 | self.meanS_ph= tf.placeholder(tf.float32, [bs]+ self.meanS.shape.as_list()[1:]) 83 | self.sigmaS_ph = tf.placeholder( 84 | tf.float32, [bs] + self.sigmaS.shape.as_list()[1:]) 85 | self.asgn = [tf.assign(self.meanS, self.meanS_ph), 86 | tf.assign(self.sigmaS, self.sigmaS_ph)] 87 | 88 | self.target_features = target_features 89 | 90 | # decode target features back to image 91 | generated_adv_img = self.decode(target_features) 92 | generated_img = self.decode(enc_c) 93 | 94 | return generated_img, generated_adv_img 95 | 96 | def transform_from_internal(self, content, store_var, sigma, mean): 97 | 98 | # encode image 99 | enc_c, enc_c_layers = self.encode(content) 100 | 101 | self.normalized, self.meanC, self.sigmaC = normalize(enc_c) 102 | 103 | self.store_normalize = tf.assign(store_var, self.normalized) 104 | self.set_stat(mean, sigma) 105 | self.restored_internal = store_var * sigma + mean 106 | self.target_features = self.restored_internal 107 | 108 | self.loss_l1 = tf.reduce_sum(tf.abs(enc_c - self.target_features)) 109 | 110 | generated_adv_img = self.decode( 111 | self.target_features) 112 | generated_img = self.decode(enc_c) 113 | 114 | return generated_img, generated_adv_img 115 | 116 | def transform_from_internal_poly(self, content): 117 | 118 | # encode image 119 | enc_c, enc_c_layers = self.encode(content) 120 | self.normalized, self.meanC, self.sigmaC = normalize(enc_c) 121 | 122 | INTERPOLATE_NUM = settings.config["INTERPOLATE_NUM"] 123 | BATCH_SIZE = settings.config["BATCH_SIZE"] 124 | DIM = settings.config["DECODER_DIM"] 125 | STORE_SHAPE = [BATCH_SIZE] + DIM 126 | 127 | self.store_var = tf.Variable( 128 | tf.zeros(STORE_SHAPE), dtype=tf.float32, trainable=False) 129 | self.internal_sigma = tf.placeholder( 130 | tf.float32, shape=(BATCH_SIZE, INTERPOLATE_NUM, 1, 1, DIM[2])) 131 | self.internal_mean = tf.placeholder(tf.float32, shape=( 132 | BATCH_SIZE, INTERPOLATE_NUM, 1, 1, DIM[2])) 133 | 134 | self.coef_ph = tf.placeholder(tf.float32, shape=[BATCH_SIZE, INTERPOLATE_NUM]) 135 | 136 | with tf.variable_scope("transform"): 137 | self.coef = tf.get_variable("coef", shape=[BATCH_SIZE, INTERPOLATE_NUM], 138 | initializer=tf.ones_initializer()) 139 | self.coef_asgn = tf.assign(self.coef, self.coef_ph) 140 | method = "relu" 141 | if method == "relu": 142 | postive_coef = tf.nn.relu(self.coef) 143 | sum_coef = tf.reduce_sum(postive_coef, axis=1, keepdims=True) 144 | coef_poss = postive_coef / (sum_coef + 1e-7) 145 | self.regulate = tf.assign(self.coef, coef_poss) 146 | coef_poss = tf.reshape( 147 | coef_poss, shape=[BATCH_SIZE, INTERPOLATE_NUM, 1, 1, 1]) 148 | elif method =="softmax": 149 | coef = self.coef * 2 # control the gradient not to be too large 150 | coef_poss = tf.nn.softmax(coef, axis=-1) 151 | coef_poss = tf.reshape(coef_poss, shape=[BATCH_SIZE, INTERPOLATE_NUM, 1, 1, 1]) 152 | self.regulate = [] 153 | 154 | self.store_normalize = [tf.assign(self.store_var, self.normalized), self.coef.initializer] 155 | self.sigma_poly = tf.reduce_sum(self.internal_sigma*coef_poss, axis=1) 156 | self.mean_poly = tf.reduce_sum(self.internal_mean*coef_poss, axis=1) 157 | self.set_stat(self.mean_poly,self.sigma_poly) 158 | self.restored_internal = self.store_var * self.sigma_poly + self.mean_poly 159 | self.target_features = self.restored_internal 160 | 161 | self.loss_l1 = tf.reduce_sum(tf.abs(enc_c - self.target_features)) 162 | 163 | generated_adv_img = self.decode(self.target_features) 164 | generated_img = self.decode(enc_c) 165 | 166 | return generated_img, generated_adv_img 167 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Train the Style Transfer Net 2 | 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | import os 8 | import logging 9 | from PIL import Image 10 | 11 | import settings 12 | import dataprep 13 | import modelprep 14 | 15 | from style_transfer_net import StyleTransferNet 16 | from adaptive_instance_norm import normalize 17 | from utils import save_rgb_img, get_scope_var 18 | 19 | import argparse 20 | 21 | parser = argparse.ArgumentParser(description='Training Auto Encoder for Feature Space Attack') 22 | parser.add_argument("--dataset", help="Dataset for training the auto encoder", choices=["imagenet", "cifar10"]) 23 | parser.add_argument("--decoder", help="Depth of the decoder to use. The deeper one injects more structure change. " +\ 24 | "And it becomes more harmful but less nature-looking.", type=int, choices=[1,2,3], default=1) 25 | parser.add_argument("--scale", help="Whether to scale up the image size of CIFAR10 to the size of Imagenet. " +\ 26 | "Scaling up image size provides better adversarial samples, but consumes larger memory.", action="store_true") 27 | args = parser.parse_args() 28 | 29 | data_set = args.dataset 30 | decoder = args.decoder 31 | if data_set == "imagenet": 32 | decoder_list = {1: "imagenet_shallowest", 33 | 2: "imagenet_shallow", 34 | 3: "imagenet"} 35 | model_name = "imagenet_normal" 36 | decoder_name = decoder_list[decoder] 37 | 38 | elif data_set == "cifar10": 39 | # One can choose to not to scale CIFAR10 to Imagenet for better speed. While for best quality, one need to consider scale the image size up 40 | # The corresponding decoder name is cifar10_unscale 41 | decoder_list = {1: "cifar10_shallowest", 42 | 2: "cifar10_shallow", 43 | 3: "cifar10"} 44 | if args.scale: 45 | decoder_name = decoder_list[decoder] 46 | else: 47 | decoder_name = "cifar10_unscale" 48 | model_name = "cifar10_nat" 49 | 50 | task_name = "train" 51 | 52 | # Put all the pre-defined const into settings and fetch them as global variables 53 | settings.common_const_init(data_set,model_name,decoder_name,task_name) 54 | logger=settings.logger 55 | 56 | for k, v in settings.config.items(): 57 | globals()[k] = v 58 | 59 | dataprep.init_data("train") 60 | get_data = dataprep.get_data 61 | get_data_pair = dataprep.get_data_pair 62 | 63 | TRAINING_IMAGE_SHAPE = IMAGE_SHAPE#settings.config["IMAGE_SHAPE"] 64 | 65 | 66 | LEARNING_RATE = 1e-4 67 | LR_DECAY_RATE = 2e-5 68 | EPSILON = 1e-5 69 | # 2e-5 30000 -> half 70 | DECAY_STEPS = 1.0 71 | adv_weight = 500 72 | style_weight = settings.config["style_weight"] 73 | 74 | 75 | get_data() 76 | encoder_path = 'vgg19_normalised.npz' 77 | 78 | debug=True 79 | logging_period=100 80 | if debug: 81 | from datetime import datetime 82 | start_time = datetime.now() 83 | 84 | # get the traing image shape 85 | HEIGHT, WIDTH, CHANNELS = TRAINING_IMAGE_SHAPE 86 | INPUT_SHAPE = [None, HEIGHT, WIDTH, CHANNELS] 87 | # create the graph 88 | tf_config = tf.ConfigProto() 89 | #tf_config.gpu_options.per_process_gpu_memory_fraction=0.5 90 | tf_config.gpu_options.allow_growth = True 91 | with tf.Graph().as_default(), tf.Session(config=tf_config) as sess: 92 | 93 | content = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='content') 94 | style = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='style') 95 | label = tf.placeholder(tf.int64, shape =None, name="label") 96 | #style = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='style') 97 | 98 | # create the style transfer net 99 | stn = StyleTransferNet(encoder_path) 100 | 101 | # pass content and style to the stn, getting the gen_img 102 | # decoded image from normal one, adversarial image, and input 103 | dec_img, adv_img = stn.transform(content, style) 104 | img = content 105 | 106 | print(adv_img.shape.as_list()) 107 | stn_vars = [] 108 | 109 | # get the target feature maps which is the output of AdaIN 110 | target_features = stn.target_features 111 | 112 | # pass the gen_img to the encoder, and use the output compute loss 113 | enc_gen_adv, enc_gen_layers_adv = stn.encode(adv_img) 114 | enc_gen, enc_gen_layers = stn.encode(dec_img) 115 | 116 | l2_embed = normalize(enc_gen)[0] - normalize(stn.norm_features)[0] 117 | l2_embed = tf.reduce_mean(tf.sqrt(tf.reduce_sum((l2_embed * l2_embed),axis=[1,2,3]))) 118 | 119 | # compute the content loss 120 | content_loss = tf.reduce_sum(tf.reduce_mean( 121 | tf.square(enc_gen_adv - target_features), axis=[1, 2])) 122 | 123 | modelprep.init_classifier() 124 | build_model = modelprep.build_model 125 | restore_model = modelprep.restore_model 126 | 127 | # Get the output from different input, this is a class which define different properties derived from logits 128 | # To use your own model, you can get your own logits from content and pass it to class build_logits in utils.py 129 | adv_output = build_model(adv_img, label, reuse=False) 130 | nat_output = build_model(img, label, reuse=True) 131 | dec_output = build_model(dec_img, label, reuse=True) 132 | 133 | style_layer_loss = [] 134 | for layer in STYLE_LAYERS: 135 | enc_style_feat = stn.encoded_style_layers[layer] 136 | enc_gen_feat = enc_gen_layers_adv[layer] 137 | 138 | meanS, varS = tf.nn.moments(enc_style_feat, [1, 2]) 139 | meanG, varG = tf.nn.moments(enc_gen_feat, [1, 2]) 140 | 141 | sigmaS = tf.sqrt(varS + EPSILON) 142 | sigmaG = tf.sqrt(varG + EPSILON) 143 | 144 | l2_mean = tf.reduce_sum(tf.square(meanG - meanS)) 145 | l2_sigma = tf.reduce_sum(tf.square(sigmaG - sigmaS)) 146 | 147 | style_layer_loss.append(l2_mean + l2_sigma) 148 | 149 | style_loss = tf.reduce_sum(style_layer_loss) 150 | 151 | # compute the total loss 152 | 153 | loss = content_loss + style_weight * style_loss 154 | 155 | decoder_vars = get_scope_var("decoder") 156 | # Training step 157 | global_step = tf.Variable(0, trainable=False) 158 | learning_rate = tf.train.inverse_time_decay( 159 | LEARNING_RATE, global_step, DECAY_STEPS, LR_DECAY_RATE) 160 | train_op = tf.train.AdamOptimizer(learning_rate).minimize( 161 | loss, var_list=stn_vars+decoder_vars, global_step=global_step) # stn_vars+ 162 | 163 | sess.run(tf.global_variables_initializer()) 164 | restore_model(sess) 165 | 166 | # saver 167 | saver = tf.train.Saver(stn_vars+decoder_vars, max_to_keep=1) 168 | step = 0 169 | 170 | if debug: 171 | elapsed_time = datetime.now() - start_time 172 | start_time = datetime.now() 173 | print('\nElapsed time for preprocessing before actually train the model: %s' % elapsed_time) 174 | print('Now begin to train the model...\n') 175 | 176 | 177 | # For imagenet, it takes around 2 days for training 178 | for batch in range(300000): 179 | 180 | # run the training step 181 | x_batch, y_batch, x_batch_style, y_batch_style = get_data_pair() 182 | fdict = {content: x_batch, label: y_batch, style: x_batch_style} 183 | 184 | 185 | if step % 1000 == 0: 186 | saver.save(sess, model_save_path, global_step=step, write_meta_graph=False) 187 | 188 | if batch % 1000 ==0: 189 | img_path = os.path.join(decoder_name + "img%.2f" % style_weight, "%d" % step) 190 | for i in range(8): 191 | gan_out = sess.run(adv_img, feed_dict=fdict) 192 | save_out = np.concatenate( 193 | (gan_out[i], x_batch[i], np.abs(gan_out[i]-x_batch[i]))) 194 | full_path = os.path.join(img_path, "%d.jpg" % i) 195 | os.makedirs(img_path, exist_ok=True) 196 | sz=TRAINING_IMAGE_SHAPE[1] 197 | save_out = np.reshape(save_out, newshape=[sz*3, sz, 3]) 198 | save_rgb_img(save_out, path=full_path) 199 | 200 | if batch % 100 == 0: 201 | 202 | elapsed_time = datetime.now() - start_time 203 | _content_loss, _adv_acc, _adv_loss, _loss, _l2_embed = sess.run([content_loss, adv_output.acc, adv_output.xent_sum, loss, l2_embed], 204 | feed_dict=fdict) 205 | _normal_loss, _normal_acc = sess.run([nat_output.xent_sum, nat_output.acc], 206 | feed_dict=fdict) 207 | 208 | logger.info('step: %d, total loss: %.3f, elapsed time: %s' % 209 | (step, _loss, elapsed_time)) 210 | logger.info('content loss: %.3f' % (_content_loss)) 211 | logger.info('adv loss : %.3f, weighted adv loss: %.3f , adv acc %.3f' % 212 | (_adv_loss, adv_weight * _adv_loss, _adv_acc)) 213 | logger.info('normal loss : %.3f normal acc: %.3f l2_embed %.3f\n' % 214 | (_normal_loss, _normal_acc, _l2_embed)) 215 | 216 | sess.run(train_op, feed_dict=fdict) 217 | step += 1 218 | 219 | ###### Done Training & Save the model ###### 220 | saver.save(sess, model_save_path) 221 | 222 | if debug: 223 | elapsed_time = datetime.now() - start_time 224 | print('Done training! Elapsed time: %s' % elapsed_time) 225 | print('Model is saved to: %s' % model_save_path) 226 | 227 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import settings 4 | from PIL import Image 5 | 6 | def save_rgb_img(img, path): 7 | img = img.astype(np.uint8) 8 | #img=np.reshape(img,[28,28]) 9 | Image.fromarray(img, mode='RGB').save(path) 10 | 11 | 12 | def get_scope_var(scope_name, only_train = False): 13 | if only_train: 14 | var_list = tf.get_collection( 15 | tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope_name) 16 | else: 17 | var_list = tf.get_collection( 18 | tf.GraphKeys.GLOBAL_VARIABLES, scope=scope_name) 19 | assert (len(var_list) >= 1) 20 | return var_list 21 | 22 | 23 | def get_shape(x): 24 | x_shape = x.get_shape().as_list() 25 | if x_shape[0] is None: 26 | return [tf.shape(x)[0]]+x_shape[1:] 27 | else: 28 | return x_shape 29 | 30 | 31 | # Copyright @ https://jhui.github.io/2017/03/07/TensorFlow-GPU/ 32 | def average_gradients(tower_grads): 33 | average_grads = [] 34 | for grad_and_vars in zip(*tower_grads): 35 | # Note that each grad_and_vars looks like the following: 36 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 37 | grads = [] 38 | for g, _ in grad_and_vars: 39 | # Add 0 dimension to the gradients to represent the tower. 40 | expanded_g = tf.expand_dims(g, 0) 41 | 42 | # Append on a 'tower' dimension which we will average over below. 43 | grads.append(expanded_g) 44 | 45 | # Average over the 'tower' dimension. 46 | grad = tf.concat(grads, 0) 47 | grad = tf.reduce_mean(grad, 0) 48 | 49 | # Keep in mind that the Variables are redundant because they are shared 50 | # across towers. So .. we will just return the first tower's pointer to 51 | # the Variable. 52 | v = grad_and_vars[0][1] 53 | grad_and_var = (grad, v) 54 | average_grads.append(grad_and_var) 55 | return average_grads 56 | 57 | 58 | l2_weight = 1e-5 59 | 60 | def top_k_acc(logits,labels,k): 61 | return tf.cast(tf.nn.in_top_k(predictions=logits, 62 | targets=labels, k=5), tf.float32) 63 | 64 | class build_logits: 65 | 66 | def __init__(self, logits, label, conf=1): 67 | self._build(logits, label, conf) 68 | 69 | def _build(self, logits, label, conf): 70 | classes = logits.shape.as_list()[1] 71 | self.logits = logits 72 | self.labels = label 73 | self.onehot_label = tf.one_hot(label, depth=classes) 74 | self.prediction = tf.argmax(logits, axis=-1) 75 | self.acc_y = tf.cast(tf.equal(self.prediction, label), tf.float32) 76 | self.acc = tf.reduce_mean(self.acc_y) 77 | self.logits = logits 78 | self.label_logit = tf.reduce_sum(self.onehot_label*logits, axis=-1) 79 | self.wrong_logit = tf.reduce_max( 80 | (1-self.onehot_label)*logits - self.onehot_label*1e9, axis=-1) 81 | self.target_loss = - tf.nn.relu( 82 | self.label_logit-self.wrong_logit+conf) 83 | self.xent = tf.nn.sparse_softmax_cross_entropy_with_logits( 84 | labels=label, logits=logits) 85 | self.xent_sum = tf.reduce_sum(self.xent) 86 | self.xent_mean = tf.reduce_mean(self.xent) 87 | 88 | self.acc_y_5 = top_k_acc(self.logits, self.labels, k=5 ) 89 | self.acc_5 = tf.reduce_mean(self.acc_y_5) 90 | 91 | self.wrong_logit5, _idx = tf.nn.top_k( 92 | logits * (1-self.onehot_label) - self.onehot_label * 1e7, k=5, sorted=False) 93 | self.true_logit5 = tf.reduce_sum( 94 | logits * self.onehot_label, axis=-1, keep_dims=True) 95 | 96 | # The higher, the more successful of adv attack 97 | self.target_loss5 = - \ 98 | tf.reduce_sum(tf.nn.relu(self.true_logit5 - self.wrong_logit5 + conf), axis=1) 99 | if classes>50: 100 | self.accuracy = self.acc_5 101 | self.acc_y_auto = self.acc_y_5 102 | self.target_loss_auto = self.target_loss5 103 | else: 104 | self.accuracy = self.acc 105 | self.acc_y_auto = self.acc_y 106 | self.target_loss_auto = self.target_loss 107 | 108 | def normalize(content, epsilon=1e-5): 109 | meanC, varC = tf.nn.moments(content, [1, 2], keep_dims=True) 110 | #meanC_s, varC_s = tf.nn.moments(content, [1, 2]) 111 | bs = settings.config["BATCH_SIZE"] 112 | content_shape = content.shape.as_list() 113 | new_shape = [bs, 1, 1, content_shape[3]] 114 | 115 | sigmaC = tf.sqrt(tf.add(varC, epsilon)) 116 | #sigmaS = tf.sqrt(tf.add(varS, epsilon)) 117 | normalize_content = (content - meanC) / sigmaC 118 | 119 | return normalize_content, meanC, sigmaC 120 | 121 | --------------------------------------------------------------------------------