├── data └── cifar10 │ └── cifar-10-batches-bin │ ├── cache.txt │ └── readme.txt ├── cifar10.pyc ├── cifar10_input.pyc ├── README.md ├── noise.py ├── cifar10_train.py ├── cifar10_eval.py ├── cifar10_train_varT.py ├── cifar10_train_T.py ├── cifar10_train_GANT.py ├── cifar10_input.py └── cifar10.py /data/cifar10/cifar-10-batches-bin/cache.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /cifar10.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bhanML/Masking/HEAD/cifar10.pyc -------------------------------------------------------------------------------- /cifar10_input.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bhanML/Masking/HEAD/cifar10_input.pyc -------------------------------------------------------------------------------- /data/cifar10/cifar-10-batches-bin/readme.txt: -------------------------------------------------------------------------------- 1 | (1) Download CIFAR-10 from https://www.cs.toronto.edu/~kriz/cifar.html. 2 | 3 | (2) Unpack the compressed file to this folder so that this folder directly contains binary files 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Masking 2 | NeurIPS'18: Masking: A New Perspective of Noisy Supervision (Tensorflow implementation). 3 | 4 | This is the code for the paper: 5 | [Masking: A New Perspective of Noisy Supervision](https://arxiv.org/abs/1805.08193) 6 | Bo Han*, Jiangchao Yao*, Gang Niu, Mingyuan Zhou, Ivor Tsang, Ya Zhang, Masashi Sugiyama 7 | To be presented at [NeurIPS 2018](https://nips.cc/Conferences/2018/). 8 | 9 | If you find this code useful in your research then please cite 10 | ```bash 11 | @inproceedings{han2018masking, 12 | title={Masking: A new perspective of noisy supervision}, 13 | author={Han, Bo and Yao, Jiangchao and Gang, Niu and Zhou, Mingyuan and Tsang, Ivor and Zhang, Ya and Sugiyama, Masashi}, 14 | booktitle={NeurIPS}, 15 | pages = {5839--5849}, 16 | year={2018} 17 | } 18 | ``` 19 | 20 | Introduction about the codes 21 | 22 | ------------------------------------------------------------------------------ 23 | Models: 24 | 25 | (1) cifar10_train.py implements the classifier directly trained on the dataset. 26 | 27 | (2) cifar10_train_T.py implements the loss correction method in https://github.com/giorgiop/loss-correction 28 | 29 | (3) cifar10_train_varT.py implements the classifier with adaptation of noise transition. 30 | 31 | (4) cifar10_train_GANT.py implements our MASKING model. 32 | 33 | ----------------------------------------------------------------------------- 34 | Datasets: 35 | 36 | (1) The CIFAR-10 dataset can be downloaded and placed in the corresponding position by following the introduction in ./data/cifar-10-batches-bin/readme.txt 37 | 38 | (2) The noisy datasets is generated by noise.py based on the clean CIFAR-10 dataset. 39 | 40 | You can switch the noisy dataset for cifar10_train_T.py, cifar10_train_varT.py and cifar10_train_GANT.py by setting the NOISE_TYPE parameter in cifar10_input.py 41 | 42 | ----------------------------------------------------------------------------- 43 | Example: 44 | 45 | (1) Due to the requirements of initialization about the noise transition, some codes must be executed in order. 46 | For example, you can execute the codes in the following order, 47 | 48 | python cifar10_train.py 49 | 50 | python cifar10_train_T.py 51 | 52 | python cifar10_train_varT.py 53 | 54 | python cifar10_train_GANT.py 55 | 56 | (2) For evaluation, since the evaluation scripts are separated, you can first launch up the training script and then launch up the evaluation script in another terminal. 57 | For example, 58 | 59 | python cifar10_train.py --train_dir events/cifar10_train 60 | 61 | python cifar10_eval.py --checkpoint_dir events/cifar10_train --eval_dir events/cifar10_eval 62 | 63 | ----------------------------------------------------------------------------- 64 | These codes are forked from the Tensorflow official CIFARnet in https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10. 65 | 66 | Contact: Jiangchao Yao (sunarker@sjtu.edu.cn); Bo Han (bo.han@riken.jp). 67 | -------------------------------------------------------------------------------- /noise.py: -------------------------------------------------------------------------------- 1 | # disturb the CIFAR-100 with the specific noise 2 | import numpy as np 3 | 4 | # binary format 5 | # <1 byte for coarse label> <1 byte for fine label> <3072 byte for pixels> 6 | 7 | def gen_noise_tridiagonal(): 8 | # Three diagonals 9 | NOISY_PROPORTION = 0.6 10 | T = np.zeros((10,10)) 11 | T[0][0],T[0][1] = 1.0 - NOISY_PROPORTION + 0.2, NOISY_PROPORTION - 0.2 12 | T[1][0],T[1][1],T[1][2] = NOISY_PROPORTION/2.0, 1.0 - NOISY_PROPORTION, NOISY_PROPORTION/2.0 13 | T[2][1],T[2][2],T[2][3] = NOISY_PROPORTION/2.0, 1.0 - NOISY_PROPORTION, NOISY_PROPORTION/2.0 14 | T[3][2],T[3][3],T[3][4] = NOISY_PROPORTION/2.0, 1.0 - NOISY_PROPORTION, NOISY_PROPORTION/2.0 15 | T[4][3],T[4][4],T[4][5] = NOISY_PROPORTION/2.0, 1.0 - NOISY_PROPORTION, NOISY_PROPORTION/2.0 16 | T[5][4],T[5][5],T[5][6] = NOISY_PROPORTION/2.0, 1.0 - NOISY_PROPORTION, NOISY_PROPORTION/2.0 17 | T[6][5],T[6][6],T[6][7] = NOISY_PROPORTION/2.0, 1.0 - NOISY_PROPORTION, NOISY_PROPORTION/2.0 18 | T[7][6],T[7][7],T[7][8] = NOISY_PROPORTION/2.0, 1.0 - NOISY_PROPORTION, NOISY_PROPORTION/2.0 19 | T[8][7],T[8][8],T[8][9] = NOISY_PROPORTION/2.0, 1.0 - NOISY_PROPORTION, NOISY_PROPORTION/2.0 20 | T[9][8],T[9][9] = NOISY_PROPORTION - 0.2, 1.0 - NOISY_PROPORTION + 0.2 21 | 22 | return T 23 | 24 | def gen_noise_column(): 25 | # Column disturbs 26 | NOISY_PROPORTION = 0.6 27 | T = np.zeros((10,10)) 28 | T[0][0],T[0][3],T[0][5] = 1.0 - NOISY_PROPORTION, NOISY_PROPORTION/2.0, NOISY_PROPORTION/2.0 29 | T[1][1],T[1][3],T[1][5] = 1.0 - NOISY_PROPORTION, NOISY_PROPORTION/2.0, NOISY_PROPORTION/2.0 30 | T[2][2],T[2][3],T[2][5] = 1.0 - NOISY_PROPORTION, NOISY_PROPORTION/2.0, NOISY_PROPORTION/2.0 31 | T[3][3],T[3][5] = 1.0 - NOISY_PROPORTION + 0.2, NOISY_PROPORTION - 0.2 32 | T[4][4],T[4][3],T[4][5] = 1.0 - NOISY_PROPORTION, NOISY_PROPORTION/2.0, NOISY_PROPORTION/2.0 33 | T[5][5],T[5][3] = 1.0 - NOISY_PROPORTION + 0.2, NOISY_PROPORTION - 0.2 34 | T[6][6],T[6][3],T[6][5] = 1.0 - NOISY_PROPORTION, NOISY_PROPORTION/2.0, NOISY_PROPORTION/2.0 35 | T[7][7],T[7][3],T[7][5] = 1.0 - NOISY_PROPORTION, NOISY_PROPORTION/2.0, NOISY_PROPORTION/2.0 36 | T[8][8],T[8][3],T[8][5] = 1.0 - NOISY_PROPORTION, NOISY_PROPORTION/2.0, NOISY_PROPORTION/2.0 37 | T[9][9],T[9][3],T[9][5] = 1.0 - NOISY_PROPORTION, NOISY_PROPORTION/2.0, NOISY_PROPORTION/2.0 38 | 39 | return T 40 | 41 | def dis_noise(label,T): 42 | return xrange(10)[np.argmax(np.random.multinomial(size=1,n=1,pvals=T[label]))] 43 | 44 | def inject_data(noise, filename, noise_name): 45 | with open(filename,'rb') as f: 46 | with open(filename[:-4] + '_' + noise_name + filename[-4:],'wb') as w: 47 | e = f.read(3073) 48 | while e: 49 | label = ord(e[0]) 50 | #print(label) 51 | dis_label = dis_noise(label,noise) 52 | dis_e = chr(dis_label) + e[1:] 53 | w.write(dis_e) 54 | e = f.read(3073) 55 | 56 | 57 | 58 | tridiagonal_noise = gen_noise_tridiagonal() 59 | inject_data(tridiagonal_noise,'data/cifar10/cifar-10-batches-bin/data_batch_1.bin','tridiagonal') 60 | inject_data(tridiagonal_noise,'data/cifar10/cifar-10-batches-bin/data_batch_2.bin','tridiagonal') 61 | inject_data(tridiagonal_noise,'data/cifar10/cifar-10-batches-bin/data_batch_3.bin','tridiagonal') 62 | inject_data(tridiagonal_noise,'data/cifar10/cifar-10-batches-bin/data_batch_4.bin','tridiagonal') 63 | inject_data(tridiagonal_noise,'data/cifar10/cifar-10-batches-bin/data_batch_5.bin','tridiagonal') 64 | 65 | column_noise = gen_noise_column() 66 | inject_data(column_noise, 'data/cifar10/cifar-10-batches-bin/data_batch_1.bin','column') 67 | inject_data(column_noise, 'data/cifar10/cifar-10-batches-bin/data_batch_2.bin','column') 68 | inject_data(column_noise, 'data/cifar10/cifar-10-batches-bin/data_batch_3.bin','column') 69 | inject_data(column_noise, 'data/cifar10/cifar-10-batches-bin/data_batch_4.bin','column') 70 | inject_data(column_noise, 'data/cifar10/cifar-10-batches-bin/data_batch_5.bin','column') 71 | 72 | -------------------------------------------------------------------------------- /cifar10_train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A binary to train CIFAR-10 using a single GPU. 17 | 18 | Accuracy: 19 | cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of 20 | data) as judged by cifar10_eval.py. 21 | 22 | Speed: With batch_size 128. 23 | 24 | System | Step Time (sec/batch) | Accuracy 25 | ------------------------------------------------------------------ 26 | 1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours) 27 | 1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours) 28 | 29 | Usage: 30 | Please see the tutorial and website for how to download the CIFAR-10 31 | data set, compile the program and train the model. 32 | 33 | http://tensorflow.org/tutorials/deep_cnn/ 34 | """ 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function 38 | 39 | from datetime import datetime 40 | import time 41 | 42 | import tensorflow as tf 43 | 44 | import cifar10 45 | 46 | FLAGS = tf.app.flags.FLAGS 47 | 48 | tf.app.flags.DEFINE_string('train_dir', 'events/cifar10_train', 49 | """Directory where to write event logs """ 50 | """and checkpoint.""") 51 | #tf.app.flags.DEFINE_integer('max_steps', 1000000, 52 | # """Number of batches to run.""") 53 | tf.app.flags.DEFINE_integer('max_steps', 150000, 54 | """Number of batches to run.""") 55 | tf.app.flags.DEFINE_boolean('log_device_placement', False, 56 | """Whether to log device placement.""") 57 | tf.app.flags.DEFINE_integer('log_frequency', 10, 58 | """How often to log results to the console.""") 59 | 60 | 61 | def train(): 62 | """Train CIFAR-10 for a number of steps.""" 63 | with tf.Graph().as_default(): 64 | global_step = tf.train.get_or_create_global_step() 65 | 66 | # Get images and labels for CIFAR-10. 67 | # Force input pipeline to CPU:0 to avoid operations sometimes ending up on 68 | # GPU and resulting in a slow down. 69 | with tf.device('/cpu:0'): 70 | #images, labels = cifar10.distorted_inputs() 71 | images, labels, T, T_mask = cifar10.noisy_distorted_inputs(return_T_flag=True) 72 | 73 | # Build a Graph that computes the logits predictions from the 74 | # inference model. 75 | dropout = tf.placeholder(tf.float32,name='dropout_rate') 76 | logits = cifar10.inference(images,dropout,dropout_flag=True) 77 | 78 | # Calculate loss. 79 | loss = cifar10.loss(logits, labels) 80 | 81 | # Build a Graph that trains the model with one batch of examples and 82 | # updates the model parameters. 83 | train_op = cifar10.train(loss, global_step) 84 | 85 | class _LoggerHook(tf.train.SessionRunHook): 86 | """Logs loss and runtime.""" 87 | 88 | def begin(self): 89 | self._step = -1 90 | self._start_time = time.time() 91 | 92 | def before_run(self, run_context): 93 | self._step += 1 94 | return tf.train.SessionRunArgs(loss) # Asks for loss value. 95 | 96 | def after_run(self, run_context, run_values): 97 | if self._step % FLAGS.log_frequency == 0: 98 | current_time = time.time() 99 | duration = current_time - self._start_time 100 | self._start_time = current_time 101 | 102 | loss_value = run_values.results 103 | examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration 104 | sec_per_batch = float(duration / FLAGS.log_frequency) 105 | 106 | format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 107 | 'sec/batch)') 108 | print (format_str % (datetime.now(), self._step, loss_value, 109 | examples_per_sec, sec_per_batch)) 110 | 111 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.45) 112 | with tf.train.MonitoredTrainingSession( 113 | checkpoint_dir=FLAGS.train_dir, 114 | hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), 115 | tf.train.NanTensorHook(loss), 116 | _LoggerHook()], 117 | save_checkpoint_secs=60, 118 | config=tf.ConfigProto( 119 | log_device_placement=FLAGS.log_device_placement, gpu_options=gpu_options)) as mon_sess: 120 | while not mon_sess.should_stop(): 121 | #mon_sess.run(train_op,feed_dict={dropout:0.75}) 122 | res = mon_sess.run([train_op,global_step,T,T_mask],feed_dict={dropout:0.75}) 123 | if res[1] % 1000 == 0: 124 | print('Disturbing matrix\n',res[2]) 125 | print('Masked structure\n',res[3]) 126 | 127 | def main(argv=None): # pylint: disable=unused-argument 128 | cifar10.maybe_download_and_extract() 129 | if tf.gfile.Exists(FLAGS.train_dir): 130 | tf.gfile.DeleteRecursively(FLAGS.train_dir) 131 | tf.gfile.MakeDirs(FLAGS.train_dir) 132 | train() 133 | 134 | 135 | if __name__ == '__main__': 136 | tf.app.run() 137 | -------------------------------------------------------------------------------- /cifar10_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Evaluation for CIFAR-10. 17 | 18 | Accuracy: 19 | cifar10_train.py achieves 83.0% accuracy after 100K steps (256 epochs 20 | of data) as judged by cifar10_eval.py. 21 | 22 | Speed: 23 | On a single Tesla K40, cifar10_train.py processes a single batch of 128 images 24 | in 0.25-0.35 sec (i.e. 350 - 600 images /sec). The model reaches ~86% 25 | accuracy after 100K steps in 8 hours of training time. 26 | 27 | Usage: 28 | Please see the tutorial and website for how to download the CIFAR-10 29 | data set, compile the program and train the model. 30 | 31 | http://tensorflow.org/tutorials/deep_cnn/ 32 | """ 33 | from __future__ import absolute_import 34 | from __future__ import division 35 | from __future__ import print_function 36 | 37 | from datetime import datetime 38 | import math 39 | import time 40 | 41 | import numpy as np 42 | import tensorflow as tf 43 | 44 | import sys 45 | 46 | FLAGS = tf.app.flags.FLAGS 47 | 48 | tf.app.flags.DEFINE_string('eval_dir', 'cifar10_ce/cifar10_eval', 49 | """Directory where to write event logs.""") 50 | tf.app.flags.DEFINE_string('eval_data', 'test', 51 | """Either 'test' or 'train_eval'.""") 52 | tf.app.flags.DEFINE_string('checkpoint_dir', 'cifar10_ce/cifar10_train', 53 | """Directory where to read model checkpoints.""") 54 | #tf.app.flags.DEFINE_integer('eval_interval_secs', 60 * 5, 55 | # """How often to run the eval.""") 56 | tf.app.flags.DEFINE_integer('eval_interval_secs', 60, 57 | """How often to run the eval.""") 58 | tf.app.flags.DEFINE_integer('num_examples', 10000, 59 | """Number of examples to run.""") 60 | tf.app.flags.DEFINE_boolean('run_once', False, 61 | """Whether to run eval only once.""") 62 | import cifar10 63 | 64 | def eval_once(saver, summary_writer, top_k_op, summary_op): 65 | """Run Eval once. 66 | 67 | Args: 68 | saver: Saver. 69 | summary_writer: Summary writer. 70 | top_k_op: Top K op. 71 | summary_op: Summary op. 72 | """ 73 | with tf.Session() as sess: 74 | ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) 75 | if ckpt and ckpt.model_checkpoint_path: 76 | saver.restore(sess, ckpt.model_checkpoint_path) 77 | # Assuming model_checkpoint_path looks something like: 78 | # /my-favorite-path/cifar10_train/model.ckpt-0, 79 | # extract global_step from it. 80 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 81 | else: 82 | print('No checkpoint file found') 83 | return 84 | 85 | # Start the queue runners. 86 | coord = tf.train.Coordinator() 87 | try: 88 | threads = [] 89 | for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): 90 | threads.extend(qr.create_threads(sess, coord=coord, daemon=True, 91 | start=True)) 92 | 93 | num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size)) 94 | true_count = 0 # Counts the number of correct predictions. 95 | total_sample_count = num_iter * FLAGS.batch_size 96 | step = 0 97 | while step < num_iter and not coord.should_stop(): 98 | predictions = sess.run([top_k_op]) 99 | true_count += np.sum(predictions) 100 | step += 1 101 | 102 | # Compute precision @ 1. 103 | precision = true_count / total_sample_count 104 | print('%s: precision @ 1 = %.3f' % (datetime.now(), precision)) 105 | 106 | summary = tf.Summary() 107 | summary.ParseFromString(sess.run(summary_op)) 108 | summary.value.add(tag='Precision @ 1', simple_value=precision) 109 | summary_writer.add_summary(summary, global_step) 110 | except Exception as e: # pylint: disable=broad-except 111 | coord.request_stop(e) 112 | 113 | coord.request_stop() 114 | coord.join(threads, stop_grace_period_secs=10) 115 | 116 | 117 | def evaluate(): 118 | """Eval CIFAR-10 for a number of steps.""" 119 | with tf.Graph().as_default() as g: 120 | # Get images and labels for CIFAR-10. 121 | eval_data = FLAGS.eval_data == 'test' 122 | images, labels = cifar10.inputs(eval_data=eval_data) 123 | 124 | # Build a Graph that computes the logits predictions from the 125 | # inference model. 126 | logits = cifar10.inference(images) 127 | 128 | # Calculate predictions. 129 | top_k_op = tf.nn.in_top_k(logits, labels, 1) 130 | 131 | # Restore the moving average version of the learned variables for eval. 132 | variable_averages = tf.train.ExponentialMovingAverage( 133 | cifar10.MOVING_AVERAGE_DECAY) 134 | variables_to_restore = variable_averages.variables_to_restore() 135 | saver = tf.train.Saver(variables_to_restore) 136 | 137 | # Build the summary operation based on the TF collection of Summaries. 138 | summary_op = tf.summary.merge_all() 139 | 140 | summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g) 141 | 142 | while True: 143 | eval_once(saver, summary_writer, top_k_op, summary_op) 144 | if FLAGS.run_once: 145 | break 146 | time.sleep(FLAGS.eval_interval_secs) 147 | 148 | 149 | def main(argv=None): # pylint: disable=unused-argument 150 | cifar10.maybe_download_and_extract() 151 | if tf.gfile.Exists(FLAGS.eval_dir): 152 | tf.gfile.DeleteRecursively(FLAGS.eval_dir) 153 | tf.gfile.MakeDirs(FLAGS.eval_dir) 154 | evaluate() 155 | 156 | 157 | if __name__ == '__main__': 158 | tf.app.run() 159 | -------------------------------------------------------------------------------- /cifar10_train_varT.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A binary to train CIFAR-10 using a single GPU. 17 | 18 | Accuracy: 19 | cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of 20 | data) as judged by cifar10_eval.py. 21 | 22 | Speed: With batch_size 128. 23 | 24 | System | Step Time (sec/batch) | Accuracy 25 | ------------------------------------------------------------------ 26 | 1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours) 27 | 1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours) 28 | 29 | Usage: 30 | Please see the tutorial and website for how to download the CIFAR-10 31 | data set, compile the program and train the model. 32 | 33 | http://tensorflow.org/tutorials/deep_cnn/ 34 | """ 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function 38 | 39 | from datetime import datetime 40 | import time 41 | import math 42 | import pickle 43 | 44 | import tensorflow as tf 45 | import numpy as np 46 | 47 | import cifar10 48 | 49 | FLAGS = tf.app.flags.FLAGS 50 | 51 | 52 | tf.app.flags.DEFINE_string('init_dir','events/cifar10_train', 53 | """Directory where to load the intializing weights""") 54 | tf.app.flags.DEFINE_string('train_dir', 'events_varT/cifar10_train', 55 | """Directory where to write event logs """ 56 | """and checkpoint.""") 57 | #tf.app.flags.DEFINE_integer('max_steps', 1000000, 58 | # """Number of batches to run.""") 59 | tf.app.flags.DEFINE_integer('max_steps', 150000, 60 | """Number of batches to run.""") 61 | tf.app.flags.DEFINE_boolean('log_device_placement', False, 62 | """Whether to log device placement.""") 63 | tf.app.flags.DEFINE_integer('log_frequency', 10, 64 | """How often to log results to the console.""") 65 | 66 | def estimation_T(anchor_set=True): 67 | """Estimation T based on a pretrained model""" 68 | with tf.Graph().as_default(): 69 | images, labels = cifar10.inputs(eval_data=False) # on the training data 70 | logits = cifar10.inference(images) 71 | pred = tf.nn.softmax(logits) 72 | 73 | variable_averages = tf.train.ExponentialMovingAverage( 74 | cifar10.MOVING_AVERAGE_DECAY) 75 | variables_to_restore = variable_averages.variables_to_restore() 76 | saver = tf.train.Saver(variables_to_restore) 77 | 78 | with tf.Session() as sess: 79 | ckpt = tf.train.get_checkpoint_state(FLAGS.init_dir) 80 | if ckpt and ckpt.model_checkpoint_path: 81 | saver.restore(sess, ckpt.model_checkpoint_path) 82 | else: 83 | print('No checkpoint file found') 84 | return 85 | 86 | # start the queue runner 87 | coord = tf.train.Coordinator() 88 | try: 89 | threads = [] 90 | for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): 91 | threads.extend(qr.create_threads(sess, coord=coord, daemon=True, 92 | start=True)) 93 | num_iter = int(math.ceil(cifar10.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size)) 94 | step = 0 95 | preds = [] 96 | annotations = [] 97 | while step < num_iter: 98 | #print('step: ', step) 99 | res = sess.run([pred,labels]) 100 | preds.append(res[0]) 101 | annotations.append(res[1]) 102 | step += 1 103 | 104 | except Exception as e: 105 | coord.request_stop(e) 106 | 107 | coord.request_stop() 108 | coord.join(threads, stop_grace_period_secs=10) 109 | 110 | preds = np.concatenate(preds,axis=0) 111 | #print(preds.shape) 112 | annotations = np.concatenate(annotations,axis=0) 113 | if anchor_set: 114 | indices = np.argmax(preds,axis=0) 115 | #print(indices) 116 | est_T = np.array(np.take(preds,indices,axis=0)) 117 | else: 118 | unnormal_est_T = np.zeros((cifar10.NUM_CLASSES,cifar10.NUM_CLASSES)) 119 | for i in xrange(annotations.shape[0]): 120 | label = annotations[i] 121 | unnormal_est_T[:,label] += preds[i] 122 | unnormal_est_T_sum = np.sum(unnormal_est_T,axis=1) 123 | est_T = unnormal_est_T / unnormal_est_T_sum[:,None] 124 | return est_T 125 | 126 | def train(T_est): 127 | """Train CIFAR-10 for a number of steps.""" 128 | with tf.Graph().as_default(): 129 | global_step = tf.train.get_or_create_global_step() 130 | 131 | logits_T_init = np.log(T_est + 1e-8) 132 | logits_T = tf.get_variable('logits_T',shape=[cifar10.NUM_CLASSES,cifar10.NUM_CLASSES],initializer=tf.constant_initializer(logits_T_init)) 133 | T = tf.nn.softmax(logits_T) 134 | 135 | # Get images and labels for CIFAR-10. 136 | # Force input pipeline to CPU:0 to avoid operations sometimes ending up on 137 | # GPU and resulting in a slow down. 138 | with tf.device('/cpu:0'): 139 | #images, labels = cifar10.distorted_inputs() 140 | images, labels, T_tru, T_mask_tru = cifar10.noisy_distorted_inputs(return_T_flag=True) 141 | 142 | # Build a Graph that computes the logits predictions from the 143 | # inference model. 144 | dropout = tf.constant(0.75) 145 | logits = cifar10.inference(images,dropout,dropout_flag=True) 146 | 147 | # softmax layer 148 | preds =tf.nn.softmax(logits) 149 | preds_aug = tf.clip_by_value(tf.matmul(preds,T), 1e-8, 1.0 - 1e-8) 150 | logits_aug = tf.log(preds_aug) 151 | 152 | # Calculate loss. 153 | loss = cifar10.loss(logits_aug, labels) 154 | 155 | # Build a Graph that trains the model with one batch of examples and 156 | # updates the model parameters. 157 | train_op, variable_averages = cifar10.train(loss, global_step, return_variable_averages=True) 158 | 159 | class _LoggerHook(tf.train.SessionRunHook): 160 | """Logs loss and runtime.""" 161 | 162 | def begin(self): 163 | self._step = -1 164 | self._start_time = time.time() 165 | 166 | def before_run(self, run_context): 167 | self._step += 1 168 | return tf.train.SessionRunArgs(loss) # Asks for loss value. 169 | 170 | def after_run(self, run_context, run_values): 171 | if self._step % FLAGS.log_frequency == 0: 172 | current_time = time.time() 173 | duration = current_time - self._start_time 174 | self._start_time = current_time 175 | 176 | loss_value = run_values.results 177 | examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration 178 | sec_per_batch = float(duration / FLAGS.log_frequency) 179 | 180 | format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 181 | 'sec/batch)') 182 | print (format_str % (datetime.now(), self._step, loss_value, 183 | examples_per_sec, sec_per_batch)) 184 | 185 | #### build scalffold for MonitoredTrainingSession to restore the variables you wish 186 | ckpt = tf.train.get_checkpoint_state(FLAGS.init_dir) 187 | variables_to_restore = variable_averages.variables_to_restore() 188 | #print(variables_to_restore) 189 | for var_name in variables_to_restore.keys(): 190 | if ('logits_T' in var_name) or ('global_step' in var_name): 191 | del variables_to_restore[var_name] 192 | #print(variables_to_restore) 193 | 194 | init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint( 195 | ckpt.model_checkpoint_path, variables_to_restore) 196 | def InitAssignFn(scaffold,sess): 197 | sess.run(init_assign_op, init_feed_dict) 198 | 199 | scaffold = tf.train.Scaffold(saver=tf.train.Saver(), init_fn=InitAssignFn) 200 | 201 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.45) 202 | with tf.train.MonitoredTrainingSession( 203 | checkpoint_dir = FLAGS.train_dir, 204 | scaffold = scaffold, 205 | hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), 206 | tf.train.NanTensorHook(loss), 207 | _LoggerHook()], 208 | save_checkpoint_secs=60, 209 | config=tf.ConfigProto( 210 | log_device_placement=FLAGS.log_device_placement, gpu_options=gpu_options)) as mon_sess: 211 | 212 | while not mon_sess.should_stop(): 213 | res = mon_sess.run([train_op,global_step,T_tru,T_mask_tru]) 214 | if res[1] % 1000 == 0: 215 | print('Disturbing matrix\n',res[2]) 216 | print('Masked structure\n',res[3]) 217 | 218 | 219 | def main(argv=None): # pylint: disable=unused-argument 220 | cifar10.maybe_download_and_extract() 221 | if tf.gfile.Exists(FLAGS.train_dir): 222 | tf.gfile.DeleteRecursively(FLAGS.train_dir) 223 | tf.gfile.MakeDirs(FLAGS.train_dir) 224 | T = estimation_T() 225 | print('estimated T \n', T) 226 | with open('varT.pkl','w') as w: 227 | pickle.dump(T,w) 228 | train(T) 229 | 230 | 231 | if __name__ == '__main__': 232 | tf.app.run() 233 | -------------------------------------------------------------------------------- /cifar10_train_T.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A binary to train CIFAR-10 using a single GPU. 17 | 18 | Accuracy: 19 | cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of 20 | data) as judged by cifar10_eval.py. 21 | 22 | Speed: With batch_size 128. 23 | 24 | System | Step Time (sec/batch) | Accuracy 25 | ------------------------------------------------------------------ 26 | 1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours) 27 | 1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours) 28 | 29 | Usage: 30 | Please see the tutorial and website for how to download the CIFAR-10 31 | data set, compile the program and train the model. 32 | 33 | http://tensorflow.org/tutorials/deep_cnn/ 34 | """ 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function 38 | 39 | from datetime import datetime 40 | import time 41 | import math 42 | import pickle 43 | 44 | import tensorflow as tf 45 | import numpy as np 46 | 47 | import cifar10 48 | 49 | FLAGS = tf.app.flags.FLAGS 50 | 51 | 52 | tf.app.flags.DEFINE_string('init_dir','events/cifar10_train', 53 | """Directory where to load the intializing weights""") 54 | tf.app.flags.DEFINE_string('train_dir', 'events_T/cifar10_train', 55 | """Directory where to write event logs """ 56 | """and checkpoint.""") 57 | #tf.app.flags.DEFINE_integer('max_steps', 1000000, 58 | # """Number of batches to run.""") 59 | tf.app.flags.DEFINE_integer('max_steps', 150000, 60 | """Number of batches to run.""") 61 | tf.app.flags.DEFINE_boolean('log_device_placement', False, 62 | """Whether to log device placement.""") 63 | tf.app.flags.DEFINE_integer('log_frequency', 10, 64 | """How often to log results to the console.""") 65 | 66 | def estimation_T(): 67 | """Estimation T based on a pretrained model""" 68 | with tf.Graph().as_default(): 69 | images, labels = cifar10.inputs(eval_data=False) # on the training data 70 | logits = cifar10.inference(images) 71 | pred = tf.nn.softmax(logits) 72 | 73 | variable_averages = tf.train.ExponentialMovingAverage( 74 | cifar10.MOVING_AVERAGE_DECAY) 75 | variables_to_restore = variable_averages.variables_to_restore() 76 | saver = tf.train.Saver(variables_to_restore) 77 | 78 | with tf.Session() as sess: 79 | ckpt = tf.train.get_checkpoint_state(FLAGS.init_dir) 80 | if ckpt and ckpt.model_checkpoint_path: 81 | saver.restore(sess, ckpt.model_checkpoint_path) 82 | else: 83 | print('No checkpoint file found') 84 | return 85 | 86 | # start the queue runner 87 | coord = tf.train.Coordinator() 88 | try: 89 | threads = [] 90 | for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): 91 | threads.extend(qr.create_threads(sess, coord=coord, daemon=True, 92 | start=True)) 93 | num_iter = int(math.ceil(cifar10.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size)) 94 | step = 0 95 | preds = [] 96 | while step < num_iter: 97 | #print('step: ', step) 98 | res = sess.run(pred) 99 | preds.append(res) 100 | step += 1 101 | 102 | except Exception as e: 103 | coord.request_stop(e) 104 | 105 | coord.request_stop() 106 | coord.join(threads, stop_grace_period_secs=10) 107 | 108 | preds = np.concatenate(preds,axis=0) 109 | #print(preds.shape) 110 | indices = np.argmax(preds,axis=0) 111 | #print(indices) 112 | est_T = np.array(np.take(preds,indices,axis=0)) 113 | 114 | return est_T 115 | 116 | def loss_forward(logits, labels, T): 117 | """Define the forward noise-aware loss.""" 118 | preds =tf.nn.softmax(logits) 119 | preds_aug = tf.clip_by_value(tf.matmul(preds,T), 1e-8, 1.0 - 1e-8) 120 | logits_aug = tf.log(preds_aug) 121 | 122 | labels = tf.cast(labels, tf.int64) 123 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( 124 | labels=labels, logits=logits_aug, name='cross_entropy_per_example') 125 | cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy') 126 | tf.add_to_collection('losses', cross_entropy_mean) 127 | 128 | # The total loss is defined as the cross entropy loss plus all of the weight 129 | # decay terms (L2 loss). 130 | return tf.add_n(tf.get_collection('losses'), name='total_loss') 131 | 132 | def loss_backward(logits, labels, T_inv): 133 | """Define the backward noise-aware loss.""" 134 | labels = tf.one_hot(labels,cifar10.NUM_CLASSES,axis=-1) 135 | labels = tf.cast(labels, tf.float32) 136 | labels_aug = tf.matmul(labels,T_inv) 137 | 138 | preds = tf.nn.softmax(logits) 139 | preds = tf.clip_by_value(preds,1e-8,1-1e-8) 140 | 141 | cross_entropy = -tf.reduce_sum(labels_aug*tf.log(preds),axis=-1,name='cross_entropy_per_example') 142 | cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy') 143 | tf.add_to_collection('losses', cross_entropy_mean) 144 | 145 | # The total loss is defined as the cross entropy loss plus all of the weight 146 | # decay terms (L2 loss). 147 | return tf.add_n(tf.get_collection('losses'), name='total_loss') 148 | 149 | def train(T_est,T_inv_est): 150 | """Train CIFAR-10 for a number of steps.""" 151 | with tf.Graph().as_default(): 152 | global_step = tf.train.get_or_create_global_step() 153 | 154 | T_est = tf.constant(T_est) 155 | T_inv_est = tf.constant(T_inv_est) 156 | 157 | # Get images and labels for CIFAR-10. 158 | # Force input pipeline to CPU:0 to avoid operations sometimes ending up on 159 | # GPU and resulting in a slow down. 160 | with tf.device('/cpu:0'): 161 | #images, labels = cifar10.distorted_inputs() 162 | images, labels, T_tru, T_mask_tru = cifar10.noisy_distorted_inputs(return_T_flag=True) 163 | 164 | # Build a Graph that computes the logits predictions from the 165 | # inference model. 166 | dropout = tf.constant(0.75) 167 | logits = cifar10.inference(images, dropout, dropout_flag=True) 168 | 169 | # Calculate loss. 170 | #loss = loss_forward(logits, labels, T_est) 171 | loss = loss_forward(logits, labels, T_tru) 172 | #loss = loss_backward(logits, labels, T_inv_est) 173 | 174 | # Build a Graph that trains the model with one batch of examples and 175 | # updates the model parameters. 176 | train_op, variable_averages = cifar10.train(loss, global_step, return_variable_averages=True) 177 | 178 | class _LoggerHook(tf.train.SessionRunHook): 179 | """Logs loss and runtime.""" 180 | 181 | def begin(self): 182 | self._step = -1 183 | self._start_time = time.time() 184 | 185 | def before_run(self, run_context): 186 | self._step += 1 187 | return tf.train.SessionRunArgs(loss) # Asks for loss value. 188 | 189 | def after_run(self, run_context, run_values): 190 | if self._step % FLAGS.log_frequency == 0: 191 | current_time = time.time() 192 | duration = current_time - self._start_time 193 | self._start_time = current_time 194 | 195 | loss_value = run_values.results 196 | examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration 197 | sec_per_batch = float(duration / FLAGS.log_frequency) 198 | 199 | format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 200 | 'sec/batch)') 201 | print (format_str % (datetime.now(), self._step, loss_value, 202 | examples_per_sec, sec_per_batch)) 203 | 204 | #### build scalffold for MonitoredTrainingSession to restore the variables you wish 205 | ckpt = tf.train.get_checkpoint_state(FLAGS.init_dir) 206 | variables_to_restore = variable_averages.variables_to_restore() 207 | #print(variables_to_restore) 208 | for var_name in variables_to_restore.keys(): 209 | if ('logits_T' in var_name) or ('global_step' in var_name): 210 | del variables_to_restore[var_name] 211 | #print(variables_to_restore) 212 | 213 | init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint( 214 | ckpt.model_checkpoint_path, variables_to_restore) 215 | def InitAssignFn(scaffold,sess): 216 | sess.run(init_assign_op, init_feed_dict) 217 | 218 | scaffold = tf.train.Scaffold(saver=tf.train.Saver(), init_fn=InitAssignFn) 219 | 220 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.45) 221 | with tf.train.MonitoredTrainingSession( 222 | checkpoint_dir = FLAGS.train_dir, 223 | scaffold = scaffold, 224 | hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), 225 | tf.train.NanTensorHook(loss), 226 | _LoggerHook()], 227 | save_checkpoint_secs=60, 228 | config=tf.ConfigProto( 229 | log_device_placement=FLAGS.log_device_placement, gpu_options=gpu_options)) as mon_sess: 230 | while not mon_sess.should_stop(): 231 | res = mon_sess.run([train_op,global_step,T_tru,T_mask_tru]) 232 | if res[1] % 1000 == 0: 233 | print('Disturbing matrix\n',res[2]) 234 | print('Masked structure\n',res[3]) 235 | 236 | 237 | def main(argv=None): # pylint: disable=unused-argument 238 | cifar10.maybe_download_and_extract() 239 | if tf.gfile.Exists(FLAGS.train_dir): 240 | tf.gfile.DeleteRecursively(FLAGS.train_dir) 241 | tf.gfile.MakeDirs(FLAGS.train_dir) 242 | T = estimation_T() 243 | print('estimated T \n', T) 244 | with open('T.pkl','w') as w: 245 | pickle.dump(T,w) 246 | T_inv = np.linalg.inv(T) 247 | print('estimated inverse T \n', T_inv) 248 | train(T,T_inv) 249 | 250 | if __name__ == '__main__': 251 | tf.app.run() 252 | -------------------------------------------------------------------------------- /cifar10_train_GANT.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A binary to train CIFAR-10 using a single GPU. 17 | 18 | Accuracy: 19 | cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of 20 | data) as judged by cifar10_eval.py. 21 | 22 | Speed: With batch_size 128. 23 | 24 | System | Step Time (sec/batch) | Accuracy 25 | ------------------------------------------------------------------ 26 | 1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours) 27 | 1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours) 28 | 29 | Usage: 30 | Please see the tutorial and website for how to download the CIFAR-10 31 | data set, compile the program and train the model. 32 | 33 | http://tensorflow.org/tutorials/deep_cnn/ 34 | """ 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function 38 | 39 | from datetime import datetime 40 | import time 41 | import math 42 | import pickle 43 | 44 | import tensorflow as tf 45 | import numpy as np 46 | 47 | import cifar10 48 | 49 | slim = tf.contrib.slim 50 | Normal = tf.contrib.distributions.Normal 51 | 52 | tf.logging.set_verbosity(tf.logging.ERROR) 53 | 54 | FLAGS = tf.app.flags.FLAGS 55 | 56 | 57 | tf.app.flags.DEFINE_string('init_dir', 'events/cifar10_train', 58 | """Directory where to restore """ 59 | """from the checkpoint.""") 60 | tf.app.flags.DEFINE_string('train_dir', 'events_GANT/cifar10_train', 61 | """Directory where to write event logs """ 62 | """and checkpoint.""") 63 | #tf.app.flags.DEFINE_integer('max_steps', 1000000, 64 | # """Number of batches to run.""") 65 | tf.app.flags.DEFINE_integer('max_steps', 150000, 66 | """Number of batches to run.""") 67 | tf.app.flags.DEFINE_boolean('log_device_placement', False, 68 | """Whether to log device placement.""") 69 | tf.app.flags.DEFINE_integer('log_frequency', 10, 70 | """How often to log results to the console.""") 71 | 72 | def train(T_est): 73 | """Train CIFAR-10 for a number of steps.""" 74 | with tf.Graph().as_default(): 75 | global_step = tf.train.get_or_create_global_step() 76 | 77 | # Get images and labels for CIFAR-10. 78 | # Force input pipeline to CPU:0 to avoid operations sometimes ending up on 79 | # GPU and resulting in a slow down. 80 | with tf.device('/cpu:0'): 81 | #images, labels = cifar10.distorted_inputs() 82 | #images, labels = cifar10.noisy_distorted_inputs() 83 | images, labels, T_tru, T_mask = cifar10.noisy_distorted_inputs(return_T_flag=True) 84 | 85 | T_est = tf.constant(T_est,dtype=tf.float32) 86 | 87 | #### Prior and groudtruth 88 | T_est = tf.tile(tf.expand_dims(T_est, 0),[FLAGS.batch_size,1,1]) 89 | T_tru = tf.tile(tf.expand_dims(T_tru, 0),[FLAGS.batch_size,1,1]) 90 | T_mask= tf.tile(tf.expand_dims(T_mask,0),[FLAGS.batch_size,1,1]) 91 | 92 | #### generator 93 | with tf.variable_scope('generator') as scope: 94 | normal = Normal(tf.zeros([1,10]),tf.ones([1,10])) 95 | epsilon = tf.to_float(normal.sample(FLAGS.batch_size)) 96 | net = slim.stack(epsilon,slim.fully_connected,[50,50]) 97 | net = slim.fully_connected(net,cifar10.NUM_CLASSES*cifar10.NUM_CLASSES,activation_fn=None) 98 | net = tf.reshape(net,[-1,cifar10.NUM_CLASSES,cifar10.NUM_CLASSES]) 99 | S = tf.nn.softmax(net) 100 | 101 | # input to discriminator 102 | S_mask = tf.sigmoid((S-0.05)/0.005) 103 | 104 | #### discriminator 105 | def discriminator(input): 106 | with tf.variable_scope('discriminator',reuse=tf.AUTO_REUSE) as scope: 107 | input = slim.flatten(input) 108 | net = slim.fully_connected(input,20,activation_fn=tf.nn.sigmoid) 109 | net = slim.fully_connected(net,1,activation_fn=None) 110 | return net 111 | D_t = discriminator(T_mask) 112 | D_s = discriminator(S_mask) 113 | 114 | #### reconstructor 115 | dropout = tf.constant(0.75) 116 | logits = cifar10.inference(images,dropout,dropout_flag=True) 117 | preds = tf.nn.softmax(logits) 118 | preds_aug = tf.reshape(tf.matmul(tf.reshape(preds,[FLAGS.batch_size,1,-1]),S),[FLAGS.batch_size,-1]) 119 | logits_aug = tf.log(tf.clip_by_value(preds_aug,1e-8,1.0-1e-8)) 120 | 121 | #### loss 122 | # R loss 123 | R_loss = cifar10.loss(logits_aug,labels) 124 | tf.summary.scalar('reconstructor loss',R_loss) 125 | 126 | # D loss 127 | D_loss = -tf.reduce_mean(D_t) + tf.reduce_mean(D_s) 128 | tf.summary.scalar('discriminator loss',D_loss) 129 | 130 | # G loss 131 | G_loss = R_loss - tf.reduce_mean(D_s) 132 | #G_loss = - tf.reduce_mean(D_s) 133 | tf.summary.scalar('generator loss',G_loss) 134 | 135 | # initialization of G 136 | S_logits = tf.log(tf.clip_by_value(S,1e-8,1.0-1e-8)) 137 | Initial_G_loss = -tf.reduce_mean(tf.reduce_sum(tf.reduce_sum(T_est*S_logits,axis=2),axis=1)) 138 | 139 | # variable list 140 | var_C = [] 141 | var_D = [] 142 | var_G = [] 143 | for item in tf.trainable_variables(): 144 | if "generator" in item.name: 145 | var_G.append(item) 146 | elif "discriminator" in item.name: 147 | var_D.append(item) 148 | else: 149 | var_C.append(item) 150 | 151 | #### optimizer 152 | # Build a Graph that trains the model with one batch of examples and 153 | # updates the model parameters. 154 | R_train_op, variable_averages, lr = cifar10.train(R_loss,global_step,var_C,return_variable_averages=True,return_lr=True) 155 | lr_DG = tf.constant(1e-5) 156 | D_train_op = tf.train.RMSPropOptimizer(learning_rate=lr_DG).minimize(D_loss,var_list=var_D) 157 | G_train_op = tf.train.RMSPropOptimizer(learning_rate=lr_DG).minimize(G_loss,var_list=var_G+var_C) 158 | 159 | #### optimizer for the initialization of the generator and the discriminator 160 | Initial_G_train_op = tf.train.RMSPropOptimizer(learning_rate=1e-4).minimize(Initial_G_loss,var_list=var_G) 161 | 162 | #### weight clamping for WGAN 163 | clip_D = [var.assign(tf.clip_by_value(var,-0.01,0.005)) for var in var_D] 164 | 165 | class _LoggerHook(tf.train.SessionRunHook): 166 | """Logs loss and runtime.""" 167 | 168 | def begin(self): 169 | self._my_print_flag = False 170 | self._step = -1 171 | self._start_time = time.time() 172 | 173 | def before_run(self, run_context): 174 | self._step += 1 175 | return tf.train.SessionRunArgs(R_loss) # Asks for loss value. 176 | 177 | def after_run(self, run_context, run_values): 178 | if self._step % FLAGS.log_frequency == 0: 179 | current_time = time.time() 180 | duration = current_time - self._start_time 181 | self._start_time = current_time 182 | 183 | loss_value = run_values.results 184 | examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration 185 | sec_per_batch = float(duration / FLAGS.log_frequency) 186 | 187 | if self._my_print_flag: 188 | format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 189 | 'sec/batch)') 190 | print (format_str % (datetime.now(), self._step, loss_value, 191 | examples_per_sec, sec_per_batch)) 192 | 193 | #### build scalffold for MonitoredTrainingSession to restore the variables you wish 194 | ckpt = tf.train.get_checkpoint_state(FLAGS.init_dir) 195 | variables_to_restore = variable_averages.variables_to_restore() 196 | #print(variables_to_restore) 197 | for var_name in variables_to_restore.keys(): 198 | if ('generator' in var_name) or ('discriminator' in var_name) or ('RMSProp' in var_name) or ('global_step' in var_name): 199 | del variables_to_restore[var_name] 200 | print(variables_to_restore) 201 | 202 | init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint( 203 | ckpt.model_checkpoint_path, variables_to_restore) 204 | def InitAssignFn(scaffold,sess): 205 | sess.run(init_assign_op, init_feed_dict) 206 | 207 | scaffold = tf.train.Scaffold(saver=tf.train.Saver(), init_fn=InitAssignFn) 208 | 209 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.45) 210 | loggerHook = _LoggerHook() 211 | with tf.train.MonitoredTrainingSession( 212 | checkpoint_dir = FLAGS.train_dir, 213 | scaffold = scaffold, 214 | hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), 215 | tf.train.NanTensorHook(R_loss), 216 | loggerHook], 217 | save_checkpoint_secs=60, 218 | config=tf.ConfigProto( 219 | log_device_placement=FLAGS.log_device_placement, gpu_options=gpu_options)) as mon_sess: 220 | 221 | #### pretrain the generator 222 | loggerHook._my_print_flag = False 223 | res = None 224 | for i in xrange(10000): 225 | res = mon_sess.run([Initial_G_train_op,Initial_G_loss,T_est,S,lr,lr_DG]) 226 | if i % 1000 == 0: 227 | print('Step: %d\tGenerator loss: %.3f'%(i,res[1])) 228 | print('Pre-estimation', res[2][0]) 229 | print('Initialization', res[3][0]) 230 | 231 | #### iteratively train G and 232 | loggerHook._my_print_flag = False 233 | step = 0 234 | step_control = 0 235 | lr_, lr_DG_ = res[-2], res[-1] 236 | while not mon_sess.should_stop(): 237 | # update the learning_rate of generator and discriminator to sync with the classifier 238 | if lr_DG_ >= lr_: # to avoid over-tuning the transition matrix due to the learning_rate decay 239 | lr_DG_ = lr_DG_/10.0 240 | # do the adversarial game 241 | if step >= step_control: 242 | res = mon_sess.run([G_train_op,G_loss,T_est,S,T_tru,S_mask,T_mask],feed_dict={lr_DG:lr_DG_}) 243 | g_loss = res[1] 244 | 245 | for i in xrange(5): 246 | _, d_loss = mon_sess.run([D_train_op,D_loss],feed_dict={lr_DG:lr_DG_}) 247 | 248 | # train the classifier 249 | _, r_loss, g_step, lr_, lr_DG_ = mon_sess.run([R_train_op,R_loss,global_step,lr,lr_DG],feed_dict={lr_DG:lr_DG_}) 250 | 251 | if step >= step_control: 252 | print('Step: %d\tR_loss: %.3f\tD_loss: %.3f\tG_loss: %.3f' % (g_step, r_loss, d_loss, g_loss)) 253 | 254 | if (g_step % 2000 == 0) or (g_step == FLAGS.max_steps-1): 255 | print('Pre-estimation', res[2][0]) 256 | print('Generated sample', res[3][0]) 257 | print('True transition', res[4][0]) 258 | print('Generated structure',res[5][0]) 259 | print('True structure',res[6][0]) 260 | else: 261 | print('Step: %d\tR_loss: %.3f' % (g_step, r_loss)) 262 | 263 | step = g_step 264 | 265 | def main(argv=None): # pylint: disable=unused-argument 266 | cifar10.maybe_download_and_extract() 267 | if tf.gfile.Exists(FLAGS.train_dir): 268 | tf.gfile.DeleteRecursively(FLAGS.train_dir) 269 | tf.gfile.MakeDirs(FLAGS.train_dir) 270 | with open('T.pkl') as f: 271 | T = pickle.load(f) 272 | print('estimated confusion matrix\n',T) 273 | train(T) 274 | 275 | if __name__ == '__main__': 276 | tf.app.run() 277 | -------------------------------------------------------------------------------- /cifar10_input.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Routine for decoding the CIFAR-10 binary file format.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import numpy as np 24 | 25 | from six.moves import xrange # pylint: disable=redefined-builtin 26 | import tensorflow as tf 27 | 28 | # Process images of this size. Note that this differs from the original CIFAR 29 | # image size of 32 x 32. If one alters this number, then the entire model 30 | # architecture will change and any model would need to be retrained. 31 | IMAGE_SIZE = 24 32 | 33 | # Global constants describing the CIFAR-10 data set. 34 | NUM_CLASSES = 10 35 | NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000 36 | NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000 37 | 38 | # noisy proportion 39 | #NOISE_TYPE = 'tridiagonal' 40 | NOISE_TYPE = 'column' 41 | 42 | def read_cifar10(filename_queue): 43 | """Reads and parses examples from CIFAR10 data files. 44 | 45 | Recommendation: if you want N-way read parallelism, call this function 46 | N times. This will give you N independent Readers reading different 47 | files & positions within those files, which will give better mixing of 48 | examples. 49 | 50 | Args: 51 | filename_queue: A queue of strings with the filenames to read from. 52 | 53 | Returns: 54 | An object representing a single example, with the following fields: 55 | height: number of rows in the result (32) 56 | width: number of columns in the result (32) 57 | depth: number of color channels in the result (3) 58 | key: a scalar string Tensor describing the filename & record number 59 | for this example. 60 | label: an int32 Tensor with the label in the range 0..9. 61 | uint8image: a [height, width, depth] uint8 Tensor with the image data 62 | """ 63 | 64 | class CIFAR10Record(object): 65 | pass 66 | result = CIFAR10Record() 67 | 68 | # Dimensions of the images in the CIFAR-10 dataset. 69 | # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the 70 | # input format. 71 | label_bytes = 1 # 2 for CIFAR-100 72 | result.height = 32 73 | result.width = 32 74 | result.depth = 3 75 | image_bytes = result.height * result.width * result.depth 76 | # Every record consists of a label followed by the image, with a 77 | # fixed number of bytes for each. 78 | record_bytes = label_bytes + image_bytes 79 | 80 | # Read a record, getting filenames from the filename_queue. No 81 | # header or footer in the CIFAR-10 format, so we leave header_bytes 82 | # and footer_bytes at their default of 0. 83 | reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) 84 | result.key, value = reader.read(filename_queue) 85 | 86 | # Convert from a string to a vector of uint8 that is record_bytes long. 87 | record_bytes = tf.decode_raw(value, tf.uint8) 88 | 89 | # The first bytes represent the label, which we convert from uint8->int32. 90 | result.label = tf.cast( 91 | tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32) 92 | 93 | # The remaining bytes after the label represent the image, which we reshape 94 | # from [depth * height * width] to [depth, height, width]. 95 | depth_major = tf.reshape( 96 | tf.strided_slice(record_bytes, [label_bytes], 97 | [label_bytes + image_bytes]), 98 | [result.depth, result.height, result.width]) 99 | # Convert from [depth, height, width] to [height, width, depth]. 100 | result.uint8image = tf.transpose(depth_major, [1, 2, 0]) 101 | 102 | return result 103 | 104 | 105 | def _generate_image_and_label_batch(image, label, min_queue_examples, 106 | batch_size, shuffle): 107 | """Construct a queued batch of images and labels. 108 | 109 | Args: 110 | image: 3-D Tensor of [height, width, 3] of type.float32. 111 | label: 1-D Tensor of type.int32 112 | min_queue_examples: int32, minimum number of samples to retain 113 | in the queue that provides of batches of examples. 114 | batch_size: Number of images per batch. 115 | shuffle: boolean indicating whether to use a shuffling queue. 116 | 117 | Returns: 118 | images: Images. 4D tensor of [batch_size, height, width, 3] size. 119 | labels: Labels. 1D tensor of [batch_size] size. 120 | """ 121 | # Create a queue that shuffles the examples, and then 122 | # read 'batch_size' images + labels from the example queue. 123 | num_preprocess_threads = 16 124 | if shuffle: 125 | images, label_batch = tf.train.shuffle_batch( 126 | [image, label], 127 | batch_size=batch_size, 128 | num_threads=num_preprocess_threads, 129 | capacity=min_queue_examples + 3 * batch_size, 130 | min_after_dequeue=min_queue_examples) 131 | else: 132 | images, label_batch = tf.train.batch( 133 | [image, label], 134 | batch_size=batch_size, 135 | num_threads=num_preprocess_threads, 136 | capacity=min_queue_examples + 3 * batch_size) 137 | 138 | # Display the training images in the visualizer. 139 | tf.summary.image('images', images) 140 | 141 | return images, tf.reshape(label_batch, [batch_size]) 142 | 143 | 144 | def distorted_inputs(data_dir, batch_size, noise_type=None): 145 | """Construct distorted input for CIFAR training using the Reader ops. 146 | 147 | Args: 148 | data_dir: Path to the CIFAR-10 data directory. 149 | batch_size: Number of images per batch. 150 | 151 | Returns: 152 | images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. 153 | labels: Labels. 1D tensor of [batch_size] size. 154 | """ 155 | if noise_type: 156 | filenames = [os.path.join(data_dir, 'data_batch_%d_%s.bin' % (i,noise_type)) 157 | for i in xrange(1, 6)] 158 | else: 159 | filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) 160 | for i in xrange(1, 6)] 161 | for f in filenames: 162 | if not tf.gfile.Exists(f): 163 | raise ValueError('Failed to find file: ' + f) 164 | 165 | # Create a queue that produces the filenames to read. 166 | filename_queue = tf.train.string_input_producer(filenames) 167 | 168 | with tf.name_scope('data_augmentation'): 169 | # Read examples from files in the filename queue. 170 | read_input = read_cifar10(filename_queue) 171 | reshaped_image = tf.cast(read_input.uint8image, tf.float32) 172 | 173 | height = IMAGE_SIZE 174 | width = IMAGE_SIZE 175 | 176 | # Image processing for training the network. Note the many random 177 | # distortions applied to the image. 178 | 179 | # Randomly crop a [height, width] section of the image. 180 | distorted_image = tf.random_crop(reshaped_image, [height, width, 3]) 181 | 182 | # Randomly flip the image horizontally. 183 | distorted_image = tf.image.random_flip_left_right(distorted_image) 184 | 185 | # Because these operations are not commutative, consider randomizing 186 | # the order their operation. 187 | # NOTE: since per_image_standardization zeros the mean and makes 188 | # the stddev unit, this likely has no effect see tensorflow#1458. 189 | distorted_image = tf.image.random_brightness(distorted_image, 190 | max_delta=63) 191 | distorted_image = tf.image.random_contrast(distorted_image, 192 | lower=0.2, upper=1.8) 193 | 194 | # Subtract off the mean and divide by the variance of the pixels. 195 | float_image = tf.image.per_image_standardization(distorted_image) 196 | 197 | # Set the shapes of tensors. 198 | float_image.set_shape([height, width, 3]) 199 | read_input.label.set_shape([1]) 200 | 201 | # Ensure that the random shuffling has good mixing properties. 202 | min_fraction_of_examples_in_queue = 0.4 203 | min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * 204 | min_fraction_of_examples_in_queue) 205 | print ('Filling queue with %d CIFAR images before starting to train. ' 206 | 'This will take a few minutes.' % min_queue_examples) 207 | 208 | # Generate a batch of images and labels by building up a queue of examples. 209 | return _generate_image_and_label_batch(float_image, read_input.label, 210 | min_queue_examples, batch_size, 211 | shuffle=True) 212 | 213 | def noisy_distorted_inputs(data_dir, batch_size, return_T_flag): 214 | """Construct noisy distorted input for CIFAR training using the Reader ops. 215 | 216 | Args: 217 | data_dir: Path to the CIFAR-10 data directory. 218 | batch_size: Number of images per batch. 219 | 220 | Returns: 221 | images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. 222 | labels: Labels. 1D tensor of [batch_size] size. 223 | """ 224 | import noise 225 | if NOISE_TYPE == 'tridiagonal': 226 | T = noise.gen_noise_tridiagonal() 227 | elif NOISE_TYPE == 'column': 228 | T = noise.gen_noise_column() 229 | else: 230 | raise ValueError("NOISE TYPE is not properly set") 231 | 232 | T_zero_mask = np.ones((10,10)) 233 | T_zero_mask[np.equal(T,0)] = 0.0 234 | 235 | images, labels = distorted_inputs(data_dir, batch_size, NOISE_TYPE) 236 | 237 | if return_T_flag: 238 | return images, labels, tf.constant(T,dtype=tf.float32), tf.constant(T_zero_mask,dtype=tf.float32) 239 | else: 240 | return images, labels 241 | 242 | def inputs(eval_data, data_dir, batch_size): 243 | """Construct input for CIFAR evaluation using the Reader ops. 244 | 245 | Args: 246 | eval_data: bool, indicating if one should use the train or eval data set. 247 | data_dir: Path to the CIFAR-10 data directory. 248 | batch_size: Number of images per batch. 249 | 250 | Returns: 251 | images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. 252 | labels: Labels. 1D tensor of [batch_size] size. 253 | """ 254 | if not eval_data: 255 | filenames = [os.path.join(data_dir, 'data_batch_%d_%s.bin' % (i,NOISE_TYPE)) 256 | for i in xrange(1, 6)] 257 | num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN 258 | else: 259 | filenames = [os.path.join(data_dir, 'test_batch.bin')] 260 | num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL 261 | 262 | for f in filenames: 263 | if not tf.gfile.Exists(f): 264 | raise ValueError('Failed to find file: ' + f) 265 | 266 | with tf.name_scope('input'): 267 | # Create a queue that produces the filenames to read. 268 | filename_queue = tf.train.string_input_producer(filenames) 269 | 270 | # Read examples from files in the filename queue. 271 | read_input = read_cifar10(filename_queue) 272 | reshaped_image = tf.cast(read_input.uint8image, tf.float32) 273 | 274 | height = IMAGE_SIZE 275 | width = IMAGE_SIZE 276 | 277 | # Image processing for evaluation. 278 | # Crop the central [height, width] of the image. 279 | resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, 280 | height, width) 281 | 282 | # Subtract off the mean and divide by the variance of the pixels. 283 | float_image = tf.image.per_image_standardization(resized_image) 284 | 285 | # Set the shapes of tensors. 286 | float_image.set_shape([height, width, 3]) 287 | read_input.label.set_shape([1]) 288 | 289 | # Ensure that the random shuffling has good mixing properties. 290 | min_fraction_of_examples_in_queue = 0.4 291 | min_queue_examples = int(num_examples_per_epoch * 292 | min_fraction_of_examples_in_queue) 293 | 294 | # Generate a batch of images and labels by building up a queue of examples. 295 | return _generate_image_and_label_batch(float_image, read_input.label, 296 | min_queue_examples, batch_size, 297 | shuffle=False) 298 | -------------------------------------------------------------------------------- /cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Builds the CIFAR-10 network. 17 | 18 | Summary of available functions: 19 | 20 | # Compute input images and labels for training. If you would like to run 21 | # evaluations, use inputs() instead. 22 | inputs, labels = distorted_inputs() 23 | 24 | # Compute inference on the model inputs to make a prediction. 25 | predictions = inference(inputs) 26 | 27 | # Compute the total loss of the prediction with respect to the labels. 28 | loss = loss(predictions, labels) 29 | 30 | # Create a graph to run one step of training with respect to the loss. 31 | train_op = train(loss, global_step) 32 | """ 33 | # pylint: disable=missing-docstring 34 | from __future__ import absolute_import 35 | from __future__ import division 36 | from __future__ import print_function 37 | 38 | import os 39 | import re 40 | import sys 41 | import tarfile 42 | 43 | from six.moves import urllib 44 | import tensorflow as tf 45 | 46 | import cifar10_input 47 | 48 | FLAGS = tf.app.flags.FLAGS 49 | 50 | # Basic model parameters. 51 | tf.app.flags.DEFINE_integer('batch_size', 128, 52 | """Number of images to process in a batch.""") 53 | tf.app.flags.DEFINE_string('data_dir', 'data/cifar10', 54 | """Path to the CIFAR-10 data directory.""") 55 | tf.app.flags.DEFINE_boolean('use_fp16', False, 56 | """Train the model using fp16.""") 57 | 58 | # Global constants describing the CIFAR-10 data set. 59 | IMAGE_SIZE = cifar10_input.IMAGE_SIZE 60 | NUM_CLASSES = cifar10_input.NUM_CLASSES 61 | NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN 62 | NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL 63 | 64 | 65 | # Constants describing the training process. 66 | MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average. 67 | #NUM_EPOCHS_PER_DECAY = 350.0 # Epochs after which learning rate decays. 68 | NUM_EPOCHS_PER_DECAY = 50.0 # Epochs after which learning rate decays. 69 | LEARNING_RATE_DECAY_FACTOR = 0.1 # Learning rate decay factor. 70 | INITIAL_LEARNING_RATE = 0.1 # Initial learning rate. 71 | 72 | # If a model is trained with multiple GPUs, prefix all Op names with tower_name 73 | # to differentiate the operations. Note that this prefix is removed from the 74 | # names of the summaries when visualizing a model. 75 | TOWER_NAME = 'tower' 76 | 77 | DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz' 78 | 79 | 80 | def _activation_summary(x): 81 | """Helper to create summaries for activations. 82 | 83 | Creates a summary that provides a histogram of activations. 84 | Creates a summary that measures the sparsity of activations. 85 | 86 | Args: 87 | x: Tensor 88 | Returns: 89 | nothing 90 | """ 91 | # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training 92 | # session. This helps the clarity of presentation on tensorboard. 93 | tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name) 94 | tf.summary.histogram(tensor_name + '/activations', x) 95 | tf.summary.scalar(tensor_name + '/sparsity', 96 | tf.nn.zero_fraction(x)) 97 | 98 | 99 | def _variable_on_cpu(name, shape, initializer): 100 | """Helper to create a Variable stored on CPU memory. 101 | 102 | Args: 103 | name: name of the variable 104 | shape: list of ints 105 | initializer: initializer for Variable 106 | 107 | Returns: 108 | Variable Tensor 109 | """ 110 | with tf.device('/cpu:0'): 111 | dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 112 | var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype) 113 | return var 114 | 115 | 116 | def _variable_with_weight_decay(name, shape, stddev, wd): 117 | """Helper to create an initialized Variable with weight decay. 118 | 119 | Note that the Variable is initialized with a truncated normal distribution. 120 | A weight decay is added only if one is specified. 121 | 122 | Args: 123 | name: name of the variable 124 | shape: list of ints 125 | stddev: standard deviation of a truncated Gaussian 126 | wd: add L2Loss weight decay multiplied by this float. If None, weight 127 | decay is not added for this Variable. 128 | 129 | Returns: 130 | Variable Tensor 131 | """ 132 | dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 133 | var = _variable_on_cpu( 134 | name, 135 | shape, 136 | tf.truncated_normal_initializer(stddev=stddev, dtype=dtype)) 137 | if wd is not None: 138 | weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss') 139 | tf.add_to_collection('losses', weight_decay) 140 | return var 141 | 142 | 143 | def distorted_inputs(): 144 | """Construct distorted input for CIFAR training using the Reader ops. 145 | 146 | Returns: 147 | images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. 148 | labels: Labels. 1D tensor of [batch_size] size. 149 | 150 | Raises: 151 | ValueError: If no data_dir 152 | """ 153 | if not FLAGS.data_dir: 154 | raise ValueError('Please supply a data_dir') 155 | data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin') 156 | images, labels = cifar10_input.distorted_inputs(data_dir=data_dir, 157 | batch_size=FLAGS.batch_size) 158 | if FLAGS.use_fp16: 159 | images = tf.cast(images, tf.float16) 160 | labels = tf.cast(labels, tf.float16) 161 | return images, labels 162 | 163 | def noisy_distorted_inputs(return_T_flag=False): 164 | """Construct distorted input for CIFAR training using the Reader ops. 165 | 166 | Returns: 167 | images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. 168 | labels: Labels. 1D tensor of [batch_size] size. 169 | 170 | Raises: 171 | ValueError: If no data_dir 172 | """ 173 | if not FLAGS.data_dir: 174 | raise ValueError('Please supply a data_dir') 175 | data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin') 176 | if return_T_flag: 177 | images, labels, T, T_mask = cifar10_input.noisy_distorted_inputs(data_dir=data_dir, 178 | batch_size=FLAGS.batch_size,return_T_flag=return_T_flag) 179 | else: 180 | images, labels = cifar10_input.noisy_distorted_inputs(data_dir=data_dir, 181 | batch_size=FLAGS.batch_size,return_T_flag=return_T_flag) 182 | if FLAGS.use_fp16: 183 | images = tf.cast(images, tf.float16) 184 | labels = tf.cast(labels, tf.float16) 185 | 186 | if return_T_flag: 187 | return images, labels, T, T_mask 188 | else: 189 | return images, labels 190 | 191 | def inputs(eval_data,batch_size=FLAGS.batch_size): 192 | """Construct input for CIFAR evaluation using the Reader ops. 193 | 194 | Args: 195 | eval_data: bool, indicating if one should use the train or eval data set. 196 | 197 | Returns: 198 | images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. 199 | labels: Labels. 1D tensor of [batch_size] size. 200 | 201 | Raises: 202 | ValueError: If no data_dir 203 | """ 204 | if not FLAGS.data_dir: 205 | raise ValueError('Please supply a data_dir') 206 | data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin') 207 | images, labels = cifar10_input.inputs(eval_data=eval_data, 208 | data_dir=data_dir, 209 | batch_size=batch_size) 210 | if FLAGS.use_fp16: 211 | images = tf.cast(images, tf.float16) 212 | labels = tf.cast(labels, tf.float16) 213 | return images, labels 214 | 215 | 216 | def inference(images,dropout=None,dropout_flag=None): 217 | """Build the CIFAR-10 model. 218 | 219 | Args: 220 | images: Images returned from distorted_inputs() or inputs(). 221 | 222 | Returns: 223 | Logits. 224 | """ 225 | # We instantiate all variables using tf.get_variable() instead of 226 | # tf.Variable() in order to share variables across multiple GPU training runs. 227 | # If we only ran this model on a single GPU, we could simplify this function 228 | # by replacing all instances of tf.get_variable() with tf.Variable(). 229 | # 230 | # conv1 231 | with tf.variable_scope('conv1') as scope: 232 | kernel = _variable_with_weight_decay('weights', 233 | shape=[5, 5, 3, 64], 234 | stddev=5e-2, 235 | wd=None) 236 | conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME') 237 | biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0)) 238 | pre_activation = tf.nn.bias_add(conv, biases) 239 | conv1 = tf.nn.relu(pre_activation, name=scope.name) 240 | _activation_summary(conv1) 241 | 242 | # pool1 243 | pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], 244 | padding='SAME', name='pool1') 245 | # norm1 246 | norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, 247 | name='norm1') 248 | 249 | # conv2 250 | with tf.variable_scope('conv2') as scope: 251 | kernel = _variable_with_weight_decay('weights', 252 | shape=[5, 5, 64, 64], 253 | stddev=5e-2, 254 | wd=None) 255 | conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME') 256 | biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1)) 257 | pre_activation = tf.nn.bias_add(conv, biases) 258 | conv2 = tf.nn.relu(pre_activation, name=scope.name) 259 | _activation_summary(conv2) 260 | 261 | # norm2 262 | norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, 263 | name='norm2') 264 | # pool2 265 | pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], 266 | strides=[1, 2, 2, 1], padding='SAME', name='pool2') 267 | 268 | # local3 269 | with tf.variable_scope('local3') as scope: 270 | # Move everything into depth so we can perform a single matrix multiply. 271 | reshape = tf.reshape(pool2, [FLAGS.batch_size, -1]) 272 | dim = reshape.get_shape()[1].value 273 | weights = _variable_with_weight_decay('weights', shape=[dim, 384], 274 | stddev=0.04, wd=0.004) 275 | biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1)) 276 | local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name) 277 | if dropout_flag: 278 | local3 = tf.nn.dropout(local3,dropout) 279 | _activation_summary(local3) 280 | 281 | # local4 282 | with tf.variable_scope('local4') as scope: 283 | weights = _variable_with_weight_decay('weights', shape=[384, 192], 284 | stddev=0.04, wd=0.004) 285 | biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1)) 286 | local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name=scope.name) 287 | if dropout_flag: 288 | local4 = tf.nn.dropout(local4,dropout) 289 | _activation_summary(local4) 290 | 291 | # linear layer(WX + b), 292 | # We don't apply softmax here because 293 | # tf.nn.sparse_softmax_cross_entropy_with_logits accepts the unscaled logits 294 | # and performs the softmax internally for efficiency. 295 | with tf.variable_scope('softmax_linear') as scope: 296 | weights = _variable_with_weight_decay('weights', [192, NUM_CLASSES], 297 | stddev=1/192.0, wd=None) 298 | biases = _variable_on_cpu('biases', [NUM_CLASSES], 299 | tf.constant_initializer(0.0)) 300 | softmax_linear = tf.add(tf.matmul(local4, weights), biases, name=scope.name) 301 | _activation_summary(softmax_linear) 302 | 303 | return softmax_linear 304 | 305 | 306 | def loss(logits, labels): 307 | """Add L2Loss to all the trainable variables. 308 | 309 | Add summary for "Loss" and "Loss/avg". 310 | Args: 311 | logits: Logits from inference(). 312 | labels: Labels from distorted_inputs or inputs(). 1-D tensor 313 | of shape [batch_size] 314 | 315 | Returns: 316 | Loss tensor of type float. 317 | """ 318 | # Calculate the average cross entropy loss across the batch. 319 | labels = tf.cast(labels, tf.int64) 320 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( 321 | labels=labels, logits=logits, name='cross_entropy_per_example') 322 | cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy') 323 | tf.add_to_collection('losses', cross_entropy_mean) 324 | 325 | # The total loss is defined as the cross entropy loss plus all of the weight 326 | # decay terms (L2 loss). 327 | return tf.add_n(tf.get_collection('losses'), name='total_loss') 328 | 329 | 330 | def _add_loss_summaries(total_loss): 331 | """Add summaries for losses in CIFAR-10 model. 332 | 333 | Generates moving average for all losses and associated summaries for 334 | visualizing the performance of the network. 335 | 336 | Args: 337 | total_loss: Total loss from loss(). 338 | Returns: 339 | loss_averages_op: op for generating moving averages of losses. 340 | """ 341 | # Compute the moving average of all individual losses and the total loss. 342 | loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') 343 | losses = tf.get_collection('losses') 344 | loss_averages_op = loss_averages.apply(losses + [total_loss]) 345 | 346 | # Attach a scalar summary to all individual losses and the total loss; do the 347 | # same for the averaged version of the losses. 348 | for l in losses + [total_loss]: 349 | # Name each loss as '(raw)' and name the moving average version of the loss 350 | # as the original loss name. 351 | tf.summary.scalar(l.op.name + ' (raw)', l) 352 | tf.summary.scalar(l.op.name, loss_averages.average(l)) 353 | 354 | return loss_averages_op 355 | 356 | 357 | def train(total_loss, global_step, variable_list=None, return_variable_averages=False, return_lr=False): 358 | """Train CIFAR-10 model. 359 | 360 | Create an optimizer and apply to all trainable variables. Add moving 361 | average for all trainable variables. 362 | 363 | Args: 364 | total_loss: Total loss from loss(). 365 | global_step: Integer Variable counting the number of training steps 366 | processed. 367 | Returns: 368 | train_op: op for training. 369 | """ 370 | # Variables that affect learning rate. 371 | num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size 372 | decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY) 373 | 374 | #Decay the learning rate exponentially based on the number of steps. 375 | lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE, 376 | global_step, 377 | decay_steps, 378 | LEARNING_RATE_DECAY_FACTOR, 379 | staircase=True) 380 | lr = tf.maximum(lr,1e-4) # this is to balance its effect compared with other learning rates of other optimizers 381 | tf.summary.scalar('learning_rate', lr) 382 | 383 | # Generate moving averages of all losses and associated summaries. 384 | loss_averages_op = _add_loss_summaries(total_loss) 385 | 386 | # Compute gradients. 387 | with tf.control_dependencies([loss_averages_op]): 388 | opt = tf.train.GradientDescentOptimizer(lr) 389 | grads = opt.compute_gradients(total_loss) 390 | 391 | capped_grads = [] 392 | if variable_list: 393 | for item in grads: 394 | if item[1] in variable_list: 395 | capped_grads.append(item) 396 | else: 397 | pass 398 | else: 399 | capped_grads = grads 400 | 401 | # Apply gradients. 402 | apply_gradient_op = opt.apply_gradients(capped_grads, global_step=global_step) 403 | 404 | # Add histograms for trainable variables. 405 | for var in tf.trainable_variables(): 406 | tf.summary.histogram(var.op.name, var) 407 | 408 | # Add histograms for gradients. 409 | for grad, var in grads: 410 | if grad is not None: 411 | tf.summary.histogram(var.op.name + '/gradients', grad) 412 | 413 | # Track the moving averages of all trainable variables. 414 | variable_averages = tf.train.ExponentialMovingAverage( 415 | MOVING_AVERAGE_DECAY, global_step) 416 | if variable_list: 417 | variables_averages_op = variable_averages.apply(variable_list) 418 | else: 419 | variables_averages_op = variable_averages.apply(tf.trainable_variables()) 420 | 421 | with tf.control_dependencies([apply_gradient_op, variables_averages_op]): 422 | train_op = tf.no_op(name='train') 423 | 424 | if return_variable_averages and return_lr: 425 | return train_op, variable_averages, lr 426 | elif return_variable_averages: 427 | return train_op, variable_averages 428 | else: 429 | return train_op 430 | 431 | def maybe_download_and_extract(): 432 | """Download and extract the tarball from Alex's website.""" 433 | dest_directory = FLAGS.data_dir 434 | if not os.path.exists(dest_directory): 435 | os.makedirs(dest_directory) 436 | filename = DATA_URL.split('/')[-1] 437 | filepath = os.path.join(dest_directory, filename) 438 | if not os.path.exists(filepath): 439 | def _progress(count, block_size, total_size): 440 | sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, 441 | float(count * block_size) / float(total_size) * 100.0)) 442 | sys.stdout.flush() 443 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 444 | print() 445 | statinfo = os.stat(filepath) 446 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 447 | extracted_dir_path = os.path.join(dest_directory, 'cifar-10-batches-bin') 448 | if not os.path.exists(extracted_dir_path): 449 | tarfile.open(filepath, 'r:gz').extractall(dest_directory) 450 | --------------------------------------------------------------------------------