├── .idea ├── Adversarial_Autoencoder.iml ├── inspectionProfiles │ └── Project_Default.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── LICENSE ├── README.md ├── README ├── AAE Block Diagram.png ├── AAE dist match.png ├── Supervised AAE.png ├── aa_encoder_dist.png ├── aa_real_dist.png ├── adversarial_autoencoder.png ├── adversarial_autoencoder_2.png ├── autoencoder_architecture.png ├── cat_n_gauss_dist_real_obtained.png ├── cover.png ├── disentanglement of style and content.png ├── grid_175450.png ├── grid_177650.png ├── nw_architecture.png ├── semi_1000.png ├── semi_9960.png ├── semi_9970.png ├── semi_9980.png ├── semi_9990.png ├── semi_AAE architecture.png ├── semi_aae_accuracy_with_NN.png ├── semi_e_c.png ├── semi_e_g.png ├── semi_r_c.png ├── semi_r_g.png └── supervised_autoencoder_100.png ├── Results ├── .gitkeep ├── Adversarial_Autoencoder │ └── .gitkeep ├── Autoencoder │ └── .gitkeep ├── Basic_NN_Classifier │ └── .gitkeep ├── Semi_Supervised │ └── .gitkeep └── Supervised │ └── .gitkeep ├── _config.yml ├── adversarial_autoencoder.py ├── autoencoder.py ├── basic_nn_classifier.py ├── requirements.txt ├── semi_supervised_adversarial_autoencoder.py └── supervised_adversarial_autoencoder.py /.idea/Adversarial_Autoencoder.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 17 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Naresh Nagabushan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial autoencoders 2 | Cover 3 | 4 | This repository contains code to implement adversarial autoencoder using Tensorflow. 5 | 6 | Medium posts: 7 | 8 | 1. [A Wizard's guide to Adversarial Autoencoders: Part 1. Autoencoders?](https://medium.com/towards-data-science/a-wizards-guide-to-adversarial-autoencoders-part-1-autoencoder-d9a5f8795af4) 9 | 10 | 2. [A Wizard's guide to Adversarial Autoencoders: Part 2. Exploring the latent space with Adversarial Autoencoders.](https://medium.com/towards-data-science/a-wizards-guide-to-adversarial-autoencoders-part-2-exploring-latent-space-with-adversarial-2d53a6f8a4f9) 11 | 12 | 3. [A Wizard's guide to Adversarial Autoencoders: Part 3. Disentanglement of style and content.](https://medium.com/towards-data-science/a-wizards-guide-to-adversarial-autoencoders-part-3-disentanglement-of-style-and-content-89262973a4d7) 13 | 14 | 3. [A Wizard's guide to Adversarial Autoencoders: Part 4. Classify MNIST using 1000 labels.](https://medium.com/towards-data-science/a-wizards-guide-to-adversarial-autoencoders-part-4-classify-mnist-using-1000-labels-2ca08071f95) 15 | 16 | ## Installing the dependencies 17 | Install virtualenv and creating a new virtual environment: 18 | 19 | pip install virtualenv 20 | virtualenv -p /usr/bin/python3 aa 21 | 22 | Install dependencies 23 | 24 | pip3 install -r requirements.txt 25 | 26 | ***Note:*** 27 | 28 | * *I'd highly recommend using your GPU during training.* 29 | * *`tf.nn.sigmoid_cross_entropy_with_logits` has a `targets` parameter which 30 | has been changed to `labels` for tensorflow version > r0.12.* 31 | 32 | ## Dataset 33 | The MNIST dataset will be downloaded automatically and will be made available 34 | in `./Data` directory. 35 | 36 | 37 | ## Training! 38 | ### Autoencoder: 39 | #### Architecture: 40 | 41 | To train a basic autoencoder run: 42 | 43 | python3 autoencoder.py --train True 44 | 45 | * This trains an autoencoder and saves the trained model once every epoch 46 | in the `./Results/Autoencoder` directory. 47 | 48 | To load the trained model and generate images passing inputs to the decoder run: 49 | 50 | python3 autoencoder.py --train False 51 | 52 | ### Adversarial Autoencoder: 53 | #### Architecture: 54 | 55 | Cover 56 | 57 | Training: 58 | 59 | python3 adversarial_autoencoder.py --train True 60 | 61 | Load model and explore the latent space: 62 | 63 | python3 adversarial_autoencoder.py --train False 64 | 65 | Example of adversarial autoencoder output when the encoder is constrained 66 | to have a stddev of 5. 67 | 68 | Cover 69 | 70 | **_Matching prior and posterior distributions._** 71 | 72 | 73 | ![Adversarial_autoencoder](https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/master/README/adversarial_autoencoder_2.png) 74 | **_Distribution of digits in the latent space._** 75 | 76 | ### Supervised Adversarial Autoencoder: 77 | #### Architecture: 78 | 79 | Cover 80 | 81 | Training: 82 | 83 | python3 supervised_adversarial_autoencoder.py --train True 84 | 85 | Load model and explore the latent space: 86 | 87 | python3 supervised_adversarial_autoencoder.py --train False 88 | 89 | Example of disentanglement of style and content: 90 | Cover 91 | 92 | ### Semi-Supervised Adversarial Autoencoder: 93 | #### Architecture: 94 | Cover 95 | 96 | Training: 97 | 98 | python3 semi_supervised_adversarial_autoencoder.py --train True 99 | 100 | Load model and explore the latent space: 101 | 102 | python3 semi_supervised_adversarial_autoencoder.py --train False 103 | 104 | Classification accuracy for 1000 labeled images: 105 | 106 | Cover 107 | 108 | Cover 109 | 110 | 111 | ***Note:*** 112 | * Each run generates a required tensorboard files under `./Results///Tensorboard` directory. 113 | * Use `tensorboard --logdir ` to look at loss variations 114 | and distributions of latent code. 115 | * Windows gives an error when `:` is used during folder naming (this is produced during the folder creation for each run).I 116 | would suggest you to remove the time stamp from `folder_name` variable in the `form_results()` function. Or, just dual boot linux! 117 | 118 | 119 | ## Thank You 120 | Please share this repo if you find it helpful. 121 | -------------------------------------------------------------------------------- /README/AAE Block Diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/AAE Block Diagram.png -------------------------------------------------------------------------------- /README/AAE dist match.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/AAE dist match.png -------------------------------------------------------------------------------- /README/Supervised AAE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/Supervised AAE.png -------------------------------------------------------------------------------- /README/aa_encoder_dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/aa_encoder_dist.png -------------------------------------------------------------------------------- /README/aa_real_dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/aa_real_dist.png -------------------------------------------------------------------------------- /README/adversarial_autoencoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/adversarial_autoencoder.png -------------------------------------------------------------------------------- /README/adversarial_autoencoder_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/adversarial_autoencoder_2.png -------------------------------------------------------------------------------- /README/autoencoder_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/autoencoder_architecture.png -------------------------------------------------------------------------------- /README/cat_n_gauss_dist_real_obtained.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/cat_n_gauss_dist_real_obtained.png -------------------------------------------------------------------------------- /README/cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/cover.png -------------------------------------------------------------------------------- /README/disentanglement of style and content.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/disentanglement of style and content.png -------------------------------------------------------------------------------- /README/grid_175450.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/grid_175450.png -------------------------------------------------------------------------------- /README/grid_177650.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/grid_177650.png -------------------------------------------------------------------------------- /README/nw_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/nw_architecture.png -------------------------------------------------------------------------------- /README/semi_1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_1000.png -------------------------------------------------------------------------------- /README/semi_9960.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_9960.png -------------------------------------------------------------------------------- /README/semi_9970.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_9970.png -------------------------------------------------------------------------------- /README/semi_9980.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_9980.png -------------------------------------------------------------------------------- /README/semi_9990.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_9990.png -------------------------------------------------------------------------------- /README/semi_AAE architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_AAE architecture.png -------------------------------------------------------------------------------- /README/semi_aae_accuracy_with_NN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_aae_accuracy_with_NN.png -------------------------------------------------------------------------------- /README/semi_e_c.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_e_c.png -------------------------------------------------------------------------------- /README/semi_e_g.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_e_g.png -------------------------------------------------------------------------------- /README/semi_r_c.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_r_c.png -------------------------------------------------------------------------------- /README/semi_r_g.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_r_g.png -------------------------------------------------------------------------------- /README/supervised_autoencoder_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/supervised_autoencoder_100.png -------------------------------------------------------------------------------- /Results/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/Results/.gitkeep -------------------------------------------------------------------------------- /Results/Adversarial_Autoencoder/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/Results/Adversarial_Autoencoder/.gitkeep -------------------------------------------------------------------------------- /Results/Autoencoder/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/Results/Autoencoder/.gitkeep -------------------------------------------------------------------------------- /Results/Basic_NN_Classifier/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/Results/Basic_NN_Classifier/.gitkeep -------------------------------------------------------------------------------- /Results/Semi_Supervised/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/Results/Semi_Supervised/.gitkeep -------------------------------------------------------------------------------- /Results/Supervised/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/Results/Supervised/.gitkeep -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /adversarial_autoencoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import datetime 4 | import os 5 | import matplotlib.pyplot as plt 6 | from matplotlib import gridspec 7 | from tensorflow.examples.tutorials.mnist import input_data 8 | 9 | # Progressbar 10 | # bar = progressbar.ProgressBar(widgets=['[', progressbar.Timer(), ']', progressbar.Bar(), '(', progressbar.ETA(), ')']) 11 | 12 | # Get the MNIST data 13 | mnist = input_data.read_data_sets('./Data', one_hot=True) 14 | 15 | # Parameters 16 | input_dim = 784 17 | n_l1 = 1000 18 | n_l2 = 1000 19 | z_dim = 2 20 | batch_size = 100 21 | n_epochs = 1000 22 | learning_rate = 0.001 23 | beta1 = 0.9 24 | results_path = './Results/Adversarial_Autoencoder' 25 | 26 | # Placeholders for input data and the targets 27 | x_input = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_dim], name='Input') 28 | x_target = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_dim], name='Target') 29 | real_distribution = tf.placeholder(dtype=tf.float32, shape=[batch_size, z_dim], name='Real_distribution') 30 | decoder_input = tf.placeholder(dtype=tf.float32, shape=[1, z_dim], name='Decoder_input') 31 | 32 | 33 | def form_results(): 34 | """ 35 | Forms folders for each run to store the tensorboard files, saved models and the log files. 36 | :return: three string pointing to tensorboard, saved models and log paths respectively. 37 | """ 38 | folder_name = "/{0}_{1}_{2}_{3}_{4}_{5}_Adversarial_Autoencoder". \ 39 | format(datetime.datetime.now(), z_dim, learning_rate, batch_size, n_epochs, beta1) 40 | tensorboard_path = results_path + folder_name + '/Tensorboard' 41 | saved_model_path = results_path + folder_name + '/Saved_models/' 42 | log_path = results_path + folder_name + '/log' 43 | if not os.path.exists(results_path + folder_name): 44 | os.mkdir(results_path + folder_name) 45 | os.mkdir(tensorboard_path) 46 | os.mkdir(saved_model_path) 47 | os.mkdir(log_path) 48 | return tensorboard_path, saved_model_path, log_path 49 | 50 | 51 | def generate_image_grid(sess, op): 52 | """ 53 | Generates a grid of images by passing a set of numbers to the decoder and getting its output. 54 | :param sess: Tensorflow Session required to get the decoder output 55 | :param op: Operation that needs to be called inorder to get the decoder output 56 | :return: None, displays a matplotlib window with all the merged images. 57 | """ 58 | x_points = np.arange(-10, 10, 1.5).astype(np.float32) 59 | y_points = np.arange(-10, 10, 1.5).astype(np.float32) 60 | 61 | nx, ny = len(x_points), len(y_points) 62 | plt.subplot() 63 | gs = gridspec.GridSpec(nx, ny, hspace=0.05, wspace=0.05) 64 | 65 | for i, g in enumerate(gs): 66 | z = np.concatenate(([x_points[int(i / ny)]], [y_points[int(i % nx)]])) 67 | z = np.reshape(z, (1, 2)) 68 | x = sess.run(op, feed_dict={decoder_input: z}) 69 | ax = plt.subplot(g) 70 | img = np.array(x.tolist()).reshape(28, 28) 71 | ax.imshow(img, cmap='gray') 72 | ax.set_xticks([]) 73 | ax.set_yticks([]) 74 | ax.set_aspect('auto') 75 | plt.show() 76 | 77 | 78 | def dense(x, n1, n2, name): 79 | """ 80 | Used to create a dense layer. 81 | :param x: input tensor to the dense layer 82 | :param n1: no. of input neurons 83 | :param n2: no. of output neurons 84 | :param name: name of the entire dense layer.i.e, variable scope name. 85 | :return: tensor with shape [batch_size, n2] 86 | """ 87 | with tf.variable_scope(name, reuse=None): 88 | weights = tf.get_variable("weights", shape=[n1, n2], 89 | initializer=tf.random_normal_initializer(mean=0., stddev=0.01)) 90 | bias = tf.get_variable("bias", shape=[n2], initializer=tf.constant_initializer(0.0)) 91 | out = tf.add(tf.matmul(x, weights), bias, name='matmul') 92 | return out 93 | 94 | 95 | # The autoencoder network 96 | def encoder(x, reuse=False): 97 | """ 98 | Encode part of the autoencoder. 99 | :param x: input to the autoencoder 100 | :param reuse: True -> Reuse the encoder variables, False -> Create or search of variables before creating 101 | :return: tensor which is the hidden latent variable of the autoencoder. 102 | """ 103 | if reuse: 104 | tf.get_variable_scope().reuse_variables() 105 | with tf.name_scope('Encoder'): 106 | e_dense_1 = tf.nn.relu(dense(x, input_dim, n_l1, 'e_dense_1')) 107 | e_dense_2 = tf.nn.relu(dense(e_dense_1, n_l1, n_l2, 'e_dense_2')) 108 | latent_variable = dense(e_dense_2, n_l2, z_dim, 'e_latent_variable') 109 | return latent_variable 110 | 111 | 112 | def decoder(x, reuse=False): 113 | """ 114 | Decoder part of the autoencoder. 115 | :param x: input to the decoder 116 | :param reuse: True -> Reuse the decoder variables, False -> Create or search of variables before creating 117 | :return: tensor which should ideally be the input given to the encoder. 118 | """ 119 | if reuse: 120 | tf.get_variable_scope().reuse_variables() 121 | with tf.name_scope('Decoder'): 122 | d_dense_1 = tf.nn.relu(dense(x, z_dim, n_l2, 'd_dense_1')) 123 | d_dense_2 = tf.nn.relu(dense(d_dense_1, n_l2, n_l1, 'd_dense_2')) 124 | output = tf.nn.sigmoid(dense(d_dense_2, n_l1, input_dim, 'd_output')) 125 | return output 126 | 127 | 128 | def discriminator(x, reuse=False): 129 | """ 130 | Discriminator that is used to match the posterior distribution with a given prior distribution. 131 | :param x: tensor of shape [batch_size, z_dim] 132 | :param reuse: True -> Reuse the discriminator variables, 133 | False -> Create or search of variables before creating 134 | :return: tensor of shape [batch_size, 1] 135 | """ 136 | if reuse: 137 | tf.get_variable_scope().reuse_variables() 138 | with tf.name_scope('Discriminator'): 139 | dc_den1 = tf.nn.relu(dense(x, z_dim, n_l1, name='dc_den1')) 140 | dc_den2 = tf.nn.relu(dense(dc_den1, n_l1, n_l2, name='dc_den2')) 141 | output = dense(dc_den2, n_l2, 1, name='dc_output') 142 | return output 143 | 144 | 145 | def train(train_model=True): 146 | """ 147 | Used to train the autoencoder by passing in the necessary inputs. 148 | :param train_model: True -> Train the model, False -> Load the latest trained model and show the image grid. 149 | :return: does not return anything 150 | """ 151 | with tf.variable_scope(tf.get_variable_scope()): 152 | encoder_output = encoder(x_input) 153 | decoder_output = decoder(encoder_output) 154 | 155 | with tf.variable_scope(tf.get_variable_scope()): 156 | d_real = discriminator(real_distribution) 157 | d_fake = discriminator(encoder_output, reuse=True) 158 | 159 | with tf.variable_scope(tf.get_variable_scope()): 160 | decoder_image = decoder(decoder_input, reuse=True) 161 | 162 | # Autoencoder loss 163 | autoencoder_loss = tf.reduce_mean(tf.square(x_target - decoder_output)) 164 | 165 | # Discrimminator Loss 166 | dc_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_real), logits=d_real)) 167 | dc_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(d_fake), logits=d_fake)) 168 | dc_loss = dc_loss_fake + dc_loss_real 169 | 170 | # Generator loss 171 | generator_loss = tf.reduce_mean( 172 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_fake), logits=d_fake)) 173 | 174 | all_variables = tf.trainable_variables() 175 | dc_var = [var for var in all_variables if 'dc_' in var.name] 176 | en_var = [var for var in all_variables if 'e_' in var.name] 177 | 178 | # Optimizers 179 | autoencoder_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, 180 | beta1=beta1).minimize(autoencoder_loss) 181 | discriminator_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, 182 | beta1=beta1).minimize(dc_loss, var_list=dc_var) 183 | generator_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, 184 | beta1=beta1).minimize(generator_loss, var_list=en_var) 185 | 186 | init = tf.global_variables_initializer() 187 | 188 | # Reshape immages to display them 189 | input_images = tf.reshape(x_input, [-1, 28, 28, 1]) 190 | generated_images = tf.reshape(decoder_output, [-1, 28, 28, 1]) 191 | 192 | # Tensorboard visualization 193 | tf.summary.scalar(name='Autoencoder Loss', tensor=autoencoder_loss) 194 | tf.summary.scalar(name='Discriminator Loss', tensor=dc_loss) 195 | tf.summary.scalar(name='Generator Loss', tensor=generator_loss) 196 | tf.summary.histogram(name='Encoder Distribution', values=encoder_output) 197 | tf.summary.histogram(name='Real Distribution', values=real_distribution) 198 | tf.summary.image(name='Input Images', tensor=input_images, max_outputs=10) 199 | tf.summary.image(name='Generated Images', tensor=generated_images, max_outputs=10) 200 | summary_op = tf.summary.merge_all() 201 | 202 | # Saving the model 203 | saver = tf.train.Saver() 204 | step = 0 205 | with tf.Session() as sess: 206 | if train_model: 207 | tensorboard_path, saved_model_path, log_path = form_results() 208 | sess.run(init) 209 | writer = tf.summary.FileWriter(logdir=tensorboard_path, graph=sess.graph) 210 | for i in range(n_epochs): 211 | n_batches = int(mnist.train.num_examples / batch_size) 212 | print("------------------Epoch {}/{}------------------".format(i, n_epochs)) 213 | for b in range(1, n_batches + 1): 214 | z_real_dist = np.random.randn(batch_size, z_dim) * 5. 215 | batch_x, _ = mnist.train.next_batch(batch_size) 216 | sess.run(autoencoder_optimizer, feed_dict={x_input: batch_x, x_target: batch_x}) 217 | sess.run(discriminator_optimizer, 218 | feed_dict={x_input: batch_x, x_target: batch_x, real_distribution: z_real_dist}) 219 | sess.run(generator_optimizer, feed_dict={x_input: batch_x, x_target: batch_x}) 220 | if b % 50 == 0: 221 | a_loss, d_loss, g_loss, summary = sess.run( 222 | [autoencoder_loss, dc_loss, generator_loss, summary_op], 223 | feed_dict={x_input: batch_x, x_target: batch_x, 224 | real_distribution: z_real_dist}) 225 | writer.add_summary(summary, global_step=step) 226 | print("Epoch: {}, iteration: {}".format(i, b)) 227 | print("Autoencoder Loss: {}".format(a_loss)) 228 | print("Discriminator Loss: {}".format(d_loss)) 229 | print("Generator Loss: {}".format(g_loss)) 230 | with open(log_path + '/log.txt', 'a') as log: 231 | log.write("Epoch: {}, iteration: {}\n".format(i, b)) 232 | log.write("Autoencoder Loss: {}\n".format(a_loss)) 233 | log.write("Discriminator Loss: {}\n".format(d_loss)) 234 | log.write("Generator Loss: {}\n".format(g_loss)) 235 | step += 1 236 | 237 | saver.save(sess, save_path=saved_model_path, global_step=step) 238 | else: 239 | # Get the latest results folder 240 | all_results = os.listdir(results_path) 241 | all_results.sort() 242 | saver.restore(sess, save_path=tf.train.latest_checkpoint(results_path + '/' + all_results[-1] + '/Saved_models/')) 243 | generate_image_grid(sess, op=decoder_image) 244 | 245 | if __name__ == '__main__': 246 | train(train_model=True) 247 | -------------------------------------------------------------------------------- /autoencoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import datetime 4 | import os 5 | import matplotlib.pyplot as plt 6 | from matplotlib import gridspec 7 | from tensorflow.examples.tutorials.mnist import input_data 8 | 9 | # Get the MNIST data 10 | mnist = input_data.read_data_sets('./Data', one_hot=True) 11 | 12 | # Parameters 13 | input_dim = 784 14 | n_l1 = 1000 15 | n_l2 = 1000 16 | z_dim = 2 17 | batch_size = 100 18 | n_epochs = 1000 19 | learning_rate = 0.001 20 | beta1 = 0.9 21 | results_path = './Results/Autoencoder' 22 | 23 | 24 | # Placeholders for input data and the targets 25 | x_input = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_dim], name='Input') 26 | x_target = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_dim], name='Target') 27 | decoder_input = tf.placeholder(dtype=tf.float32, shape=[1, z_dim], name='Decoder_input') 28 | 29 | 30 | def generate_image_grid(sess, op): 31 | """ 32 | Generates a grid of images by passing a set of numbers to the decoder and getting its output. 33 | :param sess: Tensorflow Session required to get the decoder output 34 | :param op: Operation that needs to be called inorder to get the decoder output 35 | :return: None, displays a matplotlib window with all the merged images. 36 | """ 37 | x_points = np.arange(0, 1, 1.5).astype(np.float32) 38 | y_points = np.arange(0, 1, 1.5).astype(np.float32) 39 | 40 | nx, ny = len(x_points), len(y_points) 41 | plt.subplot() 42 | gs = gridspec.GridSpec(nx, ny, hspace=0.05, wspace=0.05) 43 | 44 | for i, g in enumerate(gs): 45 | z = np.concatenate(([x_points[int(i / ny)]], [y_points[int(i % nx)]])) 46 | z = np.reshape(z, (1, 2)) 47 | x = sess.run(op, feed_dict={decoder_input: z}) 48 | ax = plt.subplot(g) 49 | img = np.array(x.tolist()).reshape(28, 28) 50 | ax.imshow(img, cmap='gray') 51 | ax.set_xticks([]) 52 | ax.set_yticks([]) 53 | ax.set_aspect('auto') 54 | plt.show() 55 | 56 | 57 | def form_results(): 58 | """ 59 | Forms folders for each run to store the tensorboard files, saved models and the log files. 60 | :return: three string pointing to tensorboard, saved models and log paths respectively. 61 | """ 62 | folder_name = "/{0}_{1}_{2}_{3}_{4}_{5}_autoencoder". \ 63 | format(datetime.datetime.now(), z_dim, learning_rate, batch_size, n_epochs, beta1) 64 | tensorboard_path = results_path + folder_name + '/Tensorboard' 65 | saved_model_path = results_path + folder_name + '/Saved_models/' 66 | log_path = results_path + folder_name + '/log' 67 | if not os.path.exists(results_path + folder_name): 68 | os.mkdir(results_path + folder_name) 69 | os.mkdir(tensorboard_path) 70 | os.mkdir(saved_model_path) 71 | os.mkdir(log_path) 72 | return tensorboard_path, saved_model_path, log_path 73 | 74 | 75 | def dense(x, n1, n2, name): 76 | """ 77 | Used to create a dense layer. 78 | :param x: input tensor to the dense layer 79 | :param n1: no. of input neurons 80 | :param n2: no. of output neurons 81 | :param name: name of the entire dense layer.i.e, variable scope name. 82 | :return: tensor with shape [batch_size, n2] 83 | """ 84 | with tf.variable_scope(name, reuse=None): 85 | weights = tf.get_variable("weights", shape=[n1, n2], 86 | initializer=tf.random_normal_initializer(mean=0., stddev=0.01)) 87 | bias = tf.get_variable("bias", shape=[n2], initializer=tf.constant_initializer(0.0)) 88 | out = tf.add(tf.matmul(x, weights), bias, name='matmul') 89 | return out 90 | 91 | 92 | # The autoencoder network 93 | def encoder(x, reuse=False): 94 | """ 95 | Encode part of the autoencoder 96 | :param x: input to the autoencoder 97 | :param reuse: True -> Reuse the encoder variables, False -> Create or search of variables before creating 98 | :return: tensor which is the hidden latent variable of the autoencoder. 99 | """ 100 | if reuse: 101 | tf.get_variable_scope().reuse_variables() 102 | with tf.name_scope('Encoder'): 103 | e_dense_1 = tf.nn.relu(dense(x, input_dim, n_l1, 'e_dense_1')) 104 | e_dense_2 = tf.nn.relu(dense(e_dense_1, n_l1, n_l2, 'e_dense_2')) 105 | latent_variable = dense(e_dense_2, n_l2, z_dim, 'e_latent_variable') 106 | return latent_variable 107 | 108 | 109 | def decoder(x, reuse=False): 110 | """ 111 | Decoder part of the autoencoder 112 | :param x: input to the decoder 113 | :param reuse: True -> Reuse the decoder variables, False -> Create or search of variables before creating 114 | :return: tensor which should ideally be the input given to the encoder. 115 | """ 116 | if reuse: 117 | tf.get_variable_scope().reuse_variables() 118 | with tf.name_scope('Decoder'): 119 | d_dense_1 = tf.nn.relu(dense(x, z_dim, n_l2, 'd_dense_1')) 120 | d_dense_2 = tf.nn.relu(dense(d_dense_1, n_l2, n_l1, 'd_dense_2')) 121 | output = tf.nn.sigmoid(dense(d_dense_2, n_l1, input_dim, 'd_output')) 122 | return output 123 | 124 | 125 | def train(train_model): 126 | """ 127 | Used to train the autoencoder by passing in the necessary inputs. 128 | :param train_model: True -> Train the model, False -> Load the latest trained model and show the image grid. 129 | :return: does not return anything 130 | """ 131 | with tf.variable_scope(tf.get_variable_scope()): 132 | encoder_output = encoder(x_input) 133 | decoder_output = decoder(encoder_output) 134 | 135 | with tf.variable_scope(tf.get_variable_scope()): 136 | decoder_image = decoder(decoder_input, reuse=True) 137 | 138 | # Loss 139 | loss = tf.reduce_mean(tf.square(x_target - decoder_output)) 140 | 141 | # Optimizer 142 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1).minimize(loss) 143 | init = tf.global_variables_initializer() 144 | 145 | # Visualization 146 | tf.summary.scalar(name='Loss', tensor=loss) 147 | tf.summary.histogram(name='Encoder Distribution', values=encoder_output) 148 | input_images = tf.reshape(x_input, [-1, 28, 28, 1]) 149 | generated_images = tf.reshape(decoder_output, [-1, 28, 28, 1]) 150 | tf.summary.image(name='Input Images', tensor=input_images, max_outputs=10) 151 | tf.summary.image(name='Generated Images', tensor=generated_images, max_outputs=10) 152 | summary_op = tf.summary.merge_all() 153 | 154 | # Saving the model 155 | saver = tf.train.Saver() 156 | step = 0 157 | with tf.Session() as sess: 158 | sess.run(init) 159 | if train_model: 160 | tensorboard_path, saved_model_path, log_path = form_results() 161 | writer = tf.summary.FileWriter(logdir=tensorboard_path, graph=sess.graph) 162 | for i in range(n_epochs): 163 | n_batches = int(mnist.train.num_examples / batch_size) 164 | for b in range(n_batches): 165 | batch_x, _ = mnist.train.next_batch(batch_size) 166 | sess.run(optimizer, feed_dict={x_input: batch_x, x_target: batch_x}) 167 | if b % 50 == 0: 168 | batch_loss, summary = sess.run([loss, summary_op], feed_dict={x_input: batch_x, x_target: batch_x}) 169 | writer.add_summary(summary, global_step=step) 170 | print("Loss: {}".format(batch_loss)) 171 | print("Epoch: {}, iteration: {}".format(i, b)) 172 | with open(log_path + '/log.txt', 'a') as log: 173 | log.write("Epoch: {}, iteration: {}\n".format(i, b)) 174 | log.write("Loss: {}\n".format(batch_loss)) 175 | step += 1 176 | saver.save(sess, save_path=saved_model_path, global_step=step) 177 | print("Model Trained!") 178 | print("Tensorboard Path: {}".format(tensorboard_path)) 179 | print("Log Path: {}".format(log_path + '/log.txt')) 180 | print("Saved Model Path: {}".format(saved_model_path)) 181 | else: 182 | all_results = os.listdir(results_path) 183 | all_results.sort() 184 | saver.restore(sess, 185 | save_path=tf.train.latest_checkpoint(results_path + '/' + all_results[-1] + '/Saved_models/')) 186 | generate_image_grid(sess, op=decoder_image) 187 | 188 | if __name__ == '__main__': 189 | train(train_model=True) 190 | -------------------------------------------------------------------------------- /basic_nn_classifier.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | import datetime 5 | from tensorflow.examples.tutorials.mnist import input_data 6 | 7 | # Parameters 8 | input_dim = 784 9 | n_l1 = 1000 10 | n_l2 = 1000 11 | batch_size = 100 12 | n_epochs = 1000 13 | learning_rate = 0.001 14 | beta1 = 0.9 15 | z_dim = 'NA' 16 | results_path = './Results/Basic_NN_Classifier' 17 | n_labels = 10 18 | n_labeled = 1000 19 | 20 | # Get MNIST data 21 | mnist = input_data.read_data_sets('./Data', one_hot=True) 22 | 23 | # Placeholders 24 | x_input = tf.placeholder(dtype=tf.float32, shape=[None, 784]) 25 | y_target = tf.placeholder(dtype=tf.float32, shape=[None, 10]) 26 | 27 | 28 | def form_results(): 29 | """ 30 | Forms folders for each run to store the tensorboard files, saved models and the log files. 31 | :return: three string pointing to tensorboard, saved models and log paths respectively. 32 | """ 33 | folder_name = "/{0}_{1}_{2}_{3}_{4}_{5}_Basic_NN_Classifier". \ 34 | format(datetime.datetime.now(), z_dim, learning_rate, batch_size, n_epochs, beta1) 35 | tensorboard_path = results_path + folder_name + '/Tensorboard' 36 | saved_model_path = results_path + folder_name + '/Saved_models/' 37 | log_path = results_path + folder_name + '/log' 38 | if not os.path.exists(results_path + folder_name): 39 | os.mkdir(results_path + folder_name) 40 | os.mkdir(tensorboard_path) 41 | os.mkdir(saved_model_path) 42 | os.mkdir(log_path) 43 | return tensorboard_path, saved_model_path, log_path 44 | 45 | 46 | def next_batch(x, y, batch_size): 47 | """ 48 | Used to return a random batch from the given inputs. 49 | :param x: Input images of shape [None, 784] 50 | :param y: Input labels of shape [None, 10] 51 | :param batch_size: integer, batch size of images and labels to return 52 | :return: x -> [batch_size, 784], y-> [batch_size, 10] 53 | """ 54 | index = np.arange(n_labeled) 55 | random_index = np.random.permutation(index)[:batch_size] 56 | return x[random_index], y[random_index] 57 | 58 | 59 | def dense(x, n1, n2, name): 60 | """ 61 | Used to create a dense layer. 62 | :param x: input tensor to the dense layer 63 | :param n1: no. of input neurons 64 | :param n2: no. of output neurons 65 | :param name: name of the entire dense layer. 66 | :return: tensor with shape [batch_size, n2] 67 | """ 68 | with tf.name_scope(name): 69 | weights = tf.Variable(tf.random_normal(shape=[n1, n2], mean=0., stddev=0.01), name='weights') 70 | bias = tf.Variable(tf.zeros(shape=[n2]), name='bias') 71 | output = tf.add(tf.matmul(x, weights), bias, name='output') 72 | return output 73 | 74 | 75 | # Dense Network 76 | def dense_nn(x): 77 | """ 78 | Network used to classify MNIST digits. 79 | :param x: tensor with shape [batch_size, 784], input to the dense fully connected layer. 80 | :return: [batch_size, 10], logits of dense fully connected. 81 | """ 82 | dense_1 = tf.nn.dropout(tf.nn.relu(dense(x, input_dim, n_l1, 'dense_1')), keep_prob=0.25) 83 | dense_2 = tf.nn.dropout(tf.nn.relu(dense(dense_1, n_l1, n_l2, 'dense_2')), keep_prob=0.25) 84 | dense_3 = dense(dense_2, n_l2, n_labels, 'dense_3') 85 | return dense_3 86 | 87 | 88 | def train(): 89 | """ 90 | Used to train the autoencoder by passing in the necessary inputs. 91 | :return: does not return anything 92 | """ 93 | dense_output = dense_nn(x_input) 94 | 95 | # Loss function 96 | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=dense_output, labels=y_target)) 97 | 98 | # Optimizer 99 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1).minimize(loss) 100 | 101 | # Accuracy 102 | pred_op = tf.equal(tf.argmax(dense_output, 1), tf.argmax(y_target, 1)) 103 | accuracy = tf.reduce_mean(tf.cast(pred_op, dtype=tf.float32)) 104 | 105 | # Summary 106 | tf.summary.scalar(name='Loss', tensor=loss) 107 | tf.summary.scalar(name='Accuracy', tensor=accuracy) 108 | summary_op = tf.summary.merge_all() 109 | 110 | saver = tf.train.Saver() 111 | 112 | init = tf.global_variables_initializer() 113 | 114 | step = 0 115 | with tf.Session() as sess: 116 | tensorboard_path, saved_model_path, log_path = form_results() 117 | x_l, y_l = mnist.test.next_batch(n_labeled) 118 | writer = tf.summary.FileWriter(logdir=tensorboard_path, graph=sess.graph) 119 | sess.run(init) 120 | for e in range(1, n_epochs + 1): 121 | n_batches = int(n_labeled / batch_size) 122 | for b in range(1, n_batches + 1): 123 | batch_x_l, batch_y_l = next_batch(x_l, y_l, batch_size=batch_size) 124 | sess.run(optimizer, feed_dict={x_input: batch_x_l, y_target: batch_y_l}) 125 | if b % 5 == 0: 126 | loss_, summary = sess.run([loss, summary_op], feed_dict={x_input: batch_x_l, y_target: batch_y_l}) 127 | writer.add_summary(summary, step) 128 | print("Epoch: {} Iteration: {}".format(e, b)) 129 | print("Loss: {}".format(loss_)) 130 | with open(log_path + '/log.txt', 'a') as log: 131 | log.write("Epoch: {}, iteration: {}\n".format(e, b)) 132 | log.write("Loss: {}\n".format(loss_)) 133 | step += 1 134 | acc = 0 135 | num_batches = int(mnist.validation.num_examples / batch_size) 136 | for j in range(num_batches): 137 | # Classify unseen validation data instead of test data or train data 138 | batch_x_l, batch_y_l = mnist.validation.next_batch(batch_size=batch_size) 139 | val_acc = sess.run(accuracy, feed_dict={x_input: batch_x_l, y_target: batch_y_l}) 140 | acc += val_acc 141 | acc /= num_batches 142 | print("Classification Accuracy: {}".format(acc)) 143 | with open(log_path + '/log.txt', 'a') as log: 144 | log.write("Classification Accuracy: {}".format(acc)) 145 | saver.save(sess, save_path=saved_model_path, global_step=step) 146 | 147 | if __name__ == '__main__': 148 | train() 149 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==2.2.2 2 | numpy==1.14.2 3 | tensorflow-gpu==1.7.0 -------------------------------------------------------------------------------- /semi_supervised_adversarial_autoencoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import datetime 4 | import os 5 | import argparse 6 | import matplotlib.pyplot as plt 7 | from matplotlib import gridspec 8 | from tensorflow.examples.tutorials.mnist import input_data 9 | 10 | # Get the MNIST data 11 | mnist = input_data.read_data_sets('./Data', one_hot=True) 12 | 13 | # Parameters 14 | input_dim = 784 15 | n_l1 = 1000 16 | n_l2 = 1000 17 | z_dim = 10 18 | batch_size = 100 19 | n_epochs = 1000 20 | learning_rate = 0.001 21 | beta1 = 0.9 22 | results_path = './Results/Semi_Supervised' 23 | n_labels = 10 24 | n_labeled = 1000 25 | 26 | # Placeholders for input data and the targets 27 | x_input = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_dim], name='Input') 28 | x_input_l = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_dim], name='Labeled_Input') 29 | y_input = tf.placeholder(dtype=tf.float32, shape=[batch_size, n_labels], name='Labels') 30 | x_target = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_dim], name='Target') 31 | real_distribution = tf.placeholder(dtype=tf.float32, shape=[batch_size, z_dim], name='Real_distribution') 32 | categorial_distribution = tf.placeholder(dtype=tf.float32, shape=[batch_size, n_labels], 33 | name='Categorical_distribution') 34 | manual_decoder_input = tf.placeholder(dtype=tf.float32, shape=[1, z_dim + n_labels], name='Decoder_input') 35 | 36 | 37 | def form_results(): 38 | """ 39 | Forms folders for each run to store the tensorboard files, saved models and the log files. 40 | :return: three string pointing to tensorboard, saved models and log paths respectively. 41 | """ 42 | folder_name = "/{0}_{1}_{2}_{3}_{4}_{5}_Semi_Supervised". \ 43 | format(datetime.datetime.now(), z_dim, learning_rate, batch_size, n_epochs, beta1) 44 | tensorboard_path = results_path + folder_name + '/Tensorboard' 45 | saved_model_path = results_path + folder_name + '/Saved_models/' 46 | log_path = results_path + folder_name + '/log' 47 | if not os.path.exists(results_path + folder_name): 48 | os.mkdir(results_path + folder_name) 49 | os.mkdir(tensorboard_path) 50 | os.mkdir(saved_model_path) 51 | os.mkdir(log_path) 52 | return tensorboard_path, saved_model_path, log_path 53 | 54 | 55 | def generate_image_grid(sess, op): 56 | """ 57 | Generates a grid of images by passing a set of numbers to the decoder and getting its output. 58 | :param sess: Tensorflow Session required to get the decoder output 59 | :param op: Operation that needs to be called inorder to get the decoder output 60 | :return: None, displays a matplotlib window with all the merged images. 61 | """ 62 | nx, ny = 10, 10 63 | random_inputs = np.random.randn(10, z_dim) * 5. 64 | sample_y = np.identity(10) 65 | plt.subplot() 66 | gs = gridspec.GridSpec(nx, ny, hspace=0.05, wspace=0.05) 67 | i = 0 68 | for r in random_inputs: 69 | for t in sample_y: 70 | r, t = np.reshape(r, (1, z_dim)), np.reshape(t, (1, n_labels)) 71 | dec_input = np.concatenate((t, r), 1) 72 | x = sess.run(op, feed_dict={manual_decoder_input: dec_input}) 73 | ax = plt.subplot(gs[i]) 74 | i += 1 75 | img = np.array(x.tolist()).reshape(28, 28) 76 | ax.imshow(img, cmap='gray') 77 | ax.set_xticks([]) 78 | ax.set_yticks([]) 79 | ax.set_aspect('auto') 80 | plt.show() 81 | 82 | 83 | def dense(x, n1, n2, name): 84 | """ 85 | Used to create a dense layer. 86 | :param x: input tensor to the dense layer 87 | :param n1: no. of input neurons 88 | :param n2: no. of output neurons 89 | :param name: name of the entire dense layer.i.e, variable scope name. 90 | :return: tensor with shape [batch_size, n2] 91 | """ 92 | with tf.variable_scope(name, reuse=None): 93 | weights = tf.get_variable("weights", shape=[n1, n2], 94 | initializer=tf.random_normal_initializer(mean=0., stddev=0.01)) 95 | bias = tf.get_variable("bias", shape=[n2], initializer=tf.constant_initializer(0.0)) 96 | out = tf.add(tf.matmul(x, weights), bias, name='matmul') 97 | return out 98 | 99 | 100 | # The autoencoder network 101 | def encoder(x, reuse=False, supervised=False): 102 | """ 103 | Encode part of the autoencoder. 104 | :param x: input to the autoencoder 105 | :param reuse: True -> Reuse the encoder variables, False -> Create or search of variables before creating 106 | :param supervised: True -> returns output without passing it through softmax, 107 | False -> returns output after passing it through softmax. 108 | :return: tensor which is the classification output and a hidden latent variable of the autoencoder. 109 | """ 110 | if reuse: 111 | tf.get_variable_scope().reuse_variables() 112 | with tf.name_scope('Encoder'): 113 | e_dense_1 = tf.nn.relu(dense(x, input_dim, n_l1, 'e_dense_1')) 114 | e_dense_2 = tf.nn.relu(dense(e_dense_1, n_l1, n_l2, 'e_dense_2')) 115 | latent_variable = dense(e_dense_2, n_l2, z_dim, 'e_latent_variable') 116 | cat_op = dense(e_dense_2, n_l2, n_labels, 'e_label') 117 | if not supervised: 118 | softmax_label = tf.nn.softmax(logits=cat_op, name='e_softmax_label') 119 | else: 120 | softmax_label = cat_op 121 | return softmax_label, latent_variable 122 | 123 | 124 | def decoder(x, reuse=False): 125 | """ 126 | Decoder part of the autoencoder. 127 | :param x: input to the decoder 128 | :param reuse: True -> Reuse the decoder variables, False -> Create or search of variables before creating 129 | :return: tensor which should ideally be the input given to the encoder. 130 | """ 131 | if reuse: 132 | tf.get_variable_scope().reuse_variables() 133 | with tf.name_scope('Decoder'): 134 | d_dense_1 = tf.nn.relu(dense(x, z_dim + n_labels, n_l2, 'd_dense_1')) 135 | d_dense_2 = tf.nn.relu(dense(d_dense_1, n_l2, n_l1, 'd_dense_2')) 136 | output = tf.nn.sigmoid(dense(d_dense_2, n_l1, input_dim, 'd_output')) 137 | return output 138 | 139 | 140 | def discriminator_gauss(x, reuse=False): 141 | """ 142 | Discriminator that is used to match the posterior distribution with a given gaussian distribution. 143 | :param x: tensor of shape [batch_size, z_dim] 144 | :param reuse: True -> Reuse the discriminator variables, 145 | False -> Create or search of variables before creating 146 | :return: tensor of shape [batch_size, 1] 147 | """ 148 | if reuse: 149 | tf.get_variable_scope().reuse_variables() 150 | with tf.name_scope('Discriminator_Gauss'): 151 | dc_den1 = tf.nn.relu(dense(x, z_dim, n_l1, name='dc_g_den1')) 152 | dc_den2 = tf.nn.relu(dense(dc_den1, n_l1, n_l2, name='dc_g_den2')) 153 | output = dense(dc_den2, n_l2, 1, name='dc_g_output') 154 | return output 155 | 156 | 157 | def discriminator_categorical(x, reuse=False): 158 | """ 159 | Discriminator that is used to match the posterior distribution with a given categorical distribution. 160 | :param x: tensor of shape [batch_size, n_labels] 161 | :param reuse: True -> Reuse the discriminator variables, 162 | False -> Create or search of variables before creating 163 | :return: tensor of shape [batch_size, 1] 164 | """ 165 | if reuse: 166 | tf.get_variable_scope().reuse_variables() 167 | with tf.name_scope('Discriminator_Categorial'): 168 | dc_den1 = tf.nn.relu(dense(x, n_labels, n_l1, name='dc_c_den1')) 169 | dc_den2 = tf.nn.relu(dense(dc_den1, n_l1, n_l2, name='dc_c_den2')) 170 | output = dense(dc_den2, n_l2, 1, name='dc_c_output') 171 | return output 172 | 173 | 174 | def next_batch(x, y, batch_size): 175 | """ 176 | Used to return a random batch from the given inputs. 177 | :param x: Input images of shape [None, 784] 178 | :param y: Input labels of shape [None, 10] 179 | :param batch_size: integer, batch size of images and labels to return 180 | :return: x -> [batch_size, 784], y-> [batch_size, 10] 181 | """ 182 | index = np.arange(n_labeled) 183 | random_index = np.random.permutation(index)[:batch_size] 184 | return x[random_index], y[random_index] 185 | 186 | 187 | def train(train_model=True): 188 | """ 189 | Used to train the autoencoder by passing in the necessary inputs. 190 | :param train_model: True -> Train the model, False -> Load the latest trained model and show the image grid. 191 | :return: does not return anything 192 | """ 193 | 194 | # Reconstruction Phase 195 | with tf.variable_scope(tf.get_variable_scope()): 196 | encoder_output_label, encoder_output_latent = encoder(x_input) 197 | # Concat class label and the encoder output 198 | decoder_input = tf.concat([encoder_output_label, encoder_output_latent], 1) 199 | decoder_output = decoder(decoder_input) 200 | 201 | # Regularization Phase 202 | with tf.variable_scope(tf.get_variable_scope()): 203 | d_g_real = discriminator_gauss(real_distribution) 204 | d_g_fake = discriminator_gauss(encoder_output_latent, reuse=True) 205 | 206 | with tf.variable_scope(tf.get_variable_scope()): 207 | d_c_real = discriminator_categorical(categorial_distribution) 208 | d_c_fake = discriminator_categorical(encoder_output_label, reuse=True) 209 | 210 | # Semi-Supervised Classification Phase 211 | with tf.variable_scope(tf.get_variable_scope()): 212 | encoder_output_label_, _ = encoder(x_input_l, reuse=True, supervised=True) 213 | 214 | # Generate output images 215 | with tf.variable_scope(tf.get_variable_scope()): 216 | decoder_image = decoder(manual_decoder_input, reuse=True) 217 | 218 | # Classification accuracy of encoder 219 | correct_pred = tf.equal(tf.argmax(encoder_output_label_, 1), tf.argmax(y_input, 1)) 220 | accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) 221 | 222 | # Autoencoder loss 223 | autoencoder_loss = tf.reduce_mean(tf.square(x_target - decoder_output)) 224 | 225 | # Gaussian Discriminator Loss 226 | dc_g_loss_real = tf.reduce_mean( 227 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_g_real), logits=d_g_real)) 228 | dc_g_loss_fake = tf.reduce_mean( 229 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(d_g_fake), logits=d_g_fake)) 230 | dc_g_loss = dc_g_loss_fake + dc_g_loss_real 231 | 232 | # Categorical Discrimminator Loss 233 | dc_c_loss_real = tf.reduce_mean( 234 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_c_real), logits=d_c_real)) 235 | dc_c_loss_fake = tf.reduce_mean( 236 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(d_c_fake), logits=d_c_fake)) 237 | dc_c_loss = dc_c_loss_fake + dc_c_loss_real 238 | 239 | # Generator loss 240 | generator_g_loss = tf.reduce_mean( 241 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_g_fake), logits=d_g_fake)) 242 | generator_c_loss = tf.reduce_mean( 243 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_c_fake), logits=d_c_fake)) 244 | generator_loss = generator_c_loss + generator_g_loss 245 | 246 | # Supervised Encoder Loss 247 | supervised_encoder_loss = tf.reduce_mean( 248 | tf.nn.softmax_cross_entropy_with_logits(labels=y_input, logits=encoder_output_label_)) 249 | 250 | all_variables = tf.trainable_variables() 251 | dc_g_var = [var for var in all_variables if 'dc_g_' in var.name] 252 | dc_c_var = [var for var in all_variables if 'dc_c_' in var.name] 253 | en_var = [var for var in all_variables if 'e_' in var.name] 254 | 255 | # Optimizers 256 | autoencoder_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, 257 | beta1=beta1).minimize(autoencoder_loss) 258 | discriminator_g_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, 259 | beta1=beta1).minimize(dc_g_loss, var_list=dc_g_var) 260 | discriminator_c_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, 261 | beta1=beta1).minimize(dc_c_loss, var_list=dc_c_var) 262 | generator_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, 263 | beta1=beta1).minimize(generator_loss, var_list=en_var) 264 | supervised_encoder_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, 265 | beta1=beta1).minimize(supervised_encoder_loss, 266 | var_list=en_var) 267 | 268 | init = tf.global_variables_initializer() 269 | 270 | # Reshape immages to display them 271 | input_images = tf.reshape(x_input, [-1, 28, 28, 1]) 272 | generated_images = tf.reshape(decoder_output, [-1, 28, 28, 1]) 273 | 274 | # Tensorboard visualization 275 | tf.summary.scalar(name='Autoencoder Loss', tensor=autoencoder_loss) 276 | tf.summary.scalar(name='Discriminator gauss Loss', tensor=dc_g_loss) 277 | tf.summary.scalar(name='Discriminator categorical Loss', tensor=dc_c_loss) 278 | tf.summary.scalar(name='Generator Loss', tensor=generator_loss) 279 | tf.summary.scalar(name='Supervised Encoder Loss', tensor=supervised_encoder_loss) 280 | tf.summary.histogram(name='Encoder Gauss Distribution', values=encoder_output_latent) 281 | tf.summary.histogram(name='Real Gauss Distribution', values=real_distribution) 282 | tf.summary.histogram(name='Encoder Categorical Distribution', values=encoder_output_label) 283 | tf.summary.histogram(name='Real Categorical Distribution', values=categorial_distribution) 284 | tf.summary.image(name='Input Images', tensor=input_images, max_outputs=10) 285 | tf.summary.image(name='Generated Images', tensor=generated_images, max_outputs=10) 286 | summary_op = tf.summary.merge_all() 287 | 288 | # Saving the model 289 | saver = tf.train.Saver() 290 | step = 0 291 | with tf.Session() as sess: 292 | if train_model: 293 | tensorboard_path, saved_model_path, log_path = form_results() 294 | sess.run(init) 295 | writer = tf.summary.FileWriter(logdir=tensorboard_path, graph=sess.graph) 296 | x_l, y_l = mnist.test.next_batch(n_labeled) 297 | for i in range(n_epochs): 298 | n_batches = int(n_labeled / batch_size) 299 | print("------------------Epoch {}/{}------------------".format(i, n_epochs)) 300 | for b in range(1, n_batches + 1): 301 | z_real_dist = np.random.randn(batch_size, z_dim) * 5. 302 | real_cat_dist = np.random.randint(low=0, high=10, size=batch_size) 303 | real_cat_dist = np.eye(n_labels)[real_cat_dist] 304 | batch_x_ul, _ = mnist.train.next_batch(batch_size) 305 | batch_x_l, batch_y_l = next_batch(x_l, y_l, batch_size=batch_size) 306 | sess.run(autoencoder_optimizer, feed_dict={x_input: batch_x_ul, x_target: batch_x_ul}) 307 | sess.run(discriminator_g_optimizer, 308 | feed_dict={x_input: batch_x_ul, x_target: batch_x_ul, real_distribution: z_real_dist}) 309 | sess.run(discriminator_c_optimizer, 310 | feed_dict={x_input: batch_x_ul, x_target: batch_x_ul, 311 | categorial_distribution: real_cat_dist}) 312 | sess.run(generator_optimizer, feed_dict={x_input: batch_x_ul, x_target: batch_x_ul}) 313 | sess.run(supervised_encoder_optimizer, feed_dict={x_input_l: batch_x_l, y_input: batch_y_l}) 314 | if b % 5 == 0: 315 | a_loss, d_g_loss, d_c_loss, g_loss, s_loss, summary = sess.run( 316 | [autoencoder_loss, dc_g_loss, dc_c_loss, generator_loss, supervised_encoder_loss, 317 | summary_op], 318 | feed_dict={x_input: batch_x_ul, x_target: batch_x_ul, 319 | real_distribution: z_real_dist, y_input: batch_y_l, x_input_l: batch_x_l, 320 | categorial_distribution: real_cat_dist}) 321 | writer.add_summary(summary, global_step=step) 322 | print("Epoch: {}, iteration: {}".format(i, b)) 323 | print("Autoencoder Loss: {}".format(a_loss)) 324 | print("Discriminator Gauss Loss: {}".format(d_g_loss)) 325 | print("Discriminator Categorical Loss: {}".format(d_c_loss)) 326 | print("Generator Loss: {}".format(g_loss)) 327 | print("Supervised Loss: {}\n".format(s_loss)) 328 | with open(log_path + '/log.txt', 'a') as log: 329 | log.write("Epoch: {}, iteration: {}\n".format(i, b)) 330 | log.write("Autoencoder Loss: {}\n".format(a_loss)) 331 | log.write("Discriminator Gauss Loss: {}".format(d_g_loss)) 332 | log.write("Discriminator Categorical Loss: {}".format(d_c_loss)) 333 | log.write("Generator Loss: {}\n".format(g_loss)) 334 | log.write("Supervised Loss: {}".format(s_loss)) 335 | step += 1 336 | acc = 0 337 | num_batches = int(mnist.validation.num_examples/batch_size) 338 | for j in range(num_batches): 339 | # Classify unseen validation data instead of test data or train data 340 | batch_x_l, batch_y_l = mnist.validation.next_batch(batch_size=batch_size) 341 | encoder_acc = sess.run(accuracy, feed_dict={x_input_l: batch_x_l, y_input: batch_y_l}) 342 | acc += encoder_acc 343 | acc /= num_batches 344 | print("Encoder Classification Accuracy: {}".format(acc)) 345 | with open(log_path + '/log.txt', 'a') as log: 346 | log.write("Encoder Classification Accuracy: {}".format(acc)) 347 | saver.save(sess, save_path=saved_model_path, global_step=step) 348 | else: 349 | # Get the latest results folder 350 | all_results = os.listdir(results_path) 351 | all_results.sort() 352 | saver.restore(sess, save_path=tf.train.latest_checkpoint(results_path + '/' + 353 | all_results[-1] + '/Saved_models/')) 354 | generate_image_grid(sess, op=decoder_image) 355 | 356 | 357 | if __name__ == '__main__': 358 | parser = argparse.ArgumentParser(description="Autoencoder Train Parameter") 359 | parser.add_argument('--train', '-t', type=bool, default=True, 360 | help='Set to True to train a new model, False to load weights and display image grid') 361 | args = parser.parse_args() 362 | train(train_model=args.train) 363 | -------------------------------------------------------------------------------- /supervised_adversarial_autoencoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import datetime 4 | import os 5 | import argparse 6 | import matplotlib.pyplot as plt 7 | from matplotlib import gridspec 8 | from tensorflow.examples.tutorials.mnist import input_data 9 | 10 | # Get the MNIST data 11 | mnist = input_data.read_data_sets('./Data', one_hot=True) 12 | 13 | # Parameters 14 | input_dim = 784 15 | n_l1 = 1000 16 | n_l2 = 1000 17 | z_dim = 15 18 | batch_size = 100 19 | n_epochs = 1000 20 | learning_rate = 0.001 21 | beta1 = 0.9 22 | results_path = './Results/Supervised' 23 | n_labels = 10 24 | 25 | # Placeholders for input data and the targets 26 | x_input = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_dim], name='Input') 27 | y_input = tf.placeholder(dtype=tf.float32, shape=[batch_size, n_labels], name='Labels') 28 | x_target = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_dim], name='Target') 29 | real_distribution = tf.placeholder(dtype=tf.float32, shape=[batch_size, z_dim], name='Real_distribution') 30 | manual_decoder_input = tf.placeholder(dtype=tf.float32, shape=[1, z_dim + n_labels], name='Decoder_input') 31 | 32 | 33 | def form_results(): 34 | """ 35 | Forms folders for each run to store the tensorboard files, saved models and the log files. 36 | :return: three string pointing to tensorboard, saved models and log paths respectively. 37 | """ 38 | folder_name = "/{0}_{1}_{2}_{3}_{4}_{5}_Supervised". \ 39 | format(datetime.datetime.now(), z_dim, learning_rate, batch_size, n_epochs, beta1) 40 | tensorboard_path = results_path + folder_name + '/Tensorboard' 41 | saved_model_path = results_path + folder_name + '/Saved_models/' 42 | log_path = results_path + folder_name + '/log' 43 | if not os.path.exists(results_path + folder_name): 44 | os.mkdir(results_path + folder_name) 45 | os.mkdir(tensorboard_path) 46 | os.mkdir(saved_model_path) 47 | os.mkdir(log_path) 48 | return tensorboard_path, saved_model_path, log_path 49 | 50 | 51 | def generate_image_grid(sess, op): 52 | """ 53 | Generates a grid of images by passing a set of numbers to the decoder and getting its output. 54 | :param sess: Tensorflow Session required to get the decoder output 55 | :param op: Operation that needs to be called inorder to get the decoder output 56 | :return: None, displays a matplotlib window with all the merged images. 57 | """ 58 | nx, ny = 10, 10 59 | random_inputs = np.random.randn(10, z_dim) * 5. 60 | sample_y = np.identity(10) 61 | plt.subplot() 62 | gs = gridspec.GridSpec(nx, ny, hspace=0.05, wspace=0.05) 63 | i = 0 64 | for r in random_inputs: 65 | for t in sample_y: 66 | r, t = np.reshape(r, (1, z_dim)), np.reshape(t, (1, n_labels)) 67 | dec_input = np.concatenate((t, r), 1) 68 | x = sess.run(op, feed_dict={manual_decoder_input: dec_input}) 69 | ax = plt.subplot(gs[i]) 70 | i += 1 71 | img = np.array(x.tolist()).reshape(28, 28) 72 | ax.imshow(img, cmap='gray') 73 | ax.set_xticks([]) 74 | ax.set_yticks([]) 75 | ax.set_aspect('auto') 76 | plt.show() 77 | 78 | 79 | def dense(x, n1, n2, name): 80 | """ 81 | Used to create a dense layer. 82 | :param x: input tensor to the dense layer 83 | :param n1: no. of input neurons 84 | :param n2: no. of output neurons 85 | :param name: name of the entire dense layer.i.e, variable scope name. 86 | :return: tensor with shape [batch_size, n2] 87 | """ 88 | with tf.variable_scope(name, reuse=None): 89 | weights = tf.get_variable("weights", shape=[n1, n2], 90 | initializer=tf.random_normal_initializer(mean=0., stddev=0.01)) 91 | bias = tf.get_variable("bias", shape=[n2], initializer=tf.constant_initializer(0.0)) 92 | out = tf.add(tf.matmul(x, weights), bias, name='matmul') 93 | return out 94 | 95 | 96 | # The autoencoder network 97 | def encoder(x, reuse=False): 98 | """ 99 | Encode part of the autoencoder. 100 | :param x: input to the autoencoder 101 | :param reuse: True -> Reuse the encoder variables, False -> Create or search of variables before creating 102 | :param supervised: True -> returns output without passing it through softmax, 103 | False -> returns output after passing it through softmax. 104 | :return: tensor which is the classification output and a hidden latent variable of the autoencoder. 105 | """ 106 | if reuse: 107 | tf.get_variable_scope().reuse_variables() 108 | with tf.name_scope('Encoder'): 109 | e_dense_1 = tf.nn.relu(dense(x, input_dim, n_l1, 'e_dense_1')) 110 | e_dense_2 = tf.nn.relu(dense(e_dense_1, n_l1, n_l2, 'e_dense_2')) 111 | latent_variable = dense(e_dense_2, n_l2, z_dim, 'e_latent_variable') 112 | return latent_variable 113 | 114 | 115 | def decoder(x, reuse=False): 116 | """ 117 | Decoder part of the autoencoder. 118 | :param x: input to the decoder 119 | :param reuse: True -> Reuse the decoder variables, False -> Create or search of variables before creating 120 | :return: tensor which should ideally be the input given to the encoder. 121 | """ 122 | if reuse: 123 | tf.get_variable_scope().reuse_variables() 124 | with tf.name_scope('Decoder'): 125 | d_dense_1 = tf.nn.relu(dense(x, z_dim + n_labels, n_l2, 'd_dense_1')) 126 | d_dense_2 = tf.nn.relu(dense(d_dense_1, n_l2, n_l1, 'd_dense_2')) 127 | output = tf.nn.sigmoid(dense(d_dense_2, n_l1, input_dim, 'd_output')) 128 | return output 129 | 130 | 131 | def discriminator(x, reuse=False): 132 | """ 133 | Discriminator that is used to match the posterior distribution with a given prior distribution. 134 | :param x: tensor of shape [batch_size, z_dim] 135 | :param reuse: True -> Reuse the discriminator variables, 136 | False -> Create or search of variables before creating 137 | :return: tensor of shape [batch_size, 1] 138 | """ 139 | if reuse: 140 | tf.get_variable_scope().reuse_variables() 141 | with tf.name_scope('Discriminator'): 142 | dc_den1 = tf.nn.relu(dense(x, z_dim, n_l1, name='dc_den1')) 143 | dc_den2 = tf.nn.relu(dense(dc_den1, n_l1, n_l2, name='dc_den2')) 144 | output = dense(dc_den2, n_l2, 1, name='dc_output') 145 | return output 146 | 147 | 148 | def train(train_model=True): 149 | """ 150 | Used to train the autoencoder by passing in the necessary inputs. 151 | :param train_model: True -> Train the model, False -> Load the latest trained model and show the image grid. 152 | :return: does not return anything 153 | """ 154 | with tf.variable_scope(tf.get_variable_scope()): 155 | encoder_output = encoder(x_input) 156 | # Concat class label and the encoder output 157 | decoder_input = tf.concat([y_input, encoder_output], 1) 158 | decoder_output = decoder(decoder_input) 159 | 160 | with tf.variable_scope(tf.get_variable_scope()): 161 | d_real = discriminator(real_distribution) 162 | d_fake = discriminator(encoder_output, reuse=True) 163 | 164 | with tf.variable_scope(tf.get_variable_scope()): 165 | decoder_image = decoder(manual_decoder_input, reuse=True) 166 | 167 | # Autoencoder loss 168 | autoencoder_loss = tf.reduce_mean(tf.square(x_target - decoder_output)) 169 | 170 | # Discriminator Loss 171 | dc_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_real), logits=d_real)) 172 | dc_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(d_fake), logits=d_fake)) 173 | dc_loss = dc_loss_fake + dc_loss_real 174 | 175 | # Generator loss 176 | generator_loss = tf.reduce_mean( 177 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_fake), logits=d_fake)) 178 | 179 | all_variables = tf.trainable_variables() 180 | dc_var = [var for var in all_variables if 'dc_' in var.name] 181 | en_var = [var for var in all_variables if 'e_' in var.name] 182 | 183 | # Optimizers 184 | autoencoder_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, 185 | beta1=beta1).minimize(autoencoder_loss) 186 | discriminator_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, 187 | beta1=beta1).minimize(dc_loss, var_list=dc_var) 188 | generator_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, 189 | beta1=beta1).minimize(generator_loss, var_list=en_var) 190 | 191 | init = tf.global_variables_initializer() 192 | 193 | # Reshape images to display them 194 | input_images = tf.reshape(x_input, [-1, 28, 28, 1]) 195 | generated_images = tf.reshape(decoder_output, [-1, 28, 28, 1]) 196 | 197 | # Tensorboard visualization 198 | tf.summary.scalar(name='Autoencoder Loss', tensor=autoencoder_loss) 199 | tf.summary.scalar(name='Discriminator Loss', tensor=dc_loss) 200 | tf.summary.scalar(name='Generator Loss', tensor=generator_loss) 201 | tf.summary.histogram(name='Encoder Distribution', values=encoder_output) 202 | tf.summary.histogram(name='Real Distribution', values=real_distribution) 203 | tf.summary.image(name='Input Images', tensor=input_images, max_outputs=10) 204 | tf.summary.image(name='Generated Images', tensor=generated_images, max_outputs=10) 205 | summary_op = tf.summary.merge_all() 206 | 207 | # Saving the model 208 | saver = tf.train.Saver() 209 | step = 0 210 | with tf.Session() as sess: 211 | if train_model: 212 | tensorboard_path, saved_model_path, log_path = form_results() 213 | sess.run(init) 214 | writer = tf.summary.FileWriter(logdir=tensorboard_path, graph=sess.graph) 215 | for i in range(n_epochs): 216 | n_batches = int(mnist.train.num_examples / batch_size) 217 | print("------------------Epoch {}/{}------------------".format(i, n_epochs)) 218 | for b in range(1, n_batches + 1): 219 | z_real_dist = np.random.randn(batch_size, z_dim) * 5. 220 | batch_x, batch_y = mnist.train.next_batch(batch_size) 221 | sess.run(autoencoder_optimizer, feed_dict={x_input: batch_x, x_target: batch_x, y_input: batch_y}) 222 | sess.run(discriminator_optimizer, 223 | feed_dict={x_input: batch_x, x_target: batch_x, real_distribution: z_real_dist}) 224 | sess.run(generator_optimizer, feed_dict={x_input: batch_x, x_target: batch_x}) 225 | if b % 50 == 0: 226 | a_loss, d_loss, g_loss, summary = sess.run( 227 | [autoencoder_loss, dc_loss, generator_loss, summary_op], 228 | feed_dict={x_input: batch_x, x_target: batch_x, 229 | real_distribution: z_real_dist, y_input: batch_y}) 230 | writer.add_summary(summary, global_step=step) 231 | print("Epoch: {}, iteration: {}".format(i, b)) 232 | print("Autoencoder Loss: {}".format(a_loss)) 233 | print("Discriminator Loss: {}".format(d_loss)) 234 | print("Generator Loss: {}".format(g_loss)) 235 | with open(log_path + '/log.txt', 'a') as log: 236 | log.write("Epoch: {}, iteration: {}\n".format(i, b)) 237 | log.write("Autoencoder Loss: {}\n".format(a_loss)) 238 | log.write("Discriminator Loss: {}\n".format(d_loss)) 239 | log.write("Generator Loss: {}\n".format(g_loss)) 240 | step += 1 241 | 242 | saver.save(sess, save_path=saved_model_path, global_step=step) 243 | else: 244 | # Get the latest results folder 245 | all_results = os.listdir(results_path) 246 | all_results.sort() 247 | saver.restore(sess, save_path=tf.train.latest_checkpoint(results_path + '/' + 248 | all_results[-1] + '/Saved_models/')) 249 | generate_image_grid(sess, op=decoder_image) 250 | 251 | 252 | if __name__ == '__main__': 253 | parser = argparse.ArgumentParser(description="Autoencoder Train Parameter") 254 | parser.add_argument('--train', '-t', type=bool, default=True, 255 | help='Set to True to train a new model, False to load weights and display image grid') 256 | args = parser.parse_args() 257 | train(train_model=args.train) 258 | --------------------------------------------------------------------------------