├── models ├── __init__.py └── GAN_models.py ├── Dataset_Reader ├── __init__.py └── read_celebADataset.py ├── __init__.py ├── .gitignore ├── logs └── images │ ├── 8k.png │ ├── 97k.png │ ├── 8k_1.png │ ├── d_loss.png │ ├── g_loss.png │ ├── w_example.png │ ├── gan_generated.png │ └── wgan_generated.png ├── run_main.sh ├── LICENSE ├── main.py ├── utils.py └── README.md /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Dataset_Reader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Charlie' 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | Data_zoo 3 | logs/ 4 | Data_zoo 5 | *.pyc 6 | 7 | 8 | -------------------------------------------------------------------------------- /logs/images/8k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shekkizh/WassersteinGAN.tensorflow/HEAD/logs/images/8k.png -------------------------------------------------------------------------------- /logs/images/97k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shekkizh/WassersteinGAN.tensorflow/HEAD/logs/images/97k.png -------------------------------------------------------------------------------- /logs/images/8k_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shekkizh/WassersteinGAN.tensorflow/HEAD/logs/images/8k_1.png -------------------------------------------------------------------------------- /logs/images/d_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shekkizh/WassersteinGAN.tensorflow/HEAD/logs/images/d_loss.png -------------------------------------------------------------------------------- /logs/images/g_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shekkizh/WassersteinGAN.tensorflow/HEAD/logs/images/g_loss.png -------------------------------------------------------------------------------- /logs/images/w_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shekkizh/WassersteinGAN.tensorflow/HEAD/logs/images/w_example.png -------------------------------------------------------------------------------- /logs/images/gan_generated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shekkizh/WassersteinGAN.tensorflow/HEAD/logs/images/gan_generated.png -------------------------------------------------------------------------------- /logs/images/wgan_generated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shekkizh/WassersteinGAN.tensorflow/HEAD/logs/images/wgan_generated.png -------------------------------------------------------------------------------- /run_main.sh: -------------------------------------------------------------------------------- 1 | python main.py --logs_dir=logs/CelebA_WGAN_logs2/ --optimizer=RMSProp --learning_rate=5e-5 --optimizer_param=0.9 --model=1 --iterations=1e5 --mode=train 2 | 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Sarath Shekkizhar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | __author__ = "shekkizh" 4 | """ 5 | Tensorflow implementation of Wasserstein GAN 6 | """ 7 | import numpy as np 8 | import tensorflow as tf 9 | from models.GAN_models import * 10 | 11 | FLAGS = tf.flags.FLAGS 12 | tf.flags.DEFINE_integer("batch_size", "64", "batch size for training") 13 | tf.flags.DEFINE_string("logs_dir", "logs/CelebA_GAN_logs/", "path to logs directory") 14 | tf.flags.DEFINE_string("data_dir", "Data_zoo/CelebA_faces/", "path to dataset") 15 | tf.flags.DEFINE_integer("z_dim", "100", "size of input vector to generator") 16 | tf.flags.DEFINE_float("learning_rate", "2e-4", "Learning rate for Adam Optimizer") 17 | tf.flags.DEFINE_float("optimizer_param", "0.5", "beta1 for Adam optimizer / decay for RMSProp") 18 | tf.flags.DEFINE_float("iterations", "1e5", "No. of iterations to train model") 19 | tf.flags.DEFINE_string("image_size", "108,64", "Size of actual images, Size of images to be generated at.") 20 | tf.flags.DEFINE_integer("model", "0", "Model to train. 0 - GAN, 1 - WassersteinGAN") 21 | tf.flags.DEFINE_string("optimizer", "Adam", "Optimizer to use for training") 22 | tf.flags.DEFINE_integer("gen_dimension", "16", "dimension of first layer in generator") 23 | tf.flags.DEFINE_string("mode", "train", "train / visualize model") 24 | 25 | 26 | def main(argv=None): 27 | gen_dim = FLAGS.gen_dimension 28 | generator_dims = [64 * gen_dim, 64 * gen_dim // 2, 64 * gen_dim // 4, 64 * gen_dim // 8, 3] 29 | discriminator_dims = [3, 64, 64 * 2, 64 * 4, 64 * 8, 1] 30 | 31 | crop_image_size, resized_image_size = map(int, FLAGS.image_size.split(',')) 32 | if FLAGS.model == 0: 33 | model = GAN(FLAGS.z_dim, crop_image_size, resized_image_size, FLAGS.batch_size, FLAGS.data_dir) 34 | elif FLAGS.model == 1: 35 | model = WasserstienGAN(FLAGS.z_dim, crop_image_size, resized_image_size, FLAGS.batch_size, FLAGS.data_dir, 36 | clip_values=(-0.01, 0.01), critic_iterations=5) 37 | else: 38 | raise ValueError("Unknown model identifier - FLAGS.model=%d" % FLAGS.model) 39 | 40 | model.create_network(generator_dims, discriminator_dims, FLAGS.optimizer, FLAGS.learning_rate, 41 | FLAGS.optimizer_param) 42 | 43 | model.initialize_network(FLAGS.logs_dir) 44 | 45 | if FLAGS.mode == "train": 46 | model.train_model(int(1 + FLAGS.iterations)) 47 | elif FLAGS.mode == "visualize": 48 | model.visualize_model() 49 | 50 | 51 | if __name__ == "__main__": 52 | tf.app.run() 53 | -------------------------------------------------------------------------------- /Dataset_Reader/read_celebADataset.py: -------------------------------------------------------------------------------- 1 | __author__ = 'charlie' 2 | import numpy as np 3 | import os, sys, inspect 4 | import random 5 | from six.moves import cPickle as pickle 6 | from tensorflow.python.platform import gfile 7 | import glob 8 | 9 | utils_path = os.path.abspath( 10 | os.path.realpath(os.path.join(os.path.split(inspect.getfile(inspect.currentframe()))[0], ".."))) 11 | if utils_path not in sys.path: 12 | sys.path.insert(0, utils_path) 13 | import utils as utils 14 | 15 | DATA_URL = 'https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip' 16 | random.seed(5) 17 | 18 | 19 | class CelebA_Dataset(): 20 | def __init__(self, dict): 21 | self.train_images = dict['train'] 22 | self.test_images = dict['test'] 23 | self.validation_images = dict['validation'] 24 | 25 | 26 | def read_dataset(data_dir): 27 | pickle_filename = "celebA.pickle" 28 | pickle_filepath = os.path.join(data_dir, pickle_filename) 29 | if not os.path.exists(pickle_filepath): 30 | # utils.maybe_download_and_extract(data_dir, DATA_URL, is_zipfile=True) 31 | celebA_folder = os.path.splitext(DATA_URL.split("/")[-1])[0] 32 | dir_path = os.path.join(data_dir, celebA_folder) 33 | if not os.path.exists(dir_path): 34 | print ("CelebA dataset needs to be downloaded and unzipped manually") 35 | print ("Download from: %s" % DATA_URL) 36 | raise ValueError("Dataset not found") 37 | 38 | result = create_image_lists(dir_path) 39 | print ("Training set: %d" % len(result['train'])) 40 | print ("Test set: %d" % len(result['test'])) 41 | print ("Validation set: %d" % len(result['validation'])) 42 | print ("Pickling ...") 43 | with open(pickle_filepath, 'wb') as f: 44 | pickle.dump(result, f, pickle.HIGHEST_PROTOCOL) 45 | else: 46 | print ("Found pickle file!") 47 | 48 | with open(pickle_filepath, 'rb') as f: 49 | result = pickle.load(f) 50 | celebA = CelebA_Dataset(result) 51 | del result 52 | return celebA 53 | 54 | 55 | def create_image_lists(image_dir, testing_percentage=0.0, validation_percentage=0.0): 56 | """ 57 | Code modified from tensorflow/tensorflow/examples/image_retraining 58 | """ 59 | if not gfile.Exists(image_dir): 60 | print("Image directory '" + image_dir + "' not found.") 61 | return None 62 | training_images = [] 63 | extensions = ['jpg', 'jpeg', 'JPG', 'JPEG'] 64 | sub_dirs = [x[0] for x in os.walk(image_dir)] 65 | file_list = [] 66 | 67 | for extension in extensions: 68 | file_glob = os.path.join(image_dir, '*.' + extension) 69 | file_list.extend(glob.glob(file_glob)) 70 | 71 | if not file_list: 72 | print('No files found') 73 | else: 74 | # print "No. of files found: %d" % len(file_list) 75 | training_images.extend([f for f in file_list]) 76 | 77 | random.shuffle(training_images) 78 | no_of_images = len(training_images) 79 | validation_offset = int(validation_percentage * no_of_images) 80 | validation_images = training_images[:validation_offset] 81 | test_offset = int(testing_percentage * no_of_images) 82 | testing_images = training_images[validation_offset:validation_offset + test_offset] 83 | training_images = training_images[validation_offset + test_offset:] 84 | 85 | result = { 86 | 'train': training_images, 87 | 'test': testing_images, 88 | 'validation': validation_images, 89 | } 90 | return result 91 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | __author__ = 'shekkizh' 2 | # Utils used with tensorflow implemetation 3 | import tensorflow as tf 4 | import numpy as np 5 | import scipy.misc as misc 6 | import os, sys 7 | from six.moves import urllib 8 | import tarfile 9 | import zipfile 10 | from tqdm import trange 11 | import matplotlib.pyplot as plt 12 | from mpl_toolkits.axes_grid1 import ImageGrid 13 | 14 | def maybe_download_and_extract(dir_path, url_name, is_tarfile=False, is_zipfile=False): 15 | if not os.path.exists(dir_path): 16 | os.makedirs(dir_path) 17 | filename = url_name.split('/')[-1] 18 | filepath = os.path.join(dir_path, filename) 19 | if not os.path.exists(filepath): 20 | def _progress(count, block_size, total_size): 21 | sys.stdout.write( 22 | '\r>> Downloading %s %.1f%%' % (filename, float(count * block_size) / float(total_size) * 100.0)) 23 | sys.stdout.flush() 24 | 25 | filepath, _ = urllib.request.urlretrieve(url_name, filepath, reporthook=_progress) 26 | print() 27 | statinfo = os.stat(filepath) 28 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 29 | if is_tarfile: 30 | tarfile.open(filepath, 'r:gz').extractall(dir_path) 31 | elif is_zipfile: 32 | with zipfile.ZipFile(filepath) as zf: 33 | zip_dir = zf.namelist()[0] 34 | zf.extractall(dir_path) 35 | 36 | 37 | def save_image(image, image_size, save_dir, name=""): 38 | """ 39 | Save image by unprocessing assuming mean 127.5 40 | :param image: 41 | :param save_dir: 42 | :param name: 43 | :return: 44 | """ 45 | image += 1 46 | image *= 127.5 47 | image = np.clip(image, 0, 255).astype(np.uint8) 48 | image = np.reshape(image, (image_size, image_size, -1)) 49 | misc.imsave(os.path.join(save_dir, name + "pred_image.png"), image) 50 | 51 | 52 | def xavier_init(fan_in, fan_out, constant=1): 53 | """ Xavier initialization of network weights""" 54 | # https://stackoverflow.com/questions/33640581/how-to-do-xavier-initialization-on-tensorflow 55 | low = -constant * np.sqrt(6.0 / (fan_in + fan_out)) 56 | high = constant * np.sqrt(6.0 / (fan_in + fan_out)) 57 | return tf.random_uniform((fan_in, fan_out), minval=low, maxval=high, dtype=tf.float32) 58 | 59 | 60 | def weight_variable_xavier_initialized(shape, constant=1, name=None): 61 | stddev = constant * np.sqrt(2.0 / (shape[2] + shape[3])) 62 | return weight_variable(shape, stddev=stddev, name=name) 63 | 64 | 65 | def weight_variable(shape, stddev=0.02, name=None): 66 | initial = tf.truncated_normal(shape, stddev=stddev) 67 | if name is None: 68 | return tf.Variable(initial) 69 | else: 70 | return tf.get_variable(name, initializer=initial) 71 | 72 | 73 | def bias_variable(shape, name=None): 74 | initial = tf.constant(0.0, shape=shape) 75 | if name is None: 76 | return tf.Variable(initial) 77 | else: 78 | return tf.get_variable(name, initializer=initial) 79 | 80 | 81 | def get_tensor_size(tensor): 82 | from operator import mul 83 | return reduce(mul, (d.value for d in tensor.get_shape()), 1) 84 | 85 | 86 | def conv2d_basic(x, W, bias): 87 | conv = tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding="SAME") 88 | return tf.nn.bias_add(conv, bias) 89 | 90 | 91 | def conv2d_strided(x, W, b): 92 | conv = tf.nn.conv2d(x, W, strides=[1, 2, 2, 1], padding="SAME") 93 | return tf.nn.bias_add(conv, b) 94 | 95 | 96 | def conv2d_transpose_strided(x, W, b, output_shape=None): 97 | # print x.get_shape() 98 | # print W.get_shape() 99 | if output_shape is None: 100 | output_shape = x.get_shape().as_list() 101 | output_shape[1] *= 2 102 | output_shape[2] *= 2 103 | output_shape[3] = W.get_shape().as_list()[2] 104 | # print output_shape 105 | conv = tf.nn.conv2d_transpose(x, W, output_shape, strides=[1, 2, 2, 1], padding="SAME") 106 | return tf.nn.bias_add(conv, b) 107 | 108 | 109 | def leaky_relu(x, alpha=0.2, name=""): 110 | return tf.maximum(alpha * x, x, name) 111 | 112 | 113 | def max_pool_2x2(x): 114 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME") 115 | 116 | 117 | def avg_pool_2x2(x): 118 | return tf.nn.avg_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME") 119 | 120 | 121 | def local_response_norm(x): 122 | return tf.nn.lrn(x, depth_radius=5, bias=2, alpha=1e-4, beta=0.75) 123 | 124 | 125 | def batch_norm(x, n_out, phase_train, scope='bn', decay=0.9, eps=1e-5, stddev=0.02): 126 | """ 127 | Code taken from http://stackoverflow.com/a/34634291/2267819 128 | """ 129 | with tf.variable_scope(scope): 130 | beta = tf.get_variable(name='beta', shape=[n_out], initializer=tf.constant_initializer(0.0) 131 | , trainable=True) 132 | gamma = tf.get_variable(name='gamma', shape=[n_out], initializer=tf.random_normal_initializer(1.0, stddev), 133 | trainable=True) 134 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments') 135 | ema = tf.train.ExponentialMovingAverage(decay=decay) 136 | 137 | def mean_var_with_update(): 138 | ema_apply_op = ema.apply([batch_mean, batch_var]) 139 | with tf.control_dependencies([ema_apply_op]): 140 | return tf.identity(batch_mean), tf.identity(batch_var) 141 | 142 | mean, var = tf.cond(phase_train, 143 | mean_var_with_update, 144 | lambda: (ema.average(batch_mean), ema.average(batch_var))) 145 | normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, eps) 146 | return normed 147 | 148 | 149 | def process_image(image, mean_pixel, norm): 150 | return (image - mean_pixel) / norm 151 | 152 | 153 | def unprocess_image(image, mean_pixel, norm): 154 | return image * norm + mean_pixel 155 | 156 | 157 | def add_to_regularization_and_summary(var): 158 | if var is not None: 159 | tf.histogram_summary(var.op.name, var) 160 | tf.add_to_collection("reg_loss", tf.nn.l2_loss(var)) 161 | 162 | 163 | def add_activation_summary(var): 164 | tf.histogram_summary(var.op.name + "/activation", var) 165 | tf.scalar_summary(var.op.name + "/sparsity", tf.nn.zero_fraction(var)) 166 | 167 | 168 | def add_gradient_summary(grad, var): 169 | if grad is not None: 170 | tf.histogram_summary(var.op.name + "/gradient", grad) 171 | 172 | def save_imshow_grid(images, logs_dir, filename, shape): 173 | """ 174 | Plot images in a grid of a given shape. 175 | """ 176 | fig = plt.figure(1) 177 | grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05) 178 | 179 | size = shape[0] * shape[1] 180 | for i in trange(size, desc="Saving images"): 181 | grid[i].axis('off') 182 | grid[i].imshow(images[i]) 183 | 184 | plt.savefig(os.path.join(logs_dir, filename)) 185 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WassersteinGAN.tensorflow 2 | Tensorflow implementation of Arjovsky et al.'s [Wasserstein GAN](https://arxiv.org/abs/1701.07875) 3 | 4 | 1. [Prerequisites](#prerequisites) 5 | 2. [Results](#results) 6 | 3. [Observations](#observations) 7 | 4. [References and related links](#references-and-related-links) 8 | 9 | Note: The paper refers to discriminators as critic. I use these names interchangably in my thoughts below. 10 | 11 | A pretty interesting paper that takes on the problem of stability in GANs and interpretability of the loss function during training. GANs essentially are models that try to learn the distribution of real data by minimizing f-divergence (difference in probabilty distribution) by generating adversarial data. The convergence in min max objective of the originally proposed GAN can be interpreted as minimizing the Jensen Shannon (JS) divergence. In this paper, the authors point out the shortcomings in such metrics when the support of the two distributions being compared do not overlap and propose using the earth movers/wasserstein distance as an alternative to JS. The parallel lines example provides a nice intuition to the differences in the f-divergence metrices. Note that when the f-divergence is discrete as in JS, KL we might face problems in learning models with gradients as the divergence loss is not differetiable everywhere. 12 | 13 | Theorem 1 proposed in the paper is probably the key takeaway for anyone wondering why wasserstein distance might help in training GANS. The theorem basically states that a distribution mapping function (critic) that is continuous with respect to its parameters and locally lipschitz has a continuous and almost everywhere differentiable wasserstein distance. 14 | 15 | A continuos and almost everywhere differentiable metric would mean we can strongly train the discriminator before doing an update to the generator which in turn would receive improved reliable gradients to train from the discriminator. With the earlier formulations of GAN such training was not possible since training discriminator strongly would lead to vanishing gradients. 16 | 17 | Given that neural networks are generally continuous w.r.t to its parameters, the thing to make sure is the critic being Lipschitz. By clipping the weight parameters in the critic, we prevent the model from saturating while the growth is made atmost linear. This would mean the gradients of the function is bounded by the slope of this linearity becoming Lipschitz bound. 18 | 19 | ## Prerequisites 20 | - Code was tested in Linux system with Titan GPU. 21 | - Model was trained with tensorflow v0.11 and python2.7. Newer versions of tensorflow requires updating the summary statements to avoid depreceated warnings. 22 | - CelebA dataset should be downloaded and unzipped manually. [Download link](https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip) 23 | - Default arguments to **main.py** runs GAN with cross entropy objective. 24 | - **run_main.sh** has command to run Wasserstein GAN model. 25 | 26 | ## Results 27 | - The network architecture used to train the model is very similar to that used in the original DCGAN. This is different from what is implemented in the pytorch version of the code released with the paper - Both the generator and discriminator have "extra layers" of stride one. 28 | 29 | - All bias terms in the network are removed. I'm not quite sure what the justification for dropping the bias in generator but with the critic it might have to do with constraing the function to a smaller lipschitz bound. 30 | 31 | - The results below are after 1e5 iterations which took approximately 18hrs in my system. This is probably not the most converged result so consider it with a pinch of salt. 32 | 33 | Random sample of images generated after training GAN with wasserstein distance for 1e5 itrs, lr=5e-5, RMSPropOptimizer. 34 | ![](logs/images/wgan_generated.png) 35 | 36 | For comparison: Random sample of images generated using GAN with cross entropy objective for 2e4 itrs, lr=2e-4, AdamOptimizer. 37 | ![](logs/images/gan_generated.png) 38 | 39 | 40 | ## Observations 41 | - After spending quite a while to get the theory in the paper, I was suprised and pleased at how simple the implementation was. 42 | Major changes from the point of implementations are 43 | - The discriminator/critic no longer produces sigmoid or probabilistic output. The loss in discriminator is simple the difference in output between real and generated images. 44 | - Train critic multiple times for each generator update. 45 | - The weights in the critic is clamped to small values around zero. 46 | - Requires low learning rate and optimizers that do not use momentum. 47 | - Training is very slow. This should be expected, given the very low learning rate and multiple updates to discriminator for each generator update. 48 | 49 | - Discriminator loss for Wasserstein GAN. Note that the original paper plots the discriminator loss with a negative sign, hence the flip in the direction of the plot. From what I noticed, the general trend of the discriminator is converging but it does increase at times before dropping back. 50 | 51 | ![](logs/images/d_loss.png) 52 | 53 | - Training to minimize wasserstein distance in this problem space can be interpreted as making the critic assign low values to real data and high values to fake data. The generator on the other hand is trying to generate images that has the critic giving it low values like the ones the real images get. In other words, the model converges when the critic is no longer able to differentiate and assign different values to generated and real images - a reason why I think calling the critic a discriminator is still a reasonable thing :smile: 54 | 55 | - The generator as mentioned above is trying the have the critic assign low values like the ones real images get. While training the generator oscillates quite a bit around zero. 56 | ![](logs/images/g_loss.png) 57 | 58 | - Weights are clipped in critic to maintain lipschitz bound and continuity. An observation here as pointed out by the author in reddit worth highlighting 59 | > The weight clipping parameter is not massively important in practice, but more investigation is required. Here are the effects of having larger clipping parameter c: 60 | 61 | > The discriminator takes longer to train, since it has to saturate some weights at a larger value. This means that you can be a risk of having an insufficiently trained critic, which can provide bad estimates and gradients. Sometimes sign changes are required in the critic, and going from c to -c on some weights will take longer. If the generator is updated in the middle of this process the gradient can be pretty bad. 62 | 63 | > The capacity is increased, which helps the optimaly trained disc provide better gradients. 64 | 65 | > In general it seems that lower clipping is more stable, but higher clipping gives a better model if the critic is well trained. 66 | 67 | ![](logs/images/w_example.png) 68 | 69 | - Theoretically, the claims in the paper about quality corresponding to loss is understandable given the formulation but since quality is a relative term, I missed to see improvements in my generated images with loss for all generated images i.e how much of a loss improvement corresponds to image quality improvement is unclear. Having said that, it is pretty possible that all images generated after "convergence" are realistic. 70 | 71 | - How this new loss term would correspond to previous works associated with GANs namely semi/unsupervised learning, adaptation, adversarial losses in computer vision tasks and such is pretty exciting and interesting. 72 | 73 | ## References and related links 74 | - Pytorch implementation of WasserstienGAN by authors of the paper - [link](https://github.com/martinarjovsky/WassersteinGAN) 75 | - Interesting discussion on r/machinelearning - [link](https://www.reddit.com/r/MachineLearning/comments/5qxoaz/r_170107875_wasserstein_gan/) 76 | -------------------------------------------------------------------------------- /models/GAN_models.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | __author__ = "shekkizh" 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | import os, sys, inspect 8 | import time 9 | 10 | utils_folder = os.path.realpath( 11 | os.path.abspath(os.path.join(os.path.split(inspect.getfile(inspect.currentframe()))[0], ".."))) 12 | if utils_folder not in sys.path: 13 | sys.path.insert(0, utils_folder) 14 | 15 | import utils as utils 16 | import Dataset_Reader.read_celebADataset as celebA 17 | from six.moves import xrange 18 | 19 | 20 | class GAN(object): 21 | def __init__(self, z_dim, crop_image_size, resized_image_size, batch_size, data_dir): 22 | celebA_dataset = celebA.read_dataset(data_dir) 23 | self.z_dim = z_dim 24 | self.crop_image_size = crop_image_size 25 | self.resized_image_size = resized_image_size 26 | self.batch_size = batch_size 27 | filename_queue = tf.train.string_input_producer(celebA_dataset.train_images) 28 | self.images = self._read_input_queue(filename_queue) 29 | 30 | def _read_input(self, filename_queue): 31 | class DataRecord(object): 32 | pass 33 | 34 | reader = tf.WholeFileReader() 35 | key, value = reader.read(filename_queue) 36 | record = DataRecord() 37 | decoded_image = tf.image.decode_jpeg(value, 38 | channels=3) # Assumption:Color images are read and are to be generated 39 | 40 | # decoded_image_4d = tf.expand_dims(decoded_image, 0) 41 | # resized_image = tf.image.resize_bilinear(decoded_image_4d, [self.target_image_size, self.target_image_size]) 42 | # record.input_image = tf.squeeze(resized_image, squeeze_dims=[0]) 43 | 44 | cropped_image = tf.cast( 45 | tf.image.crop_to_bounding_box(decoded_image, 55, 35, self.crop_image_size, self.crop_image_size), 46 | tf.float32) 47 | decoded_image_4d = tf.expand_dims(cropped_image, 0) 48 | resized_image = tf.image.resize_bilinear(decoded_image_4d, [self.resized_image_size, self.resized_image_size]) 49 | record.input_image = tf.squeeze(resized_image, squeeze_dims=[0]) 50 | return record 51 | 52 | def _read_input_queue(self, filename_queue): 53 | print("Setting up image reader...") 54 | read_input = self._read_input(filename_queue) 55 | num_preprocess_threads = 4 56 | num_examples_per_epoch = 800 57 | min_queue_examples = int(0.1 * num_examples_per_epoch) 58 | print("Shuffling") 59 | input_image = tf.train.batch([read_input.input_image], 60 | batch_size=self.batch_size, 61 | num_threads=num_preprocess_threads, 62 | capacity=min_queue_examples + 2 * self.batch_size 63 | ) 64 | input_image = utils.process_image(input_image, 127.5, 127.5) 65 | return input_image 66 | 67 | def _generator(self, z, dims, train_phase, activation=tf.nn.relu, scope_name="generator"): 68 | N = len(dims) 69 | image_size = self.resized_image_size // (2 ** (N - 1)) 70 | with tf.variable_scope(scope_name) as scope: 71 | W_z = utils.weight_variable([self.z_dim, dims[0] * image_size * image_size], name="W_z") 72 | b_z = utils.bias_variable([dims[0] * image_size * image_size], name="b_z") 73 | h_z = tf.matmul(z, W_z) + b_z 74 | h_z = tf.reshape(h_z, [-1, image_size, image_size, dims[0]]) 75 | h_bnz = utils.batch_norm(h_z, dims[0], train_phase, scope="gen_bnz") 76 | h = activation(h_bnz, name='h_z') 77 | utils.add_activation_summary(h) 78 | 79 | for index in range(N - 2): 80 | image_size *= 2 81 | W = utils.weight_variable([5, 5, dims[index + 1], dims[index]], name="W_%d" % index) 82 | b = utils.bias_variable([dims[index + 1]], name="b_%d" % index) 83 | deconv_shape = tf.pack([tf.shape(h)[0], image_size, image_size, dims[index + 1]]) 84 | h_conv_t = utils.conv2d_transpose_strided(h, W, b, output_shape=deconv_shape) 85 | h_bn = utils.batch_norm(h_conv_t, dims[index + 1], train_phase, scope="gen_bn%d" % index) 86 | h = activation(h_bn, name='h_%d' % index) 87 | utils.add_activation_summary(h) 88 | 89 | image_size *= 2 90 | W_pred = utils.weight_variable([5, 5, dims[-1], dims[-2]], name="W_pred") 91 | b_pred = utils.bias_variable([dims[-1]], name="b_pred") 92 | deconv_shape = tf.pack([tf.shape(h)[0], image_size, image_size, dims[-1]]) 93 | h_conv_t = utils.conv2d_transpose_strided(h, W_pred, b_pred, output_shape=deconv_shape) 94 | pred_image = tf.nn.tanh(h_conv_t, name='pred_image') 95 | utils.add_activation_summary(pred_image) 96 | 97 | return pred_image 98 | 99 | def _discriminator(self, input_images, dims, train_phase, activation=tf.nn.relu, scope_name="discriminator", 100 | scope_reuse=False): 101 | N = len(dims) 102 | with tf.variable_scope(scope_name) as scope: 103 | if scope_reuse: 104 | scope.reuse_variables() 105 | h = input_images 106 | skip_bn = True # First layer of discriminator skips batch norm 107 | for index in range(N - 2): 108 | W = utils.weight_variable([5, 5, dims[index], dims[index + 1]], name="W_%d" % index) 109 | b = utils.bias_variable([dims[index + 1]], name="b_%d" % index) 110 | h_conv = utils.conv2d_strided(h, W, b) 111 | if skip_bn: 112 | h_bn = h_conv 113 | skip_bn = False 114 | else: 115 | h_bn = utils.batch_norm(h_conv, dims[index + 1], train_phase, scope="disc_bn%d" % index) 116 | h = activation(h_bn, name="h_%d" % index) 117 | utils.add_activation_summary(h) 118 | 119 | shape = h.get_shape().as_list() 120 | image_size = self.resized_image_size // (2 ** (N - 2)) # dims has input dim and output dim 121 | h_reshaped = tf.reshape(h, [self.batch_size, image_size * image_size * shape[3]]) 122 | W_pred = utils.weight_variable([image_size * image_size * shape[3], dims[-1]], name="W_pred") 123 | b_pred = utils.bias_variable([dims[-1]], name="b_pred") 124 | h_pred = tf.matmul(h_reshaped, W_pred) + b_pred 125 | 126 | return tf.nn.sigmoid(h_pred), h_pred, h 127 | 128 | def _cross_entropy_loss(self, logits, labels, name="x_entropy"): 129 | xentropy = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits, labels)) 130 | tf.scalar_summary(name, xentropy) 131 | return xentropy 132 | 133 | def _get_optimizer(self, optimizer_name, learning_rate, optimizer_param): 134 | self.learning_rate = learning_rate 135 | if optimizer_name == "Adam": 136 | return tf.train.AdamOptimizer(learning_rate, beta1=optimizer_param) 137 | elif optimizer_name == "RMSProp": 138 | return tf.train.RMSPropOptimizer(learning_rate, decay=optimizer_param) 139 | else: 140 | raise ValueError("Unknown optimizer %s" % optimizer_name) 141 | 142 | def _train(self, loss_val, var_list, optimizer): 143 | grads = optimizer.compute_gradients(loss_val, var_list=var_list) 144 | for grad, var in grads: 145 | utils.add_gradient_summary(grad, var) 146 | return optimizer.apply_gradients(grads) 147 | 148 | def _setup_placeholder(self): 149 | self.train_phase = tf.placeholder(tf.bool) 150 | self.z_vec = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name="z") 151 | 152 | def _gan_loss(self, logits_real, logits_fake, feature_real, feature_fake, use_features=False): 153 | discriminator_loss_real = self._cross_entropy_loss(logits_real, tf.ones_like(logits_real), 154 | name="disc_real_loss") 155 | 156 | discriminator_loss_fake = self._cross_entropy_loss(logits_fake, tf.zeros_like(logits_fake), 157 | name="disc_fake_loss") 158 | self.discriminator_loss = discriminator_loss_fake + discriminator_loss_real 159 | 160 | gen_loss_disc = self._cross_entropy_loss(logits_fake, tf.ones_like(logits_fake), name="gen_disc_loss") 161 | if use_features: 162 | gen_loss_features = tf.reduce_mean(tf.nn.l2_loss(feature_real - feature_fake)) / (self.crop_image_size ** 2) 163 | else: 164 | gen_loss_features = 0 165 | self.gen_loss = gen_loss_disc + 0.1 * gen_loss_features 166 | 167 | tf.scalar_summary("Discriminator_loss", self.discriminator_loss) 168 | tf.scalar_summary("Generator_loss", self.gen_loss) 169 | 170 | def create_network(self, generator_dims, discriminator_dims, optimizer="Adam", learning_rate=2e-4, 171 | optimizer_param=0.9, improved_gan_loss=True): 172 | print("Setting up model...") 173 | self._setup_placeholder() 174 | tf.histogram_summary("z", self.z_vec) 175 | self.gen_images = self._generator(self.z_vec, generator_dims, self.train_phase, scope_name="generator") 176 | 177 | tf.image_summary("image_real", self.images, max_images=2) 178 | tf.image_summary("image_generated", self.gen_images, max_images=2) 179 | 180 | def leaky_relu(x, name="leaky_relu"): 181 | return utils.leaky_relu(x, alpha=0.2, name=name) 182 | 183 | discriminator_real_prob, logits_real, feature_real = self._discriminator(self.images, discriminator_dims, 184 | self.train_phase, 185 | activation=leaky_relu, 186 | scope_name="discriminator", 187 | scope_reuse=False) 188 | 189 | discriminator_fake_prob, logits_fake, feature_fake = self._discriminator(self.gen_images, discriminator_dims, 190 | self.train_phase, 191 | activation=leaky_relu, 192 | scope_name="discriminator", 193 | scope_reuse=True) 194 | 195 | # utils.add_activation_summary(tf.identity(discriminator_real_prob, name='disc_real_prob')) 196 | # utils.add_activation_summary(tf.identity(discriminator_fake_prob, name='disc_fake_prob')) 197 | 198 | # Loss calculation 199 | self._gan_loss(logits_real, logits_fake, feature_real, feature_fake, use_features=improved_gan_loss) 200 | 201 | train_variables = tf.trainable_variables() 202 | 203 | for v in train_variables: 204 | # print (v.op.name) 205 | utils.add_to_regularization_and_summary(var=v) 206 | 207 | self.generator_variables = [v for v in train_variables if v.name.startswith("generator")] 208 | # print(map(lambda x: x.op.name, generator_variables)) 209 | self.discriminator_variables = [v for v in train_variables if v.name.startswith("discriminator")] 210 | # print(map(lambda x: x.op.name, discriminator_variables)) 211 | 212 | optim = self._get_optimizer(optimizer, learning_rate, optimizer_param) 213 | 214 | self.generator_train_op = self._train(self.gen_loss, self.generator_variables, optim) 215 | self.discriminator_train_op = self._train(self.discriminator_loss, self.discriminator_variables, optim) 216 | 217 | def initialize_network(self, logs_dir): 218 | print("Initializing network...") 219 | self.logs_dir = logs_dir 220 | self.sess = tf.Session() 221 | self.summary_op = tf.merge_all_summaries() 222 | self.saver = tf.train.Saver() 223 | self.summary_writer = tf.train.SummaryWriter(self.logs_dir, self.sess.graph) 224 | 225 | self.sess.run(tf.initialize_all_variables()) 226 | ckpt = tf.train.get_checkpoint_state(self.logs_dir) 227 | if ckpt and ckpt.model_checkpoint_path: 228 | self.saver.restore(self.sess, ckpt.model_checkpoint_path) 229 | print("Model restored...") 230 | self.coord = tf.train.Coordinator() 231 | self.threads = tf.train.start_queue_runners(self.sess, self.coord) 232 | 233 | def train_model(self, max_iterations): 234 | try: 235 | print("Training model...") 236 | for itr in xrange(1, max_iterations): 237 | batch_z = np.random.uniform(-1.0, 1.0, size=[self.batch_size, self.z_dim]).astype(np.float32) 238 | feed_dict = {self.z_vec: batch_z, self.train_phase: True} 239 | 240 | self.sess.run(self.discriminator_train_op, feed_dict=feed_dict) 241 | self.sess.run(self.generator_train_op, feed_dict=feed_dict) 242 | 243 | if itr % 10 == 0: 244 | g_loss_val, d_loss_val, summary_str = self.sess.run( 245 | [self.gen_loss, self.discriminator_loss, self.summary_op], feed_dict=feed_dict) 246 | print("Step: %d, generator loss: %g, discriminator_loss: %g" % (itr, g_loss_val, d_loss_val)) 247 | self.summary_writer.add_summary(summary_str, itr) 248 | 249 | if itr % 2000 == 0: 250 | self.saver.save(self.sess, self.logs_dir + "model.ckpt", global_step=itr) 251 | 252 | except tf.errors.OutOfRangeError: 253 | print('Done training -- epoch limit reached') 254 | except KeyboardInterrupt: 255 | print("Ending Training...") 256 | finally: 257 | self.coord.request_stop() 258 | self.coord.join(self.threads) # Wait for threads to finish. 259 | 260 | def visualize_model(self): 261 | print("Sampling images from model...") 262 | batch_z = np.random.uniform(-1.0, 1.0, size=[self.batch_size, self.z_dim]).astype(np.float32) 263 | feed_dict = {self.z_vec: batch_z, self.train_phase: False} 264 | 265 | images = self.sess.run(self.gen_images, feed_dict=feed_dict) 266 | images = utils.unprocess_image(images, 127.5, 127.5).astype(np.uint8) 267 | shape = [4, self.batch_size // 4] 268 | utils.save_imshow_grid(images, self.logs_dir, "generated.png", shape=shape) 269 | 270 | 271 | class WasserstienGAN(GAN): 272 | def __init__(self, z_dim, crop_image_size, resized_image_size, batch_size, data_dir, clip_values=(-0.01, 0.01), 273 | critic_iterations=5): 274 | self.critic_iterations = critic_iterations 275 | self.clip_values = clip_values 276 | GAN.__init__(self, z_dim, crop_image_size, resized_image_size, batch_size, data_dir) 277 | 278 | def _generator(self, z, dims, train_phase, activation=tf.nn.relu, scope_name="generator"): 279 | N = len(dims) 280 | image_size = self.resized_image_size // (2 ** (N - 1)) 281 | with tf.variable_scope(scope_name) as scope: 282 | W_z = utils.weight_variable([self.z_dim, dims[0] * image_size * image_size], name="W_z") 283 | h_z = tf.matmul(z, W_z) 284 | h_z = tf.reshape(h_z, [-1, image_size, image_size, dims[0]]) 285 | h_bnz = utils.batch_norm(h_z, dims[0], train_phase, scope="gen_bnz") 286 | h = activation(h_bnz, name='h_z') 287 | utils.add_activation_summary(h) 288 | 289 | for index in range(N - 2): 290 | image_size *= 2 291 | W = utils.weight_variable([4, 4, dims[index + 1], dims[index]], name="W_%d" % index) 292 | b = tf.zeros([dims[index + 1]]) 293 | deconv_shape = tf.pack([tf.shape(h)[0], image_size, image_size, dims[index + 1]]) 294 | h_conv_t = utils.conv2d_transpose_strided(h, W, b, output_shape=deconv_shape) 295 | h_bn = utils.batch_norm(h_conv_t, dims[index + 1], train_phase, scope="gen_bn%d" % index) 296 | h = activation(h_bn, name='h_%d' % index) 297 | utils.add_activation_summary(h) 298 | 299 | image_size *= 2 300 | W_pred = utils.weight_variable([4, 4, dims[-1], dims[-2]], name="W_pred") 301 | b = tf.zeros([dims[-1]]) 302 | deconv_shape = tf.pack([tf.shape(h)[0], image_size, image_size, dims[-1]]) 303 | h_conv_t = utils.conv2d_transpose_strided(h, W_pred, b, output_shape=deconv_shape) 304 | pred_image = tf.nn.tanh(h_conv_t, name='pred_image') 305 | utils.add_activation_summary(pred_image) 306 | 307 | return pred_image 308 | 309 | def _discriminator(self, input_images, dims, train_phase, activation=tf.nn.relu, scope_name="discriminator", 310 | scope_reuse=False): 311 | N = len(dims) 312 | with tf.variable_scope(scope_name) as scope: 313 | if scope_reuse: 314 | scope.reuse_variables() 315 | h = input_images 316 | skip_bn = True # First layer of discriminator skips batch norm 317 | for index in range(N - 2): 318 | W = utils.weight_variable([4, 4, dims[index], dims[index + 1]], name="W_%d" % index) 319 | b = tf.zeros([dims[index+1]]) 320 | h_conv = utils.conv2d_strided(h, W, b) 321 | if skip_bn: 322 | h_bn = h_conv 323 | skip_bn = False 324 | else: 325 | h_bn = utils.batch_norm(h_conv, dims[index + 1], train_phase, scope="disc_bn%d" % index) 326 | h = activation(h_bn, name="h_%d" % index) 327 | utils.add_activation_summary(h) 328 | 329 | W_pred = utils.weight_variable([4, 4, dims[-2], dims[-1]], name="W_pred") 330 | b = tf.zeros([dims[-1]]) 331 | h_pred = utils.conv2d_strided(h, W_pred, b) 332 | return None, h_pred, None # Return the last convolution output. None values are returned to maintatin disc from other GAN 333 | 334 | def _gan_loss(self, logits_real, logits_fake, feature_real, feature_fake, use_features=False): 335 | self.discriminator_loss = tf.reduce_mean(logits_real - logits_fake) 336 | self.gen_loss = tf.reduce_mean(logits_fake) 337 | 338 | tf.scalar_summary("Discriminator_loss", self.discriminator_loss) 339 | tf.scalar_summary("Generator_loss", self.gen_loss) 340 | 341 | def train_model(self, max_iterations): 342 | try: 343 | print("Training Wasserstein GAN model...") 344 | clip_discriminator_var_op = [var.assign(tf.clip_by_value(var, self.clip_values[0], self.clip_values[1])) for 345 | var in self.discriminator_variables] 346 | 347 | start_time = time.time() 348 | 349 | def get_feed_dict(train_phase=True): 350 | batch_z = np.random.uniform(-1.0, 1.0, size=[self.batch_size, self.z_dim]).astype(np.float32) 351 | feed_dict = {self.z_vec: batch_z, self.train_phase: train_phase} 352 | return feed_dict 353 | 354 | for itr in xrange(1, max_iterations): 355 | if itr < 25 or itr % 500 == 0: 356 | critic_itrs = 25 357 | else: 358 | critic_itrs = self.critic_iterations 359 | 360 | for critic_itr in range(critic_itrs): 361 | self.sess.run(self.discriminator_train_op, feed_dict=get_feed_dict(True)) 362 | self.sess.run(clip_discriminator_var_op) 363 | 364 | feed_dict = get_feed_dict(True) 365 | self.sess.run(self.generator_train_op, feed_dict=feed_dict) 366 | 367 | if itr % 100 == 0: 368 | summary_str = self.sess.run(self.summary_op, feed_dict=feed_dict) 369 | self.summary_writer.add_summary(summary_str, itr) 370 | 371 | if itr % 200 == 0: 372 | stop_time = time.time() 373 | duration = (stop_time - start_time) / 200.0 374 | start_time = stop_time 375 | g_loss_val, d_loss_val = self.sess.run([self.gen_loss, self.discriminator_loss], 376 | feed_dict=feed_dict) 377 | print("Time: %g/itr, Step: %d, generator loss: %g, discriminator_loss: %g" % ( 378 | duration, itr, g_loss_val, d_loss_val)) 379 | 380 | if itr % 5000 == 0: 381 | self.saver.save(self.sess, self.logs_dir + "model.ckpt", global_step=itr) 382 | 383 | except tf.errors.OutOfRangeError: 384 | print('Done training -- epoch limit reached') 385 | except KeyboardInterrupt: 386 | print("Ending Training...") 387 | finally: 388 | self.coord.request_stop() 389 | self.coord.join(self.threads) # Wait for threads to finish. 390 | --------------------------------------------------------------------------------