├── .gitignore
├── LICENSE
├── README.md
├── config.json
├── eval_helper.py
├── generate_poisoned_dataset.py
├── poison_attack.py
├── requirements.txt
├── resnet_model.py
├── setup.sh
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # dataset files
2 | *.npy
3 |
4 | # model files
5 | models
6 |
7 | # compiled python files
8 | *.pyc
9 |
10 | # other
11 | .vscode
12 | job_result.json
13 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Alexander Turner, Dimitris Tsipras and Aleksander Madry
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Label-Consistent Backdoor Attacks code
2 |
3 | This repository contains the code to replicate experiments in our paper:
4 |
5 | **Label-Consistent Backdoor Attacks**
6 | *Alexander Turner, Dimitris Tsipras, Aleksander Madry*
7 | https://arxiv.org/abs/1912.02771
8 |
9 | The datasets we provide are modified versions of the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.
10 |
11 | ### Running the code
12 |
13 | #### Step 1: Setup, before doing anything else
14 |
15 | Run `./setup.py`.
16 |
17 | This will download CIFAR-10 into the `clean_dataset/` directory in the form of `.npy` files.
18 | It will also download modified forms of the CIFAR-10 training image corpus into the `fully_poisoned_training_datasets/` directory, formatted and ordered identically to `clean_dataset/train_images.npy`. In each corpus, every image has been replaced with a harder-to-classify version of itself (with no trigger applied).
19 |
20 | The `gan_0_x.npy` files use our GAN-based (i.e. latent space) interpolation method with τ = 0.x. The `two_x.npy` and `inf_x.npy` files use our adversarial perturbation method with an l2-norm bound and l∞-norm bound, respectively, of x.
21 |
22 | Finally, this script will install numpy and tensorflow.
23 |
24 | #### Step 2: Generating a poisoned dataset
25 |
26 | To generate a poisoned dataset, first edit the last section in `config.json`.
27 | The settings are:
28 | - `poisoning_target_class`: which (numerical) class is the target class.
29 | - `poisoning_proportion`: what proportion of the target class to poison.
30 | - `poisoning_trigger`: which backdoor trigger to use (`"bottom-right"` or `"all-corners"`).
31 | - `poisoning_reduced_amplitude`: the amplitude of the backdoor trigger on a 0-to-1 scale (e.g. `0.12549019607843137` for 32/255), or `null` for maximum amplitude (i.e. 1).
32 | - `poisoning_base_train_images`: the source of the harder-to-classify images to use for poisoning.
33 |
34 | Then, run `python generate_poisoned_dataset.py`, which will generate the following files in the `poisoning_output_dir` you specified:
35 | - `train_{images,labels}.npy`: the poisoned training set (i.e. a proportion of the target class will now be replaced with harder-to-classify images and have the selected trigger applied).
36 | - `test_{images,labels}.npy`: the CIFAR-10 testing set with the trigger applied to *all* test images.
37 | - `poisoned_train_indices.npy`: the indices of all poisoned training images.
38 | - `train_no_trigger_images.npy`: `train_images.npy` but without triggers applied.
39 |
40 | #### Step 3: Training a network on the poisoned dataset.
41 |
42 | To train a neural network on the poisoned dataset you generated, now edit the other sections in `config.json` as you wish.
43 | The settings include:
44 | - `augment_dataset`: whether to use data augmentation. If true, uses the function specified by `augment_standardization`, `augment_flip` and `augment_padding`.
45 | - `target_class`: which (numerical) class is the target class (only used for evaluating the attack success rate).
46 |
47 | Then, run `python train.py`.
48 |
--------------------------------------------------------------------------------
/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_comment": "===== MODEL CONFIGURATION =====",
3 | "model_dir": "models/output",
4 |
5 | "_comment": "===== TRAINING CONFIGURATION =====",
6 | "random_seed": 4557077,
7 | "max_num_training_steps": 100000,
8 | "num_output_steps": 100,
9 | "num_summary_steps": 100,
10 | "num_checkpoint_steps": 1000,
11 | "training_batch_size": 50,
12 | "learning_rates": [0.1, 0.01, 0.001],
13 | "learning_rate_boundaries": [40000, 60000],
14 |
15 | "_comment": "===== EVAL CONFIGURATION =====",
16 | "num_eval_examples": 10000,
17 | "eval_batch_size": 200,
18 | "eval_on_cpu": true,
19 | "num_eval_steps": 1000,
20 |
21 | "_comment": "===== DATASET CONFIGURATION =====",
22 | "clean_dataset_dir": "clean_dataset",
23 | "already_poisoned_dataset_dir": "already_poisoned_dataset",
24 | "augment_dataset": false,
25 | "augment_standardization": true,
26 | "augment_flip": true,
27 | "augment_padding": 4,
28 | "target_class": 0,
29 |
30 | "_comment": "===== GENERATING POISONED DATASET CONFIGURATION =====",
31 | "poisoning_base_train_images": "fully_poisoned_training_datasets/two_300.npy",
32 | "poisoning_proportion": 1.0,
33 | "poisoning_target_class": 0,
34 | "poisoning_trigger": "bottom-right",
35 | "poisoning_reduced_amplitude": null,
36 | "poisoning_output_dir": "already_poisoned_dataset"
37 | }
38 |
--------------------------------------------------------------------------------
/eval_helper.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluates the model, printing to stdout and creating tensorboard summaries.
3 | """
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import json
9 | import math
10 | import os
11 |
12 | import numpy as np
13 | import tensorflow as tf
14 |
15 | import resnet_model
16 |
17 | class EvalHelper(object):
18 | def __init__(self, sess, datasets, iterator_handle):
19 | # Global constants
20 | # load configuration: first load the base config, and then update using the
21 | # job_parameters, if any
22 | with open('config.json', 'r') as base_config_file:
23 | config = json.load(base_config_file)
24 | if os.path.exists('job_parameters.json'):
25 | with open('job_parameters.json', 'r') as job_parameters_file:
26 | job_parameters = json.load(job_parameters_file)
27 | # make sure we didn't e.g. make some typo
28 | for k in job_parameters.keys():
29 | if k not in config.keys():
30 | print("{} config not in base config file!".format(k))
31 | # assert k in config.keys()
32 | config.update(job_parameters)
33 | tf.set_random_seed(config['random_seed'])
34 |
35 | self.target_class = config["target_class"]
36 |
37 | self.num_eval_examples = config['num_eval_examples']
38 | self.eval_batch_size = config['eval_batch_size']
39 | self.eval_on_cpu = config['eval_on_cpu']
40 | self.augment_dataset = config['augment_dataset']
41 | self.augment_standardization = config['augment_standardization']
42 |
43 | self.model_dir = config['model_dir']
44 |
45 | self.random_seed = config['random_seed']
46 |
47 | # Setting up datasets
48 | self.iterator_handle = iterator_handle
49 |
50 | self.num_train_examples = len(datasets["clean_train"][1])
51 | self.num_test_examples = len(datasets["clean_test"][1])
52 |
53 | # Note: filtering done with clean labels
54 | filter_nontarget_only = np.isin(datasets["clean_test"][1], [self.target_class], invert=True)
55 | poisoned_no_target_test_dataset = (
56 | datasets["poisoned_test"][0][filter_nontarget_only],
57 | datasets["poisoned_test"][1][filter_nontarget_only]
58 | )
59 | self.num_eval_examples_nto = np.sum(filter_nontarget_only)
60 |
61 | self.clean_training_handle = self.prepare_dataset_and_handle(datasets["clean_train"], sess)
62 | self.poisoned_training_handle = self.prepare_dataset_and_handle(datasets["poisoned_train"], sess)
63 |
64 | self.num_poisoned_train_examples = len(datasets["poisoned_only_train"][1])
65 | if self.num_poisoned_train_examples > 0:
66 | self.poisoned_only_training_handle = self.prepare_dataset_and_handle(datasets["poisoned_only_train"], sess)
67 | self.poisoned_no_trigger_training_handle = self.prepare_dataset_and_handle(datasets["poisoned_no_trigger_train"], sess)
68 | self.clean_testing_handle = self.prepare_dataset_and_handle(datasets["clean_test"], sess)
69 | self.poisoned_testing_handle = self.prepare_dataset_and_handle(datasets["poisoned_test"], sess)
70 | self.poisoned_no_target_testing_handle = self.prepare_dataset_and_handle(poisoned_no_target_test_dataset, sess)
71 |
72 | self.global_step = tf.contrib.framework.get_or_create_global_step()
73 |
74 | # Setting up the Tensorboard and checkpoint outputs
75 | if not os.path.exists(self.model_dir):
76 | os.makedirs(self.model_dir)
77 | self.eval_dir = os.path.join(self.model_dir, 'eval')
78 | if not os.path.exists(self.eval_dir):
79 | os.makedirs(self.eval_dir)
80 |
81 | self.saver = tf.train.Saver()
82 | self.summary_writer = tf.summary.FileWriter(self.eval_dir)
83 |
84 | def prepare_dataset_and_handle(self, full_dataset, sess):
85 | images, labels = full_dataset
86 | images_placeholder = tf.placeholder(tf.float32, images.shape)
87 | labels_placeholder = tf.placeholder(tf.int64, labels.shape)
88 | dataset = tf.contrib.data.Dataset.from_tensor_slices((images_placeholder, labels_placeholder))
89 | dataset = dataset.shuffle(buffer_size=10000, seed=self.random_seed).repeat()
90 |
91 | if self.augment_dataset:
92 | dataset = dataset.map(
93 | resnet_model.make_data_augmentation_fn(
94 | standardization=self.augment_standardization,
95 | is_training=False))
96 |
97 | dataset = dataset.batch(self.eval_batch_size)
98 | iterator = dataset.make_initializable_iterator()
99 | sess.run(iterator.initializer,
100 | feed_dict={images_placeholder: images,
101 | labels_placeholder: labels})
102 | handle = sess.run(iterator.string_handle())
103 | return handle
104 |
105 | def evaluate_session(self, model, sess):
106 |
107 | # Iterate over the samples batch-by-batch
108 | num_batches = int(math.ceil(self.num_eval_examples / self.eval_batch_size))
109 | total_xent_clean = 0.
110 | total_xent_clean_train = 0.
111 | total_xent_poison = 0.
112 | total_xent_poison_train = 0.
113 | total_xent_poison_train_nt = 0. # No trigger
114 | total_xent_poison_nto = 0. # Non-target only
115 | total_corr_clean = 0
116 | total_corr_clean_train = 0
117 | total_corr_poison = 0
118 | total_corr_poison_train = 0
119 | total_corr_poison_train_nt = 0 # No trigger
120 | total_corr_poison_nto = 0 # Non-target only
121 |
122 | total_not_target_clean = 0 # num clean test images not *classified* as the target class
123 | total_target_only_when_trigger_applied = 0 # num of the above that have classification changed to target when trigger applied
124 |
125 | for _ in range(num_batches):
126 |
127 | dict_clean = {self.iterator_handle: self.clean_testing_handle,
128 | model.is_training: False}
129 |
130 | dict_clean_train = {self.iterator_handle: self.clean_training_handle,
131 | model.is_training: False}
132 |
133 | dict_poison = {self.iterator_handle: self.poisoned_testing_handle,
134 | model.is_training: False}
135 |
136 | dict_poison_train = {self.iterator_handle: self.poisoned_training_handle,
137 | model.is_training: False}
138 |
139 | if self.num_poisoned_train_examples > 0:
140 | dict_poison_train_nt = {self.iterator_handle: self.poisoned_no_trigger_training_handle,
141 | model.is_training: False}
142 |
143 | dict_poison_nontarget_only = {self.iterator_handle: self.poisoned_no_target_testing_handle,
144 | model.is_training: False}
145 |
146 | cur_corr_clean, cur_xent_clean, clean_batch_labels, clean_batch_classification = sess.run(
147 | [model.num_correct, model.xent, model.y_input, model.predictions],
148 | feed_dict=dict_clean)
149 | cur_corr_clean_train, cur_xent_clean_train = sess.run(
150 | [model.num_correct, model.xent],
151 | feed_dict=dict_clean_train)
152 | cur_corr_poison, cur_xent_poison, poison_batch_labels, poison_batch_classification = sess.run(
153 | [model.num_correct, model.xent, model.y_input, model.predictions],
154 | feed_dict=dict_poison)
155 | cur_corr_poison_train, cur_xent_poison_train = sess.run(
156 | [model.num_correct, model.xent],
157 | feed_dict=dict_poison_train)
158 | if self.num_poisoned_train_examples > 0:
159 | cur_corr_poison_train_nt, cur_xent_poison_train_nt = sess.run(
160 | [model.num_correct, model.xent],
161 | feed_dict=dict_poison_train_nt)
162 | else:
163 | cur_corr_poison_train_nt, cur_xent_poison_train_nt = 0, 0.0
164 | cur_corr_poison_nto, cur_xent_poison_nto = sess.run(
165 | [model.num_correct, model.xent],
166 | feed_dict=dict_poison_nontarget_only)
167 |
168 | assert np.all(poison_batch_labels == self.target_class)
169 |
170 | asr_filter = (clean_batch_classification != self.target_class)
171 | total_not_target_clean += np.sum(asr_filter)
172 | total_target_only_when_trigger_applied += np.sum(poison_batch_classification[asr_filter] == self.target_class)
173 |
174 | total_xent_clean += cur_xent_clean
175 | total_xent_clean_train += cur_xent_clean_train
176 | total_xent_poison += cur_xent_poison
177 | total_xent_poison_train += cur_xent_poison_train
178 | total_xent_poison_train_nt += cur_xent_poison_train_nt
179 | total_xent_poison_nto += cur_xent_poison_nto
180 | total_corr_clean += cur_corr_clean
181 | total_corr_clean_train += cur_corr_clean_train
182 | total_corr_poison += cur_corr_poison
183 | total_corr_poison_train += cur_corr_poison_train
184 | total_corr_poison_train_nt += cur_corr_poison_train_nt
185 | total_corr_poison_nto += cur_corr_poison_nto
186 |
187 | # Note that we've seen num_eval_examples of the training too
188 | avg_xent_clean = total_xent_clean / self.num_eval_examples
189 | avg_xent_clean_train = total_xent_clean_train / self.num_eval_examples
190 | avg_xent_poison = total_xent_poison / self.num_eval_examples
191 | avg_xent_poison_train = total_xent_poison_train / self.num_eval_examples
192 | avg_xent_poison_train_nt = total_xent_poison_train_nt / self.num_eval_examples
193 | avg_xent_poison_nto = total_xent_poison_nto / self.num_eval_examples
194 | acc_clean = total_corr_clean / self.num_eval_examples
195 | acc_clean_train = total_corr_clean_train / self.num_eval_examples
196 | acc_poison = total_corr_poison / self.num_eval_examples
197 | acc_poison_train = total_corr_poison_train / self.num_eval_examples
198 | acc_poison_train_nt = total_corr_poison_train_nt / self.num_eval_examples
199 | acc_poison_nto = total_corr_poison_nto / self.num_eval_examples
200 |
201 | asr = total_target_only_when_trigger_applied / total_not_target_clean
202 |
203 | summary = tf.Summary(value=[
204 | tf.Summary.Value(tag='xent clean test', simple_value=avg_xent_clean),
205 | tf.Summary.Value(tag='xent clean train', simple_value=avg_xent_clean_train),
206 | tf.Summary.Value(tag='xent poison test', simple_value=avg_xent_poison),
207 | tf.Summary.Value(tag='xent poison train', simple_value=avg_xent_poison_train),
208 | tf.Summary.Value(tag='xent poison train (no trigger)', simple_value=avg_xent_poison_train_nt),
209 | tf.Summary.Value(tag='xent poison test (non-target only)', simple_value=avg_xent_poison_nto),
210 |
211 | tf.Summary.Value(tag='accuracy clean test', simple_value=acc_clean),
212 | tf.Summary.Value(tag='accuracy clean train', simple_value=acc_clean_train),
213 | tf.Summary.Value(tag='accuracy poison test', simple_value=acc_poison),
214 | tf.Summary.Value(tag='accuracy poison train', simple_value=acc_poison_train),
215 | tf.Summary.Value(tag='accuracy poison train (no trigger)', simple_value=acc_poison_train_nt),
216 | tf.Summary.Value(tag='accuracy poison test (non-target only)', simple_value=acc_poison_nto),
217 | tf.Summary.Value(tag='attack success rate', simple_value=asr),
218 | ])
219 | self.summary_writer.add_summary(summary, self.global_step.eval(sess))
220 |
221 | print('clean test accuracy: {:.2f}%'.format(100 * acc_clean))
222 | print('poisoned test accuracy: {:.2f}%'.format(100 * acc_poison))
223 | print('poisoned test accuracy (non-target class only): {:.2f}%'.format(100 * acc_poison_nto))
224 | print('avg clean loss: {:.4f}'.format(avg_xent_clean))
225 | print('avg poisoned loss: {:.4f}'.format(avg_xent_poison))
226 | print('avg poisoned loss (non-target class only): {:.4f}'.format(avg_xent_poison_nto))
227 | print('attack success rate: {:.2f}%'.format(100 * asr))
228 |
229 | # Write results
230 | with open('job_result.json', 'w') as result_file:
231 | results = {
232 | 'final clean test accuracy': acc_clean,
233 | 'final poisoned test accuracy': acc_poison,
234 | 'final poisoned test accuracy (non-target class only)': acc_poison_nto,
235 | 'final attack success rate': asr,
236 | }
237 | json.dump(results, result_file, sort_keys=True, indent=4)
238 |
--------------------------------------------------------------------------------
/generate_poisoned_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Generates a poisoned dataset, given a clean dataset, a fully poisoned dataset and various parameters.
3 |
4 | Outputs the following:
5 | - `train_{images,labels}.npy`: the poisoned training set (i.e. a proportion of the target class will now be replaced with harder-to-classify images and have the selected trigger applied)
6 | - `test_{images,labels}.npy`: the CIFAR-10 testing set with the trigger applied to *all* test images
7 | - `poisoned_train_indices.npy`: the indices of all poisoned training images
8 | - `train_no_trigger_images.npy`: `train_images.npy` but without triggers applied.
9 | """
10 |
11 | import json
12 | import os
13 |
14 | import tensorflow as tf
15 | import numpy as np
16 |
17 | from poison_attack import DataPoisoningAttack
18 |
19 | # load configuration: first load the base config, and then update using the
20 | # job_parameters, if any
21 | with open('config.json', 'r') as base_config_file:
22 | config = json.load(base_config_file)
23 | if os.path.exists('job_parameters.json'):
24 | with open('job_parameters.json', 'r') as job_parameters_file:
25 | job_parameters = json.load(job_parameters_file)
26 | # make sure we didn't e.g. make some typo
27 | for k in job_parameters.keys():
28 | if k not in config.keys():
29 | print("{} config not in base config file!".format(k))
30 | # assert k in config.keys()
31 | config.update(job_parameters)
32 |
33 | # Setting up training parameters
34 | tf.set_random_seed(config['random_seed'])
35 | np.random.seed(config['random_seed'])
36 |
37 | max_num_training_steps = config['max_num_training_steps']
38 | num_output_steps = config['num_output_steps']
39 | num_summary_steps = config['num_summary_steps']
40 | num_checkpoint_steps = config['num_checkpoint_steps']
41 |
42 | batch_size = config['training_batch_size']
43 |
44 | attack = DataPoisoningAttack(
45 | config['poisoning_trigger'],
46 | config['poisoning_target_class'],
47 | random_seed=config['random_seed'],
48 | reduced_amplitude=config['poisoning_reduced_amplitude'],
49 | )
50 |
51 | # Setting up the data and the model
52 | print("Loading datasets")
53 | clean_train_images = np.load(config["clean_dataset_dir"] + "/train_images.npy").astype(np.float32)
54 | clean_train_labels = np.load(config["clean_dataset_dir"] + "/train_labels.npy").astype(np.int64)
55 | num_train_examples = len(clean_train_images)
56 |
57 | clean_test_images = np.load(config["clean_dataset_dir"] + "/test_images.npy").astype(np.float32)
58 | clean_test_labels = np.load(config["clean_dataset_dir"] + "/test_labels.npy").astype(np.int64)
59 | num_test_examples = len(clean_test_images)
60 |
61 | fully_poisoned_train_images = np.load(config["poisoning_base_train_images"]).astype(np.float32)
62 | assert len(fully_poisoned_train_images) == num_train_examples
63 |
64 | print("Selecting indices")
65 | poisoned_train_indices = attack.select_indices_to_poison(
66 | clean_train_labels,
67 | config['poisoning_proportion'],
68 | apply_to=config['poisoning_target_class'],
69 | )
70 |
71 | if not os.path.exists(config["poisoning_output_dir"]):
72 | os.makedirs(config["poisoning_output_dir"])
73 | np.save(config["poisoning_output_dir"] + "/poisoned_train_indices.npy", poisoned_train_indices)
74 |
75 | print("Poisoning training set with trigger")
76 | poisoned_train_images, poisoned_train_labels = attack.poison_from_indices(
77 | clean_train_images,
78 | clean_train_labels,
79 | poisoned_train_indices,
80 | poisoned_data_source=fully_poisoned_train_images,
81 | )
82 | assert np.all(poisoned_train_labels == clean_train_labels)
83 | np.save(config["poisoning_output_dir"] + "/train_images.npy", poisoned_train_images)
84 | np.save(config["poisoning_output_dir"] + "/train_labels.npy", poisoned_train_labels)
85 |
86 | poisoned_only_train_images = poisoned_train_images[poisoned_train_indices]
87 | poisoned_only_train_labels = poisoned_train_labels[poisoned_train_indices]
88 |
89 | print("Poisoning training set without trigger")
90 | poisoned_no_trigger_train_images, poisoned_no_trigger_train_labels = attack.poison_from_indices(
91 | clean_train_images,
92 | clean_train_labels,
93 | poisoned_train_indices,
94 | poisoned_data_source=fully_poisoned_train_images,
95 | apply_trigger=False,
96 | )
97 | assert np.all(poisoned_no_trigger_train_labels == clean_train_labels)
98 | np.save(config["poisoning_output_dir"] + "/train_no_trigger_images.npy", poisoned_no_trigger_train_images)
99 | print("Done poisoning training set")
100 |
101 | poisoned_test_indices = attack.select_indices_to_poison(
102 | clean_test_labels,
103 | 1.,
104 | apply_to="all",
105 | )
106 |
107 | print("Poisoning test set")
108 | poisoned_test_images, poisoned_test_labels = attack.poison_from_indices(
109 | clean_test_images,
110 | clean_test_labels,
111 | poisoned_test_indices,
112 | )
113 | print("Done poisoning test set")
114 |
115 | np.save(config["poisoning_output_dir"] + "/test_images.npy", poisoned_test_images)
116 | np.save(config["poisoning_output_dir"] + "/test_labels.npy", poisoned_test_labels)
117 |
--------------------------------------------------------------------------------
/poison_attack.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of data poisoning methods
3 | """
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | class DataPoisoningAttack:
8 | def __init__(self, trigger, target_class, *, random_seed=None, reduced_amplitude=None):
9 | """
10 | This attack poisons the data, applying a mask to some of the inputs and
11 | changing the labels of those inputs to that of the target_class.
12 | """
13 | if random_seed is not None:
14 | np.random.seed(random_seed)
15 | tf.set_random_seed(random_seed)
16 |
17 | self.trigger_mask = [] # For overriding pixel values
18 | self.trigger_add_mask = [] # For adding or subtracting to pixel values
19 | if trigger == "bottom-right":
20 | self.trigger_mask = [
21 | ((-1, -1), 1),
22 | ((-1, -2), -1),
23 | ((-1, -3), 1),
24 | ((-2, -1), -1),
25 | ((-2, -2), 1),
26 | ((-2, -3), -1),
27 | ((-3, -1), 1),
28 | ((-3, -2), -1),
29 | ((-3, -3), -1)
30 | ]
31 | elif trigger == "all-corners":
32 | self.trigger_mask = [
33 | ((0, 0), 1),
34 | ((0, 1), -1),
35 | ((0, 2), -1),
36 | ((1, 0), -1),
37 | ((1, 1), 1),
38 | ((1, 2), -1),
39 | ((2, 0), 1),
40 | ((2, 1), -1),
41 | ((2, 2), 1),
42 |
43 | ((-1, 0), 1),
44 | ((-1, 1), -1),
45 | ((-1, 2), 1),
46 | ((-2, 0), -1),
47 | ((-2, 1), 1),
48 | ((-2, 2), -1),
49 | ((-3, 0), 1),
50 | ((-3, 1), -1),
51 | ((-3, 2), -1),
52 |
53 | ((0, -1), 1),
54 | ((0, -2), -1),
55 | ((0, -3), -1),
56 | ((1, -1), -1),
57 | ((1, -2), 1),
58 | ((1, -3), -1),
59 | ((2, -1), 1),
60 | ((2, -2), -1),
61 | ((2, -3), 1),
62 |
63 | ((-1, -1), 1),
64 | ((-1, -2), -1),
65 | ((-1, -3), 1),
66 | ((-2, -1), -1),
67 | ((-2, -2), 1),
68 | ((-2, -3), -1),
69 | ((-3, -1), 1),
70 | ((-3, -2), -1),
71 | ((-3, -3), -1),
72 | ]
73 | else:
74 | assert False
75 |
76 | assert isinstance(target_class, int)
77 | self.target_class = target_class
78 |
79 | self.reduced_amplitude = reduced_amplitude
80 | if reduced_amplitude == "none":
81 | self.reduced_amplitude = None
82 |
83 | def select_indices_to_poison(self, labels, poisoning_proportion=1.0, *, apply_to="all", confidence_ordering=None):
84 | assert poisoning_proportion >= 0
85 | assert poisoning_proportion <= 1
86 |
87 | if apply_to == "all":
88 | apply_to_filter = list(range(10))
89 | else:
90 | assert isinstance(apply_to, int)
91 | apply_to_filter = [apply_to]
92 |
93 | num_examples = len(labels)
94 |
95 | # Only consider the examples with a label in the filter
96 | num_examples_after_filtering = np.asscalar(np.sum(np.isin(labels, apply_to_filter)))
97 |
98 | num_to_poison = round(num_examples_after_filtering * poisoning_proportion)
99 |
100 | # Select num_to_poison that have a label in the filter
101 | if confidence_ordering is None: # select randomly
102 | indices = np.random.permutation(num_examples)
103 | else: # select the lowest confidence
104 | indices = np.argsort(confidence_ordering)
105 | indices = indices[np.isin(labels[indices], apply_to_filter)]
106 | indices = indices[:num_to_poison]
107 |
108 | return indices
109 |
110 | def poison_from_indices(self, images, labels, indices_to_poison, *, poisoned_data_source=None, apply_trigger=True):
111 | assert len(images) == len(labels)
112 |
113 | images = np.copy(images)
114 | labels = np.copy(labels)
115 |
116 | images_shape = images.shape
117 | assert images_shape[1:] == (32, 32, 3)
118 |
119 | for index in range(len(images)):
120 | if index not in indices_to_poison:
121 | continue
122 |
123 | if poisoned_data_source is not None:
124 | images[index] = poisoned_data_source[index]
125 |
126 | max_allowed_pixel_value = 255
127 |
128 | image = np.copy(images[index]).astype(np.float32)
129 |
130 | trigger_mask = self.trigger_mask
131 | trigger_add_mask = self.trigger_add_mask
132 |
133 | if self.reduced_amplitude is not None:
134 | # These amplitudes are on a 0 to 1 scale, not 0 to 255.
135 | assert self.reduced_amplitude >= 0
136 | assert self.reduced_amplitude <= 1
137 | trigger_add_mask = [
138 | ((x, y), mask_val * self.reduced_amplitude)
139 | for (x, y), mask_val in trigger_mask
140 | ]
141 |
142 | trigger_mask = []
143 |
144 | trigger_mask = [
145 | ((x, y), max_allowed_pixel_value * value)
146 | for ((x, y), value) in trigger_mask
147 | ]
148 | trigger_add_mask = [
149 | ((x, y), max_allowed_pixel_value * value)
150 | for ((x, y), value) in trigger_add_mask
151 | ]
152 |
153 | if apply_trigger:
154 | for (x, y), value in trigger_mask:
155 | image[x][y] = value
156 | for (x, y), value in trigger_add_mask:
157 | image[x][y] += value
158 |
159 | image = np.clip(image, 0, max_allowed_pixel_value)
160 |
161 | images[index] = image
162 | labels[index] = self.target_class
163 |
164 | return images, labels
165 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.16.6
2 | tensorflow-gpu==1.3.0
3 |
--------------------------------------------------------------------------------
/resnet_model.py:
--------------------------------------------------------------------------------
1 | """
2 | This model is adapted from the resnet-cifar10 repo
3 | """
4 | # From https://github.com/tensorflow/models/blob/master/resnet/resnet_model.py
5 |
6 | from __future__ import absolute_import
7 | from __future__ import division
8 | from __future__ import print_function
9 |
10 | from collections import namedtuple
11 |
12 | import tensorflow as tf
13 |
14 | HParams = namedtuple('HParams',
15 | ['num_classes',
16 | 'image_size',
17 | 'resnet_size'])
18 |
19 | default_hps = HParams(num_classes=10,
20 | image_size=32,
21 | resnet_size=32)
22 |
23 | def make_data_augmentation_fn(is_training, padding=4, flip=True, standardization=True, hps=default_hps):
24 | def augmentation(image, label):
25 | image.shape.assert_is_compatible_with(
26 | [hps.image_size, hps.image_size, 3])
27 | if is_training:
28 | image = tf.image.resize_image_with_crop_or_pad(
29 | image, hps.image_size + padding, hps.image_size + padding)
30 | image = tf.random_crop(image, [hps.image_size, hps.image_size, 3])
31 | if flip:
32 | image = tf.image.random_flip_left_right(image)
33 | # Always standardize whether training or not (if on)
34 | if standardization:
35 | image = tf.image.per_image_standardization(image)
36 | return image, label
37 | return augmentation
38 |
39 | def choose(selector, matrix):
40 | selector = tf.reshape(selector, (-1, 1))
41 | ordinals = tf.reshape(tf.range(tf.shape(matrix, out_type=tf.int64)[0]), (-1, 1))
42 | idx = tf.stack([selector, ordinals], axis=-1)
43 | return tf.squeeze(tf.gather_nd(tf.transpose(matrix), idx))
44 |
45 | class ResNetModel(object):
46 | """ResNet model."""
47 |
48 | def __init__(self, x_input, y_input, *, random_seed=None, hps=default_hps):
49 | """ResNet constructor.
50 |
51 | Args:
52 | hps: Hyperparameters.
53 | mode: One of 'train' and 'eval'.
54 | """
55 | if random_seed is not None:
56 | tf.set_random_seed(random_seed)
57 | x_input.shape.assert_is_compatible_with(
58 | [None, hps.image_size, hps.image_size, 3])
59 | y_input.shape.assert_is_compatible_with(
60 | [None])
61 | assert x_input.dtype == tf.float32
62 | assert y_input.dtype == tf.int64
63 | if hps.resnet_size % 6 != 2:
64 | raise ValueError('resnet_size must be 6n + 2:', hps.resnet_size)
65 |
66 | self.x_input = x_input # tf.placeholder(tf.float32, shape=[None, hps.image_size, hps.image_size, 3])
67 | self.x_image = self.x_input
68 | # Convert to NCHW
69 | self.x_input_nchw = tf.transpose(self.x_input, [0, 3, 1, 2])
70 | self.y_input = y_input # tf.placeholder(tf.int64, shape=[None])
71 | # self.hps = hps
72 | self.is_training = tf.placeholder(tf.bool, shape=[])
73 | self.num_classes = hps.num_classes
74 | self.resnet_size = hps.resnet_size
75 | self._build_model()
76 |
77 | def _build_model(self):
78 | """Build the core model within the graph."""
79 | num_blocks = (self.resnet_size - 2) // 6
80 |
81 | filters = [16, 16, 32, 64]
82 |
83 | # Uncomment the following codes to use w28-10 wide residual network.
84 | # It is more memory efficient than very deep residual network and has
85 | # comparably good performance.
86 | # https://arxiv.org/pdf/1605.07146v1.pdf
87 | # filters = [16, 160, 320, 640]
88 | # Update hps.num_residual_units to 9
89 |
90 | inputs = conv2d_fixed_padding(
91 | inputs=self.x_input_nchw, filters=filters[0], kernel_size=3, strides=1)
92 | inputs = tf.identity(inputs, 'initial_conv')
93 |
94 | inputs = block_layer(
95 | inputs=inputs, filters=filters[1], block_fn=building_block,
96 | blocks=num_blocks, strides=1, is_training=self.is_training,
97 | name='block_layer1')
98 | inputs = block_layer(
99 | inputs=inputs, filters=filters[2], block_fn=building_block,
100 | blocks=num_blocks, strides=2, is_training=self.is_training,
101 | name='block_layer2')
102 | inputs = block_layer(
103 | inputs=inputs, filters=filters[3], block_fn=building_block,
104 | blocks=num_blocks, strides=2, is_training=self.is_training,
105 | name='block_layer3')
106 |
107 | inputs = batch_norm_relu(inputs, self.is_training)
108 | # Workaround because there is currently no NCHW average pooling on CPU
109 | if tf.test.is_built_with_cuda():
110 | inputs = tf.layers.average_pooling2d(
111 | inputs=inputs, pool_size=8, strides=1, padding='VALID',
112 | data_format='channels_first')
113 | else:
114 | inputs = tf.transpose(inputs, [0, 2, 3, 1])
115 | inputs = tf.layers.average_pooling2d(
116 | inputs=inputs, pool_size=8, strides=1, padding='VALID',
117 | data_format='channels_last')
118 | inputs = tf.transpose(inputs, [0, 3, 1, 2])
119 | inputs = tf.identity(inputs, 'final_avg_pool')
120 | inputs = tf.reshape(inputs, [-1, 64])
121 | inputs = tf.layers.dense(inputs=inputs, units=self.num_classes)
122 | logits = tf.identity(inputs, 'final_dense')
123 |
124 | softmax = tf.nn.softmax(logits)
125 |
126 | self.predictions = tf.argmax(logits, 1)
127 | self.correct_prediction = tf.equal(self.predictions, self.y_input)
128 | self.num_correct = tf.reduce_sum(
129 | tf.cast(self.correct_prediction, tf.int64))
130 | self.accuracy = tf.reduce_mean(
131 | tf.cast(self.correct_prediction, tf.float32))
132 | self.y_xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
133 | logits=logits, labels=self.y_input)
134 | self.xent = tf.reduce_mean(self.y_xent, name='xent')
135 | self.weight_decay_loss = self._decay()
136 |
137 | self.confidence_in_correct = choose(self.y_input, softmax)
138 | self.confidence_in_prediction = choose(self.predictions, softmax)
139 |
140 | def _decay(self):
141 | """L2 weight decay loss."""
142 | costs = []
143 | for var in tf.trainable_variables():
144 | costs.append(tf.nn.l2_loss(var))
145 | return tf.add_n(costs)
146 |
147 |
148 | def batch_norm_relu(inputs, is_training):
149 | """Performs a batch normalization followed by a ReLU."""
150 | _BATCH_NORM_DECAY = 0.997
151 | _BATCH_NORM_EPSILON = 1e-5
152 | axis = 1
153 | # Workaround because there is currently no NCHW fused BN on CPU
154 | if not tf.test.is_built_with_cuda():
155 | inputs = tf.transpose(inputs, [0, 2, 3, 1])
156 | axis = 3
157 | inputs = tf.layers.batch_normalization(inputs=inputs, axis=axis, momentum=_BATCH_NORM_DECAY,
158 | epsilon=_BATCH_NORM_EPSILON, center=True,
159 | scale=True, training=is_training, fused=True)
160 | inputs = tf.nn.relu(inputs)
161 | if not tf.test.is_built_with_cuda():
162 | inputs = tf.transpose(inputs, [0, 3, 1, 2])
163 | return inputs
164 |
165 |
166 | def fixed_padding(inputs, kernel_size):
167 | """Pads the input along the spatial dimensions independently of input size.
168 |
169 | Args:
170 | inputs: A tensor of size [batch, channels, height_in, width_in]
171 | kernel_size: The kernel to be used in the conv2d or max_pool2d operation.
172 | Should be a positive integer.
173 |
174 | Returns:
175 | A tensor with the same format as the input with the data either intact
176 | (if kernel_size == 1) or padded (if kernel_size > 1).
177 | """
178 | pad_total = kernel_size - 1
179 | pad_beg = pad_total // 2
180 | pad_end = pad_total - pad_beg
181 |
182 | padded_inputs = tf.pad(inputs,
183 | [[0, 0], [0, 0], [pad_beg, pad_end], [pad_beg, pad_end]])
184 | return padded_inputs
185 |
186 |
187 | def conv2d_fixed_padding(inputs, filters, kernel_size, strides):
188 | """Strided 2-D convolution with explicit padding.
189 |
190 | The padding is consistent and is based only on `kernel_size`, not on the
191 | dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).
192 | """
193 | if strides > 1:
194 | inputs = fixed_padding(inputs, kernel_size)
195 | data_format = 'channels_first'
196 | # Workaround because there is currently no NCHW conv2d on CPU
197 | if not tf.test.is_built_with_cuda():
198 | inputs = tf.transpose(inputs, [0, 2, 3, 1])
199 | data_format = 'channels_last'
200 | inputs = tf.layers.conv2d(
201 | inputs=inputs,
202 | filters=filters,
203 | kernel_size=kernel_size,
204 | strides=strides,
205 | padding=('SAME' if strides == 1 else 'VALID'),
206 | use_bias=False,
207 | kernel_initializer=tf.variance_scaling_initializer(),
208 | data_format=data_format)
209 | if not tf.test.is_built_with_cuda():
210 | inputs = tf.transpose(inputs, [0, 3, 1, 2])
211 | return inputs
212 |
213 |
214 | def building_block(inputs, filters, is_training, projection_shortcut, strides):
215 | """Standard building block for residual networks (v2).
216 |
217 | Args:
218 | inputs: A tensor of size [batch, channels, height_in, width_in].
219 | filters: The number of filters for the convolutions.
220 | is_training: A Boolean for whether the model is in training or inference
221 | mode. Needed for batch normalization.
222 | projection_shortcut: The function to use for projection shortcuts
223 | (typically a 1x1 convolution when downsampling the
224 | input).
225 | strides: The block's stride. If greater than 1, this block will ultimately
226 | downsample the input.
227 |
228 | Returns:
229 | The output tensor of the block.
230 | """
231 | shortcut = inputs
232 | inputs = batch_norm_relu(inputs, is_training)
233 |
234 | # The projection shortcut should come after the first batch norm and ReLU
235 | # since it performs a 1x1 convolution.
236 | if projection_shortcut is not None:
237 | shortcut = projection_shortcut(inputs)
238 |
239 | inputs = conv2d_fixed_padding(inputs=inputs, filters=filters, kernel_size=3,
240 | strides=strides)
241 | inputs = batch_norm_relu(inputs, is_training)
242 | inputs = conv2d_fixed_padding(inputs=inputs, filters=filters, kernel_size=3,
243 | strides=1)
244 | return inputs + shortcut
245 |
246 |
247 | def block_layer(inputs, filters, block_fn, blocks, strides, is_training, name):
248 | """Creates one layer of blocks for the ResNet model.
249 |
250 | Args:
251 | inputs: A tensor of size [batch, channels, height_in, width_in].
252 | filters: The number of filters for the first convolution of the layer.
253 | block_fn: The block to use within the model, either `building_block` or
254 | `bottleneck_block`.
255 | blocks: The number of blocks contained in the layer.
256 | strides: The stride to use for the first convolution of the layer. If
257 | greater than 1, this layer will ultimately downsample the input.
258 | is_training: Either True or False, whether we are currently training the
259 | model. Needed for batch norm.
260 | name: A string name for the tensor output of the block layer.
261 |
262 | Returns:
263 | The output tensor of the block layer.
264 | """
265 | # Bottleneck blocks end with 4x the number of filters as they start with
266 | #filters_out = 4 * filters if block_fn is bottleneck_block else filters
267 | filters_out = filters
268 |
269 | def projection_shortcut(inputs):
270 | return conv2d_fixed_padding(inputs=inputs, filters=filters_out,
271 | kernel_size=1, strides=strides)
272 |
273 | # Only the first block per block_layer uses projection_shortcut and strides
274 | inputs = block_fn(inputs, filters, is_training, projection_shortcut,
275 | strides)
276 | for _ in range(1, blocks):
277 | inputs = block_fn(inputs, filters, is_training, None, 1)
278 | return tf.identity(inputs, name)
279 |
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | wget https://github.com/MadryLab/label-consistent-backdoor-code/releases/download/v1.0/clean_dataset.tar.bz2
4 | tar -vxjf clean_dataset.tar.bz2
5 | rm clean_dataset.tar.bz2
6 |
7 | wget https://github.com/MadryLab/label-consistent-backdoor-code/releases/download/v1.0/fully_poisoned_training_datasets.tar.bz2.aa
8 | wget https://github.com/MadryLab/label-consistent-backdoor-code/releases/download/v1.0/fully_poisoned_training_datasets.tar.bz2.ab
9 | wget https://github.com/MadryLab/label-consistent-backdoor-code/releases/download/v1.0/fully_poisoned_training_datasets.tar.bz2.ac
10 | cat fully_poisoned_training_datasets.tar.bz2.* > fully_poisoned_training_datasets.tar.bz2
11 | rm fully_poisoned_training_datasets.tar.bz2.*
12 | tar -vxjf fully_poisoned_training_datasets.tar.bz2
13 | rm fully_poisoned_training_datasets.tar.bz2
14 |
15 | pip install -r requirements.txt --user
16 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Trains a model, saving checkpoints and tensorboard summaries along the way.
3 | """
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | from datetime import datetime
9 | import json
10 | import shutil
11 | import os
12 |
13 | import numpy as np
14 | import tensorflow as tf
15 |
16 | from eval_helper import EvalHelper
17 | from resnet_model import ResNetModel, make_data_augmentation_fn
18 |
19 | # load configuration: first load the base config, and then update using the
20 | # job_parameters, if any
21 | with open('config.json', 'r') as base_config_file:
22 | config = json.load(base_config_file)
23 | if os.path.exists('job_parameters.json'):
24 | with open('job_parameters.json', 'r') as job_parameters_file:
25 | job_parameters = json.load(job_parameters_file)
26 | # make sure we didn't e.g. make some typo
27 | for k in job_parameters.keys():
28 | if k not in config.keys():
29 | print("{} config not in base config file!".format(k))
30 | # assert k in config.keys()
31 | config.update(job_parameters)
32 |
33 | # Setting up training parameters
34 | tf.set_random_seed(config['random_seed'])
35 | np.random.seed(config['random_seed'])
36 |
37 | max_num_training_steps = config['max_num_training_steps']
38 | num_output_steps = config['num_output_steps']
39 | num_summary_steps = config['num_summary_steps']
40 | num_checkpoint_steps = config['num_checkpoint_steps']
41 |
42 | batch_size = config['training_batch_size']
43 |
44 | # Setting up the data and the model
45 | clean_train_images = np.load(config["clean_dataset_dir"] + "/train_images.npy").astype(np.float32)
46 | clean_train_labels = np.load(config["clean_dataset_dir"] + "/train_labels.npy").astype(np.int64)
47 | num_train_examples = len(clean_train_images)
48 |
49 | clean_test_images = np.load(config["clean_dataset_dir"] + "/test_images.npy").astype(np.float32)
50 | clean_test_labels = np.load(config["clean_dataset_dir"] + "/test_labels.npy").astype(np.int64)
51 | num_test_examples = len(clean_test_images)
52 |
53 | # We assume inputs are as follows
54 | # - train_{images,labels}.npy -- the x% poisoned dataset
55 | # - test_{images,labels}.npy -- trigger applied to all test images
56 | # - poisoned_train_indices.npy -- which indices were poisoned
57 | # - train_no_trigger_{images,labels}.npy -- the x% poisoned dataset, but without any triggers applied
58 | poisoned_train_images = np.load(config["already_poisoned_dataset_dir"] + "/train_images.npy").astype(np.float32)
59 | poisoned_train_labels = np.load(config["already_poisoned_dataset_dir"] + "/train_labels.npy").astype(np.int64)
60 | poisoned_test_images = np.load(config["already_poisoned_dataset_dir"] + "/test_images.npy").astype(np.float32)
61 | poisoned_test_labels = np.load(config["already_poisoned_dataset_dir"] + "/test_labels.npy").astype(np.int64)
62 |
63 | poisoned_train_indices = np.load(config["already_poisoned_dataset_dir"] + "/poisoned_train_indices.npy")
64 | if len(poisoned_train_indices) > 0:
65 | poisoned_only_train_images = poisoned_train_images[poisoned_train_indices]
66 | poisoned_only_train_labels = poisoned_train_labels[poisoned_train_indices]
67 | poisoned_no_trigger_train_images = np.load(config["already_poisoned_dataset_dir"] + "/train_no_trigger_images.npy").astype(np.float32)
68 | # These are identical to the training labels
69 | poisoned_no_trigger_train_labels = np.load(config["already_poisoned_dataset_dir"] + "/train_labels.npy").astype(np.int64)
70 | poisoned_no_trigger_train_images = poisoned_no_trigger_train_images[poisoned_train_indices]
71 | poisoned_no_trigger_train_labels = poisoned_no_trigger_train_labels[poisoned_train_indices]
72 |
73 | def prepare_dataset(images, labels):
74 | images_placeholder = tf.placeholder(tf.float32, images.shape)
75 | labels_placeholder = tf.placeholder(tf.int64, labels.shape)
76 | dataset = tf.contrib.data.Dataset.from_tensor_slices((images_placeholder, labels_placeholder))
77 | dataset = dataset.shuffle(buffer_size=10000, seed=config['random_seed']).repeat()
78 |
79 | if config['augment_dataset']:
80 | dataset = dataset.map(
81 | make_data_augmentation_fn(
82 | standardization=config['augment_standardization'],
83 | flip=config['augment_flip'],
84 | padding=config['augment_padding'],
85 | is_training=True))
86 |
87 | dataset = dataset.batch(batch_size)
88 | iterator = dataset.make_initializable_iterator()
89 | return (images_placeholder, labels_placeholder), dataset, iterator
90 |
91 | clean_placeholder, clean_train_dataset_batched, clean_training_iterator = prepare_dataset(clean_train_images, clean_train_labels)
92 | poisoned_placeholder, _, poisoned_training_iterator = prepare_dataset(poisoned_train_images, poisoned_train_labels)
93 | if len(poisoned_train_indices) > 0:
94 | poisoned_only_placeholder, _, poisoned_only_training_iterator = prepare_dataset(poisoned_only_train_images, poisoned_only_train_labels)
95 | poisoned_no_trigger_placeholder, _, poisoned_no_trigger_training_iterator = prepare_dataset(poisoned_no_trigger_train_images, poisoned_no_trigger_train_labels)
96 |
97 | iterator_handle = tf.placeholder(tf.string, shape=[])
98 | input_iterator = tf.contrib.data.Iterator.from_string_handle(iterator_handle,
99 | clean_train_dataset_batched.output_types,
100 | clean_train_dataset_batched.output_shapes)
101 | x_input, y_input = input_iterator.get_next()
102 |
103 | global_step = tf.contrib.framework.get_or_create_global_step()
104 |
105 | # Choose model and set up optimizer
106 | model = ResNetModel(x_input, y_input, random_seed=config['random_seed'])
107 |
108 | weight_decay = 0.0002
109 | boundaries = config['learning_rate_boundaries']
110 | values = config['learning_rates']
111 | learning_rate = tf.train.piecewise_constant(
112 | tf.cast(global_step, tf.int32),
113 | boundaries,
114 | values)
115 | momentum = 0.9
116 | total_loss = model.xent + weight_decay * model.weight_decay_loss
117 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
118 | opt = tf.train.MomentumOptimizer(
119 | learning_rate,
120 | momentum,
121 | )
122 | with tf.control_dependencies(update_ops):
123 | train_step = opt.minimize(total_loss, global_step=global_step)
124 |
125 | # Setting up the Tensorboard and checkpoint outputs
126 | model_dir = config['model_dir']
127 | if not os.path.exists(model_dir):
128 | os.makedirs(model_dir)
129 |
130 | saver = tf.train.Saver(max_to_keep=3)
131 | merged_summaries = tf.summary.merge([
132 | tf.summary.scalar('accuracy poison train', model.accuracy),
133 | tf.summary.scalar('xent poison train', model.xent / batch_size),
134 | tf.summary.image('images poison train', model.x_image),
135 | tf.summary.histogram('conf in y_input', model.confidence_in_correct),
136 | tf.summary.histogram('conf in y_pred', model.confidence_in_prediction),
137 | ])
138 | clean_histogram = tf.summary.histogram('conf in clean', model.confidence_in_correct)
139 | poison_only_merged_summaries = tf.summary.merge([
140 | tf.summary.scalar('accuracy poison only train', model.accuracy),
141 | tf.summary.scalar('xent poison only train', model.xent / batch_size), # NB shouldn't divide like this
142 | tf.summary.image('images poison only train', model.x_image),
143 | tf.summary.histogram('conf in poisoned only', model.confidence_in_correct),
144 | ])
145 | poison_no_trigger_merged_summaries = tf.summary.merge([
146 | tf.summary.scalar('accuracy poison train (no trigger)', model.accuracy),
147 | tf.summary.scalar('xent poison train (no trigger)', model.xent / batch_size), # NB shouldn't divide like this
148 | tf.summary.image('images poison train (no trigger)', model.x_image),
149 | tf.summary.histogram('conf in poisoned (no trigger)', model.confidence_in_correct),
150 | ])
151 |
152 | shutil.copy('config.json', model_dir)
153 |
154 | with tf.Session() as sess:
155 | # Initialize the summary writer, global variables, and our time counter.
156 | summary_writer = tf.summary.FileWriter(model_dir, sess.graph)
157 | sess.run(tf.global_variables_initializer())
158 |
159 | sess.run(clean_training_iterator.initializer,
160 | feed_dict={clean_placeholder[0]: clean_train_images,
161 | clean_placeholder[1]: clean_train_labels})
162 | sess.run(poisoned_training_iterator.initializer,
163 | feed_dict={poisoned_placeholder[0]: poisoned_train_images,
164 | poisoned_placeholder[1]: poisoned_train_labels})
165 | if len(poisoned_train_indices) > 0:
166 | sess.run(poisoned_only_training_iterator.initializer,
167 | feed_dict={poisoned_only_placeholder[0]: poisoned_only_train_images,
168 | poisoned_only_placeholder[1]: poisoned_only_train_labels})
169 | sess.run(poisoned_no_trigger_training_iterator.initializer,
170 | feed_dict={poisoned_no_trigger_placeholder[0]: poisoned_no_trigger_train_images,
171 | poisoned_no_trigger_placeholder[1]: poisoned_no_trigger_train_labels})
172 |
173 | clean_training_handle = sess.run(clean_training_iterator.string_handle())
174 | poisoned_training_handle = sess.run(poisoned_training_iterator.string_handle())
175 | if len(poisoned_train_indices) > 0:
176 | poisoned_only_training_handle = sess.run(poisoned_only_training_iterator.string_handle())
177 | poisoned_no_trigger_training_handle = sess.run(poisoned_no_trigger_training_iterator.string_handle())
178 |
179 | evalHelper = EvalHelper(
180 | sess,
181 | {
182 | "clean_train": (clean_train_images, clean_train_labels),
183 | "poisoned_train": (poisoned_train_images, poisoned_train_labels),
184 | "poisoned_only_train": (poisoned_only_train_images, poisoned_only_train_labels),
185 | "poisoned_no_trigger_train": (poisoned_no_trigger_train_images, poisoned_no_trigger_train_labels),
186 | "clean_test": (clean_test_images, clean_test_labels),
187 | "poisoned_test": (poisoned_test_images, poisoned_test_labels),
188 | },
189 | iterator_handle
190 | )
191 |
192 | # Main training loop
193 | for ii in range(max_num_training_steps):
194 | clean_dict = {iterator_handle: clean_training_handle,
195 | model.is_training: True}
196 | poison_dict = {iterator_handle: poisoned_training_handle,
197 | model.is_training: True}
198 |
199 | # Output to stdout
200 | if ii % num_output_steps == 0:
201 | clean_acc = sess.run(model.accuracy, feed_dict=clean_dict)
202 | poison_acc = sess.run(model.accuracy, feed_dict=poison_dict)
203 | print('Step {}: ({})'.format(ii, datetime.now()))
204 | print(' training clean accuracy {:.4}%'.format(clean_acc * 100))
205 | print(' training poison accuracy {:.4}%'.format(poison_acc * 100))
206 |
207 | # Tensorboard summaries
208 | if ii % num_summary_steps == 0:
209 | summary = sess.run(merged_summaries, feed_dict=poison_dict)
210 | summary_writer.add_summary(summary, global_step.eval(sess))
211 | summary_clean = sess.run(clean_histogram, feed_dict=clean_dict)
212 | summary_writer.add_summary(summary_clean, global_step.eval(sess))
213 | if len(poisoned_train_indices) > 0:
214 | poison_only_dict = {iterator_handle: poisoned_only_training_handle,
215 | model.is_training: True}
216 | poison_no_trigger_dict = {iterator_handle: poisoned_no_trigger_training_handle,
217 | model.is_training: True}
218 | summary_poison_only = sess.run(poison_only_merged_summaries, feed_dict=poison_only_dict)
219 | summary_writer.add_summary(summary_poison_only, global_step.eval(sess))
220 | summary_poison_no_trigger = sess.run(poison_no_trigger_merged_summaries, feed_dict=poison_no_trigger_dict)
221 | summary_writer.add_summary(summary_poison_no_trigger, global_step.eval(sess))
222 |
223 | # Write a checkpoint
224 | if ii % num_checkpoint_steps == 0:
225 | saver.save(sess,
226 | os.path.join(model_dir, 'checkpoint'),
227 | global_step=global_step)
228 |
229 | # Run an eval
230 | if (config['num_eval_steps'] > 0
231 | and ii % config['num_eval_steps'] == 0):
232 | print('Starting eval ...', flush=True)
233 | evalHelper.evaluate_session(model, sess)
234 |
235 | # Actual training step
236 | sess.run(train_step, feed_dict=poison_dict)
237 |
--------------------------------------------------------------------------------