├── ACGAN.py ├── BEGAN.py ├── CGAN.py ├── CVAE.py ├── DRAGAN.py ├── EBGAN.py ├── GAN.py ├── LICENSE ├── LSGAN.py ├── README.md ├── VAE.py ├── WGAN.py ├── WGAN_GP.py ├── assets ├── equations │ ├── ACGAN.png │ ├── BEGAN.png │ ├── CGAN.png │ ├── CVAE.png │ ├── DRAGAN.png │ ├── EBGAN.png │ ├── GAN.png │ ├── LSGAN.png │ ├── VAE.png │ ├── WGAN.png │ ├── WGAN_GP.png │ └── infoGAN.png ├── etc │ ├── GAN_structure.png │ └── VAE_structure.png ├── fashion_mnist_results │ ├── conditional_generation │ │ ├── ACGAN_epoch000_test_all_classes_style_by_style.png │ │ ├── ACGAN_epoch009_test_all_classes_style_by_style.png │ │ ├── ACGAN_epoch019_test_all_classes_style_by_style.png │ │ ├── ACGAN_epoch039_test_all_classes_style_by_style.png │ │ ├── CGAN_epoch000_test_all_classes_style_by_style.png │ │ ├── CGAN_epoch019_test_all_classes_style_by_style.png │ │ ├── CGAN_epoch039_test_all_classes_style_by_style.png │ │ ├── CVAE_epoch000_test_all_classes_style_by_style.png │ │ ├── CVAE_epoch019_test_all_classes_style_by_style.png │ │ ├── CVAE_epoch039_test_all_classes_style_by_style.png │ │ ├── infoGAN_epoch000_test_all_classes_style_by_style.png │ │ ├── infoGAN_epoch019_test_all_classes_style_by_style.png │ │ └── infoGAN_epoch039_test_all_classes_style_by_style.png │ ├── infogan │ │ ├── infoGAN_epoch039_test_class_c1c2_1.png │ │ ├── infoGAN_epoch039_test_class_c1c2_4.png │ │ ├── infoGAN_epoch039_test_class_c1c2_5.png │ │ └── infoGAN_epoch039_test_class_c1c2_8.png │ ├── learned_manifold │ │ ├── VAE_epoch000_learned_manifold.png │ │ ├── VAE_epoch009_learned_manifold.png │ │ └── VAE_epoch024_learned_manifold.png │ └── random_generation │ │ ├── BEGAN_epoch000_test_all_classes.png │ │ ├── BEGAN_epoch019_test_all_classes.png │ │ ├── BEGAN_epoch039_test_all_classes.png │ │ ├── DRAGAN_epoch000_test_all_classes.png │ │ ├── DRAGAN_epoch019_test_all_classes.png │ │ ├── DRAGAN_epoch039_test_all_classes.png │ │ ├── EBGAN_epoch000_test_all_classes.png │ │ ├── EBGAN_epoch019_test_all_classes.png │ │ ├── EBGAN_epoch039_test_all_classes.png │ │ ├── GAN_epoch000_test_all_classes.png │ │ ├── GAN_epoch019_test_all_classes.png │ │ ├── GAN_epoch039_test_all_classes.png │ │ ├── LSGAN_epoch000_test_all_classes.png │ │ ├── LSGAN_epoch019_test_all_classes.png │ │ ├── LSGAN_epoch039_test_all_classes.png │ │ ├── VAE_epoch000_test_all_classes.png │ │ ├── VAE_epoch019_test_all_classes.png │ │ ├── VAE_epoch039_test_all_classes.png │ │ ├── WGAN_GP_epoch000_test_all_classes.png │ │ ├── WGAN_GP_epoch019_test_all_classes.png │ │ ├── WGAN_GP_epoch039_test_all_classes.png │ │ ├── WGAN_epoch000_test_all_classes.png │ │ ├── WGAN_epoch019_test_all_classes.png │ │ └── WGAN_epoch039_test_all_classes.png └── mnist_results │ ├── conditional_generation │ ├── ACGAN_epoch000_test_all_classes_style_by_style.png │ ├── ACGAN_epoch009_test_all_classes_style_by_style.png │ ├── ACGAN_epoch024_test_all_classes_style_by_style.png │ ├── CGAN_epoch000_test_all_classes_style_by_style.png │ ├── CGAN_epoch009_test_all_classes_style_by_style.png │ ├── CGAN_epoch024_test_all_classes_style_by_style.png │ ├── CVAE_epoch000_test_all_classes_style_by_style.png │ ├── CVAE_epoch009_test_all_classes_style_by_style.png │ ├── CVAE_epoch024_test_all_classes_style_by_style.png │ ├── infoGAN_epoch000_test_all_classes_style_by_style.png │ ├── infoGAN_epoch009_test_all_classes_style_by_style.png │ └── infoGAN_epoch024_test_all_classes_style_by_style.png │ ├── infogan │ ├── infoGAN_epoch024_test_class_c1c2_2.png │ ├── infoGAN_epoch024_test_class_c1c2_5.png │ ├── infoGAN_epoch024_test_class_c1c2_7.png │ └── infoGAN_epoch024_test_class_c1c2_9.png │ ├── learned_manifold │ ├── VAE_epoch000_learned_manifold.png │ ├── VAE_epoch009_learned_manifold.png │ └── VAE_epoch024_learned_manifold.png │ └── random_generation │ ├── BEGAN_epoch000_test_all_classes.png │ ├── BEGAN_epoch001_test_all_classes.png │ ├── BEGAN_epoch009_test_all_classes.png │ ├── BEGAN_epoch024_test_all_classes.png │ ├── DRAGAN_epoch000_test_all_classes.png │ ├── DRAGAN_epoch001_test_all_classes.png │ ├── DRAGAN_epoch009_test_all_classes.png │ ├── DRAGAN_epoch024_test_all_classes.png │ ├── EBGAN_epoch000_test_all_classes.png │ ├── EBGAN_epoch001_test_all_classes.png │ ├── EBGAN_epoch009_test_all_classes.png │ ├── EBGAN_epoch024_test_all_classes.png │ ├── GAN_epoch000_test_all_classes.png │ ├── GAN_epoch001_test_all_classes.png │ ├── GAN_epoch009_test_all_classes.png │ ├── GAN_epoch024_test_all_classes.png │ ├── LSGAN_epoch000_test_all_classes.png │ ├── LSGAN_epoch001_test_all_classes.png │ ├── LSGAN_epoch009_test_all_classes.png │ ├── LSGAN_epoch024_test_all_classes.png │ ├── VAE_epoch000_test_all_classes.png │ ├── VAE_epoch001_test_all_classes.png │ ├── VAE_epoch009_test_all_classes.png │ ├── VAE_epoch024_test_all_classes.png │ ├── WGAN_GP_epoch000_test_all_classes.png │ ├── WGAN_GP_epoch001_test_all_classes.png │ ├── WGAN_GP_epoch009_test_all_classes.png │ ├── WGAN_GP_epoch024_test_all_classes.png │ ├── WGAN_epoch000_test_all_classes.png │ ├── WGAN_epoch001_test_all_classes.png │ ├── WGAN_epoch009_test_all_classes.png │ └── WGAN_epoch024_test_all_classes.png ├── infoGAN.py ├── main.py ├── ops.py ├── prior_factory.py └── utils.py /ACGAN.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | from __future__ import division 3 | import os 4 | import time 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | from ops import * 9 | from utils import * 10 | 11 | class ACGAN(object): 12 | model_name = "ACGAN" # name for checkpoint 13 | 14 | def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir): 15 | self.sess = sess 16 | self.dataset_name = dataset_name 17 | self.checkpoint_dir = checkpoint_dir 18 | self.result_dir = result_dir 19 | self.log_dir = log_dir 20 | self.epoch = epoch 21 | self.batch_size = batch_size 22 | 23 | if dataset_name == 'mnist' or dataset_name == 'fashion-mnist': 24 | # parameters 25 | self.input_height = 28 26 | self.input_width = 28 27 | self.output_height = 28 28 | self.output_width = 28 29 | 30 | self.z_dim = z_dim # dimension of noise-vector 31 | self.y_dim = 10 # dimension of code-vector (label) 32 | self.c_dim = 1 33 | 34 | # train 35 | self.learning_rate = 0.0002 36 | self.beta1 = 0.5 37 | 38 | # test 39 | self.sample_num = 64 # number of generated images to be saved 40 | 41 | # code 42 | self.len_discrete_code = 10 # categorical distribution (i.e. label) 43 | self.len_continuous_code = 2 # gaussian distribution (e.g. rotation, thickness) 44 | 45 | # load mnist 46 | self.data_X, self.data_y = load_mnist(self.dataset_name) 47 | 48 | # get number of batches for a single epoch 49 | self.num_batches = len(self.data_X) // self.batch_size 50 | else: 51 | raise NotImplementedError 52 | 53 | def classifier(self, x, is_training=True, reuse=False): 54 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 55 | # Architecture : (64)5c2s-(128)5c2s_BL-FC1024_BL-FC128_BL-FC12S’ 56 | # All layers except the last two layers are shared by discriminator 57 | with tf.variable_scope("classifier", reuse=reuse): 58 | 59 | net = lrelu(bn(linear(x, 128, scope='c_fc1'), is_training=is_training, scope='c_bn1')) 60 | out_logit = linear(net, self.y_dim, scope='c_fc2') 61 | out = tf.nn.softmax(out_logit) 62 | 63 | return out, out_logit 64 | 65 | def discriminator(self, x, is_training=True, reuse=False): 66 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 67 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S 68 | with tf.variable_scope("discriminator", reuse=reuse): 69 | 70 | net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='d_conv1')) 71 | net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'), is_training=is_training, scope='d_bn2')) 72 | net = tf.reshape(net, [self.batch_size, -1]) 73 | net = lrelu(bn(linear(net, 1024, scope='d_fc3'), is_training=is_training, scope='d_bn3')) 74 | out_logit = linear(net, 1, scope='d_fc4') 75 | out = tf.nn.sigmoid(out_logit) 76 | 77 | return out, out_logit, net 78 | 79 | def generator(self, z, y, is_training=True, reuse=False): 80 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 81 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 82 | with tf.variable_scope("generator", reuse=reuse): 83 | 84 | # merge noise and code 85 | z = concat([z, y], 1) 86 | 87 | net = tf.nn.relu(bn(linear(z, 1024, scope='g_fc1'), is_training=is_training, scope='g_bn1')) 88 | net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='g_fc2'), is_training=is_training, scope='g_bn2')) 89 | net = tf.reshape(net, [self.batch_size, 7, 7, 128]) 90 | net = tf.nn.relu( 91 | bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='g_dc3'), is_training=is_training, 92 | scope='g_bn3')) 93 | 94 | out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='g_dc4')) 95 | 96 | return out 97 | 98 | def build_model(self): 99 | # some parameters 100 | image_dims = [self.input_height, self.input_width, self.c_dim] 101 | bs = self.batch_size 102 | 103 | """ Graph Input """ 104 | # images 105 | self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images') 106 | 107 | # labels 108 | self.y = tf.placeholder(tf.float32, [bs, self.y_dim], name='y') 109 | 110 | # noises 111 | self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z') 112 | 113 | """ Loss Function """ 114 | ## 1. GAN Loss 115 | # output of D for real images 116 | D_real, D_real_logits, input4classifier_real = self.discriminator(self.inputs, is_training=True, reuse=False) 117 | 118 | # output of D for fake images 119 | G = self.generator(self.z, self.y, is_training=True, reuse=False) 120 | D_fake, D_fake_logits, input4classifier_fake = self.discriminator(G, is_training=True, reuse=True) 121 | 122 | # get loss for discriminator 123 | d_loss_real = tf.reduce_mean( 124 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones_like(D_real))) 125 | d_loss_fake = tf.reduce_mean( 126 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros_like(D_fake))) 127 | 128 | self.d_loss = d_loss_real + d_loss_fake 129 | 130 | # get loss for generator 131 | self.g_loss = tf.reduce_mean( 132 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones_like(D_fake))) 133 | 134 | ## 2. Information Loss 135 | code_fake, code_logit_fake = self.classifier(input4classifier_fake, is_training=True, reuse=False) 136 | code_real, code_logit_real = self.classifier(input4classifier_real, is_training=True, reuse=True) 137 | 138 | # For real samples 139 | q_real_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=code_logit_real, labels=self.y)) 140 | 141 | # For fake samples 142 | q_fake_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=code_logit_fake, labels=self.y)) 143 | 144 | # get information loss 145 | self.q_loss = q_fake_loss + q_real_loss 146 | 147 | """ Training """ 148 | # divide trainable variables into a group for D and a group for G 149 | t_vars = tf.trainable_variables() 150 | d_vars = [var for var in t_vars if 'd_' in var.name] 151 | g_vars = [var for var in t_vars if 'g_' in var.name] 152 | q_vars = [var for var in t_vars if ('d_' in var.name) or ('c_' in var.name) or ('g_' in var.name)] 153 | 154 | # optimizers 155 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 156 | self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \ 157 | .minimize(self.d_loss, var_list=d_vars) 158 | self.g_optim = tf.train.AdamOptimizer(self.learning_rate * 5, beta1=self.beta1) \ 159 | .minimize(self.g_loss, var_list=g_vars) 160 | self.q_optim = tf.train.AdamOptimizer(self.learning_rate * 5, beta1=self.beta1) \ 161 | .minimize(self.q_loss, var_list=q_vars) 162 | 163 | """" Testing """ 164 | # for test 165 | self.fake_images = self.generator(self.z, self.y, is_training=False, reuse=True) 166 | 167 | """ Summary """ 168 | d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real) 169 | d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake) 170 | d_loss_sum = tf.summary.scalar("d_loss", self.d_loss) 171 | g_loss_sum = tf.summary.scalar("g_loss", self.g_loss) 172 | 173 | q_loss_sum = tf.summary.scalar("g_loss", self.q_loss) 174 | q_real_sum = tf.summary.scalar("q_real_loss", q_real_loss) 175 | q_fake_sum = tf.summary.scalar("q_fake_loss", q_fake_loss) 176 | 177 | # final summary operations 178 | self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum]) 179 | self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum]) 180 | self.q_sum = tf.summary.merge([q_loss_sum, q_real_sum, q_fake_sum]) 181 | 182 | def train(self): 183 | 184 | # initialize all variables 185 | tf.global_variables_initializer().run() 186 | 187 | # graph inputs for visualize training results 188 | self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim)) 189 | self.test_codes = self.data_y[0:self.batch_size] 190 | 191 | # saver to save model 192 | self.saver = tf.train.Saver() 193 | 194 | # summary writer 195 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph) 196 | 197 | # restore check-point if it exits 198 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 199 | if could_load: 200 | start_epoch = (int)(checkpoint_counter / self.num_batches) 201 | start_batch_id = checkpoint_counter - start_epoch * self.num_batches 202 | counter = checkpoint_counter 203 | print(" [*] Load SUCCESS") 204 | else: 205 | start_epoch = 0 206 | start_batch_id = 0 207 | counter = 1 208 | print(" [!] Load failed...") 209 | 210 | # loop for epoch 211 | start_time = time.time() 212 | for epoch in range(start_epoch, self.epoch): 213 | 214 | # get batch data 215 | for idx in range(start_batch_id, self.num_batches): 216 | batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size] 217 | batch_codes = self.data_y[idx * self.batch_size:(idx + 1) * self.batch_size] 218 | 219 | batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32) 220 | 221 | # update D network 222 | _, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss], 223 | feed_dict={self.inputs: batch_images, self.y: batch_codes, 224 | self.z: batch_z}) 225 | self.writer.add_summary(summary_str, counter) 226 | 227 | # update G & Q network 228 | _, summary_str_g, g_loss, _, summary_str_q, q_loss = self.sess.run( 229 | [self.g_optim, self.g_sum, self.g_loss, self.q_optim, self.q_sum, self.q_loss], 230 | feed_dict={self.z: batch_z, self.y: batch_codes, self.inputs: batch_images}) 231 | self.writer.add_summary(summary_str_g, counter) 232 | self.writer.add_summary(summary_str_q, counter) 233 | 234 | # display training status 235 | counter += 1 236 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ 237 | % (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss)) 238 | 239 | # save training results for every 300 steps 240 | if np.mod(counter, 300) == 0: 241 | samples = self.sess.run(self.fake_images, 242 | feed_dict={self.z: self.sample_z, self.y: self.test_codes}) 243 | tot_num_samples = min(self.sample_num, self.batch_size) 244 | manifold_h = int(np.floor(np.sqrt(tot_num_samples))) 245 | manifold_w = int(np.floor(np.sqrt(tot_num_samples))) 246 | save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w], './' + check_folder( 247 | self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format( 248 | epoch, idx)) 249 | 250 | # After an epoch, start_batch_id is set to zero 251 | # non-zero value is only for the first epoch after loading pre-trained model 252 | start_batch_id = 0 253 | 254 | # save model 255 | self.save(self.checkpoint_dir, counter) 256 | 257 | # show temporal results 258 | self.visualize_results(epoch) 259 | 260 | # save model for final step 261 | self.save(self.checkpoint_dir, counter) 262 | 263 | def visualize_results(self, epoch): 264 | tot_num_samples = min(self.sample_num, self.batch_size) 265 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 266 | z_sample = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim)) 267 | 268 | """ random noise, random discrete code, fixed continuous code """ 269 | y = np.random.choice(self.len_discrete_code, self.batch_size) 270 | y_one_hot = np.zeros((self.batch_size, self.y_dim)) 271 | y_one_hot[np.arange(self.batch_size), y] = 1 272 | 273 | samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y_one_hot}) 274 | 275 | save_images(samples[:image_frame_dim*image_frame_dim,:,:,:], [image_frame_dim, image_frame_dim], 276 | check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png') 277 | 278 | """ specified condition, random noise """ 279 | n_styles = 10 # must be less than or equal to self.batch_size 280 | 281 | np.random.seed() 282 | si = np.random.choice(self.batch_size, n_styles) 283 | 284 | for l in range(self.len_discrete_code): 285 | y = np.zeros(self.batch_size, dtype=np.int64) + l 286 | y_one_hot = np.zeros((self.batch_size, self.y_dim)) 287 | y_one_hot[np.arange(self.batch_size), y] = 1 288 | 289 | samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y_one_hot}) 290 | save_images(samples[:image_frame_dim*image_frame_dim,:,:,:], [image_frame_dim, image_frame_dim], 291 | check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_class_%d.png' % l) 292 | 293 | samples = samples[si, :, :, :] 294 | 295 | if l == 0: 296 | all_samples = samples 297 | else: 298 | all_samples = np.concatenate((all_samples, samples), axis=0) 299 | 300 | """ save merged images to check style-consistency """ 301 | canvas = np.zeros_like(all_samples) 302 | for s in range(n_styles): 303 | for c in range(self.len_discrete_code): 304 | canvas[s * self.len_discrete_code + c, :, :, :] = all_samples[c * n_styles + s, :, :, :] 305 | 306 | save_images(canvas, [n_styles, self.len_discrete_code], 307 | check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes_style_by_style.png') 308 | 309 | @property 310 | def model_dir(self): 311 | return "{}_{}_{}_{}".format( 312 | self.model_name, self.dataset_name, 313 | self.batch_size, self.z_dim) 314 | 315 | def save(self, checkpoint_dir, step): 316 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 317 | 318 | if not os.path.exists(checkpoint_dir): 319 | os.makedirs(checkpoint_dir) 320 | 321 | self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step) 322 | 323 | def load(self, checkpoint_dir): 324 | import re 325 | print(" [*] Reading checkpoints...") 326 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 327 | 328 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 329 | if ckpt and ckpt.model_checkpoint_path: 330 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 331 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 332 | counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0)) 333 | print(" [*] Success to read {}".format(ckpt_name)) 334 | return True, counter 335 | else: 336 | print(" [*] Failed to find a checkpoint") 337 | return False, 0 -------------------------------------------------------------------------------- /BEGAN.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | from __future__ import division 3 | import os 4 | import time 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | from ops import * 9 | from utils import * 10 | 11 | class BEGAN(object): 12 | model_name = "BEGAN" # name for checkpoint 13 | 14 | def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir): 15 | self.sess = sess 16 | self.dataset_name = dataset_name 17 | self.checkpoint_dir = checkpoint_dir 18 | self.result_dir = result_dir 19 | self.log_dir = log_dir 20 | self.epoch = epoch 21 | self.batch_size = batch_size 22 | 23 | if dataset_name == 'mnist' or dataset_name == 'fashion-mnist': 24 | # parameters 25 | self.input_height = 28 26 | self.input_width = 28 27 | self.output_height = 28 28 | self.output_width = 28 29 | 30 | self.z_dim = z_dim # dimension of noise-vector 31 | self.c_dim = 1 32 | 33 | # BEGAN Parameter 34 | self.gamma = 0.75 35 | self.lamda = 0.001 36 | 37 | # train 38 | self.learning_rate = 0.0002 39 | self.beta1 = 0.5 40 | 41 | # test 42 | self.sample_num = 64 # number of generated images to be saved 43 | 44 | # load mnist 45 | self.data_X, self.data_y = load_mnist(self.dataset_name) 46 | 47 | # get number of batches for a single epoch 48 | self.num_batches = len(self.data_X) // self.batch_size 49 | else: 50 | raise NotImplementedError 51 | 52 | def discriminator(self, x, is_training=True, reuse=False): 53 | # It must be Auto-Encoder style architecture 54 | # Architecture : (64)4c2s-FC32_BR-FC64*14*14_BR-(1)4dc2s_S 55 | with tf.variable_scope("discriminator", reuse=reuse): 56 | 57 | net = tf.nn.relu(conv2d(x, 64, 4, 4, 2, 2, name='d_conv1')) 58 | net = tf.reshape(net, [self.batch_size, -1]) 59 | code = tf.nn.relu(bn(linear(net, 32, scope='d_fc6'), is_training=is_training, scope='d_bn6')) 60 | net = tf.nn.relu(bn(linear(code, 64 * 14 * 14, scope='d_fc3'), is_training=is_training, scope='d_bn3')) 61 | net = tf.reshape(net, [self.batch_size, 14, 14, 64]) 62 | out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='d_dc5')) 63 | 64 | # recon loss 65 | recon_error = tf.sqrt(2 * tf.nn.l2_loss(out - x)) / self.batch_size 66 | return out, recon_error, code 67 | 68 | def generator(self, z, is_training=True, reuse=False): 69 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 70 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 71 | with tf.variable_scope("generator", reuse=reuse): 72 | net = tf.nn.relu(bn(linear(z, 1024, scope='g_fc1'), is_training=is_training, scope='g_bn1')) 73 | net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='g_fc2'), is_training=is_training, scope='g_bn2')) 74 | net = tf.reshape(net, [self.batch_size, 7, 7, 128]) 75 | net = tf.nn.relu( 76 | bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='g_dc3'), is_training=is_training, 77 | scope='g_bn3')) 78 | 79 | out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='g_dc4')) 80 | 81 | return out 82 | 83 | def build_model(self): 84 | # some parameters 85 | image_dims = [self.input_height, self.input_width, self.c_dim] 86 | bs = self.batch_size 87 | 88 | """ BEGAN variable """ 89 | self.k = tf.Variable(0., trainable=False) 90 | 91 | """ Graph Input """ 92 | # images 93 | self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images') 94 | 95 | # noises 96 | self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z') 97 | 98 | """ Loss Function """ 99 | 100 | # output of D for real images 101 | D_real_img, D_real_err, D_real_code = self.discriminator(self.inputs, is_training=True, reuse=False) 102 | 103 | # output of D for fake images 104 | G = self.generator(self.z, is_training=True, reuse=False) 105 | D_fake_img, D_fake_err, D_fake_code = self.discriminator(G, is_training=True, reuse=True) 106 | 107 | # get loss for discriminator 108 | self.d_loss = D_real_err - self.k*D_fake_err 109 | 110 | # get loss for generator 111 | self.g_loss = D_fake_err 112 | 113 | # convergence metric 114 | self.M = D_real_err + tf.abs(self.gamma*D_real_err - D_fake_err) 115 | 116 | # operation for updating k 117 | self.update_k = self.k.assign( 118 | tf.clip_by_value(self.k + self.lamda*(self.gamma*D_real_err - D_fake_err), 0, 1)) 119 | 120 | """ Training """ 121 | # divide trainable variables into a group for D and a group for G 122 | t_vars = tf.trainable_variables() 123 | d_vars = [var for var in t_vars if 'd_' in var.name] 124 | g_vars = [var for var in t_vars if 'g_' in var.name] 125 | 126 | # optimizers 127 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 128 | self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \ 129 | .minimize(self.d_loss, var_list=d_vars) 130 | self.g_optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \ 131 | .minimize(self.g_loss, var_list=g_vars) 132 | 133 | """" Testing """ 134 | # for test 135 | self.fake_images = self.generator(self.z, is_training=False, reuse=True) 136 | 137 | """ Summary """ 138 | d_loss_real_sum = tf.summary.scalar("d_error_real", D_real_err) 139 | d_loss_fake_sum = tf.summary.scalar("d_error_fake", D_fake_err) 140 | d_loss_sum = tf.summary.scalar("d_loss", self.d_loss) 141 | g_loss_sum = tf.summary.scalar("g_loss", self.g_loss) 142 | M_sum = tf.summary.scalar("M", self.M) 143 | k_sum = tf.summary.scalar("k", self.k) 144 | 145 | # final summary operations 146 | self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum]) 147 | self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum]) 148 | self.p_sum = tf.summary.merge([M_sum, k_sum]) 149 | 150 | def train(self): 151 | 152 | # initialize all variables 153 | tf.global_variables_initializer().run() 154 | 155 | # graph inputs for visualize training results 156 | self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size , self.z_dim)) 157 | 158 | # saver to save model 159 | self.saver = tf.train.Saver() 160 | 161 | # summary writer 162 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph) 163 | 164 | # restore check-point if it exits 165 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 166 | if could_load: 167 | start_epoch = (int)(checkpoint_counter / self.num_batches) 168 | start_batch_id = checkpoint_counter - start_epoch * self.num_batches 169 | counter = checkpoint_counter 170 | print(" [*] Load SUCCESS") 171 | else: 172 | start_epoch = 0 173 | start_batch_id = 0 174 | counter = 1 175 | print(" [!] Load failed...") 176 | 177 | # loop for epoch 178 | start_time = time.time() 179 | for epoch in range(start_epoch, self.epoch): 180 | 181 | # get batch data 182 | for idx in range(start_batch_id, self.num_batches): 183 | batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size] 184 | batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32) 185 | 186 | # update D network 187 | _, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss], 188 | feed_dict={self.inputs: batch_images, self.z: batch_z}) 189 | self.writer.add_summary(summary_str, counter) 190 | 191 | # update G network 192 | _, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss], feed_dict={self.z: batch_z}) 193 | self.writer.add_summary(summary_str, counter) 194 | 195 | # update k 196 | _, summary_str, M_value, k_value = self.sess.run([self.update_k, self.p_sum, self.M, self.k], feed_dict={self.inputs: batch_images, self.z: batch_z}) 197 | self.writer.add_summary(summary_str, counter) 198 | 199 | # display training status 200 | counter += 1 201 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f, M: %.8f, k: %.8f" \ 202 | % (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss, M_value, k_value)) 203 | 204 | # save training results for every 300 steps 205 | if np.mod(counter, 300) == 0: 206 | samples = self.sess.run(self.fake_images, feed_dict={self.z: self.sample_z}) 207 | tot_num_samples = min(self.sample_num, self.batch_size) 208 | manifold_h = int(np.floor(np.sqrt(tot_num_samples))) 209 | manifold_w = int(np.floor(np.sqrt(tot_num_samples))) 210 | save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w], 211 | './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format( 212 | epoch, idx)) 213 | 214 | # After an epoch, start_batch_id is set to zero 215 | # non-zero value is only for the first epoch after loading pre-trained model 216 | start_batch_id = 0 217 | 218 | # save model 219 | self.save(self.checkpoint_dir, counter) 220 | 221 | # show temporal results 222 | self.visualize_results(epoch) 223 | 224 | # save model for final step 225 | self.save(self.checkpoint_dir, counter) 226 | 227 | def visualize_results(self, epoch): 228 | tot_num_samples = min(self.sample_num, self.batch_size) 229 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 230 | 231 | """ random condition, random noise """ 232 | 233 | z_sample = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim)) 234 | 235 | samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample}) 236 | 237 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 238 | check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png') 239 | 240 | @property 241 | def model_dir(self): 242 | return "{}_{}_{}_{}".format( 243 | self.model_name, self.dataset_name, 244 | self.batch_size, self.z_dim) 245 | 246 | def save(self, checkpoint_dir, step): 247 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 248 | 249 | if not os.path.exists(checkpoint_dir): 250 | os.makedirs(checkpoint_dir) 251 | 252 | self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step) 253 | 254 | def load(self, checkpoint_dir): 255 | import re 256 | print(" [*] Reading checkpoints...") 257 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 258 | 259 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 260 | if ckpt and ckpt.model_checkpoint_path: 261 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 262 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 263 | counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0)) 264 | print(" [*] Success to read {}".format(ckpt_name)) 265 | return True, counter 266 | else: 267 | print(" [*] Failed to find a checkpoint") 268 | return False, 0 269 | -------------------------------------------------------------------------------- /CGAN.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | from __future__ import division 3 | import os 4 | import time 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | from ops import * 9 | from utils import * 10 | 11 | class CGAN(object): 12 | model_name = "CGAN" # name for checkpoint 13 | 14 | def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir): 15 | self.sess = sess 16 | self.dataset_name = dataset_name 17 | self.checkpoint_dir = checkpoint_dir 18 | self.result_dir = result_dir 19 | self.log_dir = log_dir 20 | self.epoch = epoch 21 | self.batch_size = batch_size 22 | 23 | if dataset_name == 'mnist' or dataset_name == 'fashion-mnist': 24 | # parameters 25 | self.input_height = 28 26 | self.input_width = 28 27 | self.output_height = 28 28 | self.output_width = 28 29 | 30 | self.z_dim = z_dim # dimension of noise-vector 31 | self.y_dim = 10 # dimension of condition-vector (label) 32 | self.c_dim = 1 33 | 34 | # train 35 | self.learning_rate = 0.0002 36 | self.beta1 = 0.5 37 | 38 | # test 39 | self.sample_num = 64 # number of generated images to be saved 40 | 41 | # load mnist 42 | self.data_X, self.data_y = load_mnist(self.dataset_name) 43 | 44 | # get number of batches for a single epoch 45 | self.num_batches = len(self.data_X) // self.batch_size 46 | else: 47 | raise NotImplementedError 48 | 49 | 50 | def discriminator(self, x, y, is_training=True, reuse=False): 51 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 52 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S 53 | with tf.variable_scope("discriminator", reuse=reuse): 54 | 55 | # merge image and label 56 | y = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim]) 57 | x = conv_cond_concat(x, y) 58 | 59 | net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='d_conv1')) 60 | net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'), is_training=is_training, scope='d_bn2')) 61 | net = tf.reshape(net, [self.batch_size, -1]) 62 | net = lrelu(bn(linear(net, 1024, scope='d_fc3'), is_training=is_training, scope='d_bn3')) 63 | out_logit = linear(net, 1, scope='d_fc4') 64 | out = tf.nn.sigmoid(out_logit) 65 | 66 | return out, out_logit, net 67 | 68 | def generator(self, z, y, is_training=True, reuse=False): 69 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 70 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 71 | with tf.variable_scope("generator", reuse=reuse): 72 | 73 | # merge noise and label 74 | z = concat([z, y], 1) 75 | 76 | net = tf.nn.relu(bn(linear(z, 1024, scope='g_fc1'), is_training=is_training, scope='g_bn1')) 77 | net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='g_fc2'), is_training=is_training, scope='g_bn2')) 78 | net = tf.reshape(net, [self.batch_size, 7, 7, 128]) 79 | net = tf.nn.relu( 80 | bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='g_dc3'), is_training=is_training, 81 | scope='g_bn3')) 82 | 83 | out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='g_dc4')) 84 | 85 | return out 86 | 87 | def build_model(self): 88 | # some parameters 89 | image_dims = [self.input_height, self.input_width, self.c_dim] 90 | bs = self.batch_size 91 | 92 | """ Graph Input """ 93 | # images 94 | self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images') 95 | 96 | # labels 97 | self.y = tf.placeholder(tf.float32, [bs, self.y_dim], name='y') 98 | 99 | # noises 100 | self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z') 101 | 102 | """ Loss Function """ 103 | 104 | # output of D for real images 105 | D_real, D_real_logits, _ = self.discriminator(self.inputs, self.y, is_training=True, reuse=False) 106 | 107 | # output of D for fake images 108 | G = self.generator(self.z, self.y, is_training=True, reuse=False) 109 | D_fake, D_fake_logits, _ = self.discriminator(G, self.y, is_training=True, reuse=True) 110 | 111 | # get loss for discriminator 112 | d_loss_real = tf.reduce_mean( 113 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones_like(D_real))) 114 | d_loss_fake = tf.reduce_mean( 115 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros_like(D_fake))) 116 | 117 | self.d_loss = d_loss_real + d_loss_fake 118 | 119 | # get loss for generator 120 | self.g_loss = tf.reduce_mean( 121 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones_like(D_fake))) 122 | 123 | """ Training """ 124 | # divide trainable variables into a group for D and a group for G 125 | t_vars = tf.trainable_variables() 126 | d_vars = [var for var in t_vars if 'd_' in var.name] 127 | g_vars = [var for var in t_vars if 'g_' in var.name] 128 | 129 | # optimizers 130 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 131 | self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \ 132 | .minimize(self.d_loss, var_list=d_vars) 133 | self.g_optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \ 134 | .minimize(self.g_loss, var_list=g_vars) 135 | 136 | """" Testing """ 137 | # for test 138 | self.fake_images = self.generator(self.z, self.y, is_training=False, reuse=True) 139 | 140 | """ Summary """ 141 | d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real) 142 | d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake) 143 | d_loss_sum = tf.summary.scalar("d_loss", self.d_loss) 144 | g_loss_sum = tf.summary.scalar("g_loss", self.g_loss) 145 | 146 | # final summary operations 147 | self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum]) 148 | self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum]) 149 | 150 | def train(self): 151 | 152 | # initialize all variables 153 | tf.global_variables_initializer().run() 154 | 155 | # graph inputs for visualize training results 156 | self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size , self.z_dim)) 157 | self.test_labels = self.data_y[0:self.batch_size] 158 | 159 | # saver to save model 160 | self.saver = tf.train.Saver() 161 | 162 | # summary writer 163 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph) 164 | 165 | # restore check-point if it exits 166 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 167 | if could_load: 168 | start_epoch = (int)(checkpoint_counter / self.num_batches) 169 | start_batch_id = checkpoint_counter - start_epoch * self.num_batches 170 | counter = checkpoint_counter 171 | print(" [*] Load SUCCESS") 172 | else: 173 | start_epoch = 0 174 | start_batch_id = 0 175 | counter = 1 176 | print(" [!] Load failed...") 177 | 178 | # loop for epoch 179 | start_time = time.time() 180 | for epoch in range(start_epoch, self.epoch): 181 | 182 | # get batch data 183 | for idx in range(start_batch_id, self.num_batches): 184 | batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size] 185 | batch_labels = self.data_y[idx * self.batch_size:(idx + 1) * self.batch_size] 186 | batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32) 187 | 188 | # update D network 189 | _, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss], 190 | feed_dict={self.inputs: batch_images, self.y: batch_labels, 191 | self.z: batch_z}) 192 | self.writer.add_summary(summary_str, counter) 193 | 194 | # update G network 195 | _, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss], 196 | feed_dict={self.y: batch_labels, self.z: batch_z}) 197 | self.writer.add_summary(summary_str, counter) 198 | 199 | # display training status 200 | counter += 1 201 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ 202 | % (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss)) 203 | 204 | # save training results for every 300 steps 205 | if np.mod(counter, 300) == 0: 206 | samples = self.sess.run(self.fake_images, 207 | feed_dict={self.z: self.sample_z, self.y: self.test_labels}) 208 | tot_num_samples = min(self.sample_num, self.batch_size) 209 | manifold_h = int(np.floor(np.sqrt(tot_num_samples))) 210 | manifold_w = int(np.floor(np.sqrt(tot_num_samples))) 211 | save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w], 212 | './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format( 213 | epoch, idx)) 214 | 215 | # After an epoch, start_batch_id is set to zero 216 | # non-zero value is only for the first epoch after loading pre-trained model 217 | start_batch_id = 0 218 | 219 | # save model 220 | self.save(self.checkpoint_dir, counter) 221 | 222 | # show temporal results 223 | self.visualize_results(epoch) 224 | 225 | # save model for final step 226 | self.save(self.checkpoint_dir, counter) 227 | 228 | def visualize_results(self, epoch): 229 | tot_num_samples = min(self.sample_num, self.batch_size) 230 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 231 | 232 | """ random condition, random noise """ 233 | y = np.random.choice(self.y_dim, self.batch_size) 234 | y_one_hot = np.zeros((self.batch_size, self.y_dim)) 235 | y_one_hot[np.arange(self.batch_size), y] = 1 236 | 237 | z_sample = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim)) 238 | 239 | samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y_one_hot}) 240 | 241 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 242 | check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png') 243 | 244 | """ specified condition, random noise """ 245 | n_styles = 10 # must be less than or equal to self.batch_size 246 | 247 | np.random.seed() 248 | si = np.random.choice(self.batch_size, n_styles) 249 | 250 | for l in range(self.y_dim): 251 | y = np.zeros(self.batch_size, dtype=np.int64) + l 252 | y_one_hot = np.zeros((self.batch_size, self.y_dim)) 253 | y_one_hot[np.arange(self.batch_size), y] = 1 254 | 255 | samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y_one_hot}) 256 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 257 | check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_class_%d.png' % l) 258 | 259 | samples = samples[si, :, :, :] 260 | 261 | if l == 0: 262 | all_samples = samples 263 | else: 264 | all_samples = np.concatenate((all_samples, samples), axis=0) 265 | 266 | """ save merged images to check style-consistency """ 267 | canvas = np.zeros_like(all_samples) 268 | for s in range(n_styles): 269 | for c in range(self.y_dim): 270 | canvas[s * self.y_dim + c, :, :, :] = all_samples[c * n_styles + s, :, :, :] 271 | 272 | save_images(canvas, [n_styles, self.y_dim], 273 | check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes_style_by_style.png') 274 | 275 | @property 276 | def model_dir(self): 277 | return "{}_{}_{}_{}".format( 278 | self.model_name, self.dataset_name, 279 | self.batch_size, self.z_dim) 280 | 281 | def save(self, checkpoint_dir, step): 282 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 283 | 284 | if not os.path.exists(checkpoint_dir): 285 | os.makedirs(checkpoint_dir) 286 | 287 | self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step) 288 | 289 | def load(self, checkpoint_dir): 290 | import re 291 | print(" [*] Reading checkpoints...") 292 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 293 | 294 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 295 | if ckpt and ckpt.model_checkpoint_path: 296 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 297 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 298 | counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0)) 299 | print(" [*] Success to read {}".format(ckpt_name)) 300 | return True, counter 301 | else: 302 | print(" [*] Failed to find a checkpoint") 303 | return False, 0 304 | -------------------------------------------------------------------------------- /CVAE.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | from __future__ import division 3 | import os 4 | import time 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | from ops import * 9 | from utils import * 10 | 11 | import prior_factory as prior 12 | 13 | class CVAE(object): 14 | model_name = "CVAE" # name for checkpoint 15 | 16 | def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir): 17 | self.sess = sess 18 | self.dataset_name = dataset_name 19 | self.checkpoint_dir = checkpoint_dir 20 | self.result_dir = result_dir 21 | self.log_dir = log_dir 22 | self.epoch = epoch 23 | self.batch_size = batch_size 24 | 25 | if dataset_name == 'mnist' or dataset_name == 'fashion-mnist': 26 | # parameters 27 | self.input_height = 28 28 | self.input_width = 28 29 | self.output_height = 28 30 | self.output_width = 28 31 | 32 | self.z_dim = z_dim # dimension of noise-vector 33 | self.y_dim = 10 # dimension of condition-vector (label) 34 | self.c_dim = 1 35 | 36 | # train 37 | self.learning_rate = 0.0002 38 | self.beta1 = 0.5 39 | 40 | # test 41 | self.sample_num = 64 # number of generated images to be saved 42 | 43 | # load mnist 44 | self.data_X, self.data_y = load_mnist(self.dataset_name) 45 | 46 | # get number of batches for a single epoch 47 | self.num_batches = len(self.data_X) // self.batch_size 48 | else: 49 | raise NotImplementedError 50 | 51 | # Gaussian Encoder 52 | def encoder(self, x, y, is_training=True, reuse=False): 53 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 54 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S 55 | with tf.variable_scope("encoder", reuse=reuse): 56 | 57 | # merge image and label 58 | y = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim]) 59 | x = conv_cond_concat(x, y) 60 | 61 | net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='en_conv1')) 62 | net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='en_conv2'), is_training=is_training, scope='en_bn2')) 63 | net = tf.reshape(net, [self.batch_size, -1]) 64 | net = lrelu(bn(linear(net, 1024, scope='en_fc3'), is_training=is_training, scope='en_bn3')) 65 | gaussian_params = linear(net, 2 * self.z_dim, scope='en_fc4') 66 | 67 | # The mean parameter is unconstrained 68 | mean = gaussian_params[:, :self.z_dim] 69 | # The standard deviation must be positive. Parametrize with a softplus and 70 | # add a small epsilon for numerical stability 71 | stddev = 1e-6 + tf.nn.softplus(gaussian_params[:, self.z_dim:]) 72 | 73 | return mean, stddev 74 | 75 | # Bernoulli decoder 76 | def decoder(self, z, y, is_training=True, reuse=False): 77 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 78 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 79 | with tf.variable_scope("decoder", reuse=reuse): 80 | 81 | # merge noise and label 82 | z = concat([z, y], 1) 83 | 84 | net = tf.nn.relu(bn(linear(z, 1024, scope='de_fc1'), is_training=is_training, scope='de_bn1')) 85 | net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='de_fc2'), is_training=is_training, scope='de_bn2')) 86 | net = tf.reshape(net, [self.batch_size, 7, 7, 128]) 87 | net = tf.nn.relu( 88 | bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='de_dc3'), is_training=is_training, 89 | scope='de_bn3')) 90 | 91 | out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='de_dc4')) 92 | 93 | return out 94 | 95 | def build_model(self): 96 | # some parameters 97 | image_dims = [self.input_height, self.input_width, self.c_dim] 98 | bs = self.batch_size 99 | 100 | """ Graph Input """ 101 | # images 102 | self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images') 103 | 104 | # labels 105 | self.y = tf.placeholder(tf.float32, [bs, self.y_dim], name='y') 106 | 107 | # noises 108 | self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z') 109 | 110 | """ Loss Function """ 111 | # encoding 112 | mu, sigma = self.encoder(self.inputs, self.y, is_training=True, reuse=False) 113 | 114 | # sampling by re-parameterization technique 115 | z = mu + sigma * tf.random_normal(tf.shape(mu), 0, 1, dtype=tf.float32) 116 | 117 | # decoding 118 | out = self.decoder(z, self.y, is_training=True, reuse=False) 119 | self.out = tf.clip_by_value(out, 1e-8, 1 - 1e-8) 120 | 121 | # loss 122 | marginal_likelihood = tf.reduce_sum(self.inputs * tf.log(self.out) + (1 - self.inputs) * tf.log(1 - self.out), [1, 2]) 123 | KL_divergence = 0.5 * tf.reduce_sum(tf.square(mu) + tf.square(sigma) - tf.log(1e-8 + tf.square(sigma)) - 1, [1]) 124 | 125 | self.neg_loglikelihood = -tf.reduce_mean(marginal_likelihood) 126 | self.KL_divergence = tf.reduce_mean(KL_divergence) 127 | 128 | ELBO = -self.neg_loglikelihood - self.KL_divergence 129 | 130 | self.loss = -ELBO 131 | 132 | """ Training """ 133 | # optimizers 134 | t_vars = tf.trainable_variables() 135 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 136 | self.optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \ 137 | .minimize(self.loss, var_list=t_vars) 138 | 139 | """" Testing """ 140 | # for test 141 | self.fake_images = self.decoder(self.z, self.y, is_training=False, reuse=True) 142 | 143 | """ Summary """ 144 | nll_sum = tf.summary.scalar("nll", self.neg_loglikelihood) 145 | kl_sum = tf.summary.scalar("kl", self.KL_divergence) 146 | loss_sum = tf.summary.scalar("loss", self.loss) 147 | 148 | # final summary operations 149 | self.merged_summary_op = tf.summary.merge_all() 150 | 151 | def train(self): 152 | 153 | # initialize all variables 154 | tf.global_variables_initializer().run() 155 | 156 | # graph inputs for visualize training results 157 | self.sample_z = prior.gaussian(self.batch_size, self.z_dim) 158 | self.test_labels = self.data_y[0:self.batch_size] 159 | 160 | # saver to save model 161 | self.saver = tf.train.Saver() 162 | 163 | # summary writer 164 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph) 165 | 166 | # restore check-point if it exits 167 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 168 | if could_load: 169 | start_epoch = (int)(checkpoint_counter / self.num_batches) 170 | start_batch_id = checkpoint_counter - start_epoch * self.num_batches 171 | counter = checkpoint_counter 172 | print(" [*] Load SUCCESS") 173 | else: 174 | start_epoch = 0 175 | start_batch_id = 0 176 | counter = 1 177 | print(" [!] Load failed...") 178 | 179 | # loop for epoch 180 | start_time = time.time() 181 | for epoch in range(start_epoch, self.epoch): 182 | 183 | # get batch data 184 | for idx in range(start_batch_id, self.num_batches): 185 | batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size] 186 | batch_labels = self.data_y[idx * self.batch_size:(idx + 1) * self.batch_size] 187 | batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32) 188 | 189 | # update autoencoder 190 | _, summary_str, loss, nll_loss, kl_loss = self.sess.run([self.optim, self.merged_summary_op, self.loss, self.neg_loglikelihood, self.KL_divergence], 191 | feed_dict={self.inputs: batch_images, self.y: batch_labels, self.z: batch_z}) 192 | self.writer.add_summary(summary_str, counter) 193 | 194 | # display training status 195 | counter += 1 196 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.8f, nll: %.8f, kl: %.8f" \ 197 | % (epoch, idx, self.num_batches, time.time() - start_time, loss, nll_loss, kl_loss)) 198 | 199 | # save training results for every 300 steps 200 | if np.mod(counter, 300) == 0: 201 | samples = self.sess.run(self.fake_images, 202 | feed_dict={self.z: self.sample_z, self.y: self.test_labels}) 203 | tot_num_samples = min(self.sample_num, self.batch_size) 204 | manifold_h = int(np.floor(np.sqrt(tot_num_samples))) 205 | manifold_w = int(np.floor(np.sqrt(tot_num_samples))) 206 | save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w], 207 | './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format( 208 | epoch, idx)) 209 | 210 | # After an epoch, start_batch_id is set to zero 211 | # non-zero value is only for the first epoch after loading pre-trained model 212 | start_batch_id = 0 213 | 214 | # save model 215 | self.save(self.checkpoint_dir, counter) 216 | 217 | # show temporal results 218 | self.visualize_results(epoch) 219 | 220 | # save model for final step 221 | self.save(self.checkpoint_dir, counter) 222 | 223 | def visualize_results(self, epoch): 224 | tot_num_samples = min(self.sample_num, self.batch_size) 225 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 226 | 227 | """ random condition, random noise """ 228 | y = np.random.choice(self.y_dim, self.batch_size) 229 | y_one_hot = np.zeros((self.batch_size, self.y_dim)) 230 | y_one_hot[np.arange(self.batch_size), y] = 1 231 | 232 | z_sample = prior.gaussian(self.batch_size, self.z_dim) 233 | 234 | samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y_one_hot}) 235 | 236 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 237 | check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png') 238 | 239 | """ specified condition, random noise """ 240 | n_styles = 10 # must be less than or equal to self.batch_size 241 | 242 | np.random.seed() 243 | si = np.random.choice(self.batch_size, n_styles) 244 | 245 | for l in range(self.y_dim): 246 | y = np.zeros(self.batch_size, dtype=np.int64) + l 247 | y_one_hot = np.zeros((self.batch_size, self.y_dim)) 248 | y_one_hot[np.arange(self.batch_size), y] = 1 249 | 250 | samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y_one_hot}) 251 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 252 | check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_class_%d.png' % l) 253 | 254 | samples = samples[si, :, :, :] 255 | 256 | if l == 0: 257 | all_samples = samples 258 | else: 259 | all_samples = np.concatenate((all_samples, samples), axis=0) 260 | 261 | """ save merged images to check style-consistency """ 262 | canvas = np.zeros_like(all_samples) 263 | for s in range(n_styles): 264 | for c in range(self.y_dim): 265 | canvas[s * self.y_dim + c, :, :, :] = all_samples[c * n_styles + s, :, :, :] 266 | 267 | save_images(canvas, [n_styles, self.y_dim], 268 | check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes_style_by_style.png') 269 | 270 | @property 271 | def model_dir(self): 272 | return "{}_{}_{}_{}".format( 273 | self.model_name, self.dataset_name, 274 | self.batch_size, self.z_dim) 275 | 276 | def save(self, checkpoint_dir, step): 277 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 278 | 279 | if not os.path.exists(checkpoint_dir): 280 | os.makedirs(checkpoint_dir) 281 | 282 | self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step) 283 | 284 | def load(self, checkpoint_dir): 285 | import re 286 | print(" [*] Reading checkpoints...") 287 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 288 | 289 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 290 | if ckpt and ckpt.model_checkpoint_path: 291 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 292 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 293 | counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0)) 294 | print(" [*] Success to read {}".format(ckpt_name)) 295 | return True, counter 296 | else: 297 | print(" [*] Failed to find a checkpoint") 298 | return False, 0 -------------------------------------------------------------------------------- /DRAGAN.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | from __future__ import division 3 | import os 4 | import time 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | from ops import * 9 | from utils import * 10 | 11 | class DRAGAN(object): 12 | model_name = "DRAGAN" # name for checkpoint 13 | 14 | def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir): 15 | self.sess = sess 16 | self.dataset_name = dataset_name 17 | self.checkpoint_dir = checkpoint_dir 18 | self.result_dir = result_dir 19 | self.log_dir = log_dir 20 | self.epoch = epoch 21 | self.batch_size = batch_size 22 | 23 | if dataset_name == 'mnist' or dataset_name == 'fashion-mnist': 24 | # parameters 25 | self.input_height = 28 26 | self.input_width = 28 27 | self.output_height = 28 28 | self.output_width = 28 29 | 30 | self.z_dim = z_dim # dimension of noise-vector 31 | self.c_dim = 1 32 | 33 | # DRAGAN parameter 34 | self.lambd = 0.25 # The higher value, the more stable, but the slower convergence 35 | 36 | # train 37 | self.learning_rate = 0.0002 38 | self.beta1 = 0.5 39 | 40 | # test 41 | self.sample_num = 64 # number of generated images to be saved 42 | 43 | # load mnist 44 | self.data_X, self.data_y = load_mnist(self.dataset_name) 45 | 46 | # get number of batches for a single epoch 47 | self.num_batches = len(self.data_X) // self.batch_size 48 | else: 49 | raise NotImplementedError 50 | 51 | def discriminator(self, x, is_training=True, reuse=False): 52 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 53 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S 54 | with tf.variable_scope("discriminator", reuse=reuse): 55 | 56 | net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='d_conv1')) 57 | net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'), is_training=is_training, scope='d_bn2')) 58 | net = tf.reshape(net, [self.batch_size, -1]) 59 | net = lrelu(bn(linear(net, 1024, scope='d_fc3'), is_training=is_training, scope='d_bn3')) 60 | out_logit = linear(net, 1, scope='d_fc4') 61 | out = tf.nn.sigmoid(out_logit) 62 | 63 | return out, out_logit, net 64 | 65 | def generator(self, z, is_training=True, reuse=False): 66 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 67 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 68 | with tf.variable_scope("generator", reuse=reuse): 69 | net = tf.nn.relu(bn(linear(z, 1024, scope='g_fc1'), is_training=is_training, scope='g_bn1')) 70 | net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='g_fc2'), is_training=is_training, scope='g_bn2')) 71 | net = tf.reshape(net, [self.batch_size, 7, 7, 128]) 72 | net = tf.nn.relu( 73 | bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='g_dc3'), is_training=is_training, 74 | scope='g_bn3')) 75 | 76 | out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='g_dc4')) 77 | 78 | return out 79 | 80 | def get_perturbed_batch(self, minibatch): 81 | return minibatch + 0.5 * minibatch.std() * np.random.random(minibatch.shape) 82 | 83 | def build_model(self): 84 | # some parameters 85 | image_dims = [self.input_height, self.input_width, self.c_dim] 86 | bs = self.batch_size 87 | 88 | """ Graph Input """ 89 | # images 90 | self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images') 91 | self.inputs_p = tf.placeholder(tf.float32, [bs] + image_dims, name='real_perturbed_images') 92 | 93 | # noises 94 | self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z') 95 | 96 | """ Loss Function """ 97 | 98 | # output of D for real images 99 | D_real, D_real_logits, _ = self.discriminator(self.inputs, is_training=True, reuse=False) 100 | 101 | # output of D for fake images 102 | G = self.generator(self.z, is_training=True, reuse=False) 103 | D_fake, D_fake_logits, _ = self.discriminator(G, is_training=True, reuse=True) 104 | 105 | # get loss for discriminator 106 | d_loss_real = tf.reduce_mean( 107 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones_like(D_real))) 108 | d_loss_fake = tf.reduce_mean( 109 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros_like(D_fake))) 110 | 111 | self.d_loss = d_loss_real + d_loss_fake 112 | 113 | # get loss for generator 114 | self.g_loss = tf.reduce_mean( 115 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones_like(D_fake))) 116 | 117 | """ DRAGAN Loss (Gradient penalty) """ 118 | # This is borrowed from https://github.com/kodalinaveen3/DRAGAN/blob/master/DRAGAN.ipynb 119 | alpha = tf.random_uniform(shape=self.inputs.get_shape(), minval=0.,maxval=1.) 120 | differences = self.inputs_p - self.inputs # This is different from WGAN-GP 121 | interpolates = self.inputs + (alpha * differences) 122 | _,D_inter,_=self.discriminator(interpolates, is_training=True, reuse=True) 123 | gradients = tf.gradients(D_inter, [interpolates])[0] 124 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) 125 | gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2) 126 | self.d_loss += self.lambd * gradient_penalty 127 | 128 | """ Training """ 129 | # divide trainable variables into a group for D and a group for G 130 | t_vars = tf.trainable_variables() 131 | d_vars = [var for var in t_vars if 'd_' in var.name] 132 | g_vars = [var for var in t_vars if 'g_' in var.name] 133 | 134 | # optimizers 135 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 136 | self.d_optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \ 137 | .minimize(self.d_loss, var_list=d_vars) 138 | self.g_optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \ 139 | .minimize(self.g_loss, var_list=g_vars) 140 | 141 | """" Testing """ 142 | # for test 143 | self.fake_images = self.generator(self.z, is_training=False, reuse=True) 144 | 145 | """ Summary """ 146 | d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real) 147 | d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake) 148 | d_loss_sum = tf.summary.scalar("d_loss", self.d_loss) 149 | g_loss_sum = tf.summary.scalar("g_loss", self.g_loss) 150 | 151 | # final summary operations 152 | self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum]) 153 | self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum]) 154 | 155 | def train(self): 156 | 157 | # initialize all variables 158 | tf.global_variables_initializer().run() 159 | 160 | # graph inputs for visualize training results 161 | self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size , self.z_dim)) 162 | 163 | # saver to save model 164 | self.saver = tf.train.Saver() 165 | 166 | # summary writer 167 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph) 168 | 169 | # restore check-point if it exits 170 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 171 | if could_load: 172 | start_epoch = (int)(checkpoint_counter / self.num_batches) 173 | start_batch_id = checkpoint_counter - start_epoch * self.num_batches 174 | counter = checkpoint_counter 175 | print(" [*] Load SUCCESS") 176 | else: 177 | start_epoch = 0 178 | start_batch_id = 0 179 | counter = 1 180 | print(" [!] Load failed...") 181 | 182 | # loop for epoch 183 | start_time = time.time() 184 | for epoch in range(start_epoch, self.epoch): 185 | 186 | # get batch data 187 | for idx in range(start_batch_id, self.num_batches): 188 | batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size] 189 | batch_images_p = self.get_perturbed_batch(batch_images) 190 | batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32) 191 | 192 | # update D network 193 | _, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss], 194 | feed_dict={self.inputs: batch_images, self.inputs_p: batch_images_p, self.z: batch_z}) 195 | self.writer.add_summary(summary_str, counter) 196 | 197 | # update G network 198 | _, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss], feed_dict={self.z: batch_z}) 199 | self.writer.add_summary(summary_str, counter) 200 | 201 | # display training status 202 | counter += 1 203 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ 204 | % (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss)) 205 | 206 | # save training results for every 300 steps 207 | if np.mod(counter, 300) == 0: 208 | samples = self.sess.run(self.fake_images, 209 | feed_dict={self.z: self.sample_z}) 210 | tot_num_samples = min(self.sample_num, self.batch_size) 211 | manifold_h = int(np.floor(np.sqrt(tot_num_samples))) 212 | manifold_w = int(np.floor(np.sqrt(tot_num_samples))) 213 | save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w], 214 | './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format( 215 | epoch, idx)) 216 | 217 | # After an epoch, start_batch_id is set to zero 218 | # non-zero value is only for the first epoch after loading pre-trained model 219 | start_batch_id = 0 220 | 221 | # save model 222 | self.save(self.checkpoint_dir, counter) 223 | 224 | # show temporal results 225 | self.visualize_results(epoch) 226 | 227 | # save model for final step 228 | self.save(self.checkpoint_dir, counter) 229 | 230 | def visualize_results(self, epoch): 231 | tot_num_samples = min(self.sample_num, self.batch_size) 232 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 233 | 234 | """ random condition, random noise """ 235 | 236 | z_sample = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim)) 237 | 238 | samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample}) 239 | 240 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 241 | check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png') 242 | 243 | @property 244 | def model_dir(self): 245 | return "{}_{}_{}_{}".format( 246 | self.model_name, self.dataset_name, 247 | self.batch_size, self.z_dim) 248 | 249 | def save(self, checkpoint_dir, step): 250 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 251 | 252 | if not os.path.exists(checkpoint_dir): 253 | os.makedirs(checkpoint_dir) 254 | 255 | self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step) 256 | 257 | def load(self, checkpoint_dir): 258 | import re 259 | print(" [*] Reading checkpoints...") 260 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 261 | 262 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 263 | if ckpt and ckpt.model_checkpoint_path: 264 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 265 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 266 | counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0)) 267 | print(" [*] Success to read {}".format(ckpt_name)) 268 | return True, counter 269 | else: 270 | print(" [*] Failed to find a checkpoint") 271 | return False, 0 -------------------------------------------------------------------------------- /EBGAN.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | from __future__ import division 3 | import os 4 | import time 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | from ops import * 9 | from utils import * 10 | 11 | class EBGAN(object): 12 | model_name = "EBGAN" # name for checkpoint 13 | 14 | def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir): 15 | self.sess = sess 16 | self.dataset_name = dataset_name 17 | self.checkpoint_dir = checkpoint_dir 18 | self.result_dir = result_dir 19 | self.log_dir = log_dir 20 | self.epoch = epoch 21 | self.batch_size = batch_size 22 | 23 | if dataset_name == 'mnist' or dataset_name == 'fashion-mnist': 24 | # parameters 25 | self.input_height = 28 26 | self.input_width = 28 27 | self.output_height = 28 28 | self.output_width = 28 29 | 30 | self.z_dim = z_dim # dimension of noise-vector 31 | self.c_dim = 1 32 | 33 | # EBGAN Parameter 34 | self.pt_loss_weight = 0.1 35 | self.margin = max(1,self.batch_size/64.) # margin for loss function 36 | # usually margin of 1 is enough, but for large batch size it must be larger than 1 37 | 38 | # train 39 | self.learning_rate = 0.0002 40 | self.beta1 = 0.5 41 | 42 | # test 43 | self.sample_num = 64 # number of generated images to be saved 44 | 45 | # load mnist 46 | self.data_X, self.data_y = load_mnist(self.dataset_name) 47 | 48 | # get number of batches for a single epoch 49 | self.num_batches = len(self.data_X) // self.batch_size 50 | else: 51 | raise NotImplementedError 52 | 53 | # borrowed from https://github.com/shekkizh/EBGAN.tensorflow/blob/master/EBGAN/Faces_EBGAN.py 54 | def pullaway_loss(self, embeddings): 55 | """ 56 | Pull Away loss calculation 57 | :param embeddings: The embeddings to be orthogonalized for varied faces. Shape [batch_size, embeddings_dim] 58 | :return: pull away term loss 59 | """ 60 | norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True)) 61 | normalized_embeddings = embeddings / norm 62 | similarity = tf.matmul( 63 | normalized_embeddings, normalized_embeddings, transpose_b=True) 64 | batch_size = tf.cast(tf.shape(embeddings)[0], tf.float32) 65 | pt_loss = (tf.reduce_sum(similarity) - batch_size) / (batch_size * (batch_size - 1)) 66 | return pt_loss 67 | 68 | def discriminator(self, x, is_training=True, reuse=False): 69 | # It must be Auto-Encoder style architecture 70 | # Architecture : (64)4c2s-FC32-FC64*14*14_BR-(1)4dc2s_S 71 | with tf.variable_scope("discriminator", reuse=reuse): 72 | 73 | net = tf.nn.relu(conv2d(x, 64, 4, 4, 2, 2, name='d_conv1')) 74 | net = tf.reshape(net, [self.batch_size, -1]) 75 | code = (linear(net, 32, scope='d_fc6')) # bn and relu are excluded since code is used in pullaway_loss 76 | net = tf.nn.relu(bn(linear(code, 64 * 14 * 14, scope='d_fc3'), is_training=is_training, scope='d_bn3')) 77 | net = tf.reshape(net, [self.batch_size, 14, 14, 64]) 78 | out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='d_dc5')) 79 | 80 | # recon loss 81 | recon_error = tf.sqrt(2 * tf.nn.l2_loss(out - x)) / self.batch_size 82 | return out, recon_error, code 83 | 84 | def generator(self, z, is_training=True, reuse=False): 85 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 86 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 87 | with tf.variable_scope("generator", reuse=reuse): 88 | net = tf.nn.relu(bn(linear(z, 1024, scope='g_fc1'), is_training=is_training, scope='g_bn1')) 89 | net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='g_fc2'), is_training=is_training, scope='g_bn2')) 90 | net = tf.reshape(net, [self.batch_size, 7, 7, 128]) 91 | net = tf.nn.relu( 92 | bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='g_dc3'), is_training=is_training, 93 | scope='g_bn3')) 94 | 95 | out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='g_dc4')) 96 | 97 | return out 98 | 99 | def build_model(self): 100 | # some parameters 101 | image_dims = [self.input_height, self.input_width, self.c_dim] 102 | bs = self.batch_size 103 | 104 | """ Graph Input """ 105 | # images 106 | self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images') 107 | 108 | # noises 109 | self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z') 110 | 111 | """ Loss Function """ 112 | 113 | # output of D for real images 114 | D_real_img, D_real_err, D_real_code = self.discriminator(self.inputs, is_training=True, reuse=False) 115 | 116 | # output of D for fake images 117 | G = self.generator(self.z, is_training=True, reuse=False) 118 | D_fake_img, D_fake_err, D_fake_code = self.discriminator(G, is_training=True, reuse=True) 119 | 120 | # get loss for discriminator 121 | self.d_loss = D_real_err + tf.maximum(self.margin - D_fake_err,0) 122 | 123 | # get loss for generator 124 | self.g_loss = D_fake_err + self.pt_loss_weight*self.pullaway_loss(D_fake_code) 125 | 126 | """ Training """ 127 | # divide trainable variables into a group for D and a group for G 128 | t_vars = tf.trainable_variables() 129 | d_vars = [var for var in t_vars if 'd_' in var.name] 130 | g_vars = [var for var in t_vars if 'g_' in var.name] 131 | 132 | # optimizers 133 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 134 | self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \ 135 | .minimize(self.d_loss, var_list=d_vars) 136 | self.g_optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \ 137 | .minimize(self.g_loss, var_list=g_vars) 138 | 139 | """" Testing """ 140 | # for test 141 | self.fake_images = self.generator(self.z, is_training=False, reuse=True) 142 | 143 | """ Summary """ 144 | d_loss_real_sum = tf.summary.scalar("d_error_real", D_real_err) 145 | d_loss_fake_sum = tf.summary.scalar("d_error_fake", D_fake_err) 146 | d_loss_sum = tf.summary.scalar("d_loss", self.d_loss) 147 | g_loss_sum = tf.summary.scalar("g_loss", self.g_loss) 148 | 149 | # final summary operations 150 | self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum]) 151 | self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum]) 152 | 153 | def train(self): 154 | 155 | # initialize all variables 156 | tf.global_variables_initializer().run() 157 | 158 | # graph inputs for visualize training results 159 | self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size , self.z_dim)) 160 | 161 | # saver to save model 162 | self.saver = tf.train.Saver() 163 | 164 | # summary writer 165 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph) 166 | 167 | # restore check-point if it exits 168 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 169 | if could_load: 170 | start_epoch = (int)(checkpoint_counter / self.num_batches) 171 | start_batch_id = checkpoint_counter - start_epoch * self.num_batches 172 | counter = checkpoint_counter 173 | print(" [*] Load SUCCESS") 174 | else: 175 | start_epoch = 0 176 | start_batch_id = 0 177 | counter = 1 178 | print(" [!] Load failed...") 179 | 180 | # loop for epoch 181 | start_time = time.time() 182 | for epoch in range(start_epoch, self.epoch): 183 | 184 | # get batch data 185 | for idx in range(start_batch_id, self.num_batches): 186 | batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size] 187 | batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32) 188 | 189 | # update D network 190 | _, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss], 191 | feed_dict={self.inputs: batch_images, self.z: batch_z}) 192 | self.writer.add_summary(summary_str, counter) 193 | 194 | # update G network 195 | _, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss], feed_dict={self.z: batch_z}) 196 | self.writer.add_summary(summary_str, counter) 197 | 198 | # display training status 199 | counter += 1 200 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ 201 | % (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss)) 202 | 203 | # save training results for every 300 steps 204 | if np.mod(counter, 300) == 0: 205 | samples = self.sess.run(self.fake_images, 206 | feed_dict={self.z: self.sample_z}) 207 | tot_num_samples = min(self.sample_num, self.batch_size) 208 | manifold_h = int(np.floor(np.sqrt(tot_num_samples))) 209 | manifold_w = int(np.floor(np.sqrt(tot_num_samples))) 210 | save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w], 211 | './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format( 212 | epoch, idx)) 213 | 214 | # After an epoch, start_batch_id is set to zero 215 | # non-zero value is only for the first epoch after loading pre-trained model 216 | start_batch_id = 0 217 | 218 | # save model 219 | self.save(self.checkpoint_dir, counter) 220 | 221 | # show temporal results 222 | self.visualize_results(epoch) 223 | 224 | # save model for final step 225 | self.save(self.checkpoint_dir, counter) 226 | 227 | def visualize_results(self, epoch): 228 | tot_num_samples = min(self.sample_num, self.batch_size) 229 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 230 | 231 | """ random condition, random noise """ 232 | 233 | z_sample = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim)) 234 | 235 | samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample}) 236 | 237 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 238 | check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png') 239 | 240 | @property 241 | def model_dir(self): 242 | return "{}_{}_{}_{}".format( 243 | self.model_name, self.dataset_name, 244 | self.batch_size, self.z_dim) 245 | 246 | def save(self, checkpoint_dir, step): 247 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 248 | 249 | if not os.path.exists(checkpoint_dir): 250 | os.makedirs(checkpoint_dir) 251 | 252 | self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step) 253 | 254 | def load(self, checkpoint_dir): 255 | import re 256 | print(" [*] Reading checkpoints...") 257 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 258 | 259 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 260 | if ckpt and ckpt.model_checkpoint_path: 261 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 262 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 263 | counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0)) 264 | print(" [*] Success to read {}".format(ckpt_name)) 265 | return True, counter 266 | else: 267 | print(" [*] Failed to find a checkpoint") 268 | return False, 0 -------------------------------------------------------------------------------- /GAN.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | from __future__ import division 3 | import os 4 | import time 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | from ops import * 9 | from utils import * 10 | 11 | class GAN(object): 12 | model_name = "GAN" # name for checkpoint 13 | 14 | def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir): 15 | self.sess = sess 16 | self.dataset_name = dataset_name 17 | self.checkpoint_dir = checkpoint_dir 18 | self.result_dir = result_dir 19 | self.log_dir = log_dir 20 | self.epoch = epoch 21 | self.batch_size = batch_size 22 | 23 | if dataset_name == 'mnist' or dataset_name == 'fashion-mnist': 24 | # parameters 25 | self.input_height = 28 26 | self.input_width = 28 27 | self.output_height = 28 28 | self.output_width = 28 29 | 30 | self.z_dim = z_dim # dimension of noise-vector 31 | self.c_dim = 1 32 | 33 | # train 34 | self.learning_rate = 0.0002 35 | self.beta1 = 0.5 36 | 37 | # test 38 | self.sample_num = 64 # number of generated images to be saved 39 | 40 | # load mnist 41 | self.data_X, self.data_y = load_mnist(self.dataset_name) 42 | 43 | # get number of batches for a single epoch 44 | self.num_batches = len(self.data_X) // self.batch_size 45 | else: 46 | raise NotImplementedError 47 | 48 | def discriminator(self, x, is_training=True, reuse=False): 49 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 50 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S 51 | with tf.variable_scope("discriminator", reuse=reuse): 52 | 53 | net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='d_conv1')) 54 | net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'), is_training=is_training, scope='d_bn2')) 55 | net = tf.reshape(net, [self.batch_size, -1]) 56 | net = lrelu(bn(linear(net, 1024, scope='d_fc3'), is_training=is_training, scope='d_bn3')) 57 | out_logit = linear(net, 1, scope='d_fc4') 58 | out = tf.nn.sigmoid(out_logit) 59 | 60 | return out, out_logit, net 61 | 62 | def generator(self, z, is_training=True, reuse=False): 63 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 64 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 65 | with tf.variable_scope("generator", reuse=reuse): 66 | net = tf.nn.relu(bn(linear(z, 1024, scope='g_fc1'), is_training=is_training, scope='g_bn1')) 67 | net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='g_fc2'), is_training=is_training, scope='g_bn2')) 68 | net = tf.reshape(net, [self.batch_size, 7, 7, 128]) 69 | net = tf.nn.relu( 70 | bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='g_dc3'), is_training=is_training, 71 | scope='g_bn3')) 72 | 73 | out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='g_dc4')) 74 | 75 | return out 76 | 77 | def build_model(self): 78 | # some parameters 79 | image_dims = [self.input_height, self.input_width, self.c_dim] 80 | bs = self.batch_size 81 | 82 | """ Graph Input """ 83 | # images 84 | self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images') 85 | 86 | # noises 87 | self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z') 88 | 89 | """ Loss Function """ 90 | 91 | # output of D for real images 92 | D_real, D_real_logits, _ = self.discriminator(self.inputs, is_training=True, reuse=False) 93 | 94 | # output of D for fake images 95 | G = self.generator(self.z, is_training=True, reuse=False) 96 | D_fake, D_fake_logits, _ = self.discriminator(G, is_training=True, reuse=True) 97 | 98 | # get loss for discriminator 99 | d_loss_real = tf.reduce_mean( 100 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones_like(D_real))) 101 | d_loss_fake = tf.reduce_mean( 102 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros_like(D_fake))) 103 | 104 | self.d_loss = d_loss_real + d_loss_fake 105 | 106 | # get loss for generator 107 | self.g_loss = tf.reduce_mean( 108 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones_like(D_fake))) 109 | 110 | """ Training """ 111 | # divide trainable variables into a group for D and a group for G 112 | t_vars = tf.trainable_variables() 113 | d_vars = [var for var in t_vars if 'd_' in var.name] 114 | g_vars = [var for var in t_vars if 'g_' in var.name] 115 | 116 | # optimizers 117 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 118 | self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \ 119 | .minimize(self.d_loss, var_list=d_vars) 120 | self.g_optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \ 121 | .minimize(self.g_loss, var_list=g_vars) 122 | 123 | """" Testing """ 124 | # for test 125 | self.fake_images = self.generator(self.z, is_training=False, reuse=True) 126 | 127 | """ Summary """ 128 | d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real) 129 | d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake) 130 | d_loss_sum = tf.summary.scalar("d_loss", self.d_loss) 131 | g_loss_sum = tf.summary.scalar("g_loss", self.g_loss) 132 | 133 | # final summary operations 134 | self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum]) 135 | self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum]) 136 | 137 | def train(self): 138 | 139 | # initialize all variables 140 | tf.global_variables_initializer().run() 141 | 142 | # graph inputs for visualize training results 143 | self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size , self.z_dim)) 144 | 145 | # saver to save model 146 | self.saver = tf.train.Saver() 147 | 148 | # summary writer 149 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph) 150 | 151 | # restore check-point if it exits 152 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 153 | if could_load: 154 | start_epoch = (int)(checkpoint_counter / self.num_batches) 155 | start_batch_id = checkpoint_counter - start_epoch * self.num_batches 156 | counter = checkpoint_counter 157 | print(" [*] Load SUCCESS") 158 | else: 159 | start_epoch = 0 160 | start_batch_id = 0 161 | counter = 1 162 | print(" [!] Load failed...") 163 | 164 | # loop for epoch 165 | start_time = time.time() 166 | for epoch in range(start_epoch, self.epoch): 167 | 168 | # get batch data 169 | for idx in range(start_batch_id, self.num_batches): 170 | batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size] 171 | batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32) 172 | 173 | # update D network 174 | _, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss], 175 | feed_dict={self.inputs: batch_images, self.z: batch_z}) 176 | self.writer.add_summary(summary_str, counter) 177 | 178 | # update G network 179 | _, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss], feed_dict={self.z: batch_z}) 180 | self.writer.add_summary(summary_str, counter) 181 | 182 | # display training status 183 | counter += 1 184 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ 185 | % (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss)) 186 | 187 | # save training results for every 300 steps 188 | if np.mod(counter, 300) == 0: 189 | samples = self.sess.run(self.fake_images, feed_dict={self.z: self.sample_z}) 190 | tot_num_samples = min(self.sample_num, self.batch_size) 191 | manifold_h = int(np.floor(np.sqrt(tot_num_samples))) 192 | manifold_w = int(np.floor(np.sqrt(tot_num_samples))) 193 | save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w], 194 | './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format( 195 | epoch, idx)) 196 | 197 | # After an epoch, start_batch_id is set to zero 198 | # non-zero value is only for the first epoch after loading pre-trained model 199 | start_batch_id = 0 200 | 201 | # save model 202 | self.save(self.checkpoint_dir, counter) 203 | 204 | # show temporal results 205 | self.visualize_results(epoch) 206 | 207 | # save model for final step 208 | self.save(self.checkpoint_dir, counter) 209 | 210 | def visualize_results(self, epoch): 211 | tot_num_samples = min(self.sample_num, self.batch_size) 212 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 213 | 214 | """ random condition, random noise """ 215 | 216 | z_sample = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim)) 217 | 218 | samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample}) 219 | 220 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 221 | check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png') 222 | 223 | @property 224 | def model_dir(self): 225 | return "{}_{}_{}_{}".format( 226 | self.model_name, self.dataset_name, 227 | self.batch_size, self.z_dim) 228 | 229 | def save(self, checkpoint_dir, step): 230 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 231 | 232 | if not os.path.exists(checkpoint_dir): 233 | os.makedirs(checkpoint_dir) 234 | 235 | self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step) 236 | 237 | def load(self, checkpoint_dir): 238 | import re 239 | print(" [*] Reading checkpoints...") 240 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 241 | 242 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 243 | if ckpt and ckpt.model_checkpoint_path: 244 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 245 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 246 | counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0)) 247 | print(" [*] Success to read {}".format(ckpt_name)) 248 | return True, counter 249 | else: 250 | print(" [*] Failed to find a checkpoint") 251 | return False, 0 252 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2017 Hwalsuk Lee. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright 2017, Hwalsuk Lee. 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /LSGAN.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | from __future__ import division 3 | import os 4 | import time 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | from ops import * 9 | from utils import * 10 | 11 | class LSGAN(object): 12 | model_name = "LSGAN" # name for checkpoint 13 | 14 | def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir): 15 | self.sess = sess 16 | self.dataset_name = dataset_name 17 | self.checkpoint_dir = checkpoint_dir 18 | self.result_dir = result_dir 19 | self.log_dir = log_dir 20 | self.epoch = epoch 21 | self.batch_size = batch_size 22 | 23 | if dataset_name == 'mnist' or dataset_name == 'fashion-mnist': 24 | # parameters 25 | self.input_height = 28 26 | self.input_width = 28 27 | self.output_height = 28 28 | self.output_width = 28 29 | 30 | self.z_dim = z_dim # dimension of noise-vector 31 | self.c_dim = 1 32 | 33 | # train 34 | self.learning_rate = 0.0002 35 | self.beta1 = 0.5 36 | 37 | # test 38 | self.sample_num = 64 # number of generated images to be saved 39 | 40 | # load mnist 41 | self.data_X, self.data_y = load_mnist(self.dataset_name) 42 | 43 | # get number of batches for a single epoch 44 | self.num_batches = len(self.data_X) // self.batch_size 45 | else: 46 | raise NotImplementedError 47 | 48 | def discriminator(self, x, is_training=True, reuse=False): 49 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 50 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S 51 | with tf.variable_scope("discriminator", reuse=reuse): 52 | 53 | net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='d_conv1')) 54 | net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'), is_training=is_training, scope='d_bn2')) 55 | net = tf.reshape(net, [self.batch_size, -1]) 56 | net = lrelu(bn(linear(net, 1024, scope='d_fc3'), is_training=is_training, scope='d_bn3')) 57 | out_logit = linear(net, 1, scope='d_fc4') 58 | out = tf.nn.sigmoid(out_logit) 59 | 60 | return out, out_logit, net 61 | 62 | def generator(self, z, is_training=True, reuse=False): 63 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 64 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 65 | with tf.variable_scope("generator", reuse=reuse): 66 | net = tf.nn.relu(bn(linear(z, 1024, scope='g_fc1'), is_training=is_training, scope='g_bn1')) 67 | net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='g_fc2'), is_training=is_training, scope='g_bn2')) 68 | net = tf.reshape(net, [self.batch_size, 7, 7, 128]) 69 | net = tf.nn.relu( 70 | bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='g_dc3'), is_training=is_training, 71 | scope='g_bn3')) 72 | 73 | out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='g_dc4')) 74 | 75 | return out 76 | 77 | def mse_loss(self, pred, data): 78 | loss_val = tf.sqrt(2 * tf.nn.l2_loss(pred - data)) / self.batch_size 79 | return loss_val 80 | 81 | def build_model(self): 82 | # some parameters 83 | image_dims = [self.input_height, self.input_width, self.c_dim] 84 | bs = self.batch_size 85 | 86 | """ Graph Input """ 87 | # images 88 | self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images') 89 | 90 | # noises 91 | self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z') 92 | 93 | """ Loss Function """ 94 | 95 | # output of D for real images 96 | D_real, D_real_logits, _ = self.discriminator(self.inputs, is_training=True, reuse=False) 97 | 98 | # output of D for fake images 99 | G = self.generator(self.z, is_training=True, reuse=False) 100 | D_fake, D_fake_logits, _ = self.discriminator(G, is_training=True, reuse=True) 101 | 102 | # get loss for discriminator 103 | d_loss_real = tf.reduce_mean(self.mse_loss(D_real_logits, tf.ones_like(D_real_logits))) 104 | d_loss_fake = tf.reduce_mean(self.mse_loss(D_fake_logits, tf.zeros_like(D_fake_logits))) 105 | 106 | self.d_loss = 0.5*(d_loss_real + d_loss_fake) 107 | 108 | # get loss for generator 109 | self.g_loss = tf.reduce_mean(self.mse_loss(D_fake_logits, tf.ones_like(D_fake_logits))) 110 | 111 | """ Training """ 112 | # divide trainable variables into a group for D and a group for G 113 | t_vars = tf.trainable_variables() 114 | d_vars = [var for var in t_vars if 'd_' in var.name] 115 | g_vars = [var for var in t_vars if 'g_' in var.name] 116 | 117 | # optimizers 118 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 119 | self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \ 120 | .minimize(self.d_loss, var_list=d_vars) 121 | self.g_optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \ 122 | .minimize(self.g_loss, var_list=g_vars) 123 | 124 | # weight clipping 125 | self.clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in d_vars] 126 | 127 | """" Testing """ 128 | # for test 129 | self.fake_images = self.generator(self.z, is_training=False, reuse=True) 130 | 131 | """ Summary """ 132 | d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real) 133 | d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake) 134 | d_loss_sum = tf.summary.scalar("d_loss", self.d_loss) 135 | g_loss_sum = tf.summary.scalar("g_loss", self.g_loss) 136 | 137 | # final summary operations 138 | self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum]) 139 | self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum]) 140 | 141 | def train(self): 142 | 143 | # initialize all variables 144 | tf.global_variables_initializer().run() 145 | 146 | # graph inputs for visualize training results 147 | self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size , self.z_dim)) 148 | 149 | # saver to save model 150 | self.saver = tf.train.Saver() 151 | 152 | # summary writer 153 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph) 154 | 155 | # restore check-point if it exits 156 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 157 | if could_load: 158 | start_epoch = (int)(checkpoint_counter / self.num_batches) 159 | start_batch_id = checkpoint_counter - start_epoch * self.num_batches 160 | counter = checkpoint_counter 161 | print(" [*] Load SUCCESS") 162 | else: 163 | start_epoch = 0 164 | start_batch_id = 0 165 | counter = 1 166 | print(" [!] Load failed...") 167 | 168 | # loop for epoch 169 | start_time = time.time() 170 | for epoch in range(start_epoch, self.epoch): 171 | 172 | # get batch data 173 | for idx in range(start_batch_id, self.num_batches): 174 | batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size] 175 | batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32) 176 | 177 | # update D network 178 | _, _, summary_str, d_loss = self.sess.run([self.d_optim, self.clip_D, self.d_sum, self.d_loss], 179 | feed_dict={self.inputs: batch_images, self.z: batch_z}) 180 | self.writer.add_summary(summary_str, counter) 181 | 182 | # update G network 183 | _, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss], feed_dict={self.z: batch_z}) 184 | self.writer.add_summary(summary_str, counter) 185 | 186 | # display training status 187 | counter += 1 188 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ 189 | % (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss)) 190 | 191 | # save training results for every 300 steps 192 | if np.mod(counter, 300) == 0: 193 | samples = self.sess.run(self.fake_images, 194 | feed_dict={self.z: self.sample_z}) 195 | tot_num_samples = min(self.sample_num, self.batch_size) 196 | manifold_h = int(np.floor(np.sqrt(tot_num_samples))) 197 | manifold_w = int(np.floor(np.sqrt(tot_num_samples))) 198 | save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w], 199 | './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format( 200 | epoch, idx)) 201 | 202 | # After an epoch, start_batch_id is set to zero 203 | # non-zero value is only for the first epoch after loading pre-trained model 204 | start_batch_id = 0 205 | 206 | # save model 207 | self.save(self.checkpoint_dir, counter) 208 | 209 | # show temporal results 210 | self.visualize_results(epoch) 211 | 212 | # save model for final step 213 | self.save(self.checkpoint_dir, counter) 214 | 215 | def visualize_results(self, epoch): 216 | tot_num_samples = min(self.sample_num, self.batch_size) 217 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 218 | 219 | """ random condition, random noise """ 220 | 221 | z_sample = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim)) 222 | 223 | samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample}) 224 | 225 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 226 | check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png') 227 | 228 | @property 229 | def model_dir(self): 230 | return "{}_{}_{}_{}".format( 231 | self.model_name, self.dataset_name, 232 | self.batch_size, self.z_dim) 233 | 234 | def save(self, checkpoint_dir, step): 235 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 236 | 237 | if not os.path.exists(checkpoint_dir): 238 | os.makedirs(checkpoint_dir) 239 | 240 | self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step) 241 | 242 | def load(self, checkpoint_dir): 243 | import re 244 | print(" [*] Reading checkpoints...") 245 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 246 | 247 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 248 | if ckpt and ckpt.model_checkpoint_path: 249 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 250 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 251 | counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0)) 252 | print(" [*] Success to read {}".format(ckpt_name)) 253 | return True, counter 254 | else: 255 | print(" [*] Failed to find a checkpoint") 256 | return False, 0 -------------------------------------------------------------------------------- /VAE.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | from __future__ import division 3 | import os 4 | import time 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | from ops import * 9 | from utils import * 10 | 11 | import prior_factory as prior 12 | 13 | class VAE(object): 14 | model_name = "VAE" # name for checkpoint 15 | 16 | def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir): 17 | self.sess = sess 18 | self.dataset_name = dataset_name 19 | self.checkpoint_dir = checkpoint_dir 20 | self.result_dir = result_dir 21 | self.log_dir = log_dir 22 | self.epoch = epoch 23 | self.batch_size = batch_size 24 | 25 | if dataset_name == 'mnist' or dataset_name == 'fashion-mnist': 26 | # parameters 27 | self.input_height = 28 28 | self.input_width = 28 29 | self.output_height = 28 30 | self.output_width = 28 31 | 32 | self.z_dim = z_dim # dimension of noise-vector 33 | self.c_dim = 1 34 | 35 | # train 36 | self.learning_rate = 0.0002 37 | self.beta1 = 0.5 38 | 39 | # test 40 | self.sample_num = 64 # number of generated images to be saved 41 | 42 | # load mnist 43 | self.data_X, self.data_y = load_mnist(self.dataset_name) 44 | 45 | # get number of batches for a single epoch 46 | self.num_batches = len(self.data_X) // self.batch_size 47 | else: 48 | raise NotImplementedError 49 | 50 | # Gaussian Encoder 51 | def encoder(self, x, is_training=True, reuse=False): 52 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 53 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC62*4 54 | with tf.variable_scope("encoder", reuse=reuse): 55 | 56 | net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='en_conv1')) 57 | net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='en_conv2'), is_training=is_training, scope='en_bn2')) 58 | net = tf.reshape(net, [self.batch_size, -1]) 59 | net = lrelu(bn(linear(net, 1024, scope='en_fc3'), is_training=is_training, scope='en_bn3')) 60 | gaussian_params = linear(net, 2 * self.z_dim, scope='en_fc4') 61 | 62 | # The mean parameter is unconstrained 63 | mean = gaussian_params[:, :self.z_dim] 64 | # The standard deviation must be positive. Parametrize with a softplus and 65 | # add a small epsilon for numerical stability 66 | stddev = 1e-6 + tf.nn.softplus(gaussian_params[:, self.z_dim:]) 67 | 68 | return mean, stddev 69 | 70 | # Bernoulli decoder 71 | def decoder(self, z, is_training=True, reuse=False): 72 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 73 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 74 | with tf.variable_scope("decoder", reuse=reuse): 75 | net = tf.nn.relu(bn(linear(z, 1024, scope='de_fc1'), is_training=is_training, scope='de_bn1')) 76 | net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='de_fc2'), is_training=is_training, scope='de_bn2')) 77 | net = tf.reshape(net, [self.batch_size, 7, 7, 128]) 78 | net = tf.nn.relu( 79 | bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='de_dc3'), is_training=is_training, 80 | scope='de_bn3')) 81 | 82 | out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='de_dc4')) 83 | return out 84 | 85 | def build_model(self): 86 | # some parameters 87 | image_dims = [self.input_height, self.input_width, self.c_dim] 88 | bs = self.batch_size 89 | 90 | """ Graph Input """ 91 | # images 92 | self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images') 93 | 94 | # noises 95 | self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z') 96 | 97 | """ Loss Function """ 98 | # encoding 99 | self.mu, sigma = self.encoder(self.inputs, is_training=True, reuse=False) 100 | 101 | # sampling by re-parameterization technique 102 | z = self.mu + sigma * tf.random_normal(tf.shape(self.mu), 0, 1, dtype=tf.float32) 103 | 104 | # decoding 105 | out = self.decoder(z, is_training=True, reuse=False) 106 | self.out = tf.clip_by_value(out, 1e-8, 1 - 1e-8) 107 | 108 | # loss 109 | marginal_likelihood = tf.reduce_sum(self.inputs * tf.log(self.out) + (1 - self.inputs) * tf.log(1 - self.out), 110 | [1, 2]) 111 | KL_divergence = 0.5 * tf.reduce_sum(tf.square(self.mu) + tf.square(sigma) - tf.log(1e-8 + tf.square(sigma)) - 1, [1]) 112 | 113 | self.neg_loglikelihood = -tf.reduce_mean(marginal_likelihood) 114 | self.KL_divergence = tf.reduce_mean(KL_divergence) 115 | 116 | ELBO = -self.neg_loglikelihood - self.KL_divergence 117 | 118 | self.loss = -ELBO 119 | 120 | """ Training """ 121 | # optimizers 122 | t_vars = tf.trainable_variables() 123 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 124 | self.optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \ 125 | .minimize(self.loss, var_list=t_vars) 126 | 127 | """" Testing """ 128 | # for test 129 | self.fake_images = self.decoder(self.z, is_training=False, reuse=True) 130 | 131 | """ Summary """ 132 | nll_sum = tf.summary.scalar("nll", self.neg_loglikelihood) 133 | kl_sum = tf.summary.scalar("kl", self.KL_divergence) 134 | loss_sum = tf.summary.scalar("loss", self.loss) 135 | 136 | # final summary operations 137 | self.merged_summary_op = tf.summary.merge_all() 138 | 139 | def train(self): 140 | 141 | # initialize all variables 142 | tf.global_variables_initializer().run() 143 | 144 | # graph inputs for visualize training results 145 | self.sample_z = prior.gaussian(self.batch_size, self.z_dim) 146 | 147 | # saver to save model 148 | self.saver = tf.train.Saver() 149 | 150 | # summary writer 151 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph) 152 | 153 | # restore check-point if it exits 154 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 155 | if could_load: 156 | start_epoch = (int)(checkpoint_counter / self.num_batches) 157 | start_batch_id = checkpoint_counter - start_epoch * self.num_batches 158 | counter = checkpoint_counter 159 | print(" [*] Load SUCCESS") 160 | else: 161 | start_epoch = 0 162 | start_batch_id = 0 163 | counter = 1 164 | print(" [!] Load failed...") 165 | 166 | # loop for epoch 167 | start_time = time.time() 168 | for epoch in range(start_epoch, self.epoch): 169 | 170 | # get batch data 171 | for idx in range(start_batch_id, self.num_batches): 172 | batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size] 173 | batch_z = prior.gaussian(self.batch_size, self.z_dim) 174 | 175 | # update autoencoder 176 | _, summary_str, loss, nll_loss, kl_loss = self.sess.run([self.optim, self.merged_summary_op, self.loss, self.neg_loglikelihood, self.KL_divergence], 177 | feed_dict={self.inputs: batch_images, self.z: batch_z}) 178 | self.writer.add_summary(summary_str, counter) 179 | 180 | # display training status 181 | counter += 1 182 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.8f, nll: %.8f, kl: %.8f" \ 183 | % (epoch, idx, self.num_batches, time.time() - start_time, loss, nll_loss, kl_loss)) 184 | 185 | # save training results for every 300 steps 186 | if np.mod(counter, 300) == 0: 187 | samples = self.sess.run(self.fake_images, 188 | feed_dict={self.z: self.sample_z}) 189 | 190 | tot_num_samples = min(self.sample_num, self.batch_size) 191 | manifold_h = int(np.floor(np.sqrt(tot_num_samples))) 192 | manifold_w = int(np.floor(np.sqrt(tot_num_samples))) 193 | save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w], 194 | './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format( 195 | epoch, idx)) 196 | 197 | # After an epoch, start_batch_id is set to zero 198 | # non-zero value is only for the first epoch after loading pre-trained model 199 | start_batch_id = 0 200 | 201 | # save model 202 | self.save(self.checkpoint_dir, counter) 203 | 204 | # show temporal results 205 | self.visualize_results(epoch) 206 | 207 | # save model for final step 208 | self.save(self.checkpoint_dir, counter) 209 | 210 | def visualize_results(self, epoch): 211 | tot_num_samples = min(self.sample_num, self.batch_size) 212 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 213 | 214 | """ random condition, random noise """ 215 | 216 | z_sample = prior.gaussian(self.batch_size, self.z_dim) 217 | 218 | samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample}) 219 | 220 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 221 | check_folder( 222 | self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png') 223 | 224 | """ learned manifold """ 225 | if self.z_dim == 2: 226 | assert self.z_dim == 2 227 | 228 | z_tot = None 229 | id_tot = None 230 | for idx in range(0, 100): 231 | #randomly sampling 232 | id = np.random.randint(0,self.num_batches) 233 | batch_images = self.data_X[id * self.batch_size:(id + 1) * self.batch_size] 234 | batch_labels = self.data_y[id * self.batch_size:(id + 1) * self.batch_size] 235 | 236 | z = self.sess.run(self.mu, feed_dict={self.inputs: batch_images}) 237 | 238 | if idx == 0: 239 | z_tot = z 240 | id_tot = batch_labels 241 | else: 242 | z_tot = np.concatenate((z_tot, z), axis=0) 243 | id_tot = np.concatenate((id_tot, batch_labels), axis=0) 244 | 245 | save_scattered_image(z_tot, id_tot, -4, 4, name=check_folder( 246 | self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_learned_manifold.png') 247 | 248 | @property 249 | def model_dir(self): 250 | return "{}_{}_{}_{}".format( 251 | self.model_name, self.dataset_name, 252 | self.batch_size, self.z_dim) 253 | 254 | def save(self, checkpoint_dir, step): 255 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 256 | 257 | if not os.path.exists(checkpoint_dir): 258 | os.makedirs(checkpoint_dir) 259 | 260 | self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step) 261 | 262 | def load(self, checkpoint_dir): 263 | import re 264 | print(" [*] Reading checkpoints...") 265 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 266 | 267 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 268 | if ckpt and ckpt.model_checkpoint_path: 269 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 270 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 271 | counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0)) 272 | print(" [*] Success to read {}".format(ckpt_name)) 273 | return True, counter 274 | else: 275 | print(" [*] Failed to find a checkpoint") 276 | return False, 0 -------------------------------------------------------------------------------- /WGAN.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | from __future__ import division 3 | import os 4 | import time 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | from ops import * 9 | from utils import * 10 | 11 | class WGAN(object): 12 | model_name = "WGAN" # name for checkpoint 13 | 14 | def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir): 15 | self.sess = sess 16 | self.dataset_name = dataset_name 17 | self.checkpoint_dir = checkpoint_dir 18 | self.result_dir = result_dir 19 | self.log_dir = log_dir 20 | self.epoch = epoch 21 | self.batch_size = batch_size 22 | 23 | if dataset_name == 'mnist' or dataset_name == 'fashion-mnist': 24 | # parameters 25 | self.input_height = 28 26 | self.input_width = 28 27 | self.output_height = 28 28 | self.output_width = 28 29 | 30 | self.z_dim = z_dim # dimension of noise-vector 31 | self.c_dim = 1 32 | 33 | # WGAN parameter 34 | self.disc_iters = 1 # The number of critic iterations for one-step of generator 35 | 36 | # train 37 | self.learning_rate = 0.0002 38 | self.beta1 = 0.5 39 | 40 | # test 41 | self.sample_num = 64 # number of generated images to be saved 42 | 43 | # load mnist 44 | self.data_X, self.data_y = load_mnist(self.dataset_name) 45 | 46 | # get number of batches for a single epoch 47 | self.num_batches = len(self.data_X) // self.batch_size 48 | else: 49 | raise NotImplementedError 50 | 51 | def discriminator(self, x, is_training=True, reuse=False): 52 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 53 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S 54 | with tf.variable_scope("discriminator", reuse=reuse): 55 | 56 | net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='d_conv1')) 57 | net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'), is_training=is_training, scope='d_bn2')) 58 | net = tf.reshape(net, [self.batch_size, -1]) 59 | net = lrelu(bn(linear(net, 1024, scope='d_fc3'), is_training=is_training, scope='d_bn3')) 60 | out_logit = linear(net, 1, scope='d_fc4') 61 | out = tf.nn.sigmoid(out_logit) 62 | 63 | return out, out_logit, net 64 | 65 | def generator(self, z, is_training=True, reuse=False): 66 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 67 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 68 | with tf.variable_scope("generator", reuse=reuse): 69 | net = tf.nn.relu(bn(linear(z, 1024, scope='g_fc1'), is_training=is_training, scope='g_bn1')) 70 | net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='g_fc2'), is_training=is_training, scope='g_bn2')) 71 | net = tf.reshape(net, [self.batch_size, 7, 7, 128]) 72 | net = tf.nn.relu( 73 | bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='g_dc3'), is_training=is_training, 74 | scope='g_bn3')) 75 | 76 | out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='g_dc4')) 77 | 78 | return out 79 | 80 | def build_model(self): 81 | # some parameters 82 | image_dims = [self.input_height, self.input_width, self.c_dim] 83 | bs = self.batch_size 84 | 85 | """ Graph Input """ 86 | # images 87 | self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images') 88 | 89 | # noises 90 | self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z') 91 | 92 | """ Loss Function """ 93 | 94 | # output of D for real images 95 | D_real, D_real_logits, _ = self.discriminator(self.inputs, is_training=True, reuse=False) 96 | 97 | # output of D for fake images 98 | G = self.generator(self.z, is_training=True, reuse=False) 99 | D_fake, D_fake_logits, _ = self.discriminator(G, is_training=True, reuse=True) 100 | 101 | # get loss for discriminator 102 | d_loss_real = - tf.reduce_mean(D_real_logits) 103 | d_loss_fake = tf.reduce_mean(D_fake_logits) 104 | 105 | self.d_loss = d_loss_real + d_loss_fake 106 | 107 | # get loss for generator 108 | self.g_loss = - d_loss_fake 109 | 110 | """ Training """ 111 | # divide trainable variables into a group for D and a group for G 112 | t_vars = tf.trainable_variables() 113 | d_vars = [var for var in t_vars if 'd_' in var.name] 114 | g_vars = [var for var in t_vars if 'g_' in var.name] 115 | 116 | # optimizers 117 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 118 | self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \ 119 | .minimize(self.d_loss, var_list=d_vars) 120 | self.g_optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \ 121 | .minimize(self.g_loss, var_list=g_vars) 122 | 123 | # weight clipping 124 | self.clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in d_vars] 125 | 126 | """" Testing """ 127 | # for test 128 | self.fake_images = self.generator(self.z, is_training=False, reuse=True) 129 | 130 | """ Summary """ 131 | d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real) 132 | d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake) 133 | d_loss_sum = tf.summary.scalar("d_loss", self.d_loss) 134 | g_loss_sum = tf.summary.scalar("g_loss", self.g_loss) 135 | 136 | # final summary operations 137 | self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum]) 138 | self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum]) 139 | 140 | def train(self): 141 | 142 | # initialize all variables 143 | tf.global_variables_initializer().run() 144 | 145 | # graph inputs for visualize training results 146 | self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size , self.z_dim)) 147 | 148 | # saver to save model 149 | self.saver = tf.train.Saver() 150 | 151 | # summary writer 152 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph) 153 | 154 | # restore check-point if it exits 155 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 156 | if could_load: 157 | start_epoch = (int)(checkpoint_counter / self.num_batches) 158 | start_batch_id = checkpoint_counter - start_epoch * self.num_batches 159 | counter = checkpoint_counter 160 | print(" [*] Load SUCCESS") 161 | else: 162 | start_epoch = 0 163 | start_batch_id = 0 164 | counter = 1 165 | print(" [!] Load failed...") 166 | 167 | # loop for epoch 168 | start_time = time.time() 169 | for epoch in range(start_epoch, self.epoch): 170 | 171 | # get batch data 172 | for idx in range(start_batch_id, self.num_batches): 173 | batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size] 174 | batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32) 175 | 176 | # update D network 177 | _, _, summary_str, d_loss = self.sess.run([self.d_optim, self.clip_D, self.d_sum, self.d_loss], 178 | feed_dict={self.inputs: batch_images, self.z: batch_z}) 179 | self.writer.add_summary(summary_str, counter) 180 | 181 | # update G network 182 | if (counter - 1) % self.disc_iters == 0: 183 | _, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss], feed_dict={self.z: batch_z}) 184 | self.writer.add_summary(summary_str, counter) 185 | 186 | # display training status 187 | counter += 1 188 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ 189 | % (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss)) 190 | 191 | # save training results for every 300 steps 192 | if np.mod(counter, 300) == 0: 193 | samples = self.sess.run(self.fake_images, 194 | feed_dict={self.z: self.sample_z}) 195 | tot_num_samples = min(self.sample_num, self.batch_size) 196 | manifold_h = int(np.floor(np.sqrt(tot_num_samples))) 197 | manifold_w = int(np.floor(np.sqrt(tot_num_samples))) 198 | save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w], 199 | './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format( 200 | epoch, idx)) 201 | 202 | # After an epoch, start_batch_id is set to zero 203 | # non-zero value is only for the first epoch after loading pre-trained model 204 | start_batch_id = 0 205 | 206 | # save model 207 | self.save(self.checkpoint_dir, counter) 208 | 209 | # show temporal results 210 | self.visualize_results(epoch) 211 | 212 | # save model for final step 213 | self.save(self.checkpoint_dir, counter) 214 | 215 | def visualize_results(self, epoch): 216 | tot_num_samples = min(self.sample_num, self.batch_size) 217 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 218 | 219 | """ random condition, random noise """ 220 | 221 | z_sample = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim)) 222 | 223 | samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample}) 224 | 225 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 226 | check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png') 227 | 228 | @property 229 | def model_dir(self): 230 | return "{}_{}_{}_{}".format( 231 | self.model_name, self.dataset_name, 232 | self.batch_size, self.z_dim) 233 | 234 | def save(self, checkpoint_dir, step): 235 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 236 | 237 | if not os.path.exists(checkpoint_dir): 238 | os.makedirs(checkpoint_dir) 239 | 240 | self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step) 241 | 242 | def load(self, checkpoint_dir): 243 | import re 244 | print(" [*] Reading checkpoints...") 245 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 246 | 247 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 248 | if ckpt and ckpt.model_checkpoint_path: 249 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 250 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 251 | counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0)) 252 | print(" [*] Success to read {}".format(ckpt_name)) 253 | return True, counter 254 | else: 255 | print(" [*] Failed to find a checkpoint") 256 | return False, 0 -------------------------------------------------------------------------------- /WGAN_GP.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | from __future__ import division 3 | import os 4 | import time 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | from ops import * 9 | from utils import * 10 | 11 | class WGAN_GP(object): 12 | model_name = "WGAN_GP" # name for checkpoint 13 | 14 | def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir): 15 | self.sess = sess 16 | self.dataset_name = dataset_name 17 | self.checkpoint_dir = checkpoint_dir 18 | self.result_dir = result_dir 19 | self.log_dir = log_dir 20 | self.epoch = epoch 21 | self.batch_size = batch_size 22 | 23 | if dataset_name == 'mnist' or dataset_name == 'fashion-mnist': 24 | # parameters 25 | self.input_height = 28 26 | self.input_width = 28 27 | self.output_height = 28 28 | self.output_width = 28 29 | 30 | self.z_dim = z_dim # dimension of noise-vector 31 | self.c_dim = 1 32 | 33 | # WGAN_GP parameter 34 | self.lambd = 0.25 # The higher value, the more stable, but the slower convergence 35 | self.disc_iters = 1 # The number of critic iterations for one-step of generator 36 | 37 | # train 38 | self.learning_rate = 0.0002 39 | self.beta1 = 0.5 40 | 41 | # test 42 | self.sample_num = 64 # number of generated images to be saved 43 | 44 | # load mnist 45 | self.data_X, self.data_y = load_mnist(self.dataset_name) 46 | 47 | # get number of batches for a single epoch 48 | self.num_batches = len(self.data_X) // self.batch_size 49 | else: 50 | raise NotImplementedError 51 | 52 | def discriminator(self, x, is_training=True, reuse=False): 53 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 54 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S 55 | with tf.variable_scope("discriminator", reuse=reuse): 56 | 57 | net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='d_conv1')) 58 | net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'), is_training=is_training, scope='d_bn2')) 59 | net = tf.reshape(net, [self.batch_size, -1]) 60 | net = lrelu(bn(linear(net, 1024, scope='d_fc3'), is_training=is_training, scope='d_bn3')) 61 | out_logit = linear(net, 1, scope='d_fc4') 62 | out = tf.nn.sigmoid(out_logit) 63 | 64 | return out, out_logit, net 65 | 66 | def generator(self, z, is_training=True, reuse=False): 67 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 68 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 69 | with tf.variable_scope("generator", reuse=reuse): 70 | net = tf.nn.relu(bn(linear(z, 1024, scope='g_fc1'), is_training=is_training, scope='g_bn1')) 71 | net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='g_fc2'), is_training=is_training, scope='g_bn2')) 72 | net = tf.reshape(net, [self.batch_size, 7, 7, 128]) 73 | net = tf.nn.relu( 74 | bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='g_dc3'), is_training=is_training, 75 | scope='g_bn3')) 76 | 77 | out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='g_dc4')) 78 | 79 | return out 80 | 81 | def build_model(self): 82 | # some parameters 83 | image_dims = [self.input_height, self.input_width, self.c_dim] 84 | bs = self.batch_size 85 | 86 | """ Graph Input """ 87 | # images 88 | self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images') 89 | 90 | # noises 91 | self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z') 92 | 93 | """ Loss Function """ 94 | 95 | # output of D for real images 96 | D_real, D_real_logits, _ = self.discriminator(self.inputs, is_training=True, reuse=False) 97 | 98 | # output of D for fake images 99 | G = self.generator(self.z, is_training=True, reuse=False) 100 | D_fake, D_fake_logits, _ = self.discriminator(G, is_training=True, reuse=True) 101 | 102 | # get loss for discriminator 103 | d_loss_real = - tf.reduce_mean(D_real_logits) 104 | d_loss_fake = tf.reduce_mean(D_fake_logits) 105 | 106 | self.d_loss = d_loss_real + d_loss_fake 107 | 108 | # get loss for generator 109 | self.g_loss = - d_loss_fake 110 | 111 | """ Gradient Penalty """ 112 | # This is borrowed from https://github.com/kodalinaveen3/DRAGAN/blob/master/DRAGAN.ipynb 113 | alpha = tf.random_uniform(shape=self.inputs.get_shape(), minval=0.,maxval=1.) 114 | differences = G - self.inputs # This is different from MAGAN 115 | interpolates = self.inputs + (alpha * differences) 116 | _,D_inter,_=self.discriminator(interpolates, is_training=True, reuse=True) 117 | gradients = tf.gradients(D_inter, [interpolates])[0] 118 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) 119 | gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2) 120 | self.d_loss += self.lambd * gradient_penalty 121 | 122 | """ Training """ 123 | # divide trainable variables into a group for D and a group for G 124 | t_vars = tf.trainable_variables() 125 | d_vars = [var for var in t_vars if 'd_' in var.name] 126 | g_vars = [var for var in t_vars if 'g_' in var.name] 127 | 128 | # optimizers 129 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 130 | self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \ 131 | .minimize(self.d_loss, var_list=d_vars) 132 | self.g_optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \ 133 | .minimize(self.g_loss, var_list=g_vars) 134 | 135 | """" Testing """ 136 | # for test 137 | self.fake_images = self.generator(self.z, is_training=False, reuse=True) 138 | 139 | """ Summary """ 140 | d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real) 141 | d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake) 142 | d_loss_sum = tf.summary.scalar("d_loss", self.d_loss) 143 | g_loss_sum = tf.summary.scalar("g_loss", self.g_loss) 144 | 145 | # final summary operations 146 | self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum]) 147 | self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum]) 148 | 149 | def train(self): 150 | 151 | # initialize all variables 152 | tf.global_variables_initializer().run() 153 | 154 | # graph inputs for visualize training results 155 | self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size , self.z_dim)) 156 | 157 | # saver to save model 158 | self.saver = tf.train.Saver() 159 | 160 | # summary writer 161 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph) 162 | 163 | # restore check-point if it exits 164 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 165 | if could_load: 166 | start_epoch = (int)(checkpoint_counter / self.num_batches) 167 | start_batch_id = checkpoint_counter - start_epoch * self.num_batches 168 | counter = checkpoint_counter 169 | print(" [*] Load SUCCESS") 170 | else: 171 | start_epoch = 0 172 | start_batch_id = 0 173 | counter = 1 174 | print(" [!] Load failed...") 175 | 176 | # loop for epoch 177 | start_time = time.time() 178 | for epoch in range(start_epoch, self.epoch): 179 | 180 | # get batch data 181 | for idx in range(start_batch_id, self.num_batches): 182 | batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size] 183 | batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32) 184 | 185 | # update D network 186 | _, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss], 187 | feed_dict={self.inputs: batch_images, self.z: batch_z}) 188 | self.writer.add_summary(summary_str, counter) 189 | 190 | # update G network 191 | if (counter-1) % self.disc_iters == 0: 192 | batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32) 193 | _, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss], feed_dict={self.z: batch_z}) 194 | self.writer.add_summary(summary_str, counter) 195 | 196 | counter += 1 197 | 198 | # display training status 199 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ 200 | % (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss)) 201 | 202 | # save training results for every 300 steps 203 | if np.mod(counter, 300) == 0: 204 | samples = self.sess.run(self.fake_images, 205 | feed_dict={self.z: self.sample_z}) 206 | tot_num_samples = min(self.sample_num, self.batch_size) 207 | manifold_h = int(np.floor(np.sqrt(tot_num_samples))) 208 | manifold_w = int(np.floor(np.sqrt(tot_num_samples))) 209 | save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w], 210 | './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format( 211 | epoch, idx)) 212 | 213 | # After an epoch, start_batch_id is set to zero 214 | # non-zero value is only for the first epoch after loading pre-trained model 215 | start_batch_id = 0 216 | 217 | # save model 218 | self.save(self.checkpoint_dir, counter) 219 | 220 | # show temporal results 221 | self.visualize_results(epoch) 222 | 223 | # save model for final step 224 | self.save(self.checkpoint_dir, counter) 225 | 226 | def visualize_results(self, epoch): 227 | tot_num_samples = min(self.sample_num, self.batch_size) 228 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 229 | 230 | """ random condition, random noise """ 231 | 232 | z_sample = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim)) 233 | 234 | samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample}) 235 | 236 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 237 | check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png') 238 | 239 | @property 240 | def model_dir(self): 241 | return "{}_{}_{}_{}".format( 242 | self.model_name, self.dataset_name, 243 | self.batch_size, self.z_dim) 244 | 245 | def save(self, checkpoint_dir, step): 246 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 247 | 248 | if not os.path.exists(checkpoint_dir): 249 | os.makedirs(checkpoint_dir) 250 | 251 | self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step) 252 | 253 | def load(self, checkpoint_dir): 254 | import re 255 | print(" [*] Reading checkpoints...") 256 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 257 | 258 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 259 | if ckpt and ckpt.model_checkpoint_path: 260 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 261 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 262 | counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0)) 263 | print(" [*] Success to read {}".format(ckpt_name)) 264 | return True, counter 265 | else: 266 | print(" [*] Failed to find a checkpoint") 267 | return False, 0 -------------------------------------------------------------------------------- /assets/equations/ACGAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/equations/ACGAN.png -------------------------------------------------------------------------------- /assets/equations/BEGAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/equations/BEGAN.png -------------------------------------------------------------------------------- /assets/equations/CGAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/equations/CGAN.png -------------------------------------------------------------------------------- /assets/equations/CVAE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/equations/CVAE.png -------------------------------------------------------------------------------- /assets/equations/DRAGAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/equations/DRAGAN.png -------------------------------------------------------------------------------- /assets/equations/EBGAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/equations/EBGAN.png -------------------------------------------------------------------------------- /assets/equations/GAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/equations/GAN.png -------------------------------------------------------------------------------- /assets/equations/LSGAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/equations/LSGAN.png -------------------------------------------------------------------------------- /assets/equations/VAE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/equations/VAE.png -------------------------------------------------------------------------------- /assets/equations/WGAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/equations/WGAN.png -------------------------------------------------------------------------------- /assets/equations/WGAN_GP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/equations/WGAN_GP.png -------------------------------------------------------------------------------- /assets/equations/infoGAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/equations/infoGAN.png -------------------------------------------------------------------------------- /assets/etc/GAN_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/etc/GAN_structure.png -------------------------------------------------------------------------------- /assets/etc/VAE_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/etc/VAE_structure.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/conditional_generation/ACGAN_epoch000_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/conditional_generation/ACGAN_epoch000_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/conditional_generation/ACGAN_epoch009_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/conditional_generation/ACGAN_epoch009_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/conditional_generation/ACGAN_epoch019_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/conditional_generation/ACGAN_epoch019_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/conditional_generation/ACGAN_epoch039_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/conditional_generation/ACGAN_epoch039_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/conditional_generation/CGAN_epoch000_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/conditional_generation/CGAN_epoch000_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/conditional_generation/CGAN_epoch019_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/conditional_generation/CGAN_epoch019_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/conditional_generation/CGAN_epoch039_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/conditional_generation/CGAN_epoch039_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/conditional_generation/CVAE_epoch000_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/conditional_generation/CVAE_epoch000_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/conditional_generation/CVAE_epoch019_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/conditional_generation/CVAE_epoch019_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/conditional_generation/CVAE_epoch039_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/conditional_generation/CVAE_epoch039_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/conditional_generation/infoGAN_epoch000_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/conditional_generation/infoGAN_epoch000_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/conditional_generation/infoGAN_epoch019_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/conditional_generation/infoGAN_epoch019_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/conditional_generation/infoGAN_epoch039_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/conditional_generation/infoGAN_epoch039_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/infogan/infoGAN_epoch039_test_class_c1c2_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/infogan/infoGAN_epoch039_test_class_c1c2_1.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/infogan/infoGAN_epoch039_test_class_c1c2_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/infogan/infoGAN_epoch039_test_class_c1c2_4.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/infogan/infoGAN_epoch039_test_class_c1c2_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/infogan/infoGAN_epoch039_test_class_c1c2_5.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/infogan/infoGAN_epoch039_test_class_c1c2_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/infogan/infoGAN_epoch039_test_class_c1c2_8.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/learned_manifold/VAE_epoch000_learned_manifold.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/learned_manifold/VAE_epoch000_learned_manifold.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/learned_manifold/VAE_epoch009_learned_manifold.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/learned_manifold/VAE_epoch009_learned_manifold.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/learned_manifold/VAE_epoch024_learned_manifold.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/learned_manifold/VAE_epoch024_learned_manifold.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/BEGAN_epoch000_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/BEGAN_epoch000_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/BEGAN_epoch019_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/BEGAN_epoch019_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/BEGAN_epoch039_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/BEGAN_epoch039_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/DRAGAN_epoch000_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/DRAGAN_epoch000_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/DRAGAN_epoch019_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/DRAGAN_epoch019_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/DRAGAN_epoch039_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/DRAGAN_epoch039_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/EBGAN_epoch000_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/EBGAN_epoch000_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/EBGAN_epoch019_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/EBGAN_epoch019_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/EBGAN_epoch039_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/EBGAN_epoch039_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/GAN_epoch000_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/GAN_epoch000_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/GAN_epoch019_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/GAN_epoch019_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/GAN_epoch039_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/GAN_epoch039_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/LSGAN_epoch000_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/LSGAN_epoch000_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/LSGAN_epoch019_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/LSGAN_epoch019_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/LSGAN_epoch039_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/LSGAN_epoch039_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/VAE_epoch000_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/VAE_epoch000_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/VAE_epoch019_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/VAE_epoch019_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/VAE_epoch039_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/VAE_epoch039_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/WGAN_GP_epoch000_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/WGAN_GP_epoch000_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/WGAN_GP_epoch019_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/WGAN_GP_epoch019_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/WGAN_GP_epoch039_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/WGAN_GP_epoch039_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/WGAN_epoch000_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/WGAN_epoch000_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/WGAN_epoch019_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/WGAN_epoch019_test_all_classes.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/random_generation/WGAN_epoch039_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/fashion_mnist_results/random_generation/WGAN_epoch039_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/conditional_generation/ACGAN_epoch000_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/conditional_generation/ACGAN_epoch000_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/mnist_results/conditional_generation/ACGAN_epoch009_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/conditional_generation/ACGAN_epoch009_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/mnist_results/conditional_generation/ACGAN_epoch024_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/conditional_generation/ACGAN_epoch024_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/mnist_results/conditional_generation/CGAN_epoch000_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/conditional_generation/CGAN_epoch000_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/mnist_results/conditional_generation/CGAN_epoch009_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/conditional_generation/CGAN_epoch009_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/mnist_results/conditional_generation/CGAN_epoch024_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/conditional_generation/CGAN_epoch024_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/mnist_results/conditional_generation/CVAE_epoch000_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/conditional_generation/CVAE_epoch000_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/mnist_results/conditional_generation/CVAE_epoch009_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/conditional_generation/CVAE_epoch009_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/mnist_results/conditional_generation/CVAE_epoch024_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/conditional_generation/CVAE_epoch024_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/mnist_results/conditional_generation/infoGAN_epoch000_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/conditional_generation/infoGAN_epoch000_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/mnist_results/conditional_generation/infoGAN_epoch009_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/conditional_generation/infoGAN_epoch009_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/mnist_results/conditional_generation/infoGAN_epoch024_test_all_classes_style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/conditional_generation/infoGAN_epoch024_test_all_classes_style_by_style.png -------------------------------------------------------------------------------- /assets/mnist_results/infogan/infoGAN_epoch024_test_class_c1c2_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/infogan/infoGAN_epoch024_test_class_c1c2_2.png -------------------------------------------------------------------------------- /assets/mnist_results/infogan/infoGAN_epoch024_test_class_c1c2_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/infogan/infoGAN_epoch024_test_class_c1c2_5.png -------------------------------------------------------------------------------- /assets/mnist_results/infogan/infoGAN_epoch024_test_class_c1c2_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/infogan/infoGAN_epoch024_test_class_c1c2_7.png -------------------------------------------------------------------------------- /assets/mnist_results/infogan/infoGAN_epoch024_test_class_c1c2_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/infogan/infoGAN_epoch024_test_class_c1c2_9.png -------------------------------------------------------------------------------- /assets/mnist_results/learned_manifold/VAE_epoch000_learned_manifold.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/learned_manifold/VAE_epoch000_learned_manifold.png -------------------------------------------------------------------------------- /assets/mnist_results/learned_manifold/VAE_epoch009_learned_manifold.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/learned_manifold/VAE_epoch009_learned_manifold.png -------------------------------------------------------------------------------- /assets/mnist_results/learned_manifold/VAE_epoch024_learned_manifold.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/learned_manifold/VAE_epoch024_learned_manifold.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/BEGAN_epoch000_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/BEGAN_epoch000_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/BEGAN_epoch001_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/BEGAN_epoch001_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/BEGAN_epoch009_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/BEGAN_epoch009_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/BEGAN_epoch024_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/BEGAN_epoch024_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/DRAGAN_epoch000_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/DRAGAN_epoch000_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/DRAGAN_epoch001_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/DRAGAN_epoch001_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/DRAGAN_epoch009_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/DRAGAN_epoch009_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/DRAGAN_epoch024_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/DRAGAN_epoch024_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/EBGAN_epoch000_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/EBGAN_epoch000_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/EBGAN_epoch001_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/EBGAN_epoch001_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/EBGAN_epoch009_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/EBGAN_epoch009_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/EBGAN_epoch024_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/EBGAN_epoch024_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/GAN_epoch000_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/GAN_epoch000_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/GAN_epoch001_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/GAN_epoch001_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/GAN_epoch009_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/GAN_epoch009_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/GAN_epoch024_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/GAN_epoch024_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/LSGAN_epoch000_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/LSGAN_epoch000_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/LSGAN_epoch001_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/LSGAN_epoch001_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/LSGAN_epoch009_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/LSGAN_epoch009_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/LSGAN_epoch024_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/LSGAN_epoch024_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/VAE_epoch000_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/VAE_epoch000_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/VAE_epoch001_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/VAE_epoch001_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/VAE_epoch009_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/VAE_epoch009_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/VAE_epoch024_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/VAE_epoch024_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/WGAN_GP_epoch000_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/WGAN_GP_epoch000_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/WGAN_GP_epoch001_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/WGAN_GP_epoch001_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/WGAN_GP_epoch009_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/WGAN_GP_epoch009_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/WGAN_GP_epoch024_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/WGAN_GP_epoch024_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/WGAN_epoch000_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/WGAN_epoch000_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/WGAN_epoch001_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/WGAN_epoch001_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/WGAN_epoch009_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/WGAN_epoch009_test_all_classes.png -------------------------------------------------------------------------------- /assets/mnist_results/random_generation/WGAN_epoch024_test_all_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwalsuklee/tensorflow-generative-model-collections/3abde8a9bcfe31815da50347d14641ec95096f62/assets/mnist_results/random_generation/WGAN_epoch024_test_all_classes.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ## GAN Variants 4 | from GAN import GAN 5 | from CGAN import CGAN 6 | from infoGAN import infoGAN 7 | from ACGAN import ACGAN 8 | from EBGAN import EBGAN 9 | from WGAN import WGAN 10 | from WGAN_GP import WGAN_GP 11 | from DRAGAN import DRAGAN 12 | from LSGAN import LSGAN 13 | from BEGAN import BEGAN 14 | 15 | ## VAE Variants 16 | from VAE import VAE 17 | from CVAE import CVAE 18 | 19 | from utils import show_all_variables 20 | from utils import check_folder 21 | 22 | import tensorflow as tf 23 | import argparse 24 | 25 | """parsing and configuration""" 26 | def parse_args(): 27 | desc = "Tensorflow implementation of GAN collections" 28 | parser = argparse.ArgumentParser(description=desc) 29 | 30 | parser.add_argument('--gan_type', type=str, default='GAN', 31 | choices=['GAN', 'CGAN', 'infoGAN', 'ACGAN', 'EBGAN', 'BEGAN', 'WGAN', 'WGAN_GP', 'DRAGAN', 'LSGAN', 'VAE', 'CVAE'], 32 | help='The type of GAN', required=True) 33 | parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist', 'fashion-mnist', 'celebA'], 34 | help='The name of dataset') 35 | parser.add_argument('--epoch', type=int, default=20, help='The number of epochs to run') 36 | parser.add_argument('--batch_size', type=int, default=64, help='The size of batch') 37 | parser.add_argument('--z_dim', type=int, default=62, help='Dimension of noise vector') 38 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 39 | help='Directory name to save the checkpoints') 40 | parser.add_argument('--result_dir', type=str, default='results', 41 | help='Directory name to save the generated images') 42 | parser.add_argument('--log_dir', type=str, default='logs', 43 | help='Directory name to save training logs') 44 | 45 | return check_args(parser.parse_args()) 46 | 47 | """checking arguments""" 48 | def check_args(args): 49 | # --checkpoint_dir 50 | check_folder(args.checkpoint_dir) 51 | 52 | # --result_dir 53 | check_folder(args.result_dir) 54 | 55 | # --result_dir 56 | check_folder(args.log_dir) 57 | 58 | # --epoch 59 | assert args.epoch >= 1, 'number of epochs must be larger than or equal to one' 60 | 61 | # --batch_size 62 | assert args.batch_size >= 1, 'batch size must be larger than or equal to one' 63 | 64 | # --z_dim 65 | assert args.z_dim >= 1, 'dimension of noise vector must be larger than or equal to one' 66 | 67 | return args 68 | 69 | """main""" 70 | def main(): 71 | # parse arguments 72 | args = parse_args() 73 | if args is None: 74 | exit() 75 | 76 | # open session 77 | models = [GAN, CGAN, infoGAN, ACGAN, EBGAN, WGAN, WGAN_GP, DRAGAN, 78 | LSGAN, BEGAN, VAE, CVAE] 79 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 80 | # declare instance for GAN 81 | 82 | gan = None 83 | for model in models: 84 | if args.gan_type == model.model_name: 85 | gan = model(sess, 86 | epoch=args.epoch, 87 | batch_size=args.batch_size, 88 | z_dim=args.z_dim, 89 | dataset_name=args.dataset, 90 | checkpoint_dir=args.checkpoint_dir, 91 | result_dir=args.result_dir, 92 | log_dir=args.log_dir) 93 | if gan is None: 94 | raise Exception("[!] There is no option for " + args.gan_type) 95 | 96 | # build graph 97 | gan.build_model() 98 | 99 | # show network architecture 100 | show_all_variables() 101 | 102 | # launch the graph in a session 103 | gan.train() 104 | print(" [*] Training finished!") 105 | 106 | # visualize learned generator 107 | gan.visualize_results(args.epoch-1) 108 | print(" [*] Testing finished!") 109 | 110 | if __name__ == '__main__': 111 | main() 112 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Most codes from https://github.com/carpedm20/DCGAN-tensorflow 3 | """ 4 | import math 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from tensorflow.python.framework import ops 9 | 10 | from utils import * 11 | 12 | if "concat_v2" in dir(tf): 13 | def concat(tensors, axis, *args, **kwargs): 14 | return tf.concat_v2(tensors, axis, *args, **kwargs) 15 | else: 16 | def concat(tensors, axis, *args, **kwargs): 17 | return tf.concat(tensors, axis, *args, **kwargs) 18 | 19 | def bn(x, is_training, scope): 20 | return tf.contrib.layers.batch_norm(x, 21 | decay=0.9, 22 | updates_collections=None, 23 | epsilon=1e-5, 24 | scale=True, 25 | is_training=is_training, 26 | scope=scope) 27 | 28 | def conv_out_size_same(size, stride): 29 | return int(math.ceil(float(size) / float(stride))) 30 | 31 | def conv_cond_concat(x, y): 32 | """Concatenate conditioning vector on feature map axis.""" 33 | x_shapes = x.get_shape() 34 | y_shapes = y.get_shape() 35 | return concat([x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3) 36 | 37 | def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, name="conv2d"): 38 | with tf.variable_scope(name): 39 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], 40 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 41 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') 42 | 43 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 44 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 45 | 46 | return conv 47 | 48 | def deconv2d(input_, output_shape, k_h=5, k_w=5, d_h=2, d_w=2, name="deconv2d", stddev=0.02, with_w=False): 49 | with tf.variable_scope(name): 50 | # filter : [height, width, output_channels, in_channels] 51 | w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]], 52 | initializer=tf.random_normal_initializer(stddev=stddev)) 53 | 54 | try: 55 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, strides=[1, d_h, d_w, 1]) 56 | 57 | # Support for verisons of TensorFlow before 0.7.0 58 | except AttributeError: 59 | deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, strides=[1, d_h, d_w, 1]) 60 | 61 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 62 | deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) 63 | 64 | if with_w: 65 | return deconv, w, biases 66 | else: 67 | return deconv 68 | 69 | def lrelu(x, leak=0.2, name="lrelu"): 70 | return tf.maximum(x, leak*x) 71 | 72 | def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): 73 | shape = input_.get_shape().as_list() 74 | 75 | with tf.variable_scope(scope or "Linear"): 76 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, 77 | tf.random_normal_initializer(stddev=stddev)) 78 | bias = tf.get_variable("bias", [output_size], 79 | initializer=tf.constant_initializer(bias_start)) 80 | if with_w: 81 | return tf.matmul(input_, matrix) + bias, matrix, bias 82 | else: 83 | return tf.matmul(input_, matrix) + bias 84 | -------------------------------------------------------------------------------- /prior_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Most codes from https://github.com/musyoku/adversarial-autoencoder/blob/master/sampler.py 3 | """ 4 | 5 | import numpy as np 6 | from math import sin,cos,sqrt 7 | 8 | def onehot_categorical(batch_size, n_labels): 9 | y = np.zeros((batch_size, n_labels), dtype=np.float32) 10 | indices = np.random.randint(0, n_labels, batch_size) 11 | for b in range(batch_size): 12 | y[b, indices[b]] = 1 13 | return y 14 | 15 | def uniform(batch_size, n_dim, n_labels=10, minv=-1, maxv=1, label_indices=None): 16 | if label_indices is not None: 17 | if n_dim != 2: 18 | raise Exception("n_dim must be 2.") 19 | 20 | def sample(label, n_labels): 21 | num = int(np.ceil(np.sqrt(n_labels))) 22 | size = (maxv-minv)*1.0/num 23 | x, y = np.random.uniform(-size/2, size/2, (2,)) 24 | i = label / num 25 | j = label % num 26 | x += j*size+minv+0.5*size 27 | y += i*size+minv+0.5*size 28 | return np.array([x, y]).reshape((2,)) 29 | 30 | z = np.empty((batch_size, n_dim), dtype=np.float32) 31 | for batch in range(batch_size): 32 | for zi in range((int)(n_dim/2)): 33 | z[batch, zi*2:zi*2+2] = sample(label_indices[batch], n_labels) 34 | else: 35 | z = np.random.uniform(minv, maxv, (batch_size, n_dim)).astype(np.float32) 36 | return z 37 | 38 | def gaussian(batch_size, n_dim, mean=0, var=1, n_labels=10, use_label_info=False): 39 | if use_label_info: 40 | if n_dim != 2: 41 | raise Exception("n_dim must be 2.") 42 | 43 | def sample(n_labels): 44 | x, y = np.random.normal(mean, var, (2,)) 45 | angle = np.angle((x-mean) + 1j*(y-mean), deg=True) 46 | 47 | label = ((int)(n_labels*angle))//360 48 | 49 | if label<0: 50 | label+=n_labels 51 | 52 | return np.array([x, y]).reshape((2,)), label 53 | 54 | z = np.empty((batch_size, n_dim), dtype=np.float32) 55 | z_id = np.empty((batch_size, 1), dtype=np.int32) 56 | for batch in range(batch_size): 57 | for zi in range((int)(n_dim/2)): 58 | a_sample, a_label = sample(n_labels) 59 | z[batch, zi*2:zi*2+2] = a_sample 60 | z_id[batch] = a_label 61 | return z, z_id 62 | else: 63 | z = np.random.normal(mean, var, (batch_size, n_dim)).astype(np.float32) 64 | return z 65 | 66 | def gaussian_mixture(batch_size, n_dim=2, n_labels=10, x_var=0.5, y_var=0.1, label_indices=None): 67 | if n_dim != 2: 68 | raise Exception("n_dim must be 2.") 69 | 70 | def sample(x, y, label, n_labels): 71 | shift = 1.4 72 | r = 2.0 * np.pi / float(n_labels) * float(label) 73 | new_x = x * cos(r) - y * sin(r) 74 | new_y = x * sin(r) + y * cos(r) 75 | new_x += shift * cos(r) 76 | new_y += shift * sin(r) 77 | return np.array([new_x, new_y]).reshape((2,)) 78 | 79 | x = np.random.normal(0, x_var, (batch_size, (int)(n_dim/2))) 80 | y = np.random.normal(0, y_var, (batch_size, (int)(n_dim/2))) 81 | z = np.empty((batch_size, n_dim), dtype=np.float32) 82 | for batch in range(batch_size): 83 | for zi in range((int)(n_dim/2)): 84 | if label_indices is not None: 85 | z[batch, zi*2:zi*2+2] = sample(x[batch, zi], y[batch, zi], label_indices[batch], n_labels) 86 | else: 87 | z[batch, zi*2:zi*2+2] = sample(x[batch, zi], y[batch, zi], np.random.randint(0, n_labels), n_labels) 88 | 89 | return z 90 | 91 | def swiss_roll(batch_size, n_dim=2, n_labels=10, label_indices=None): 92 | if n_dim != 2: 93 | raise Exception("n_dim must be 2.") 94 | 95 | def sample(label, n_labels): 96 | uni = np.random.uniform(0.0, 1.0) / float(n_labels) + float(label) / float(n_labels) 97 | r = sqrt(uni) * 3.0 98 | rad = np.pi * 4.0 * sqrt(uni) 99 | x = r * cos(rad) 100 | y = r * sin(rad) 101 | return np.array([x, y]).reshape((2,)) 102 | 103 | z = np.zeros((batch_size, n_dim), dtype=np.float32) 104 | for batch in range(batch_size): 105 | for zi in range((int)(n_dim/2)): 106 | if label_indices is not None: 107 | z[batch, zi*2:zi*2+2] = sample(label_indices[batch], n_labels) 108 | else: 109 | z[batch, zi*2:zi*2+2] = sample(np.random.randint(0, n_labels), n_labels) 110 | return z -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Most codes from https://github.com/carpedm20/DCGAN-tensorflow 3 | """ 4 | from __future__ import division 5 | import math 6 | import random 7 | import pprint 8 | import scipy.misc 9 | import numpy as np 10 | from time import gmtime, strftime 11 | from six.moves import xrange 12 | import matplotlib.pyplot as plt 13 | import os, gzip 14 | 15 | import tensorflow as tf 16 | import tensorflow.contrib.slim as slim 17 | 18 | def load_mnist(dataset_name): 19 | data_dir = os.path.join("./data", dataset_name) 20 | 21 | def extract_data(filename, num_data, head_size, data_size): 22 | with gzip.open(filename) as bytestream: 23 | bytestream.read(head_size) 24 | buf = bytestream.read(data_size * num_data) 25 | data = np.frombuffer(buf, dtype=np.uint8).astype(np.float) 26 | return data 27 | 28 | data = extract_data(data_dir + '/train-images-idx3-ubyte.gz', 60000, 16, 28 * 28) 29 | trX = data.reshape((60000, 28, 28, 1)) 30 | 31 | data = extract_data(data_dir + '/train-labels-idx1-ubyte.gz', 60000, 8, 1) 32 | trY = data.reshape((60000)) 33 | 34 | data = extract_data(data_dir + '/t10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28) 35 | teX = data.reshape((10000, 28, 28, 1)) 36 | 37 | data = extract_data(data_dir + '/t10k-labels-idx1-ubyte.gz', 10000, 8, 1) 38 | teY = data.reshape((10000)) 39 | 40 | trY = np.asarray(trY) 41 | teY = np.asarray(teY) 42 | 43 | X = np.concatenate((trX, teX), axis=0) 44 | y = np.concatenate((trY, teY), axis=0).astype(np.int) 45 | 46 | seed = 547 47 | np.random.seed(seed) 48 | np.random.shuffle(X) 49 | np.random.seed(seed) 50 | np.random.shuffle(y) 51 | 52 | y_vec = np.zeros((len(y), 10), dtype=np.float) 53 | for i, label in enumerate(y): 54 | y_vec[i, y[i]] = 1.0 55 | 56 | return X / 255., y_vec 57 | 58 | def check_folder(log_dir): 59 | if not os.path.exists(log_dir): 60 | os.makedirs(log_dir) 61 | return log_dir 62 | 63 | def show_all_variables(): 64 | model_vars = tf.trainable_variables() 65 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 66 | 67 | def get_image(image_path, input_height, input_width, resize_height=64, resize_width=64, crop=True, grayscale=False): 68 | image = imread(image_path, grayscale) 69 | return transform(image, input_height, input_width, resize_height, resize_width, crop) 70 | 71 | def save_images(images, size, image_path): 72 | return imsave(inverse_transform(images), size, image_path) 73 | 74 | def imread(path, grayscale = False): 75 | if (grayscale): 76 | return scipy.misc.imread(path, flatten = True).astype(np.float) 77 | else: 78 | return scipy.misc.imread(path).astype(np.float) 79 | 80 | def merge_images(images, size): 81 | return inverse_transform(images) 82 | 83 | def merge(images, size): 84 | h, w = images.shape[1], images.shape[2] 85 | if (images.shape[3] in (3,4)): 86 | c = images.shape[3] 87 | img = np.zeros((h * size[0], w * size[1], c)) 88 | for idx, image in enumerate(images): 89 | i = idx % size[1] 90 | j = idx // size[1] 91 | img[j * h:j * h + h, i * w:i * w + w, :] = image 92 | return img 93 | elif images.shape[3]==1: 94 | img = np.zeros((h * size[0], w * size[1])) 95 | for idx, image in enumerate(images): 96 | i = idx % size[1] 97 | j = idx // size[1] 98 | img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0] 99 | return img 100 | else: 101 | raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4') 102 | 103 | def imsave(images, size, path): 104 | image = np.squeeze(merge(images, size)) 105 | return scipy.misc.imsave(path, image) 106 | 107 | def center_crop(x, crop_h, crop_w, resize_h=64, resize_w=64): 108 | if crop_w is None: 109 | crop_w = crop_h 110 | h, w = x.shape[:2] 111 | j = int(round((h - crop_h)/2.)) 112 | i = int(round((w - crop_w)/2.)) 113 | return scipy.misc.imresize(x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w]) 114 | 115 | def transform(image, input_height, input_width, resize_height=64, resize_width=64, crop=True): 116 | if crop: 117 | cropped_image = center_crop(image, input_height, input_width, resize_height, resize_width) 118 | else: 119 | cropped_image = scipy.misc.imresize(image, [resize_height, resize_width]) 120 | return np.array(cropped_image)/127.5 - 1. 121 | 122 | def inverse_transform(images): 123 | return (images+1.)/2. 124 | 125 | """ Drawing Tools """ 126 | # borrowed from https://github.com/ykwon0407/variational_autoencoder/blob/master/variational_bayes.ipynb 127 | def save_scattered_image(z, id, z_range_x, z_range_y, name='scattered_image.jpg'): 128 | N = 10 129 | plt.figure(figsize=(8, 6)) 130 | plt.scatter(z[:, 0], z[:, 1], c=np.argmax(id, 1), marker='o', edgecolor='none', cmap=discrete_cmap(N, 'jet')) 131 | plt.colorbar(ticks=range(N)) 132 | axes = plt.gca() 133 | axes.set_xlim([-z_range_x, z_range_x]) 134 | axes.set_ylim([-z_range_y, z_range_y]) 135 | plt.grid(True) 136 | plt.savefig(name) 137 | 138 | # borrowed from https://gist.github.com/jakevdp/91077b0cae40f8f8244a 139 | def discrete_cmap(N, base_cmap=None): 140 | """Create an N-bin discrete colormap from the specified input map""" 141 | 142 | # Note that if base_cmap is a string or None, you can simply do 143 | # return plt.cm.get_cmap(base_cmap, N) 144 | # The following works for string, None, or a colormap instance: 145 | 146 | base = plt.cm.get_cmap(base_cmap) 147 | color_list = base(np.linspace(0, 1, N)) 148 | cmap_name = base.name + str(N) 149 | return base.from_list(cmap_name, color_list, N) --------------------------------------------------------------------------------