├── .gitignore ├── LICENSE ├── README.md ├── main.py ├── model.py ├── ops.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # log directory 2 | logs/ 3 | 4 | # data directory 5 | data/* 6 | 7 | # sample directory 8 | samples/ 9 | 10 | # python 11 | *.pyc 12 | 13 | logs_180417/ 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Changwoo Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WGAN-GP-tensorflow 2 | 3 | Tensorflow implementation of paper ["Improved Training of Wasserstein GANs"](https://arxiv.org/abs/1704.00028). 4 | 5 | ![gif](https://thumbs.gfycat.com/VerifiableHonoredHind-size_restricted.gif) 6 | 7 | * 0 epoch 8 | 9 | ![epoch0](http://cfile24.uf.tistory.com/image/99DE3E355AD971992E9F3C) 10 | 11 | * 25 epoch 12 | 13 | ![img](http://cfile29.uf.tistory.com/image/99274A355AD9719925FEF4) 14 | 15 | * 50 epoch 16 | 17 | ![epoch50](http://cfile23.uf.tistory.com/image/9927653B5AD971B537B169) 18 | 19 | * 100 epoch 20 | 21 | ![img](http://cfile8.uf.tistory.com/image/996E113B5AD971CB1010F7) 22 | 23 | * 150 epoch 24 | 25 | ![img](http://cfile28.uf.tistory.com/image/9999403C5AD971DB2483C5) 26 | 27 | ## Prerequisites 28 | 29 | - Python 2.7 or 3.5 30 | - Tensorflow 1.3+ 31 | - SciPy 32 | - Aligned&Cropped celebA dataset([download](https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADSNUu0bseoCKuxuI5ZeTl1a/Img?dl=0)) 33 | - (Optional) moviepy (for visualization) 34 | 35 | ## Usage 36 | 37 | * Download aligned&cropped celebA dataset([link](https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADSNUu0bseoCKuxuI5ZeTl1a/Img?dl=0)) and unzip at ./data/img_align_celeba 38 | 39 | * Train: 40 | 41 | ``` 42 | $ python main.py --train 43 | ``` 44 | 45 | Or you can set some arguments like: 46 | 47 | ``` 48 | $ python main.py --dataset=celebA --max_epoch=50 --learning_rate=1e-4 --train 49 | ``` 50 | 51 | * Test: 52 | 53 | ``` 54 | $ python main.py 55 | ``` 56 | 57 | ## Acknowledge 58 | 59 | Based on the implementation [carpedm20/DCGAN-tensorflow](https://github.com/carpedm20/DCGAN-tensorflow), [LynnHo/DCGAN-LSGAN-WGAN-WGAN-GP-Tensorflow](https://github.com/LynnHo/DCGAN-LSGAN-WGAN-WGAN-GP-Tensorflow), [shekkizh/WassersteinGAN.tensorflow](https://github.com/shekkizh/WassersteinGAN.tensorflow) and [igul222/improved_wgan_training](https://github.com/igul222/improved_wgan_training). 60 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.misc 3 | import numpy as np 4 | 5 | import tensorflow as tf 6 | 7 | from model import WGAN 8 | from utils import pp, visualize, show_all_variables, forward_test 9 | 10 | flags = tf.app.flags 11 | flags.DEFINE_integer("max_epoch",150,"Maximum epoch") 12 | flags.DEFINE_integer("input_height",108,"Input height") 13 | flags.DEFINE_integer("input_width",108,"Input width") 14 | flags.DEFINE_integer("batch_size", 64, "Batch size") 15 | flags.DEFINE_integer("z_dim", 100, "Dimension of z") 16 | flags.DEFINE_integer("output_height", 64, "Output height") 17 | flags.DEFINE_integer("output_width", 64, "Output width") 18 | flags.DEFINE_boolean("crop",True,"Center Crop") 19 | flags.DEFINE_string("dataset","celebA","Name of dataset") 20 | flags.DEFINE_string("data_pattern","*.jpg","data file pattern") 21 | flags.DEFINE_string("log_dir","./logs","Log directory path") 22 | flags.DEFINE_string("sample_dir","./samples","sample directory") 23 | flags.DEFINE_integer("n_critic",5,"Number of critic iteration") 24 | flags.DEFINE_float("beta1",0.,"beta1 for Adam Optimizer") 25 | flags.DEFINE_float("beta2",0.9,"beta2 for Adam Optimizer") 26 | flags.DEFINE_float("learning_rate",1e-4,"learning rate") 27 | flags.DEFINE_integer("g_dim",64,"Dimension of generator") 28 | flags.DEFINE_integer("d_dim",64,"Dimension of discriminator") 29 | flags.DEFINE_boolean("train",False,"train") 30 | flags.DEFINE_boolean("forward_test",False,"Forward Test") 31 | FLAGS = flags.FLAGS 32 | 33 | 34 | def main(_): 35 | pp.pprint(flags.FLAGS.__flags) 36 | 37 | # run_config = tf.ConfigProto() 38 | # run_config.gpu_options.allow_growth=True 39 | 40 | # with tf.Session(config=run_config) as sess: 41 | with tf.Session() as sess: 42 | wgan = WGAN(sess, 43 | input_height=FLAGS.input_height, 44 | input_width=FLAGS.input_width, 45 | crop=FLAGS.crop, 46 | batch_size=FLAGS.batch_size, 47 | output_height=FLAGS.output_height, 48 | output_width=FLAGS.output_width, 49 | z_dim=FLAGS.z_dim, 50 | g_dim=FLAGS.g_dim, 51 | d_dim=FLAGS.d_dim, 52 | dataset_name=FLAGS.dataset, 53 | input_fname_pattern=FLAGS.data_pattern, 54 | log_dir=FLAGS.log_dir, 55 | sample_dir=FLAGS.sample_dir, 56 | max_epoch=FLAGS.max_epoch, 57 | n_critic=FLAGS.n_critic, 58 | lr=FLAGS.learning_rate, 59 | beta1=FLAGS.beta1, 60 | beta2=FLAGS.beta2) 61 | 62 | 63 | show_all_variables() 64 | 65 | if FLAGS.train: 66 | wgan.train() 67 | else: 68 | if not wgan.load(FLAGS.log_dir): 69 | raise Exception("[!] Train a model first, then run test mode") 70 | 71 | 72 | 73 | if FLAGS.forward_test: 74 | forward_test(sess,wgan,FLAGS, FLAGS.test_num) 75 | OPTION = 1 76 | visualize(sess,wgan, FLAGS, OPTION) 77 | 78 | 79 | if __name__=='__main__': 80 | tf.app.run() 81 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import tensorflow as tf 3 | import os 4 | from ops import * 5 | from glob import glob 6 | from utils import * 7 | import numpy as np 8 | 9 | 10 | class WGAN(object): 11 | def __init__(self, sess, input_height=108, input_width=108, crop=True, 12 | batch_size=64, sample_num=64, output_height=64, output_width=64, 13 | y_dim=None, z_dim=100, g_dim=None, d_dim=None, c_dim=3, 14 | dataset_name='celebA', input_fname_pattern='*.jpg', 15 | log_dir=None, 16 | sample_dir=None, 17 | max_epoch=50, 18 | n_critic=5, lr=1e-4, beta1=0., beta2=0.9): 19 | 20 | """ 21 | Original code from DCGAN-tensorflow by carpedm20 22 | """ 23 | 24 | 25 | self.sess = sess 26 | self.input_height = input_height 27 | self.input_width = input_width 28 | self.crop = crop 29 | self.batch_size = batch_size 30 | self.sample_num= sample_num 31 | self.output_height=output_height 32 | self.output_width = output_width 33 | self.z_dim = z_dim 34 | 35 | self.c_dim = c_dim 36 | self.g_dim = g_dim 37 | self.d_dim = d_dim 38 | self.dataset_name = get_dataset(dataset_name) 39 | self.input_fname_pattern = input_fname_pattern 40 | self.log_dir = log_dir 41 | if not os.path.exists(self.log_dir): 42 | os.makedirs(self.log_dir) 43 | 44 | self.sample_dir = sample_dir 45 | if not os.path.exists(self.sample_dir): 46 | os.makedirs(self.sample_dir) 47 | 48 | self.data = glob(os.path.join("./data", self.dataset_name, self.input_fname_pattern)) 49 | if len(self.data) is 0: 50 | raise Exception("[!] No training data. Program shut down") 51 | 52 | self.max_epoch = max_epoch 53 | self.n_critic = n_critic 54 | self.lr = lr #learning rate 55 | self.beta1 = beta1 56 | self.beta2 = beta2 57 | 58 | 59 | 60 | self.build_model() 61 | 62 | def read_input(self): 63 | """ 64 | Code from https://github.com/tdeboissiere/DeepLearningImplementations 65 | """ 66 | with tf.device('/cpu:0'): 67 | reader = tf.WholeFileReader() 68 | filename_queue = tf.train.string_input_producer(self.data) 69 | data_num = len(self.data) 70 | key, value = reader.read(filename_queue) 71 | image = tf.image.decode_jpeg(value, channels=self.c_dim, name="dataset_image") 72 | 73 | # Crop and other random augmentations 74 | image = tf.image.random_flip_left_right(image) 75 | 76 | image = tf.image.crop_to_bounding_box(image, (218 - self.input_height) //2, (178 - self.input_width) // 2, self.input_height, self.input_width) 77 | if self.crop: 78 | image = tf.image.resize_images(image, [self.output_height, self.output_width], method=tf.image.ResizeMethod.BICUBIC) 79 | 80 | 81 | num_preprocess_threads=4 82 | num_examples_per_epoch=800 83 | min_queue_examples = int(0.1 * num_examples_per_epoch) 84 | img_batch = tf.train.batch([image], 85 | batch_size=self.batch_size, 86 | num_threads=4, 87 | capacity=min_queue_examples + 2*self.batch_size) 88 | img_batch = 2*((tf.cast(img_batch, tf.float32) / 255.) - 0.5) 89 | 90 | return img_batch, data_num 91 | 92 | 93 | 94 | def build_model(self): 95 | if self.crop: 96 | image_dims = [self.output_height, self.output_width, self.c_dim] 97 | else: 98 | image_dims = [self.input_height, self.input_width, self.c_dim] 99 | 100 | self.X_real, self.data_num = self.read_input() 101 | 102 | self.z = tf.placeholder( 103 | tf.float32, [None, self.z_dim], name='z') 104 | self.z_sum = tf.summary.histogram("z", self.z) 105 | 106 | self.X_fake = self.generator(self.z) 107 | # self.real_img_sum = tf.summary.image("image_real", self.X_real, max_outputs=4) 108 | # self.fake_img_sum = tf.summary.image("image_fake", self.X_fake, max_outputs=4) 109 | 110 | self.d_logits_fake = self.discriminator(self.X_fake, reuse=False) 111 | self.d_logits_real = self.discriminator(self.X_real, reuse=True) 112 | # WGAN Loss 113 | self.d_loss = tf.reduce_mean(self.d_logits_fake) - tf.reduce_mean(self.d_logits_real) 114 | self.g_loss = -tf.reduce_mean(self.d_logits_fake) 115 | 116 | # Gradient Penalty 117 | self.epsilon = tf.random_uniform( 118 | shape=[self.batch_size, 1, 1, 1], 119 | minval=0., 120 | maxval=1.) 121 | X_hat = self.X_real + self.epsilon * (self.X_fake - self.X_real) 122 | D_X_hat = self.discriminator(X_hat, reuse=True) 123 | grad_D_X_hat = tf.gradients(D_X_hat, [X_hat])[0] 124 | red_idx = range(1, X_hat.shape.ndims) 125 | slopes = tf.sqrt(tf.reduce_sum(tf.square(grad_D_X_hat), reduction_indices=red_idx)) 126 | gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2) 127 | self.d_loss = self.d_loss + 10.0 * gradient_penalty 128 | 129 | self.d_loss_sum = tf.summary.scalar("Discriminator_loss", self.d_loss) 130 | self.g_loss_sum = tf.summary.scalar("Generator_loss", self.g_loss) 131 | self.gp_sum = tf.summary.scalar("Gradient_penalty", gradient_penalty) 132 | 133 | train_vars = tf.trainable_variables() 134 | 135 | for v in train_vars: 136 | tf.add_to_collection("reg_loss", tf.nn.l2_loss(v)) 137 | 138 | self.generator_vars = [v for v in train_vars if 'g_' in v.name] 139 | self.discriminator_vars = [v for v in train_vars if 'd_' in v.name] 140 | 141 | 142 | 143 | self.g_optimizer = tf.train.AdamOptimizer(learning_rate=self.lr, name='g_opt', 144 | beta1=self.beta1, beta2=self.beta2).minimize(self.g_loss, var_list=self.generator_vars) 145 | self.d_optimizer = tf.train.AdamOptimizer(learning_rate=self.lr, name='d_opt', 146 | beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=self.discriminator_vars) 147 | 148 | self.d_sum = tf.summary.merge([self.z_sum, self.d_loss_sum]) 149 | self.g_sum = tf.summary.merge([self.z_sum, self.g_loss_sum]) 150 | # Sample image 151 | sample_ = self.generator(self.z, reuse=True) 152 | sample_ = merge(sample_, image_manifold_size(sample_.shape[0])) 153 | sample_ = tf.cast(tf.expand_dims(sample_, 0), tf.float32) 154 | self.sample_sum = tf.summary.image("generated_image",sample_) 155 | with tf.variable_scope('counter'): 156 | self.counter = tf.get_variable('counter', shape=[1], initializer=tf.constant_initializer([0]), dtype=tf.int32) 157 | self.update_counter = tf.assign(self.counter, tf.add(self.counter, 1)) 158 | 159 | 160 | 161 | self.saver = tf.train.Saver() 162 | self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) 163 | 164 | 165 | self.initialize_model() 166 | 167 | def initialize_model(self): 168 | print("[*] initializing network...") 169 | 170 | # self.sess.run(tf.global_variables_initializer()) 171 | # ckpt = tf.train.get_checkpoint_state(self.log_dir) 172 | # if ckpt and ckpt.model_checkpoint_path: 173 | # self.saver.restore(self.sess, ckpt.model_checkpoint_path) 174 | # print("[*] Model restored.") 175 | if not self.load(self.log_dir): 176 | self.sess.run(tf.global_variables_initializer()) 177 | self.coord = tf.train.Coordinator() 178 | self.threads = tf.train.start_queue_runners(self.sess, self.coord) 179 | 180 | 181 | def train(self): 182 | print("[*] Training Improved Wasserstein GAN") 183 | 184 | 185 | sample_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]) 186 | start_time = time.time() 187 | 188 | batch_epoch = self.data_num // (self.batch_size * self.n_critic) 189 | max_iterations = self.max_epoch * batch_epoch 190 | print("[*] Start from step %d." % (self.sess.run(self.counter))) 191 | for step in xrange(self.sess.run(self.counter), max_iterations): 192 | 193 | epoch = step // batch_epoch 194 | batch_step = step % batch_epoch + 1 195 | 196 | # Critic 197 | for critic_iter in range(self.n_critic): 198 | self.batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]) 199 | # Update Discriminator 200 | _, summary_str = self.sess.run([self.d_optimizer, self.d_sum], feed_dict={self.z: self.batch_z})#, self.X_real: self.batch_images}) 201 | # Update Generator 202 | self.batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]) 203 | self.summary_writer.add_summary(summary_str, step) 204 | _, summary_str = self.sess.run([self.g_optimizer, self.g_sum] ,feed_dict={self.z: self.batch_z}) 205 | self.summary_writer.add_summary(summary_str, step) 206 | if step%100==0: 207 | summary_str = self.sess.run(self.sample_sum, feed_dict={self.z: sample_z}) 208 | self.summary_writer.add_summary(summary_str, step) 209 | 210 | if step%200==199: 211 | stop_time = time.time() 212 | duration = (stop_time - start_time) / 200.0 213 | start_time = stop_time 214 | g_loss_val, d_loss_val = self.sess.run([self.g_loss, self.d_loss], feed_dict={self.z: self.batch_z})#, self.X_real: self.batch_images}) 215 | print("Time: %g/itr, Epoch: %d, Step: (%d/%d), generator loss: %g, discriminator loss: %g" % (duration, epoch, batch_step, batch_epoch, g_loss_val, d_loss_val)) 216 | generated_images = self.sess.run(self.X_fake, feed_dict={self.z: sample_z}) 217 | save_images(generated_images, 218 | image_manifold_size(generated_images.shape[0]), 219 | './{}/sample_{:02d}_{:04d}.png'.format(self.sample_dir, epoch, batch_step)) 220 | 221 | 222 | if step%1000==0: 223 | self.saver.save(self.sess, self.log_dir + "/model.ckpt", global_step=step) 224 | 225 | 226 | self.sess.run(self.update_counter) 227 | 228 | self.saver.save(self.sess, self.log_dir + "/model.ckpt", global_step=max_iterations) 229 | 230 | def generator(self, z, reuse=False): 231 | with tf.variable_scope("generator") as scope: 232 | if reuse: 233 | scope.reuse_variables() 234 | dims = [self.g_dim*8, self.g_dim*4, self.g_dim*2, self.g_dim, 3] 235 | s_h, s_w = self.output_height, self.output_width 236 | s_h, s_w = [s_h//16, s_h//8, s_h//4, s_h//2, s_h], [s_w//16, s_w//8, s_w//4, s_w//2, s_w] 237 | 238 | self.z_ = linear("g_h0_lin", z, dims[0] * s_h[0] * s_w[0]) 239 | 240 | self.h0 = tf.reshape(self.z_, [-1, s_h[0], s_w[0], dims[0]]) 241 | h0 = tf.nn.relu(batch_norm(self.h0, name="g_bn0")) 242 | 243 | 244 | h1 = deconv2d("g_h1", h0, [self.batch_size, s_h[1], s_w[1], dims[1]]) 245 | h1 = tf.nn.relu(batch_norm(h1, name="g_bn1")) 246 | 247 | h2 = deconv2d("g_h2", h1, [self.batch_size, s_h[2], s_w[2], dims[2]]) 248 | h2 = tf.nn.relu(batch_norm(h2, name="g_bn2")) 249 | 250 | h3 = deconv2d("g_h3", h2, [self.batch_size, s_h[3], s_w[3], dims[3]]) 251 | h3 = tf.nn.relu(batch_norm(h3, name="g_bn3")) 252 | 253 | h4 = deconv2d("g_h4", h3, [self.batch_size, s_h[4], s_w[4], dims[4]]) 254 | 255 | return tf.nn.tanh(h4, name='pred_image') 256 | 257 | 258 | def discriminator(self, input_image, reuse=False): 259 | with tf.variable_scope("discriminator") as scope: 260 | if reuse: 261 | scope.reuse_variables() 262 | 263 | 264 | dims = [self.c_dim, self.d_dim, self.d_dim*2, self.d_dim*4, self.d_dim*8] 265 | # s_h, s_w = self.output_height, self.output_width 266 | # s_h, s_w = [s_h, s_h*2, s_h*4, s_h*8, s_h*16], [s_w, s_w*2, s_w*4, s_w*8, s_w*16] 267 | 268 | h0 = conv2d("d_h0", input_image, dims[1]) 269 | h0 = lrelu(h0) 270 | 271 | h1 = conv2d("d_h1", h0, dims[2]) 272 | h1 = lrelu(layer_norm(h1, name="d_ln1")) 273 | 274 | h2 = conv2d("d_h2", h1, dims[3]) 275 | h2 = lrelu(layer_norm(h2, name="d_ln2")) 276 | 277 | h3 = conv2d("d_h3", h2, dims[4]) 278 | h3 = lrelu(layer_norm(h3, name="d_ln3")) 279 | 280 | 281 | h3 = tf.reshape(h3, [-1, 4*4*self.d_dim*8]) 282 | h_pred = linear("d_h4", h3, 1) 283 | h_pred = tf.reshape(h_pred, [-1]) 284 | return h_pred 285 | 286 | 287 | def load(self, log_dir): 288 | print("[*] Reading Checkpoints...") 289 | ckpt = tf.train.get_checkpoint_state(log_dir) 290 | if ckpt and ckpt.model_checkpoint_path: 291 | self.saver.restore(self.sess, ckpt.model_checkpoint_path) 292 | print("[*] Model restored.") 293 | return True 294 | else: 295 | print("[*] Failed to find a checkpoint") 296 | return False 297 | 298 | 299 | 300 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | 5 | def batch_norm(x, epsilon=1e-5, decay=0.9, name='batch_norm'): 6 | with tf.variable_scope(name): 7 | return tf.contrib.layers.batch_norm(x, decay=decay, epsilon=epsilon, scale=True) 8 | 9 | def layer_norm(x, trainable=True, name='layer_norm'): 10 | with tf.variable_scope(name): 11 | return tf.contrib.layers.layer_norm(x, trainable=trainable) 12 | 13 | def linear(name, x, output_dim, stddev=0.02): 14 | with tf.variable_scope(name): 15 | w = tf.get_variable('w', shape=[x.get_shape()[-1], output_dim], 16 | initializer=tf.random_normal_initializer(stddev=stddev)) 17 | y = tf.matmul(x, w) 18 | return y 19 | 20 | 21 | def deconv2d(name, input_, output_shape, strides=[1,2,2,1], ksize=5, stddev=0.02, trainable=True): 22 | with tf.variable_scope(name): 23 | w = tf.get_variable('w', shape=[ksize, ksize, output_shape[-1], input_.get_shape()[-1]], 24 | initializer=tf.random_normal_initializer(stddev=stddev), trainable=trainable) 25 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, 26 | strides=strides) 27 | 28 | return deconv 29 | 30 | 31 | def conv2d(name, input_, output_dim, strides=[1,2,2,1], ksize=5, stddev=0.02): 32 | with tf.variable_scope(name): 33 | w = tf.get_variable('w', [ksize, ksize, input_.get_shape()[-1],output_dim], 34 | initializer=tf.random_normal_initializer(stddev=stddev)) 35 | 36 | conv = tf.nn.conv2d(input_, w, strides=strides, padding='SAME') 37 | return conv 38 | 39 | 40 | 41 | def lrelu(x, leak=0.2, name='lrelu'): 42 | return tf.maximum(x, leak*x) 43 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codes from https://github.com/carpedm20/DCGAN-tensorflow 3 | 4 | 5 | Some codes from https://github.com/Newmu/wgan_gp_code 6 | """ 7 | from __future__ import division 8 | import math 9 | import time 10 | import json 11 | import random 12 | import pprint 13 | import scipy.misc 14 | import numpy as np 15 | from time import gmtime, strftime 16 | from six.moves import xrange 17 | 18 | import tensorflow as tf 19 | import tensorflow.contrib.slim as slim 20 | 21 | pp = pprint.PrettyPrinter() 22 | 23 | get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1]) 24 | 25 | def show_all_variables(): 26 | model_vars = tf.trainable_variables() 27 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 28 | 29 | def get_image(image_path, input_height, input_width, 30 | resize_height=64, resize_width=64, 31 | crop=True, grayscale=False): 32 | image = imread(image_path, grayscale) 33 | return transform(image, input_height, input_width, 34 | resize_height, resize_width, crop) 35 | 36 | def save_images(images, size, image_path): 37 | return imsave(inverse_transform(images), size, image_path) 38 | 39 | def imread(path, grayscale = False): 40 | if (grayscale): 41 | return scipy.misc.imread(path, flatten = True).astype(np.float) 42 | else: 43 | return scipy.misc.imread(path).astype(np.float) 44 | 45 | def merge_images(images, size): 46 | return inverse_transform(merge(images, size)) 47 | 48 | def merge(images, size): 49 | try: 50 | h, w = images.get_shape().as_list()[1], images.get_shape().as_list()[2] 51 | except: 52 | h, w = images.shape[1], images.shape[2] 53 | if (images.shape[3] in (3,4)): 54 | c = images.shape[3] 55 | img = np.zeros((h * size[0], w * size[1], c)) 56 | try: 57 | for idx, image in enumerate(images): 58 | i = idx % size[1] 59 | j = idx // size[1] 60 | img[j * h:j * h + h, i * w:i * w + w, :] = image 61 | return img 62 | except: 63 | img = tf.zeros((h, w*size[1], images.get_shape().as_list()[3])) 64 | 65 | for idx in range(0, images.get_shape().as_list()[0]): 66 | 67 | image = images[idx,:,:,:] 68 | i = idx % size[1] 69 | j = idx // size[1] 70 | if i==0: 71 | img_row = images[idx,:,:,:] 72 | continue 73 | # img[j*h: j*h + h, i*w: i*w + w, :] = image 74 | img_row = tf.concat([img_row, image], 1) 75 | if i==size[1]-1: 76 | img = tf.concat([img, img_row], 0) 77 | 78 | img = img[h:,:,:]#tf.reshape(img, [h*size[0], w*size[1], images.get_shape().as_list()[3]]) 79 | return img 80 | elif images.shape[3]==1: 81 | img = np.zeros((h * size[0], w * size[1])) 82 | for idx, image in enumerate(images): 83 | i = idx % size[1] 84 | j = idx // size[1] 85 | img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0] 86 | return img 87 | else: 88 | raise ValueError('in merge(images,size) images parameter ' 89 | 'must have dimensions: HxW or HxWx3 or HxWx4') 90 | 91 | def imsave(images, size, path): 92 | image = np.squeeze(merge(images, size)) 93 | return scipy.misc.imsave(path, image) 94 | 95 | def center_crop(x, crop_h, crop_w, 96 | resize_h=64, resize_w=64): 97 | if crop_w is None: 98 | crop_w = crop_h 99 | h, w = x.shape[:2] 100 | j = int(round((h - crop_h)/2.)) 101 | i = int(round((w - crop_w)/2.)) 102 | return scipy.misc.imresize( 103 | x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w]) 104 | 105 | def transform(image, input_height, input_width, 106 | resize_height=64, resize_width=64, crop=True): 107 | if crop: 108 | cropped_image = center_crop( 109 | image, input_height, input_width, 110 | resize_height, resize_width) 111 | else: 112 | cropped_image = scipy.misc.imresize(image, [resize_height, resize_width]) 113 | return np.array(cropped_image)/127.5 - 1. 114 | 115 | def inverse_transform(images): 116 | return (images+1.)/2. 117 | 118 | 119 | def get_dataset(dataset_name): 120 | if 'celebA' in dataset_name or 'celeba' in dataset_name: 121 | return 'img_align_celeba' 122 | 123 | def to_json(output_path, *layers): 124 | with open(output_path, "w") as layer_f: 125 | lines = "" 126 | for w, b, bn in layers: 127 | layer_idx = w.name.split('/')[0].split('h')[1] 128 | 129 | B = b.eval() 130 | 131 | if "lin/" in w.name: 132 | W = w.eval() 133 | depth = W.shape[1] 134 | else: 135 | W = np.rollaxis(w.eval(), 2, 0) 136 | depth = W.shape[0] 137 | 138 | biases = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(B)]} 139 | if bn != None: 140 | gamma = bn.gamma.eval() 141 | beta = bn.beta.eval() 142 | 143 | gamma = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(gamma)]} 144 | beta = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(beta)]} 145 | else: 146 | gamma = {"sy": 1, "sx": 1, "depth": 0, "w": []} 147 | beta = {"sy": 1, "sx": 1, "depth": 0, "w": []} 148 | 149 | if "lin/" in w.name: 150 | fs = [] 151 | for w in W.T: 152 | fs.append({"sy": 1, "sx": 1, "depth": W.shape[0], "w": ['%.2f' % elem for elem in list(w)]}) 153 | 154 | lines += """ 155 | var layer_%s = { 156 | "layer_type": "fc", 157 | "sy": 1, "sx": 1, 158 | "out_sx": 1, "out_sy": 1, 159 | "stride": 1, "pad": 0, 160 | "out_depth": %s, "in_depth": %s, 161 | "biases": %s, 162 | "gamma": %s, 163 | "beta": %s, 164 | "filters": %s 165 | };""" % (layer_idx.split('_')[0], W.shape[1], W.shape[0], biases, gamma, beta, fs) 166 | else: 167 | fs = [] 168 | for w_ in W: 169 | fs.append({"sy": 5, "sx": 5, "depth": W.shape[3], "w": ['%.2f' % elem for elem in list(w_.flatten())]}) 170 | 171 | lines += """ 172 | var layer_%s = { 173 | "layer_type": "deconv", 174 | "sy": 5, "sx": 5, 175 | "out_sx": %s, "out_sy": %s, 176 | "stride": 2, "pad": 1, 177 | "out_depth": %s, "in_depth": %s, 178 | "biases": %s, 179 | "gamma": %s, 180 | "beta": %s, 181 | "filters": %s 182 | };""" % (layer_idx, 2**(int(layer_idx)+2), 2**(int(layer_idx)+2), 183 | W.shape[0], W.shape[3], biases, gamma, beta, fs) 184 | layer_f.write(" ".join(lines.replace("'","").split())) 185 | 186 | def make_gif(images, fname, duration=2, true_image=False): 187 | import moviepy.editor as mpy 188 | 189 | def make_frame(t): 190 | try: 191 | x = images[int(len(images)/duration*t)] 192 | except: 193 | x = images[-1] 194 | 195 | if true_image: 196 | return x.astype(np.uint8) 197 | else: 198 | return ((x+1)/2*255).astype(np.uint8) 199 | 200 | clip = mpy.VideoClip(make_frame, duration=duration) 201 | clip.write_gif(fname, fps = len(images) / duration) 202 | 203 | def forward_test(sess, wgan_gp, config, test_num): 204 | print("[*] Forward Test for generating %s images Start" % (test_num*wgan_gp.z_dim)) 205 | 206 | start_time = time.time() 207 | for i in range (0, test_num): 208 | values= np.arange(0,1,1./config.batch_size) 209 | z_sample = np.random.uniform(-1, 1, size=(config.batch_size, wgan_gp.z_dim)) 210 | 211 | for idx in xrange(wgan_gp.z_dim): 212 | for kdx, z in enumerate(z_sample): 213 | z[idx] = values[kdx] 214 | samples = sess.run(wgan_gp.X_fake, feed_dict={wgan_gp.z: z_sample}) 215 | print("[*] Test Finished. Elasped Time : %4.4f" %( time.time()-start_time)) 216 | def visualize(sess, wgan_gp, config, option): 217 | image_frame_dim = int(math.ceil(config.batch_size**.5)) 218 | if option == 0: 219 | z_sample = np.random.uniform(-0.5, 0.5, size=(config.batch_size, wgan_gp.z_dim)) 220 | samples = sess.run(wgan_gp.generator, feed_dict={wgan_gp.z: z_sample}) 221 | save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y-%m-%d-%H-%M-%S", gmtime())) 222 | elif option == 1: 223 | values = np.arange(0, 1, 1./config.batch_size) 224 | for idx in xrange(wgan_gp.z_dim): 225 | print(" [*] %d" % idx) 226 | z_sample = np.random.uniform(-1, 1, size=(config.batch_size , wgan_gp.z_dim)) 227 | for kdx, z in enumerate(z_sample): 228 | z[idx] = values[kdx] 229 | samples = sess.run(wgan_gp.X_fake, feed_dict={wgan_gp.z: z_sample}) 230 | 231 | save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_arange_%s.png' % (idx)) 232 | elif option == 2: 233 | values = np.arange(0, 1, 1./config.batch_size) 234 | for idx in [random.randint(0, wgan_gp.z_dim - 1) for _ in xrange(wgan_gp.z_dim)]: 235 | print(" [*] %d" % idx) 236 | z = np.random.uniform(-0.2, 0.2, size=(wgan_gp.z_dim)) 237 | z_sample = np.tile(z, (config.batch_size, 1)) 238 | #z_sample = np.zeros([config.batch_size, wgan_gp.z_dim]) 239 | for kdx, z in enumerate(z_sample): 240 | z[idx] = values[kdx] 241 | 242 | if config.dataset == "mnist": 243 | y = np.random.choice(10, config.batch_size) 244 | y_one_hot = np.zeros((config.batch_size, 10)) 245 | y_one_hot[np.arange(config.batch_size), y] = 1 246 | 247 | samples = sess.run(wgan_gp.X_fake, feed_dict={wgan_gp.z: z_sample, wgan_gp.y: y_one_hot}) 248 | else: 249 | samples = sess.run(wgan_gp.X_fake, feed_dict={wgan_gp.z: z_sample}) 250 | 251 | try: 252 | make_gif(samples, './samples/test_gif_%s.gif' % (idx)) 253 | except: 254 | save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y-%m-%d-%H-%M-%S", gmtime())) 255 | elif option == 3: 256 | values = np.arange(0, 1, 1./config.batch_size) 257 | for idx in xrange(wgan_gp.z_dim): 258 | print(" [*] %d" % idx) 259 | z_sample = np.zeros([config.batch_size, wgan_gp.z_dim]) 260 | for kdx, z in enumerate(z_sample): 261 | z[idx] = values[kdx] 262 | 263 | samples = sess.run(wgan_gp.sampler, feed_dict={wgan_gp.z: z_sample}) 264 | make_gif(samples, './samples/test_gif_%s.gif' % (idx)) 265 | elif option == 4: 266 | image_set = [] 267 | values = np.arange(0, 1, 1./config.batch_size) 268 | 269 | for idx in xrange(wgan_gp.z_dim): 270 | print(" [*] %d" % idx) 271 | z_sample = np.zeros([config.batch_size, wgan_gp.z_dim]) 272 | for kdx, z in enumerate(z_sample): z[idx] = values[kdx] 273 | 274 | image_set.append(sess.run(wgan_gp.sampler, feed_dict={wgan_gp.z: z_sample})) 275 | make_gif(image_set[-1], './samples/test_gif_%s.gif' % (idx)) 276 | 277 | new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \ 278 | for idx in range(64) + range(63, -1, -1)] 279 | make_gif(new_image_set, './samples/test_gif_merged.gif', duration=8) 280 | 281 | 282 | def image_manifold_size(num_images): 283 | try: 284 | im_sqrt = np.sqrt(num_images) 285 | except: 286 | im_sqrt = np.sqrt(num_images.value) 287 | manifold_h = int(np.floor(im_sqrt)) 288 | manifold_w = int(np.ceil(im_sqrt)) 289 | assert manifold_h * manifold_w == num_images 290 | return manifold_h, manifold_w 291 | --------------------------------------------------------------------------------