├── Image ├── PMLR_.png ├── Gaussian.png ├── labels_0.png ├── labels_1.png ├── labels_2.png ├── G_generated.png ├── Swiss_roll_.png ├── GM_generated.png ├── Original_image.png ├── S_R_generated.png ├── Supervised_AAE.png ├── Supervised_AAE_.png ├── Gaussian_mixture_.png ├── Restored_Semi_AAE.png └── Semisupervised_AAE_.png ├── plot.py ├── utils.py ├── prior.py ├── AAE.py ├── README.md ├── data_utils.py └── main.py /Image/PMLR_.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/PMLR_.png -------------------------------------------------------------------------------- /Image/Gaussian.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/Gaussian.png -------------------------------------------------------------------------------- /Image/labels_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/labels_0.png -------------------------------------------------------------------------------- /Image/labels_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/labels_1.png -------------------------------------------------------------------------------- /Image/labels_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/labels_2.png -------------------------------------------------------------------------------- /Image/G_generated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/G_generated.png -------------------------------------------------------------------------------- /Image/Swiss_roll_.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/Swiss_roll_.png -------------------------------------------------------------------------------- /Image/GM_generated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/GM_generated.png -------------------------------------------------------------------------------- /Image/Original_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/Original_image.png -------------------------------------------------------------------------------- /Image/S_R_generated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/S_R_generated.png -------------------------------------------------------------------------------- /Image/Supervised_AAE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/Supervised_AAE.png -------------------------------------------------------------------------------- /Image/Supervised_AAE_.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/Supervised_AAE_.png -------------------------------------------------------------------------------- /Image/Gaussian_mixture_.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/Gaussian_mixture_.png -------------------------------------------------------------------------------- /Image/Restored_Semi_AAE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/Restored_Semi_AAE.png -------------------------------------------------------------------------------- /Image/Semisupervised_AAE_.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/Semisupervised_AAE_.png -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import tensorflow as tf 3 | from data_utils import * 4 | 5 | def plot_2d_scatter(x,y,test_labels): 6 | plt.figure(figsize = (8,6)) 7 | plt.scatter(x,y, c = np.argmax(test_labels,1), marker ='.', edgecolor = 'none', cmap = discrete_cmap('jet')) 8 | plt.colorbar() 9 | plt.grid() 10 | if not tf.gfile.Exists("./Scatter"): 11 | tf.gfile.MakeDirs("./Scatter") 12 | plt.savefig('./Scatter/2D_latent_space.png') 13 | plt.close() 14 | 15 | def discrete_cmap(base_cmap =None): 16 | base = plt.cm.get_cmap(base_cmap) 17 | color_list = base(np.linspace(0,1,10)) 18 | cmap_name = base.name + str(10) 19 | return base.from_list(cmap_name,color_list,10) 20 | 21 | def plot_manifold_canvas(images, n, type, name): 22 | assert images.shape[0] == n**2, "n**2 should be number of images" 23 | height = images.shape[1] 24 | width = images.shape[2] # width = height 25 | x = np.linspace(-2, 2, n) 26 | y = np.linspace(-2, 2, n) 27 | 28 | if type == "MNIST": 29 | canvas = np.empty((n * height, n * height)) 30 | for i, yi in enumerate(x): 31 | for j, xi in enumerate(y): 32 | canvas[height*i: height*i + height, width*j: width*j + width] = np.reshape(images[n*i + j], [height, width]) 33 | plt.figure(figsize=(8, 8)) 34 | plt.imshow(canvas, cmap="gray") 35 | else: 36 | canvas = np.empty((n * height, n * height, 3)) 37 | for i, yi in enumerate(x): 38 | for j, xi in enumerate(y): 39 | canvas[height*i: height*i + height, width*j: width*j + width,:] = images[n*i + j] 40 | plt.figure(figsize=(8, 8)) 41 | plt.imshow(canvas) 42 | 43 | if not tf.gfile.Exists("./plot"): 44 | tf.gfile.MakeDirs("./plot") 45 | if not tf.gfile.Exists("./plot/PMLR"): 46 | tf.gfile.MakeDirs("./plot/PMLR") 47 | if not tf.gfile.Exists("./plot/PARR"): 48 | tf.gfile.MakeDirs("./plot/PARR") 49 | 50 | name = name + ".png" 51 | path = os.path.join("./plot", name) 52 | plt.savefig(path) 53 | print("saving location: %s" % (path)) 54 | plt.close() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | initializer = tf.contrib.layers.xavier_initializer() 4 | #initializer = tf.contrib.layers.variance_scaling_initializer(factor = 1.0) 5 | 6 | 7 | def conv(inputs,filters,name): 8 | net = tf.layers.conv2d(inputs = inputs, 9 | filters = filters, 10 | kernel_size = [3,3], 11 | strides = (1,1), 12 | padding ="SAME", 13 | kernel_initializer = initializer, 14 | name = name, 15 | reuse = tf.AUTO_REUSE) 16 | return net 17 | 18 | def maxpool(input,name): 19 | net = tf.nn.max_pool(value = input, ksize = [1,2,2,1], strides = [1,2,2,1], padding = "SAME", name = name) 20 | return net 21 | 22 | def bn(inputs,is_training,name): 23 | net = tf.contrib.layers.batch_norm(inputs, decay = 0.9, is_training = is_training, reuse = tf.AUTO_REUSE, scope = name) 24 | return net 25 | 26 | def leaky(input): 27 | return tf.nn.leaky_relu(input) 28 | 29 | def relu(input): 30 | return tf.nn.relu(input) 31 | 32 | def drop_out(input, keep_prob): 33 | 34 | return tf.nn.dropout(input, keep_prob) 35 | def dense(inputs, units, name): 36 | net = tf.layers.dense(inputs = inputs, 37 | units = units, 38 | reuse = tf.AUTO_REUSE, 39 | name = name, 40 | kernel_initializer = initializer) 41 | return net 42 | 43 | user_flags = [] 44 | 45 | def DEFINE_string(name, default_value, doc_string): 46 | tf.app.flags.DEFINE_string(name, default_value, doc_string) 47 | global user_flags 48 | user_flags.append(name) 49 | 50 | def DEFINE_integer(name, default_value, doc_string): 51 | tf.app.flags.DEFINE_integer(name, default_value, doc_string) 52 | global user_flags 53 | user_flags.append(name) 54 | 55 | def DEFINE_float(name, defualt_value, doc_string): 56 | tf.app.flags.DEFINE_float(name, defualt_value, doc_string) 57 | global user_flags 58 | user_flags.append(name) 59 | 60 | def DEFINE_boolean(name, default_value, doc_string): 61 | tf.app.flags.DEFINE_boolean(name, default_value, doc_string) 62 | global user_flags 63 | user_flags.append(name) 64 | 65 | def print_user_flags(line_limit = 100): 66 | print("-" * 80) 67 | 68 | global user_flags 69 | FLAGS = tf.app.flags.FLAGS 70 | 71 | for flag_name in sorted(user_flags): 72 | value = "{}".format(getattr(FLAGS, flag_name)) 73 | log_string = flag_name 74 | log_string += "." * (line_limit - len(flag_name) - len(value)) 75 | log_string += value 76 | print(log_string) 77 | 78 | return FLAGS -------------------------------------------------------------------------------- /prior.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from math import sin,cos,sqrt 3 | ## this code is borrowed from https://github.com/hwalsuklee/tensorflow-mnist-AAE/ 4 | 5 | 6 | 7 | def gaussian(batch_size, n_labels, n_dim, mean=0, var=1, use_label_info=False): 8 | np.random.seed(0) 9 | if use_label_info: 10 | if n_dim != 2 or n_labels != 10: 11 | raise Exception("n_dim must be 2 and n_labels must be 10.") 12 | 13 | def sample(n_labels): 14 | x, y = np.random.normal(mean, var, (2,)) 15 | angle = np.angle((x-mean) + 1j*(y-mean), deg=True) 16 | dist = np.sqrt((x-mean)**2+(y-mean)**2) 17 | 18 | # label 0 19 | if dist <1.0: 20 | label = 0 21 | else: 22 | label = ((int)((n_labels-1)*angle))//360 23 | 24 | if label<0: 25 | label+=n_labels-1 26 | 27 | label += 1 28 | 29 | return np.array([x, y]).reshape((2,)), label 30 | 31 | z = np.empty((batch_size, n_dim), dtype=np.float32) 32 | z_id = np.empty((batch_size), dtype=np.int32) 33 | for batch in range(batch_size): 34 | for zi in range((int)(n_dim/2)): 35 | a_sample, a_label = sample(n_labels) 36 | z[batch, zi*2:zi*2+2] = a_sample 37 | z_id[batch] = a_label 38 | return z, z_id 39 | else: 40 | z = np.random.normal(mean, var, (batch_size, n_dim)).astype(np.float32) 41 | return z 42 | 43 | def gaussian_mixture(batch_size, n_labels ,n_dim, x_var=0.5, y_var=0.1, label_indices=None): 44 | np.random.seed(0) 45 | if n_dim != 2: 46 | raise Exception("n_dim must be 2.") 47 | 48 | def sample(x, y, label, n_labels): 49 | shift = 1.4 50 | r = 2.0 * np.pi / float(n_labels) * float(label) 51 | new_x = x * cos(r) - y * sin(r) 52 | new_y = x * sin(r) + y * cos(r) 53 | new_x += shift * cos(r) 54 | new_y += shift * sin(r) 55 | return np.array([new_x, new_y]).reshape((2,)) 56 | 57 | x = np.random.normal(0, x_var, (batch_size, (int)(n_dim/2))) 58 | y = np.random.normal(0, y_var, (batch_size, (int)(n_dim/2))) 59 | z = np.empty((batch_size, n_dim), dtype=np.float32) 60 | for batch in range(batch_size): 61 | for zi in range((int)(n_dim/2)): 62 | if label_indices is not None: 63 | z[batch, zi*2:zi*2+2] = sample(x[batch, zi], y[batch, zi], label_indices[batch], n_labels) 64 | else: 65 | z[batch, zi*2:zi*2+2] = sample(x[batch, zi], y[batch, zi], np.random.randint(0, n_labels), n_labels) 66 | 67 | return z 68 | 69 | def swiss_roll(batch_size, n_labels, n_dim, label_indices=None): 70 | np.random.seed(0) 71 | if n_dim != 2: 72 | raise Exception("n_dim must be 2.") 73 | 74 | def sample(label, n_labels): 75 | uni = np.random.uniform(0.0, 1.0) / float(n_labels) + float(label) / float(n_labels) 76 | r = sqrt(uni) * 3.0 77 | rad = np.pi * 4.0 * sqrt(uni) 78 | x = r * cos(rad) 79 | y = r * sin(rad) 80 | return np.array([x, y]).reshape((2,)) 81 | 82 | z = np.zeros((batch_size, n_dim), dtype=np.float32) 83 | for batch in range(batch_size): 84 | for zi in range((int)(n_dim/2)): 85 | if label_indices is not None: 86 | z[batch, zi*2:zi*2+2] = sample(label_indices[batch], n_labels) 87 | else: 88 | z[batch, zi*2:zi*2+2] = sample(np.random.randint(0, n_labels), n_labels) 89 | return z -------------------------------------------------------------------------------- /AAE.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from utils import * 3 | from data_utils import * 4 | from prior import * 5 | 6 | class AAE: 7 | def __init__(self, conf, shape, n_labels): 8 | self.conf = conf 9 | self.mode = conf.model 10 | self.data = conf.data 11 | self.super_n_hidden = conf.super_n_hidden 12 | self.semi_n_hidden = conf.semi_n_hidden 13 | self.n_z = conf.n_z 14 | self.batch_size = conf.batch_size 15 | self.prior = conf.prior 16 | self.w = shape[1] 17 | self.h = shape[2] 18 | self.c = shape[3] 19 | self.length = self.h * self.w * self.c 20 | self.n_labels = n_labels 21 | 22 | def sup_encoder(self, X, keep_prob): # encoder for supervised AAE 23 | 24 | with tf.variable_scope("sup_encoder", reuse = tf.AUTO_REUSE): 25 | net = drop_out(relu(dense(X, self.super_n_hidden, name = "dense_1")), keep_prob) 26 | net = drop_out(relu(dense(net, self.super_n_hidden, name="dense_2")), keep_prob) 27 | net = dense(net, self.n_z, name ="dense_3") 28 | 29 | return net 30 | 31 | def sup_decoder(self, Z, keep_prob): # decoder for supervised AAE 32 | 33 | with tf.variable_scope("sup_decoder", reuse = tf.AUTO_REUSE): 34 | net = drop_out(relu(dense(Z, self.super_n_hidden, name = "dense_1")), keep_prob) 35 | net = drop_out(relu(dense(net, self.super_n_hidden, name="dense_2")), keep_prob) 36 | net = tf.nn.sigmoid(dense(net, self.length, name = "dense_3")) 37 | 38 | return net 39 | 40 | def discriminator(self,Z, keep_prob): # discriminator for supervised AAE 41 | 42 | with tf.variable_scope("discriminator", reuse = tf.AUTO_REUSE): 43 | net = drop_out(relu(dense(Z, self.super_n_hidden, name = "dense_1")), keep_prob) 44 | net = drop_out(relu(dense(net, self.super_n_hidden, name="dense_2")), keep_prob) 45 | logits = dense(net, 1, name ="dense_3") 46 | 47 | return logits 48 | 49 | def Sup_Adversarial_AutoEncoder(self, X, X_noised, Y, z_prior, z_id, keep_prob): 50 | 51 | X_flatten = tf.reshape(X, [-1, self.length]) 52 | X_flatten_noised = tf.reshape(X_noised, [-1, self.length]) 53 | 54 | z_generated = self.sup_encoder(X_flatten_noised, keep_prob) 55 | X_generated = self.sup_decoder(z_generated, keep_prob) 56 | 57 | negative_log_likelihood = tf.reduce_mean(tf.squared_difference(X_generated, X_flatten)) 58 | 59 | z_prior = tf.concat([z_prior, z_id], axis = 1) 60 | z_fake = tf.concat([z_generated, Y], axis = 1) 61 | D_real_logits = self.discriminator(z_prior, keep_prob) 62 | D_fake_logits = self.discriminator(z_fake, keep_prob) 63 | 64 | D_loss_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_fake_logits, labels = tf.zeros_like(D_fake_logits)) 65 | D_loss_true = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_real_logits, labels = tf.ones_like(D_real_logits)) 66 | 67 | G_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_fake_logits, labels = tf.ones_like(D_fake_logits)) 68 | 69 | D_loss = tf.reduce_mean(D_loss_fake) + tf.reduce_mean(D_loss_true) 70 | G_loss = tf.reduce_mean(G_loss) 71 | 72 | return z_generated, X_generated, negative_log_likelihood, D_loss, G_loss 73 | 74 | def semi_encoder(self, X, keep_prob, semi_supervised = False): 75 | 76 | with tf.variable_scope("semi_encoder", reuse = tf.AUTO_REUSE): 77 | net = drop_out(relu(dense(X, self.semi_n_hidden, name = "dense_1")), keep_prob) 78 | net = drop_out(relu(dense(net, self.semi_n_hidden, name="dense_2")), keep_prob) 79 | style = dense(net, self.n_z, name ="style") 80 | 81 | if semi_supervised is False: 82 | labels_generated = tf.nn.softmax(dense(net, self.n_labels, name = "labels")) 83 | else: 84 | labels_generated = dense(net, self.n_labels, name = "label_logits") 85 | 86 | return style, labels_generated 87 | 88 | def semi_decoder(self, Z, keep_prob): 89 | 90 | with tf.variable_scope("semi_decoder", reuse = tf.AUTO_REUSE): 91 | net = drop_out(relu(dense(Z, self.semi_n_hidden, name = "dense_1")), keep_prob) 92 | net = drop_out(relu(dense(net, self.semi_n_hidden, name="dense_2")), keep_prob) 93 | net = tf.nn.sigmoid(dense(net, self.length, name = "dense_3")) 94 | 95 | return net 96 | 97 | def semi_z_discriminator(self,Z, keep_prob): 98 | 99 | with tf.variable_scope("semi_z_discriminator", reuse = tf.AUTO_REUSE): 100 | net = drop_out(relu(dense(Z, self.semi_n_hidden, name="dense_1")), keep_prob) 101 | net = drop_out(relu(dense(net, self.semi_n_hidden, name="dense_2")), keep_prob) 102 | logits = dense(net, 1, name="dense_3") 103 | 104 | return logits 105 | 106 | def semi_y_discriminator(self, Y, keep_prob): 107 | 108 | with tf.variable_scope("semi_y_discriminator", reuse = tf.AUTO_REUSE): 109 | net = drop_out(relu(dense(Y, self.semi_n_hidden, name = "dense_1")), keep_prob) 110 | net = drop_out(relu(dense(net, self.semi_n_hidden, name="dense_2")), keep_prob) 111 | logits = dense(net, 1, name = "dense_3") 112 | 113 | return logits 114 | 115 | def Semi_Adversarial_AutoEncoder(self, X, X_noised, labels, labels_cat, z_prior, keep_prob): 116 | 117 | X_flatten = tf.reshape(X, [-1 , self.length]) 118 | X_noised_flatten = tf.reshape(X_noised, [-1, self.length]) 119 | 120 | style, labels_softmax = self.semi_encoder(X_noised_flatten, keep_prob, semi_supervised = False) 121 | latent_inputs = tf.concat([style, labels_softmax], axis = 1) 122 | X_generated = self.semi_decoder(latent_inputs, keep_prob) 123 | 124 | _, labels_generated = self.semi_encoder(X_noised_flatten, keep_prob, semi_supervised = True) 125 | 126 | D_Y_fake = self.semi_y_discriminator(labels_softmax, keep_prob) 127 | D_Y_real = self.semi_y_discriminator(labels_cat, keep_prob) 128 | 129 | D_Z_fake = self.semi_z_discriminator(style, keep_prob) 130 | D_Z_real = self.semi_z_discriminator(z_prior, keep_prob) 131 | 132 | negative_loglikelihood = tf.reduce_mean(tf.squared_difference(X_generated,X_flatten)) 133 | 134 | D_loss_y_real = tf.nn.sigmoid_cross_entropy_with_logits(logits=D_Y_real, labels=tf.ones_like(D_Y_real)) 135 | D_loss_y_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits=D_Y_fake, labels=tf.zeros_like(D_Y_fake)) 136 | D_loss_y = tf.reduce_mean(D_loss_y_real) + tf.reduce_mean(D_loss_y_fake) 137 | D_loss_z_real = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_Z_real, labels = tf.ones_like(D_Z_real)) 138 | D_loss_z_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_Z_fake, labels = tf.zeros_like(D_Z_fake)) 139 | D_loss_z = tf.reduce_mean(D_loss_z_real) + tf.reduce_mean(D_loss_z_fake) 140 | 141 | 142 | G_loss_y = tf.nn.sigmoid_cross_entropy_with_logits(logits=D_Y_fake, labels=tf.ones_like(D_Y_fake)) 143 | G_loss_z = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_Z_fake, labels = tf.ones_like(D_Z_fake)) 144 | G_loss = tf.reduce_mean(G_loss_y) + tf.reduce_mean(G_loss_z) 145 | 146 | CE_labels = tf.nn.softmax_cross_entropy_with_logits(logits = labels_generated, labels = labels) 147 | CE_labels = tf.reduce_mean(CE_labels) 148 | 149 | 150 | return style, X_generated, negative_loglikelihood, D_loss_y, D_loss_z, G_loss, CE_labels -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Adversarial AutoEncoder(AAE)- Tensorflow 2 | 3 | I write the Tensorflow Code for Supervised AAE and SemiSupervised AAE 4 | 5 | ## Enviroment 6 | - OS: Ubuntu 16.04 7 | 8 | - Graphic Card /RAM : 1080TI /16G 9 | 10 | - Python 3.5 11 | 12 | - Tensorflow-gpu version: 1.4.0rc2 13 | 14 | - OpenCV 3.4.1 15 | 16 | ## Schematic of AAE 17 | 18 | ### Supervised AAE 19 | 20 | Drawing 21 | 22 | *** 23 | 24 | ### SemiSupervised AAE 25 | 26 | Drawing 27 | 28 | ## Code 29 | 30 | **Supervised Encoder** 31 | ```python 32 | def sup_encoder(self, X, keep_prob): # encoder for supervised AAE 33 | 34 | with tf.variable_scope("sup_encoder", reuse = tf.AUTO_REUSE): 35 | net = drop_out(relu(dense(X, self.super_n_hidden, name = "dense_1")), keep_prob) 36 | net = drop_out(relu(dense(net, self.super_n_hidden, name="dense_2")), keep_prob) 37 | net = dense(net, self.n_z, name ="dense_3") 38 | 39 | return net 40 | ``` 41 | 42 | **Supervised Decoder** 43 | ```python 44 | def sup_decoder(self, Z, keep_prob): # decoder for supervised AAE 45 | 46 | with tf.variable_scope("sup_decoder", reuse = tf.AUTO_REUSE): 47 | net = drop_out(relu(dense(Z, self.super_n_hidden, name = "dense_1")), keep_prob) 48 | net = drop_out(relu(dense(net, self.super_n_hidden, name="dense_2")), keep_prob) 49 | net = tf.nn.sigmoid(dense(net, self.length, name = "dense_3")) 50 | 51 | return net 52 | ``` 53 | 54 | **Supervised Discriminator** 55 | ```python 56 | def discriminator(self,Z, keep_prob): # discriminator for supervised AAE 57 | 58 | with tf.variable_scope("discriminator", reuse = tf.AUTO_REUSE): 59 | net = drop_out(relu(dense(Z, self.super_n_hidden, name = "dense_1")), keep_prob) 60 | net = drop_out(relu(dense(net, self.super_n_hidden, name="dense_2")), keep_prob) 61 | logits = dense(net, 1, name ="dense_3") 62 | 63 | return logits 64 | ``` 65 | 66 | **Supervised Adversarial AutoEncoder** 67 | ```python 68 | def Sup_Adversarial_AutoEncoder(self, X, X_noised, Y, z_prior, z_id, keep_prob): 69 | 70 | X_flatten = tf.reshape(X, [-1, self.length]) 71 | X_flatten_noised = tf.reshape(X_noised, [-1, self.length]) 72 | 73 | z_generated = self.sup_encoder(X_flatten_noised, keep_prob) 74 | X_generated = self.sup_decoder(z_generated, keep_prob) 75 | 76 | negative_log_likelihood = tf.reduce_mean(tf.squared_difference(X_generated, X_flatten)) 77 | 78 | z_prior = tf.concat([z_prior, z_id], axis = 1) 79 | z_fake = tf.concat([z_generated, Y], axis = 1) 80 | D_real_logits = self.discriminator(z_prior, keep_prob) 81 | D_fake_logits = self.discriminator(z_fake, keep_prob) 82 | 83 | D_loss_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_fake_logits, labels = tf.zeros_like(D_fake_logits)) 84 | D_loss_true = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_real_logits, labels = tf.ones_like(D_real_logits)) 85 | 86 | G_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_fake_logits, labels = tf.ones_like(D_fake_logits)) 87 | 88 | D_loss = tf.reduce_mean(D_loss_fake) + tf.reduce_mean(D_loss_true) 89 | G_loss = tf.reduce_mean(G_loss) 90 | 91 | return z_generated, X_generated, negative_log_likelihood, D_loss, G_loss 92 | ``` 93 | 94 | *** 95 | 96 | **SemiSupervised Encoder** 97 | ```python 98 | def semi_encoder(self, X, keep_prob, semi_supervised = False): 99 | 100 | with tf.variable_scope("semi_encoder", reuse = tf.AUTO_REUSE): 101 | net = drop_out(relu(dense(X, self.semi_n_hidden, name = "dense_1")), keep_prob) 102 | net = drop_out(relu(dense(net, self.semi_n_hidden, name="dense_2")), keep_prob) 103 | style = dense(net, self.n_z, name ="style") 104 | 105 | if semi_supervised is False: 106 | labels_generated = tf.nn.softmax(dense(net, self.n_labels, name = "labels")) 107 | else: 108 | labels_generated = dense(net, self.n_labels, name = "label_logits") 109 | 110 | return style, labels_generated 111 | ``` 112 | 113 | **SemiSupervised Decoder** 114 | ```python 115 | def semi_decoder(self, Z, keep_prob): 116 | 117 | with tf.variable_scope("semi_decoder", reuse = tf.AUTO_REUSE): 118 | net = drop_out(relu(dense(Z, self.semi_n_hidden, name = "dense_1")), keep_prob) 119 | net = drop_out(relu(dense(net, self.semi_n_hidden, name="dense_2")), keep_prob) 120 | net = tf.nn.sigmoid(dense(net, self.length, name = "dense_3")) 121 | 122 | return net 123 | ``` 124 | 125 | **SemiSupervised z Discriminator** 126 | ```python 127 | def semi_z_discriminator(self,Z, keep_prob): 128 | 129 | with tf.variable_scope("semi_z_discriminator", reuse = tf.AUTO_REUSE): 130 | net = drop_out(relu(dense(Z, self.semi_n_hidden, name="dense_1")), keep_prob) 131 | net = drop_out(relu(dense(net, self.semi_n_hidden, name="dense_2")), keep_prob) 132 | logits = dense(net, 1, name="dense_3") 133 | 134 | return logits 135 | ``` 136 | 137 | **SemiSupervised y Discriminator** 138 | ```python 139 | def semi_y_discriminator(self, Y, keep_prob): 140 | 141 | with tf.variable_scope("semi_y_discriminator", reuse = tf.AUTO_REUSE): 142 | net = drop_out(relu(dense(Y, self.semi_n_hidden, name = "dense_1")), keep_prob) 143 | net = drop_out(relu(dense(net, self.semi_n_hidden, name="dense_2")), keep_prob) 144 | logits = dense(net, 1, name = "dense_3") 145 | 146 | return logits 147 | ``` 148 | 149 | **SemiSupervised Adversarial AutoEncoder** 150 | ```python 151 | def Semi_Adversarial_AutoEncoder(self, X, X_noised, labels, labels_cat, z_prior, keep_prob): 152 | 153 | X_flatten = tf.reshape(X, [-1 , self.length]) 154 | X_noised_flatten = tf.reshape(X_noised, [-1, self.length]) 155 | 156 | style, labels_softmax = self.semi_encoder(X_noised_flatten, keep_prob, semi_supervised = False) 157 | latent_inputs = tf.concat([style, labels_softmax], axis = 1) 158 | X_generated = self.semi_decoder(latent_inputs, keep_prob) 159 | 160 | _, labels_generated = self.semi_encoder(X_noised_flatten, keep_prob, semi_supervised = True) 161 | 162 | D_Y_fake = self.semi_y_discriminator(labels_softmax, keep_prob) 163 | D_Y_real = self.semi_y_discriminator(labels_cat, keep_prob) 164 | 165 | D_Z_fake = self.semi_z_discriminator(style, keep_prob) 166 | D_Z_real = self.semi_z_discriminator(z_prior, keep_prob) 167 | 168 | negative_loglikelihood = tf.reduce_mean(tf.squared_difference(X_generated,X_flatten)) 169 | 170 | D_loss_y_real = tf.nn.sigmoid_cross_entropy_with_logits(logits=D_Y_real, labels=tf.ones_like(D_Y_real)) 171 | D_loss_y_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits=D_Y_fake, labels=tf.zeros_like(D_Y_fake)) 172 | D_loss_y = tf.reduce_mean(D_loss_y_real) + tf.reduce_mean(D_loss_y_fake) 173 | D_loss_z_real = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_Z_real, labels = tf.ones_like(D_Z_real)) 174 | D_loss_z_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_Z_fake, labels = tf.zeros_like(D_Z_fake)) 175 | D_loss_z = tf.reduce_mean(D_loss_z_real) + tf.reduce_mean(D_loss_z_fake) 176 | 177 | 178 | G_loss_y = tf.nn.sigmoid_cross_entropy_with_logits(logits=D_Y_fake, labels=tf.ones_like(D_Y_fake)) 179 | G_loss_z = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_Z_fake, labels = tf.ones_like(D_Z_fake)) 180 | G_loss = tf.reduce_mean(G_loss_y) + tf.reduce_mean(G_loss_z) 181 | 182 | CE_labels = tf.nn.softmax_cross_entropy_with_logits(logits = labels_generated, labels = labels) 183 | CE_labels = tf.reduce_mean(CE_labels) 184 | 185 | 186 | return style, X_generated, negative_loglikelihood, D_loss_y, D_loss_z, G_loss, CE_labels 187 | ``` 188 | 189 | ## Results 190 | 191 | **1. Restoring** 192 | ``` 193 | python main.py --model supervised --prior gaussian --n_z 20 194 | 195 | or 196 | 197 | python main.py --model semi_supervised --prior gaussian --n_z 20 198 | ``` 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 210 |
Original Images Restored via Supervised AAE Restored via Semisupervised AAE
207 | 208 | 209 |
211 | 212 | **2. 2D Latent Space** 213 | 214 | ***Target*** 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 227 |
Gaussian Gaussian Mixture Swiss Roll
224 | 225 | 226 |
228 | 229 | ***Coding Space of Supervised AAE*** 230 | ``` 231 | Test was performed using 10,000 number of test dataset not used for learning. 232 | 233 | python main.py --model supervised --prior gaussian_mixture --n_z 2 234 | ``` 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 247 |
Gaussian Gaussian Mixture Swiss Roll
244 | 245 | 246 |
248 | 249 | 250 | **3. Manifold Learning Result** 251 | 252 | ***Supervised AAE*** 253 | 254 | ``` 255 | python main.py --model supervised --prior gaussian_mixture --n_z 2 --PMLR True 256 | ``` 257 | 258 | 259 | 260 | 261 | 262 | 263 | 265 |
Manifold
264 |
266 | 267 | ***SemiSupervised AAE*** 268 | 269 | ``` 270 | python main.py --model semi_supervised --prior gaussian --n_z 2 --PMLR True 271 | 272 | 273 | The results suggest that when n_z is 2, SemiSupervised AAE can't extract label information from Input image very well. 274 | ``` 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 287 |
Manifold with a condition 0 Manifold with a condition 1 Manifold with a condition 2
284 | 285 | 286 |
288 | 289 | ## Reference 290 | 291 | ### Paper 292 | AAE: https://arxiv.org/abs/1511.05644 293 | 294 | GAN: https://arxiv.org/abs/1406.2661 295 | 296 | ### Github 297 | https://github.com/hwalsuklee/tensorflow-mnist-AAE 298 | 299 | https://github.com/MINGUKKANG/CVAE 300 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import random 4 | import gzip 5 | import tarfile 6 | import pickle 7 | import os 8 | from six.moves import urllib 9 | from plot import * 10 | 11 | class data_pipeline: 12 | def __init__(self,type): 13 | self.type = type 14 | self.debug = 0 15 | self.batch = 0 16 | 17 | if self.type == "MNIST": 18 | self.url = "http://yann.lecun.com/exdb/mnist/" 19 | self.debug =1 20 | self.n_train_images = 60000 21 | self.n_test_images = 10000 22 | self.n_channels = 1 23 | self.size = 28 24 | self.MNIST_filename = ["train-images-idx3-ubyte.gz", 25 | "train-labels-idx1-ubyte.gz", 26 | "t10k-images-idx3-ubyte.gz", 27 | "t10k-labels-idx1-ubyte.gz"] 28 | 29 | elif self.type == "CIFAR_10": 30 | self.url = "https://www.cs.toronto.edu/~kriz/" 31 | self.debug = 1 32 | self.n_train_images = 50000 33 | self.n_test_images = 10000 34 | self.n_channels = 3 35 | self.size = 32 36 | self.CIFAR_10_filename = ["cifar-10-python.tar.gz"] 37 | 38 | assert self.debug == 1, "Data type must be MNIST or CIFAR_10" 39 | 40 | def maybe_download(self, filename, filepath): 41 | if os.path.isfile(filepath) is True: 42 | print("Filename %s is already downloaded" % filename) 43 | else: 44 | filepath,_ = urllib.request.urlretrieve(self.url + filename, filepath) 45 | with tf.gfile.GFile(filepath) as f: 46 | size = f.size() 47 | print("Successfully download", filename, size, "bytes") 48 | return filepath 49 | 50 | 51 | def download_data(self): 52 | self.filepath_holder = [] 53 | 54 | if not tf.gfile.Exists("./Data"): 55 | tf.gfile.MakeDirs("./Data") 56 | 57 | if self.type == "MNIST": 58 | for i in self.MNIST_filename: 59 | filepath = os.path.join("./Data", i) 60 | self.maybe_download(i,filepath) 61 | self.filepath_holder.append(filepath) 62 | 63 | elif self.type == "CIFAR_10": 64 | for i in self.CIFAR_10_filename: 65 | filepath = os.path.join("./Data", i) 66 | self.maybe_download(i,filepath) 67 | self.filepath_holder.append(filepath) 68 | print("-" * 80) 69 | 70 | def extract_mnist_images(self, filepath, size, n_images,n_channels): 71 | print("Extracting and Reading ", filepath) 72 | 73 | with gzip.open(filepath) as bytestream: 74 | bytestream.read(16) 75 | buf = bytestream.read(size*size*n_images*n_channels) 76 | data = np.frombuffer(buf, dtype = np.uint8) 77 | data = np.reshape(data,[n_images, size, size, n_channels]) 78 | return data 79 | 80 | def extract_mnist_labels(self, filepath,n_images): 81 | print("Extracting and Reading ", filepath) 82 | 83 | with gzip.open(filepath) as bytestream: 84 | bytestream.read(8) 85 | buf = bytestream.read(1*n_images) 86 | labels = np.frombuffer(buf, dtype = np.uint8) 87 | one_hot_encoding = np.zeros((n_images, 10)) 88 | one_hot_encoding[np.arange(n_images), labels] = 1 89 | one_hot_encoding = np.reshape(one_hot_encoding, [-1,10]) 90 | return one_hot_encoding 91 | 92 | def extract_cifar_data(self,filepath, train_files,n_images): 93 | ## this code is from https://github.com/melodyguan/enas/blob/master/src/cifar10/data_utils.py 94 | images, labels = [], [] 95 | for file_name in train_files: 96 | full_name = os.path.join(filepath, file_name) 97 | with open(full_name, mode = "rb") as finp: 98 | data = pickle.load(finp, encoding = "bytes") 99 | batch_images = data[b'data'] 100 | batch_labels = np.array(data[b'labels']) 101 | images.append(batch_images) 102 | labels.append(batch_labels) 103 | images = np.concatenate(images, axis=0) 104 | labels = np.concatenate(labels, axis=0) 105 | one_hot_encoding = np.zeros((n_images, 10)) 106 | one_hot_encoding[np.arange(n_images), labels] = 1 107 | one_hot_encoding = np.reshape(one_hot_encoding, [-1, 10]) 108 | images = np.reshape(images, [-1, 3, 32, 32]) 109 | images = np.transpose(images, [0, 2, 3, 1]) 110 | 111 | return images, one_hot_encoding 112 | 113 | def extract_cifar_data_(self,filepath, num_valids=5000): 114 | print("Reading data") 115 | with tarfile.open(filepath, "r:gz") as tar: 116 | tar.extractall("./Data") 117 | images, labels = {}, {} 118 | train_files = [ 119 | "./cifar-10-batches-py/data_batch_1", 120 | "./cifar-10-batches-py/data_batch_2", 121 | "./cifar-10-batches-py/data_batch_3", 122 | "./cifar-10-batches-py/data_batch_4", 123 | "./cifar-10-batches-py/data_batch_5"] 124 | test_file = ["./cifar-10-batches-py/test_batch"] 125 | images["train"], labels["train"] = self.extract_cifar_data("./Data", train_files,self.n_train_images) 126 | 127 | if num_valids: 128 | images["valid"] = images["train"][-num_valids:] 129 | labels["valid"] = labels["train"][-num_valids:] 130 | 131 | images["train"] = images["train"][:-num_valids] 132 | labels["train"] = labels["train"][:-num_valids] 133 | else: 134 | images["valid"], labels["valid"] = None, None 135 | 136 | images["test"], labels["test"] = self.extract_cifar_data("./Data", test_file,self.n_test_images) 137 | return images, labels 138 | 139 | def apply_preprocessing(self, images, mode): 140 | mean = np.mean(images, axis =(0,1,2)) 141 | images = images/255 142 | print("%s_mean: " % mode, mean) 143 | return images 144 | 145 | def load_preprocess_data(self): 146 | self.download_data() 147 | if self.type == "MNIST": 148 | train_images = self.extract_mnist_images(self.filepath_holder[0],self.size, self.n_train_images, self.n_channels) 149 | train_labels = self.extract_mnist_labels(self.filepath_holder[1], self.n_train_images) 150 | self.valid_images = train_images[0:5000,:,:,:] 151 | self.valid_labels = train_labels[0:5000,:] 152 | self.train_images = train_images[5000:,:,:,:] 153 | self.train_labels = train_labels[5000:,:] 154 | self.test_images = self.extract_mnist_images(self.filepath_holder[2],self.size, self.n_test_images, self.n_channels) 155 | self.test_labels = self.extract_mnist_labels(self.filepath_holder[3], self.n_test_images) 156 | print("-" * 80) 157 | self.train_images = self.apply_preprocessing(images = self.train_images, mode = "train") 158 | self.valid_images = self.apply_preprocessing(images = self.valid_images, mode = "valid") 159 | self.test_images = self.apply_preprocessing(images = self.test_images, mode = "test") 160 | print("-" * 80) 161 | print("training size: ", np.shape(self.train_images),", ",np.shape(self.train_labels)) 162 | print("valid size: ", np.shape(self.valid_images), ", ", np.shape(self.valid_labels)) 163 | print("test size: ", np.shape(self.test_images), ", ", np.shape(self.test_labels)) 164 | else: 165 | images, labels = self.extract_cifar_data_(self.filepath_holder[0]) 166 | self.train_images = images["train"] 167 | self.train_labels = labels["train"] 168 | self.valid_images = images["valid"] 169 | self.valid_labels = labels["valid"] 170 | self.test_images = images["test"] 171 | self.test_labels = labels["test"] 172 | print("-" * 80) 173 | self.train_images = self.apply_preprocessing(images = self.train_images, mode = "train") 174 | self.valid_images = self.apply_preprocessing(images = self.valid_images, mode = "valid") 175 | self.test_images = self.apply_preprocessing(images = self.test_images, mode = "test") 176 | print("-" * 80) 177 | print("training size: ", np.shape(self.train_images),", ",np.shape(self.train_labels)) 178 | print("valid size: ", np.shape(self.valid_images), ", ", np.shape(self.valid_labels)) 179 | print("test size: ", np.shape(self.test_images), ", ", np.shape(self.test_labels)) 180 | 181 | return self.train_images, self.train_labels, self.valid_images, self.valid_labels, self.test_images, self.test_labels 182 | 183 | def make_noise(self,image): 184 | 185 | def gaussian_noise(image): 186 | size = np.shape(image) 187 | noise = np.random.normal(0,0.3, size = size) 188 | image = image + noise 189 | 190 | return image 191 | 192 | return gaussian_noise(image) 193 | 194 | def initialize_batch(self): 195 | self.batch = 0 196 | 197 | def next_batch(self, images, labels, batch_size, make_noise = None): 198 | 199 | if make_noise is False: 200 | self.length = len(images)//batch_size 201 | batch_xs = images[self.batch*batch_size: self.batch*batch_size + batch_size,:,:,:] 202 | batch_noised_xs = np.copy(batch_xs) 203 | batch_ys = labels[self.batch*batch_size: self.batch*batch_size + batch_size,:] 204 | self.batch += 1 205 | if self.batch == (self.length): 206 | self.batch = 0 207 | else: 208 | self.length = len(images)//batch_size 209 | batch_noised_xs = [] 210 | batch_xs = images[self.batch*batch_size: self.batch*batch_size + batch_size,:,:,:] 211 | batch_ys = labels[self.batch * batch_size: self.batch * batch_size + batch_size, :] 212 | 213 | if self.type == "MNIST": 214 | _ = np.reshape(batch_xs, [-1, self.size, self.size]) 215 | for i in range(batch_size): 216 | batch_noised_xs.append(self.make_noise(_[i])) 217 | batch_noised_xs = np.reshape(batch_noised_xs, [-1, self.size, self.size, self.n_channels]) 218 | else: 219 | for i in range(batch_size): 220 | batch_noised_xs.append(self.make_noise(batch_xs[i])) 221 | 222 | self.batch += 1 223 | if self.batch == (self.length): 224 | self.batch = 0 225 | 226 | return batch_xs, batch_noised_xs, batch_ys 227 | 228 | def get_total_batch(self,images, batch_size): 229 | self.batch_size = batch_size 230 | return len(images)//self.batch_size 231 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import time 4 | from utils import * 5 | from plot import * 6 | from AAE import * 7 | from data_utils import * 8 | 9 | DEFINE_string("model", "semi_supervised", "[supervised | semi_supervised]") 10 | DEFINE_string("data", "MNIST", "[MNIST | CIFAR_10]") 11 | DEFINE_string("prior", "gaussian", "[gaussain | gaussain_mixture | swiss_roll]") 12 | 13 | 14 | DEFINE_integer("super_n_hidden", 3000, "the number of elements for hidden layers") 15 | DEFINE_integer("semi_n_hidden", 3000, "teh number of elements for hidden layers") 16 | DEFINE_integer("n_epoch", 100, "number of Epoch for training") 17 | DEFINE_integer("n_z", 20, "Dimension of Latent variables") 18 | DEFINE_integer("num_samples",5000, "number of samples for semi supervised learning") 19 | DEFINE_integer("batch_size", 128, "Batch Size for training") 20 | 21 | DEFINE_float("keep_prob", 0.9, "dropout rate") 22 | DEFINE_float("lr_start", 0.001, "initial learning rate") 23 | DEFINE_float("lr_mid", 0.0001, "mid learning rate") 24 | DEFINE_float("lr_end", 0.0001, "final learning rate") 25 | 26 | DEFINE_boolean("noised", True, "") 27 | DEFINE_boolean("PMLR", True, "Boolean for plot manifold learning result") 28 | DEFINE_boolean("PARR", False, "Boolean for plot analogical reasoning result") 29 | 30 | conf = print_user_flags(line_limit = 100) 31 | print("-"*80) 32 | 33 | if conf.model == "supervised": 34 | 35 | data_pipeline = data_pipeline(conf.data) 36 | 37 | train_xs, train_ys, valid_xs, valid_ys, test_xs, test_ys = data_pipeline.load_preprocess_data() 38 | 39 | _, height, width, channel = np.shape(train_xs) 40 | n_cls = np.shape(train_ys)[1] 41 | 42 | X = tf.placeholder(dtype=tf.float32, shape=[None, height, width, channel], name="Inputs") 43 | X_noised = tf.placeholder(dtype=tf.float32, shape=[None, height, width, channel], name="Inputs_noised") 44 | Y = tf.placeholder(dtype=tf.float32, shape=[None, n_cls], name="Input_labels") 45 | z_prior = tf.placeholder(tf.float32, shape=[None, conf.n_z], name="z_prior") 46 | z_id = tf.placeholder(tf.float32, shape = [None, n_cls], name = "prior_labels") 47 | latent = tf.placeholder(tf.float32, shape = [None, conf.n_z], name = "latent_for_generation") 48 | keep_prob = tf.placeholder(dtype = tf.float32, name = "dropout_rate") 49 | lr_ = tf.placeholder(dtype = tf.float32, name = "learning_rate") 50 | global_step = tf.Variable(0, trainable=False) 51 | 52 | AAE = AAE(conf, [_, height, width, channel], n_cls) 53 | z_generated, X_generated, negative_log_likelihood, D_loss, G_loss = AAE.Sup_Adversarial_AutoEncoder(X, 54 | X_noised, 55 | Y, 56 | z_prior, 57 | z_id, 58 | keep_prob) 59 | images_PMLR = AAE.sup_decoder(latent, keep_prob) 60 | total_batch = data_pipeline.get_total_batch(train_xs, conf.batch_size) 61 | 62 | total_vars = tf.trainable_variables() 63 | var_AE = [var for var in total_vars if "encoder" or "decoder" in var.name] 64 | var_generator = [var for var in total_vars if "encoder" in var.name] 65 | var_discriminator = [var for var in total_vars if "discriminator" in var.name] 66 | 67 | op_AE = tf.train.AdamOptimizer(learning_rate = lr_).minimize(negative_log_likelihood, 68 | global_step = global_step, 69 | var_list = var_AE) 70 | 71 | op_D = tf.train.AdamOptimizer(learning_rate = lr_/5). minimize(D_loss, 72 | global_step = global_step, 73 | var_list = var_discriminator) 74 | op_G = tf.train.AdamOptimizer(learning_rate = lr_).minimize(G_loss, 75 | global_step = global_step, 76 | var_list = var_generator) 77 | 78 | batch_t_xs, batch_tn_xs, batch_t_ys = data_pipeline.next_batch(valid_xs, valid_ys, 100, make_noise= False) 79 | data_pipeline.initialize_batch() 80 | 81 | sess = tf.Session() 82 | sess.run(tf.initialize_all_variables()) 83 | 84 | start_time = time.time() 85 | for i in range(conf.n_epoch): 86 | likelihood = 0 87 | D_value = 0 88 | G_value = 0 89 | for j in range(total_batch): 90 | batch_xs, batch_noised_xs, batch_ys = data_pipeline.next_batch(train_xs, 91 | train_ys, 92 | conf.batch_size, 93 | make_noise=conf.noised) 94 | if conf.prior == "gaussian": 95 | z_prior_, z_id_ = gaussian(conf.batch_size, 96 | n_labels = n_cls, 97 | n_dim = conf.n_z, 98 | use_label_info = True) 99 | z_id_onehot = np.eye(n_cls)[z_id_].astype(np.float32) 100 | 101 | elif conf.prior == "gaussian_mixture": 102 | z_id_ = np.random.randint(0, n_cls, size=[conf.batch_size]) 103 | z_id_onehot = np.eye(n_cls)[z_id_].astype(np.float32) 104 | z_prior_ = gaussian_mixture(conf.batch_size, 105 | n_labels = n_cls, 106 | n_dim = conf.n_z, 107 | label_indices = z_id_) 108 | 109 | elif conf.prior == "swiss_roll": 110 | z_id_ = np.random.randint(0, n_cls, size=[conf.batch_size]) 111 | z_id_onehot = np.eye(n_cls)[z_id_].astype(np.float32) 112 | z_prior_ = swiss_roll(conf.batch_size, 113 | n_labels = n_cls, 114 | n_dim = conf.n_z, 115 | label_indices = z_id_) 116 | else: 117 | print("FLAGS.prior should be [gaussian, gaussian_mixture, swiss_roll]") 118 | 119 | if i <= 50: 120 | lr_value = conf.lr_start 121 | elif i <=100: 122 | lr_value = conf.lr_mid 123 | else: 124 | lr_value = conf.lr_end 125 | 126 | feed_dict = {X: batch_xs, 127 | X_noised: batch_noised_xs, 128 | Y: batch_ys, 129 | z_prior: z_prior_, 130 | z_id: z_id_onehot, 131 | lr_: lr_value, 132 | keep_prob: conf.keep_prob} 133 | 134 | # AutoEncoder phase 135 | l, _, g = sess.run([negative_log_likelihood, op_AE, global_step], feed_dict=feed_dict) 136 | 137 | # Discriminator phase 138 | l_D, _ = sess.run([D_loss, op_D], feed_dict = feed_dict) 139 | 140 | l_G, _ = sess.run([G_loss, op_G], feed_dict = feed_dict) 141 | 142 | likelihood += l/total_batch 143 | D_value += l_D/total_batch 144 | G_value += l_G/total_batch 145 | 146 | if i % 5 == 0 or i == (conf.n_epoch -1): 147 | images = sess.run(X_generated, feed_dict = {X:batch_t_xs, 148 | X_noised: batch_tn_xs, 149 | keep_prob: 1.0}) 150 | images = np.reshape(images, [-1, height, width, channel]) 151 | name = "Manifold_canvas_" + str(i) 152 | plot_manifold_canvas(images, 10, type = "MNIST", name = name) 153 | 154 | 155 | hour = int((time.time() - start_time) / 3600) 156 | min = int(((time.time() - start_time) - 3600 * hour) / 60) 157 | sec = int((time.time() - start_time) - 3600 * hour - 60 * min) 158 | print("Epoch: %3d lr_AE: %.5f loss_AE: %.4f Time: %d hour %d min %d sec" % (i, lr_value, likelihood, hour, min, sec)) 159 | print(" lr_D: %.5f loss_D: %.4f" % (lr_value/5, D_value)) 160 | print(" lr_G: %.5f loss_G: %.4f\n" % (lr_value, G_value)) 161 | 162 | ## code for 2D scatter plot 163 | if conf.n_z == 2: 164 | print("-" * 80) 165 | print("plot 2D Scatter Result") 166 | test_total_batch = data_pipeline.get_total_batch(test_xs, 128) 167 | data_pipeline.initialize_batch() 168 | latent_holder = [] 169 | for i in range(test_total_batch): 170 | batch_test_xs, batch_test_noised_xs, batch_test_ys = data_pipeline.next_batch(test_xs, 171 | test_ys, 172 | conf.batch_size, 173 | make_noise=False) 174 | feed_dict = {X: batch_test_xs, 175 | X_noised: batch_test_noised_xs, 176 | keep_prob: 1.0} 177 | 178 | latent_vars = sess.run(z_generated, feed_dict=feed_dict) 179 | latent_holder.append(latent_vars) 180 | latent_holder = np.concatenate(latent_holder, axis=0) 181 | plot_2d_scatter(latent_holder[:, 0], latent_holder[:, 1], test_ys[:len(latent_holder)]) 182 | 183 | if conf.PMLR is True: 184 | print("-" * 80) 185 | assert conf.n_z == 2, "Error: n_z should be 2" 186 | print("plot Manifold Learning Result") 187 | x_axis = np.linspace(-0.5, 0.5, 10) 188 | y_axis = np.linspace(0.5, -0.5, 10) 189 | z_holder = [] 190 | for i, yi in enumerate(y_axis): 191 | for j, xi in enumerate(x_axis): 192 | z_holder.append([xi, yi]) 193 | length = len(z_holder) 194 | MLR = sess.run(images_PMLR, feed_dict={latent: z_holder, keep_prob: 1.0}) 195 | MLR = np.reshape(MLR, [-1, height, width, channel]) 196 | p_name = "PMLR/PMLR" 197 | plot_manifold_canvas(MLR, 10, "MNIST", p_name) 198 | 199 | elif conf.model == "semi_supervised": 200 | 201 | Data = data_pipeline(conf.data) 202 | Data_semi = data_pipeline(conf.data) 203 | train_xs, train_ys, valid_xs, valid_ys, test_xs, test_ys = Data.load_preprocess_data() 204 | valid_xs, valid_ys = valid_xs[:conf.num_samples], valid_ys[:conf.num_samples] 205 | 206 | _, height, width, channel = np.shape(train_xs) 207 | n_cls = np.shape(train_ys)[1] 208 | 209 | X = tf.placeholder(dtype=tf.float32, shape=[None, height, width, channel], name="Input") 210 | X_noised = tf.placeholder(dtype=tf.float32, shape=[None, height, width, channel], name="Input_noised") 211 | Y = tf.placeholder(dtype=tf.float32, shape=[None, n_cls], name="Input_labels") 212 | Y_cat = tf.placeholder(dtype=tf.float32, shape=[None, n_cls], name="labels_cat") 213 | z_prior_ = tf.placeholder(dtype = tf.float32, shape = [None,conf.n_z], name = "z_prior" ) 214 | latent = tf.placeholder(dtype = tf.float32, shape = [None, conf.n_z + n_cls], name = "latent_for_generation") 215 | keep_prob = tf.placeholder(dtype=tf.float32, name="dropout_rate") 216 | lr_ = tf.placeholder(dtype=tf.float32, name="learning_rate") 217 | global_step = tf.Variable(0, trainable=False) 218 | 219 | AAE = AAE(conf, [_, height, width, channel], n_cls) 220 | 221 | style, X_generated, negative_log_likelihood, D_loss_y, D_loss_z, G_loss, CE_labels =AAE.Semi_Adversarial_AutoEncoder(X, 222 | X_noised, 223 | Y, 224 | Y_cat, 225 | z_prior_, 226 | keep_prob) 227 | images_PARR = AAE.semi_decoder(latent, keep_prob) 228 | images_manifold = AAE.semi_decoder(latent, keep_prob) 229 | 230 | total_batch = Data.get_total_batch(train_xs, conf.batch_size) 231 | 232 | total_vars = tf.trainable_variables() 233 | var_AE = [var for var in total_vars if "encoder" or "decoder" in var.name] 234 | var_z_discriminator = [var for var in total_vars if "z_discriminator" in var.name] 235 | var_y_discriminator = [var for var in total_vars if "y_discriminator" in var.name] 236 | var_generator = [var for var in total_vars if "encoder" in var.name] 237 | 238 | op_AE = tf.train.AdamOptimizer(learning_rate = lr_).minimize(negative_log_likelihood, global_step = global_step, var_list = var_AE) 239 | op_y_D = tf.train.AdamOptimizer(learning_rate = lr_/5).minimize(D_loss_y, global_step = global_step, var_list = var_y_discriminator) 240 | op_z_D = tf.train.AdamOptimizer(learning_rate = lr_/5).minimize(D_loss_z, global_step = global_step, var_list = var_z_discriminator) 241 | op_G = tf.train.AdamOptimizer(learning_rate = lr_).minimize(G_loss, global_step = global_step, var_list = var_generator) 242 | op_CE_labels = tf.train.AdamOptimizer(learning_rate = lr_).minimize(CE_labels, global_step = global_step, var_list = var_generator) 243 | 244 | batch_t_xs, batch_tn_xs, batch_t_ys = Data.next_batch(valid_xs, valid_ys, 100, make_noise = False) 245 | Data.initialize_batch() 246 | 247 | sess = tf.Session() 248 | sess.run(tf.initialize_all_variables()) 249 | 250 | start_time = time.time() 251 | for i in range(conf.n_epoch): 252 | likelihood = 0 253 | D_z_value = 0 254 | D_y_value = 0 255 | G_value = 0 256 | CE_value = 0 257 | 258 | if i <= 50: 259 | lr_value = conf.lr_start 260 | elif i <= 100: 261 | lr_value = conf.lr_mid 262 | else: 263 | lr_value = conf.lr_end 264 | 265 | for j in range(total_batch): 266 | batch_xs, batch_noised_xs, batch_ys = Data.next_batch(train_xs, 267 | train_ys, 268 | conf.batch_size, 269 | make_noise = conf.noised) 270 | 271 | real_cat_labels = np.random.randint(low = 0, high = n_cls, size = conf.batch_size) 272 | real_cat_labels = np.eye(n_cls)[real_cat_labels] 273 | 274 | if conf.prior == "gaussian": 275 | z_prior = gaussian(conf.batch_size, 276 | n_labels=n_cls, 277 | n_dim=conf.n_z, 278 | use_label_info=False) 279 | 280 | elif conf.prior == "gaussian_mixture": 281 | z_prior = gaussian_mixture(conf.batch_size, 282 | n_labels=n_cls, 283 | n_dim=conf.n_z) 284 | 285 | elif conf.prior == "swiss_roll": 286 | z_prior = swiss_roll(conf.batch_size, 287 | n_labels=n_cls, 288 | n_dim=conf.n_z) 289 | else: 290 | print("FLAGS.prior should be [gaussian, gaussian_mixture, swiss_roll]") 291 | 292 | feed_dict = {X: batch_xs, 293 | X_noised: batch_noised_xs, 294 | Y: batch_ys, 295 | Y_cat: real_cat_labels, 296 | z_prior_: z_prior, 297 | lr_: lr_value, 298 | keep_prob: conf.keep_prob} 299 | 300 | # AutoEncoder phase 301 | l, _, g = sess.run([negative_log_likelihood, op_AE, global_step], feed_dict = feed_dict) 302 | 303 | # z_Discriminator phase 304 | l_z_D,_ = sess.run([D_loss_z, op_z_D], feed_dict = feed_dict) 305 | 306 | # y_Discriminator phase 307 | l_y_D, _ = sess.run([D_loss_y, op_y_D], feed_dict=feed_dict) 308 | 309 | # Generator phase 310 | l_G, _ = sess.run([G_loss, op_G], feed_dict = feed_dict) 311 | 312 | batch_semi_xs, batch_noised_semi_xs,batch_semi_ys = Data_semi.next_batch(valid_xs, 313 | valid_ys, 314 | conf.batch_size, 315 | make_noise = False) 316 | 317 | feed_dict = {X: batch_semi_xs, 318 | X_noised: batch_noised_semi_xs, 319 | Y: batch_semi_ys, 320 | Y_cat: real_cat_labels, 321 | lr_:lr_value, 322 | keep_prob: conf.keep_prob} 323 | 324 | # Cross_Entropy phase 325 | CE, _ = sess.run([CE_labels, op_CE_labels], feed_dict = feed_dict) 326 | 327 | likelihood += l/total_batch 328 | D_z_value += l_z_D/total_batch 329 | D_y_value += l_y_D/total_batch 330 | G_value += l_G/total_batch 331 | CE_value += CE/total_batch 332 | 333 | if i % 5 == 0 or i == (conf.n_epoch -1): 334 | images = sess.run(X_generated, feed_dict = {X:batch_t_xs, 335 | X_noised: batch_tn_xs, 336 | keep_prob: 1.0}) 337 | images = np.reshape(images, [-1, height, width, channel]) 338 | name = "Manifold_semi_canvas_" + str(i) 339 | plot_manifold_canvas(images, 10, type = "MNIST", name = name) 340 | 341 | 342 | hour = int((time.time() - start_time) / 3600) 343 | min = int(((time.time() - start_time) - 3600 * hour) / 60) 344 | sec = int((time.time() - start_time) - 3600 * hour - 60 * min) 345 | print("Epoch: %3d lr_AE_G_CE: %.5f lr_D: %.5f Time: %d hour %d min %d sec" % (i, lr_value,lr_value/5, hour, min, sec)) 346 | print("loss_AE: %.5f" % (likelihood)) 347 | print("loss_z_D: %.4f loss_y_D: %f" % (D_z_value, D_y_value)) 348 | print("loss_G: %.4f CE_semi: %.4f\n" % (G_value, CE_value)) 349 | 350 | if conf.PARR is True: 351 | print("-"*80) 352 | print("plot analogical reasoning result") 353 | z_holder = [] 354 | for i in range(n_cls): 355 | z_ = np.random.rand(10, conf.n_z) 356 | z_holder.append(z_) 357 | z_holder = np.concatenate(z_holder, axis = 0) 358 | y = [j for j in range(n_cls)] 359 | y = y*10 360 | length = len(z_holder) 361 | y_one_hot = np.zeros((length, n_cls)) 362 | y_one_hot[np.arange(length), y] = 1 363 | y_one_hot = np.reshape(y_one_hot, [-1, n_cls]) 364 | z_concated = np.concatenate([z_holder, y_one_hot], axis=1) 365 | PARR = sess.run(images_PARR, feed_dict = {latent: z_concated, keep_prob: 1.0}) 366 | PARR = np.reshape(PARR, [-1, height, width, channel]) 367 | p_name = "PARR/Cond_generation" 368 | plot_manifold_canvas(PARR, 10, "MNIST", p_name) 369 | 370 | ## code for 2D scatter plot 371 | if conf.n_z == 2: 372 | print("-" * 80) 373 | print("plot 2D Scatter Result") 374 | test_total_batch = Data.get_total_batch(test_xs, 128) 375 | Data.initialize_batch() 376 | latent_holder = [] 377 | for i in range(test_total_batch): 378 | batch_test_xs, batch_test_noised_xs, batch_test_ys = Data.next_batch(test_xs, 379 | test_ys, 380 | conf.batch_size, 381 | make_noise=False) 382 | feed_dict = {X: batch_test_xs, 383 | X_noised: batch_test_noised_xs, 384 | Y: batch_test_ys, 385 | keep_prob: 1.0} 386 | 387 | latent_vars = sess.run(style, feed_dict=feed_dict) 388 | latent_holder.append(latent_vars) 389 | latent_holder = np.concatenate(latent_holder, axis=0) 390 | plot_2d_scatter(latent_holder[:, 0], latent_holder[:, 1], test_ys[:len(latent_holder)]) 391 | 392 | if conf.PMLR is True: 393 | print("-"*80) 394 | assert conf.n_z == 2, "Error: n_z should be 2" 395 | print("plot Manifold Learning Results") 396 | x_axis = np.linspace(-0.5,0.5,10) 397 | y_axis = np.linspace(-0.5,0.5,10) 398 | z_holder = [] 399 | for i,xi in enumerate(x_axis): 400 | for j, yi in enumerate(y_axis): 401 | z_holder.append([xi,yi]) 402 | length = len(z_holder) 403 | for k in range(n_cls): 404 | y = [k]*length 405 | y_one_hot = np.zeros((length, n_cls)) 406 | y_one_hot[np.arange(length), y] = 1 407 | y_one_hot = np.reshape(y_one_hot, [-1,n_cls]) 408 | z_concated = np.concatenate([z_holder, y_one_hot], axis=1) 409 | MLR = sess.run(images_manifold, feed_dict = {latent: z_concated, keep_prob: 1.0}) 410 | MLR = np.reshape(MLR, [-1, height, width, channel]) 411 | p_name = "PMLR/labels" +str(k) 412 | plot_manifold_canvas(MLR, 10, "MNIST", p_name) --------------------------------------------------------------------------------