├── .idea
├── Adversarial_Autoencoder.iml
├── inspectionProfiles
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── LICENSE
├── README.md
├── README
├── AAE Block Diagram.png
├── AAE dist match.png
├── Supervised AAE.png
├── aa_encoder_dist.png
├── aa_real_dist.png
├── adversarial_autoencoder.png
├── adversarial_autoencoder_2.png
├── autoencoder_architecture.png
├── cat_n_gauss_dist_real_obtained.png
├── cover.png
├── disentanglement of style and content.png
├── grid_175450.png
├── grid_177650.png
├── nw_architecture.png
├── semi_1000.png
├── semi_9960.png
├── semi_9970.png
├── semi_9980.png
├── semi_9990.png
├── semi_AAE architecture.png
├── semi_aae_accuracy_with_NN.png
├── semi_e_c.png
├── semi_e_g.png
├── semi_r_c.png
├── semi_r_g.png
└── supervised_autoencoder_100.png
├── Results
├── .gitkeep
├── Adversarial_Autoencoder
│ └── .gitkeep
├── Autoencoder
│ └── .gitkeep
├── Basic_NN_Classifier
│ └── .gitkeep
├── Semi_Supervised
│ └── .gitkeep
└── Supervised
│ └── .gitkeep
├── _config.yml
├── adversarial_autoencoder.py
├── autoencoder.py
├── basic_nn_classifier.py
├── requirements.txt
├── semi_supervised_adversarial_autoencoder.py
└── supervised_adversarial_autoencoder.py
/.idea/Adversarial_Autoencoder.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Naresh Nagabushan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Adversarial autoencoders
2 |
3 |
4 | This repository contains code to implement adversarial autoencoder using Tensorflow.
5 |
6 | Medium posts:
7 |
8 | 1. [A Wizard's guide to Adversarial Autoencoders: Part 1. Autoencoders?](https://medium.com/towards-data-science/a-wizards-guide-to-adversarial-autoencoders-part-1-autoencoder-d9a5f8795af4)
9 |
10 | 2. [A Wizard's guide to Adversarial Autoencoders: Part 2. Exploring the latent space with Adversarial Autoencoders.](https://medium.com/towards-data-science/a-wizards-guide-to-adversarial-autoencoders-part-2-exploring-latent-space-with-adversarial-2d53a6f8a4f9)
11 |
12 | 3. [A Wizard's guide to Adversarial Autoencoders: Part 3. Disentanglement of style and content.](https://medium.com/towards-data-science/a-wizards-guide-to-adversarial-autoencoders-part-3-disentanglement-of-style-and-content-89262973a4d7)
13 |
14 | 3. [A Wizard's guide to Adversarial Autoencoders: Part 4. Classify MNIST using 1000 labels.](https://medium.com/towards-data-science/a-wizards-guide-to-adversarial-autoencoders-part-4-classify-mnist-using-1000-labels-2ca08071f95)
15 |
16 | ## Installing the dependencies
17 | Install virtualenv and creating a new virtual environment:
18 |
19 | pip install virtualenv
20 | virtualenv -p /usr/bin/python3 aa
21 |
22 | Install dependencies
23 |
24 | pip3 install -r requirements.txt
25 |
26 | ***Note:***
27 |
28 | * *I'd highly recommend using your GPU during training.*
29 | * *`tf.nn.sigmoid_cross_entropy_with_logits` has a `targets` parameter which
30 | has been changed to `labels` for tensorflow version > r0.12.*
31 |
32 | ## Dataset
33 | The MNIST dataset will be downloaded automatically and will be made available
34 | in `./Data` directory.
35 |
36 |
37 | ## Training!
38 | ### Autoencoder:
39 | #### Architecture:
40 |
41 | To train a basic autoencoder run:
42 |
43 | python3 autoencoder.py --train True
44 |
45 | * This trains an autoencoder and saves the trained model once every epoch
46 | in the `./Results/Autoencoder` directory.
47 |
48 | To load the trained model and generate images passing inputs to the decoder run:
49 |
50 | python3 autoencoder.py --train False
51 |
52 | ### Adversarial Autoencoder:
53 | #### Architecture:
54 |
55 |
56 |
57 | Training:
58 |
59 | python3 adversarial_autoencoder.py --train True
60 |
61 | Load model and explore the latent space:
62 |
63 | python3 adversarial_autoencoder.py --train False
64 |
65 | Example of adversarial autoencoder output when the encoder is constrained
66 | to have a stddev of 5.
67 |
68 |
69 |
70 | **_Matching prior and posterior distributions._**
71 |
72 |
73 | 
74 | **_Distribution of digits in the latent space._**
75 |
76 | ### Supervised Adversarial Autoencoder:
77 | #### Architecture:
78 |
79 |
80 |
81 | Training:
82 |
83 | python3 supervised_adversarial_autoencoder.py --train True
84 |
85 | Load model and explore the latent space:
86 |
87 | python3 supervised_adversarial_autoencoder.py --train False
88 |
89 | Example of disentanglement of style and content:
90 |
91 |
92 | ### Semi-Supervised Adversarial Autoencoder:
93 | #### Architecture:
94 |
95 |
96 | Training:
97 |
98 | python3 semi_supervised_adversarial_autoencoder.py --train True
99 |
100 | Load model and explore the latent space:
101 |
102 | python3 semi_supervised_adversarial_autoencoder.py --train False
103 |
104 | Classification accuracy for 1000 labeled images:
105 |
106 |
107 |
108 |
109 |
110 |
111 | ***Note:***
112 | * Each run generates a required tensorboard files under `./Results///Tensorboard` directory.
113 | * Use `tensorboard --logdir ` to look at loss variations
114 | and distributions of latent code.
115 | * Windows gives an error when `:` is used during folder naming (this is produced during the folder creation for each run).I
116 | would suggest you to remove the time stamp from `folder_name` variable in the `form_results()` function. Or, just dual boot linux!
117 |
118 |
119 | ## Thank You
120 | Please share this repo if you find it helpful.
121 |
--------------------------------------------------------------------------------
/README/AAE Block Diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/AAE Block Diagram.png
--------------------------------------------------------------------------------
/README/AAE dist match.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/AAE dist match.png
--------------------------------------------------------------------------------
/README/Supervised AAE.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/Supervised AAE.png
--------------------------------------------------------------------------------
/README/aa_encoder_dist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/aa_encoder_dist.png
--------------------------------------------------------------------------------
/README/aa_real_dist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/aa_real_dist.png
--------------------------------------------------------------------------------
/README/adversarial_autoencoder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/adversarial_autoencoder.png
--------------------------------------------------------------------------------
/README/adversarial_autoencoder_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/adversarial_autoencoder_2.png
--------------------------------------------------------------------------------
/README/autoencoder_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/autoencoder_architecture.png
--------------------------------------------------------------------------------
/README/cat_n_gauss_dist_real_obtained.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/cat_n_gauss_dist_real_obtained.png
--------------------------------------------------------------------------------
/README/cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/cover.png
--------------------------------------------------------------------------------
/README/disentanglement of style and content.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/disentanglement of style and content.png
--------------------------------------------------------------------------------
/README/grid_175450.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/grid_175450.png
--------------------------------------------------------------------------------
/README/grid_177650.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/grid_177650.png
--------------------------------------------------------------------------------
/README/nw_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/nw_architecture.png
--------------------------------------------------------------------------------
/README/semi_1000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_1000.png
--------------------------------------------------------------------------------
/README/semi_9960.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_9960.png
--------------------------------------------------------------------------------
/README/semi_9970.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_9970.png
--------------------------------------------------------------------------------
/README/semi_9980.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_9980.png
--------------------------------------------------------------------------------
/README/semi_9990.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_9990.png
--------------------------------------------------------------------------------
/README/semi_AAE architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_AAE architecture.png
--------------------------------------------------------------------------------
/README/semi_aae_accuracy_with_NN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_aae_accuracy_with_NN.png
--------------------------------------------------------------------------------
/README/semi_e_c.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_e_c.png
--------------------------------------------------------------------------------
/README/semi_e_g.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_e_g.png
--------------------------------------------------------------------------------
/README/semi_r_c.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_r_c.png
--------------------------------------------------------------------------------
/README/semi_r_g.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/semi_r_g.png
--------------------------------------------------------------------------------
/README/supervised_autoencoder_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/README/supervised_autoencoder_100.png
--------------------------------------------------------------------------------
/Results/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/Results/.gitkeep
--------------------------------------------------------------------------------
/Results/Adversarial_Autoencoder/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/Results/Adversarial_Autoencoder/.gitkeep
--------------------------------------------------------------------------------
/Results/Autoencoder/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/Results/Autoencoder/.gitkeep
--------------------------------------------------------------------------------
/Results/Basic_NN_Classifier/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/Results/Basic_NN_Classifier/.gitkeep
--------------------------------------------------------------------------------
/Results/Semi_Supervised/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/Results/Semi_Supervised/.gitkeep
--------------------------------------------------------------------------------
/Results/Supervised/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Naresh1318/Adversarial_Autoencoder/e689c0f408f5b1cab58ad48a962be3e5d94cb97a/Results/Supervised/.gitkeep
--------------------------------------------------------------------------------
/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-cayman
--------------------------------------------------------------------------------
/adversarial_autoencoder.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import datetime
4 | import os
5 | import matplotlib.pyplot as plt
6 | from matplotlib import gridspec
7 | from tensorflow.examples.tutorials.mnist import input_data
8 |
9 | # Progressbar
10 | # bar = progressbar.ProgressBar(widgets=['[', progressbar.Timer(), ']', progressbar.Bar(), '(', progressbar.ETA(), ')'])
11 |
12 | # Get the MNIST data
13 | mnist = input_data.read_data_sets('./Data', one_hot=True)
14 |
15 | # Parameters
16 | input_dim = 784
17 | n_l1 = 1000
18 | n_l2 = 1000
19 | z_dim = 2
20 | batch_size = 100
21 | n_epochs = 1000
22 | learning_rate = 0.001
23 | beta1 = 0.9
24 | results_path = './Results/Adversarial_Autoencoder'
25 |
26 | # Placeholders for input data and the targets
27 | x_input = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_dim], name='Input')
28 | x_target = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_dim], name='Target')
29 | real_distribution = tf.placeholder(dtype=tf.float32, shape=[batch_size, z_dim], name='Real_distribution')
30 | decoder_input = tf.placeholder(dtype=tf.float32, shape=[1, z_dim], name='Decoder_input')
31 |
32 |
33 | def form_results():
34 | """
35 | Forms folders for each run to store the tensorboard files, saved models and the log files.
36 | :return: three string pointing to tensorboard, saved models and log paths respectively.
37 | """
38 | folder_name = "/{0}_{1}_{2}_{3}_{4}_{5}_Adversarial_Autoencoder". \
39 | format(datetime.datetime.now(), z_dim, learning_rate, batch_size, n_epochs, beta1)
40 | tensorboard_path = results_path + folder_name + '/Tensorboard'
41 | saved_model_path = results_path + folder_name + '/Saved_models/'
42 | log_path = results_path + folder_name + '/log'
43 | if not os.path.exists(results_path + folder_name):
44 | os.mkdir(results_path + folder_name)
45 | os.mkdir(tensorboard_path)
46 | os.mkdir(saved_model_path)
47 | os.mkdir(log_path)
48 | return tensorboard_path, saved_model_path, log_path
49 |
50 |
51 | def generate_image_grid(sess, op):
52 | """
53 | Generates a grid of images by passing a set of numbers to the decoder and getting its output.
54 | :param sess: Tensorflow Session required to get the decoder output
55 | :param op: Operation that needs to be called inorder to get the decoder output
56 | :return: None, displays a matplotlib window with all the merged images.
57 | """
58 | x_points = np.arange(-10, 10, 1.5).astype(np.float32)
59 | y_points = np.arange(-10, 10, 1.5).astype(np.float32)
60 |
61 | nx, ny = len(x_points), len(y_points)
62 | plt.subplot()
63 | gs = gridspec.GridSpec(nx, ny, hspace=0.05, wspace=0.05)
64 |
65 | for i, g in enumerate(gs):
66 | z = np.concatenate(([x_points[int(i / ny)]], [y_points[int(i % nx)]]))
67 | z = np.reshape(z, (1, 2))
68 | x = sess.run(op, feed_dict={decoder_input: z})
69 | ax = plt.subplot(g)
70 | img = np.array(x.tolist()).reshape(28, 28)
71 | ax.imshow(img, cmap='gray')
72 | ax.set_xticks([])
73 | ax.set_yticks([])
74 | ax.set_aspect('auto')
75 | plt.show()
76 |
77 |
78 | def dense(x, n1, n2, name):
79 | """
80 | Used to create a dense layer.
81 | :param x: input tensor to the dense layer
82 | :param n1: no. of input neurons
83 | :param n2: no. of output neurons
84 | :param name: name of the entire dense layer.i.e, variable scope name.
85 | :return: tensor with shape [batch_size, n2]
86 | """
87 | with tf.variable_scope(name, reuse=None):
88 | weights = tf.get_variable("weights", shape=[n1, n2],
89 | initializer=tf.random_normal_initializer(mean=0., stddev=0.01))
90 | bias = tf.get_variable("bias", shape=[n2], initializer=tf.constant_initializer(0.0))
91 | out = tf.add(tf.matmul(x, weights), bias, name='matmul')
92 | return out
93 |
94 |
95 | # The autoencoder network
96 | def encoder(x, reuse=False):
97 | """
98 | Encode part of the autoencoder.
99 | :param x: input to the autoencoder
100 | :param reuse: True -> Reuse the encoder variables, False -> Create or search of variables before creating
101 | :return: tensor which is the hidden latent variable of the autoencoder.
102 | """
103 | if reuse:
104 | tf.get_variable_scope().reuse_variables()
105 | with tf.name_scope('Encoder'):
106 | e_dense_1 = tf.nn.relu(dense(x, input_dim, n_l1, 'e_dense_1'))
107 | e_dense_2 = tf.nn.relu(dense(e_dense_1, n_l1, n_l2, 'e_dense_2'))
108 | latent_variable = dense(e_dense_2, n_l2, z_dim, 'e_latent_variable')
109 | return latent_variable
110 |
111 |
112 | def decoder(x, reuse=False):
113 | """
114 | Decoder part of the autoencoder.
115 | :param x: input to the decoder
116 | :param reuse: True -> Reuse the decoder variables, False -> Create or search of variables before creating
117 | :return: tensor which should ideally be the input given to the encoder.
118 | """
119 | if reuse:
120 | tf.get_variable_scope().reuse_variables()
121 | with tf.name_scope('Decoder'):
122 | d_dense_1 = tf.nn.relu(dense(x, z_dim, n_l2, 'd_dense_1'))
123 | d_dense_2 = tf.nn.relu(dense(d_dense_1, n_l2, n_l1, 'd_dense_2'))
124 | output = tf.nn.sigmoid(dense(d_dense_2, n_l1, input_dim, 'd_output'))
125 | return output
126 |
127 |
128 | def discriminator(x, reuse=False):
129 | """
130 | Discriminator that is used to match the posterior distribution with a given prior distribution.
131 | :param x: tensor of shape [batch_size, z_dim]
132 | :param reuse: True -> Reuse the discriminator variables,
133 | False -> Create or search of variables before creating
134 | :return: tensor of shape [batch_size, 1]
135 | """
136 | if reuse:
137 | tf.get_variable_scope().reuse_variables()
138 | with tf.name_scope('Discriminator'):
139 | dc_den1 = tf.nn.relu(dense(x, z_dim, n_l1, name='dc_den1'))
140 | dc_den2 = tf.nn.relu(dense(dc_den1, n_l1, n_l2, name='dc_den2'))
141 | output = dense(dc_den2, n_l2, 1, name='dc_output')
142 | return output
143 |
144 |
145 | def train(train_model=True):
146 | """
147 | Used to train the autoencoder by passing in the necessary inputs.
148 | :param train_model: True -> Train the model, False -> Load the latest trained model and show the image grid.
149 | :return: does not return anything
150 | """
151 | with tf.variable_scope(tf.get_variable_scope()):
152 | encoder_output = encoder(x_input)
153 | decoder_output = decoder(encoder_output)
154 |
155 | with tf.variable_scope(tf.get_variable_scope()):
156 | d_real = discriminator(real_distribution)
157 | d_fake = discriminator(encoder_output, reuse=True)
158 |
159 | with tf.variable_scope(tf.get_variable_scope()):
160 | decoder_image = decoder(decoder_input, reuse=True)
161 |
162 | # Autoencoder loss
163 | autoencoder_loss = tf.reduce_mean(tf.square(x_target - decoder_output))
164 |
165 | # Discrimminator Loss
166 | dc_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_real), logits=d_real))
167 | dc_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(d_fake), logits=d_fake))
168 | dc_loss = dc_loss_fake + dc_loss_real
169 |
170 | # Generator loss
171 | generator_loss = tf.reduce_mean(
172 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_fake), logits=d_fake))
173 |
174 | all_variables = tf.trainable_variables()
175 | dc_var = [var for var in all_variables if 'dc_' in var.name]
176 | en_var = [var for var in all_variables if 'e_' in var.name]
177 |
178 | # Optimizers
179 | autoencoder_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
180 | beta1=beta1).minimize(autoencoder_loss)
181 | discriminator_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
182 | beta1=beta1).minimize(dc_loss, var_list=dc_var)
183 | generator_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
184 | beta1=beta1).minimize(generator_loss, var_list=en_var)
185 |
186 | init = tf.global_variables_initializer()
187 |
188 | # Reshape immages to display them
189 | input_images = tf.reshape(x_input, [-1, 28, 28, 1])
190 | generated_images = tf.reshape(decoder_output, [-1, 28, 28, 1])
191 |
192 | # Tensorboard visualization
193 | tf.summary.scalar(name='Autoencoder Loss', tensor=autoencoder_loss)
194 | tf.summary.scalar(name='Discriminator Loss', tensor=dc_loss)
195 | tf.summary.scalar(name='Generator Loss', tensor=generator_loss)
196 | tf.summary.histogram(name='Encoder Distribution', values=encoder_output)
197 | tf.summary.histogram(name='Real Distribution', values=real_distribution)
198 | tf.summary.image(name='Input Images', tensor=input_images, max_outputs=10)
199 | tf.summary.image(name='Generated Images', tensor=generated_images, max_outputs=10)
200 | summary_op = tf.summary.merge_all()
201 |
202 | # Saving the model
203 | saver = tf.train.Saver()
204 | step = 0
205 | with tf.Session() as sess:
206 | if train_model:
207 | tensorboard_path, saved_model_path, log_path = form_results()
208 | sess.run(init)
209 | writer = tf.summary.FileWriter(logdir=tensorboard_path, graph=sess.graph)
210 | for i in range(n_epochs):
211 | n_batches = int(mnist.train.num_examples / batch_size)
212 | print("------------------Epoch {}/{}------------------".format(i, n_epochs))
213 | for b in range(1, n_batches + 1):
214 | z_real_dist = np.random.randn(batch_size, z_dim) * 5.
215 | batch_x, _ = mnist.train.next_batch(batch_size)
216 | sess.run(autoencoder_optimizer, feed_dict={x_input: batch_x, x_target: batch_x})
217 | sess.run(discriminator_optimizer,
218 | feed_dict={x_input: batch_x, x_target: batch_x, real_distribution: z_real_dist})
219 | sess.run(generator_optimizer, feed_dict={x_input: batch_x, x_target: batch_x})
220 | if b % 50 == 0:
221 | a_loss, d_loss, g_loss, summary = sess.run(
222 | [autoencoder_loss, dc_loss, generator_loss, summary_op],
223 | feed_dict={x_input: batch_x, x_target: batch_x,
224 | real_distribution: z_real_dist})
225 | writer.add_summary(summary, global_step=step)
226 | print("Epoch: {}, iteration: {}".format(i, b))
227 | print("Autoencoder Loss: {}".format(a_loss))
228 | print("Discriminator Loss: {}".format(d_loss))
229 | print("Generator Loss: {}".format(g_loss))
230 | with open(log_path + '/log.txt', 'a') as log:
231 | log.write("Epoch: {}, iteration: {}\n".format(i, b))
232 | log.write("Autoencoder Loss: {}\n".format(a_loss))
233 | log.write("Discriminator Loss: {}\n".format(d_loss))
234 | log.write("Generator Loss: {}\n".format(g_loss))
235 | step += 1
236 |
237 | saver.save(sess, save_path=saved_model_path, global_step=step)
238 | else:
239 | # Get the latest results folder
240 | all_results = os.listdir(results_path)
241 | all_results.sort()
242 | saver.restore(sess, save_path=tf.train.latest_checkpoint(results_path + '/' + all_results[-1] + '/Saved_models/'))
243 | generate_image_grid(sess, op=decoder_image)
244 |
245 | if __name__ == '__main__':
246 | train(train_model=True)
247 |
--------------------------------------------------------------------------------
/autoencoder.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import datetime
4 | import os
5 | import matplotlib.pyplot as plt
6 | from matplotlib import gridspec
7 | from tensorflow.examples.tutorials.mnist import input_data
8 |
9 | # Get the MNIST data
10 | mnist = input_data.read_data_sets('./Data', one_hot=True)
11 |
12 | # Parameters
13 | input_dim = 784
14 | n_l1 = 1000
15 | n_l2 = 1000
16 | z_dim = 2
17 | batch_size = 100
18 | n_epochs = 1000
19 | learning_rate = 0.001
20 | beta1 = 0.9
21 | results_path = './Results/Autoencoder'
22 |
23 |
24 | # Placeholders for input data and the targets
25 | x_input = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_dim], name='Input')
26 | x_target = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_dim], name='Target')
27 | decoder_input = tf.placeholder(dtype=tf.float32, shape=[1, z_dim], name='Decoder_input')
28 |
29 |
30 | def generate_image_grid(sess, op):
31 | """
32 | Generates a grid of images by passing a set of numbers to the decoder and getting its output.
33 | :param sess: Tensorflow Session required to get the decoder output
34 | :param op: Operation that needs to be called inorder to get the decoder output
35 | :return: None, displays a matplotlib window with all the merged images.
36 | """
37 | x_points = np.arange(0, 1, 1.5).astype(np.float32)
38 | y_points = np.arange(0, 1, 1.5).astype(np.float32)
39 |
40 | nx, ny = len(x_points), len(y_points)
41 | plt.subplot()
42 | gs = gridspec.GridSpec(nx, ny, hspace=0.05, wspace=0.05)
43 |
44 | for i, g in enumerate(gs):
45 | z = np.concatenate(([x_points[int(i / ny)]], [y_points[int(i % nx)]]))
46 | z = np.reshape(z, (1, 2))
47 | x = sess.run(op, feed_dict={decoder_input: z})
48 | ax = plt.subplot(g)
49 | img = np.array(x.tolist()).reshape(28, 28)
50 | ax.imshow(img, cmap='gray')
51 | ax.set_xticks([])
52 | ax.set_yticks([])
53 | ax.set_aspect('auto')
54 | plt.show()
55 |
56 |
57 | def form_results():
58 | """
59 | Forms folders for each run to store the tensorboard files, saved models and the log files.
60 | :return: three string pointing to tensorboard, saved models and log paths respectively.
61 | """
62 | folder_name = "/{0}_{1}_{2}_{3}_{4}_{5}_autoencoder". \
63 | format(datetime.datetime.now(), z_dim, learning_rate, batch_size, n_epochs, beta1)
64 | tensorboard_path = results_path + folder_name + '/Tensorboard'
65 | saved_model_path = results_path + folder_name + '/Saved_models/'
66 | log_path = results_path + folder_name + '/log'
67 | if not os.path.exists(results_path + folder_name):
68 | os.mkdir(results_path + folder_name)
69 | os.mkdir(tensorboard_path)
70 | os.mkdir(saved_model_path)
71 | os.mkdir(log_path)
72 | return tensorboard_path, saved_model_path, log_path
73 |
74 |
75 | def dense(x, n1, n2, name):
76 | """
77 | Used to create a dense layer.
78 | :param x: input tensor to the dense layer
79 | :param n1: no. of input neurons
80 | :param n2: no. of output neurons
81 | :param name: name of the entire dense layer.i.e, variable scope name.
82 | :return: tensor with shape [batch_size, n2]
83 | """
84 | with tf.variable_scope(name, reuse=None):
85 | weights = tf.get_variable("weights", shape=[n1, n2],
86 | initializer=tf.random_normal_initializer(mean=0., stddev=0.01))
87 | bias = tf.get_variable("bias", shape=[n2], initializer=tf.constant_initializer(0.0))
88 | out = tf.add(tf.matmul(x, weights), bias, name='matmul')
89 | return out
90 |
91 |
92 | # The autoencoder network
93 | def encoder(x, reuse=False):
94 | """
95 | Encode part of the autoencoder
96 | :param x: input to the autoencoder
97 | :param reuse: True -> Reuse the encoder variables, False -> Create or search of variables before creating
98 | :return: tensor which is the hidden latent variable of the autoencoder.
99 | """
100 | if reuse:
101 | tf.get_variable_scope().reuse_variables()
102 | with tf.name_scope('Encoder'):
103 | e_dense_1 = tf.nn.relu(dense(x, input_dim, n_l1, 'e_dense_1'))
104 | e_dense_2 = tf.nn.relu(dense(e_dense_1, n_l1, n_l2, 'e_dense_2'))
105 | latent_variable = dense(e_dense_2, n_l2, z_dim, 'e_latent_variable')
106 | return latent_variable
107 |
108 |
109 | def decoder(x, reuse=False):
110 | """
111 | Decoder part of the autoencoder
112 | :param x: input to the decoder
113 | :param reuse: True -> Reuse the decoder variables, False -> Create or search of variables before creating
114 | :return: tensor which should ideally be the input given to the encoder.
115 | """
116 | if reuse:
117 | tf.get_variable_scope().reuse_variables()
118 | with tf.name_scope('Decoder'):
119 | d_dense_1 = tf.nn.relu(dense(x, z_dim, n_l2, 'd_dense_1'))
120 | d_dense_2 = tf.nn.relu(dense(d_dense_1, n_l2, n_l1, 'd_dense_2'))
121 | output = tf.nn.sigmoid(dense(d_dense_2, n_l1, input_dim, 'd_output'))
122 | return output
123 |
124 |
125 | def train(train_model):
126 | """
127 | Used to train the autoencoder by passing in the necessary inputs.
128 | :param train_model: True -> Train the model, False -> Load the latest trained model and show the image grid.
129 | :return: does not return anything
130 | """
131 | with tf.variable_scope(tf.get_variable_scope()):
132 | encoder_output = encoder(x_input)
133 | decoder_output = decoder(encoder_output)
134 |
135 | with tf.variable_scope(tf.get_variable_scope()):
136 | decoder_image = decoder(decoder_input, reuse=True)
137 |
138 | # Loss
139 | loss = tf.reduce_mean(tf.square(x_target - decoder_output))
140 |
141 | # Optimizer
142 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1).minimize(loss)
143 | init = tf.global_variables_initializer()
144 |
145 | # Visualization
146 | tf.summary.scalar(name='Loss', tensor=loss)
147 | tf.summary.histogram(name='Encoder Distribution', values=encoder_output)
148 | input_images = tf.reshape(x_input, [-1, 28, 28, 1])
149 | generated_images = tf.reshape(decoder_output, [-1, 28, 28, 1])
150 | tf.summary.image(name='Input Images', tensor=input_images, max_outputs=10)
151 | tf.summary.image(name='Generated Images', tensor=generated_images, max_outputs=10)
152 | summary_op = tf.summary.merge_all()
153 |
154 | # Saving the model
155 | saver = tf.train.Saver()
156 | step = 0
157 | with tf.Session() as sess:
158 | sess.run(init)
159 | if train_model:
160 | tensorboard_path, saved_model_path, log_path = form_results()
161 | writer = tf.summary.FileWriter(logdir=tensorboard_path, graph=sess.graph)
162 | for i in range(n_epochs):
163 | n_batches = int(mnist.train.num_examples / batch_size)
164 | for b in range(n_batches):
165 | batch_x, _ = mnist.train.next_batch(batch_size)
166 | sess.run(optimizer, feed_dict={x_input: batch_x, x_target: batch_x})
167 | if b % 50 == 0:
168 | batch_loss, summary = sess.run([loss, summary_op], feed_dict={x_input: batch_x, x_target: batch_x})
169 | writer.add_summary(summary, global_step=step)
170 | print("Loss: {}".format(batch_loss))
171 | print("Epoch: {}, iteration: {}".format(i, b))
172 | with open(log_path + '/log.txt', 'a') as log:
173 | log.write("Epoch: {}, iteration: {}\n".format(i, b))
174 | log.write("Loss: {}\n".format(batch_loss))
175 | step += 1
176 | saver.save(sess, save_path=saved_model_path, global_step=step)
177 | print("Model Trained!")
178 | print("Tensorboard Path: {}".format(tensorboard_path))
179 | print("Log Path: {}".format(log_path + '/log.txt'))
180 | print("Saved Model Path: {}".format(saved_model_path))
181 | else:
182 | all_results = os.listdir(results_path)
183 | all_results.sort()
184 | saver.restore(sess,
185 | save_path=tf.train.latest_checkpoint(results_path + '/' + all_results[-1] + '/Saved_models/'))
186 | generate_image_grid(sess, op=decoder_image)
187 |
188 | if __name__ == '__main__':
189 | train(train_model=True)
190 |
--------------------------------------------------------------------------------
/basic_nn_classifier.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import os
4 | import datetime
5 | from tensorflow.examples.tutorials.mnist import input_data
6 |
7 | # Parameters
8 | input_dim = 784
9 | n_l1 = 1000
10 | n_l2 = 1000
11 | batch_size = 100
12 | n_epochs = 1000
13 | learning_rate = 0.001
14 | beta1 = 0.9
15 | z_dim = 'NA'
16 | results_path = './Results/Basic_NN_Classifier'
17 | n_labels = 10
18 | n_labeled = 1000
19 |
20 | # Get MNIST data
21 | mnist = input_data.read_data_sets('./Data', one_hot=True)
22 |
23 | # Placeholders
24 | x_input = tf.placeholder(dtype=tf.float32, shape=[None, 784])
25 | y_target = tf.placeholder(dtype=tf.float32, shape=[None, 10])
26 |
27 |
28 | def form_results():
29 | """
30 | Forms folders for each run to store the tensorboard files, saved models and the log files.
31 | :return: three string pointing to tensorboard, saved models and log paths respectively.
32 | """
33 | folder_name = "/{0}_{1}_{2}_{3}_{4}_{5}_Basic_NN_Classifier". \
34 | format(datetime.datetime.now(), z_dim, learning_rate, batch_size, n_epochs, beta1)
35 | tensorboard_path = results_path + folder_name + '/Tensorboard'
36 | saved_model_path = results_path + folder_name + '/Saved_models/'
37 | log_path = results_path + folder_name + '/log'
38 | if not os.path.exists(results_path + folder_name):
39 | os.mkdir(results_path + folder_name)
40 | os.mkdir(tensorboard_path)
41 | os.mkdir(saved_model_path)
42 | os.mkdir(log_path)
43 | return tensorboard_path, saved_model_path, log_path
44 |
45 |
46 | def next_batch(x, y, batch_size):
47 | """
48 | Used to return a random batch from the given inputs.
49 | :param x: Input images of shape [None, 784]
50 | :param y: Input labels of shape [None, 10]
51 | :param batch_size: integer, batch size of images and labels to return
52 | :return: x -> [batch_size, 784], y-> [batch_size, 10]
53 | """
54 | index = np.arange(n_labeled)
55 | random_index = np.random.permutation(index)[:batch_size]
56 | return x[random_index], y[random_index]
57 |
58 |
59 | def dense(x, n1, n2, name):
60 | """
61 | Used to create a dense layer.
62 | :param x: input tensor to the dense layer
63 | :param n1: no. of input neurons
64 | :param n2: no. of output neurons
65 | :param name: name of the entire dense layer.
66 | :return: tensor with shape [batch_size, n2]
67 | """
68 | with tf.name_scope(name):
69 | weights = tf.Variable(tf.random_normal(shape=[n1, n2], mean=0., stddev=0.01), name='weights')
70 | bias = tf.Variable(tf.zeros(shape=[n2]), name='bias')
71 | output = tf.add(tf.matmul(x, weights), bias, name='output')
72 | return output
73 |
74 |
75 | # Dense Network
76 | def dense_nn(x):
77 | """
78 | Network used to classify MNIST digits.
79 | :param x: tensor with shape [batch_size, 784], input to the dense fully connected layer.
80 | :return: [batch_size, 10], logits of dense fully connected.
81 | """
82 | dense_1 = tf.nn.dropout(tf.nn.relu(dense(x, input_dim, n_l1, 'dense_1')), keep_prob=0.25)
83 | dense_2 = tf.nn.dropout(tf.nn.relu(dense(dense_1, n_l1, n_l2, 'dense_2')), keep_prob=0.25)
84 | dense_3 = dense(dense_2, n_l2, n_labels, 'dense_3')
85 | return dense_3
86 |
87 |
88 | def train():
89 | """
90 | Used to train the autoencoder by passing in the necessary inputs.
91 | :return: does not return anything
92 | """
93 | dense_output = dense_nn(x_input)
94 |
95 | # Loss function
96 | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=dense_output, labels=y_target))
97 |
98 | # Optimizer
99 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1).minimize(loss)
100 |
101 | # Accuracy
102 | pred_op = tf.equal(tf.argmax(dense_output, 1), tf.argmax(y_target, 1))
103 | accuracy = tf.reduce_mean(tf.cast(pred_op, dtype=tf.float32))
104 |
105 | # Summary
106 | tf.summary.scalar(name='Loss', tensor=loss)
107 | tf.summary.scalar(name='Accuracy', tensor=accuracy)
108 | summary_op = tf.summary.merge_all()
109 |
110 | saver = tf.train.Saver()
111 |
112 | init = tf.global_variables_initializer()
113 |
114 | step = 0
115 | with tf.Session() as sess:
116 | tensorboard_path, saved_model_path, log_path = form_results()
117 | x_l, y_l = mnist.test.next_batch(n_labeled)
118 | writer = tf.summary.FileWriter(logdir=tensorboard_path, graph=sess.graph)
119 | sess.run(init)
120 | for e in range(1, n_epochs + 1):
121 | n_batches = int(n_labeled / batch_size)
122 | for b in range(1, n_batches + 1):
123 | batch_x_l, batch_y_l = next_batch(x_l, y_l, batch_size=batch_size)
124 | sess.run(optimizer, feed_dict={x_input: batch_x_l, y_target: batch_y_l})
125 | if b % 5 == 0:
126 | loss_, summary = sess.run([loss, summary_op], feed_dict={x_input: batch_x_l, y_target: batch_y_l})
127 | writer.add_summary(summary, step)
128 | print("Epoch: {} Iteration: {}".format(e, b))
129 | print("Loss: {}".format(loss_))
130 | with open(log_path + '/log.txt', 'a') as log:
131 | log.write("Epoch: {}, iteration: {}\n".format(e, b))
132 | log.write("Loss: {}\n".format(loss_))
133 | step += 1
134 | acc = 0
135 | num_batches = int(mnist.validation.num_examples / batch_size)
136 | for j in range(num_batches):
137 | # Classify unseen validation data instead of test data or train data
138 | batch_x_l, batch_y_l = mnist.validation.next_batch(batch_size=batch_size)
139 | val_acc = sess.run(accuracy, feed_dict={x_input: batch_x_l, y_target: batch_y_l})
140 | acc += val_acc
141 | acc /= num_batches
142 | print("Classification Accuracy: {}".format(acc))
143 | with open(log_path + '/log.txt', 'a') as log:
144 | log.write("Classification Accuracy: {}".format(acc))
145 | saver.save(sess, save_path=saved_model_path, global_step=step)
146 |
147 | if __name__ == '__main__':
148 | train()
149 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==2.2.2
2 | numpy==1.14.2
3 | tensorflow-gpu==1.7.0
--------------------------------------------------------------------------------
/semi_supervised_adversarial_autoencoder.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import datetime
4 | import os
5 | import argparse
6 | import matplotlib.pyplot as plt
7 | from matplotlib import gridspec
8 | from tensorflow.examples.tutorials.mnist import input_data
9 |
10 | # Get the MNIST data
11 | mnist = input_data.read_data_sets('./Data', one_hot=True)
12 |
13 | # Parameters
14 | input_dim = 784
15 | n_l1 = 1000
16 | n_l2 = 1000
17 | z_dim = 10
18 | batch_size = 100
19 | n_epochs = 1000
20 | learning_rate = 0.001
21 | beta1 = 0.9
22 | results_path = './Results/Semi_Supervised'
23 | n_labels = 10
24 | n_labeled = 1000
25 |
26 | # Placeholders for input data and the targets
27 | x_input = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_dim], name='Input')
28 | x_input_l = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_dim], name='Labeled_Input')
29 | y_input = tf.placeholder(dtype=tf.float32, shape=[batch_size, n_labels], name='Labels')
30 | x_target = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_dim], name='Target')
31 | real_distribution = tf.placeholder(dtype=tf.float32, shape=[batch_size, z_dim], name='Real_distribution')
32 | categorial_distribution = tf.placeholder(dtype=tf.float32, shape=[batch_size, n_labels],
33 | name='Categorical_distribution')
34 | manual_decoder_input = tf.placeholder(dtype=tf.float32, shape=[1, z_dim + n_labels], name='Decoder_input')
35 |
36 |
37 | def form_results():
38 | """
39 | Forms folders for each run to store the tensorboard files, saved models and the log files.
40 | :return: three string pointing to tensorboard, saved models and log paths respectively.
41 | """
42 | folder_name = "/{0}_{1}_{2}_{3}_{4}_{5}_Semi_Supervised". \
43 | format(datetime.datetime.now(), z_dim, learning_rate, batch_size, n_epochs, beta1)
44 | tensorboard_path = results_path + folder_name + '/Tensorboard'
45 | saved_model_path = results_path + folder_name + '/Saved_models/'
46 | log_path = results_path + folder_name + '/log'
47 | if not os.path.exists(results_path + folder_name):
48 | os.mkdir(results_path + folder_name)
49 | os.mkdir(tensorboard_path)
50 | os.mkdir(saved_model_path)
51 | os.mkdir(log_path)
52 | return tensorboard_path, saved_model_path, log_path
53 |
54 |
55 | def generate_image_grid(sess, op):
56 | """
57 | Generates a grid of images by passing a set of numbers to the decoder and getting its output.
58 | :param sess: Tensorflow Session required to get the decoder output
59 | :param op: Operation that needs to be called inorder to get the decoder output
60 | :return: None, displays a matplotlib window with all the merged images.
61 | """
62 | nx, ny = 10, 10
63 | random_inputs = np.random.randn(10, z_dim) * 5.
64 | sample_y = np.identity(10)
65 | plt.subplot()
66 | gs = gridspec.GridSpec(nx, ny, hspace=0.05, wspace=0.05)
67 | i = 0
68 | for r in random_inputs:
69 | for t in sample_y:
70 | r, t = np.reshape(r, (1, z_dim)), np.reshape(t, (1, n_labels))
71 | dec_input = np.concatenate((t, r), 1)
72 | x = sess.run(op, feed_dict={manual_decoder_input: dec_input})
73 | ax = plt.subplot(gs[i])
74 | i += 1
75 | img = np.array(x.tolist()).reshape(28, 28)
76 | ax.imshow(img, cmap='gray')
77 | ax.set_xticks([])
78 | ax.set_yticks([])
79 | ax.set_aspect('auto')
80 | plt.show()
81 |
82 |
83 | def dense(x, n1, n2, name):
84 | """
85 | Used to create a dense layer.
86 | :param x: input tensor to the dense layer
87 | :param n1: no. of input neurons
88 | :param n2: no. of output neurons
89 | :param name: name of the entire dense layer.i.e, variable scope name.
90 | :return: tensor with shape [batch_size, n2]
91 | """
92 | with tf.variable_scope(name, reuse=None):
93 | weights = tf.get_variable("weights", shape=[n1, n2],
94 | initializer=tf.random_normal_initializer(mean=0., stddev=0.01))
95 | bias = tf.get_variable("bias", shape=[n2], initializer=tf.constant_initializer(0.0))
96 | out = tf.add(tf.matmul(x, weights), bias, name='matmul')
97 | return out
98 |
99 |
100 | # The autoencoder network
101 | def encoder(x, reuse=False, supervised=False):
102 | """
103 | Encode part of the autoencoder.
104 | :param x: input to the autoencoder
105 | :param reuse: True -> Reuse the encoder variables, False -> Create or search of variables before creating
106 | :param supervised: True -> returns output without passing it through softmax,
107 | False -> returns output after passing it through softmax.
108 | :return: tensor which is the classification output and a hidden latent variable of the autoencoder.
109 | """
110 | if reuse:
111 | tf.get_variable_scope().reuse_variables()
112 | with tf.name_scope('Encoder'):
113 | e_dense_1 = tf.nn.relu(dense(x, input_dim, n_l1, 'e_dense_1'))
114 | e_dense_2 = tf.nn.relu(dense(e_dense_1, n_l1, n_l2, 'e_dense_2'))
115 | latent_variable = dense(e_dense_2, n_l2, z_dim, 'e_latent_variable')
116 | cat_op = dense(e_dense_2, n_l2, n_labels, 'e_label')
117 | if not supervised:
118 | softmax_label = tf.nn.softmax(logits=cat_op, name='e_softmax_label')
119 | else:
120 | softmax_label = cat_op
121 | return softmax_label, latent_variable
122 |
123 |
124 | def decoder(x, reuse=False):
125 | """
126 | Decoder part of the autoencoder.
127 | :param x: input to the decoder
128 | :param reuse: True -> Reuse the decoder variables, False -> Create or search of variables before creating
129 | :return: tensor which should ideally be the input given to the encoder.
130 | """
131 | if reuse:
132 | tf.get_variable_scope().reuse_variables()
133 | with tf.name_scope('Decoder'):
134 | d_dense_1 = tf.nn.relu(dense(x, z_dim + n_labels, n_l2, 'd_dense_1'))
135 | d_dense_2 = tf.nn.relu(dense(d_dense_1, n_l2, n_l1, 'd_dense_2'))
136 | output = tf.nn.sigmoid(dense(d_dense_2, n_l1, input_dim, 'd_output'))
137 | return output
138 |
139 |
140 | def discriminator_gauss(x, reuse=False):
141 | """
142 | Discriminator that is used to match the posterior distribution with a given gaussian distribution.
143 | :param x: tensor of shape [batch_size, z_dim]
144 | :param reuse: True -> Reuse the discriminator variables,
145 | False -> Create or search of variables before creating
146 | :return: tensor of shape [batch_size, 1]
147 | """
148 | if reuse:
149 | tf.get_variable_scope().reuse_variables()
150 | with tf.name_scope('Discriminator_Gauss'):
151 | dc_den1 = tf.nn.relu(dense(x, z_dim, n_l1, name='dc_g_den1'))
152 | dc_den2 = tf.nn.relu(dense(dc_den1, n_l1, n_l2, name='dc_g_den2'))
153 | output = dense(dc_den2, n_l2, 1, name='dc_g_output')
154 | return output
155 |
156 |
157 | def discriminator_categorical(x, reuse=False):
158 | """
159 | Discriminator that is used to match the posterior distribution with a given categorical distribution.
160 | :param x: tensor of shape [batch_size, n_labels]
161 | :param reuse: True -> Reuse the discriminator variables,
162 | False -> Create or search of variables before creating
163 | :return: tensor of shape [batch_size, 1]
164 | """
165 | if reuse:
166 | tf.get_variable_scope().reuse_variables()
167 | with tf.name_scope('Discriminator_Categorial'):
168 | dc_den1 = tf.nn.relu(dense(x, n_labels, n_l1, name='dc_c_den1'))
169 | dc_den2 = tf.nn.relu(dense(dc_den1, n_l1, n_l2, name='dc_c_den2'))
170 | output = dense(dc_den2, n_l2, 1, name='dc_c_output')
171 | return output
172 |
173 |
174 | def next_batch(x, y, batch_size):
175 | """
176 | Used to return a random batch from the given inputs.
177 | :param x: Input images of shape [None, 784]
178 | :param y: Input labels of shape [None, 10]
179 | :param batch_size: integer, batch size of images and labels to return
180 | :return: x -> [batch_size, 784], y-> [batch_size, 10]
181 | """
182 | index = np.arange(n_labeled)
183 | random_index = np.random.permutation(index)[:batch_size]
184 | return x[random_index], y[random_index]
185 |
186 |
187 | def train(train_model=True):
188 | """
189 | Used to train the autoencoder by passing in the necessary inputs.
190 | :param train_model: True -> Train the model, False -> Load the latest trained model and show the image grid.
191 | :return: does not return anything
192 | """
193 |
194 | # Reconstruction Phase
195 | with tf.variable_scope(tf.get_variable_scope()):
196 | encoder_output_label, encoder_output_latent = encoder(x_input)
197 | # Concat class label and the encoder output
198 | decoder_input = tf.concat([encoder_output_label, encoder_output_latent], 1)
199 | decoder_output = decoder(decoder_input)
200 |
201 | # Regularization Phase
202 | with tf.variable_scope(tf.get_variable_scope()):
203 | d_g_real = discriminator_gauss(real_distribution)
204 | d_g_fake = discriminator_gauss(encoder_output_latent, reuse=True)
205 |
206 | with tf.variable_scope(tf.get_variable_scope()):
207 | d_c_real = discriminator_categorical(categorial_distribution)
208 | d_c_fake = discriminator_categorical(encoder_output_label, reuse=True)
209 |
210 | # Semi-Supervised Classification Phase
211 | with tf.variable_scope(tf.get_variable_scope()):
212 | encoder_output_label_, _ = encoder(x_input_l, reuse=True, supervised=True)
213 |
214 | # Generate output images
215 | with tf.variable_scope(tf.get_variable_scope()):
216 | decoder_image = decoder(manual_decoder_input, reuse=True)
217 |
218 | # Classification accuracy of encoder
219 | correct_pred = tf.equal(tf.argmax(encoder_output_label_, 1), tf.argmax(y_input, 1))
220 | accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
221 |
222 | # Autoencoder loss
223 | autoencoder_loss = tf.reduce_mean(tf.square(x_target - decoder_output))
224 |
225 | # Gaussian Discriminator Loss
226 | dc_g_loss_real = tf.reduce_mean(
227 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_g_real), logits=d_g_real))
228 | dc_g_loss_fake = tf.reduce_mean(
229 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(d_g_fake), logits=d_g_fake))
230 | dc_g_loss = dc_g_loss_fake + dc_g_loss_real
231 |
232 | # Categorical Discrimminator Loss
233 | dc_c_loss_real = tf.reduce_mean(
234 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_c_real), logits=d_c_real))
235 | dc_c_loss_fake = tf.reduce_mean(
236 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(d_c_fake), logits=d_c_fake))
237 | dc_c_loss = dc_c_loss_fake + dc_c_loss_real
238 |
239 | # Generator loss
240 | generator_g_loss = tf.reduce_mean(
241 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_g_fake), logits=d_g_fake))
242 | generator_c_loss = tf.reduce_mean(
243 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_c_fake), logits=d_c_fake))
244 | generator_loss = generator_c_loss + generator_g_loss
245 |
246 | # Supervised Encoder Loss
247 | supervised_encoder_loss = tf.reduce_mean(
248 | tf.nn.softmax_cross_entropy_with_logits(labels=y_input, logits=encoder_output_label_))
249 |
250 | all_variables = tf.trainable_variables()
251 | dc_g_var = [var for var in all_variables if 'dc_g_' in var.name]
252 | dc_c_var = [var for var in all_variables if 'dc_c_' in var.name]
253 | en_var = [var for var in all_variables if 'e_' in var.name]
254 |
255 | # Optimizers
256 | autoencoder_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
257 | beta1=beta1).minimize(autoencoder_loss)
258 | discriminator_g_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
259 | beta1=beta1).minimize(dc_g_loss, var_list=dc_g_var)
260 | discriminator_c_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
261 | beta1=beta1).minimize(dc_c_loss, var_list=dc_c_var)
262 | generator_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
263 | beta1=beta1).minimize(generator_loss, var_list=en_var)
264 | supervised_encoder_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
265 | beta1=beta1).minimize(supervised_encoder_loss,
266 | var_list=en_var)
267 |
268 | init = tf.global_variables_initializer()
269 |
270 | # Reshape immages to display them
271 | input_images = tf.reshape(x_input, [-1, 28, 28, 1])
272 | generated_images = tf.reshape(decoder_output, [-1, 28, 28, 1])
273 |
274 | # Tensorboard visualization
275 | tf.summary.scalar(name='Autoencoder Loss', tensor=autoencoder_loss)
276 | tf.summary.scalar(name='Discriminator gauss Loss', tensor=dc_g_loss)
277 | tf.summary.scalar(name='Discriminator categorical Loss', tensor=dc_c_loss)
278 | tf.summary.scalar(name='Generator Loss', tensor=generator_loss)
279 | tf.summary.scalar(name='Supervised Encoder Loss', tensor=supervised_encoder_loss)
280 | tf.summary.histogram(name='Encoder Gauss Distribution', values=encoder_output_latent)
281 | tf.summary.histogram(name='Real Gauss Distribution', values=real_distribution)
282 | tf.summary.histogram(name='Encoder Categorical Distribution', values=encoder_output_label)
283 | tf.summary.histogram(name='Real Categorical Distribution', values=categorial_distribution)
284 | tf.summary.image(name='Input Images', tensor=input_images, max_outputs=10)
285 | tf.summary.image(name='Generated Images', tensor=generated_images, max_outputs=10)
286 | summary_op = tf.summary.merge_all()
287 |
288 | # Saving the model
289 | saver = tf.train.Saver()
290 | step = 0
291 | with tf.Session() as sess:
292 | if train_model:
293 | tensorboard_path, saved_model_path, log_path = form_results()
294 | sess.run(init)
295 | writer = tf.summary.FileWriter(logdir=tensorboard_path, graph=sess.graph)
296 | x_l, y_l = mnist.test.next_batch(n_labeled)
297 | for i in range(n_epochs):
298 | n_batches = int(n_labeled / batch_size)
299 | print("------------------Epoch {}/{}------------------".format(i, n_epochs))
300 | for b in range(1, n_batches + 1):
301 | z_real_dist = np.random.randn(batch_size, z_dim) * 5.
302 | real_cat_dist = np.random.randint(low=0, high=10, size=batch_size)
303 | real_cat_dist = np.eye(n_labels)[real_cat_dist]
304 | batch_x_ul, _ = mnist.train.next_batch(batch_size)
305 | batch_x_l, batch_y_l = next_batch(x_l, y_l, batch_size=batch_size)
306 | sess.run(autoencoder_optimizer, feed_dict={x_input: batch_x_ul, x_target: batch_x_ul})
307 | sess.run(discriminator_g_optimizer,
308 | feed_dict={x_input: batch_x_ul, x_target: batch_x_ul, real_distribution: z_real_dist})
309 | sess.run(discriminator_c_optimizer,
310 | feed_dict={x_input: batch_x_ul, x_target: batch_x_ul,
311 | categorial_distribution: real_cat_dist})
312 | sess.run(generator_optimizer, feed_dict={x_input: batch_x_ul, x_target: batch_x_ul})
313 | sess.run(supervised_encoder_optimizer, feed_dict={x_input_l: batch_x_l, y_input: batch_y_l})
314 | if b % 5 == 0:
315 | a_loss, d_g_loss, d_c_loss, g_loss, s_loss, summary = sess.run(
316 | [autoencoder_loss, dc_g_loss, dc_c_loss, generator_loss, supervised_encoder_loss,
317 | summary_op],
318 | feed_dict={x_input: batch_x_ul, x_target: batch_x_ul,
319 | real_distribution: z_real_dist, y_input: batch_y_l, x_input_l: batch_x_l,
320 | categorial_distribution: real_cat_dist})
321 | writer.add_summary(summary, global_step=step)
322 | print("Epoch: {}, iteration: {}".format(i, b))
323 | print("Autoencoder Loss: {}".format(a_loss))
324 | print("Discriminator Gauss Loss: {}".format(d_g_loss))
325 | print("Discriminator Categorical Loss: {}".format(d_c_loss))
326 | print("Generator Loss: {}".format(g_loss))
327 | print("Supervised Loss: {}\n".format(s_loss))
328 | with open(log_path + '/log.txt', 'a') as log:
329 | log.write("Epoch: {}, iteration: {}\n".format(i, b))
330 | log.write("Autoencoder Loss: {}\n".format(a_loss))
331 | log.write("Discriminator Gauss Loss: {}".format(d_g_loss))
332 | log.write("Discriminator Categorical Loss: {}".format(d_c_loss))
333 | log.write("Generator Loss: {}\n".format(g_loss))
334 | log.write("Supervised Loss: {}".format(s_loss))
335 | step += 1
336 | acc = 0
337 | num_batches = int(mnist.validation.num_examples/batch_size)
338 | for j in range(num_batches):
339 | # Classify unseen validation data instead of test data or train data
340 | batch_x_l, batch_y_l = mnist.validation.next_batch(batch_size=batch_size)
341 | encoder_acc = sess.run(accuracy, feed_dict={x_input_l: batch_x_l, y_input: batch_y_l})
342 | acc += encoder_acc
343 | acc /= num_batches
344 | print("Encoder Classification Accuracy: {}".format(acc))
345 | with open(log_path + '/log.txt', 'a') as log:
346 | log.write("Encoder Classification Accuracy: {}".format(acc))
347 | saver.save(sess, save_path=saved_model_path, global_step=step)
348 | else:
349 | # Get the latest results folder
350 | all_results = os.listdir(results_path)
351 | all_results.sort()
352 | saver.restore(sess, save_path=tf.train.latest_checkpoint(results_path + '/' +
353 | all_results[-1] + '/Saved_models/'))
354 | generate_image_grid(sess, op=decoder_image)
355 |
356 |
357 | if __name__ == '__main__':
358 | parser = argparse.ArgumentParser(description="Autoencoder Train Parameter")
359 | parser.add_argument('--train', '-t', type=bool, default=True,
360 | help='Set to True to train a new model, False to load weights and display image grid')
361 | args = parser.parse_args()
362 | train(train_model=args.train)
363 |
--------------------------------------------------------------------------------
/supervised_adversarial_autoencoder.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import datetime
4 | import os
5 | import argparse
6 | import matplotlib.pyplot as plt
7 | from matplotlib import gridspec
8 | from tensorflow.examples.tutorials.mnist import input_data
9 |
10 | # Get the MNIST data
11 | mnist = input_data.read_data_sets('./Data', one_hot=True)
12 |
13 | # Parameters
14 | input_dim = 784
15 | n_l1 = 1000
16 | n_l2 = 1000
17 | z_dim = 15
18 | batch_size = 100
19 | n_epochs = 1000
20 | learning_rate = 0.001
21 | beta1 = 0.9
22 | results_path = './Results/Supervised'
23 | n_labels = 10
24 |
25 | # Placeholders for input data and the targets
26 | x_input = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_dim], name='Input')
27 | y_input = tf.placeholder(dtype=tf.float32, shape=[batch_size, n_labels], name='Labels')
28 | x_target = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_dim], name='Target')
29 | real_distribution = tf.placeholder(dtype=tf.float32, shape=[batch_size, z_dim], name='Real_distribution')
30 | manual_decoder_input = tf.placeholder(dtype=tf.float32, shape=[1, z_dim + n_labels], name='Decoder_input')
31 |
32 |
33 | def form_results():
34 | """
35 | Forms folders for each run to store the tensorboard files, saved models and the log files.
36 | :return: three string pointing to tensorboard, saved models and log paths respectively.
37 | """
38 | folder_name = "/{0}_{1}_{2}_{3}_{4}_{5}_Supervised". \
39 | format(datetime.datetime.now(), z_dim, learning_rate, batch_size, n_epochs, beta1)
40 | tensorboard_path = results_path + folder_name + '/Tensorboard'
41 | saved_model_path = results_path + folder_name + '/Saved_models/'
42 | log_path = results_path + folder_name + '/log'
43 | if not os.path.exists(results_path + folder_name):
44 | os.mkdir(results_path + folder_name)
45 | os.mkdir(tensorboard_path)
46 | os.mkdir(saved_model_path)
47 | os.mkdir(log_path)
48 | return tensorboard_path, saved_model_path, log_path
49 |
50 |
51 | def generate_image_grid(sess, op):
52 | """
53 | Generates a grid of images by passing a set of numbers to the decoder and getting its output.
54 | :param sess: Tensorflow Session required to get the decoder output
55 | :param op: Operation that needs to be called inorder to get the decoder output
56 | :return: None, displays a matplotlib window with all the merged images.
57 | """
58 | nx, ny = 10, 10
59 | random_inputs = np.random.randn(10, z_dim) * 5.
60 | sample_y = np.identity(10)
61 | plt.subplot()
62 | gs = gridspec.GridSpec(nx, ny, hspace=0.05, wspace=0.05)
63 | i = 0
64 | for r in random_inputs:
65 | for t in sample_y:
66 | r, t = np.reshape(r, (1, z_dim)), np.reshape(t, (1, n_labels))
67 | dec_input = np.concatenate((t, r), 1)
68 | x = sess.run(op, feed_dict={manual_decoder_input: dec_input})
69 | ax = plt.subplot(gs[i])
70 | i += 1
71 | img = np.array(x.tolist()).reshape(28, 28)
72 | ax.imshow(img, cmap='gray')
73 | ax.set_xticks([])
74 | ax.set_yticks([])
75 | ax.set_aspect('auto')
76 | plt.show()
77 |
78 |
79 | def dense(x, n1, n2, name):
80 | """
81 | Used to create a dense layer.
82 | :param x: input tensor to the dense layer
83 | :param n1: no. of input neurons
84 | :param n2: no. of output neurons
85 | :param name: name of the entire dense layer.i.e, variable scope name.
86 | :return: tensor with shape [batch_size, n2]
87 | """
88 | with tf.variable_scope(name, reuse=None):
89 | weights = tf.get_variable("weights", shape=[n1, n2],
90 | initializer=tf.random_normal_initializer(mean=0., stddev=0.01))
91 | bias = tf.get_variable("bias", shape=[n2], initializer=tf.constant_initializer(0.0))
92 | out = tf.add(tf.matmul(x, weights), bias, name='matmul')
93 | return out
94 |
95 |
96 | # The autoencoder network
97 | def encoder(x, reuse=False):
98 | """
99 | Encode part of the autoencoder.
100 | :param x: input to the autoencoder
101 | :param reuse: True -> Reuse the encoder variables, False -> Create or search of variables before creating
102 | :param supervised: True -> returns output without passing it through softmax,
103 | False -> returns output after passing it through softmax.
104 | :return: tensor which is the classification output and a hidden latent variable of the autoencoder.
105 | """
106 | if reuse:
107 | tf.get_variable_scope().reuse_variables()
108 | with tf.name_scope('Encoder'):
109 | e_dense_1 = tf.nn.relu(dense(x, input_dim, n_l1, 'e_dense_1'))
110 | e_dense_2 = tf.nn.relu(dense(e_dense_1, n_l1, n_l2, 'e_dense_2'))
111 | latent_variable = dense(e_dense_2, n_l2, z_dim, 'e_latent_variable')
112 | return latent_variable
113 |
114 |
115 | def decoder(x, reuse=False):
116 | """
117 | Decoder part of the autoencoder.
118 | :param x: input to the decoder
119 | :param reuse: True -> Reuse the decoder variables, False -> Create or search of variables before creating
120 | :return: tensor which should ideally be the input given to the encoder.
121 | """
122 | if reuse:
123 | tf.get_variable_scope().reuse_variables()
124 | with tf.name_scope('Decoder'):
125 | d_dense_1 = tf.nn.relu(dense(x, z_dim + n_labels, n_l2, 'd_dense_1'))
126 | d_dense_2 = tf.nn.relu(dense(d_dense_1, n_l2, n_l1, 'd_dense_2'))
127 | output = tf.nn.sigmoid(dense(d_dense_2, n_l1, input_dim, 'd_output'))
128 | return output
129 |
130 |
131 | def discriminator(x, reuse=False):
132 | """
133 | Discriminator that is used to match the posterior distribution with a given prior distribution.
134 | :param x: tensor of shape [batch_size, z_dim]
135 | :param reuse: True -> Reuse the discriminator variables,
136 | False -> Create or search of variables before creating
137 | :return: tensor of shape [batch_size, 1]
138 | """
139 | if reuse:
140 | tf.get_variable_scope().reuse_variables()
141 | with tf.name_scope('Discriminator'):
142 | dc_den1 = tf.nn.relu(dense(x, z_dim, n_l1, name='dc_den1'))
143 | dc_den2 = tf.nn.relu(dense(dc_den1, n_l1, n_l2, name='dc_den2'))
144 | output = dense(dc_den2, n_l2, 1, name='dc_output')
145 | return output
146 |
147 |
148 | def train(train_model=True):
149 | """
150 | Used to train the autoencoder by passing in the necessary inputs.
151 | :param train_model: True -> Train the model, False -> Load the latest trained model and show the image grid.
152 | :return: does not return anything
153 | """
154 | with tf.variable_scope(tf.get_variable_scope()):
155 | encoder_output = encoder(x_input)
156 | # Concat class label and the encoder output
157 | decoder_input = tf.concat([y_input, encoder_output], 1)
158 | decoder_output = decoder(decoder_input)
159 |
160 | with tf.variable_scope(tf.get_variable_scope()):
161 | d_real = discriminator(real_distribution)
162 | d_fake = discriminator(encoder_output, reuse=True)
163 |
164 | with tf.variable_scope(tf.get_variable_scope()):
165 | decoder_image = decoder(manual_decoder_input, reuse=True)
166 |
167 | # Autoencoder loss
168 | autoencoder_loss = tf.reduce_mean(tf.square(x_target - decoder_output))
169 |
170 | # Discriminator Loss
171 | dc_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_real), logits=d_real))
172 | dc_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(d_fake), logits=d_fake))
173 | dc_loss = dc_loss_fake + dc_loss_real
174 |
175 | # Generator loss
176 | generator_loss = tf.reduce_mean(
177 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_fake), logits=d_fake))
178 |
179 | all_variables = tf.trainable_variables()
180 | dc_var = [var for var in all_variables if 'dc_' in var.name]
181 | en_var = [var for var in all_variables if 'e_' in var.name]
182 |
183 | # Optimizers
184 | autoencoder_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
185 | beta1=beta1).minimize(autoencoder_loss)
186 | discriminator_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
187 | beta1=beta1).minimize(dc_loss, var_list=dc_var)
188 | generator_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
189 | beta1=beta1).minimize(generator_loss, var_list=en_var)
190 |
191 | init = tf.global_variables_initializer()
192 |
193 | # Reshape images to display them
194 | input_images = tf.reshape(x_input, [-1, 28, 28, 1])
195 | generated_images = tf.reshape(decoder_output, [-1, 28, 28, 1])
196 |
197 | # Tensorboard visualization
198 | tf.summary.scalar(name='Autoencoder Loss', tensor=autoencoder_loss)
199 | tf.summary.scalar(name='Discriminator Loss', tensor=dc_loss)
200 | tf.summary.scalar(name='Generator Loss', tensor=generator_loss)
201 | tf.summary.histogram(name='Encoder Distribution', values=encoder_output)
202 | tf.summary.histogram(name='Real Distribution', values=real_distribution)
203 | tf.summary.image(name='Input Images', tensor=input_images, max_outputs=10)
204 | tf.summary.image(name='Generated Images', tensor=generated_images, max_outputs=10)
205 | summary_op = tf.summary.merge_all()
206 |
207 | # Saving the model
208 | saver = tf.train.Saver()
209 | step = 0
210 | with tf.Session() as sess:
211 | if train_model:
212 | tensorboard_path, saved_model_path, log_path = form_results()
213 | sess.run(init)
214 | writer = tf.summary.FileWriter(logdir=tensorboard_path, graph=sess.graph)
215 | for i in range(n_epochs):
216 | n_batches = int(mnist.train.num_examples / batch_size)
217 | print("------------------Epoch {}/{}------------------".format(i, n_epochs))
218 | for b in range(1, n_batches + 1):
219 | z_real_dist = np.random.randn(batch_size, z_dim) * 5.
220 | batch_x, batch_y = mnist.train.next_batch(batch_size)
221 | sess.run(autoencoder_optimizer, feed_dict={x_input: batch_x, x_target: batch_x, y_input: batch_y})
222 | sess.run(discriminator_optimizer,
223 | feed_dict={x_input: batch_x, x_target: batch_x, real_distribution: z_real_dist})
224 | sess.run(generator_optimizer, feed_dict={x_input: batch_x, x_target: batch_x})
225 | if b % 50 == 0:
226 | a_loss, d_loss, g_loss, summary = sess.run(
227 | [autoencoder_loss, dc_loss, generator_loss, summary_op],
228 | feed_dict={x_input: batch_x, x_target: batch_x,
229 | real_distribution: z_real_dist, y_input: batch_y})
230 | writer.add_summary(summary, global_step=step)
231 | print("Epoch: {}, iteration: {}".format(i, b))
232 | print("Autoencoder Loss: {}".format(a_loss))
233 | print("Discriminator Loss: {}".format(d_loss))
234 | print("Generator Loss: {}".format(g_loss))
235 | with open(log_path + '/log.txt', 'a') as log:
236 | log.write("Epoch: {}, iteration: {}\n".format(i, b))
237 | log.write("Autoencoder Loss: {}\n".format(a_loss))
238 | log.write("Discriminator Loss: {}\n".format(d_loss))
239 | log.write("Generator Loss: {}\n".format(g_loss))
240 | step += 1
241 |
242 | saver.save(sess, save_path=saved_model_path, global_step=step)
243 | else:
244 | # Get the latest results folder
245 | all_results = os.listdir(results_path)
246 | all_results.sort()
247 | saver.restore(sess, save_path=tf.train.latest_checkpoint(results_path + '/' +
248 | all_results[-1] + '/Saved_models/'))
249 | generate_image_grid(sess, op=decoder_image)
250 |
251 |
252 | if __name__ == '__main__':
253 | parser = argparse.ArgumentParser(description="Autoencoder Train Parameter")
254 | parser.add_argument('--train', '-t', type=bool, default=True,
255 | help='Set to True to train a new model, False to load weights and display image grid')
256 | args = parser.parse_args()
257 | train(train_model=args.train)
258 |
--------------------------------------------------------------------------------