├── README.md ├── discriminator.py ├── generator.py ├── requirements.txt ├── samples └── lsun_examples.png ├── swg.py └── utils └── flags_wrapper.py /README.md: -------------------------------------------------------------------------------- 1 | # SWG 2 | Sliced Wasserstein Generator 3 | 4 | ![alt text](https://raw.githubusercontent.com/ishansd/swg/master/samples/lsun_examples.png) 5 | 6 | 7 | https://arxiv.org/abs/1803.11188 8 | 9 | # Code tested with: 10 | Python 3.6.4 11 | 12 | Tensorflow 1.5.0 13 | 14 | Numpy 1.14.0 15 | 16 | 17 | # Additional requirements: 18 | Argparse 19 | 20 | Matplotlib 21 | 22 | Opencv-python 23 | 24 | # For running options: 25 | python3 swg.py -h 26 | -------------------------------------------------------------------------------- /discriminator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.layers as layers 3 | 4 | 5 | def discriminator(x, reuse=False): 6 | """discriminator 7 | Network to classify fake and true samples. 8 | params: 9 | x: Input images [batch size, 64, 64, 3] 10 | returns: 11 | y: Unnormalized probablity of sample being real [batch size, 1] 12 | h: Features from penultimate layer of discriminator 13 | [batch size, feature dim] 14 | """ 15 | batch_norm = layers.layer_norm 16 | 17 | h = x 18 | with tf.variable_scope("discriminator", reuse=reuse) as scope: 19 | h = layers.conv2d( 20 | inputs=h, 21 | num_outputs=64, 22 | kernel_size=4, 23 | stride=2, 24 | activation_fn=tf.nn.leaky_relu, 25 | normalizer_fn=batch_norm) 26 | # [32,32,64] 27 | 28 | h = layers.conv2d( 29 | inputs=h, 30 | num_outputs=128, 31 | kernel_size=4, 32 | stride=2, 33 | activation_fn=tf.nn.leaky_relu, 34 | normalizer_fn=batch_norm) 35 | # [16,16,128] 36 | 37 | h = layers.conv2d( 38 | inputs=h, 39 | num_outputs=256, 40 | kernel_size=4, 41 | stride=2, 42 | activation_fn=tf.nn.leaky_relu, 43 | normalizer_fn=batch_norm) 44 | # [8,8,256] 45 | 46 | h = layers.conv2d( 47 | inputs=h, 48 | num_outputs=512, 49 | kernel_size=4, 50 | stride=2, 51 | activation_fn=tf.nn.leaky_relu, 52 | normalizer_fn=batch_norm) 53 | # [4,4,512] 54 | 55 | h = layers.flatten(h) 56 | y = layers.fully_connected( 57 | inputs=h, 58 | num_outputs=1, 59 | activation_fn=None, 60 | biases_initializer=None) 61 | return y, h 62 | -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.layers as layers 3 | 4 | 5 | def generator(z, reuse=False): 6 | """ 7 | generator 8 | Network to produce samples. 9 | params: 10 | z: Input noise [batch size, latent dimension] 11 | returns: 12 | x_hat: Artificial image [batch size, 64, 64, 3] 13 | """ 14 | batch_norm = layers.batch_norm 15 | 16 | outputs = [] 17 | h = z 18 | with tf.variable_scope("generator", reuse=reuse) as scope: 19 | h = layers.fully_connected( 20 | inputs=h, 21 | num_outputs=4 * 4 * 1024, 22 | activation_fn=tf.nn.relu, 23 | normalizer_fn=batch_norm) 24 | h = tf.reshape(h, [-1, 4, 4, 1024]) 25 | # [4,4,1024] 26 | 27 | h = layers.conv2d_transpose( 28 | inputs=h, 29 | num_outputs=512, 30 | kernel_size=4, 31 | stride=2, 32 | activation_fn=tf.nn.relu, 33 | normalizer_fn=batch_norm) 34 | # [8,8,512] 35 | 36 | h = layers.conv2d_transpose( 37 | inputs=h, 38 | num_outputs=256, 39 | kernel_size=4, 40 | stride=2, 41 | activation_fn=tf.nn.relu, 42 | normalizer_fn=batch_norm) 43 | 44 | # [16,16,256] 45 | 46 | h = layers.conv2d_transpose( 47 | inputs=h, 48 | num_outputs=128, 49 | kernel_size=4, 50 | stride=2, 51 | activation_fn=tf.nn.relu, 52 | normalizer_fn=batch_norm) 53 | 54 | # This is an extra conv layer like the WGAN folks. 55 | h = layers.conv2d( 56 | inputs=h, 57 | num_outputs=128, 58 | kernel_size=4, 59 | stride=1, 60 | activation_fn=tf.nn.relu, 61 | normalizer_fn=batch_norm) 62 | 63 | # [32,32,128] 64 | 65 | x_hat = layers.conv2d_transpose( 66 | inputs=h, 67 | num_outputs=3, 68 | kernel_size=4, 69 | stride=2, 70 | activation_fn=tf.nn.sigmoid, 71 | biases_initializer=None) 72 | # [64,64,3] 73 | return x_hat 74 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==2.1.2 2 | numpy==1.14.1 3 | opencv-python==3.4.0.12 4 | scikit-learn==0.19.1 5 | tensorflow-gpu==1.5.0 6 | -------------------------------------------------------------------------------- /samples/lsun_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ishansd/swg/d2d3dae311b64526b2b61003f015e9f478fec081/samples/lsun_examples.png -------------------------------------------------------------------------------- /swg.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | import cv2 7 | 8 | import argparse 9 | 10 | 11 | """ 12 | For saving plots on the cluster, 13 | """ 14 | import matplotlib as mpl 15 | mpl.use('Agg') 16 | import matplotlib.pyplot as plt 17 | 18 | 19 | from generator import generator 20 | from discriminator import discriminator 21 | 22 | import time 23 | 24 | from utils.flags_wrapper import flags_wrapper 25 | 26 | 27 | class swg(): 28 | """swg 29 | The generative model. 30 | 31 | params: 32 | flags: A flags_wrapper object with all hyperparams 33 | model_name: Name for output folder. Will be created in "results/" 34 | """ 35 | 36 | def __init__( 37 | self, 38 | flags=None, 39 | model_name='test_experiment'): 40 | """initialization 41 | """ 42 | self.image_width = 64 43 | self.num_channels = 3 44 | self.image_size = self.num_channels * (self.image_width**2) 45 | 46 | self.flags = flags 47 | 48 | self.base_dir = 'results/' + model_name 49 | import os 50 | import errno 51 | 52 | try: 53 | "" 54 | os.makedirs(self.base_dir) 55 | except OSError as exc: # Python >2.5 56 | if exc.errno == errno.EEXIST and os.path.isdir(self.base_dir): 57 | pass 58 | 59 | self.build_model() 60 | 61 | return 62 | 63 | def read_data(self): 64 | """read_data 65 | Assumes the data is in a single numpy array. Loads it into memory. Can be 66 | replaced by tf queues. 67 | todo: Read from disk 68 | 69 | params: None 70 | 71 | returns: 72 | im: numpy array of flattened images [number of images, 64*64*3] 73 | """ 74 | path = '/tmp/cropped_celeba.npy' 75 | im = np.load(path) 76 | return im 77 | 78 | def sw_loss(self, true_distribution, generated_distribution): 79 | """sw_loss 80 | Computes the sliced Wasserstein distance between two sets of samples in the 81 | following way: 82 | 1. Projects the samples onto random (Gaussian) directions (unit vectors). 83 | 2. For each direction, computes the Wasserstein-2 distance by sorting the 84 | two projected sets (which results in the lowest distance matching). 85 | 3. Adds distance over all directions. 86 | 87 | NOTE: 88 | This will create ops that require a fixed batch size. 89 | 90 | params: 91 | true_distribution: Samples from the true distribution 92 | [batch size, disc. feature size] 93 | generated_distribution: Samples from the generator 94 | [batch size, disc. feature size] 95 | 96 | returns: 97 | sliced Wasserstein distance 98 | """ 99 | s = true_distribution.get_shape().as_list()[-1] 100 | 101 | # theta contains the projection directions as columns : 102 | # [theta1, theta2, ...] 103 | theta = tf.random_normal(shape=[s, self.flags.num_projections]) 104 | theta = tf.nn.l2_normalize(theta, axis=0) 105 | 106 | # project the samples (images). After being transposed, we have tensors 107 | # of the format: [projected_image1, projected_image2, ...]. 108 | # Each row has the projections along one direction. This makes it 109 | # easier for the sorting that follows. 110 | projected_true = tf.transpose( 111 | tf.matmul(true_distribution, theta)) 112 | 113 | projected_fake = tf.transpose( 114 | tf.matmul(generated_distribution, theta)) 115 | 116 | sorted_true, true_indices = tf.nn.top_k( 117 | projected_true, 118 | self.flags.batch_size) 119 | 120 | sorted_fake, fake_indices = tf.nn.top_k( 121 | projected_fake, 122 | self.flags.batch_size) 123 | 124 | # For faster gradient computation, we do not use sorted_fake to compute 125 | # loss. Instead we re-order the sorted_true so that the samples from the 126 | # true distribution go to the correct sample from the fake distribution. 127 | # This is because Tensorflow did not have a GPU op for rearranging the 128 | # gradients at the time of writing this code. 129 | 130 | # It is less expensive (memory-wise) to rearrange arrays in TF. 131 | # Flatten the sorted_true from [batch_size, num_projections]. 132 | flat_true = tf.reshape(sorted_true, [-1]) 133 | 134 | # Modify the indices to reflect this transition to an array. 135 | # new index = row + index 136 | rows = np.asarray( 137 | [self.flags.batch_size * np.floor(i * 1.0 / self.flags.batch_size) 138 | for i in range(self.flags.num_projections * self.flags.batch_size)]) 139 | rows = rows.astype(np.int32) 140 | flat_idx = tf.reshape(fake_indices, [-1, 1]) + np.reshape(rows, [-1, 1]) 141 | 142 | # The scatter operation takes care of reshaping to the rearranged matrix 143 | shape = tf.constant([self.flags.batch_size * self.flags.num_projections]) 144 | rearranged_true = tf.reshape( 145 | tf.scatter_nd(flat_idx, flat_true, shape), 146 | [self.flags.num_projections, self.flags.batch_size]) 147 | 148 | return tf.reduce_mean(tf.square(projected_fake - rearranged_true)) 149 | 150 | def build_model(self): 151 | """build_model 152 | Creates the computation graph. 153 | """ 154 | 155 | # Input images from the true distribution 156 | self.x = tf.placeholder( 157 | tf.float32, 158 | [None, self.image_width, self.image_width, self.num_channels]) 159 | 160 | # Latent variable 161 | self.z = tf.placeholder(tf.float32, [None, self.flags.latent_dim]) 162 | 163 | # Output images from the GAN 164 | self.x_hat = generator(self.z) 165 | 166 | if self.flags.use_discriminator: 167 | # The discriminator returns the output (unnormalized) probability 168 | # of fake/true, and also a feature vector for the image. 169 | self.y, self.y_to_match = discriminator(self.x) 170 | self.y_hat, self.y_hat_to_match = discriminator( 171 | self.x_hat, 172 | reuse=True) 173 | 174 | # The discriminator is trained for simple binary classification. 175 | true_loss = tf.nn.sigmoid_cross_entropy_with_logits( 176 | labels=tf.ones_like(self.y), 177 | logits=self.y) 178 | fake_loss = tf.nn.sigmoid_cross_entropy_with_logits( 179 | labels=tf.zeros_like(self.y_hat), 180 | logits=self.y_hat) 181 | self.discriminator_loss = tf.reduce_mean(true_loss + fake_loss) 182 | 183 | discriminator_vars = tf.get_collection( 184 | tf.GraphKeys.GLOBAL_VARIABLES, 185 | scope='discriminator') 186 | self.d_optimizer = tf.train.AdamOptimizer( 187 | self.flags.learning_rate, 188 | beta1=0.5).minimize(self.discriminator_loss, 189 | var_list=discriminator_vars) 190 | 191 | self.generator_loss = self.sw_loss( 192 | self.y_to_match, 193 | self.y_hat_to_match) 194 | 195 | else: 196 | self.generator_loss = self.sw_loss( 197 | tf.reshape(self.x, [-1, self.image_size]), 198 | tf.reshape(self.x_hat, [-1, self.image_size])) 199 | 200 | generator_vars = tf.get_collection( 201 | tf.GraphKeys.GLOBAL_VARIABLES, 202 | scope='generator') 203 | self.g_optimizer = tf.train.AdamOptimizer( 204 | self.flags.learning_rate, 205 | beta1=0.5).minimize(self.generator_loss, 206 | var_list=generator_vars) 207 | 208 | # self.merged_summary_op = tf.summary.merge_all() 209 | return 210 | 211 | def train(self): 212 | """train 213 | Main training loop. Saves a checkpoint and sample images periodically. 214 | """ 215 | dfreq = 1 216 | diter = 1 217 | 218 | print("Loading data into memory.") 219 | data = self.read_data() 220 | max_examples = data.shape[0] 221 | print("Loaded {} examples".format(max_examples)) 222 | 223 | saver = tf.train.Saver() 224 | 225 | config = tf.ConfigProto() 226 | config.gpu_options.allow_growth = True 227 | sess = tf.Session(config=config) 228 | 229 | sess.run(tf.global_variables_initializer()) 230 | 231 | # Prefer not to use summaries, they seem to slow down execution over 232 | # time. 233 | # summary_writer = tf.summary.FileWriter(self.base_dir,sess.graph) 234 | 235 | curr_time = time.time() 236 | print("Starting code") 237 | for iteration in range(self.flags.max_iters): 238 | 239 | x = data[np.random.randint(0, max_examples, self.flags.batch_size)] 240 | z = np.random.uniform( 241 | low=-1, 242 | high=1, 243 | size=[self.flags.batch_size, self.flags.latent_dim]) 244 | 245 | sess.run(self.g_optimizer, feed_dict={self.x: x, self.z: z}) 246 | 247 | if self.flags.use_discriminator: 248 | if iteration % dfreq == 0: 249 | for diteration in range(diter): 250 | sess.run( 251 | self.d_optimizer, 252 | feed_dict={self.x: x, self.z: z}) 253 | 254 | if iteration % 50 == 0: 255 | loss = sess.run( 256 | self.generator_loss, 257 | feed_dict={self.x: x, self.z: z}) 258 | print( 259 | "Time elapsed: {}, Loss after iteration {}: {}".format( 260 | time.time() - curr_time, 261 | iteration, 262 | loss)) 263 | curr_time = time.time() 264 | 265 | if iteration % 1000 == 0: 266 | z = np.random.uniform( 267 | low=-1, 268 | high=1, 269 | size=[36, self.flags.latent_dim]) 270 | im = sess.run(self.x_hat, feed_dict={self.z: z}) 271 | im = np.reshape(im, (-1, self.image_width, self.num_channels)) 272 | im = np.hstack(np.split(im, 6)) 273 | 274 | plt.imshow(im) 275 | plt.axis('off') 276 | fig = plt.gcf() 277 | fig.set_size_inches(12, 12) 278 | plt.savefig(self.base_dir + '/Iteration_{}.png'.format( 279 | iteration), 280 | bbox_inches='tight') 281 | plt.close() 282 | 283 | if iteration % 10000 == 0: 284 | saver.save(sess, self.base_dir + '/checkpoint.ckpt') 285 | 286 | return 287 | 288 | def generate_images(self): 289 | """generate_images 290 | Method to generate samples using a pre-trained model 291 | """ 292 | sess = tf.Session() 293 | sess.run(tf.global_variables_initializer()) 294 | saver = tf.train.Saver() 295 | saver.restore(sess, tf.train.latest_checkpoint(self.base_dir + '/')) 296 | 297 | z = np.random.uniform( 298 | low=-1, 299 | high=1, 300 | size=[36, self.flags.latent_dim]) 301 | 302 | im = sess.run(self.x_hat, feed_dict={self.z: z}) 303 | 304 | im = np.reshape(im, (-1, self.image_width, self.num_channels)) 305 | im = np.hstack(np.split(im, 6)) 306 | 307 | plt.imshow(im) 308 | plt.axis('off') 309 | fig = plt.gcf() 310 | fig.set_size_inches(12, 12) 311 | plt.savefig(self.base_dir + '/Samples.png', bbox_inches='tight') 312 | plt.close() 313 | return 314 | 315 | def generate_tsne(self): 316 | """generate_tsne 317 | Method to visualize TSNE with random samples from the ground truth and 318 | generated distribution. This might help in catching mode collapse. If 319 | there is an obvious case of mode collapse, then we should see several 320 | points from the ground truth without any generated samples nearby. 321 | Purely a sanity check. 322 | """ 323 | from sklearn.manifold import TSNE 324 | 325 | num_points = 1000 326 | data = self.read_data()[:num_points] 327 | data = np.reshape(data, [num_points, -1]) 328 | 329 | print("Loaded ground truth.") 330 | 331 | sess = tf.Session() 332 | sess.run(tf.global_variables_initializer()) 333 | saver = tf.train.Saver() 334 | saver.restore(sess, tf.train.latest_checkpoint(self.base_dir + '/')) 335 | z = np.random.uniform(-1, 1, size=[num_points, self.flags.latent_dim]) 336 | 337 | generated = sess.run(self.x_hat, feed_dict={self.z: z}) 338 | generated = np.reshape(generated, [num_points, -1]) 339 | 340 | X = np.vstack((data, generated)) 341 | 342 | print("Computing TSNE.") 343 | X_embedded = TSNE(n_components=2).fit_transform(X) 344 | print("Plotting data.") 345 | 346 | plt.scatter( 347 | X_embedded[:num_points, 0], 348 | X_embedded[:num_points, 1], 349 | color='r', 350 | label='GT') 351 | 352 | plt.scatter( 353 | X_embedded[num_points:, 0], 354 | X_embedded[num_points:, 1], 355 | color='b', 356 | label='Generated', 357 | alpha=0.25) 358 | 359 | plt.axis('off') 360 | fig = plt.gcf() 361 | fig.set_size_inches(12, 12) 362 | plt.savefig(self.base_dir + '/TSNE.png', bbox_inches='tight') 363 | plt.close() 364 | 365 | return 366 | 367 | 368 | def main(argv=None): 369 | parser = argparse.ArgumentParser(description='SWGAN') 370 | 371 | parser.add_argument( 372 | '--name', 373 | metavar='output folder', 374 | default="test", 375 | help='Output folder') 376 | 377 | parser.add_argument( 378 | '--train', 379 | dest='train', 380 | action='store_true', 381 | help='Use to train') 382 | 383 | parser.add_argument( 384 | '--learning_rate', 385 | metavar='learning rate', 386 | default=1e-4, 387 | help='Learning rate for optimizer') 388 | 389 | parser.add_argument( 390 | '--max_iters', 391 | metavar='max iters', 392 | default=10000, 393 | help='Number of iterations to train') 394 | 395 | parser.add_argument( 396 | '--num_projections', 397 | metavar='num projections', 398 | default=10000, 399 | help='Number of projections to use at every step') 400 | 401 | parser.add_argument( 402 | '--batch_size', 403 | metavar='batch size', 404 | default=64, 405 | help='Batch size') 406 | 407 | parser.add_argument( 408 | '--use_discriminator', 409 | dest='use_discriminator', 410 | action='store_true', 411 | help='Enable discriminator') 412 | 413 | args = parser.parse_args() 414 | 415 | np.random.seed(np.random.randint(0, 10)) 416 | tf.set_random_seed(np.random.randint(0, 10)) 417 | tf.reset_default_graph() 418 | 419 | flags = flags_wrapper( 420 | learning_rate=args.learning_rate, 421 | max_iters=args.max_iters, 422 | batch_size=args.batch_size, 423 | num_projections=args.num_projections, 424 | use_discriminator=args.use_discriminator) 425 | 426 | g = swg(model_name=args.name, flags=flags) 427 | 428 | if args.train: 429 | g.train() 430 | g.generate_images() 431 | g.generate_tsne() 432 | return 433 | 434 | if __name__ == '__main__': 435 | main() 436 | -------------------------------------------------------------------------------- /utils/flags_wrapper.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class flags_wrapper(): 4 | """flags_wrapper 5 | My wrapper around hyperparameters. This will be passed to the generative 6 | model. 7 | """ 8 | 9 | def __init__( 10 | self, 11 | learning_rate=1e-4, 12 | batch_size=64, 13 | latent_dim=100, 14 | max_iters=10000, 15 | num_projections=10000, 16 | use_discriminator=True): 17 | 18 | self.learning_rate = learning_rate 19 | self.batch_size = int(batch_size) 20 | self.max_iters = int(max_iters) 21 | 22 | # For input noise 23 | self.latent_dim = int(latent_dim) 24 | 25 | # SWG specific params 26 | self.num_projections = int(num_projections) 27 | self.use_discriminator = use_discriminator 28 | 29 | return 30 | --------------------------------------------------------------------------------