├── .gitignore ├── README.md ├── code ├── ae │ ├── __init__.py │ ├── autoencoder.py │ ├── autoencoder_test.py │ └── utils │ │ ├── __init__.py │ │ ├── data.py │ │ ├── eval.py │ │ ├── flags.py │ │ ├── start_tensorboard.py │ │ └── utils.py ├── requirements.txt └── run.py ├── cpu └── Dockerfile ├── filters_1.png ├── setup_linux ├── setup_mac └── tb_hist.png /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | nohup.out 3 | *checkpoint* 4 | *.pyc 5 | tbpid 6 | venv/* 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Autoencoder with TensorFlow 2 | 3 |
4 |
5 |
7 | A selection of first layer weight filters learned during the pretraining 8 |
9 | 10 | ## Introduction 11 | The purpose of this repo is to explore the functionality of Google's recently open-sourced 12 | "sofware library for numerical computation using data flow graphs", 13 | [TensorFlow](https://www.tensorflow.org/). We use the library to train 14 | a deep autoencoder on the MNIST digit data set. For background and a similar implementation using 15 | [Theano](http://deeplearning.net/software/theano/) see the tutorial at [http://www.deeplearning.net/tutorial/SdA.html](http://www.deeplearning.net/tutorial/SdA.html). 16 | 17 | The main training code can be found in [autoencoder.py](https://github.com/cmgreen210/TensorFlowDeepAutoencoder/blob/master/code/ae/autoencoder.py) along with the AutoEncoder class that creates and manages the Variables and Tensors used. 18 | 19 | ## Docker Setup (CPU version only for the time being) 20 | In order to avoid platform issues it's highly encouraged that you run 21 | the example code in a [Docker](https://www.docker.com/) container. Follow 22 | the Docker installation instructions on the website. Then run: 23 | 24 | ```bash 25 | $ git clone https://github.com/cmgreen210/TensorFlowDeepAutoencoder 26 | $ cd TensorFlowDeepAutoencoder 27 | $ docker build -t tfdae -f cpu/Dockerfile . 28 | $ docker run -it -p 80:6006 tfdae python run.py 29 | ``` 30 | 31 | Navigate to http://localhost:80 32 | to explore [TensorBoard](https://www.tensorflow.org/versions/master/how_tos/summaries_and_tensorboard/index.html#tensorboard-visualizing-learning) and view the training progress. 33 |
34 |
35 |
37 | View of TensorBoard's display of weight and bias parameter progress. 38 |
39 | ## Customizing 40 | You can play around with the run options, including the neural net size and shape, input corruption, learning rates, etc. 41 | in [flags.py](https://github.com/cmgreen210/TensorFlowDeepAutoencoder/blob/master/code/ae/utils/flags.py). 42 | 43 | ## Old Setup 44 | It is expected that Python2.7 is installed and your default python version. 45 | ### Ubuntu/Linux 46 | ```bash 47 | $ git clone https://github.com/cmgreen210/TensorFlowDeepAutoencoder 48 | $ cd TensorFlowDeepAutoencoder 49 | $ sudo chmod +x setup_linux 50 | $ sudo ./setup_linux # If you want GPU version specify -g or --gpu 51 | $ source venv/bin/activate 52 | ``` 53 | ### Mac OS X 54 | ```bash 55 | $ git clone https://github.com/cmgreen210/TensorFlowDeepAutoencoder 56 | $ cd TensorFlowDeepAutoencoder 57 | $ sudo chmod +x setup_mac 58 | $ sudo ./setup_mac 59 | $ source venv/bin/activate 60 | ``` 61 | ## Run 62 | To run the default example execute the following command. 63 | NOTE: this will take a very long time if you are running on a CPU as opposed to a GPU 64 | ```bash 65 | $ python code/run.py 66 | ``` 67 | 68 | Navigate to http://localhost:6006 69 | to explore [TensorBoard](https://www.tensorflow.org/versions/master/how_tos/summaries_and_tensorboard/index.html#tensorboard-visualizing-learning) and view training progress. 70 |
71 |
72 |
74 | View of TensorBoard's display of weight and bias parameter progress. 75 |
76 | ## Customizing 77 | You can play around with the run options, including the neural net size and shape, input corruption, learning rates, etc. 78 | in [flags.py](https://github.com/cmgreen210/TensorFlowDeepAutoencoder/blob/master/code/ae/utils/flags.py). 79 | -------------------------------------------------------------------------------- /code/ae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmgreen210/TensorFlowDeepAutoencoder/5298ec437689ba7ecb59229599141549ef6a6a1d/code/ae/__init__.py -------------------------------------------------------------------------------- /code/ae/autoencoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | import time 4 | from os.path import join as pjoin 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | from utils.data import fill_feed_dict_ae, read_data_sets_pretraining 9 | from utils.data import read_data_sets, fill_feed_dict 10 | from utils.flags import FLAGS 11 | from utils.eval import loss_supervised, evaluation, do_eval_summary 12 | from utils.utils import tile_raster_images 13 | 14 | 15 | class AutoEncoder(object): 16 | """Generic deep autoencoder. 17 | 18 | Autoencoder used for full training cycle, including 19 | unsupervised pretraining layers and final fine tuning. 20 | The user specifies the structure of the neural net 21 | by specifying number of inputs, the number of hidden 22 | units for each layer and the number of final output 23 | logits. 24 | """ 25 | _weights_str = "weights{0}" 26 | _biases_str = "biases{0}" 27 | 28 | def __init__(self, shape, sess): 29 | """Autoencoder initializer 30 | 31 | Args: 32 | shape: list of ints specifying 33 | num input, hidden1 units,...hidden_n units, num logits 34 | sess: tensorflow session object to use 35 | """ 36 | self.__shape = shape # [input_dim,hidden1_dim,...,hidden_n_dim,output_dim] 37 | self.__num_hidden_layers = len(self.__shape) - 2 38 | 39 | self.__variables = {} 40 | self.__sess = sess 41 | 42 | self._setup_variables() 43 | 44 | @property 45 | def shape(self): 46 | return self.__shape 47 | 48 | @property 49 | def num_hidden_layers(self): 50 | return self.__num_hidden_layers 51 | 52 | @property 53 | def session(self): 54 | return self.__sess 55 | 56 | def __getitem__(self, item): 57 | """Get autoencoder tf variable 58 | 59 | Returns the specified variable created by this object. 60 | Names are weights#, biases#, biases#_out, weights#_fixed, 61 | biases#_fixed. 62 | 63 | Args: 64 | item: string, variables internal name 65 | Returns: 66 | Tensorflow variable 67 | """ 68 | return self.__variables[item] 69 | 70 | def __setitem__(self, key, value): 71 | """Store a tensorflow variable 72 | 73 | NOTE: Don't call this explicity. It should 74 | be used only internally when setting up 75 | variables. 76 | 77 | Args: 78 | key: string, name of variable 79 | value: tensorflow variable 80 | """ 81 | self.__variables[key] = value 82 | 83 | def _setup_variables(self): 84 | with tf.name_scope("autoencoder_variables"): 85 | for i in xrange(self.__num_hidden_layers + 1): 86 | # Train weights 87 | name_w = self._weights_str.format(i + 1) 88 | w_shape = (self.__shape[i], self.__shape[i + 1]) 89 | a = tf.mul(4.0, tf.sqrt(6.0 / (w_shape[0] + w_shape[1]))) 90 | w_init = tf.random_uniform(w_shape, -1 * a, a) 91 | self[name_w] = tf.Variable(w_init, 92 | name=name_w, 93 | trainable=True) 94 | # Train biases 95 | name_b = self._biases_str.format(i + 1) 96 | b_shape = (self.__shape[i + 1],) 97 | b_init = tf.zeros(b_shape) 98 | self[name_b] = tf.Variable(b_init, trainable=True, name=name_b) 99 | 100 | if i < self.__num_hidden_layers: 101 | # Hidden layer fixed weights (after pretraining before fine tuning) 102 | self[name_w + "_fixed"] = tf.Variable(tf.identity(self[name_w]), 103 | name=name_w + "_fixed", 104 | trainable=False) 105 | 106 | # Hidden layer fixed biases 107 | self[name_b + "_fixed"] = tf.Variable(tf.identity(self[name_b]), 108 | name=name_b + "_fixed", 109 | trainable=False) 110 | 111 | # Pretraining output training biases 112 | name_b_out = self._biases_str.format(i + 1) + "_out" 113 | b_shape = (self.__shape[i],) 114 | b_init = tf.zeros(b_shape) 115 | self[name_b_out] = tf.Variable(b_init, 116 | trainable=True, 117 | name=name_b_out) 118 | 119 | def _w(self, n, suffix=""): 120 | return self[self._weights_str.format(n) + suffix] 121 | 122 | def _b(self, n, suffix=""): 123 | return self[self._biases_str.format(n) + suffix] 124 | 125 | def get_variables_to_init(self, n): 126 | """Return variables that need initialization 127 | 128 | This method aides in the initialization of variables 129 | before training begins at step n. The returned 130 | list should be than used as the input to 131 | tf.initialize_variables 132 | 133 | Args: 134 | n: int giving step of training 135 | """ 136 | assert n > 0 137 | assert n <= self.__num_hidden_layers + 1 138 | 139 | vars_to_init = [self._w(n), self._b(n)] 140 | 141 | if n <= self.__num_hidden_layers: 142 | vars_to_init.append(self._b(n, "_out")) 143 | 144 | if 1 < n <= self.__num_hidden_layers: 145 | vars_to_init.append(self._w(n - 1, "_fixed")) 146 | vars_to_init.append(self._b(n - 1, "_fixed")) 147 | 148 | return vars_to_init 149 | 150 | @staticmethod 151 | def _activate(x, w, b, transpose_w=False): 152 | y = tf.sigmoid(tf.nn.bias_add(tf.matmul(x, w, transpose_b=transpose_w), b)) 153 | return y 154 | 155 | def pretrain_net(self, input_pl, n, is_target=False): 156 | """Return net for step n training or target net 157 | 158 | Args: 159 | input_pl: tensorflow placeholder of AE inputs 160 | n: int specifying pretrain step 161 | is_target: bool specifying if required tensor 162 | should be the target tensor 163 | Returns: 164 | Tensor giving pretraining net or pretraining target 165 | """ 166 | assert n > 0 167 | assert n <= self.__num_hidden_layers 168 | 169 | last_output = input_pl 170 | for i in xrange(n - 1): 171 | w = self._w(i + 1, "_fixed") 172 | b = self._b(i + 1, "_fixed") 173 | 174 | last_output = self._activate(last_output, w, b) 175 | 176 | if is_target: 177 | return last_output 178 | 179 | last_output = self._activate(last_output, self._w(n), self._b(n)) 180 | 181 | out = self._activate(last_output, self._w(n), self._b(n, "_out"), 182 | transpose_w=True) 183 | out = tf.maximum(out, 1.e-9) 184 | out = tf.minimum(out, 1 - 1.e-9) 185 | return out 186 | 187 | def supervised_net(self, input_pl): 188 | """Get the supervised fine tuning net 189 | 190 | Args: 191 | input_pl: tf placeholder for ae input data 192 | Returns: 193 | Tensor giving full ae net 194 | """ 195 | last_output = input_pl 196 | 197 | for i in xrange(self.__num_hidden_layers + 1): 198 | # Fine tuning will be done on these variables 199 | w = self._w(i + 1) 200 | b = self._b(i + 1) 201 | 202 | last_output = self._activate(last_output, w, b) 203 | 204 | return last_output 205 | 206 | 207 | loss_summaries = {} 208 | 209 | 210 | def training(loss, learning_rate, loss_key=None): 211 | """Sets up the training Ops. 212 | 213 | Creates a summarizer to track the loss over time in TensorBoard. 214 | 215 | Creates an optimizer and applies the gradients to all trainable variables. 216 | 217 | The Op returned by this function is what must be passed to the 218 | `sess.run()` call to cause the model to train. 219 | 220 | Args: 221 | loss: Loss tensor, from loss(). 222 | learning_rate: The learning rate to use for gradient descent. 223 | loss_key: int giving stage of pretraining so we can store 224 | loss summaries for each pretraining stage 225 | 226 | Returns: 227 | train_op: The Op for training. 228 | """ 229 | if loss_key is not None: 230 | # Add a scalar summary for the snapshot loss. 231 | loss_summaries[loss_key] = tf.scalar_summary(loss.op.name, loss) 232 | else: 233 | tf.scalar_summary(loss.op.name, loss) 234 | for var in tf.trainable_variables(): 235 | tf.histogram_summary(var.op.name, var) 236 | # Create the gradient descent optimizer with the given learning rate. 237 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) 238 | # Create a variable to track the global step. 239 | global_step = tf.Variable(0, name='global_step', trainable=False) 240 | # Use the optimizer to apply the gradients that minimize the loss 241 | # (and also increment the global step counter) as a single training step. 242 | train_op = optimizer.minimize(loss, global_step=global_step) 243 | return train_op, global_step 244 | 245 | 246 | def loss_x_entropy(output, target): 247 | """Cross entropy loss 248 | 249 | See https://en.wikipedia.org/wiki/Cross_entropy 250 | 251 | Args: 252 | output: tensor of net output 253 | target: tensor of net we are trying to reconstruct 254 | Returns: 255 | Scalar tensor of cross entropy 256 | """ 257 | with tf.name_scope("xentropy_loss"): 258 | net_output_tf = tf.convert_to_tensor(output, name='input') 259 | target_tf = tf.convert_to_tensor(target, name='target') 260 | cross_entropy = tf.add(tf.mul(tf.log(net_output_tf, name='log_output'), 261 | target_tf), 262 | tf.mul(tf.log(1 - net_output_tf), 263 | (1 - target_tf))) 264 | return -1 * tf.reduce_mean(tf.reduce_sum(cross_entropy, 1), 265 | name='xentropy_mean') 266 | 267 | 268 | def main_unsupervised(): 269 | with tf.Graph().as_default() as g: 270 | sess = tf.Session() 271 | 272 | num_hidden = FLAGS.num_hidden_layers 273 | ae_hidden_shapes = [getattr(FLAGS, "hidden{0}_units".format(j + 1)) 274 | for j in xrange(num_hidden)] 275 | ae_shape = [FLAGS.image_pixels] + ae_hidden_shapes + [FLAGS.num_classes] 276 | 277 | ae = AutoEncoder(ae_shape, sess) 278 | 279 | data = read_data_sets_pretraining(FLAGS.data_dir) 280 | num_train = data.train.num_examples 281 | 282 | learning_rates = {j: getattr(FLAGS, 283 | "pre_layer{0}_learning_rate".format(j + 1)) 284 | for j in xrange(num_hidden)} 285 | 286 | noise = {j: getattr(FLAGS, "noise_{0}".format(j + 1)) 287 | for j in xrange(num_hidden)} 288 | 289 | for i in xrange(len(ae_shape) - 2): 290 | n = i + 1 291 | with tf.variable_scope("pretrain_{0}".format(n)): 292 | input_ = tf.placeholder(dtype=tf.float32, 293 | shape=(FLAGS.batch_size, ae_shape[0]), 294 | name='ae_input_pl') 295 | target_ = tf.placeholder(dtype=tf.float32, 296 | shape=(FLAGS.batch_size, ae_shape[0]), 297 | name='ae_target_pl') 298 | layer = ae.pretrain_net(input_, n) 299 | 300 | with tf.name_scope("target"): 301 | target_for_loss = ae.pretrain_net(target_, n, is_target=True) 302 | 303 | loss = loss_x_entropy(layer, target_for_loss) 304 | train_op, global_step = training(loss, learning_rates[i], i) 305 | 306 | summary_dir = pjoin(FLAGS.summary_dir, 'pretraining_{0}'.format(n)) 307 | summary_writer = tf.train.SummaryWriter(summary_dir, 308 | graph_def=sess.graph_def, 309 | flush_secs=FLAGS.flush_secs) 310 | summary_vars = [ae["biases{0}".format(n)], ae["weights{0}".format(n)]] 311 | 312 | hist_summarries = [tf.histogram_summary(v.op.name, v) 313 | for v in summary_vars] 314 | hist_summarries.append(loss_summaries[i]) 315 | summary_op = tf.merge_summary(hist_summarries) 316 | 317 | vars_to_init = ae.get_variables_to_init(n) 318 | vars_to_init.append(global_step) 319 | sess.run(tf.initialize_variables(vars_to_init)) 320 | 321 | print("\n\n") 322 | print("| Training Step | Cross Entropy | Layer | Epoch |") 323 | print("|---------------|---------------|---------|----------|") 324 | 325 | for step in xrange(FLAGS.pretraining_epochs * num_train): 326 | feed_dict = fill_feed_dict_ae(data.train, input_, target_, noise[i]) 327 | 328 | loss_summary, loss_value = sess.run([train_op, loss], 329 | feed_dict=feed_dict) 330 | 331 | if step % 100 == 0: 332 | summary_str = sess.run(summary_op, feed_dict=feed_dict) 333 | summary_writer.add_summary(summary_str, step) 334 | image_summary_op = \ 335 | tf.image_summary("training_images", 336 | tf.reshape(input_, 337 | (FLAGS.batch_size, 338 | FLAGS.image_size, 339 | FLAGS.image_size, 1)), 340 | max_images=FLAGS.batch_size) 341 | 342 | summary_img_str = sess.run(image_summary_op, 343 | feed_dict=feed_dict) 344 | summary_writer.add_summary(summary_img_str) 345 | 346 | output = "| {0:>13} | {1:13.4f} | Layer {2} | Epoch {3} |"\ 347 | .format(step, loss_value, n, step // num_train + 1) 348 | 349 | print(output) 350 | if i == 0: 351 | filters = sess.run(tf.identity(ae["weights1"])) 352 | np.save(pjoin(FLAGS.chkpt_dir, "filters"), filters) 353 | filters = tile_raster_images(X=filters.T, 354 | img_shape=(FLAGS.image_size, 355 | FLAGS.image_size), 356 | tile_shape=(10, 10), 357 | output_pixel_vals=False) 358 | filters = np.expand_dims(np.expand_dims(filters, 0), 3) 359 | image_var = tf.Variable(filters) 360 | image_filter = tf.identity(image_var) 361 | sess.run(tf.initialize_variables([image_var])) 362 | img_filter_summary_op = tf.image_summary("first_layer_filters", 363 | image_filter) 364 | summary_writer.add_summary(sess.run(img_filter_summary_op)) 365 | summary_writer.flush() 366 | 367 | return ae 368 | 369 | 370 | def main_supervised(ae): 371 | with ae.session.graph.as_default(): 372 | sess = ae.session 373 | input_pl = tf.placeholder(tf.float32, shape=(FLAGS.batch_size, 374 | FLAGS.image_pixels), 375 | name='input_pl') 376 | logits = ae.supervised_net(input_pl) 377 | 378 | data = read_data_sets(FLAGS.data_dir) 379 | num_train = data.train.num_examples 380 | 381 | labels_placeholder = tf.placeholder(tf.int32, 382 | shape=FLAGS.batch_size, 383 | name='target_pl') 384 | 385 | loss = loss_supervised(logits, labels_placeholder) 386 | train_op, global_step = training(loss, FLAGS.supervised_learning_rate) 387 | eval_correct = evaluation(logits, labels_placeholder) 388 | 389 | hist_summaries = [ae['biases{0}'.format(i + 1)] 390 | for i in xrange(ae.num_hidden_layers + 1)] 391 | hist_summaries.extend([ae['weights{0}'.format(i + 1)] 392 | for i in xrange(ae.num_hidden_layers + 1)]) 393 | 394 | hist_summaries = [tf.histogram_summary(v.op.name + "_fine_tuning", v) 395 | for v in hist_summaries] 396 | summary_op = tf.merge_summary(hist_summaries) 397 | 398 | summary_writer = tf.train.SummaryWriter(pjoin(FLAGS.summary_dir, 399 | 'fine_tuning'), 400 | graph_def=sess.graph_def, 401 | flush_secs=FLAGS.flush_secs) 402 | 403 | vars_to_init = ae.get_variables_to_init(ae.num_hidden_layers + 1) 404 | vars_to_init.append(global_step) 405 | sess.run(tf.initialize_variables(vars_to_init)) 406 | 407 | steps = FLAGS.finetuning_epochs * num_train 408 | for step in xrange(steps): 409 | start_time = time.time() 410 | 411 | feed_dict = fill_feed_dict(data.train, 412 | input_pl, 413 | labels_placeholder) 414 | 415 | _, loss_value = sess.run([train_op, loss], 416 | feed_dict=feed_dict) 417 | 418 | duration = time.time() - start_time 419 | 420 | # Write the summaries and print an overview fairly often. 421 | if step % 100 == 0: 422 | # Print status to stdout. 423 | print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) 424 | # Update the events file. 425 | 426 | summary_str = sess.run(summary_op, feed_dict=feed_dict) 427 | summary_writer.add_summary(summary_str, step) 428 | summary_img_str = sess.run( 429 | tf.image_summary("training_images", 430 | tf.reshape(input_pl, 431 | (FLAGS.batch_size, 432 | FLAGS.image_size, 433 | FLAGS.image_size, 1)), 434 | max_images=FLAGS.batch_size), 435 | feed_dict=feed_dict 436 | ) 437 | summary_writer.add_summary(summary_img_str) 438 | 439 | if (step + 1) % 1000 == 0 or (step + 1) == steps: 440 | train_sum = do_eval_summary("training_error", 441 | sess, 442 | eval_correct, 443 | input_pl, 444 | labels_placeholder, 445 | data.train) 446 | 447 | val_sum = do_eval_summary("validation_error", 448 | sess, 449 | eval_correct, 450 | input_pl, 451 | labels_placeholder, 452 | data.validation) 453 | 454 | test_sum = do_eval_summary("test_error", 455 | sess, 456 | eval_correct, 457 | input_pl, 458 | labels_placeholder, 459 | data.test) 460 | 461 | summary_writer.add_summary(train_sum, step) 462 | summary_writer.add_summary(val_sum, step) 463 | summary_writer.add_summary(test_sum, step) 464 | 465 | if __name__ == '__main__': 466 | ae = main_unsupervised() 467 | main_supervised(ae) 468 | -------------------------------------------------------------------------------- /code/ae/autoencoder_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import tensorflow as tf 5 | from autoencoder import AutoEncoder 6 | 7 | 8 | class AutoEncoderTest(tf.test.TestCase): 9 | 10 | def test_constructor(self): 11 | with self.test_session() as sess: 12 | 13 | ae_shape = [10, 20, 30, 2] 14 | self.assertTrue(AutoEncoder(ae_shape, sess)) 15 | 16 | def test_get_variables(self): 17 | with self.test_session() as sess: 18 | ae_shape = [10, 20, 30, 2] 19 | ae = AutoEncoder(ae_shape, sess) 20 | 21 | with self.assertRaises(AssertionError): 22 | ae.get_variables_to_init(0) 23 | with self.assertRaises(AssertionError): 24 | ae.get_variables_to_init(4) 25 | 26 | v1 = ae.get_variables_to_init(1) 27 | self.assertEqual(len(v1), 3) 28 | 29 | v2 = ae.get_variables_to_init(2) 30 | self.assertEqual(len(v2), 5) 31 | 32 | v3 = ae.get_variables_to_init(3) 33 | self.assertEqual(len(v3), 2) 34 | 35 | def test_nets(self): 36 | with self.test_session() as sess: 37 | ae_shape = [10, 20, 30, 2] 38 | ae = AutoEncoder(ae_shape, sess) 39 | 40 | input_pl = tf.placeholder(tf.float32, shape=(100, 10)) 41 | with self.assertRaises(AssertionError): 42 | ae.pretrain_net(input_pl, 0) 43 | with self.assertRaises(AssertionError): 44 | ae.pretrain_net(input_pl, 3) 45 | 46 | net1 = ae.pretrain_net(input_pl, 1) 47 | net2 = ae.pretrain_net(input_pl, 2) 48 | 49 | self.assertEqual(net1.get_shape().dims[1].value, 10) 50 | self.assertEqual(net2.get_shape().dims[1].value, 20) 51 | 52 | net1_target = ae.pretrain_net(input_pl, 1, is_target=True) 53 | self.assertEqual(net1_target.get_shape().dims[1].value, 10) 54 | net2_target = ae.pretrain_net(input_pl, 2, is_target=True) 55 | self.assertEqual(net2_target.get_shape().dims[1].value, 20) 56 | 57 | sup_net = ae.supervised_net(input_pl) 58 | self.assertEqual(sup_net.get_shape().dims[1].value, 2) 59 | -------------------------------------------------------------------------------- /code/ae/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmgreen210/TensorFlowDeepAutoencoder/5298ec437689ba7ecb59229599141549ef6a6a1d/code/ae/utils/__init__.py -------------------------------------------------------------------------------- /code/ae/utils/data.py: -------------------------------------------------------------------------------- 1 | """Functions for downloading and reading MNIST data.""" 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import gzip 6 | 7 | import numpy 8 | 9 | from six.moves import urllib 10 | from six.moves import xrange # pylint: disable=redefined-builtin 11 | from flags import FLAGS 12 | import os 13 | SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' 14 | 15 | 16 | def maybe_download(filename, work_directory): 17 | """Download the data from Yann's website, unless it's already here.""" 18 | if not os.path.exists(work_directory): 19 | os.mkdir(work_directory) 20 | filepath = os.path.join(work_directory, filename) 21 | if not os.path.exists(filepath): 22 | filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) 23 | statinfo = os.stat(filepath) 24 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 25 | return filepath 26 | 27 | 28 | def _read32(bytestream): 29 | dt = numpy.dtype(numpy.uint32).newbyteorder('>') 30 | return numpy.frombuffer(bytestream.read(4), dtype=dt) 31 | 32 | 33 | def extract_images(filename): 34 | """Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" 35 | print('\nExtracting', filename) 36 | with gzip.open(filename) as bytestream: 37 | magic = _read32(bytestream) 38 | if magic != 2051: 39 | raise ValueError( 40 | 'Invalid magic number %d in MNIST image file: %s' % 41 | (magic, filename)) 42 | num_images = _read32(bytestream) 43 | rows = _read32(bytestream) 44 | cols = _read32(bytestream) 45 | buf = bytestream.read(rows * cols * num_images) 46 | data = numpy.frombuffer(buf, dtype=numpy.uint8) 47 | data = data.reshape(num_images, rows, cols, 1) 48 | return data 49 | 50 | 51 | def dense_to_one_hot(labels_dense, num_classes=10): 52 | """Convert class labels from scalars to one-hot vectors.""" 53 | num_labels = labels_dense.shape[0] 54 | index_offset = numpy.arange(num_labels) * num_classes 55 | labels_one_hot = numpy.zeros((num_labels, num_classes)) 56 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 57 | return labels_one_hot 58 | 59 | 60 | def extract_labels(filename, one_hot=False): 61 | """Extract the labels into a 1D uint8 numpy array [index].""" 62 | print('Extracting', filename) 63 | with gzip.open(filename) as bytestream: 64 | magic = _read32(bytestream) 65 | if magic != 2049: 66 | raise ValueError( 67 | 'Invalid magic number %d in MNIST label file: %s' % 68 | (magic, filename)) 69 | num_items = _read32(bytestream) 70 | buf = bytestream.read(num_items) 71 | labels = numpy.frombuffer(buf, dtype=numpy.uint8) 72 | if one_hot: 73 | return dense_to_one_hot(labels) 74 | return labels 75 | 76 | 77 | class DataSet(object): 78 | 79 | def __init__(self, images, labels, fake_data=False): 80 | if fake_data: 81 | self._num_examples = 10000 82 | else: 83 | assert images.shape[0] == labels.shape[0], ( 84 | "images.shape: %s labels.shape: %s" % (images.shape, 85 | labels.shape)) 86 | self._num_examples = images.shape[0] 87 | 88 | # Convert shape from [num examples, rows, columns, depth] 89 | # to [num examples, rows*columns] (assuming depth == 1) 90 | assert images.shape[3] == 1 91 | images = images.reshape(images.shape[0], 92 | images.shape[1] * images.shape[2]) 93 | # Convert from [0, 255] -> [0.0, 1.0]. 94 | images = images.astype(numpy.float32) 95 | images = numpy.multiply(images, 1.0 / 255.0) 96 | self._images = images 97 | self._labels = labels 98 | self._epochs_completed = 0 99 | self._index_in_epoch = 0 100 | 101 | @property 102 | def images(self): 103 | return self._images 104 | 105 | @property 106 | def labels(self): 107 | return self._labels 108 | 109 | @property 110 | def num_examples(self): 111 | return self._num_examples 112 | 113 | @property 114 | def epochs_completed(self): 115 | return self._epochs_completed 116 | 117 | def next_batch(self, batch_size): 118 | """Return the next `batch_size` examples from this data set.""" 119 | start = self._index_in_epoch 120 | self._index_in_epoch += batch_size 121 | if self._index_in_epoch > self._num_examples: 122 | # Finished epoch 123 | self._epochs_completed += 1 124 | # Shuffle the data 125 | perm = numpy.arange(self._num_examples) 126 | numpy.random.shuffle(perm) 127 | self._images = self._images[perm] 128 | self._labels = self._labels[perm] 129 | # Start next epoch 130 | start = 0 131 | self._index_in_epoch = batch_size 132 | assert batch_size <= self._num_examples 133 | end = self._index_in_epoch 134 | return self._images[start:end], self._labels[start:end] 135 | 136 | 137 | class DataSetPreTraining(object): 138 | 139 | def __init__(self, images): 140 | self._num_examples = images.shape[0] 141 | 142 | # Convert shape from [num examples, rows, columns, depth] 143 | # to [num examples, rows*columns] (assuming depth == 1) 144 | assert images.shape[3] == 1 145 | images = images.reshape(images.shape[0], 146 | images.shape[1] * images.shape[2]) 147 | # Convert from [0, 255] -> [0.0, 1.0]. 148 | images = images.astype(numpy.float32) 149 | images = numpy.multiply(images, 1.0 / 255.0) 150 | self._images = images 151 | self._images[self._images < FLAGS.zero_bound] = FLAGS.zero_bound 152 | self._images[self._images > FLAGS.one_bound] = FLAGS.one_bound 153 | self._epochs_completed = 0 154 | self._index_in_epoch = 0 155 | 156 | @property 157 | def images(self): 158 | return self._images 159 | 160 | @property 161 | def num_examples(self): 162 | return self._num_examples 163 | 164 | @property 165 | def epochs_completed(self): 166 | return self._epochs_completed 167 | 168 | def next_batch(self, batch_size): 169 | """Return the next `batch_size` examples from this data set.""" 170 | start = self._index_in_epoch 171 | self._index_in_epoch += batch_size 172 | if self._index_in_epoch > self._num_examples: 173 | # Finished epoch 174 | self._epochs_completed += 1 175 | # Shuffle the data 176 | perm = numpy.arange(self._num_examples) 177 | numpy.random.shuffle(perm) 178 | self._images = self._images[perm] 179 | # Start next epoch 180 | start = 0 181 | self._index_in_epoch = batch_size 182 | assert batch_size <= self._num_examples 183 | end = self._index_in_epoch 184 | 185 | return self._images[start:end], self._images[start:end] 186 | 187 | 188 | def read_data_sets(train_dir, fake_data=False, one_hot=False): 189 | class DataSets(object): 190 | pass 191 | data_sets = DataSets() 192 | 193 | if fake_data: 194 | data_sets.train = DataSet([], [], fake_data=True) 195 | data_sets.validation = DataSet([], [], fake_data=True) 196 | data_sets.test = DataSet([], [], fake_data=True) 197 | return data_sets 198 | 199 | TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' 200 | TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' 201 | TEST_IMAGES = 't10k-images-idx3-ubyte.gz' 202 | TEST_LABELS = 't10k-labels-idx1-ubyte.gz' 203 | VALIDATION_SIZE = 5000 204 | 205 | local_file = maybe_download(TRAIN_IMAGES, train_dir) 206 | train_images = extract_images(local_file) 207 | 208 | local_file = maybe_download(TRAIN_LABELS, train_dir) 209 | train_labels = extract_labels(local_file, one_hot=one_hot) 210 | 211 | local_file = maybe_download(TEST_IMAGES, train_dir) 212 | test_images = extract_images(local_file) 213 | 214 | local_file = maybe_download(TEST_LABELS, train_dir) 215 | test_labels = extract_labels(local_file, one_hot=one_hot) 216 | 217 | validation_images = train_images[:VALIDATION_SIZE] 218 | validation_labels = train_labels[:VALIDATION_SIZE] 219 | train_images = train_images[VALIDATION_SIZE:] 220 | train_labels = train_labels[VALIDATION_SIZE:] 221 | 222 | data_sets.train = DataSet(train_images, train_labels) 223 | data_sets.validation = DataSet(validation_images, validation_labels) 224 | data_sets.test = DataSet(test_images, test_labels) 225 | 226 | return data_sets 227 | 228 | 229 | def read_data_sets_pretraining(train_dir): 230 | class DataSets(object): 231 | pass 232 | data_sets = DataSets() 233 | 234 | TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' 235 | TEST_IMAGES = 't10k-images-idx3-ubyte.gz' 236 | VALIDATION_SIZE = 5000 237 | 238 | local_file = maybe_download(TRAIN_IMAGES, train_dir) 239 | train_images = extract_images(local_file) 240 | 241 | local_file = maybe_download(TEST_IMAGES, train_dir) 242 | test_images = extract_images(local_file) 243 | 244 | validation_images = train_images[:VALIDATION_SIZE] 245 | train_images = train_images[VALIDATION_SIZE:] 246 | 247 | data_sets.train = DataSetPreTraining(train_images) 248 | data_sets.validation = DataSetPreTraining(validation_images) 249 | data_sets.test = DataSetPreTraining(test_images) 250 | 251 | return data_sets 252 | 253 | 254 | def _add_noise(x, rate): 255 | x_cp = numpy.copy(x) 256 | pix_to_drop = numpy.random.rand(x_cp.shape[0], 257 | x_cp.shape[1]) < rate 258 | x_cp[pix_to_drop] = FLAGS.zero_bound 259 | return x_cp 260 | 261 | 262 | def fill_feed_dict_ae(data_set, input_pl, target_pl, noise=None): 263 | input_feed, target_feed = data_set.next_batch(FLAGS.batch_size) 264 | if noise: 265 | input_feed = _add_noise(input_feed, noise) 266 | feed_dict = { 267 | input_pl: input_feed, 268 | target_pl: target_feed 269 | } 270 | return feed_dict 271 | 272 | 273 | def fill_feed_dict(data_set, images_pl, labels_pl, noise=False): 274 | """Fills the feed_dict for training the given step. 275 | A feed_dict takes the form of: 276 | feed_dict = { 277 |