├── .gitignore ├── LICENSE ├── README.md ├── checkpoint ├── DCGAN.model-60502.data-00000-of-00001 ├── DCGAN.model-60502.index ├── DCGAN.model-60502.meta └── checkpoint ├── complete.py ├── completion.compressed.gif ├── model.py ├── ops.py ├── simple-distributions.py ├── train-dcgan.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | logs 2 | samples 3 | checkpoint 4 | data 5 | __pycache__ 6 | out -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Most of the code in this repository was written by modifying a duplicate of 2 | Taehoon Kim's DCGAN-tensorflow project: 3 | https://github.com/carpedm20/DCGAN-tensorflow 4 | 5 | The modifications are Copyright (c) 2016 Brandon Amos 6 | 7 | In compliance with the license for DCGAN-tensorflow, we reproduce the license 8 | statement below, and release the code in this directory under the same license. 9 | 10 | 11 | The MIT License (MIT) 12 | 13 | Copyright (c) 2016 Taehoon Kim 14 | 15 | Permission is hereby granted, free of charge, to any person obtaining a copy 16 | of this software and associated documentation files (the "Software"), to deal 17 | in the Software without restriction, including without limitation the rights 18 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 19 | copies of the Software, and to permit persons to whom the Software is 20 | furnished to do so, subject to the following conditions: 21 | 22 | The above copyright notice and this permission notice shall be included in all 23 | copies or substantial portions of the Software. 24 | 25 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 26 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 27 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 28 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 29 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 30 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 31 | SOFTWARE. 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image Completion with Deep Learning in TensorFlow 2 | 3 | ![](/completion.compressed.gif) 4 | 5 | + [See my blog post for more details and usage instructions](http://bamos.github.io/2016/08/09/deep-completion/). 6 | + This repository implements Raymond Yeh and Chen Chen et al.'s paper 7 | [Semantic Image Inpainting with Perceptual and Contextual Losses](https://arxiv.org/abs/1607.07539). 8 | + Most of the code in this repository was written by modifying a 9 | duplicate of [Taehoon Kim's](http://carpedm20.github.io/) 10 | [DCGAN-tensorflow](https://github.com/carpedm20/DCGAN-tensorflow) project, 11 | which is MIT-licensed. 12 | My modifications are also [MIT-licensed](./LICENSE). 13 | + The [./checkpoint](./checkpoint) directory contains a pre-trained 14 | model for faces, trained on the CelebA dataset for 20 epochs. 15 | 16 | # Citations 17 | 18 | Please consider citing this project in your 19 | publications if it helps your research. 20 | The following is a [BibTeX](http://www.bibtex.org/) 21 | and plaintext reference. 22 | The BibTeX entry requires the `url` LaTeX package. 23 | 24 | ``` 25 | @misc{amos2016image, 26 | title = {{Image Completion with Deep Learning in TensorFlow}}, 27 | author = {Amos, Brandon}, 28 | howpublished = {\url{http://bamos.github.io/2016/08/09/deep-completion}}, 29 | note = {Accessed: [Insert date here]} 30 | } 31 | 32 | Brandon Amos. Image Completion with Deep Learning in TensorFlow. 33 | http://bamos.github.io/2016/08/09/deep-completion. 34 | Accessed: [Insert date here] 35 | ``` 36 | -------------------------------------------------------------------------------- /checkpoint/DCGAN.model-60502.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bamos/dcgan-completion.tensorflow/6b0a06c300e65f209b0283ab7c076a40c42e40bb/checkpoint/DCGAN.model-60502.data-00000-of-00001 -------------------------------------------------------------------------------- /checkpoint/DCGAN.model-60502.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bamos/dcgan-completion.tensorflow/6b0a06c300e65f209b0283ab7c076a40c42e40bb/checkpoint/DCGAN.model-60502.index -------------------------------------------------------------------------------- /checkpoint/DCGAN.model-60502.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bamos/dcgan-completion.tensorflow/6b0a06c300e65f209b0283ab7c076a40c42e40bb/checkpoint/DCGAN.model-60502.meta -------------------------------------------------------------------------------- /checkpoint/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "DCGAN.model-60502" 2 | all_model_checkpoint_paths: "DCGAN.model-60502" 3 | -------------------------------------------------------------------------------- /complete.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Brandon Amos (http://bamos.github.io) 4 | # License: MIT 5 | # 2016-08-05 6 | 7 | import argparse 8 | import os 9 | import tensorflow as tf 10 | 11 | from model import DCGAN 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--approach', type=str, 15 | choices=['adam', 'hmc'], 16 | default='adam') 17 | parser.add_argument('--lr', type=float, default=0.01) 18 | parser.add_argument('--beta1', type=float, default=0.9) 19 | parser.add_argument('--beta2', type=float, default=0.999) 20 | parser.add_argument('--eps', type=float, default=1e-8) 21 | parser.add_argument('--hmcBeta', type=float, default=0.2) 22 | parser.add_argument('--hmcEps', type=float, default=0.001) 23 | parser.add_argument('--hmcL', type=int, default=100) 24 | parser.add_argument('--hmcAnneal', type=float, default=1) 25 | parser.add_argument('--nIter', type=int, default=1000) 26 | parser.add_argument('--imgSize', type=int, default=64) 27 | parser.add_argument('--lam', type=float, default=0.1) 28 | parser.add_argument('--checkpointDir', type=str, default='checkpoint') 29 | parser.add_argument('--outDir', type=str, default='completions') 30 | parser.add_argument('--outInterval', type=int, default=50) 31 | parser.add_argument('--maskType', type=str, 32 | choices=['random', 'center', 'left', 'full', 'grid', 'lowres'], 33 | default='center') 34 | parser.add_argument('--centerScale', type=float, default=0.25) 35 | parser.add_argument('imgs', type=str, nargs='+') 36 | 37 | args = parser.parse_args() 38 | 39 | assert(os.path.exists(args.checkpointDir)) 40 | 41 | config = tf.ConfigProto() 42 | config.gpu_options.allow_growth = True 43 | with tf.Session(config=config) as sess: 44 | dcgan = DCGAN(sess, image_size=args.imgSize, 45 | batch_size=min(64, len(args.imgs)), 46 | checkpoint_dir=args.checkpointDir, lam=args.lam) 47 | dcgan.complete(args) 48 | -------------------------------------------------------------------------------- /completion.compressed.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bamos/dcgan-completion.tensorflow/6b0a06c300e65f209b0283ab7c076a40c42e40bb/completion.compressed.gif -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Original Version: Taehoon Kim (http://carpedm20.github.io) 2 | # + Source: https://github.com/carpedm20/DCGAN-tensorflow/blob/e30539fb5e20d5a0fed40935853da97e9e55eee8/model.py 3 | # + License: MIT 4 | # [2016-08-05] Modifications for Completion: Brandon Amos (http://bamos.github.io) 5 | # + License: MIT 6 | 7 | from __future__ import division 8 | import os 9 | import time 10 | import math 11 | import itertools 12 | from glob import glob 13 | import tensorflow as tf 14 | from six.moves import xrange 15 | 16 | from ops import * 17 | from utils import * 18 | 19 | SUPPORTED_EXTENSIONS = ["png", "jpg", "jpeg"] 20 | 21 | def dataset_files(root): 22 | """Returns a list of all image files in the given directory""" 23 | return list(itertools.chain.from_iterable( 24 | glob(os.path.join(root, "*.{}".format(ext))) for ext in SUPPORTED_EXTENSIONS)) 25 | 26 | 27 | class DCGAN(object): 28 | def __init__(self, sess, image_size=64, is_crop=False, 29 | batch_size=64, sample_size=64, lowres=8, 30 | z_dim=100, gf_dim=64, df_dim=64, 31 | gfc_dim=1024, dfc_dim=1024, c_dim=3, 32 | checkpoint_dir=None, lam=0.1): 33 | """ 34 | 35 | Args: 36 | sess: TensorFlow session 37 | batch_size: The size of batch. Should be specified before training. 38 | lowres: (optional) Low resolution image/mask shrink factor. [8] 39 | z_dim: (optional) Dimension of dim for Z. [100] 40 | gf_dim: (optional) Dimension of gen filters in first conv layer. [64] 41 | df_dim: (optional) Dimension of discrim filters in first conv layer. [64] 42 | gfc_dim: (optional) Dimension of gen untis for for fully connected layer. [1024] 43 | dfc_dim: (optional) Dimension of discrim units for fully connected layer. [1024] 44 | c_dim: (optional) Dimension of image color. [3] 45 | """ 46 | # Currently, image size must be a (power of 2) and (8 or higher). 47 | assert(image_size & (image_size - 1) == 0 and image_size >= 8) 48 | 49 | self.sess = sess 50 | self.is_crop = is_crop 51 | self.batch_size = batch_size 52 | self.image_size = image_size 53 | self.sample_size = sample_size 54 | self.image_shape = [image_size, image_size, c_dim] 55 | 56 | self.lowres = lowres 57 | self.lowres_size = image_size // lowres 58 | self.lowres_shape = [self.lowres_size, self.lowres_size, c_dim] 59 | 60 | self.z_dim = z_dim 61 | 62 | self.gf_dim = gf_dim 63 | self.df_dim = df_dim 64 | 65 | self.gfc_dim = gfc_dim 66 | self.dfc_dim = dfc_dim 67 | 68 | self.lam = lam 69 | 70 | self.c_dim = c_dim 71 | 72 | # batch normalization : deals with poor initialization helps gradient flow 73 | self.d_bns = [ 74 | batch_norm(name='d_bn{}'.format(i,)) for i in range(4)] 75 | 76 | log_size = int(math.log(image_size) / math.log(2)) 77 | self.g_bns = [ 78 | batch_norm(name='g_bn{}'.format(i,)) for i in range(log_size)] 79 | 80 | self.checkpoint_dir = checkpoint_dir 81 | self.build_model() 82 | 83 | self.model_name = "DCGAN.model" 84 | 85 | def build_model(self): 86 | self.is_training = tf.placeholder(tf.bool, name='is_training') 87 | self.images = tf.placeholder( 88 | tf.float32, [None] + self.image_shape, name='real_images') 89 | self.lowres_images = tf.reduce_mean(tf.reshape(self.images, 90 | [self.batch_size, self.lowres_size, self.lowres, 91 | self.lowres_size, self.lowres, self.c_dim]), [2, 4]) 92 | self.z = tf.placeholder(tf.float32, [None, self.z_dim], name='z') 93 | self.z_sum = tf.summary.histogram("z", self.z) 94 | 95 | self.G = self.generator(self.z) 96 | self.lowres_G = tf.reduce_mean(tf.reshape(self.G, 97 | [self.batch_size, self.lowres_size, self.lowres, 98 | self.lowres_size, self.lowres, self.c_dim]), [2, 4]) 99 | self.D, self.D_logits = self.discriminator(self.images) 100 | 101 | self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True) 102 | 103 | self.d_sum = tf.summary.histogram("d", self.D) 104 | self.d__sum = tf.summary.histogram("d_", self.D_) 105 | self.G_sum = tf.summary.image("G", self.G) 106 | 107 | self.d_loss_real = tf.reduce_mean( 108 | tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits, 109 | labels=tf.ones_like(self.D))) 110 | self.d_loss_fake = tf.reduce_mean( 111 | tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, 112 | labels=tf.zeros_like(self.D_))) 113 | self.g_loss = tf.reduce_mean( 114 | tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, 115 | labels=tf.ones_like(self.D_))) 116 | 117 | self.d_loss_real_sum = tf.summary.scalar("d_loss_real", self.d_loss_real) 118 | self.d_loss_fake_sum = tf.summary.scalar("d_loss_fake", self.d_loss_fake) 119 | 120 | self.d_loss = self.d_loss_real + self.d_loss_fake 121 | 122 | self.g_loss_sum = tf.summary.scalar("g_loss", self.g_loss) 123 | self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss) 124 | 125 | t_vars = tf.trainable_variables() 126 | 127 | self.d_vars = [var for var in t_vars if 'd_' in var.name] 128 | self.g_vars = [var for var in t_vars if 'g_' in var.name] 129 | 130 | self.saver = tf.train.Saver(max_to_keep=1) 131 | 132 | # Completion. 133 | self.mask = tf.placeholder(tf.float32, self.image_shape, name='mask') 134 | self.lowres_mask = tf.placeholder(tf.float32, self.lowres_shape, name='lowres_mask') 135 | self.contextual_loss = tf.reduce_sum( 136 | tf.contrib.layers.flatten( 137 | tf.abs(tf.multiply(self.mask, self.G) - tf.multiply(self.mask, self.images))), 1) 138 | self.contextual_loss += tf.reduce_sum( 139 | tf.contrib.layers.flatten( 140 | tf.abs(tf.multiply(self.lowres_mask, self.lowres_G) - tf.multiply(self.lowres_mask, self.lowres_images))), 1) 141 | self.perceptual_loss = self.g_loss 142 | self.complete_loss = self.contextual_loss + self.lam*self.perceptual_loss 143 | self.grad_complete_loss = tf.gradients(self.complete_loss, self.z) 144 | 145 | def train(self, config): 146 | data = dataset_files(config.dataset) 147 | np.random.shuffle(data) 148 | assert(len(data) > 0) 149 | 150 | d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \ 151 | .minimize(self.d_loss, var_list=self.d_vars) 152 | g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \ 153 | .minimize(self.g_loss, var_list=self.g_vars) 154 | try: 155 | tf.global_variables_initializer().run() 156 | except: 157 | tf.initialize_all_variables().run() 158 | 159 | self.g_sum = tf.summary.merge( 160 | [self.z_sum, self.d__sum, self.G_sum, self.d_loss_fake_sum, self.g_loss_sum]) 161 | self.d_sum = tf.summary.merge( 162 | [self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum]) 163 | self.writer = tf.summary.FileWriter("./logs", self.sess.graph) 164 | 165 | sample_z = np.random.uniform(-1, 1, size=(self.sample_size , self.z_dim)) 166 | sample_files = data[0:self.sample_size] 167 | 168 | sample = [get_image(sample_file, self.image_size, is_crop=self.is_crop) for sample_file in sample_files] 169 | sample_images = np.array(sample).astype(np.float32) 170 | 171 | counter = 1 172 | start_time = time.time() 173 | 174 | if self.load(self.checkpoint_dir): 175 | print(""" 176 | 177 | ====== 178 | An existing model was found in the checkpoint directory. 179 | If you just cloned this repository, it's a model for faces 180 | trained on the CelebA dataset for 20 epochs. 181 | If you want to train a new model from scratch, 182 | delete the checkpoint directory or specify a different 183 | --checkpoint_dir argument. 184 | ====== 185 | 186 | """) 187 | else: 188 | print(""" 189 | 190 | ====== 191 | An existing model was not found in the checkpoint directory. 192 | Initializing a new one. 193 | ====== 194 | 195 | """) 196 | 197 | for epoch in xrange(config.epoch): 198 | data = dataset_files(config.dataset) 199 | batch_idxs = min(len(data), config.train_size) // self.batch_size 200 | 201 | for idx in xrange(0, batch_idxs): 202 | batch_files = data[idx*config.batch_size:(idx+1)*config.batch_size] 203 | batch = [get_image(batch_file, self.image_size, is_crop=self.is_crop) 204 | for batch_file in batch_files] 205 | batch_images = np.array(batch).astype(np.float32) 206 | 207 | batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]) \ 208 | .astype(np.float32) 209 | 210 | # Update D network 211 | _, summary_str = self.sess.run([d_optim, self.d_sum], 212 | feed_dict={ self.images: batch_images, self.z: batch_z, self.is_training: True }) 213 | self.writer.add_summary(summary_str, counter) 214 | 215 | # Update G network 216 | _, summary_str = self.sess.run([g_optim, self.g_sum], 217 | feed_dict={ self.z: batch_z, self.is_training: True }) 218 | self.writer.add_summary(summary_str, counter) 219 | 220 | # Run g_optim twice to make sure that d_loss does not go to zero (different from paper) 221 | _, summary_str = self.sess.run([g_optim, self.g_sum], 222 | feed_dict={ self.z: batch_z, self.is_training: True }) 223 | self.writer.add_summary(summary_str, counter) 224 | 225 | errD_fake = self.d_loss_fake.eval({self.z: batch_z, self.is_training: False}) 226 | errD_real = self.d_loss_real.eval({self.images: batch_images, self.is_training: False}) 227 | errG = self.g_loss.eval({self.z: batch_z, self.is_training: False}) 228 | 229 | counter += 1 230 | print("Epoch: [{:2d}] [{:4d}/{:4d}] time: {:4.4f}, d_loss: {:.8f}, g_loss: {:.8f}".format( 231 | epoch, idx, batch_idxs, time.time() - start_time, errD_fake+errD_real, errG)) 232 | 233 | if np.mod(counter, 100) == 1: 234 | samples, d_loss, g_loss = self.sess.run( 235 | [self.G, self.d_loss, self.g_loss], 236 | feed_dict={self.z: sample_z, self.images: sample_images, self.is_training: False} 237 | ) 238 | save_images(samples, [8, 8], 239 | './samples/train_{:02d}_{:04d}.png'.format(epoch, idx)) 240 | print("[Sample] d_loss: {:.8f}, g_loss: {:.8f}".format(d_loss, g_loss)) 241 | 242 | if np.mod(counter, 500) == 2: 243 | self.save(config.checkpoint_dir, counter) 244 | 245 | 246 | def complete(self, config): 247 | def make_dir(name): 248 | # Works on python 2.7, where exist_ok arg to makedirs isn't available. 249 | p = os.path.join(config.outDir, name) 250 | if not os.path.exists(p): 251 | os.makedirs(p) 252 | make_dir('hats_imgs') 253 | make_dir('completed') 254 | make_dir('logs') 255 | 256 | try: 257 | tf.global_variables_initializer().run() 258 | except: 259 | tf.initialize_all_variables().run() 260 | 261 | isLoaded = self.load(self.checkpoint_dir) 262 | assert(isLoaded) 263 | 264 | nImgs = len(config.imgs) 265 | 266 | batch_idxs = int(np.ceil(nImgs/self.batch_size)) 267 | lowres_mask = np.zeros(self.lowres_shape) 268 | if config.maskType == 'random': 269 | fraction_masked = 0.2 270 | mask = np.ones(self.image_shape) 271 | mask[np.random.random(self.image_shape[:2]) < fraction_masked] = 0.0 272 | elif config.maskType == 'center': 273 | assert(config.centerScale <= 0.5) 274 | mask = np.ones(self.image_shape) 275 | sz = self.image_size 276 | l = int(self.image_size*config.centerScale) 277 | u = int(self.image_size*(1.0-config.centerScale)) 278 | mask[l:u, l:u, :] = 0.0 279 | elif config.maskType == 'left': 280 | mask = np.ones(self.image_shape) 281 | c = self.image_size // 2 282 | mask[:,:c,:] = 0.0 283 | elif config.maskType == 'full': 284 | mask = np.ones(self.image_shape) 285 | elif config.maskType == 'grid': 286 | mask = np.zeros(self.image_shape) 287 | mask[::4,::4,:] = 1.0 288 | elif config.maskType == 'lowres': 289 | lowres_mask = np.ones(self.lowres_shape) 290 | mask = np.zeros(self.image_shape) 291 | else: 292 | assert(False) 293 | 294 | for idx in xrange(0, batch_idxs): 295 | l = idx*self.batch_size 296 | u = min((idx+1)*self.batch_size, nImgs) 297 | batchSz = u-l 298 | batch_files = config.imgs[l:u] 299 | batch = [get_image(batch_file, self.image_size, is_crop=self.is_crop) 300 | for batch_file in batch_files] 301 | batch_images = np.array(batch).astype(np.float32) 302 | if batchSz < self.batch_size: 303 | print(batchSz) 304 | padSz = ((0, int(self.batch_size-batchSz)), (0,0), (0,0), (0,0)) 305 | batch_images = np.pad(batch_images, padSz, 'constant') 306 | batch_images = batch_images.astype(np.float32) 307 | 308 | zhats = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim)) 309 | m = 0 310 | v = 0 311 | 312 | nRows = np.ceil(batchSz/8) 313 | nCols = min(8, batchSz) 314 | save_images(batch_images[:batchSz,:,:,:], [nRows,nCols], 315 | os.path.join(config.outDir, 'before.png')) 316 | masked_images = np.multiply(batch_images, mask) 317 | save_images(masked_images[:batchSz,:,:,:], [nRows,nCols], 318 | os.path.join(config.outDir, 'masked.png')) 319 | if lowres_mask.any(): 320 | lowres_images = np.reshape(batch_images, [self.batch_size, self.lowres_size, self.lowres, 321 | self.lowres_size, self.lowres, self.c_dim]).mean(4).mean(2) 322 | lowres_images = np.multiply(lowres_images, lowres_mask) 323 | lowres_images = np.repeat(np.repeat(lowres_images, self.lowres, 1), self.lowres, 2) 324 | save_images(lowres_images[:batchSz,:,:,:], [nRows,nCols], 325 | os.path.join(config.outDir, 'lowres.png')) 326 | for img in range(batchSz): 327 | with open(os.path.join(config.outDir, 'logs/hats_{:02d}.log'.format(img)), 'a') as f: 328 | f.write('iter loss ' + 329 | ' '.join(['z{}'.format(zi) for zi in range(self.z_dim)]) + 330 | '\n') 331 | 332 | for i in xrange(config.nIter): 333 | fd = { 334 | self.z: zhats, 335 | self.mask: mask, 336 | self.lowres_mask: lowres_mask, 337 | self.images: batch_images, 338 | self.is_training: False 339 | } 340 | run = [self.complete_loss, self.grad_complete_loss, self.G, self.lowres_G] 341 | loss, g, G_imgs, lowres_G_imgs = self.sess.run(run, feed_dict=fd) 342 | 343 | for img in range(batchSz): 344 | with open(os.path.join(config.outDir, 'logs/hats_{:02d}.log'.format(img)), 'ab') as f: 345 | f.write('{} {} '.format(i, loss[img]).encode()) 346 | np.savetxt(f, zhats[img:img+1]) 347 | 348 | if i % config.outInterval == 0: 349 | print(i, np.mean(loss[0:batchSz])) 350 | imgName = os.path.join(config.outDir, 351 | 'hats_imgs/{:04d}.png'.format(i)) 352 | nRows = np.ceil(batchSz/8) 353 | nCols = min(8, batchSz) 354 | save_images(G_imgs[:batchSz,:,:,:], [nRows,nCols], imgName) 355 | if lowres_mask.any(): 356 | imgName = imgName[:-4] + '.lowres.png' 357 | save_images(np.repeat(np.repeat(lowres_G_imgs[:batchSz,:,:,:], 358 | self.lowres, 1), self.lowres, 2), 359 | [nRows,nCols], imgName) 360 | 361 | inv_masked_hat_images = np.multiply(G_imgs, 1.0-mask) 362 | completed = masked_images + inv_masked_hat_images 363 | imgName = os.path.join(config.outDir, 364 | 'completed/{:04d}.png'.format(i)) 365 | save_images(completed[:batchSz,:,:,:], [nRows,nCols], imgName) 366 | 367 | if config.approach == 'adam': 368 | # Optimize single completion with Adam 369 | m_prev = np.copy(m) 370 | v_prev = np.copy(v) 371 | m = config.beta1 * m_prev + (1 - config.beta1) * g[0] 372 | v = config.beta2 * v_prev + (1 - config.beta2) * np.multiply(g[0], g[0]) 373 | m_hat = m / (1 - config.beta1 ** (i + 1)) 374 | v_hat = v / (1 - config.beta2 ** (i + 1)) 375 | zhats += - np.true_divide(config.lr * m_hat, (np.sqrt(v_hat) + config.eps)) 376 | zhats = np.clip(zhats, -1, 1) 377 | 378 | elif config.approach == 'hmc': 379 | # Sample example completions with HMC (not in paper) 380 | zhats_old = np.copy(zhats) 381 | loss_old = np.copy(loss) 382 | v = np.random.randn(self.batch_size, self.z_dim) 383 | v_old = np.copy(v) 384 | 385 | for steps in range(config.hmcL): 386 | v -= config.hmcEps/2 * config.hmcBeta * g[0] 387 | zhats += config.hmcEps * v 388 | np.copyto(zhats, np.clip(zhats, -1, 1)) 389 | loss, g, _, _ = self.sess.run(run, feed_dict=fd) 390 | v -= config.hmcEps/2 * config.hmcBeta * g[0] 391 | 392 | for img in range(batchSz): 393 | logprob_old = config.hmcBeta * loss_old[img] + np.sum(v_old[img]**2)/2 394 | logprob = config.hmcBeta * loss[img] + np.sum(v[img]**2)/2 395 | accept = np.exp(logprob_old - logprob) 396 | if accept < 1 and np.random.uniform() > accept: 397 | np.copyto(zhats[img], zhats_old[img]) 398 | 399 | config.hmcBeta *= config.hmcAnneal 400 | 401 | else: 402 | assert(False) 403 | 404 | def discriminator(self, image, reuse=False): 405 | with tf.variable_scope("discriminator") as scope: 406 | if reuse: 407 | scope.reuse_variables() 408 | 409 | # TODO: Investigate how to parameterise discriminator based off image size. 410 | h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv')) 411 | h1 = lrelu(self.d_bns[0](conv2d(h0, self.df_dim*2, name='d_h1_conv'), self.is_training)) 412 | h2 = lrelu(self.d_bns[1](conv2d(h1, self.df_dim*4, name='d_h2_conv'), self.is_training)) 413 | h3 = lrelu(self.d_bns[2](conv2d(h2, self.df_dim*8, name='d_h3_conv'), self.is_training)) 414 | h4 = linear(tf.reshape(h3, [-1, 8192]), 1, 'd_h4_lin') 415 | 416 | return tf.nn.sigmoid(h4), h4 417 | 418 | def generator(self, z): 419 | with tf.variable_scope("generator") as scope: 420 | self.z_, self.h0_w, self.h0_b = linear(z, self.gf_dim*8*4*4, 'g_h0_lin', with_w=True) 421 | 422 | # TODO: Nicer iteration pattern here. #readability 423 | hs = [None] 424 | hs[0] = tf.reshape(self.z_, [-1, 4, 4, self.gf_dim * 8]) 425 | hs[0] = tf.nn.relu(self.g_bns[0](hs[0], self.is_training)) 426 | 427 | i = 1 # Iteration number. 428 | depth_mul = 8 # Depth decreases as spatial component increases. 429 | size = 8 # Size increases as depth decreases. 430 | 431 | while size < self.image_size: 432 | hs.append(None) 433 | name = 'g_h{}'.format(i) 434 | hs[i], _, _ = conv2d_transpose(hs[i-1], 435 | [self.batch_size, size, size, self.gf_dim*depth_mul], name=name, with_w=True) 436 | hs[i] = tf.nn.relu(self.g_bns[i](hs[i], self.is_training)) 437 | 438 | i += 1 439 | depth_mul //= 2 440 | size *= 2 441 | 442 | hs.append(None) 443 | name = 'g_h{}'.format(i) 444 | hs[i], _, _ = conv2d_transpose(hs[i - 1], 445 | [self.batch_size, size, size, 3], name=name, with_w=True) 446 | 447 | return tf.nn.tanh(hs[i]) 448 | 449 | def save(self, checkpoint_dir, step): 450 | if not os.path.exists(checkpoint_dir): 451 | os.makedirs(checkpoint_dir) 452 | 453 | self.saver.save(self.sess, 454 | os.path.join(checkpoint_dir, self.model_name), 455 | global_step=step) 456 | 457 | def load(self, checkpoint_dir): 458 | print(" [*] Reading checkpoints...") 459 | 460 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 461 | if ckpt and ckpt.model_checkpoint_path: 462 | self.saver.restore(self.sess, ckpt.model_checkpoint_path) 463 | return True 464 | else: 465 | return False 466 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | # Original Version: Taehoon Kim (http://carpedm20.github.io) 2 | # + Source: https://github.com/carpedm20/DCGAN-tensorflow/blob/e30539fb5e20d5a0fed40935853da97e9e55eee8/ops.py 3 | # + License: MIT 4 | 5 | import math 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | from tensorflow.python.framework import ops 10 | 11 | from utils import * 12 | 13 | class batch_norm(object): 14 | """Code modification of http://stackoverflow.com/a/33950177""" 15 | def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"): 16 | with tf.variable_scope(name): 17 | self.epsilon = epsilon 18 | self.momentum = momentum 19 | 20 | self.name = name 21 | 22 | def __call__(self, x, train): 23 | return tf.contrib.layers.batch_norm(x, decay=self.momentum, updates_collections=None, epsilon=self.epsilon, 24 | center=True, scale=True, is_training=train, scope=self.name) 25 | 26 | def binary_cross_entropy(preds, targets, name=None): 27 | """Computes binary cross entropy given `preds`. 28 | 29 | For brevity, let `x = `, `z = targets`. The logistic loss is 30 | 31 | loss(x, z) = - sum_i (x[i] * log(z[i]) + (1 - x[i]) * log(1 - z[i])) 32 | 33 | Args: 34 | preds: A `Tensor` of type `float32` or `float64`. 35 | targets: A `Tensor` of the same type and shape as `preds`. 36 | """ 37 | eps = 1e-12 38 | with ops.op_scope([preds, targets], name, "bce_loss") as name: 39 | preds = ops.convert_to_tensor(preds, name="preds") 40 | targets = ops.convert_to_tensor(targets, name="targets") 41 | return tf.reduce_mean(-(targets * tf.log(preds + eps) + 42 | (1. - targets) * tf.log(1. - preds + eps))) 43 | 44 | def conv_cond_concat(x, y): 45 | """Concatenate conditioning vector on feature map axis.""" 46 | x_shapes = x.get_shape() 47 | y_shapes = y.get_shape() 48 | return tf.concat(3, [x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])]) 49 | 50 | def conv2d(input_, output_dim, 51 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 52 | name="conv2d"): 53 | with tf.variable_scope(name): 54 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], 55 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 56 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') 57 | 58 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 59 | # conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 60 | conv = tf.nn.bias_add(conv, biases) 61 | 62 | return conv 63 | 64 | def conv2d_transpose(input_, output_shape, 65 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 66 | name="conv2d_transpose", with_w=False): 67 | with tf.variable_scope(name): 68 | # filter : [height, width, output_channels, in_channels] 69 | w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]], 70 | initializer=tf.random_normal_initializer(stddev=stddev)) 71 | 72 | try: 73 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, 74 | strides=[1, d_h, d_w, 1]) 75 | 76 | # Support for verisons of TensorFlow before 0.7.0 77 | except AttributeError: 78 | deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, 79 | strides=[1, d_h, d_w, 1]) 80 | 81 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 82 | # deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) 83 | deconv = tf.nn.bias_add(deconv, biases) 84 | 85 | if with_w: 86 | return deconv, w, biases 87 | else: 88 | return deconv 89 | 90 | def lrelu(x, leak=0.2, name="lrelu"): 91 | with tf.variable_scope(name): 92 | f1 = 0.5 * (1 + leak) 93 | f2 = 0.5 * (1 - leak) 94 | return f1 * x + f2 * abs(x) 95 | 96 | def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): 97 | shape = input_.get_shape().as_list() 98 | 99 | with tf.variable_scope(scope or "Linear"): 100 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, 101 | tf.random_normal_initializer(stddev=stddev)) 102 | bias = tf.get_variable("bias", [output_size], 103 | initializer=tf.constant_initializer(bias_start)) 104 | if with_w: 105 | return tf.matmul(input_, matrix) + bias, matrix, bias 106 | else: 107 | return tf.matmul(input_, matrix) + bias 108 | -------------------------------------------------------------------------------- /simple-distributions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | from scipy.stats import norm 5 | 6 | import matplotlib as mpl 7 | mpl.use('Agg') 8 | import matplotlib.pyplot as plt 9 | plt.style.use('bmh') 10 | import matplotlib.mlab as mlab 11 | 12 | np.random.seed(0) 13 | 14 | X = np.arange(-3, 3, 0.001) 15 | Y = norm.pdf(X, 0, 1) 16 | 17 | fig = plt.figure() 18 | plt.plot(X, Y) 19 | plt.tight_layout() 20 | plt.savefig("normal-pdf.png") 21 | 22 | nSamples = 35 23 | X = np.random.normal(0, 1, nSamples) 24 | Y = np.zeros(nSamples) 25 | fig = plt.figure(figsize=(7,3)) 26 | plt.scatter(X, Y, color='k') 27 | plt.xlim((-3,3)) 28 | frame = plt.gca() 29 | frame.axes.get_yaxis().set_visible(False) 30 | plt.savefig("normal-samples.png") 31 | 32 | delta = 0.025 33 | x = np.arange(-3.0, 3.0, delta) 34 | y = np.arange(-3.0, 3.0, delta) 35 | X, Y = np.meshgrid(x, y) 36 | Z = mlab.bivariate_normal(X, Y, 1.0, 1.0, 0.0, 0.0) 37 | 38 | plt.figure() 39 | CS = plt.contour(X, Y, Z) 40 | plt.clabel(CS, inline=1, fontsize=10) 41 | 42 | nSamples = 200 43 | mean = [0, 0] 44 | cov = [[1,0], [0,1]] 45 | X, Y = np.random.multivariate_normal(mean, cov, nSamples).T 46 | plt.scatter(X, Y, color='k') 47 | 48 | plt.savefig("normal-2d.png") 49 | -------------------------------------------------------------------------------- /train-dcgan.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Original Version: Taehoon Kim (http://carpedm20.github.io) 4 | # + Source: https://github.com/carpedm20/DCGAN-tensorflow/blob/e30539fb5e20d5a0fed40935853da97e9e55eee8/main.py 5 | # + License: MIT 6 | # [2016-08-05] Modifications for Inpainting: Brandon Amos (http://bamos.github.io) 7 | # + License: MIT 8 | 9 | import os 10 | import scipy.misc 11 | import numpy as np 12 | 13 | from model import DCGAN 14 | from utils import pp, visualize, to_json 15 | 16 | import tensorflow as tf 17 | 18 | flags = tf.app.flags 19 | flags.DEFINE_integer("epoch", 25, "Epoch to train [25]") 20 | flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]") 21 | flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]") 22 | flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]") 23 | flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]") 24 | flags.DEFINE_integer("image_size", 64, "The size of image to use") 25 | flags.DEFINE_string("dataset", "lfw-aligned-64", "Dataset directory.") 26 | flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]") 27 | flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]") 28 | FLAGS = flags.FLAGS 29 | 30 | if not os.path.exists(FLAGS.checkpoint_dir): 31 | os.makedirs(FLAGS.checkpoint_dir) 32 | if not os.path.exists(FLAGS.sample_dir): 33 | os.makedirs(FLAGS.sample_dir) 34 | 35 | config = tf.ConfigProto() 36 | config.gpu_options.allow_growth = True 37 | with tf.Session(config=config) as sess: 38 | dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, 39 | is_crop=False, checkpoint_dir=FLAGS.checkpoint_dir) 40 | 41 | dcgan.train(FLAGS) 42 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Original Version: Taehoon Kim (http://carpedm20.github.io) 2 | # + Source: https://github.com/carpedm20/DCGAN-tensorflow/blob/e30539fb5e20d5a0fed40935853da97e9e55eee8/utils.py 3 | # + License: MIT 4 | 5 | """ 6 | Some codes from https://github.com/Newmu/dcgan_code 7 | """ 8 | from __future__ import division 9 | import math 10 | import json 11 | import random 12 | import pprint 13 | import scipy.misc 14 | import numpy as np 15 | from time import gmtime, strftime 16 | 17 | pp = pprint.PrettyPrinter() 18 | 19 | get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1]) 20 | 21 | def get_image(image_path, image_size, is_crop=True): 22 | return transform(imread(image_path), image_size, is_crop) 23 | 24 | def save_images(images, size, image_path): 25 | return imsave(inverse_transform(images), size, image_path) 26 | 27 | def imread(path): 28 | return scipy.misc.imread(path, mode='RGB').astype(np.float) 29 | 30 | def merge_images(images, size): 31 | return inverse_transform(images) 32 | 33 | def merge(images, size): 34 | h, w = images.shape[1], images.shape[2] 35 | img = np.zeros((int(h * size[0]), int(w * size[1]), 3)) 36 | for idx, image in enumerate(images): 37 | i = idx % size[1] 38 | j = idx // size[1] 39 | img[j*h:j*h+h, i*w:i*w+w, :] = image 40 | 41 | return img 42 | 43 | def imsave(images, size, path): 44 | img = merge(images, size) 45 | return scipy.misc.imsave(path, (255*img).astype(np.uint8)) 46 | 47 | def center_crop(x, crop_h, crop_w=None, resize_w=64): 48 | if crop_w is None: 49 | crop_w = crop_h 50 | h, w = x.shape[:2] 51 | j = int(round((h - crop_h)/2.)) 52 | i = int(round((w - crop_w)/2.)) 53 | return scipy.misc.imresize(x[j:j+crop_h, i:i+crop_w], 54 | [resize_w, resize_w]) 55 | 56 | def transform(image, npx=64, is_crop=True): 57 | # npx : # of pixels width/height of image 58 | if is_crop: 59 | cropped_image = center_crop(image, npx) 60 | else: 61 | cropped_image = image 62 | return np.array(cropped_image)/127.5 - 1. 63 | 64 | def inverse_transform(images): 65 | return (images+1.)/2. 66 | 67 | 68 | def to_json(output_path, *layers): 69 | with open(output_path, "w") as layer_f: 70 | lines = "" 71 | for w, b, bn in layers: 72 | layer_idx = w.name.split('/')[0].split('h')[1] 73 | 74 | B = b.eval() 75 | 76 | if "lin/" in w.name: 77 | W = w.eval() 78 | depth = W.shape[1] 79 | else: 80 | W = np.rollaxis(w.eval(), 2, 0) 81 | depth = W.shape[0] 82 | 83 | biases = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(B)]} 84 | if bn != None: 85 | gamma = bn.gamma.eval() 86 | beta = bn.beta.eval() 87 | 88 | gamma = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(gamma)]} 89 | beta = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(beta)]} 90 | else: 91 | gamma = {"sy": 1, "sx": 1, "depth": 0, "w": []} 92 | beta = {"sy": 1, "sx": 1, "depth": 0, "w": []} 93 | 94 | if "lin/" in w.name: 95 | fs = [] 96 | for w in W.T: 97 | fs.append({"sy": 1, "sx": 1, "depth": W.shape[0], "w": ['%.2f' % elem for elem in list(w)]}) 98 | 99 | lines += """ 100 | var layer_%s = { 101 | "layer_type": "fc", 102 | "sy": 1, "sx": 1, 103 | "out_sx": 1, "out_sy": 1, 104 | "stride": 1, "pad": 0, 105 | "out_depth": %s, "in_depth": %s, 106 | "biases": %s, 107 | "gamma": %s, 108 | "beta": %s, 109 | "filters": %s 110 | };""" % (layer_idx.split('_')[0], W.shape[1], W.shape[0], biases, gamma, beta, fs) 111 | else: 112 | fs = [] 113 | for w_ in W: 114 | fs.append({"sy": 5, "sx": 5, "depth": W.shape[3], "w": ['%.2f' % elem for elem in list(w_.flatten())]}) 115 | 116 | lines += """ 117 | var layer_%s = { 118 | "layer_type": "deconv", 119 | "sy": 5, "sx": 5, 120 | "out_sx": %s, "out_sy": %s, 121 | "stride": 2, "pad": 1, 122 | "out_depth": %s, "in_depth": %s, 123 | "biases": %s, 124 | "gamma": %s, 125 | "beta": %s, 126 | "filters": %s 127 | };""" % (layer_idx, 2**(int(layer_idx)+2), 2**(int(layer_idx)+2), 128 | W.shape[0], W.shape[3], biases, gamma, beta, fs) 129 | layer_f.write(" ".join(lines.replace("'","").split())) 130 | 131 | def make_gif(images, fname, duration=2, true_image=False): 132 | import moviepy.editor as mpy 133 | 134 | def make_frame(t): 135 | try: 136 | x = images[int(len(images)/duration*t)] 137 | except: 138 | x = images[-1] 139 | 140 | if true_image: 141 | return x.astype(np.uint8) 142 | else: 143 | return ((x+1)/2*255).astype(np.uint8) 144 | 145 | clip = mpy.VideoClip(make_frame, duration=duration) 146 | clip.write_gif(fname, fps = len(images) / duration) 147 | 148 | def visualize(sess, dcgan, config, option): 149 | if option == 0: 150 | z_sample = np.random.uniform(-0.5, 0.5, size=(config.batch_size, dcgan.z_dim)) 151 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 152 | save_images(samples, [8, 8], './samples/test_%s.png' % strftime("%Y-%m-%d %H:%M:%S", gmtime())) 153 | elif option == 1: 154 | values = np.arange(0, 1, 1./config.batch_size) 155 | for idx in xrange(100): 156 | print(" [*] %d" % idx) 157 | z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 158 | for kdx, z in enumerate(z_sample): 159 | z[idx] = values[kdx] 160 | 161 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 162 | save_images(samples, [8, 8], './samples/test_arange_%s.png' % (idx)) 163 | elif option == 2: 164 | values = np.arange(0, 1, 1./config.batch_size) 165 | for idx in [random.randint(0, 99) for _ in xrange(100)]: 166 | print(" [*] %d" % idx) 167 | z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim)) 168 | z_sample = np.tile(z, (config.batch_size, 1)) 169 | #z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 170 | for kdx, z in enumerate(z_sample): 171 | z[idx] = values[kdx] 172 | 173 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 174 | make_gif(samples, './samples/test_gif_%s.gif' % (idx)) 175 | elif option == 3: 176 | values = np.arange(0, 1, 1./config.batch_size) 177 | for idx in xrange(100): 178 | print(" [*] %d" % idx) 179 | z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 180 | for kdx, z in enumerate(z_sample): 181 | z[idx] = values[kdx] 182 | 183 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 184 | make_gif(samples, './samples/test_gif_%s.gif' % (idx)) 185 | elif option == 4: 186 | image_set = [] 187 | values = np.arange(0, 1, 1./config.batch_size) 188 | 189 | for idx in xrange(100): 190 | print(" [*] %d" % idx) 191 | z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 192 | for kdx, z in enumerate(z_sample): z[idx] = values[kdx] 193 | 194 | image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})) 195 | make_gif(image_set[-1], './samples/test_gif_%s.gif' % (idx)) 196 | 197 | new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \ 198 | for idx in range(64) + range(63, -1, -1)] 199 | make_gif(new_image_set, './samples/test_gif_merged.gif', duration=8) 200 | --------------------------------------------------------------------------------