├── README.md ├── createPhotos.py ├── data_ops.py ├── architecture.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Least Squares Generative Adversarial Networks 2 | Implementation of [LSGANs](https://arxiv.org/pdf/1611.04076v2.pdf) in Tensorflow. Official repo for 3 | the paper can be found [here](https://github.com/martinarjovsky/WassersteinGAN). 4 | 5 | ___ 6 | 7 | Requirements 8 | * Python 2.7 9 | * [Tensorflow v1.0](https://www.tensorflow.org/) 10 | 11 | Datasets 12 | * LSUN 13 | 14 | ___ 15 | 16 | 17 | ### Results 18 | The results aren't as good as the paper shows, and I'm still investigating why. 19 | 20 | Churches from LSUN 21 | ![img](http://i.imgur.com/r2Lyavw.png) 22 | 23 | Bedrooms from LSUN 24 | ![img2](http://i.imgur.com/mhhXjby.png) 25 | 26 | ### Training 27 | 28 | ### Data 29 | I used the LSUN church and bedroom dataset. The images are resized to 112x112 (same size as the generator produces). 30 | 31 | ### How to 32 | 33 | #### Train 34 | `python train.py --DATA_DIR=[/path/to/images/] --DATASET=[dataset] --BATCH_SIZE=[batch_size]` 35 | 36 | For example, if you have the [LSUN dataset](http://lsun.cs.princeton.edu/2016/) 37 | 38 | `pytohn train.py --DATA_DIR=/mnt/lsun/church/images/ --DATASET=church` 39 | 40 | 41 | #### View Results 42 | 43 | To see a fancy picture such as the one on this page, simply run 44 | 45 | `python createPhotos.py checkpoints/church/` 46 | 47 | or wherever your model is saved. 48 | 49 | If you see the following as your "results", then you did not provide the complete path 50 | to your checkpoint, and this is from the model's initialized weights. 51 | 52 | ![bad](http://i.imgur.com/MJfmze1.jpg) 53 | 54 | -------------------------------------------------------------------------------- /createPhotos.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | Script to create a nxm matrix of photos generated by the generator 4 | 5 | ''' 6 | 7 | import cv2 8 | import tensorflow as tf 9 | import numpy as np 10 | from architecture import netG 11 | import sys 12 | from scipy import misc 13 | 14 | if __name__ == '__main__': 15 | 16 | checkpoint_dir = sys.argv[1] 17 | 18 | n = 15 # cols 19 | m = 5 # rows 20 | 21 | num_images = n*m 22 | 23 | img_size = (112, 112, 3) 24 | 25 | canvas = 255*np.ones((m*img_size[0]+(10*m)+10, n*img_size[1]+(10*n)+10, 3), dtype=np.uint8) 26 | 27 | z = tf.placeholder(tf.float32, shape=(num_images, 1024), name='z') 28 | generated_images = netG(z, num_images) 29 | 30 | init = tf.global_variables_initializer() 31 | sess = tf.Session() 32 | sess.run(init) 33 | 34 | saver = tf.train.Saver() 35 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 36 | if ckpt and ckpt.model_checkpoint_path: 37 | print "Restoring previous model..." 38 | try: 39 | saver.restore(sess, ckpt.model_checkpoint_path) 40 | print "Model restored" 41 | except: 42 | raise 43 | print "Could not restore model" 44 | exit() 45 | 46 | batch_z = np.random.normal(-1.0, 1.0, size=[num_images, 1024]).astype(np.float32) 47 | gen_imgs = sess.run([generated_images], feed_dict={z:batch_z}) 48 | gen_imgs = np.squeeze(np.asarray(gen_imgs)) 49 | 50 | start_x = 10 51 | start_y = 10 52 | 53 | x = 0 54 | y = 0 55 | 56 | for img in gen_imgs: 57 | 58 | img = (img+1.) 59 | img *= 127.5 60 | img = np.clip(img, 0, 255).astype(np.uint8) 61 | img = np.reshape(img, (112, 112, -1)) 62 | 63 | end_x = start_x+112 64 | end_y = start_y+112 65 | 66 | canvas[start_y:end_y, start_x:end_x, :] = img 67 | 68 | if x < n: 69 | start_x += 112+10 70 | x += 1 71 | if x == n: 72 | x = 0 73 | start_x = 10 74 | start_y = end_y + 10 75 | end_y = start_y+112 76 | 77 | misc.imsave('results.png', canvas) 78 | #cv2.imwrite('results.jpg', canvas) 79 | #cv2.imshow('canvas', canvas) 80 | #cv2.waitKey(0) 81 | #cv2.destroyAllWindows() 82 | -------------------------------------------------------------------------------- /data_ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import cv2 3 | import numpy as np 4 | from random import shuffle 5 | from tqdm import tqdm 6 | import os 7 | import fnmatch 8 | import cPickle as pickle 9 | import scipy.misc as misc 10 | 11 | def _read_input(filename_queue): 12 | class DataRecord(object): 13 | pass 14 | reader = tf.WholeFileReader() 15 | key, value = reader.read(filename_queue) 16 | record = DataRecord() 17 | decoded_image = tf.image.decode_jpeg(value, channels=3) 18 | decoded_image_4d = tf.expand_dims(decoded_image, 0) 19 | resized_image = tf.image.resize_bilinear(decoded_image_4d, [112, 112]) 20 | record.input_image = tf.squeeze(resized_image, squeeze_dims=[0]) 21 | return record 22 | 23 | 24 | def read_input_queue(filename_queue, batch_size): 25 | read_input = _read_input(filename_queue) 26 | num_preprocess_threads = 8 27 | min_queue_examples = int(0.1 * 100) 28 | print("Shuffling") 29 | input_image = tf.train.shuffle_batch([read_input.input_image], 30 | batch_size=batch_size, 31 | num_threads=num_preprocess_threads, 32 | capacity=min_queue_examples + 8 * batch_size, 33 | min_after_dequeue=min_queue_examples) 34 | input_image = input_image/127.5 - 1. 35 | return input_image 36 | 37 | 38 | def saveImage(images, step, images_dir): 39 | num = 0 40 | for image in images: 41 | image = (image+1.) 42 | image *= 127.5 43 | image = np.clip(image, 0, 255).astype(np.uint8) 44 | #image = np.reshape(image, (64, 64, -1)) 45 | misc.imsave(images_dir+str(step)+'_'+str(num)+'.jpg', image) 46 | num += 1 47 | if num == 5: 48 | break 49 | 50 | ''' 51 | Inputs: A directory containing images (can have nested dirs inside) and optional extension 52 | 53 | Outputs: A list of image paths 54 | ''' 55 | def getPaths(data_dir, ext='jpg'): 56 | pattern = '*.'+ext 57 | image_list = [] 58 | for d, s, fList in os.walk(data_dir): 59 | for filename in fList: 60 | if fnmatch.fnmatch(filename, pattern): 61 | image_list.append(os.path.join(d,filename)) 62 | return image_list 63 | 64 | ''' 65 | Loads the celeba data 66 | ''' 67 | def loadCeleba(data_dir, dataset): 68 | 69 | # celeba pickle file contains a list of images 70 | pkl_file = './'+dataset+'.pkl' 71 | 72 | # first, check if a pickle file has been made with the image paths 73 | if os.path.isfile(pkl_file): 74 | print 'Pickle file found' 75 | image_paths = pickle.load(open(pkl_file, 'rb')) 76 | return image_paths 77 | else: 78 | print 'Getting paths!' 79 | image_paths = getPaths(data_dir) 80 | pf = open(pkl_file, 'wb') 81 | data = pickle.dumps(image_paths) 82 | pf.write(data) 83 | pf.close() 84 | return image_paths 85 | -------------------------------------------------------------------------------- /architecture.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.slim as slim 3 | import sys 4 | 5 | ''' 6 | Leaky RELU 7 | https://arxiv.org/pdf/1502.01852.pdf 8 | ''' 9 | def lrelu(x, leak=0.2, name='lrelu'): 10 | return tf.maximum(leak*x, x) 11 | 12 | def netG(z, batch_size): 13 | 14 | print 'GENERATOR' 15 | z = slim.fully_connected(z, 7*7*256, normalizer_fn=slim.batch_norm, activation_fn=tf.identity, scope='g_z') 16 | z = tf.reshape(z, [batch_size, 7, 7, 256]) 17 | print 'z:',z 18 | 19 | print 'z:',z 20 | conv1 = slim.convolution2d_transpose(z, 256, 3, stride=2, normalizer_fn=slim.batch_norm, activation_fn=tf.identity, scope='g_conv1') 21 | conv1 = tf.nn.relu(conv1) 22 | print 'conv1:',conv1 23 | 24 | conv2 = slim.convolution2d_transpose(conv1, 256, 5, stride=1, normalizer_fn=slim.batch_norm, activation_fn=tf.identity, scope='g_conv2') 25 | conv2 = tf.nn.relu(conv2) 26 | print 'conv2:',conv2 27 | 28 | conv3 = slim.convolution2d_transpose(conv2, 256, 3, stride=2, normalizer_fn=slim.batch_norm, activation_fn=tf.identity, scope='g_conv3') 29 | conv3 = tf.nn.relu(conv3) 30 | print 'conv3:',conv3 31 | 32 | conv4 = slim.convolution2d_transpose(conv3, 256, 3, stride=1, normalizer_fn=slim.batch_norm, activation_fn=tf.identity, scope='g_conv4') 33 | conv4 = tf.nn.relu(conv4) 34 | print 'conv4:',conv4 35 | 36 | conv5 = slim.convolution2d_transpose(conv4, 128, 3, stride=2, normalizer_fn=slim.batch_norm, activation_fn=tf.identity, scope='g_conv5') 37 | conv5 = tf.nn.relu(conv5) 38 | print 'conv5:',conv5 39 | 40 | conv6 = slim.convolution2d_transpose(conv5, 64, 3, stride=2, normalizer_fn=slim.batch_norm, activation_fn=tf.identity, scope='g_conv6') 41 | conv6 = tf.nn.relu(conv6) 42 | print 'conv6:',conv6 43 | 44 | conv7 = slim.convolution2d_transpose(conv6, 3, 3, stride=1, activation_fn=tf.identity, scope='g_conv7') 45 | conv7 = tf.nn.tanh(conv7) 46 | print 'conv7:',conv7 47 | print 48 | print 'END G' 49 | print 50 | tf.add_to_collection('vars', conv1) 51 | tf.add_to_collection('vars', conv2) 52 | tf.add_to_collection('vars', conv3) 53 | tf.add_to_collection('vars', conv4) 54 | 55 | return conv7 56 | 57 | 58 | ''' 59 | Discriminator network 60 | ''' 61 | def netD(input_images, batch_size, reuse=False): 62 | print 'DISCRIMINATOR' 63 | sc = tf.get_variable_scope() 64 | with tf.variable_scope(sc, reuse=reuse): 65 | 66 | print 'input images:',input_images 67 | conv1 = slim.convolution(input_images, 64, 5, stride=2, activation_fn=tf.identity, scope='d_conv1') 68 | conv1 = lrelu(conv1) 69 | print 'conv1:',conv1 70 | 71 | conv2 = slim.convolution(conv1, 128, 5, stride=2, normalizer_fn=slim.batch_norm, activation_fn=tf.identity, scope='d_conv2') 72 | conv2 = lrelu(conv2) 73 | print 'conv2:',conv2 74 | 75 | conv3 = slim.convolution(conv2, 256, 5, stride=2, normalizer_fn=slim.batch_norm, activation_fn=tf.identity, scope='d_conv3') 76 | conv3 = lrelu(conv3) 77 | print 'conv3:',conv3 78 | 79 | conv4 = slim.convolution(conv3, 512, 5, stride=2, normalizer_fn=slim.batch_norm, activation_fn=tf.identity, scope='d_conv4') 80 | conv4 = lrelu(conv4) 81 | print 'conv4:',conv4 82 | 83 | conv4 = tf.reshape(conv4, [batch_size, -1]) 84 | fc1 = slim.fully_connected(conv4, 1, scope='d_fc1', activation_fn=tf.identity) 85 | fc1 = tf.nn.sigmoid(fc1) 86 | print 'fc1:',fc1 87 | print 'END D\n' 88 | 89 | tf.add_to_collection('vars', conv1) 90 | tf.add_to_collection('vars', conv2) 91 | tf.add_to_collection('vars', conv3) 92 | tf.add_to_collection('vars', conv4) 93 | 94 | return conv4 95 | 96 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import scipy.misc as misc 2 | import time 3 | import tensorflow as tf 4 | from architecture import netD, netG 5 | import numpy as np 6 | import random 7 | import ntpath 8 | import sys 9 | import cv2 10 | import os 11 | from skimage import color 12 | import argparse 13 | import data_ops 14 | 15 | if __name__ == '__main__': 16 | 17 | parser = argparse.ArgumentParser() 18 | 19 | parser.add_argument('--BATCH_SIZE', required=False,type=int,default=32,help='Batch size to use') 20 | parser.add_argument('--DATA_DIR', required=True,type=str,help='Directory containing images') 21 | parser.add_argument('--DATASET', required=True,type=str,help='Name of the dataset') 22 | a = parser.parse_args() 23 | 24 | DATA_DIR = a.DATA_DIR 25 | DATASET = a.DATASET 26 | BATCH_SIZE = a.BATCH_SIZE 27 | 28 | 29 | checkpoint_dir = 'checkpoints/'+DATASET+'/' 30 | 31 | try:os.mkdir('checkpoints/') 32 | except: pass 33 | try: os.mkdir(checkpoint_dir) 34 | except: pass 35 | try: os.mkdir(checkpoint_dir+'images/') 36 | except: pass 37 | 38 | images_dir = checkpoint_dir+'images/' 39 | 40 | # placeholders for data going into the network 41 | global_step = tf.Variable(0, name='global_step', trainable=False) 42 | z = tf.placeholder(tf.float32, shape=(BATCH_SIZE, 1024), name='z') 43 | 44 | train_images_list = data_ops.loadCeleba(DATA_DIR, DATASET) 45 | filename_queue = tf.train.string_input_producer(train_images_list) 46 | real_images = data_ops.read_input_queue(filename_queue, BATCH_SIZE) 47 | 48 | # generated images 49 | gen_images = netG(z, BATCH_SIZE) 50 | 51 | # get the output from D on the real and fake data 52 | errD_real = tf.reduce_mean(netD(real_images, BATCH_SIZE)) 53 | errD_fake = tf.reduce_mean(netD(gen_images, BATCH_SIZE, reuse=True)) 54 | 55 | errD = 0.5*(tf.square(errD_real - 1)) + 0.5*(tf.square(errD_fake)) 56 | errG = 0.5*(tf.square(errD_fake - 1)) 57 | 58 | # tensorboard summaries 59 | tf.summary.scalar('d_loss', errD) 60 | tf.summary.scalar('g_loss', errG) 61 | #tf.summary.image('real_images', real_images, max_outputs=BATCH_SIZE) 62 | #tf.summary.image('generated_images', gen_images, max_outputs=BATCH_SIZE) 63 | merged_summary_op = tf.summary.merge_all() 64 | 65 | # get all trainable variables, and split by network G and network D 66 | t_vars = tf.trainable_variables() 67 | d_vars = [var for var in t_vars if 'd_' in var.name] 68 | g_vars = [var for var in t_vars if 'g_' in var.name] 69 | 70 | # optimize G 71 | G_train_op = tf.train.AdamOptimizer(learning_rate=0.001,beta1=0.5).minimize(errG, var_list=g_vars, global_step=global_step) 72 | 73 | # optimize D 74 | D_train_op = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.5).minimize(errD, var_list=d_vars) 75 | 76 | saver = tf.train.Saver(max_to_keep=1) 77 | init = tf.global_variables_initializer() 78 | 79 | sess = tf.Session() 80 | sess.run(init) 81 | 82 | summary_writer = tf.summary.FileWriter(checkpoint_dir+'/logs/', graph=tf.get_default_graph()) 83 | 84 | tf.add_to_collection('G_train_op', G_train_op) 85 | tf.add_to_collection('D_train_op', D_train_op) 86 | 87 | # restore previous model if there is one 88 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 89 | if ckpt and ckpt.model_checkpoint_path: 90 | print "Restoring previous model..." 91 | try: 92 | saver.restore(sess, ckpt.model_checkpoint_path) 93 | print "Model restored" 94 | except: 95 | print "Could not restore model" 96 | pass 97 | 98 | ########################################### training portion 99 | 100 | step = sess.run(global_step) 101 | 102 | coord = tf.train.Coordinator() 103 | threads = tf.train.start_queue_runners(sess, coord=coord) 104 | 105 | num_train = len(train_images_list) 106 | 107 | while True: 108 | 109 | start = time.time() 110 | epoch_num = step/(num_train/BATCH_SIZE) 111 | batch_z = np.random.uniform(-1.0, 1.0, size=[BATCH_SIZE, 1024]).astype(np.float32) 112 | sess.run(D_train_op, feed_dict={z:batch_z}) 113 | 114 | batch_z = np.random.normal(-1.0, 1.0, size=[BATCH_SIZE, 1024]).astype(np.float32) 115 | sess.run(G_train_op, feed_dict={z:batch_z}) 116 | 117 | # now get all losses and summary *without* performing a training step - for tensorboard 118 | D_real, D_fake, D_loss, G_loss, summary = sess.run([errD_real, errD_fake, errD, errG, merged_summary_op], feed_dict={z:batch_z}) 119 | summary_writer.add_summary(summary, step) 120 | 121 | if step%10 == 0: 122 | print 'epoch_num:',epoch_num,'step:',step,'D_real:',D_real,'D_fake:',D_fake,'D loss:',D_loss,'G_loss:',G_loss,'time:',time.time()-start 123 | step += 1 124 | 125 | if step%1000 == 0: 126 | print 'Saving model...' 127 | saver.save(sess, checkpoint_dir+'checkpoint-'+str(step)) 128 | saver.export_meta_graph(checkpoint_dir+'checkpoint-'+str(step)+'.meta') 129 | batch_z = np.random.uniform(-1.0, 1.0, size=[BATCH_SIZE, 1024]).astype(np.float32) 130 | gen_imgs = sess.run([gen_images], feed_dict={z:batch_z}) 131 | 132 | data_ops.saveImage(gen_imgs[0], step, images_dir) 133 | print 'Done saving' 134 | 135 | --------------------------------------------------------------------------------