├── Image
├── PMLR_.png
├── Gaussian.png
├── labels_0.png
├── labels_1.png
├── labels_2.png
├── G_generated.png
├── Swiss_roll_.png
├── GM_generated.png
├── Original_image.png
├── S_R_generated.png
├── Supervised_AAE.png
├── Supervised_AAE_.png
├── Gaussian_mixture_.png
├── Restored_Semi_AAE.png
└── Semisupervised_AAE_.png
├── plot.py
├── utils.py
├── prior.py
├── AAE.py
├── README.md
├── data_utils.py
└── main.py
/Image/PMLR_.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/PMLR_.png
--------------------------------------------------------------------------------
/Image/Gaussian.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/Gaussian.png
--------------------------------------------------------------------------------
/Image/labels_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/labels_0.png
--------------------------------------------------------------------------------
/Image/labels_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/labels_1.png
--------------------------------------------------------------------------------
/Image/labels_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/labels_2.png
--------------------------------------------------------------------------------
/Image/G_generated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/G_generated.png
--------------------------------------------------------------------------------
/Image/Swiss_roll_.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/Swiss_roll_.png
--------------------------------------------------------------------------------
/Image/GM_generated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/GM_generated.png
--------------------------------------------------------------------------------
/Image/Original_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/Original_image.png
--------------------------------------------------------------------------------
/Image/S_R_generated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/S_R_generated.png
--------------------------------------------------------------------------------
/Image/Supervised_AAE.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/Supervised_AAE.png
--------------------------------------------------------------------------------
/Image/Supervised_AAE_.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/Supervised_AAE_.png
--------------------------------------------------------------------------------
/Image/Gaussian_mixture_.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/Gaussian_mixture_.png
--------------------------------------------------------------------------------
/Image/Restored_Semi_AAE.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/Restored_Semi_AAE.png
--------------------------------------------------------------------------------
/Image/Semisupervised_AAE_.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingukkang/Adversarial-AutoEncoder/HEAD/Image/Semisupervised_AAE_.png
--------------------------------------------------------------------------------
/plot.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import tensorflow as tf
3 | from data_utils import *
4 |
5 | def plot_2d_scatter(x,y,test_labels):
6 | plt.figure(figsize = (8,6))
7 | plt.scatter(x,y, c = np.argmax(test_labels,1), marker ='.', edgecolor = 'none', cmap = discrete_cmap('jet'))
8 | plt.colorbar()
9 | plt.grid()
10 | if not tf.gfile.Exists("./Scatter"):
11 | tf.gfile.MakeDirs("./Scatter")
12 | plt.savefig('./Scatter/2D_latent_space.png')
13 | plt.close()
14 |
15 | def discrete_cmap(base_cmap =None):
16 | base = plt.cm.get_cmap(base_cmap)
17 | color_list = base(np.linspace(0,1,10))
18 | cmap_name = base.name + str(10)
19 | return base.from_list(cmap_name,color_list,10)
20 |
21 | def plot_manifold_canvas(images, n, type, name):
22 | assert images.shape[0] == n**2, "n**2 should be number of images"
23 | height = images.shape[1]
24 | width = images.shape[2] # width = height
25 | x = np.linspace(-2, 2, n)
26 | y = np.linspace(-2, 2, n)
27 |
28 | if type == "MNIST":
29 | canvas = np.empty((n * height, n * height))
30 | for i, yi in enumerate(x):
31 | for j, xi in enumerate(y):
32 | canvas[height*i: height*i + height, width*j: width*j + width] = np.reshape(images[n*i + j], [height, width])
33 | plt.figure(figsize=(8, 8))
34 | plt.imshow(canvas, cmap="gray")
35 | else:
36 | canvas = np.empty((n * height, n * height, 3))
37 | for i, yi in enumerate(x):
38 | for j, xi in enumerate(y):
39 | canvas[height*i: height*i + height, width*j: width*j + width,:] = images[n*i + j]
40 | plt.figure(figsize=(8, 8))
41 | plt.imshow(canvas)
42 |
43 | if not tf.gfile.Exists("./plot"):
44 | tf.gfile.MakeDirs("./plot")
45 | if not tf.gfile.Exists("./plot/PMLR"):
46 | tf.gfile.MakeDirs("./plot/PMLR")
47 | if not tf.gfile.Exists("./plot/PARR"):
48 | tf.gfile.MakeDirs("./plot/PARR")
49 |
50 | name = name + ".png"
51 | path = os.path.join("./plot", name)
52 | plt.savefig(path)
53 | print("saving location: %s" % (path))
54 | plt.close()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | initializer = tf.contrib.layers.xavier_initializer()
4 | #initializer = tf.contrib.layers.variance_scaling_initializer(factor = 1.0)
5 |
6 |
7 | def conv(inputs,filters,name):
8 | net = tf.layers.conv2d(inputs = inputs,
9 | filters = filters,
10 | kernel_size = [3,3],
11 | strides = (1,1),
12 | padding ="SAME",
13 | kernel_initializer = initializer,
14 | name = name,
15 | reuse = tf.AUTO_REUSE)
16 | return net
17 |
18 | def maxpool(input,name):
19 | net = tf.nn.max_pool(value = input, ksize = [1,2,2,1], strides = [1,2,2,1], padding = "SAME", name = name)
20 | return net
21 |
22 | def bn(inputs,is_training,name):
23 | net = tf.contrib.layers.batch_norm(inputs, decay = 0.9, is_training = is_training, reuse = tf.AUTO_REUSE, scope = name)
24 | return net
25 |
26 | def leaky(input):
27 | return tf.nn.leaky_relu(input)
28 |
29 | def relu(input):
30 | return tf.nn.relu(input)
31 |
32 | def drop_out(input, keep_prob):
33 |
34 | return tf.nn.dropout(input, keep_prob)
35 | def dense(inputs, units, name):
36 | net = tf.layers.dense(inputs = inputs,
37 | units = units,
38 | reuse = tf.AUTO_REUSE,
39 | name = name,
40 | kernel_initializer = initializer)
41 | return net
42 |
43 | user_flags = []
44 |
45 | def DEFINE_string(name, default_value, doc_string):
46 | tf.app.flags.DEFINE_string(name, default_value, doc_string)
47 | global user_flags
48 | user_flags.append(name)
49 |
50 | def DEFINE_integer(name, default_value, doc_string):
51 | tf.app.flags.DEFINE_integer(name, default_value, doc_string)
52 | global user_flags
53 | user_flags.append(name)
54 |
55 | def DEFINE_float(name, defualt_value, doc_string):
56 | tf.app.flags.DEFINE_float(name, defualt_value, doc_string)
57 | global user_flags
58 | user_flags.append(name)
59 |
60 | def DEFINE_boolean(name, default_value, doc_string):
61 | tf.app.flags.DEFINE_boolean(name, default_value, doc_string)
62 | global user_flags
63 | user_flags.append(name)
64 |
65 | def print_user_flags(line_limit = 100):
66 | print("-" * 80)
67 |
68 | global user_flags
69 | FLAGS = tf.app.flags.FLAGS
70 |
71 | for flag_name in sorted(user_flags):
72 | value = "{}".format(getattr(FLAGS, flag_name))
73 | log_string = flag_name
74 | log_string += "." * (line_limit - len(flag_name) - len(value))
75 | log_string += value
76 | print(log_string)
77 |
78 | return FLAGS
--------------------------------------------------------------------------------
/prior.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from math import sin,cos,sqrt
3 | ## this code is borrowed from https://github.com/hwalsuklee/tensorflow-mnist-AAE/
4 |
5 |
6 |
7 | def gaussian(batch_size, n_labels, n_dim, mean=0, var=1, use_label_info=False):
8 | np.random.seed(0)
9 | if use_label_info:
10 | if n_dim != 2 or n_labels != 10:
11 | raise Exception("n_dim must be 2 and n_labels must be 10.")
12 |
13 | def sample(n_labels):
14 | x, y = np.random.normal(mean, var, (2,))
15 | angle = np.angle((x-mean) + 1j*(y-mean), deg=True)
16 | dist = np.sqrt((x-mean)**2+(y-mean)**2)
17 |
18 | # label 0
19 | if dist <1.0:
20 | label = 0
21 | else:
22 | label = ((int)((n_labels-1)*angle))//360
23 |
24 | if label<0:
25 | label+=n_labels-1
26 |
27 | label += 1
28 |
29 | return np.array([x, y]).reshape((2,)), label
30 |
31 | z = np.empty((batch_size, n_dim), dtype=np.float32)
32 | z_id = np.empty((batch_size), dtype=np.int32)
33 | for batch in range(batch_size):
34 | for zi in range((int)(n_dim/2)):
35 | a_sample, a_label = sample(n_labels)
36 | z[batch, zi*2:zi*2+2] = a_sample
37 | z_id[batch] = a_label
38 | return z, z_id
39 | else:
40 | z = np.random.normal(mean, var, (batch_size, n_dim)).astype(np.float32)
41 | return z
42 |
43 | def gaussian_mixture(batch_size, n_labels ,n_dim, x_var=0.5, y_var=0.1, label_indices=None):
44 | np.random.seed(0)
45 | if n_dim != 2:
46 | raise Exception("n_dim must be 2.")
47 |
48 | def sample(x, y, label, n_labels):
49 | shift = 1.4
50 | r = 2.0 * np.pi / float(n_labels) * float(label)
51 | new_x = x * cos(r) - y * sin(r)
52 | new_y = x * sin(r) + y * cos(r)
53 | new_x += shift * cos(r)
54 | new_y += shift * sin(r)
55 | return np.array([new_x, new_y]).reshape((2,))
56 |
57 | x = np.random.normal(0, x_var, (batch_size, (int)(n_dim/2)))
58 | y = np.random.normal(0, y_var, (batch_size, (int)(n_dim/2)))
59 | z = np.empty((batch_size, n_dim), dtype=np.float32)
60 | for batch in range(batch_size):
61 | for zi in range((int)(n_dim/2)):
62 | if label_indices is not None:
63 | z[batch, zi*2:zi*2+2] = sample(x[batch, zi], y[batch, zi], label_indices[batch], n_labels)
64 | else:
65 | z[batch, zi*2:zi*2+2] = sample(x[batch, zi], y[batch, zi], np.random.randint(0, n_labels), n_labels)
66 |
67 | return z
68 |
69 | def swiss_roll(batch_size, n_labels, n_dim, label_indices=None):
70 | np.random.seed(0)
71 | if n_dim != 2:
72 | raise Exception("n_dim must be 2.")
73 |
74 | def sample(label, n_labels):
75 | uni = np.random.uniform(0.0, 1.0) / float(n_labels) + float(label) / float(n_labels)
76 | r = sqrt(uni) * 3.0
77 | rad = np.pi * 4.0 * sqrt(uni)
78 | x = r * cos(rad)
79 | y = r * sin(rad)
80 | return np.array([x, y]).reshape((2,))
81 |
82 | z = np.zeros((batch_size, n_dim), dtype=np.float32)
83 | for batch in range(batch_size):
84 | for zi in range((int)(n_dim/2)):
85 | if label_indices is not None:
86 | z[batch, zi*2:zi*2+2] = sample(label_indices[batch], n_labels)
87 | else:
88 | z[batch, zi*2:zi*2+2] = sample(np.random.randint(0, n_labels), n_labels)
89 | return z
--------------------------------------------------------------------------------
/AAE.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from utils import *
3 | from data_utils import *
4 | from prior import *
5 |
6 | class AAE:
7 | def __init__(self, conf, shape, n_labels):
8 | self.conf = conf
9 | self.mode = conf.model
10 | self.data = conf.data
11 | self.super_n_hidden = conf.super_n_hidden
12 | self.semi_n_hidden = conf.semi_n_hidden
13 | self.n_z = conf.n_z
14 | self.batch_size = conf.batch_size
15 | self.prior = conf.prior
16 | self.w = shape[1]
17 | self.h = shape[2]
18 | self.c = shape[3]
19 | self.length = self.h * self.w * self.c
20 | self.n_labels = n_labels
21 |
22 | def sup_encoder(self, X, keep_prob): # encoder for supervised AAE
23 |
24 | with tf.variable_scope("sup_encoder", reuse = tf.AUTO_REUSE):
25 | net = drop_out(relu(dense(X, self.super_n_hidden, name = "dense_1")), keep_prob)
26 | net = drop_out(relu(dense(net, self.super_n_hidden, name="dense_2")), keep_prob)
27 | net = dense(net, self.n_z, name ="dense_3")
28 |
29 | return net
30 |
31 | def sup_decoder(self, Z, keep_prob): # decoder for supervised AAE
32 |
33 | with tf.variable_scope("sup_decoder", reuse = tf.AUTO_REUSE):
34 | net = drop_out(relu(dense(Z, self.super_n_hidden, name = "dense_1")), keep_prob)
35 | net = drop_out(relu(dense(net, self.super_n_hidden, name="dense_2")), keep_prob)
36 | net = tf.nn.sigmoid(dense(net, self.length, name = "dense_3"))
37 |
38 | return net
39 |
40 | def discriminator(self,Z, keep_prob): # discriminator for supervised AAE
41 |
42 | with tf.variable_scope("discriminator", reuse = tf.AUTO_REUSE):
43 | net = drop_out(relu(dense(Z, self.super_n_hidden, name = "dense_1")), keep_prob)
44 | net = drop_out(relu(dense(net, self.super_n_hidden, name="dense_2")), keep_prob)
45 | logits = dense(net, 1, name ="dense_3")
46 |
47 | return logits
48 |
49 | def Sup_Adversarial_AutoEncoder(self, X, X_noised, Y, z_prior, z_id, keep_prob):
50 |
51 | X_flatten = tf.reshape(X, [-1, self.length])
52 | X_flatten_noised = tf.reshape(X_noised, [-1, self.length])
53 |
54 | z_generated = self.sup_encoder(X_flatten_noised, keep_prob)
55 | X_generated = self.sup_decoder(z_generated, keep_prob)
56 |
57 | negative_log_likelihood = tf.reduce_mean(tf.squared_difference(X_generated, X_flatten))
58 |
59 | z_prior = tf.concat([z_prior, z_id], axis = 1)
60 | z_fake = tf.concat([z_generated, Y], axis = 1)
61 | D_real_logits = self.discriminator(z_prior, keep_prob)
62 | D_fake_logits = self.discriminator(z_fake, keep_prob)
63 |
64 | D_loss_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_fake_logits, labels = tf.zeros_like(D_fake_logits))
65 | D_loss_true = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_real_logits, labels = tf.ones_like(D_real_logits))
66 |
67 | G_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_fake_logits, labels = tf.ones_like(D_fake_logits))
68 |
69 | D_loss = tf.reduce_mean(D_loss_fake) + tf.reduce_mean(D_loss_true)
70 | G_loss = tf.reduce_mean(G_loss)
71 |
72 | return z_generated, X_generated, negative_log_likelihood, D_loss, G_loss
73 |
74 | def semi_encoder(self, X, keep_prob, semi_supervised = False):
75 |
76 | with tf.variable_scope("semi_encoder", reuse = tf.AUTO_REUSE):
77 | net = drop_out(relu(dense(X, self.semi_n_hidden, name = "dense_1")), keep_prob)
78 | net = drop_out(relu(dense(net, self.semi_n_hidden, name="dense_2")), keep_prob)
79 | style = dense(net, self.n_z, name ="style")
80 |
81 | if semi_supervised is False:
82 | labels_generated = tf.nn.softmax(dense(net, self.n_labels, name = "labels"))
83 | else:
84 | labels_generated = dense(net, self.n_labels, name = "label_logits")
85 |
86 | return style, labels_generated
87 |
88 | def semi_decoder(self, Z, keep_prob):
89 |
90 | with tf.variable_scope("semi_decoder", reuse = tf.AUTO_REUSE):
91 | net = drop_out(relu(dense(Z, self.semi_n_hidden, name = "dense_1")), keep_prob)
92 | net = drop_out(relu(dense(net, self.semi_n_hidden, name="dense_2")), keep_prob)
93 | net = tf.nn.sigmoid(dense(net, self.length, name = "dense_3"))
94 |
95 | return net
96 |
97 | def semi_z_discriminator(self,Z, keep_prob):
98 |
99 | with tf.variable_scope("semi_z_discriminator", reuse = tf.AUTO_REUSE):
100 | net = drop_out(relu(dense(Z, self.semi_n_hidden, name="dense_1")), keep_prob)
101 | net = drop_out(relu(dense(net, self.semi_n_hidden, name="dense_2")), keep_prob)
102 | logits = dense(net, 1, name="dense_3")
103 |
104 | return logits
105 |
106 | def semi_y_discriminator(self, Y, keep_prob):
107 |
108 | with tf.variable_scope("semi_y_discriminator", reuse = tf.AUTO_REUSE):
109 | net = drop_out(relu(dense(Y, self.semi_n_hidden, name = "dense_1")), keep_prob)
110 | net = drop_out(relu(dense(net, self.semi_n_hidden, name="dense_2")), keep_prob)
111 | logits = dense(net, 1, name = "dense_3")
112 |
113 | return logits
114 |
115 | def Semi_Adversarial_AutoEncoder(self, X, X_noised, labels, labels_cat, z_prior, keep_prob):
116 |
117 | X_flatten = tf.reshape(X, [-1 , self.length])
118 | X_noised_flatten = tf.reshape(X_noised, [-1, self.length])
119 |
120 | style, labels_softmax = self.semi_encoder(X_noised_flatten, keep_prob, semi_supervised = False)
121 | latent_inputs = tf.concat([style, labels_softmax], axis = 1)
122 | X_generated = self.semi_decoder(latent_inputs, keep_prob)
123 |
124 | _, labels_generated = self.semi_encoder(X_noised_flatten, keep_prob, semi_supervised = True)
125 |
126 | D_Y_fake = self.semi_y_discriminator(labels_softmax, keep_prob)
127 | D_Y_real = self.semi_y_discriminator(labels_cat, keep_prob)
128 |
129 | D_Z_fake = self.semi_z_discriminator(style, keep_prob)
130 | D_Z_real = self.semi_z_discriminator(z_prior, keep_prob)
131 |
132 | negative_loglikelihood = tf.reduce_mean(tf.squared_difference(X_generated,X_flatten))
133 |
134 | D_loss_y_real = tf.nn.sigmoid_cross_entropy_with_logits(logits=D_Y_real, labels=tf.ones_like(D_Y_real))
135 | D_loss_y_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits=D_Y_fake, labels=tf.zeros_like(D_Y_fake))
136 | D_loss_y = tf.reduce_mean(D_loss_y_real) + tf.reduce_mean(D_loss_y_fake)
137 | D_loss_z_real = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_Z_real, labels = tf.ones_like(D_Z_real))
138 | D_loss_z_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_Z_fake, labels = tf.zeros_like(D_Z_fake))
139 | D_loss_z = tf.reduce_mean(D_loss_z_real) + tf.reduce_mean(D_loss_z_fake)
140 |
141 |
142 | G_loss_y = tf.nn.sigmoid_cross_entropy_with_logits(logits=D_Y_fake, labels=tf.ones_like(D_Y_fake))
143 | G_loss_z = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_Z_fake, labels = tf.ones_like(D_Z_fake))
144 | G_loss = tf.reduce_mean(G_loss_y) + tf.reduce_mean(G_loss_z)
145 |
146 | CE_labels = tf.nn.softmax_cross_entropy_with_logits(logits = labels_generated, labels = labels)
147 | CE_labels = tf.reduce_mean(CE_labels)
148 |
149 |
150 | return style, X_generated, negative_loglikelihood, D_loss_y, D_loss_z, G_loss, CE_labels
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Adversarial AutoEncoder(AAE)- Tensorflow
2 |
3 | I write the Tensorflow Code for Supervised AAE and SemiSupervised AAE
4 |
5 | ## Enviroment
6 | - OS: Ubuntu 16.04
7 |
8 | - Graphic Card /RAM : 1080TI /16G
9 |
10 | - Python 3.5
11 |
12 | - Tensorflow-gpu version: 1.4.0rc2
13 |
14 | - OpenCV 3.4.1
15 |
16 | ## Schematic of AAE
17 |
18 | ### Supervised AAE
19 |
20 |
21 |
22 | ***
23 |
24 | ### SemiSupervised AAE
25 |
26 |
27 |
28 | ## Code
29 |
30 | **Supervised Encoder**
31 | ```python
32 | def sup_encoder(self, X, keep_prob): # encoder for supervised AAE
33 |
34 | with tf.variable_scope("sup_encoder", reuse = tf.AUTO_REUSE):
35 | net = drop_out(relu(dense(X, self.super_n_hidden, name = "dense_1")), keep_prob)
36 | net = drop_out(relu(dense(net, self.super_n_hidden, name="dense_2")), keep_prob)
37 | net = dense(net, self.n_z, name ="dense_3")
38 |
39 | return net
40 | ```
41 |
42 | **Supervised Decoder**
43 | ```python
44 | def sup_decoder(self, Z, keep_prob): # decoder for supervised AAE
45 |
46 | with tf.variable_scope("sup_decoder", reuse = tf.AUTO_REUSE):
47 | net = drop_out(relu(dense(Z, self.super_n_hidden, name = "dense_1")), keep_prob)
48 | net = drop_out(relu(dense(net, self.super_n_hidden, name="dense_2")), keep_prob)
49 | net = tf.nn.sigmoid(dense(net, self.length, name = "dense_3"))
50 |
51 | return net
52 | ```
53 |
54 | **Supervised Discriminator**
55 | ```python
56 | def discriminator(self,Z, keep_prob): # discriminator for supervised AAE
57 |
58 | with tf.variable_scope("discriminator", reuse = tf.AUTO_REUSE):
59 | net = drop_out(relu(dense(Z, self.super_n_hidden, name = "dense_1")), keep_prob)
60 | net = drop_out(relu(dense(net, self.super_n_hidden, name="dense_2")), keep_prob)
61 | logits = dense(net, 1, name ="dense_3")
62 |
63 | return logits
64 | ```
65 |
66 | **Supervised Adversarial AutoEncoder**
67 | ```python
68 | def Sup_Adversarial_AutoEncoder(self, X, X_noised, Y, z_prior, z_id, keep_prob):
69 |
70 | X_flatten = tf.reshape(X, [-1, self.length])
71 | X_flatten_noised = tf.reshape(X_noised, [-1, self.length])
72 |
73 | z_generated = self.sup_encoder(X_flatten_noised, keep_prob)
74 | X_generated = self.sup_decoder(z_generated, keep_prob)
75 |
76 | negative_log_likelihood = tf.reduce_mean(tf.squared_difference(X_generated, X_flatten))
77 |
78 | z_prior = tf.concat([z_prior, z_id], axis = 1)
79 | z_fake = tf.concat([z_generated, Y], axis = 1)
80 | D_real_logits = self.discriminator(z_prior, keep_prob)
81 | D_fake_logits = self.discriminator(z_fake, keep_prob)
82 |
83 | D_loss_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_fake_logits, labels = tf.zeros_like(D_fake_logits))
84 | D_loss_true = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_real_logits, labels = tf.ones_like(D_real_logits))
85 |
86 | G_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_fake_logits, labels = tf.ones_like(D_fake_logits))
87 |
88 | D_loss = tf.reduce_mean(D_loss_fake) + tf.reduce_mean(D_loss_true)
89 | G_loss = tf.reduce_mean(G_loss)
90 |
91 | return z_generated, X_generated, negative_log_likelihood, D_loss, G_loss
92 | ```
93 |
94 | ***
95 |
96 | **SemiSupervised Encoder**
97 | ```python
98 | def semi_encoder(self, X, keep_prob, semi_supervised = False):
99 |
100 | with tf.variable_scope("semi_encoder", reuse = tf.AUTO_REUSE):
101 | net = drop_out(relu(dense(X, self.semi_n_hidden, name = "dense_1")), keep_prob)
102 | net = drop_out(relu(dense(net, self.semi_n_hidden, name="dense_2")), keep_prob)
103 | style = dense(net, self.n_z, name ="style")
104 |
105 | if semi_supervised is False:
106 | labels_generated = tf.nn.softmax(dense(net, self.n_labels, name = "labels"))
107 | else:
108 | labels_generated = dense(net, self.n_labels, name = "label_logits")
109 |
110 | return style, labels_generated
111 | ```
112 |
113 | **SemiSupervised Decoder**
114 | ```python
115 | def semi_decoder(self, Z, keep_prob):
116 |
117 | with tf.variable_scope("semi_decoder", reuse = tf.AUTO_REUSE):
118 | net = drop_out(relu(dense(Z, self.semi_n_hidden, name = "dense_1")), keep_prob)
119 | net = drop_out(relu(dense(net, self.semi_n_hidden, name="dense_2")), keep_prob)
120 | net = tf.nn.sigmoid(dense(net, self.length, name = "dense_3"))
121 |
122 | return net
123 | ```
124 |
125 | **SemiSupervised z Discriminator**
126 | ```python
127 | def semi_z_discriminator(self,Z, keep_prob):
128 |
129 | with tf.variable_scope("semi_z_discriminator", reuse = tf.AUTO_REUSE):
130 | net = drop_out(relu(dense(Z, self.semi_n_hidden, name="dense_1")), keep_prob)
131 | net = drop_out(relu(dense(net, self.semi_n_hidden, name="dense_2")), keep_prob)
132 | logits = dense(net, 1, name="dense_3")
133 |
134 | return logits
135 | ```
136 |
137 | **SemiSupervised y Discriminator**
138 | ```python
139 | def semi_y_discriminator(self, Y, keep_prob):
140 |
141 | with tf.variable_scope("semi_y_discriminator", reuse = tf.AUTO_REUSE):
142 | net = drop_out(relu(dense(Y, self.semi_n_hidden, name = "dense_1")), keep_prob)
143 | net = drop_out(relu(dense(net, self.semi_n_hidden, name="dense_2")), keep_prob)
144 | logits = dense(net, 1, name = "dense_3")
145 |
146 | return logits
147 | ```
148 |
149 | **SemiSupervised Adversarial AutoEncoder**
150 | ```python
151 | def Semi_Adversarial_AutoEncoder(self, X, X_noised, labels, labels_cat, z_prior, keep_prob):
152 |
153 | X_flatten = tf.reshape(X, [-1 , self.length])
154 | X_noised_flatten = tf.reshape(X_noised, [-1, self.length])
155 |
156 | style, labels_softmax = self.semi_encoder(X_noised_flatten, keep_prob, semi_supervised = False)
157 | latent_inputs = tf.concat([style, labels_softmax], axis = 1)
158 | X_generated = self.semi_decoder(latent_inputs, keep_prob)
159 |
160 | _, labels_generated = self.semi_encoder(X_noised_flatten, keep_prob, semi_supervised = True)
161 |
162 | D_Y_fake = self.semi_y_discriminator(labels_softmax, keep_prob)
163 | D_Y_real = self.semi_y_discriminator(labels_cat, keep_prob)
164 |
165 | D_Z_fake = self.semi_z_discriminator(style, keep_prob)
166 | D_Z_real = self.semi_z_discriminator(z_prior, keep_prob)
167 |
168 | negative_loglikelihood = tf.reduce_mean(tf.squared_difference(X_generated,X_flatten))
169 |
170 | D_loss_y_real = tf.nn.sigmoid_cross_entropy_with_logits(logits=D_Y_real, labels=tf.ones_like(D_Y_real))
171 | D_loss_y_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits=D_Y_fake, labels=tf.zeros_like(D_Y_fake))
172 | D_loss_y = tf.reduce_mean(D_loss_y_real) + tf.reduce_mean(D_loss_y_fake)
173 | D_loss_z_real = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_Z_real, labels = tf.ones_like(D_Z_real))
174 | D_loss_z_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_Z_fake, labels = tf.zeros_like(D_Z_fake))
175 | D_loss_z = tf.reduce_mean(D_loss_z_real) + tf.reduce_mean(D_loss_z_fake)
176 |
177 |
178 | G_loss_y = tf.nn.sigmoid_cross_entropy_with_logits(logits=D_Y_fake, labels=tf.ones_like(D_Y_fake))
179 | G_loss_z = tf.nn.sigmoid_cross_entropy_with_logits(logits = D_Z_fake, labels = tf.ones_like(D_Z_fake))
180 | G_loss = tf.reduce_mean(G_loss_y) + tf.reduce_mean(G_loss_z)
181 |
182 | CE_labels = tf.nn.softmax_cross_entropy_with_logits(logits = labels_generated, labels = labels)
183 | CE_labels = tf.reduce_mean(CE_labels)
184 |
185 |
186 | return style, X_generated, negative_loglikelihood, D_loss_y, D_loss_z, G_loss, CE_labels
187 | ```
188 |
189 | ## Results
190 |
191 | **1. Restoring**
192 | ```
193 | python main.py --model supervised --prior gaussian --n_z 20
194 |
195 | or
196 |
197 | python main.py --model semi_supervised --prior gaussian --n_z 20
198 | ```
199 |
200 |
201 | | Original Images |
202 | Restored via Supervised AAE |
203 | Restored via Semisupervised AAE |
204 |
205 |
206 |
207 | |
208 | |
209 | |
210 |
211 |
212 | **2. 2D Latent Space**
213 |
214 | ***Target***
215 |
216 |
217 |
218 | | Gaussian |
219 | Gaussian Mixture |
220 | Swiss Roll |
221 |
222 |
223 |
224 | |
225 | |
226 | |
227 |
228 |
229 | ***Coding Space of Supervised AAE***
230 | ```
231 | Test was performed using 10,000 number of test dataset not used for learning.
232 |
233 | python main.py --model supervised --prior gaussian_mixture --n_z 2
234 | ```
235 |
236 |
237 |
238 | | Gaussian |
239 | Gaussian Mixture |
240 | Swiss Roll |
241 |
242 |
243 |
244 | |
245 | |
246 | |
247 |
248 |
249 |
250 | **3. Manifold Learning Result**
251 |
252 | ***Supervised AAE***
253 |
254 | ```
255 | python main.py --model supervised --prior gaussian_mixture --n_z 2 --PMLR True
256 | ```
257 |
258 |
259 |
260 | | Manifold |
261 |
262 |
263 |
264 | |
265 |
266 |
267 | ***SemiSupervised AAE***
268 |
269 | ```
270 | python main.py --model semi_supervised --prior gaussian --n_z 2 --PMLR True
271 |
272 |
273 | The results suggest that when n_z is 2, SemiSupervised AAE can't extract label information from Input image very well.
274 | ```
275 |
276 |
277 |
278 | | Manifold with a condition 0 |
279 | Manifold with a condition 1 |
280 | Manifold with a condition 2 |
281 |
282 |
283 |
284 | |
285 | |
286 | |
287 |
288 |
289 | ## Reference
290 |
291 | ### Paper
292 | AAE: https://arxiv.org/abs/1511.05644
293 |
294 | GAN: https://arxiv.org/abs/1406.2661
295 |
296 | ### Github
297 | https://github.com/hwalsuklee/tensorflow-mnist-AAE
298 |
299 | https://github.com/MINGUKKANG/CVAE
300 |
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import random
4 | import gzip
5 | import tarfile
6 | import pickle
7 | import os
8 | from six.moves import urllib
9 | from plot import *
10 |
11 | class data_pipeline:
12 | def __init__(self,type):
13 | self.type = type
14 | self.debug = 0
15 | self.batch = 0
16 |
17 | if self.type == "MNIST":
18 | self.url = "http://yann.lecun.com/exdb/mnist/"
19 | self.debug =1
20 | self.n_train_images = 60000
21 | self.n_test_images = 10000
22 | self.n_channels = 1
23 | self.size = 28
24 | self.MNIST_filename = ["train-images-idx3-ubyte.gz",
25 | "train-labels-idx1-ubyte.gz",
26 | "t10k-images-idx3-ubyte.gz",
27 | "t10k-labels-idx1-ubyte.gz"]
28 |
29 | elif self.type == "CIFAR_10":
30 | self.url = "https://www.cs.toronto.edu/~kriz/"
31 | self.debug = 1
32 | self.n_train_images = 50000
33 | self.n_test_images = 10000
34 | self.n_channels = 3
35 | self.size = 32
36 | self.CIFAR_10_filename = ["cifar-10-python.tar.gz"]
37 |
38 | assert self.debug == 1, "Data type must be MNIST or CIFAR_10"
39 |
40 | def maybe_download(self, filename, filepath):
41 | if os.path.isfile(filepath) is True:
42 | print("Filename %s is already downloaded" % filename)
43 | else:
44 | filepath,_ = urllib.request.urlretrieve(self.url + filename, filepath)
45 | with tf.gfile.GFile(filepath) as f:
46 | size = f.size()
47 | print("Successfully download", filename, size, "bytes")
48 | return filepath
49 |
50 |
51 | def download_data(self):
52 | self.filepath_holder = []
53 |
54 | if not tf.gfile.Exists("./Data"):
55 | tf.gfile.MakeDirs("./Data")
56 |
57 | if self.type == "MNIST":
58 | for i in self.MNIST_filename:
59 | filepath = os.path.join("./Data", i)
60 | self.maybe_download(i,filepath)
61 | self.filepath_holder.append(filepath)
62 |
63 | elif self.type == "CIFAR_10":
64 | for i in self.CIFAR_10_filename:
65 | filepath = os.path.join("./Data", i)
66 | self.maybe_download(i,filepath)
67 | self.filepath_holder.append(filepath)
68 | print("-" * 80)
69 |
70 | def extract_mnist_images(self, filepath, size, n_images,n_channels):
71 | print("Extracting and Reading ", filepath)
72 |
73 | with gzip.open(filepath) as bytestream:
74 | bytestream.read(16)
75 | buf = bytestream.read(size*size*n_images*n_channels)
76 | data = np.frombuffer(buf, dtype = np.uint8)
77 | data = np.reshape(data,[n_images, size, size, n_channels])
78 | return data
79 |
80 | def extract_mnist_labels(self, filepath,n_images):
81 | print("Extracting and Reading ", filepath)
82 |
83 | with gzip.open(filepath) as bytestream:
84 | bytestream.read(8)
85 | buf = bytestream.read(1*n_images)
86 | labels = np.frombuffer(buf, dtype = np.uint8)
87 | one_hot_encoding = np.zeros((n_images, 10))
88 | one_hot_encoding[np.arange(n_images), labels] = 1
89 | one_hot_encoding = np.reshape(one_hot_encoding, [-1,10])
90 | return one_hot_encoding
91 |
92 | def extract_cifar_data(self,filepath, train_files,n_images):
93 | ## this code is from https://github.com/melodyguan/enas/blob/master/src/cifar10/data_utils.py
94 | images, labels = [], []
95 | for file_name in train_files:
96 | full_name = os.path.join(filepath, file_name)
97 | with open(full_name, mode = "rb") as finp:
98 | data = pickle.load(finp, encoding = "bytes")
99 | batch_images = data[b'data']
100 | batch_labels = np.array(data[b'labels'])
101 | images.append(batch_images)
102 | labels.append(batch_labels)
103 | images = np.concatenate(images, axis=0)
104 | labels = np.concatenate(labels, axis=0)
105 | one_hot_encoding = np.zeros((n_images, 10))
106 | one_hot_encoding[np.arange(n_images), labels] = 1
107 | one_hot_encoding = np.reshape(one_hot_encoding, [-1, 10])
108 | images = np.reshape(images, [-1, 3, 32, 32])
109 | images = np.transpose(images, [0, 2, 3, 1])
110 |
111 | return images, one_hot_encoding
112 |
113 | def extract_cifar_data_(self,filepath, num_valids=5000):
114 | print("Reading data")
115 | with tarfile.open(filepath, "r:gz") as tar:
116 | tar.extractall("./Data")
117 | images, labels = {}, {}
118 | train_files = [
119 | "./cifar-10-batches-py/data_batch_1",
120 | "./cifar-10-batches-py/data_batch_2",
121 | "./cifar-10-batches-py/data_batch_3",
122 | "./cifar-10-batches-py/data_batch_4",
123 | "./cifar-10-batches-py/data_batch_5"]
124 | test_file = ["./cifar-10-batches-py/test_batch"]
125 | images["train"], labels["train"] = self.extract_cifar_data("./Data", train_files,self.n_train_images)
126 |
127 | if num_valids:
128 | images["valid"] = images["train"][-num_valids:]
129 | labels["valid"] = labels["train"][-num_valids:]
130 |
131 | images["train"] = images["train"][:-num_valids]
132 | labels["train"] = labels["train"][:-num_valids]
133 | else:
134 | images["valid"], labels["valid"] = None, None
135 |
136 | images["test"], labels["test"] = self.extract_cifar_data("./Data", test_file,self.n_test_images)
137 | return images, labels
138 |
139 | def apply_preprocessing(self, images, mode):
140 | mean = np.mean(images, axis =(0,1,2))
141 | images = images/255
142 | print("%s_mean: " % mode, mean)
143 | return images
144 |
145 | def load_preprocess_data(self):
146 | self.download_data()
147 | if self.type == "MNIST":
148 | train_images = self.extract_mnist_images(self.filepath_holder[0],self.size, self.n_train_images, self.n_channels)
149 | train_labels = self.extract_mnist_labels(self.filepath_holder[1], self.n_train_images)
150 | self.valid_images = train_images[0:5000,:,:,:]
151 | self.valid_labels = train_labels[0:5000,:]
152 | self.train_images = train_images[5000:,:,:,:]
153 | self.train_labels = train_labels[5000:,:]
154 | self.test_images = self.extract_mnist_images(self.filepath_holder[2],self.size, self.n_test_images, self.n_channels)
155 | self.test_labels = self.extract_mnist_labels(self.filepath_holder[3], self.n_test_images)
156 | print("-" * 80)
157 | self.train_images = self.apply_preprocessing(images = self.train_images, mode = "train")
158 | self.valid_images = self.apply_preprocessing(images = self.valid_images, mode = "valid")
159 | self.test_images = self.apply_preprocessing(images = self.test_images, mode = "test")
160 | print("-" * 80)
161 | print("training size: ", np.shape(self.train_images),", ",np.shape(self.train_labels))
162 | print("valid size: ", np.shape(self.valid_images), ", ", np.shape(self.valid_labels))
163 | print("test size: ", np.shape(self.test_images), ", ", np.shape(self.test_labels))
164 | else:
165 | images, labels = self.extract_cifar_data_(self.filepath_holder[0])
166 | self.train_images = images["train"]
167 | self.train_labels = labels["train"]
168 | self.valid_images = images["valid"]
169 | self.valid_labels = labels["valid"]
170 | self.test_images = images["test"]
171 | self.test_labels = labels["test"]
172 | print("-" * 80)
173 | self.train_images = self.apply_preprocessing(images = self.train_images, mode = "train")
174 | self.valid_images = self.apply_preprocessing(images = self.valid_images, mode = "valid")
175 | self.test_images = self.apply_preprocessing(images = self.test_images, mode = "test")
176 | print("-" * 80)
177 | print("training size: ", np.shape(self.train_images),", ",np.shape(self.train_labels))
178 | print("valid size: ", np.shape(self.valid_images), ", ", np.shape(self.valid_labels))
179 | print("test size: ", np.shape(self.test_images), ", ", np.shape(self.test_labels))
180 |
181 | return self.train_images, self.train_labels, self.valid_images, self.valid_labels, self.test_images, self.test_labels
182 |
183 | def make_noise(self,image):
184 |
185 | def gaussian_noise(image):
186 | size = np.shape(image)
187 | noise = np.random.normal(0,0.3, size = size)
188 | image = image + noise
189 |
190 | return image
191 |
192 | return gaussian_noise(image)
193 |
194 | def initialize_batch(self):
195 | self.batch = 0
196 |
197 | def next_batch(self, images, labels, batch_size, make_noise = None):
198 |
199 | if make_noise is False:
200 | self.length = len(images)//batch_size
201 | batch_xs = images[self.batch*batch_size: self.batch*batch_size + batch_size,:,:,:]
202 | batch_noised_xs = np.copy(batch_xs)
203 | batch_ys = labels[self.batch*batch_size: self.batch*batch_size + batch_size,:]
204 | self.batch += 1
205 | if self.batch == (self.length):
206 | self.batch = 0
207 | else:
208 | self.length = len(images)//batch_size
209 | batch_noised_xs = []
210 | batch_xs = images[self.batch*batch_size: self.batch*batch_size + batch_size,:,:,:]
211 | batch_ys = labels[self.batch * batch_size: self.batch * batch_size + batch_size, :]
212 |
213 | if self.type == "MNIST":
214 | _ = np.reshape(batch_xs, [-1, self.size, self.size])
215 | for i in range(batch_size):
216 | batch_noised_xs.append(self.make_noise(_[i]))
217 | batch_noised_xs = np.reshape(batch_noised_xs, [-1, self.size, self.size, self.n_channels])
218 | else:
219 | for i in range(batch_size):
220 | batch_noised_xs.append(self.make_noise(batch_xs[i]))
221 |
222 | self.batch += 1
223 | if self.batch == (self.length):
224 | self.batch = 0
225 |
226 | return batch_xs, batch_noised_xs, batch_ys
227 |
228 | def get_total_batch(self,images, batch_size):
229 | self.batch_size = batch_size
230 | return len(images)//self.batch_size
231 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import time
4 | from utils import *
5 | from plot import *
6 | from AAE import *
7 | from data_utils import *
8 |
9 | DEFINE_string("model", "semi_supervised", "[supervised | semi_supervised]")
10 | DEFINE_string("data", "MNIST", "[MNIST | CIFAR_10]")
11 | DEFINE_string("prior", "gaussian", "[gaussain | gaussain_mixture | swiss_roll]")
12 |
13 |
14 | DEFINE_integer("super_n_hidden", 3000, "the number of elements for hidden layers")
15 | DEFINE_integer("semi_n_hidden", 3000, "teh number of elements for hidden layers")
16 | DEFINE_integer("n_epoch", 100, "number of Epoch for training")
17 | DEFINE_integer("n_z", 20, "Dimension of Latent variables")
18 | DEFINE_integer("num_samples",5000, "number of samples for semi supervised learning")
19 | DEFINE_integer("batch_size", 128, "Batch Size for training")
20 |
21 | DEFINE_float("keep_prob", 0.9, "dropout rate")
22 | DEFINE_float("lr_start", 0.001, "initial learning rate")
23 | DEFINE_float("lr_mid", 0.0001, "mid learning rate")
24 | DEFINE_float("lr_end", 0.0001, "final learning rate")
25 |
26 | DEFINE_boolean("noised", True, "")
27 | DEFINE_boolean("PMLR", True, "Boolean for plot manifold learning result")
28 | DEFINE_boolean("PARR", False, "Boolean for plot analogical reasoning result")
29 |
30 | conf = print_user_flags(line_limit = 100)
31 | print("-"*80)
32 |
33 | if conf.model == "supervised":
34 |
35 | data_pipeline = data_pipeline(conf.data)
36 |
37 | train_xs, train_ys, valid_xs, valid_ys, test_xs, test_ys = data_pipeline.load_preprocess_data()
38 |
39 | _, height, width, channel = np.shape(train_xs)
40 | n_cls = np.shape(train_ys)[1]
41 |
42 | X = tf.placeholder(dtype=tf.float32, shape=[None, height, width, channel], name="Inputs")
43 | X_noised = tf.placeholder(dtype=tf.float32, shape=[None, height, width, channel], name="Inputs_noised")
44 | Y = tf.placeholder(dtype=tf.float32, shape=[None, n_cls], name="Input_labels")
45 | z_prior = tf.placeholder(tf.float32, shape=[None, conf.n_z], name="z_prior")
46 | z_id = tf.placeholder(tf.float32, shape = [None, n_cls], name = "prior_labels")
47 | latent = tf.placeholder(tf.float32, shape = [None, conf.n_z], name = "latent_for_generation")
48 | keep_prob = tf.placeholder(dtype = tf.float32, name = "dropout_rate")
49 | lr_ = tf.placeholder(dtype = tf.float32, name = "learning_rate")
50 | global_step = tf.Variable(0, trainable=False)
51 |
52 | AAE = AAE(conf, [_, height, width, channel], n_cls)
53 | z_generated, X_generated, negative_log_likelihood, D_loss, G_loss = AAE.Sup_Adversarial_AutoEncoder(X,
54 | X_noised,
55 | Y,
56 | z_prior,
57 | z_id,
58 | keep_prob)
59 | images_PMLR = AAE.sup_decoder(latent, keep_prob)
60 | total_batch = data_pipeline.get_total_batch(train_xs, conf.batch_size)
61 |
62 | total_vars = tf.trainable_variables()
63 | var_AE = [var for var in total_vars if "encoder" or "decoder" in var.name]
64 | var_generator = [var for var in total_vars if "encoder" in var.name]
65 | var_discriminator = [var for var in total_vars if "discriminator" in var.name]
66 |
67 | op_AE = tf.train.AdamOptimizer(learning_rate = lr_).minimize(negative_log_likelihood,
68 | global_step = global_step,
69 | var_list = var_AE)
70 |
71 | op_D = tf.train.AdamOptimizer(learning_rate = lr_/5). minimize(D_loss,
72 | global_step = global_step,
73 | var_list = var_discriminator)
74 | op_G = tf.train.AdamOptimizer(learning_rate = lr_).minimize(G_loss,
75 | global_step = global_step,
76 | var_list = var_generator)
77 |
78 | batch_t_xs, batch_tn_xs, batch_t_ys = data_pipeline.next_batch(valid_xs, valid_ys, 100, make_noise= False)
79 | data_pipeline.initialize_batch()
80 |
81 | sess = tf.Session()
82 | sess.run(tf.initialize_all_variables())
83 |
84 | start_time = time.time()
85 | for i in range(conf.n_epoch):
86 | likelihood = 0
87 | D_value = 0
88 | G_value = 0
89 | for j in range(total_batch):
90 | batch_xs, batch_noised_xs, batch_ys = data_pipeline.next_batch(train_xs,
91 | train_ys,
92 | conf.batch_size,
93 | make_noise=conf.noised)
94 | if conf.prior == "gaussian":
95 | z_prior_, z_id_ = gaussian(conf.batch_size,
96 | n_labels = n_cls,
97 | n_dim = conf.n_z,
98 | use_label_info = True)
99 | z_id_onehot = np.eye(n_cls)[z_id_].astype(np.float32)
100 |
101 | elif conf.prior == "gaussian_mixture":
102 | z_id_ = np.random.randint(0, n_cls, size=[conf.batch_size])
103 | z_id_onehot = np.eye(n_cls)[z_id_].astype(np.float32)
104 | z_prior_ = gaussian_mixture(conf.batch_size,
105 | n_labels = n_cls,
106 | n_dim = conf.n_z,
107 | label_indices = z_id_)
108 |
109 | elif conf.prior == "swiss_roll":
110 | z_id_ = np.random.randint(0, n_cls, size=[conf.batch_size])
111 | z_id_onehot = np.eye(n_cls)[z_id_].astype(np.float32)
112 | z_prior_ = swiss_roll(conf.batch_size,
113 | n_labels = n_cls,
114 | n_dim = conf.n_z,
115 | label_indices = z_id_)
116 | else:
117 | print("FLAGS.prior should be [gaussian, gaussian_mixture, swiss_roll]")
118 |
119 | if i <= 50:
120 | lr_value = conf.lr_start
121 | elif i <=100:
122 | lr_value = conf.lr_mid
123 | else:
124 | lr_value = conf.lr_end
125 |
126 | feed_dict = {X: batch_xs,
127 | X_noised: batch_noised_xs,
128 | Y: batch_ys,
129 | z_prior: z_prior_,
130 | z_id: z_id_onehot,
131 | lr_: lr_value,
132 | keep_prob: conf.keep_prob}
133 |
134 | # AutoEncoder phase
135 | l, _, g = sess.run([negative_log_likelihood, op_AE, global_step], feed_dict=feed_dict)
136 |
137 | # Discriminator phase
138 | l_D, _ = sess.run([D_loss, op_D], feed_dict = feed_dict)
139 |
140 | l_G, _ = sess.run([G_loss, op_G], feed_dict = feed_dict)
141 |
142 | likelihood += l/total_batch
143 | D_value += l_D/total_batch
144 | G_value += l_G/total_batch
145 |
146 | if i % 5 == 0 or i == (conf.n_epoch -1):
147 | images = sess.run(X_generated, feed_dict = {X:batch_t_xs,
148 | X_noised: batch_tn_xs,
149 | keep_prob: 1.0})
150 | images = np.reshape(images, [-1, height, width, channel])
151 | name = "Manifold_canvas_" + str(i)
152 | plot_manifold_canvas(images, 10, type = "MNIST", name = name)
153 |
154 |
155 | hour = int((time.time() - start_time) / 3600)
156 | min = int(((time.time() - start_time) - 3600 * hour) / 60)
157 | sec = int((time.time() - start_time) - 3600 * hour - 60 * min)
158 | print("Epoch: %3d lr_AE: %.5f loss_AE: %.4f Time: %d hour %d min %d sec" % (i, lr_value, likelihood, hour, min, sec))
159 | print(" lr_D: %.5f loss_D: %.4f" % (lr_value/5, D_value))
160 | print(" lr_G: %.5f loss_G: %.4f\n" % (lr_value, G_value))
161 |
162 | ## code for 2D scatter plot
163 | if conf.n_z == 2:
164 | print("-" * 80)
165 | print("plot 2D Scatter Result")
166 | test_total_batch = data_pipeline.get_total_batch(test_xs, 128)
167 | data_pipeline.initialize_batch()
168 | latent_holder = []
169 | for i in range(test_total_batch):
170 | batch_test_xs, batch_test_noised_xs, batch_test_ys = data_pipeline.next_batch(test_xs,
171 | test_ys,
172 | conf.batch_size,
173 | make_noise=False)
174 | feed_dict = {X: batch_test_xs,
175 | X_noised: batch_test_noised_xs,
176 | keep_prob: 1.0}
177 |
178 | latent_vars = sess.run(z_generated, feed_dict=feed_dict)
179 | latent_holder.append(latent_vars)
180 | latent_holder = np.concatenate(latent_holder, axis=0)
181 | plot_2d_scatter(latent_holder[:, 0], latent_holder[:, 1], test_ys[:len(latent_holder)])
182 |
183 | if conf.PMLR is True:
184 | print("-" * 80)
185 | assert conf.n_z == 2, "Error: n_z should be 2"
186 | print("plot Manifold Learning Result")
187 | x_axis = np.linspace(-0.5, 0.5, 10)
188 | y_axis = np.linspace(0.5, -0.5, 10)
189 | z_holder = []
190 | for i, yi in enumerate(y_axis):
191 | for j, xi in enumerate(x_axis):
192 | z_holder.append([xi, yi])
193 | length = len(z_holder)
194 | MLR = sess.run(images_PMLR, feed_dict={latent: z_holder, keep_prob: 1.0})
195 | MLR = np.reshape(MLR, [-1, height, width, channel])
196 | p_name = "PMLR/PMLR"
197 | plot_manifold_canvas(MLR, 10, "MNIST", p_name)
198 |
199 | elif conf.model == "semi_supervised":
200 |
201 | Data = data_pipeline(conf.data)
202 | Data_semi = data_pipeline(conf.data)
203 | train_xs, train_ys, valid_xs, valid_ys, test_xs, test_ys = Data.load_preprocess_data()
204 | valid_xs, valid_ys = valid_xs[:conf.num_samples], valid_ys[:conf.num_samples]
205 |
206 | _, height, width, channel = np.shape(train_xs)
207 | n_cls = np.shape(train_ys)[1]
208 |
209 | X = tf.placeholder(dtype=tf.float32, shape=[None, height, width, channel], name="Input")
210 | X_noised = tf.placeholder(dtype=tf.float32, shape=[None, height, width, channel], name="Input_noised")
211 | Y = tf.placeholder(dtype=tf.float32, shape=[None, n_cls], name="Input_labels")
212 | Y_cat = tf.placeholder(dtype=tf.float32, shape=[None, n_cls], name="labels_cat")
213 | z_prior_ = tf.placeholder(dtype = tf.float32, shape = [None,conf.n_z], name = "z_prior" )
214 | latent = tf.placeholder(dtype = tf.float32, shape = [None, conf.n_z + n_cls], name = "latent_for_generation")
215 | keep_prob = tf.placeholder(dtype=tf.float32, name="dropout_rate")
216 | lr_ = tf.placeholder(dtype=tf.float32, name="learning_rate")
217 | global_step = tf.Variable(0, trainable=False)
218 |
219 | AAE = AAE(conf, [_, height, width, channel], n_cls)
220 |
221 | style, X_generated, negative_log_likelihood, D_loss_y, D_loss_z, G_loss, CE_labels =AAE.Semi_Adversarial_AutoEncoder(X,
222 | X_noised,
223 | Y,
224 | Y_cat,
225 | z_prior_,
226 | keep_prob)
227 | images_PARR = AAE.semi_decoder(latent, keep_prob)
228 | images_manifold = AAE.semi_decoder(latent, keep_prob)
229 |
230 | total_batch = Data.get_total_batch(train_xs, conf.batch_size)
231 |
232 | total_vars = tf.trainable_variables()
233 | var_AE = [var for var in total_vars if "encoder" or "decoder" in var.name]
234 | var_z_discriminator = [var for var in total_vars if "z_discriminator" in var.name]
235 | var_y_discriminator = [var for var in total_vars if "y_discriminator" in var.name]
236 | var_generator = [var for var in total_vars if "encoder" in var.name]
237 |
238 | op_AE = tf.train.AdamOptimizer(learning_rate = lr_).minimize(negative_log_likelihood, global_step = global_step, var_list = var_AE)
239 | op_y_D = tf.train.AdamOptimizer(learning_rate = lr_/5).minimize(D_loss_y, global_step = global_step, var_list = var_y_discriminator)
240 | op_z_D = tf.train.AdamOptimizer(learning_rate = lr_/5).minimize(D_loss_z, global_step = global_step, var_list = var_z_discriminator)
241 | op_G = tf.train.AdamOptimizer(learning_rate = lr_).minimize(G_loss, global_step = global_step, var_list = var_generator)
242 | op_CE_labels = tf.train.AdamOptimizer(learning_rate = lr_).minimize(CE_labels, global_step = global_step, var_list = var_generator)
243 |
244 | batch_t_xs, batch_tn_xs, batch_t_ys = Data.next_batch(valid_xs, valid_ys, 100, make_noise = False)
245 | Data.initialize_batch()
246 |
247 | sess = tf.Session()
248 | sess.run(tf.initialize_all_variables())
249 |
250 | start_time = time.time()
251 | for i in range(conf.n_epoch):
252 | likelihood = 0
253 | D_z_value = 0
254 | D_y_value = 0
255 | G_value = 0
256 | CE_value = 0
257 |
258 | if i <= 50:
259 | lr_value = conf.lr_start
260 | elif i <= 100:
261 | lr_value = conf.lr_mid
262 | else:
263 | lr_value = conf.lr_end
264 |
265 | for j in range(total_batch):
266 | batch_xs, batch_noised_xs, batch_ys = Data.next_batch(train_xs,
267 | train_ys,
268 | conf.batch_size,
269 | make_noise = conf.noised)
270 |
271 | real_cat_labels = np.random.randint(low = 0, high = n_cls, size = conf.batch_size)
272 | real_cat_labels = np.eye(n_cls)[real_cat_labels]
273 |
274 | if conf.prior == "gaussian":
275 | z_prior = gaussian(conf.batch_size,
276 | n_labels=n_cls,
277 | n_dim=conf.n_z,
278 | use_label_info=False)
279 |
280 | elif conf.prior == "gaussian_mixture":
281 | z_prior = gaussian_mixture(conf.batch_size,
282 | n_labels=n_cls,
283 | n_dim=conf.n_z)
284 |
285 | elif conf.prior == "swiss_roll":
286 | z_prior = swiss_roll(conf.batch_size,
287 | n_labels=n_cls,
288 | n_dim=conf.n_z)
289 | else:
290 | print("FLAGS.prior should be [gaussian, gaussian_mixture, swiss_roll]")
291 |
292 | feed_dict = {X: batch_xs,
293 | X_noised: batch_noised_xs,
294 | Y: batch_ys,
295 | Y_cat: real_cat_labels,
296 | z_prior_: z_prior,
297 | lr_: lr_value,
298 | keep_prob: conf.keep_prob}
299 |
300 | # AutoEncoder phase
301 | l, _, g = sess.run([negative_log_likelihood, op_AE, global_step], feed_dict = feed_dict)
302 |
303 | # z_Discriminator phase
304 | l_z_D,_ = sess.run([D_loss_z, op_z_D], feed_dict = feed_dict)
305 |
306 | # y_Discriminator phase
307 | l_y_D, _ = sess.run([D_loss_y, op_y_D], feed_dict=feed_dict)
308 |
309 | # Generator phase
310 | l_G, _ = sess.run([G_loss, op_G], feed_dict = feed_dict)
311 |
312 | batch_semi_xs, batch_noised_semi_xs,batch_semi_ys = Data_semi.next_batch(valid_xs,
313 | valid_ys,
314 | conf.batch_size,
315 | make_noise = False)
316 |
317 | feed_dict = {X: batch_semi_xs,
318 | X_noised: batch_noised_semi_xs,
319 | Y: batch_semi_ys,
320 | Y_cat: real_cat_labels,
321 | lr_:lr_value,
322 | keep_prob: conf.keep_prob}
323 |
324 | # Cross_Entropy phase
325 | CE, _ = sess.run([CE_labels, op_CE_labels], feed_dict = feed_dict)
326 |
327 | likelihood += l/total_batch
328 | D_z_value += l_z_D/total_batch
329 | D_y_value += l_y_D/total_batch
330 | G_value += l_G/total_batch
331 | CE_value += CE/total_batch
332 |
333 | if i % 5 == 0 or i == (conf.n_epoch -1):
334 | images = sess.run(X_generated, feed_dict = {X:batch_t_xs,
335 | X_noised: batch_tn_xs,
336 | keep_prob: 1.0})
337 | images = np.reshape(images, [-1, height, width, channel])
338 | name = "Manifold_semi_canvas_" + str(i)
339 | plot_manifold_canvas(images, 10, type = "MNIST", name = name)
340 |
341 |
342 | hour = int((time.time() - start_time) / 3600)
343 | min = int(((time.time() - start_time) - 3600 * hour) / 60)
344 | sec = int((time.time() - start_time) - 3600 * hour - 60 * min)
345 | print("Epoch: %3d lr_AE_G_CE: %.5f lr_D: %.5f Time: %d hour %d min %d sec" % (i, lr_value,lr_value/5, hour, min, sec))
346 | print("loss_AE: %.5f" % (likelihood))
347 | print("loss_z_D: %.4f loss_y_D: %f" % (D_z_value, D_y_value))
348 | print("loss_G: %.4f CE_semi: %.4f\n" % (G_value, CE_value))
349 |
350 | if conf.PARR is True:
351 | print("-"*80)
352 | print("plot analogical reasoning result")
353 | z_holder = []
354 | for i in range(n_cls):
355 | z_ = np.random.rand(10, conf.n_z)
356 | z_holder.append(z_)
357 | z_holder = np.concatenate(z_holder, axis = 0)
358 | y = [j for j in range(n_cls)]
359 | y = y*10
360 | length = len(z_holder)
361 | y_one_hot = np.zeros((length, n_cls))
362 | y_one_hot[np.arange(length), y] = 1
363 | y_one_hot = np.reshape(y_one_hot, [-1, n_cls])
364 | z_concated = np.concatenate([z_holder, y_one_hot], axis=1)
365 | PARR = sess.run(images_PARR, feed_dict = {latent: z_concated, keep_prob: 1.0})
366 | PARR = np.reshape(PARR, [-1, height, width, channel])
367 | p_name = "PARR/Cond_generation"
368 | plot_manifold_canvas(PARR, 10, "MNIST", p_name)
369 |
370 | ## code for 2D scatter plot
371 | if conf.n_z == 2:
372 | print("-" * 80)
373 | print("plot 2D Scatter Result")
374 | test_total_batch = Data.get_total_batch(test_xs, 128)
375 | Data.initialize_batch()
376 | latent_holder = []
377 | for i in range(test_total_batch):
378 | batch_test_xs, batch_test_noised_xs, batch_test_ys = Data.next_batch(test_xs,
379 | test_ys,
380 | conf.batch_size,
381 | make_noise=False)
382 | feed_dict = {X: batch_test_xs,
383 | X_noised: batch_test_noised_xs,
384 | Y: batch_test_ys,
385 | keep_prob: 1.0}
386 |
387 | latent_vars = sess.run(style, feed_dict=feed_dict)
388 | latent_holder.append(latent_vars)
389 | latent_holder = np.concatenate(latent_holder, axis=0)
390 | plot_2d_scatter(latent_holder[:, 0], latent_holder[:, 1], test_ys[:len(latent_holder)])
391 |
392 | if conf.PMLR is True:
393 | print("-"*80)
394 | assert conf.n_z == 2, "Error: n_z should be 2"
395 | print("plot Manifold Learning Results")
396 | x_axis = np.linspace(-0.5,0.5,10)
397 | y_axis = np.linspace(-0.5,0.5,10)
398 | z_holder = []
399 | for i,xi in enumerate(x_axis):
400 | for j, yi in enumerate(y_axis):
401 | z_holder.append([xi,yi])
402 | length = len(z_holder)
403 | for k in range(n_cls):
404 | y = [k]*length
405 | y_one_hot = np.zeros((length, n_cls))
406 | y_one_hot[np.arange(length), y] = 1
407 | y_one_hot = np.reshape(y_one_hot, [-1,n_cls])
408 | z_concated = np.concatenate([z_holder, y_one_hot], axis=1)
409 | MLR = sess.run(images_manifold, feed_dict = {latent: z_concated, keep_prob: 1.0})
410 | MLR = np.reshape(MLR, [-1, height, width, channel])
411 | p_name = "PMLR/labels" +str(k)
412 | plot_manifold_canvas(MLR, 10, "MNIST", p_name)
--------------------------------------------------------------------------------