├── .gitignore ├── README.md ├── assets ├── celeba │ ├── began.30k.png │ ├── began.50k.png │ ├── began.75k.png │ ├── began.M.png │ ├── coulombgan.200k.png │ ├── dcgan.G1e-3.30k.png │ ├── dcgan.G2e-4.30k.png │ ├── dcgan.G2e-4.40k.png │ ├── dcgan.G2e-4.50k.png │ ├── dragan.fixed.120k.png │ ├── ebgan.nopt.30k.png │ ├── ebgan.nopt.graph.png │ ├── ebgan.pt.30k.png │ ├── ebgan.pt.graph.png │ ├── lsgan.100.30k.png │ ├── lsgan.1024.30k.png │ ├── lsgan.1024.40k.png │ ├── wgan-gp.dcgan.10k.png │ ├── wgan-gp.dcgan.20k.png │ ├── wgan-gp.dcgan.30k.png │ ├── wgan-gp.dcgan.gp.png │ ├── wgan-gp.dcgan.w_dist.expand.png │ ├── wgan-gp.dcgan.w_dist.png │ ├── wgan-gp.good.10k.png │ ├── wgan-gp.good.15k.png │ ├── wgan-gp.good.20k.png │ ├── wgan-gp.good.25k.png │ ├── wgan-gp.good.30k.png │ ├── wgan-gp.good.40k.png │ ├── wgan-gp.good.50k.png │ ├── wgan-gp.good.5k.png │ ├── wgan-gp.good.7k.png │ ├── wgan-gp.good.gp.png │ ├── wgan-gp.good.w_dist.expand.png │ ├── wgan-gp.good.w_dist.png │ ├── wgan.30k.png │ ├── wgan.40k.png │ ├── wgan.50k.png │ └── wgan.w_dist.png └── lsun │ ├── began.100k.png │ ├── began.150k.png │ ├── began.200k.png │ ├── began.250k.png │ ├── dcgan.100k.png │ ├── dragan.200k.png │ ├── ebgan.80k.png │ ├── lsgan.150k.png │ ├── wgan-gp.150k.png │ └── wgan.230k.png ├── config.py ├── convert.py ├── download.py ├── eval.py ├── inputpipe.py ├── models ├── __init__.py ├── basemodel.py ├── began.py ├── coulombgan.py ├── dcgan.py ├── dragan.py ├── ebgan.py ├── lsgan.py ├── wgan.py └── wgan_gp.py ├── ops.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .ipynb_checkpoints/ 3 | *.ipynb 4 | *.log 5 | .DS_Store 6 | models/patchgan.py 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GANs comparison without cherry-picking 2 | 3 | Implementations of some theoretical generative adversarial nets: DCGAN, EBGAN, LSGAN, WGAN, WGAN-GP, BEGAN, DRAGAN and CoulombGAN. 4 | 5 | I implemented the structure of model equal to the structure in paper and compared it on the CelebA dataset and LSUN dataset without cherry-picking. 6 | 7 | 8 | ## Table of Contents 9 | 10 | * [Features](#features) 11 | * [Models](#models) 12 | * [Dataset](#dataset) 13 | * [CelebA](#celeba) 14 | * [LSUN](#lsun-bedroom) 15 | * [Results](#results) 16 | * [DCGAN](#dcgan) 17 | * [EBGAN](#ebgan) 18 | * [LSGAN](#lsgan) 19 | * [WGAN](#wgan) 20 | * [WGAN-GP](#wgan-gp) 21 | * [BEGAN](#began) 22 | * [DRAGAN](#dragan) 23 | * [CoulombGAN](#coulombgan) 24 | * [Conclusion](#conclusion) 25 | * [Usage](#usage) 26 | * [Requirements](#requirements) 27 | * [Similar works](#similar-works) 28 | 29 | 30 | ## Features 31 | 32 | - Model architectures are same as the architectures proposed in each paper 33 | - Each model was not much tuned, so the results can be improved 34 | - Well-structured (was my goal at the start, but I don't know whether it succeed!) 35 | - TensorFlow queue runner is used for input pipeline 36 | - Single trainer (and single evaluator) - multi model structure 37 | - Logs in training and configuration are recorded on the TensorBoard 38 | 39 | ## Models 40 | 41 | - DCGAN 42 | - LSGAN 43 | - WGAN 44 | - WGAN-GP 45 | - EBGAN 46 | - BEGAN 47 | - DRAGAN 48 | - CoulombGAN 49 | 50 | The family of conditional GANs are excluded (CGAN, acGAN, and so on). 51 | 52 | ## Dataset 53 | 54 | ### CelebA 55 | 56 | http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html 57 | 58 | - All experiments were performed on 64x64 CelebA dataset 59 | - The dataset has 202599 images 60 | - 1 epoch consists of about 1.58k iterations for batch size 128 61 | 62 | ### LSUN bedroom 63 | 64 | http://lsun.cs.princeton.edu/2017/ 65 | 66 | - The dataset has 3033042 images 67 | - 1 epoch consists of about 23.7k iterations for batch size 128 68 | 69 | This dataset is provided in [LMDB](http://www.lmdb.tech/) format. https://github.com/fyu/lsun provides documentation and demo code to use it. 70 | 71 | ## Results 72 | 73 | - I implemented the same as the proposed model in each paper, but ignored some details (or the paper did not describe details of model) 74 | - Granted, a little details make great differences in the results due to the very unstable GAN training 75 | - So if you had a better results, let me know the settings 🙂 76 | - Default batch_size=128 and z_dim=100 (from DCGAN) 77 | 78 | ### DCGAN 79 | 80 | Radford, Alec, Luke Metz, and Soumith Chintala. "Unsupervised representation learning with deep convolutional generative adversarial networks." arXiv preprint arXiv:1511.06434 (2015). 81 | 82 | - Relatively simple networks 83 | - Learning rate for discriminator (D_lr) is 2e-4 and learning rate for generator (G_lr) is 2e-4 (proposed in the paper) and 1e-3 84 | 85 | | G_lr=2e-4 | G_lr=1e-3 | 86 | | :--------------------------------------: | :--------------------------------------: | 87 | | 50k | 30k | 88 | | ![dcgan.G2e-4.50k](assets/celeba/dcgan.G2e-4.50k.png) | ![dcgan.G1e-3.30k](assets/celeba/dcgan.G1e-3.30k.png) | 89 | 90 | Second row (50k, 30k) indicates each training iteration. 91 | 92 | Higher learning rate (1e-3) for generator made better results. In this case, however, the generator has been collapsed sometimes due to its large learning rate. Lowering both learning rate may bring stability like https://ajolicoeur.wordpress.com/cats/ in which suggested D_lr=5e-5 and G_lr=2e-4. 93 | 94 | | LSUN | 95 | | :--------------------------------------: | 96 | | 100k | 97 | | ![dcgan.100k](assets/lsun/dcgan.100k.png) | 98 | 99 | 100 | 101 | ### EBGAN 102 | 103 | Zhao, Junbo, Michael Mathieu, and Yann LeCun. "Energy-based generative adversarial network." arXiv preprint arXiv:1609.03126 (2016). 104 | 105 | - I like energy concept, so this paper is very interesting for me :) 106 | - But there is criticism: [Are Energy-Based GANs any more energy-based than normal GANs?](http://www.inference.vc/are-energy-based-gans-actually-energy-based/) 107 | - Anyway, the energy concept and autoencoder based loss function are impressive, and the results are also fine 108 | - But I have a question for Pulling-away Term (PT), which prevents mode-collapse theoretically. This is the same idea as minibatch discrimination (T. Salimans et al). 109 | 110 | 111 | | pt weight = 0.1 | No pt loss | 112 | | :--------------------------------------: | :--------------------------------------: | 113 | | 30k | 30k | 114 | | ![ebgan.pt.30k](assets/celeba/ebgan.pt.30k.png) | ![ebgan.nopt.30k](assets/celeba/ebgan.nopt.30k.png) | 115 | 116 | The model using PT generates slightly better sample visually. However, it is not clear from this results whether PT prevents mode-collapse. Furthermore, I could not distinguish what setting is better from repeated experiments. 117 | 118 | 119 | | pt weight = 0.1 | No pt loss | 120 | | :--------------------------------------: | :--------------------------------------: | 121 | | ![ebgan.pt.graph](assets/celeba/ebgan.pt.graph.png) | ![ebgan.nopt.graph](assets/celeba/ebgan.nopt.graph.png) | 122 | 123 | pt_loss decreases a little faster in the left which used pt_weight=0.1 but there is no big difference and even at the end the right which used no pt_loss showed a lower pt_loss. So I wonder: is the PT loss really working for preventing mode-collapse as described in the paper? 124 | 125 | | LSUN | 126 | | :-------------------------------------: | 127 | | 80k | 128 | | ![ebgan.80k](assets/lsun/ebgan.80k.png) | 129 | 130 | ### LSGAN 131 | 132 | Mao, Xudong, et al. "Least squares generative adversarial networks." arXiv preprint ArXiv:1611.04076 (2016). 133 | 134 | - Unusually, LSGAN used large latent space dimension (z_dim=1024) 135 | - But in my experiment, z_dim=100 makes better results than z_dim=1024 which is originally used in paper 136 | 137 | | z_dim=100 | z_dim=1024 | 138 | | :--------------------------------------: | :--------------------------------------: | 139 | | 30k | 30k | 140 | | ![lsgan.100.30k](assets/celeba/lsgan.100.30k.png) | ![lsgan.1024.30k](assets/celeba/lsgan.1024.30k.png) | 141 | 142 | | LSUN | 143 | | :--------------------------------------: | 144 | | 150k | 145 | | ![lsgan.150k](assets/lsun/lsgan.150k.png) | 146 | 147 | ### WGAN 148 | 149 | Arjovsky, Martin, Soumith Chintala, and Léon Bottou. "Wasserstein gan." arXiv preprint arXiv:1701.07875 (2017). 150 | 151 | - The samples from WGAN are not that impressive - compared to the very impressive theory 152 | - Also no specific network structure proposed, so DCGAN architecture was used for experiments 153 | - In the [author's implementation](https://github.com/martinarjovsky/WassersteinGAN), they used higher n_critic in the early stage of training and per 500 iterations 154 | 155 | | 30k | W distance | 156 | | :-------------------------------------: | :--------------------------------------: | 157 | | ![wgan.30k](assets/celeba/wgan.30k.png) | ![wgan.w_dist](assets/celeba/wgan.w_dist.png) | 158 | 159 | | LSUN | 160 | | :-------------------------------------: | 161 | | 230k | 162 | | ![wgan.230k](assets/lsun/wgan.230k.png) | 163 | 164 | ### WGAN-GP 165 | 166 | Gulrajani, Ishaan, et al. "Improved training of wasserstein gans." arXiv preprint arXiv:1704.00028 (2017). 167 | 168 | - I tried two network architectures, which are DCGAN architecture and ResNet architecture in appendix C 169 | - ResNet has more complicated architecture and better performance than DCGAN architecture 170 | - The interesting thing is that the visual quality of samples improves very quickly (ResNet WGAN-GP has best samples on 7k iterations) and it gets worse when continue training 171 | - According to DRAGAN, constraints of WGAN are too restrictive to learn good generator 172 | 173 | | DCGAN architecture | ResNet architecture | 174 | | :--------------------------------------: | :--------------------------------------: | 175 | | 30k | 7k, batch size = 64 | 176 | | ![wgan-gp.dcgan.30k](assets/celeba/wgan-gp.dcgan.30k.png) | ![wgan-gp.good.7k](assets/celeba/wgan-gp.good.7k.png) | 177 | 178 | | LSUN | 179 | | :--------------------------------------: | 180 | | 100k, ResNet architecture | 181 | | ![wgan-gp.150k](assets/lsun/wgan-gp.150k.png) | 182 | 183 | #### Face collapse phenomenon 184 | 185 | WGAN-GP was collapsed more than other models when the iteration increases. 186 | 187 | **DCGAN architecture** 188 | 189 | | 10k | 20k | 30k | 190 | | :--------------------------------------: | :--------------------------------------: | :--------------------------------------: | 191 | | ![wgan-gp.dcgan.10k](assets/celeba/wgan-gp.dcgan.10k.png) | ![wgan-gp.dcgan.20k](assets/celeba/wgan-gp.dcgan.20k.png) | ![wgan-gp.dcgan.30k](assets/celeba/wgan-gp.dcgan.30k.png) | 192 | 193 | 194 | **ResNet architecture** 195 | 196 | ResNet architecture showed the best visual quality sample in the very early stage, 7k iteration in my criteria. This maybe due to the residual architecture. 197 | 198 | batch_size=64. 199 | 200 | | 5k | 7k | 10k | 15k | 201 | | :--------------------------------------: | :--------------------------------------: | :--------------------------------------: | :--------------------------------------: | 202 | | ![wgan-gp.good.5k](assets/celeba/wgan-gp.good.5k.png) | ![wgan-gp.good.7k](assets/celeba/wgan-gp.good.7k.png) | ![wgan-gp.good.10k](assets/celeba/wgan-gp.good.10k.png) | ![wgan-gp.good.15k](assets/celeba/wgan-gp.good.15k.png) | 203 | | 20k | 25k | 30k | 40k | 204 | | ![wgan-gp.good.20k](assets/celeba/wgan-gp.good.20k.png) | ![wgan-gp.good.25k](assets/celeba/wgan-gp.good.25k.png) | ![wgan-gp.good.30k](assets/celeba/wgan-gp.good.30k.png) | ![wgan-gp.good.40k](assets/celeba/wgan-gp.good.40k.png) | 205 | 206 | Regardless of the face collapse phenomenon, the Wasserstein distance decreased steadily. It should come from that the critic (discriminator) network failed to find the supremum and K-Lipschitz function. 207 | 208 | | DCGAN architecture | ResNet architecture | 209 | | :--------------------------------------: | :--------------------------------------: | 210 | | ![wgan-gp.dcgan.w_dist](assets/celeba/wgan-gp.dcgan.w_dist.png) | ![wgan-gp.good.w_dist](assets/celeba/wgan-gp.good.w_dist.png) | 211 | | ![wgan-gp.dcgan.w_dist.expand](assets/celeba/wgan-gp.dcgan.w_dist.expand.png) | ![wgan-gp.good.w_dist.expand](assets/celeba/wgan-gp.good.w_dist.expand.png) | 212 | 213 | The plots in the last row of the table are just expanded version of the plots in the second row. 214 | 215 | It is interesting that W_dist < 0 at the end of the training. This indicates that E[fake] > E[real] and, in the point of original GAN view, it means the generator dominates the discriminator. 216 | 217 | ### BEGAN 218 | 219 | Berthelot, David, Tom Schumm, and Luke Metz. "Began: Boundary equilibrium generative adversarial networks." arXiv preprint arXiv:1703.10717 (2017). 220 | 221 | - The best model that generates samples with the best visual quality as far as I know 222 | - It also showed the best performance in this project 223 | - Even though optional improvements was not implemented (section 3.5.1 in the paper) 224 | - However, the samples generated by BEGAN give a slightly different feel from other models - it seems like disappearing details. 225 | - So I just wonder what the results are for different datasets 226 | 227 | batch_size=16, z_dim=64, gamma=0.5. 228 | 229 | | 30k | 50k | 75k | 230 | | :--------------------------------------: | :--------------------------------------: | :--------------------------------------: | 231 | | ![began.30k](assets/celeba/began.30k.png) | ![began.50k](assets/celeba/began.50k.png) | ![began.75k](assets/celeba/began.75k.png) | 232 | 233 | | Convergence measure M | 234 | | :-----------------------------------: | 235 | | ![began.M](assets/celeba/began.M.png) | 236 | 237 | I also tried to reduce speck-like artifacts as suggested in [Heumi/BEGAN-tensorflow](https://github.com/Heumi/BEGAN-tensorflow/), but it did not go away. 238 | 239 | 246 | 247 | 248 | 249 | BEGAN in the LSUN datset works terribly. Not only severe mode-collapse was observed, but also generated images were not realistic. 250 | 251 | | LSUN | LSUN | 252 | | :--------------------------------------: | :--------------------------------------: | 253 | | 100k | 150k | 254 | | ![began.100k](assets/lsun/began.100k.png) | ![began.150k](assets/lsun/began.150k.png) | 255 | | 200k | 250k | 256 | | ![began.200k](assets/lsun/began.200k.png) | ![began.250k](assets/lsun/began.250k.png) | 257 | 258 | 259 | ### DRAGAN 260 | 261 | Kodali, Naveen, et al. "How to Train Your DRAGAN." arXiv preprint arXiv:1705.07215 (2017). 262 | 263 | - Different with other papers, DRAGAN was motivated from the game theory for improving performance of GAN 264 | - This approach through the game theory is highly unique and interesting 265 | - But, IMHO, there is not much real contribution. The algorithm is similar to WGAN-GP 266 | 267 | | DCGAN architecture | 268 | | :--------------------------------------: | 269 | | 120k | 270 | | ![dragan.30k](assets/celeba/dragan.fixed.120k.png) | 271 | 272 | The original paper has some bugs. One of those is image x is pertured only positive-sided. I applied two-sided perturbation as the author admitted this bug on the [GitHub](https://github.com/kodalinaveen3/DRAGAN). 273 | 274 | | LSUN | 275 | | :--------------------------------------: | 276 | | 200k | 277 | | ![dragan.200k](assets/lsun/dragan.200k.png) | 278 | 279 | ### CoulombGAN 280 | 281 | Unterthiner, Thomas, et al. "Coulomb GANs: Provably Optimal Nash Equilibria via Potential Fields." arXiv preprint arXiv:1708.08819 (2017). 282 | 283 | - CoulombGAN has also very interesting perspective - "Coulomb potential". 284 | - It is very interesting but I don't know whether it is GAN. 285 | - CoulombGAN tried to solve the diversity problem (mode collapse) 286 | 287 | 288 | G_lr=5e-4, D_lr=25e-5, z_dim=32. 289 | 290 | | DCGAN architecture | 291 | | :--------------------------------------: | 292 | | 200k | 293 | | ![coulombgan.200k](assets/celeba/coulombgan.200k.png) | 294 | 295 | The disadvantage of this model is that it takes a very long time to train despite the simplicity of network architecture. Further, like original GAN, there is no convergence measure. I thought that the potentials of fake samples served as a convergence measure, but it did not. 296 | 297 | 298 | 299 | 309 | 310 | ## Usage 311 | 312 | Download CelebA dataset: 313 | 314 | ``` 315 | $ python download.py celebA 316 | $ python download.py lsun 317 | ``` 318 | 319 | Convert images to tfrecords format: 320 | Options for converting are hard-coded, so ensure to modify it before run `convert.py`. In particular, LSUN dataset is provided in LMDB format. 321 | 322 | ``` 323 | $ python convert.py 324 | ``` 325 | 326 | Train: 327 | If you want to change the settings of each model, you must also modify code directly. 328 | 329 | ``` 330 | $ python train.py --help 331 | usage: train.py [-h] [--num_epochs NUM_EPOCHS] [--batch_size BATCH_SIZE] 332 | [--num_threads NUM_THREADS] --model MODEL [--name NAME] 333 | --dataset DATASET [--ckpt_step CKPT_STEP] [--renew] 334 | 335 | optional arguments: 336 | -h, --help show this help message and exit 337 | --num_epochs NUM_EPOCHS 338 | default: 20 339 | --batch_size BATCH_SIZE 340 | default: 128 341 | --num_threads NUM_THREADS 342 | # of data read threads (default: 4) 343 | --model MODEL DCGAN / LSGAN / WGAN / WGAN-GP / EBGAN / BEGAN / 344 | DRAGAN / CoulombGAN 345 | --name NAME default: name=model 346 | --dataset DATASET, -D DATASET 347 | CelebA / LSUN 348 | --ckpt_step CKPT_STEP 349 | # of steps for saving checkpoint (default: 5000) 350 | --renew train model from scratch - clean saved checkpoints and 351 | summaries 352 | ``` 353 | 354 | Monitor through TensorBoard: 355 | 356 | ``` 357 | $ tensorboard --logdir=summary/dataset/name 358 | ``` 359 | 360 | Evaluate (generate fake samples): 361 | 362 | ``` 363 | $ python eval.py --help 364 | usage: eval.py [-h] --model MODEL [--name NAME] --dataset DATASET 365 | [--sample_size SAMPLE_SIZE] 366 | 367 | optional arguments: 368 | -h, --help show this help message and exit 369 | --model MODEL DCGAN / LSGAN / WGAN / WGAN-GP / EBGAN / BEGAN / 370 | DRAGAN / CoulombGAN 371 | --name NAME default: name=model 372 | --dataset DATASET, -D DATASET 373 | CelebA / LSUN 374 | --sample_size SAMPLE_SIZE, -N SAMPLE_SIZE 375 | # of samples. It should be a square number. (default: 376 | 16) 377 | ``` 378 | 379 | 380 | ### Requirements 381 | 382 | - python 2.7 383 | - tensorflow >= 1.2 (verified on 1.2 and 1.3) 384 | - tqdm 385 | - (optional) pynvml - for automatic gpu selection 386 | 387 | ## Similar works 388 | 389 | - https://ajolicoeur.wordpress.com/cats/ 390 | - [wiseodd/generative-models](https://github.com/wiseodd/generative-models) 391 | - [hwalsuklee/tensorflow-generative-model-collections](https://github.com/hwalsuklee/tensorflow-generative-model-collections) 392 | - [sanghoon/tf-exercise-gan](https://github.com/sanghoon/tf-exercise-gan) 393 | - [YadiraF/GAN_Theories](https://github.com/YadiraF/GAN_Theories) 394 | 395 | 396 | -------------------------------------------------------------------------------- /assets/celeba/began.30k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/began.30k.png -------------------------------------------------------------------------------- /assets/celeba/began.50k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/began.50k.png -------------------------------------------------------------------------------- /assets/celeba/began.75k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/began.75k.png -------------------------------------------------------------------------------- /assets/celeba/began.M.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/began.M.png -------------------------------------------------------------------------------- /assets/celeba/coulombgan.200k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/coulombgan.200k.png -------------------------------------------------------------------------------- /assets/celeba/dcgan.G1e-3.30k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/dcgan.G1e-3.30k.png -------------------------------------------------------------------------------- /assets/celeba/dcgan.G2e-4.30k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/dcgan.G2e-4.30k.png -------------------------------------------------------------------------------- /assets/celeba/dcgan.G2e-4.40k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/dcgan.G2e-4.40k.png -------------------------------------------------------------------------------- /assets/celeba/dcgan.G2e-4.50k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/dcgan.G2e-4.50k.png -------------------------------------------------------------------------------- /assets/celeba/dragan.fixed.120k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/dragan.fixed.120k.png -------------------------------------------------------------------------------- /assets/celeba/ebgan.nopt.30k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/ebgan.nopt.30k.png -------------------------------------------------------------------------------- /assets/celeba/ebgan.nopt.graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/ebgan.nopt.graph.png -------------------------------------------------------------------------------- /assets/celeba/ebgan.pt.30k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/ebgan.pt.30k.png -------------------------------------------------------------------------------- /assets/celeba/ebgan.pt.graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/ebgan.pt.graph.png -------------------------------------------------------------------------------- /assets/celeba/lsgan.100.30k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/lsgan.100.30k.png -------------------------------------------------------------------------------- /assets/celeba/lsgan.1024.30k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/lsgan.1024.30k.png -------------------------------------------------------------------------------- /assets/celeba/lsgan.1024.40k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/lsgan.1024.40k.png -------------------------------------------------------------------------------- /assets/celeba/wgan-gp.dcgan.10k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan-gp.dcgan.10k.png -------------------------------------------------------------------------------- /assets/celeba/wgan-gp.dcgan.20k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan-gp.dcgan.20k.png -------------------------------------------------------------------------------- /assets/celeba/wgan-gp.dcgan.30k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan-gp.dcgan.30k.png -------------------------------------------------------------------------------- /assets/celeba/wgan-gp.dcgan.gp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan-gp.dcgan.gp.png -------------------------------------------------------------------------------- /assets/celeba/wgan-gp.dcgan.w_dist.expand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan-gp.dcgan.w_dist.expand.png -------------------------------------------------------------------------------- /assets/celeba/wgan-gp.dcgan.w_dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan-gp.dcgan.w_dist.png -------------------------------------------------------------------------------- /assets/celeba/wgan-gp.good.10k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan-gp.good.10k.png -------------------------------------------------------------------------------- /assets/celeba/wgan-gp.good.15k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan-gp.good.15k.png -------------------------------------------------------------------------------- /assets/celeba/wgan-gp.good.20k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan-gp.good.20k.png -------------------------------------------------------------------------------- /assets/celeba/wgan-gp.good.25k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan-gp.good.25k.png -------------------------------------------------------------------------------- /assets/celeba/wgan-gp.good.30k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan-gp.good.30k.png -------------------------------------------------------------------------------- /assets/celeba/wgan-gp.good.40k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan-gp.good.40k.png -------------------------------------------------------------------------------- /assets/celeba/wgan-gp.good.50k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan-gp.good.50k.png -------------------------------------------------------------------------------- /assets/celeba/wgan-gp.good.5k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan-gp.good.5k.png -------------------------------------------------------------------------------- /assets/celeba/wgan-gp.good.7k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan-gp.good.7k.png -------------------------------------------------------------------------------- /assets/celeba/wgan-gp.good.gp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan-gp.good.gp.png -------------------------------------------------------------------------------- /assets/celeba/wgan-gp.good.w_dist.expand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan-gp.good.w_dist.expand.png -------------------------------------------------------------------------------- /assets/celeba/wgan-gp.good.w_dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan-gp.good.w_dist.png -------------------------------------------------------------------------------- /assets/celeba/wgan.30k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan.30k.png -------------------------------------------------------------------------------- /assets/celeba/wgan.40k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan.40k.png -------------------------------------------------------------------------------- /assets/celeba/wgan.50k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan.50k.png -------------------------------------------------------------------------------- /assets/celeba/wgan.w_dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/celeba/wgan.w_dist.png -------------------------------------------------------------------------------- /assets/lsun/began.100k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/lsun/began.100k.png -------------------------------------------------------------------------------- /assets/lsun/began.150k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/lsun/began.150k.png -------------------------------------------------------------------------------- /assets/lsun/began.200k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/lsun/began.200k.png -------------------------------------------------------------------------------- /assets/lsun/began.250k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/lsun/began.250k.png -------------------------------------------------------------------------------- /assets/lsun/dcgan.100k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/lsun/dcgan.100k.png -------------------------------------------------------------------------------- /assets/lsun/dragan.200k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/lsun/dragan.200k.png -------------------------------------------------------------------------------- /assets/lsun/ebgan.80k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/lsun/ebgan.80k.png -------------------------------------------------------------------------------- /assets/lsun/lsgan.150k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/lsun/lsgan.150k.png -------------------------------------------------------------------------------- /assets/lsun/wgan-gp.150k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/lsun/wgan-gp.150k.png -------------------------------------------------------------------------------- /assets/lsun/wgan.230k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khanrc/tf.gans-comparison/4e79266b11a9a051499fa9befc8e6a2d2f7b5bf5/assets/lsun/wgan.230k.png -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from models import * 2 | 3 | 4 | model_zoo = ['DCGAN', 'LSGAN', 'WGAN', 'WGAN-GP', 'EBGAN', 'BEGAN', 'DRAGAN', 'CoulombGAN'] 5 | 6 | def get_model(mtype, name, training): 7 | model = None 8 | if mtype == 'DCGAN': 9 | model = dcgan.DCGAN 10 | elif mtype == 'LSGAN': 11 | model = lsgan.LSGAN 12 | elif mtype == 'WGAN': 13 | model = wgan.WGAN 14 | elif mtype == 'WGAN-GP': 15 | model = wgan_gp.WGAN_GP 16 | elif mtype == 'EBGAN': 17 | model = ebgan.EBGAN 18 | elif mtype == 'BEGAN': 19 | model = began.BEGAN 20 | elif mtype == 'DRAGAN': 21 | model = dragan.DRAGAN 22 | elif mtype == 'COULOMBGAN': 23 | model = coulombgan.CoulombGAN 24 | else: 25 | assert False, mtype + ' is not in the model zoo' 26 | 27 | assert model, mtype + ' is work in progress' 28 | 29 | return model(name=name, training=training) 30 | 31 | 32 | def get_dataset(dataset_name): 33 | celebA_64 = './data/celebA_tfrecords/*.tfrecord' 34 | celebA_128 = './data/celebA_128_tfrecords/*.tfrecord' 35 | lsun_bedroom_128 = './data/lsun/bedroom_128_tfrecords/*.tfrecord' 36 | 37 | if dataset_name == 'celeba': 38 | path = celebA_128 39 | n_examples = 202599 40 | elif dataset_name == 'lsun': 41 | path = lsun_bedroom_128 42 | n_examples = 3033042 43 | else: 44 | raise ValueError('{} is does not supported. dataset must be celeba or lsun.'.format(dataset_name)) 45 | 46 | return path, n_examples 47 | 48 | 49 | def pprint_args(FLAGS): 50 | print("\nParameters:") 51 | for attr, value in sorted(vars(FLAGS).items()): 52 | print("{}={}".format(attr.upper(), value)) 53 | print("") 54 | 55 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import scipy.misc 6 | import os 7 | import glob 8 | 9 | 10 | def _bytes_features(value): 11 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 12 | 13 | 14 | def _int64_features(value): 15 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 16 | 17 | 18 | # preproc for celebA 19 | def center_crop(im, output_size): 20 | output_height, output_width = output_size 21 | h, w = im.shape[:2] 22 | if h < output_height and w < output_width: 23 | raise ValueError("image is small") 24 | 25 | offset_h = int((h - output_height) / 2) 26 | offset_w = int((w - output_width) / 2) 27 | return im[offset_h:offset_h+output_height, offset_w:offset_w+output_width, :] 28 | 29 | 30 | def convert(source_dir, target_dir, crop_size, out_size, exts=[''], num_shards=128, tfrecords_prefix=''): 31 | if not tf.gfile.Exists(source_dir): 32 | print('source_dir does not exists') 33 | return 34 | 35 | if tfrecords_prefix and not tfrecords_prefix.endswith('-'): 36 | tfrecords_prefix += '-' 37 | 38 | if tf.gfile.Exists(target_dir): 39 | print("{} is Already exists".format(target_dir)) 40 | return 41 | else: 42 | tf.gfile.MakeDirs(target_dir) 43 | 44 | # get meta-data 45 | path_list = [] 46 | for ext in exts: 47 | pattern = '*.' + ext if ext != '' else '*' 48 | path = os.path.join(source_dir, pattern) 49 | path_list.extend(glob.glob(path)) 50 | 51 | # shuffle path_list 52 | np.random.shuffle(path_list) 53 | num_files = len(path_list) 54 | num_per_shard = num_files // num_shards # Last shard will have more files 55 | 56 | print('# of files: {}'.format(num_files)) 57 | print('# of shards: {}'.format(num_shards)) 58 | print('# files per shards: {}'.format(num_per_shard)) 59 | 60 | # convert to tfrecords 61 | shard_idx = 0 62 | writer = None 63 | for i, path in enumerate(path_list): 64 | if i % num_per_shard == 0 and shard_idx < num_shards: 65 | shard_idx += 1 66 | tfrecord_fn = '{}{:0>4d}-of-{:0>4d}.tfrecord'.format(tfrecords_prefix, shard_idx, num_shards) 67 | tfrecord_path = os.path.join(target_dir, tfrecord_fn) 68 | print("Writing {} ...".format(tfrecord_path)) 69 | if shard_idx > 1: 70 | writer.close() 71 | writer = tf.python_io.TFRecordWriter(tfrecord_path) 72 | 73 | # mode='RGB' read even grayscale image as RGB shape 74 | im = scipy.misc.imread(path, mode='RGB') 75 | # preproc 76 | try: 77 | im = center_crop(im, crop_size) 78 | except Exception as e: 79 | # print("im_path: {}".format(path)) 80 | # print("im_shape: {}".format(im.shape)) 81 | print("[Exception] {}".format(e)) 82 | continue 83 | 84 | im = scipy.misc.imresize(im, out_size) 85 | example = tf.train.Example(features=tf.train.Features(feature={ 86 | # "shape": _int64_features(im.shape), 87 | "image": _bytes_features([im.tostring()]) 88 | })) 89 | writer.write(example.SerializeToString()) 90 | 91 | writer.close() 92 | 93 | 94 | ''' Below function burrowed from https://github.com/fyu/lsun. 95 | Process: LMDB => images => tfrecords 96 | It is more efficient method to skip intermediate images, but that is a little messy job. 97 | The method through images is inefficient but convenient. 98 | ''' 99 | def export_images(db_path, out_dir, flat=False, limit=-1): 100 | print('Exporting {} to {}'.format(db_path, out_dir)) 101 | env = lmdb.open(db_path, map_size=1099511627776, max_readers=100, readonly=True) 102 | num_images = env.stat()['entries'] 103 | count = 0 104 | with env.begin(write=False) as txn: 105 | cursor = txn.cursor() 106 | for key, val in cursor: 107 | if not flat: 108 | image_out_dir = join(out_dir, '/'.join(key[:6])) 109 | else: 110 | image_out_dir = out_dir 111 | if not exists(image_out_dir): 112 | os.makedirs(image_out_dir) 113 | image_out_path = join(image_out_dir, key + '.webp') 114 | with open(image_out_path, 'w') as fp: 115 | fp.write(val) 116 | count += 1 117 | if count == limit: 118 | break 119 | if count % 10000 == 0: 120 | print('{}/{} ...'.format(count, num_images)) 121 | 122 | 123 | if __name__ == "__main__": 124 | # CelebA 125 | convert('./data/celebA', './data/celebA_128_tfrecords', crop_size=[128, 128], out_size=[128, 128], 126 | exts=['jpg'], num_shards=128, tfrecords_prefix='celebA') 127 | 128 | # LSUN 129 | # export_images('./tf.gans-comparison/data/lsun/bedroom_val_lmdb/', 130 | # './tf.gans-comparison/data/lsun/bedroom_val_images/', flat=True) 131 | # convert('./data/lsun/bedroom_train_images', './data/lsun/bedroom_128_tfrecords', crop_size=[128, 128], 132 | # out_size=[128, 128], exts=['webp'], num_shards=128, tfrecords_prefix='lsun_bedroom') 133 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import gzip 5 | import json 6 | import shutil 7 | import zipfile 8 | import argparse 9 | import requests 10 | import subprocess 11 | from tqdm import tqdm 12 | from six.moves import urllib 13 | 14 | """ 15 | Burrowed from https://github.com/carpedm20/DCGAN-tensorflow/blob/master/download.py 16 | """ 17 | 18 | """ 19 | Modification of https://github.com/stanfordnlp/treelstm/blob/master/scripts/download.py 20 | Downloads the following: 21 | - Celeb-A dataset 22 | - LSUN dataset 23 | - MNIST dataset 24 | """ 25 | 26 | 27 | parser = argparse.ArgumentParser(description='Download dataset for DCGAN.') 28 | parser.add_argument('datasets', metavar='N', type=str, nargs='+', choices=['celebA', 'lsun', 'mnist'], 29 | help='name of dataset to download [celebA, lsun, mnist]') 30 | 31 | def download(url, dirpath): 32 | filename = url.split('/')[-1] 33 | filepath = os.path.join(dirpath, filename) 34 | u = urllib.request.urlopen(url) 35 | f = open(filepath, 'wb') 36 | filesize = int(u.headers["Content-Length"]) 37 | print("Downloading: %s Bytes: %s" % (filename, filesize)) 38 | 39 | downloaded = 0 40 | block_sz = 8192 41 | status_width = 70 42 | while True: 43 | buf = u.read(block_sz) 44 | if not buf: 45 | print('') 46 | break 47 | else: 48 | print('', end='\r') 49 | downloaded += len(buf) 50 | f.write(buf) 51 | status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") % 52 | ('=' * int(float(downloaded) / filesize * status_width) + '>', downloaded * 100. / filesize)) 53 | print(status, end='') 54 | sys.stdout.flush() 55 | f.close() 56 | return filepath 57 | 58 | def download_file_from_google_drive(id, destination): 59 | URL = "https://docs.google.com/uc?export=download" 60 | session = requests.Session() 61 | 62 | response = session.get(URL, params={ 'id': id }, stream=True) 63 | token = get_confirm_token(response) 64 | 65 | if token: 66 | params = { 'id' : id, 'confirm' : token } 67 | response = session.get(URL, params=params, stream=True) 68 | 69 | save_response_content(response, destination) 70 | 71 | def get_confirm_token(response): 72 | for key, value in response.cookies.items(): 73 | if key.startswith('download_warning'): 74 | return value 75 | return None 76 | 77 | def save_response_content(response, destination, chunk_size=32*1024): 78 | total_size = int(response.headers.get('content-length', 0)) 79 | with open(destination, "wb") as f: 80 | for chunk in tqdm(response.iter_content(chunk_size), total=total_size, 81 | unit='B', unit_scale=True, desc=destination): 82 | if chunk: # filter out keep-alive new chunks 83 | f.write(chunk) 84 | 85 | def unzip(filepath): 86 | print("Extracting: " + filepath) 87 | dirpath = os.path.dirname(filepath) 88 | with zipfile.ZipFile(filepath) as zf: 89 | zf.extractall(dirpath) 90 | os.remove(filepath) 91 | 92 | def download_celeb_a(dirpath): 93 | data_dir = 'celebA' 94 | if os.path.exists(os.path.join(dirpath, data_dir)): 95 | print('Found Celeb-A - skip') 96 | return 97 | 98 | filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM" 99 | save_path = os.path.join(dirpath, filename) 100 | 101 | if os.path.exists(save_path): 102 | print('[*] {} already exists'.format(save_path)) 103 | else: 104 | download_file_from_google_drive(drive_id, save_path) 105 | 106 | zip_dir = '' 107 | with zipfile.ZipFile(save_path) as zf: 108 | zip_dir = zf.namelist()[0] 109 | zf.extractall(dirpath) 110 | os.remove(save_path) 111 | os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir)) 112 | 113 | def _list_categories(tag): 114 | url = 'http://lsun.cs.princeton.edu/htbin/list.cgi?tag=' + tag 115 | f = urllib.request.urlopen(url) 116 | return json.loads(f.read()) 117 | 118 | def _download_lsun(out_dir, category, set_name, tag): 119 | url = 'http://lsun.cs.princeton.edu/htbin/download.cgi?tag={tag}' \ 120 | '&category={category}&set={set_name}'.format(**locals()) 121 | print(url) 122 | if set_name == 'test': 123 | out_name = 'test_lmdb.zip' 124 | else: 125 | out_name = '{category}_{set_name}_lmdb.zip'.format(**locals()) 126 | out_path = os.path.join(out_dir, out_name) 127 | cmd = ['curl', url, '-o', out_path] 128 | print('Downloading', category, set_name, 'set') 129 | subprocess.call(cmd) 130 | 131 | def download_lsun(dirpath): 132 | data_dir = os.path.join(dirpath, 'lsun') 133 | if os.path.exists(data_dir): 134 | print('Found LSUN - skip') 135 | return 136 | else: 137 | os.mkdir(data_dir) 138 | 139 | tag = 'latest' 140 | #categories = _list_categories(tag) 141 | categories = ['bedroom'] 142 | 143 | for category in categories: 144 | _download_lsun(data_dir, category, 'train', tag) 145 | _download_lsun(data_dir, category, 'val', tag) 146 | _download_lsun(data_dir, '', 'test', tag) 147 | 148 | def download_mnist(dirpath): 149 | data_dir = os.path.join(dirpath, 'mnist') 150 | if os.path.exists(data_dir): 151 | print('Found MNIST - skip') 152 | return 153 | else: 154 | os.mkdir(data_dir) 155 | url_base = 'http://yann.lecun.com/exdb/mnist/' 156 | file_names = ['train-images-idx3-ubyte.gz', 157 | 'train-labels-idx1-ubyte.gz', 158 | 't10k-images-idx3-ubyte.gz', 159 | 't10k-labels-idx1-ubyte.gz'] 160 | for file_name in file_names: 161 | url = (url_base+file_name).format(**locals()) 162 | print(url) 163 | out_path = os.path.join(data_dir,file_name) 164 | cmd = ['curl', url, '-o', out_path] 165 | print('Downloading ', file_name) 166 | subprocess.call(cmd) 167 | cmd = ['gzip', '-d', out_path] 168 | print('Decompressing ', file_name) 169 | subprocess.call(cmd) 170 | 171 | def prepare_data_dir(path = './data'): 172 | if not os.path.exists(path): 173 | os.mkdir(path) 174 | 175 | if __name__ == '__main__': 176 | args = parser.parse_args() 177 | prepare_data_dir() 178 | 179 | if any(name in args.datasets for name in ['CelebA', 'celebA', 'celeba']): 180 | download_celeb_a('./data') 181 | if 'lsun' in args.datasets: 182 | download_lsun('./data') 183 | if 'mnist' in args.datasets: 184 | download_mnist('./data') 185 | 186 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | import tensorflow as tf 3 | import numpy as np 4 | import utils 5 | import config 6 | import os, glob 7 | import scipy.misc 8 | from argparse import ArgumentParser 9 | slim = tf.contrib.slim 10 | 11 | 12 | def build_parser(): 13 | parser = ArgumentParser() 14 | models_str = ' / '.join(config.model_zoo) 15 | parser.add_argument('--model', help=models_str, required=True) 16 | parser.add_argument('--name', help='default: name=model') 17 | parser.add_argument('--dataset', '-D', help='CelebA / LSUN', required=True) 18 | parser.add_argument('--sample_size', '-N', help='# of samples. It should be a square number. (default: 16)', 19 | default=16, type=int) 20 | 21 | return parser 22 | 23 | 24 | def sample_z(shape): 25 | return np.random.normal(size=shape) 26 | 27 | 28 | def get_all_checkpoints(ckpt_dir, force=False): 29 | ''' 30 | When the learning is interrupted and resumed, all checkpoints can not be fetched with get_checkpoint_state 31 | (The checkpoint state is rewritten from the point of resume). 32 | This function fetch all checkpoints forcely when arguments force=True. 33 | ''' 34 | 35 | if force: 36 | ckpts = os.listdir(ckpt_dir) # get all fns 37 | ckpts = map(lambda p: os.path.splitext(p)[0], ckpts) # del ext 38 | ckpts = set(ckpts) # unique 39 | ckpts = filter(lambda x: x.split('-')[-1].isdigit(), ckpts) # filter non-ckpt 40 | ckpts = sorted(ckpts, key=lambda x: int(x.split('-')[-1])) # sort 41 | ckpts = map(lambda x: os.path.join(ckpt_dir, x), ckpts) # fn => path 42 | else: 43 | ckpts = tf.train.get_checkpoint_state(ckpt_dir).all_model_checkpoint_paths 44 | 45 | return ckpts 46 | 47 | 48 | def eval(model, name, dataset, sample_shape=[4,4], load_all_ckpt=True): 49 | if name == None: 50 | name = model.name 51 | dir_name = os.path.join('eval', dataset, name) 52 | if tf.gfile.Exists(dir_name): 53 | tf.gfile.DeleteRecursively(dir_name) 54 | tf.gfile.MakeDirs(dir_name) 55 | 56 | restorer = tf.train.Saver(slim.get_model_variables()) 57 | 58 | config = tf.ConfigProto() 59 | best_gpu = utils.get_best_gpu() 60 | config.gpu_options.visible_device_list = str(best_gpu) 61 | with tf.Session(config=config) as sess: 62 | ckpt_path = os.path.join('checkpoints', dataset, name) 63 | ckpts = get_all_checkpoints(ckpt_path, force=load_all_ckpt) 64 | size = sample_shape[0] * sample_shape[1] 65 | 66 | z_ = sample_z([size, model.z_dim]) 67 | 68 | for v in ckpts: 69 | print("Evaluating {} ...".format(v)) 70 | restorer.restore(sess, v) 71 | global_step = int(v.split('/')[-1].split('-')[-1]) 72 | 73 | fake_samples = sess.run(model.fake_sample, {model.z: z_}) 74 | 75 | # inverse transform: [-1, 1] => [0, 1] 76 | fake_samples = (fake_samples + 1.) / 2. 77 | merged_samples = utils.merge(fake_samples, size=sample_shape) 78 | fn = "{:0>6d}.png".format(global_step) 79 | scipy.misc.imsave(os.path.join(dir_name, fn), merged_samples) 80 | 81 | 82 | ''' 83 | You can create a gif movie through imagemagick on the commandline: 84 | $ convert -delay 20 eval/* movie.gif 85 | ''' 86 | # def to_gif(dir_name='eval'): 87 | # images = [] 88 | # for path in glob.glob(os.path.join(dir_name, '*.png')): 89 | # im = scipy.misc.imread(path) 90 | # images.append(im) 91 | 92 | # # make_gif(images, dir_name + '/movie.gif', duration=10, true_image=True) 93 | # imageio.mimsave('movie.gif', images, duration=0.2) 94 | 95 | 96 | if __name__ == "__main__": 97 | parser = build_parser() 98 | FLAGS = parser.parse_args() 99 | FLAGS.model = FLAGS.model.upper() 100 | FLAGS.dataset = FLAGS.dataset.lower() 101 | if FLAGS.name is None: 102 | FLAGS.name = FLAGS.model.lower() 103 | config.pprint_args(FLAGS) 104 | 105 | N = FLAGS.sample_size**0.5 106 | assert N == int(N), 'sample size should be a square number' 107 | 108 | # training=False => build generator only 109 | model = config.get_model(FLAGS.model, FLAGS.name, training=False) 110 | eval(model, dataset=FLAGS.dataset, name=FLAGS.name, sample_shape=[int(N),int(N)], load_all_ckpt=True) 111 | -------------------------------------------------------------------------------- /inputpipe.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import tensorflow as tf 3 | 4 | 5 | def read_parse_preproc(filename_queue): 6 | ''' read, parse, and preproc single example. ''' 7 | with tf.variable_scope('read_parse_preproc'): 8 | reader = tf.TFRecordReader() 9 | key, records = reader.read(filename_queue) 10 | 11 | # parse records 12 | features = tf.parse_single_example( 13 | records, 14 | features={ 15 | "image": tf.FixedLenFeature([], tf.string) 16 | } 17 | ) 18 | 19 | image = tf.decode_raw(features["image"], tf.uint8) 20 | image = tf.reshape(image, [128, 128, 3]) # The image_shape must be explicitly specified 21 | image = tf.image.resize_images(image, [64, 64]) 22 | image = tf.cast(image, tf.float32) 23 | image = image / 127.5 - 1.0 # preproc - normalize 24 | 25 | return [image] 26 | 27 | 28 | # https://www.tensorflow.org/programmers_guide/reading_data 29 | def get_batch(tfrecords_list, batch_size, shuffle=False, num_threads=1, min_after_dequeue=None, num_epochs=None): 30 | name = "batch" if not shuffle else "shuffle_batch" 31 | with tf.variable_scope(name): 32 | filename_queue = tf.train.string_input_producer(tfrecords_list, shuffle=shuffle, num_epochs=num_epochs) 33 | data_point = read_parse_preproc(filename_queue) 34 | 35 | if min_after_dequeue is None: 36 | min_after_dequeue = batch_size * 10 37 | capacity = min_after_dequeue + 3*batch_size 38 | if shuffle: 39 | batch = tf.train.shuffle_batch(data_point, batch_size=batch_size, capacity=capacity, 40 | min_after_dequeue=min_after_dequeue, num_threads=num_threads, allow_smaller_final_batch=True) 41 | else: 42 | batch = tf.train.batch(data_point, batch_size, capacity=capacity, num_threads=num_threads, 43 | allow_smaller_final_batch=True) 44 | 45 | return batch 46 | 47 | 48 | def get_batch_join(tfrecords_list, batch_size, shuffle=False, num_threads=1, min_after_dequeue=None, num_epochs=None): 49 | name = "batch_join" if not shuffle else "shuffle_batch_join" 50 | with tf.variable_scope(name): 51 | filename_queue = tf.train.string_input_producer(tfrecords_list, shuffle=shuffle, num_epochs=num_epochs) 52 | example_list = [read_parse_preproc(filename_queue) for _ in range(num_threads)] 53 | 54 | if min_after_dequeue is None: 55 | min_after_dequeue = batch_size * 10 56 | capacity = min_after_dequeue + 3*batch_size 57 | if shuffle: 58 | batch = tf.train.shuffle_batch_join(tensors_list=example_list, batch_size=batch_size, capacity=capacity, 59 | min_after_dequeue=min_after_dequeue, allow_smaller_final_batch=True) 60 | else: 61 | batch = tf.train.batch_join(example_list, batch_size, capacity=capacity, allow_smaller_final_batch=True) 62 | 63 | return batch 64 | 65 | 66 | # interfaces 67 | def shuffle_batch_join(tfrecords_list, batch_size, num_threads, num_epochs, min_after_dequeue=None): 68 | return get_batch_join(tfrecords_list, batch_size, shuffle=True, num_threads=num_threads, 69 | num_epochs=num_epochs, min_after_dequeue=min_after_dequeue) 70 | 71 | def batch_join(tfrecords_list, batch_size, num_threads, num_epochs, min_after_dequeue=None): 72 | return get_batch_join(tfrecords_list, batch_size, shuffle=False, num_threads=num_threads, 73 | num_epochs=num_epochs, min_after_dequeue=min_after_dequeue) 74 | 75 | def shuffle_batch(tfrecords_list, batch_size, num_threads, num_epochs, min_after_dequeue=None): 76 | return get_batch(tfrecords_list, batch_size, shuffle=True, num_threads=num_threads, 77 | num_epochs=num_epochs, min_after_dequeue=min_after_dequeue) 78 | 79 | def batch(tfrecords_list, batch_size, num_threads, num_epochs, min_after_dequeue=None): 80 | return get_batch(tfrecords_list, batch_size, shuffle=False, num_threads=num_threads, 81 | num_epochs=num_epochs, min_after_dequeue=min_after_dequeue) 82 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile 2 | import glob 3 | 4 | 5 | def get_all_modules_cwd(): 6 | modules = glob.glob(dirname(__file__)+"/*.py") 7 | return [basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')] 8 | 9 | 10 | __all__ = get_all_modules_cwd() -------------------------------------------------------------------------------- /models/basemodel.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | '''BaseModel for Generative Adversarial Netowrks. 4 | ''' 5 | 6 | import tensorflow as tf 7 | slim = tf.contrib.slim 8 | 9 | 10 | class BaseModel(object): 11 | FAKE_MAX_OUTPUT = 6 12 | 13 | def __init__(self, name, training, D_lr, G_lr, image_shape=[64, 64, 3], z_dim=100): 14 | self.name = name 15 | self.shape = image_shape 16 | self.bn_params = { 17 | "decay": 0.99, 18 | "epsilon": 1e-5, 19 | "scale": True, 20 | "is_training": training 21 | } 22 | self.z_dim = z_dim 23 | self.D_lr = D_lr 24 | self.G_lr = G_lr 25 | self.args = vars(self).copy() # dict 26 | 27 | if training == True: 28 | self._build_train_graph() 29 | else: 30 | self._build_gen_graph() 31 | 32 | 33 | def _build_gen_graph(self): 34 | '''build computational graph for generation (evaluation)''' 35 | with tf.variable_scope(self.name): 36 | self.z = tf.placeholder(tf.float32, [None, self.z_dim]) 37 | self.fake_sample = tf.clip_by_value(self._generator(self.z), -1., 1.) 38 | 39 | 40 | def _build_train_graph(self, X): 41 | '''build computational graph for training''' 42 | pass -------------------------------------------------------------------------------- /models/began.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import tensorflow as tf 3 | slim = tf.contrib.slim 4 | from utils import expected_shape 5 | import ops 6 | from basemodel import BaseModel 7 | 8 | 9 | class BEGAN(BaseModel): 10 | def __init__(self, name, training, D_lr=1e-4, G_lr=1e-4, image_shape=[64, 64, 3], z_dim=64, gamma=0.5): 11 | self.gamma = gamma 12 | self.decay_step = 3000 13 | self.decay_rate = 0.95 14 | self.beta1 = 0.5 15 | self.lambd_k = 0.001 16 | self.nf = 128 17 | self.lr_lower_bound = 2e-5 18 | super(BEGAN, self).__init__(name=name, training=training, D_lr=D_lr, G_lr=G_lr, 19 | image_shape=image_shape, z_dim=z_dim) 20 | 21 | def _build_train_graph(self): 22 | with tf.variable_scope(self.name): 23 | X = tf.placeholder(tf.float32, [None] + self.shape) 24 | z = tf.placeholder(tf.float32, [None, self.z_dim]) 25 | global_step = tf.Variable(0, name='global_step', trainable=False) 26 | 27 | G = self._generator(z) 28 | # Discriminator is not called an energy function in BEGAN. The naming is from EBGAN. 29 | D_real_energy = self._discriminator(X) 30 | D_fake_energy = self._discriminator(G, reuse=True) 31 | 32 | k = tf.Variable(0., name='k', trainable=False) 33 | with tf.variable_scope('D_loss'): 34 | D_loss = D_real_energy - k * D_fake_energy 35 | with tf.variable_scope('G_loss'): 36 | G_loss = D_fake_energy 37 | with tf.variable_scope('balance'): 38 | balance = self.gamma*D_real_energy - D_fake_energy 39 | with tf.variable_scope('M'): 40 | M = D_real_energy + tf.abs(balance) 41 | 42 | D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name+'/D/') 43 | G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name+'/G/') 44 | 45 | D_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name+'/D/') 46 | G_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name+'/G/') 47 | 48 | # The authors suggest decaying learning rate by 0.5 when the convergence mesure stall 49 | # carpedm20 decays by 0.5 per 100000 steps 50 | # Heumi decays by 0.95 per 2000 steps (https://github.com/Heumi/BEGAN-tensorflow/) 51 | D_lr = tf.train.exponential_decay(self.D_lr, global_step, self.decay_step, self.decay_rate, staircase=True) 52 | D_lr = tf.maximum(D_lr, self.lr_lower_bound) 53 | G_lr = tf.train.exponential_decay(self.G_lr, global_step, self.decay_step, self.decay_rate, staircase=True) 54 | G_lr = tf.maximum(G_lr, self.lr_lower_bound) 55 | 56 | with tf.variable_scope('D_train_op'): 57 | with tf.control_dependencies(D_update_ops): 58 | D_train_op = tf.train.AdamOptimizer(learning_rate=D_lr, beta1=self.beta1).\ 59 | minimize(D_loss, var_list=D_vars) 60 | with tf.variable_scope('G_train_op'): 61 | with tf.control_dependencies(G_update_ops): 62 | G_train_op = tf.train.AdamOptimizer(learning_rate=G_lr, beta1=self.beta1).\ 63 | minimize(G_loss, var_list=G_vars, global_step=global_step) 64 | 65 | # It should be ops `define` under control_dependencies 66 | with tf.control_dependencies([D_train_op]): # should be iterable 67 | with tf.variable_scope('update_k'): 68 | update_k = tf.assign(k, tf.clip_by_value(k + self.lambd_k * balance, 0., 1.)) # define 69 | D_train_op = update_k # run op 70 | 71 | # summaries 72 | # per-step summary 73 | self.summary_op = tf.summary.merge([ 74 | tf.summary.scalar('G_loss', G_loss), 75 | tf.summary.scalar('D_loss', D_loss), 76 | tf.summary.scalar('D_energy/real', D_real_energy), 77 | tf.summary.scalar('D_energy/fake', D_fake_energy), 78 | tf.summary.scalar('convergence_measure', M), 79 | tf.summary.scalar('balance', balance), 80 | tf.summary.scalar('k', k), 81 | tf.summary.scalar('D_lr', D_lr), 82 | tf.summary.scalar('G_lr', G_lr) 83 | ]) 84 | 85 | # sparse-step summary 86 | # Generator of BEGAN does not use tanh activation func. 87 | # So the generated sample (fake sample) can exceed the image bound [-1, 1]. 88 | fake_sample = tf.clip_by_value(G, -1., 1.) 89 | tf.summary.image('fake_sample', fake_sample, max_outputs=self.FAKE_MAX_OUTPUT) 90 | tf.summary.histogram('G_hist', G) # for checking out of bound 91 | # histogram all varibles 92 | # for var in tf.trainable_variables(): 93 | # tf.summary.histogram(var.op.name, var) 94 | 95 | self.all_summary_op = tf.summary.merge_all() 96 | 97 | # accesible points 98 | self.X = X 99 | self.z = z 100 | self.D_train_op = D_train_op 101 | self.G_train_op = G_train_op 102 | self.fake_sample = fake_sample 103 | self.global_step = global_step 104 | 105 | def _encoder(self, X, reuse=False): 106 | with tf.variable_scope('encoder', reuse=reuse): 107 | nf = self.nf 108 | nh = self.z_dim 109 | 110 | with slim.arg_scope([slim.conv2d], kernel_size=[3,3], padding='SAME', activation_fn=tf.nn.elu): 111 | net = slim.conv2d(X, nf) 112 | 113 | net = slim.conv2d(net, nf) 114 | net = slim.conv2d(net, nf) 115 | net = slim.conv2d(net, nf*2, stride=2) # 32x32 116 | 117 | net = slim.conv2d(net, nf*2) 118 | net = slim.conv2d(net, nf*2) 119 | net = slim.conv2d(net, nf*3, stride=2) # 16x16 120 | 121 | net = slim.conv2d(net, nf*3) 122 | net = slim.conv2d(net, nf*3) 123 | net = slim.conv2d(net, nf*4, stride=2) # 8x8 124 | 125 | net = slim.conv2d(net, nf*4) 126 | net = slim.conv2d(net, nf*4) 127 | net = slim.conv2d(net, nf*4) 128 | 129 | net = slim.flatten(net) 130 | h = slim.fully_connected(net, nh, activation_fn=None) 131 | 132 | return h 133 | 134 | def _decoder(self, h, reuse=False): 135 | with tf.variable_scope('decoder', reuse=reuse): 136 | nf = self.nf 137 | nh = self.z_dim 138 | 139 | h0 = slim.fully_connected(h, 8*8*nf, activation_fn=None) # h0 140 | net = tf.reshape(h0, [-1, 8, 8, nf]) 141 | 142 | with slim.arg_scope([slim.conv2d], kernel_size=[3,3], padding='SAME', activation_fn=tf.nn.elu): 143 | net = slim.conv2d(net, nf) 144 | net = slim.conv2d(net, nf) 145 | net = tf.image.resize_nearest_neighbor(net, [16, 16]) # upsampling 146 | 147 | net = slim.conv2d(net, nf) 148 | net = slim.conv2d(net, nf) 149 | net = tf.image.resize_nearest_neighbor(net, [32, 32]) 150 | 151 | net = slim.conv2d(net, nf) 152 | net = slim.conv2d(net, nf) 153 | net = tf.image.resize_nearest_neighbor(net, [64, 64]) 154 | 155 | net = slim.conv2d(net, nf) 156 | net = slim.conv2d(net, nf) 157 | 158 | net = slim.conv2d(net, 3, activation_fn=None) 159 | 160 | return net 161 | 162 | def _discriminator(self, X, reuse=False): 163 | with tf.variable_scope('D', reuse=reuse): 164 | h = self._encoder(X, reuse=reuse) 165 | x_recon = self._decoder(h, reuse=reuse) 166 | 167 | energy = tf.abs(X-x_recon) # L1 loss 168 | energy = tf.reduce_mean(energy) 169 | 170 | return energy 171 | 172 | def _generator(self, z, reuse=False): 173 | with tf.variable_scope('G', reuse=reuse): 174 | x_fake = self._decoder(z, reuse=reuse) 175 | 176 | return x_fake 177 | -------------------------------------------------------------------------------- /models/coulombgan.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Reference code: https://github.com/bioinf-jku/coulomb_gan 3 | import tensorflow as tf 4 | import numpy as np 5 | slim = tf.contrib.slim 6 | from utils import expected_shape 7 | import ops 8 | from basemodel import BaseModel 9 | 10 | 11 | def sd_matrix(a, b, name='square_distance_matrix'): 12 | with tf.variable_scope(name): 13 | '''Square distance matrix 14 | a, b: [N, tensor] (N = batch size) 15 | return: [N, N] (square distance matrix for every tensor pairs) 16 | ''' 17 | batch_size = tf.shape(a)[0] 18 | a = tf.reshape(a, [batch_size, 1, -1]) 19 | b = tf.reshape(b, [1, batch_size, -1]) 20 | return tf.reduce_sum((b-a)**2, axis=2) 21 | 22 | 23 | def plummer_kernel(a, b, kernel_dim, kernel_eps, name='plummer_kernel'): 24 | # plummer kernel represents `influence`. 25 | with tf.variable_scope(name): 26 | r = sd_matrix(a, b) + kernel_eps**2 27 | d = kernel_dim-2 28 | return r**(-d/2.) 29 | 30 | 31 | # Burrowed from ref code and modified to paper-style. 32 | def get_potentials(x, y, kernel_dim, kernel_eps): 33 | ''' 34 | This is alsmost the same `calculate_potential`, but 35 | px, py = get_potentials(x, y) 36 | is faster than: 37 | px = calculate_potential(x, y, x) 38 | py = calculate_potential(x, y, y) 39 | because we calculate the cross terms only once. 40 | ''' 41 | x_fixed = tf.stop_gradient(x) 42 | y_fixed = tf.stop_gradient(y) 43 | pk_xx = plummer_kernel(x_fixed, x, kernel_dim, kernel_eps) 44 | pk_yx = plummer_kernel(y, x, kernel_dim, kernel_eps) 45 | pk_yy = plummer_kernel(y_fixed, y, kernel_dim, kernel_eps) 46 | batch_size = tf.shape(x)[0] 47 | pk_xx = tf.matrix_set_diag(pk_xx, tf.ones(shape=[batch_size], dtype=pk_xx.dtype)) 48 | pk_yy = tf.matrix_set_diag(pk_yy, tf.ones(shape=[batch_size], dtype=pk_yy.dtype)) 49 | kxx = tf.reduce_mean(pk_xx, axis=0) 50 | kyx = tf.reduce_mean(pk_yx, axis=0) 51 | kxy = tf.reduce_mean(pk_yx, axis=1) 52 | kyy = tf.reduce_mean(pk_yy, axis=0) 53 | pot_x = kyx - kxx 54 | pot_y = kyy - kyx 55 | pot_x = tf.reshape(pot_x, [batch_size, -1]) 56 | pot_y = tf.reshape(pot_y, [batch_size, -1]) 57 | return pot_x, pot_y 58 | 59 | 60 | def calc_potential(x, y, a, kernel_dim, kernel_eps, name='potential'): 61 | '''Paper notations are used in this function 62 | x: fake 63 | y: real 64 | 65 | return: potential of a 66 | ''' 67 | 68 | with tf.variable_scope(name): 69 | # Why does stop_gradient not apply to a? 70 | x = tf.stop_gradient(x) 71 | y = tf.stop_gradient(y) 72 | kxa = tf.reduce_mean(plummer_kernel(x, a, kernel_dim, kernel_eps), axis=0) 73 | kya = tf.reduce_mean(plummer_kernel(y, a, kernel_dim, kernel_eps), axis=0) 74 | # kxa: influence of fake on a 75 | # kya: influence of real on a 76 | p = kya - kxa 77 | p = tf.reshape(p, [-1, 1]) 78 | return p 79 | 80 | 81 | ''' 82 | Originally, D_lr=5e-5 and G_lr=1e-4 in the paper. 83 | It takes too long to train, so I used higher learning rates (5 times each). 84 | ''' 85 | class CoulombGAN(BaseModel): 86 | def __init__(self, name, training, D_lr=25e-5, G_lr=5e-4, image_shape=[64, 64, 3], z_dim=32): 87 | self.beta1 = 0.5 88 | self.kernel_dim = 3 89 | self.kernel_eps = 1. 90 | super(CoulombGAN, self).__init__(name=name, training=training, D_lr=D_lr, G_lr=G_lr, 91 | image_shape=image_shape, z_dim=z_dim) 92 | 93 | def _build_train_graph(self): 94 | with tf.variable_scope(self.name): 95 | X = tf.placeholder(tf.float32, [None] + self.shape) 96 | z = tf.placeholder(tf.float32, [None, self.z_dim]) 97 | global_step = tf.Variable(0, name='global_step', trainable=False) 98 | 99 | G = self._generator(z) 100 | D_real = self._discriminator(X) 101 | D_fake = self._discriminator(G, reuse=True) 102 | 103 | ''' 104 | D estimates potential and G minimize D_fake (estimated potential of fake). 105 | It means that minimize distance the between real and fake 106 | while maximizing the distance between fake and fake. 107 | 108 | P(a) = k(a, real) - k(a, fake). 109 | So, 110 | P(real) = k(real, real) - k(real, fake), 111 | P(fake) = k(fake, real) - k(fake, fake). 112 | ''' 113 | 114 | # get_potentials function is more efficient but it is more readable and intuitive 115 | # to calculate potential for each real and fake samples separately. 116 | # Further, there was no significant difference in efficiency as a result of the experiment. 117 | P_real = calc_potential(G, X, X, kernel_dim=self.kernel_dim, kernel_eps=self.kernel_eps, name='P_real') 118 | P_fake = calc_potential(G, X, G, kernel_dim=self.kernel_dim, kernel_eps=self.kernel_eps, name='P_fake') 119 | D_loss_real = tf.losses.mean_squared_error(D_real, P_real) 120 | D_loss_fake = tf.losses.mean_squared_error(D_fake, P_fake) 121 | D_loss = D_loss_real + D_loss_fake 122 | G_loss = -tf.reduce_mean(D_fake) 123 | 124 | D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name+'/D/') 125 | G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name+'/G/') 126 | 127 | D_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name+'/D/') 128 | G_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name+'/G/') 129 | 130 | with tf.control_dependencies(D_update_ops): 131 | D_train_op = tf.train.AdamOptimizer(learning_rate=self.D_lr, beta1=self.beta1).\ 132 | minimize(D_loss, var_list=D_vars) 133 | with tf.control_dependencies(G_update_ops): 134 | G_train_op = tf.train.AdamOptimizer(learning_rate=self.G_lr, beta1=self.beta1).\ 135 | minimize(G_loss, var_list=G_vars, global_step=global_step) 136 | 137 | # summaries 138 | # per-step summary 139 | self.summary_op = tf.summary.merge([ 140 | tf.summary.scalar('G_loss', G_loss), 141 | tf.summary.scalar('D_loss', D_loss), 142 | tf.summary.scalar('potential/real_mean', tf.reduce_mean(P_real)), 143 | tf.summary.scalar('potential/fake_mean', tf.reduce_mean(P_fake)) 144 | # tf.summary.scalar('potential/real', P_real), 145 | # tf.summary.scalar('potential/fake', P_fake), 146 | # tf.summary.scalar('disc/real', D_real), 147 | # tf.summary.scalar('disc/fake', D_fake) 148 | ]) 149 | 150 | # sparse-step summary 151 | tf.summary.image('fake_sample', G, max_outputs=self.FAKE_MAX_OUTPUT) 152 | tf.summary.histogram('potential/real', P_real) 153 | tf.summary.histogram('potential/fake', P_fake) 154 | self.all_summary_op = tf.summary.merge_all() 155 | 156 | # accesible points 157 | self.X = X 158 | self.z = z 159 | self.D_train_op = D_train_op 160 | self.G_train_op = G_train_op 161 | self.fake_sample = G 162 | self.global_step = global_step 163 | 164 | # Discriminator of CoulombGAN uses double channels of DCGAN 165 | def _discriminator(self, X, reuse=False): 166 | with tf.variable_scope('D', reuse=reuse): 167 | net = X 168 | 169 | with slim.arg_scope([slim.conv2d], kernel_size=[5,5], stride=2, padding='SAME', activation_fn=ops.lrelu, 170 | normalizer_fn=slim.batch_norm, normalizer_params=self.bn_params): 171 | net = slim.conv2d(net, 128, normalizer_fn=None) 172 | net = slim.conv2d(net, 256) 173 | net = slim.conv2d(net, 512) 174 | net = slim.conv2d(net, 1024) 175 | expected_shape(net, [4, 4, 1024]) 176 | 177 | net = slim.flatten(net) 178 | logits = slim.fully_connected(net, 1, activation_fn=None) 179 | 180 | return logits # potential 181 | 182 | def _generator(self, z, reuse=False): 183 | with tf.variable_scope('G', reuse=reuse): 184 | net = z 185 | net = slim.fully_connected(net, 4*4*1024, activation_fn=tf.nn.relu) 186 | net = tf.reshape(net, [-1, 4, 4, 1024]) 187 | 188 | with slim.arg_scope([slim.conv2d_transpose], kernel_size=[5,5], stride=2, padding='SAME', 189 | activation_fn=tf.nn.relu, normalizer_fn=slim.batch_norm, normalizer_params=self.bn_params): 190 | net = slim.conv2d_transpose(net, 512) 191 | expected_shape(net, [8, 8, 512]) 192 | net = slim.conv2d_transpose(net, 256) 193 | expected_shape(net, [16, 16, 256]) 194 | net = slim.conv2d_transpose(net, 128) 195 | expected_shape(net, [32, 32, 128]) 196 | net = slim.conv2d_transpose(net, 3, activation_fn=tf.nn.tanh, normalizer_fn=None) 197 | expected_shape(net, [64, 64, 3]) 198 | 199 | return net 200 | -------------------------------------------------------------------------------- /models/dcgan.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import tensorflow as tf 3 | slim = tf.contrib.slim 4 | from utils import expected_shape 5 | import ops 6 | from basemodel import BaseModel 7 | 8 | '''Original hyperparams: 9 | optimizer - SGD 10 | init - stddev 0.02 11 | ''' 12 | 13 | class DCGAN(BaseModel): 14 | def __init__(self, name, training, D_lr=2e-4, G_lr=2e-4, image_shape=[64, 64, 3], z_dim=100): 15 | self.beta1 = 0.5 16 | super(DCGAN, self).__init__(name=name, training=training, D_lr=D_lr, G_lr=G_lr, 17 | image_shape=image_shape, z_dim=z_dim) 18 | 19 | def _build_train_graph(self): 20 | with tf.variable_scope(self.name): 21 | X = tf.placeholder(tf.float32, [None] + self.shape) 22 | z = tf.placeholder(tf.float32, [None, self.z_dim]) 23 | global_step = tf.Variable(0, name='global_step', trainable=False) 24 | 25 | G = self._generator(z) 26 | D_real_prob, D_real_logits = self._discriminator(X) 27 | D_fake_prob, D_fake_logits = self._discriminator(G, reuse=True) 28 | 29 | G_loss = tf.losses.sigmoid_cross_entropy(tf.ones_like(D_fake_logits), logits=D_fake_logits) 30 | D_loss_real = tf.losses.sigmoid_cross_entropy(tf.ones_like(D_real_logits), logits=D_real_logits) 31 | D_loss_fake = tf.losses.sigmoid_cross_entropy(tf.zeros_like(D_fake_logits), logits=D_fake_logits) 32 | D_loss = D_loss_real + D_loss_fake 33 | 34 | D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name+'/D/') 35 | G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name+'/G/') 36 | 37 | D_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name+'/D/') 38 | G_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name+'/G/') 39 | 40 | with tf.control_dependencies(D_update_ops): 41 | D_train_op = tf.train.AdamOptimizer(learning_rate=self.D_lr, beta1=self.beta1).\ 42 | minimize(D_loss, var_list=D_vars) 43 | with tf.control_dependencies(G_update_ops): 44 | # learning rate 2e-4/1e-3 45 | G_train_op = tf.train.AdamOptimizer(learning_rate=self.G_lr, beta1=self.beta1).\ 46 | minimize(G_loss, var_list=G_vars, global_step=global_step) 47 | 48 | # summaries 49 | # per-step summary 50 | self.summary_op = tf.summary.merge([ 51 | tf.summary.scalar('G_loss', G_loss), 52 | tf.summary.scalar('D_loss', D_loss), 53 | tf.summary.scalar('D_loss/real', D_loss_real), 54 | tf.summary.scalar('D_loss/fake', D_loss_fake) 55 | ]) 56 | 57 | # sparse-step summary 58 | tf.summary.image('fake_sample', G, max_outputs=self.FAKE_MAX_OUTPUT) 59 | tf.summary.histogram('real_probs', D_real_prob) 60 | tf.summary.histogram('fake_probs', D_fake_prob) 61 | self.all_summary_op = tf.summary.merge_all() 62 | 63 | # accesible points 64 | self.X = X 65 | self.z = z 66 | self.D_train_op = D_train_op 67 | self.G_train_op = G_train_op 68 | self.fake_sample = G 69 | self.global_step = global_step 70 | 71 | def _discriminator(self, X, reuse=False): 72 | with tf.variable_scope('D', reuse=reuse): 73 | net = X 74 | 75 | with slim.arg_scope([slim.conv2d], kernel_size=[5,5], stride=2, padding='SAME', activation_fn=ops.lrelu, 76 | normalizer_fn=slim.batch_norm, normalizer_params=self.bn_params): 77 | net = slim.conv2d(net, 64, normalizer_fn=None) 78 | expected_shape(net, [32, 32, 64]) 79 | net = slim.conv2d(net, 128) 80 | expected_shape(net, [16, 16, 128]) 81 | net = slim.conv2d(net, 256) 82 | expected_shape(net, [8, 8, 256]) 83 | net = slim.conv2d(net, 512) 84 | expected_shape(net, [4, 4, 512]) 85 | 86 | net = slim.flatten(net) 87 | logits = slim.fully_connected(net, 1, activation_fn=None) 88 | prob = tf.sigmoid(logits) 89 | 90 | return prob, logits 91 | 92 | def _generator(self, z, reuse=False): 93 | with tf.variable_scope('G', reuse=reuse): 94 | net = z 95 | net = slim.fully_connected(net, 4*4*1024, activation_fn=tf.nn.relu) 96 | net = tf.reshape(net, [-1, 4, 4, 1024]) 97 | 98 | with slim.arg_scope([slim.conv2d_transpose], kernel_size=[5,5], stride=2, padding='SAME', 99 | activation_fn=tf.nn.relu, normalizer_fn=slim.batch_norm, normalizer_params=self.bn_params): 100 | net = slim.conv2d_transpose(net, 512) 101 | expected_shape(net, [8, 8, 512]) 102 | net = slim.conv2d_transpose(net, 256) 103 | expected_shape(net, [16, 16, 256]) 104 | net = slim.conv2d_transpose(net, 128) 105 | expected_shape(net, [32, 32, 128]) 106 | net = slim.conv2d_transpose(net, 3, activation_fn=tf.nn.tanh, normalizer_fn=None) 107 | expected_shape(net, [64, 64, 3]) 108 | 109 | return net 110 | -------------------------------------------------------------------------------- /models/dragan.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import tensorflow as tf 3 | slim = tf.contrib.slim 4 | from utils import expected_shape 5 | import ops 6 | from basemodel import BaseModel 7 | 8 | ''' 9 | DRAGAN has similar gradient penalty to WGAN-GP, although different motivation. 10 | It is also similar to DCGAN except for gradient penalty. 11 | ''' 12 | 13 | class DRAGAN(BaseModel): 14 | def __init__(self, name, training, D_lr=1e-4, G_lr=1e-4, image_shape=[64, 64, 3], z_dim=100): 15 | self.beta1 = 0.5 16 | self.beta2 = 0.9 17 | self.ld = 10. # lambda 18 | self.C = 0.5 19 | super(DRAGAN, self).__init__(name=name, training=training, D_lr=D_lr, G_lr=G_lr, 20 | image_shape=image_shape, z_dim=z_dim) 21 | 22 | def _build_train_graph(self): 23 | with tf.variable_scope(self.name): 24 | X = tf.placeholder(tf.float32, [None] + self.shape) 25 | z = tf.placeholder(tf.float32, [None, self.z_dim]) 26 | global_step = tf.Variable(0, name='global_step', trainable=False) 27 | 28 | G = self._generator(z) 29 | D_real_prob, D_real_logits = self._discriminator(X) 30 | D_fake_prob, D_fake_logits = self._discriminator(G, reuse=True) 31 | 32 | G_loss = tf.losses.sigmoid_cross_entropy(tf.ones_like(D_fake_logits), logits=D_fake_logits) 33 | D_loss_real = tf.losses.sigmoid_cross_entropy(tf.ones_like(D_real_logits), logits=D_real_logits) 34 | D_loss_fake = tf.losses.sigmoid_cross_entropy(tf.zeros_like(D_fake_logits), logits=D_fake_logits) 35 | D_loss = D_loss_real + D_loss_fake 36 | 37 | # Gradient Penalty (GP) 38 | # perturbed minibatch: x_noise = x_i + noise_i 39 | # x_hat = alpha*x + (1-alpha)*x_noise = x_i + (1-alpha)*noise_i 40 | 41 | shape = tf.shape(X) 42 | eps = tf.random_uniform(shape=shape, minval=0., maxval=1.) 43 | x_mean, x_var = tf.nn.moments(X, axes=[0,1,2,3]) 44 | x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region 45 | noise = self.C*x_std*eps # delta in paper 46 | # Author suggested U[0,1] in original paper, but he admitted it is bug in github 47 | # (https://github.com/kodalinaveen3/DRAGAN). It should be two-sided. 48 | alpha = tf.random_uniform(shape=[shape[0], 1, 1, 1], minval=-1., maxval=1.) 49 | xhat = tf.clip_by_value(X + alpha*noise, -1., 1.) # x_hat should be in the space of X 50 | 51 | D_xhat_prob, D_xhat_logits = self._discriminator(xhat, reuse=True) 52 | # Originally, the paper suggested D_xhat_prob instead of D_xhat_logits. 53 | # But D_xhat_prob (D with sigmoid) causes numerical problem (NaN in gradient). 54 | D_xhat_grad = tf.gradients(D_xhat_logits, xhat)[0] # gradient of D(x_hat) 55 | D_xhat_grad_norm = tf.norm(D_xhat_grad, axis=1) # l2 norm 56 | # D_xhat_grad_norm = tf.sqrt(tf.reduce_sum(tf.square(D_xhat_grad), axis=[1])) 57 | GP = self.ld * tf.reduce_mean(tf.square(D_xhat_grad_norm - 1.)) 58 | D_loss += GP 59 | 60 | D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name+'/discriminator/') 61 | G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name+'/generator/') 62 | 63 | # DRAGAN does not use BN, so you don't need to set control dependencies for update ops. 64 | D_train_op = tf.train.AdamOptimizer(learning_rate=self.D_lr, beta1=self.beta1, beta2=self.beta2).\ 65 | minimize(D_loss, var_list=D_vars) 66 | G_train_op = tf.train.AdamOptimizer(learning_rate=self.G_lr, beta1=self.beta1, beta2=self.beta2).\ 67 | minimize(G_loss, var_list=G_vars, global_step=global_step) 68 | 69 | # summaries 70 | # per-step summary 71 | self.summary_op = tf.summary.merge([ 72 | tf.summary.scalar('G_loss', G_loss), 73 | tf.summary.scalar('D_loss', D_loss), 74 | tf.summary.scalar('GP', GP) 75 | ]) 76 | 77 | # sparse-step summary 78 | tf.summary.image('fake_sample', G, max_outputs=self.FAKE_MAX_OUTPUT) 79 | tf.summary.histogram('real_probs', D_real_prob) 80 | tf.summary.histogram('fake_probs', D_fake_prob) 81 | self.all_summary_op = tf.summary.merge_all() 82 | 83 | # accesible points 84 | self.X = X 85 | self.z = z 86 | self.D_train_op = D_train_op 87 | self.G_train_op = G_train_op 88 | self.fake_sample = G 89 | self.global_step = global_step 90 | 91 | # DRAGAN does not use BN 92 | # DCGAN architecture 93 | def _discriminator(self, X, reuse=False): 94 | with tf.variable_scope('discriminator', reuse=reuse): 95 | net = X 96 | 97 | with slim.arg_scope([slim.conv2d], kernel_size=[5,5], stride=2, activation_fn=ops.lrelu): 98 | net = slim.conv2d(net, 64) 99 | expected_shape(net, [32, 32, 64]) 100 | net = slim.conv2d(net, 128) 101 | expected_shape(net, [16, 16, 128]) 102 | net = slim.conv2d(net, 256) 103 | expected_shape(net, [8, 8, 256]) 104 | net = slim.conv2d(net, 512) 105 | expected_shape(net, [4, 4, 512]) 106 | 107 | net = slim.flatten(net) 108 | logits = slim.fully_connected(net, 1, activation_fn=None) 109 | prob = tf.nn.sigmoid(logits) 110 | 111 | return prob, logits 112 | 113 | def _generator(self, z, reuse=False): 114 | with tf.variable_scope('generator', reuse=reuse): 115 | net = z 116 | net = slim.fully_connected(net, 4*4*1024, activation_fn=tf.nn.relu) 117 | net = tf.reshape(net, [-1, 4, 4, 1024]) 118 | 119 | with slim.arg_scope([slim.conv2d_transpose], kernel_size=[5,5], stride=2, activation_fn=tf.nn.relu): 120 | net = slim.conv2d_transpose(net, 512) 121 | expected_shape(net, [8, 8, 512]) 122 | net = slim.conv2d_transpose(net, 256) 123 | expected_shape(net, [16, 16, 256]) 124 | net = slim.conv2d_transpose(net, 128) 125 | expected_shape(net, [32, 32, 128]) 126 | net = slim.conv2d_transpose(net, 3, activation_fn=tf.nn.tanh, normalizer_fn=None) 127 | expected_shape(net, [64, 64, 3]) 128 | 129 | return net 130 | -------------------------------------------------------------------------------- /models/ebgan.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import tensorflow as tf 3 | slim = tf.contrib.slim 4 | from utils import expected_shape 5 | import ops 6 | from basemodel import BaseModel 7 | 8 | 9 | class EBGAN(BaseModel): 10 | def __init__(self, name, training, D_lr=1e-3, G_lr=1e-3, image_shape=[64, 64, 3], z_dim=100, 11 | pt_weight=0.1, margin=20.): 12 | ''' The default value of pt_weight and margin is taken from the paper for celebA. ''' 13 | self.pt_weight = pt_weight 14 | self.m = margin 15 | self.beta1 = 0.5 16 | super(EBGAN, self).__init__(name=name, training=training, D_lr=D_lr, G_lr=G_lr, 17 | image_shape=image_shape, z_dim=z_dim) 18 | 19 | def _build_train_graph(self): 20 | with tf.variable_scope(self.name): 21 | X = tf.placeholder(tf.float32, [None] + self.shape) 22 | z = tf.placeholder(tf.float32, [None, self.z_dim]) 23 | global_step = tf.Variable(0, name='global_step', trainable=False) 24 | 25 | G = self._generator(z) 26 | D_real_latent, D_real_energy = self._discriminator(X) 27 | D_fake_latent, D_fake_energy = self._discriminator(G, reuse=True) 28 | 29 | D_fake_hinge = tf.maximum(0., self.m - D_fake_energy) # hinge_loss 30 | D_loss = D_real_energy + D_fake_hinge 31 | G_loss = D_fake_energy 32 | PT = self.pt_regularizer(D_fake_latent) 33 | pt_loss = self.pt_weight * PT 34 | G_loss += pt_loss 35 | 36 | D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name+'/D/') 37 | G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name+'/G/') 38 | 39 | D_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name+'/D/') 40 | G_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name+'/G/') 41 | 42 | with tf.control_dependencies(D_update_ops): 43 | D_train_op = tf.train.AdamOptimizer(learning_rate=self.D_lr, beta1=self.beta1).\ 44 | minimize(D_loss, var_list=D_vars) 45 | with tf.control_dependencies(G_update_ops): 46 | G_train_op = tf.train.AdamOptimizer(learning_rate=self.G_lr, beta1=self.beta1).\ 47 | minimize(G_loss, var_list=G_vars, global_step=global_step) 48 | 49 | # summaries 50 | # per-step summary 51 | self.summary_op = tf.summary.merge([ 52 | tf.summary.scalar('G_loss', G_loss), 53 | tf.summary.scalar('D_loss', D_loss), 54 | tf.summary.scalar('PT', PT), 55 | tf.summary.scalar('pt_loss', pt_loss), 56 | tf.summary.scalar('D_energy/real', D_real_energy), 57 | tf.summary.scalar('D_energy/fake', D_fake_energy), 58 | tf.summary.scalar('D_fake_hinge', D_fake_hinge) 59 | ]) 60 | 61 | # sparse-step summary 62 | tf.summary.image('fake_sample', G, max_outputs=self.FAKE_MAX_OUTPUT) 63 | self.all_summary_op = tf.summary.merge_all() 64 | 65 | # accesible points 66 | self.X = X 67 | self.z = z 68 | self.D_train_op = D_train_op 69 | self.G_train_op = G_train_op 70 | self.fake_sample = G 71 | self.global_step = global_step 72 | 73 | def _discriminator(self, X, reuse=False): 74 | with tf.variable_scope('D', reuse=reuse): 75 | net = X 76 | 77 | with slim.arg_scope([slim.conv2d, slim.conv2d_transpose], kernel_size=[4,4], stride=2, padding='SAME', 78 | activation_fn=ops.lrelu, normalizer_fn=slim.batch_norm, normalizer_params=self.bn_params): 79 | # encoder 80 | net = slim.conv2d(net, 64, normalizer_fn=None) # 32x32 81 | net = slim.conv2d(net, 128) # 16x16 82 | net = slim.conv2d(net, 256) # 8x8 83 | latent = net 84 | expected_shape(latent, [8, 8, 256]) 85 | # decoder 86 | net = slim.conv2d_transpose(net, 128) # 16x16 87 | net = slim.conv2d_transpose(net, 64) # 32x32 88 | x_recon = slim.conv2d_transpose(net, 3, activation_fn=None, normalizer_fn=None) 89 | expected_shape(x_recon, [64, 64, 3]) 90 | 91 | energy = tf.sqrt(tf.reduce_sum(tf.square(X-x_recon), axis=[1,2,3])) # l2-norm error 92 | energy = tf.reduce_mean(energy) 93 | 94 | return latent, energy 95 | 96 | def _generator(self, z, reuse=False): 97 | with tf.variable_scope('G', reuse=reuse): 98 | net = z 99 | net = slim.fully_connected(net, 4*4*1024, activation_fn=tf.nn.relu) 100 | net = tf.reshape(net, [-1, 4, 4, 1024]) 101 | 102 | with slim.arg_scope([slim.conv2d_transpose], kernel_size=[4,4], stride=2, padding='SAME', 103 | activation_fn=tf.nn.relu, normalizer_fn=slim.batch_norm, normalizer_params=self.bn_params): 104 | net = slim.conv2d_transpose(net, 512) 105 | expected_shape(net, [8, 8, 512]) 106 | net = slim.conv2d_transpose(net, 256) 107 | expected_shape(net, [16, 16, 256]) 108 | net = slim.conv2d_transpose(net, 128) 109 | expected_shape(net, [32, 32, 128]) 110 | net = slim.conv2d_transpose(net, 3, activation_fn=tf.nn.tanh, normalizer_fn=None) 111 | expected_shape(net, [64, 64, 3]) 112 | 113 | return net 114 | 115 | # lf: latent features 116 | def pt_regularizer(self, lf): 117 | eps = 1e-8 # epsilon for numerical stability 118 | lf = slim.flatten(lf) 119 | # l2_norm = tf.sqrt(tf.reduce_sum(tf.square(lf), axis=1, keep_dims=True)) 120 | l2_norm = tf.norm(lf, axis=1, keep_dims=True) 121 | expected_shape(l2_norm, [1]) 122 | unit_lf = lf / (l2_norm + eps) 123 | cos_sim = tf.square(tf.matmul(unit_lf, unit_lf, transpose_b=True)) # [N, h_dim] x [h_dim, N] = [N, N] 124 | N = tf.cast(tf.shape(lf)[0], tf.float32) # batch_size 125 | pt_loss = (tf.reduce_sum(cos_sim)-N) / (N*(N-1)) 126 | return pt_loss 127 | 128 | -------------------------------------------------------------------------------- /models/lsgan.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import tensorflow as tf 3 | slim = tf.contrib.slim 4 | from utils import expected_shape 5 | import ops 6 | from basemodel import BaseModel 7 | 8 | 9 | class LSGAN(BaseModel): 10 | def __init__(self, name, training, D_lr=1e-3, G_lr=1e-3, image_shape=[64, 64, 3], z_dim=1024, a=0., b=1., c=1.): 11 | ''' 12 | a: fake label 13 | b: real label 14 | c: real label for G (The value that G wants to deceive D - intuitively same as real label b) 15 | 16 | Pearson chi-square divergence: a=-1, b=1, c=0. 17 | Intuitive (real label 1, fake label 0): a=0, b=c=1. 18 | ''' 19 | self.a = a 20 | self.b = b 21 | self.c = c 22 | self.beta1 = 0.5 23 | super(LSGAN, self).__init__(name=name, training=training, D_lr=D_lr, G_lr=G_lr, 24 | image_shape=image_shape, z_dim=z_dim) 25 | 26 | def _build_train_graph(self): 27 | with tf.variable_scope(self.name): 28 | X = tf.placeholder(tf.float32, [None] + self.shape) 29 | z = tf.placeholder(tf.float32, [None, self.z_dim]) 30 | global_step = tf.Variable(0, name='global_step', trainable=False) 31 | 32 | G = self._generator(z) 33 | D_real = self._discriminator(X) 34 | D_fake = self._discriminator(G, reuse=True) 35 | 36 | D_loss_real = 0.5 * tf.reduce_mean(tf.square(D_real - self.b)) # self.b 37 | D_loss_fake = 0.5 * tf.reduce_mean(tf.square(D_fake - self.a)) # self.a 38 | D_loss = D_loss_real + D_loss_fake 39 | G_loss = 0.5 * tf.reduce_mean(tf.square(D_fake - self.c)) # self.c 40 | 41 | D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name+'/D/') 42 | G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name+'/G/') 43 | 44 | D_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name+'/D/') 45 | G_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name+'/G/') 46 | 47 | with tf.control_dependencies(D_update_ops): 48 | D_train_op = tf.train.AdamOptimizer(learning_rate=self.D_lr, beta1=self.beta1).\ 49 | minimize(D_loss, var_list=D_vars) 50 | 51 | with tf.control_dependencies(G_update_ops): 52 | G_train_op = tf.train.AdamOptimizer(learning_rate=self.G_lr, beta1=self.beta1).\ 53 | minimize(G_loss, var_list=G_vars, global_step=global_step) 54 | 55 | # summaries 56 | # per-step summary 57 | self.summary_op = tf.summary.merge([ 58 | tf.summary.scalar('G/loss', G_loss), 59 | tf.summary.scalar('D/loss', D_loss), 60 | tf.summary.scalar('D/loss/real', D_loss_real), 61 | tf.summary.scalar('D/loss/fake', D_loss_fake) 62 | ]) 63 | 64 | # sparse-step summary 65 | tf.summary.image('G/fake_sample', G, max_outputs=self.FAKE_MAX_OUTPUT) 66 | tf.summary.histogram('D/real_value', D_real) 67 | tf.summary.histogram('D/fake_value', D_fake) 68 | 69 | self.all_summary_op = tf.summary.merge_all() 70 | 71 | # accesible points 72 | self.X = X 73 | self.z = z 74 | self.D_train_op = D_train_op 75 | self.G_train_op = G_train_op 76 | self.fake_sample = G 77 | self.global_step = global_step 78 | 79 | def _discriminator(self, X, reuse=False): 80 | with tf.variable_scope('D', reuse=reuse): 81 | net = X 82 | 83 | with slim.arg_scope([slim.conv2d], kernel_size=[5,5], stride=2, padding='SAME', activation_fn=ops.lrelu, 84 | normalizer_fn=slim.batch_norm, normalizer_params=self.bn_params): 85 | 86 | net = slim.conv2d(net, 64, normalizer_fn=None) 87 | expected_shape(net, [32, 32, 64]) 88 | net = slim.conv2d(net, 128) 89 | expected_shape(net, [16, 16, 128]) 90 | net = slim.conv2d(net, 256) 91 | expected_shape(net, [8, 8, 256]) 92 | net = slim.conv2d(net, 512) 93 | expected_shape(net, [4, 4, 512]) 94 | 95 | net = slim.flatten(net) 96 | d_value = slim.fully_connected(net, 1, activation_fn=None) 97 | 98 | return d_value 99 | 100 | # Originally, LSGAN used 112x112 LSUN images 101 | # We used 64x64 CelebA images 102 | def _generator(self, z, reuse=False): 103 | with tf.variable_scope('G', reuse=reuse): 104 | net = z 105 | net = slim.fully_connected(net, 4*4*256, activation_fn=tf.nn.relu, normalizer_fn=slim.batch_norm, 106 | normalizer_params=self.bn_params) 107 | net = tf.reshape(net, [-1, 4, 4, 256]) 108 | 109 | with slim.arg_scope([slim.conv2d_transpose], kernel_size=[3,3], padding='SAME', activation_fn=tf.nn.relu, 110 | normalizer_fn=slim.batch_norm, normalizer_params=self.bn_params): 111 | 112 | net = slim.conv2d_transpose(net, 256, stride=2) 113 | net = slim.conv2d_transpose(net, 256, stride=1) 114 | expected_shape(net, [8, 8, 256]) 115 | net = slim.conv2d_transpose(net, 256, stride=2) 116 | net = slim.conv2d_transpose(net, 256, stride=1) 117 | expected_shape(net, [16, 16, 256]) 118 | net = slim.conv2d_transpose(net, 128, stride=2) 119 | expected_shape(net, [32, 32, 128]) 120 | net = slim.conv2d_transpose(net, 64, stride=2) 121 | expected_shape(net, [64, 64, 64]) 122 | net = slim.conv2d_transpose(net, 3, stride=1, activation_fn=tf.nn.tanh, normalizer_fn=None) 123 | expected_shape(net, [64, 64, 3]) 124 | 125 | return net 126 | -------------------------------------------------------------------------------- /models/wgan.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import tensorflow as tf 3 | slim = tf.contrib.slim 4 | from utils import expected_shape 5 | import ops 6 | from basemodel import BaseModel 7 | 8 | ''' 9 | based on DCGAN. 10 | 11 | WGAN: 12 | WD = max_f [ Ex[f(x)] - Ez[f(g(z))] ] where f has K-Lipschitz constraint 13 | J = min WD (G_loss) 14 | ''' 15 | 16 | class WGAN(BaseModel): 17 | def __init__(self, name, training, D_lr=5e-5, G_lr=5e-5, image_shape=[64, 64, 3], z_dim=100): 18 | self.ld = 10. # lambda 19 | self.n_critic = 5 20 | super(WGAN, self).__init__(name=name, training=training, D_lr=D_lr, G_lr=G_lr, 21 | image_shape=image_shape, z_dim=z_dim) 22 | 23 | def _build_train_graph(self): 24 | with tf.variable_scope(self.name): 25 | X = tf.placeholder(tf.float32, [None] + self.shape) 26 | z = tf.placeholder(tf.float32, [None, self.z_dim]) 27 | global_step = tf.Variable(0, name='global_step', trainable=False) 28 | 29 | G = self._generator(z) 30 | C_real = self._critic(X) 31 | C_fake = self._critic(G, reuse=True) 32 | 33 | W_dist = tf.reduce_mean(C_real - C_fake) 34 | C_loss = -W_dist 35 | G_loss = tf.reduce_mean(-C_fake) 36 | 37 | C_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name+'/critic/') 38 | G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name+'/generator/') 39 | 40 | C_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name+'/critic/') 41 | G_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name+'/generator/') 42 | 43 | # In the paper, critic networks has been trained n_critic times for each training step. 44 | # Here I adjust learning rate instead. 45 | with tf.control_dependencies(C_update_ops): 46 | C_train_op = tf.train.RMSPropOptimizer(learning_rate=self.D_lr*self.n_critic).\ 47 | minimize(C_loss, var_list=C_vars) 48 | with tf.control_dependencies(G_update_ops): 49 | G_train_op = tf.train.RMSPropOptimizer(learning_rate=self.G_lr).\ 50 | minimize(G_loss, var_list=G_vars, global_step=global_step) 51 | 52 | # weight clipping 53 | ''' It is right that clips gamma of the batch_norm? ''' 54 | 55 | # ver 1. clips all variables in critic 56 | C_clips = [tf.assign(var, tf.clip_by_value(var, -0.01, 0.01)) for var in C_vars] # with gamma 57 | 58 | # ver 2. does not work 59 | # C_clips = [tf.assign(var, tf.clip_by_value(var, -0.01, 0.01)) for var in C_vars if 'gamma' not in var.op.name] # without gamma 60 | 61 | # ver 3. works but strange 62 | # C_clips = [] 63 | # for var in C_vars: 64 | # if 'gamma' not in var.op.name: 65 | # C_clips.append(tf.assign(var, tf.clip_by_value(var, -0.01, 0.01))) 66 | # else: 67 | # C_clips.append(tf.assign(var, tf.clip_by_value(var, -1.00, 1.00))) 68 | 69 | with tf.control_dependencies([C_train_op]): # should be iterable 70 | C_train_op = tf.tuple(C_clips) # tf.group ? 71 | 72 | # summaries 73 | # per-step summary 74 | self.summary_op = tf.summary.merge([ 75 | tf.summary.scalar('G_loss', G_loss), 76 | tf.summary.scalar('C_loss', C_loss), 77 | tf.summary.scalar('W_dist', W_dist) 78 | ]) 79 | 80 | # sparse-step summary 81 | tf.summary.image('fake_sample', G, max_outputs=self.FAKE_MAX_OUTPUT) 82 | # tf.summary.histogram('real_probs', D_real_prob) 83 | # tf.summary.histogram('fake_probs', D_fake_prob) 84 | self.all_summary_op = tf.summary.merge_all() 85 | 86 | # accesible points 87 | self.X = X 88 | self.z = z 89 | self.D_train_op = C_train_op # compatibility for train.py 90 | self.G_train_op = G_train_op 91 | self.fake_sample = G 92 | self.global_step = global_step 93 | 94 | def _critic(self, X, reuse=False): 95 | ''' K-Lipschitz function ''' 96 | with tf.variable_scope('critic', reuse=reuse): 97 | net = X 98 | 99 | with slim.arg_scope([slim.conv2d], kernel_size=[5,5], stride=2, activation_fn=ops.lrelu, 100 | normalizer_fn=slim.batch_norm, normalizer_params=self.bn_params): 101 | net = slim.conv2d(net, 64, normalizer_fn=None) 102 | expected_shape(net, [32, 32, 64]) 103 | net = slim.conv2d(net, 128) 104 | expected_shape(net, [16, 16, 128]) 105 | net = slim.conv2d(net, 256) 106 | expected_shape(net, [8, 8, 256]) 107 | net = slim.conv2d(net, 512) 108 | expected_shape(net, [4, 4, 512]) 109 | 110 | net = slim.flatten(net) 111 | net = slim.fully_connected(net, 1, activation_fn=None) 112 | 113 | return net 114 | 115 | def _generator(self, z, reuse=False): 116 | with tf.variable_scope('generator', reuse=reuse): 117 | net = z 118 | net = slim.fully_connected(net, 4*4*1024, activation_fn=tf.nn.relu) 119 | net = tf.reshape(net, [-1, 4, 4, 1024]) 120 | 121 | with slim.arg_scope([slim.conv2d_transpose], kernel_size=[5,5], stride=2, activation_fn=tf.nn.relu, 122 | normalizer_fn=slim.batch_norm, normalizer_params=self.bn_params): 123 | net = slim.conv2d_transpose(net, 512) 124 | expected_shape(net, [8, 8, 512]) 125 | net = slim.conv2d_transpose(net, 256) 126 | expected_shape(net, [16, 16, 256]) 127 | net = slim.conv2d_transpose(net, 128) 128 | expected_shape(net, [32, 32, 128]) 129 | net = slim.conv2d_transpose(net, 3, activation_fn=tf.nn.tanh, normalizer_fn=None) 130 | expected_shape(net, [64, 64, 3]) 131 | 132 | return net 133 | -------------------------------------------------------------------------------- /models/wgan_gp.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import tensorflow as tf 3 | slim = tf.contrib.slim 4 | from utils import expected_shape 5 | import ops 6 | from basemodel import BaseModel 7 | 8 | ''' 9 | WGAN: 10 | WD = max_f [ Ex[f(x)] - Ez[f(g(z))] ] where f has K-Lipschitz constraint 11 | J = min WD (G_loss) 12 | 13 | + GP: 14 | Instead of weight clipping, WGAN-GP proposed gradient penalty. 15 | ''' 16 | 17 | class WGAN_GP(BaseModel): 18 | def __init__(self, name, training, D_lr=1e-4, G_lr=1e-4, image_shape=[64, 64, 3], z_dim=100): 19 | self.beta1 = 0.0 20 | self.beta2 = 0.9 21 | self.ld = 10. # lambda 22 | self.n_critic = 5 23 | super(WGAN_GP, self).__init__(name=name, training=training, D_lr=D_lr, G_lr=G_lr, 24 | image_shape=image_shape, z_dim=z_dim) 25 | 26 | def _build_train_graph(self): 27 | with tf.variable_scope(self.name): 28 | X = tf.placeholder(tf.float32, [None] + self.shape) 29 | z = tf.placeholder(tf.float32, [None, self.z_dim]) 30 | global_step = tf.Variable(0, name='global_step', trainable=False) 31 | 32 | # `critic` named from wgan (wgan-gp use the term `discriminator` rather than `critic`) 33 | G = self._generator(z) 34 | C_real = self._critic(X) 35 | C_fake = self._critic(G, reuse=True) 36 | 37 | W_dist = tf.reduce_mean(C_real - C_fake) 38 | C_loss = -W_dist 39 | G_loss = tf.reduce_mean(-C_fake) 40 | 41 | # Gradient Penalty (GP) 42 | eps = tf.random_uniform(shape=[tf.shape(X)[0], 1, 1, 1], minval=0., maxval=1.) 43 | x_hat = eps*X + (1.-eps)*G 44 | C_xhat = self._critic(x_hat, reuse=True) 45 | C_xhat_grad = tf.gradients(C_xhat, x_hat)[0] # gradient of D(x_hat) 46 | C_xhat_grad_norm = tf.norm(slim.flatten(C_xhat_grad), axis=1) # l2 norm 47 | GP = self.ld * tf.reduce_mean(tf.square(C_xhat_grad_norm - 1.)) 48 | C_loss += GP 49 | 50 | C_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name+'/critic/') 51 | G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name+'/generator/') 52 | 53 | C_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name+'/critic/') 54 | G_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name+'/generator/') 55 | 56 | n_critic = 5 57 | lr = 1e-4 58 | with tf.control_dependencies(C_update_ops): 59 | C_train_op = tf.train.AdamOptimizer(learning_rate=self.D_lr*n_critic, beta1=self.beta1, beta2=self.beta2).\ 60 | minimize(C_loss, var_list=C_vars) 61 | with tf.control_dependencies(G_update_ops): 62 | G_train_op = tf.train.AdamOptimizer(learning_rate=self.G_lr, beta1=self.beta1, beta2=self.beta2).\ 63 | minimize(G_loss, var_list=G_vars, global_step=global_step) 64 | 65 | # summaries 66 | # per-step summary 67 | self.summary_op = tf.summary.merge([ 68 | tf.summary.scalar('G_loss', G_loss), 69 | tf.summary.scalar('C_loss', C_loss), 70 | tf.summary.scalar('W_dist', W_dist), 71 | tf.summary.scalar('GP', GP) 72 | ]) 73 | 74 | # sparse-step summary 75 | tf.summary.image('fake_sample', G, max_outputs=self.FAKE_MAX_OUTPUT) 76 | # tf.summary.histogram('real_probs', D_real_prob) 77 | # tf.summary.histogram('fake_probs', D_fake_prob) 78 | self.all_summary_op = tf.summary.merge_all() 79 | 80 | # accesible points 81 | self.X = X 82 | self.z = z 83 | self.D_train_op = C_train_op # train.py 와의 accesibility 를 위해... 흠... 구린데... 84 | self.G_train_op = G_train_op 85 | self.fake_sample = G 86 | self.global_step = global_step 87 | 88 | def _critic(self, X, reuse=False): 89 | return self._good_critic(X, reuse) 90 | 91 | def _generator(self, z, reuse=False): 92 | return self._good_generator(z, reuse) 93 | 94 | def _dcgan_critic(self, X, reuse=False): 95 | ''' 96 | K-Lipschitz function. 97 | WGAN-GP does not use critic in batch norm. 98 | ''' 99 | with tf.variable_scope('critic', reuse=reuse): 100 | net = X 101 | 102 | with slim.arg_scope([slim.conv2d], kernel_size=[5,5], stride=2, padding='SAME', activation_fn=ops.lrelu): 103 | net = slim.conv2d(net, 64) 104 | expected_shape(net, [32, 32, 64]) 105 | net = slim.conv2d(net, 128) 106 | expected_shape(net, [16, 16, 128]) 107 | net = slim.conv2d(net, 256) 108 | expected_shape(net, [8, 8, 256]) 109 | net = slim.conv2d(net, 512) 110 | expected_shape(net, [4, 4, 512]) 111 | 112 | net = slim.flatten(net) 113 | net = slim.fully_connected(net, 1, activation_fn=None) 114 | 115 | return net 116 | 117 | def _dcgan_generator(self, z, reuse=False): 118 | with tf.variable_scope('generator', reuse=reuse): 119 | net = z 120 | net = slim.fully_connected(net, 4*4*1024, activation_fn=tf.nn.relu) 121 | net = tf.reshape(net, [-1, 4, 4, 1024]) 122 | 123 | with slim.arg_scope([slim.conv2d_transpose], kernel_size=[5,5], stride=2, activation_fn=tf.nn.relu, 124 | normalizer_fn=slim.batch_norm, normalizer_params=self.bn_params): 125 | net = slim.conv2d_transpose(net, 512) 126 | expected_shape(net, [8, 8, 512]) 127 | net = slim.conv2d_transpose(net, 256) 128 | expected_shape(net, [16, 16, 256]) 129 | net = slim.conv2d_transpose(net, 128) 130 | expected_shape(net, [32, 32, 128]) 131 | net = slim.conv2d_transpose(net, 3, activation_fn=tf.nn.tanh, normalizer_fn=None) 132 | expected_shape(net, [64, 64, 3]) 133 | 134 | return net 135 | 136 | ''' 137 | ResNet architecture from appendix C in the paper. 138 | https://github.com/igul222/improved_wgan_training/blob/master/gan_64x64.py - GoodGenerator / GoodDiscriminator 139 | layer norm in D, batch norm in G. 140 | some details are ignored in this implemenation. 141 | ''' 142 | def _residual_block(self, X, nf_output, resample, kernel_size=[3,3], name='res_block'): 143 | with tf.variable_scope(name): 144 | input_shape = X.shape 145 | nf_input = input_shape[-1] 146 | if resample == 'down': # Downsample 147 | shortcut = slim.avg_pool2d(X, [2,2]) 148 | shortcut = slim.conv2d(shortcut, nf_output, kernel_size=[1,1], activation_fn=None) # init xavier 149 | 150 | net = slim.layer_norm(X, activation_fn=tf.nn.relu) 151 | net = slim.conv2d(net, nf_input, kernel_size=kernel_size, biases_initializer=None) # skip bias 152 | net = slim.layer_norm(net, activation_fn=tf.nn.relu) 153 | net = slim.conv2d(net, nf_output, kernel_size=kernel_size) 154 | net = slim.avg_pool2d(net, [2,2]) 155 | 156 | return net + shortcut 157 | elif resample == 'up': # Upsample 158 | upsample_shape = map(lambda x: int(x)*2, input_shape[1:3]) 159 | shortcut = tf.image.resize_nearest_neighbor(X, upsample_shape) 160 | shortcut = slim.conv2d(shortcut, nf_output, kernel_size=[1,1], activation_fn=None) 161 | 162 | net = slim.batch_norm(X, activation_fn=tf.nn.relu, **self.bn_params) 163 | net = tf.image.resize_nearest_neighbor(net, upsample_shape) 164 | net = slim.conv2d(net, nf_output, kernel_size=kernel_size, biases_initializer=None) # skip bias 165 | net = slim.batch_norm(net, activation_fn=tf.nn.relu, **self.bn_params) 166 | net = slim.conv2d(net, nf_output, kernel_size=kernel_size) 167 | 168 | return net + shortcut 169 | else: 170 | raise Exception('invalid resample value') 171 | 172 | def _good_generator(self, z, reuse=False): 173 | with tf.variable_scope('generator', reuse=reuse): 174 | nf = 64 175 | net = slim.fully_connected(z, 4*4*8*nf, activation_fn=None) # 4x4x512 176 | net = tf.reshape(net, [-1, 4, 4, 8*nf]) 177 | net = self._residual_block(net, 8*nf, resample='up', name='res_block1') # 8x8x512 178 | net = self._residual_block(net, 4*nf, resample='up', name='res_block2') # 16x16x256 179 | net = self._residual_block(net, 2*nf, resample='up', name='res_block3') # 32x32x128 180 | net = self._residual_block(net, 1*nf, resample='up', name='res_block4') # 64x64x64 181 | expected_shape(net, [64, 64, 64]) 182 | net = slim.batch_norm(net, activation_fn=tf.nn.relu, **self.bn_params) 183 | net = slim.conv2d(net, 3, kernel_size=[3,3], activation_fn=tf.nn.tanh) 184 | expected_shape(net, [64, 64, 3]) 185 | 186 | return net 187 | 188 | def _good_critic(self, X, reuse=False): 189 | with tf.variable_scope('critic', reuse=reuse): 190 | nf = 64 191 | net = slim.conv2d(X, nf, [3,3], activation_fn=None) # 64x64x64 192 | net = self._residual_block(net, 2*nf, resample='down', name='res_block1') # 32x32x128 193 | net = self._residual_block(net, 4*nf, resample='down', name='res_block2') # 16x16x256 194 | net = self._residual_block(net, 8*nf, resample='down', name='res_block3') # 8x8x512 195 | net = self._residual_block(net, 8*nf, resample='down', name='res_block4') # 4x4x512 196 | expected_shape(net, [4, 4, 512]) 197 | net = slim.flatten(net) 198 | net = slim.fully_connected(net, 1, activation_fn=None) 199 | 200 | return net 201 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import tensorflow as tf 4 | slim = tf.contrib.slim 5 | 6 | 7 | def lrelu(inputs, leak=0.2, scope="lrelu"): 8 | """ 9 | https://github.com/tensorflow/tensorflow/issues/4079 10 | """ 11 | with tf.variable_scope(scope): 12 | f1 = 0.5 * (1 + leak) 13 | f2 = 0.5 * (1 - leak) 14 | return f1 * inputs + f2 * abs(inputs) 15 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import tensorflow as tf 4 | from tqdm import tqdm 5 | import numpy as np 6 | import inputpipe as ip 7 | import glob, os, sys 8 | from argparse import ArgumentParser 9 | import utils, config 10 | 11 | 12 | def build_parser(): 13 | parser = ArgumentParser() 14 | parser.add_argument('--num_epochs', default=20, help='default: 20', type=int) 15 | parser.add_argument('--batch_size', default=128, help='default: 128', type=int) 16 | parser.add_argument('--num_threads', default=4, help='# of data read threads (default: 4)', type=int) 17 | models_str = ' / '.join(config.model_zoo) 18 | parser.add_argument('--model', help=models_str, required=True) # DRAGAN, CramerGAN 19 | parser.add_argument('--name', help='default: name=model') 20 | parser.add_argument('--dataset', '-D', help='CelebA / LSUN', required=True) 21 | parser.add_argument('--ckpt_step', default=5000, help='# of steps for saving checkpoint (default: 5000)', type=int) 22 | parser.add_argument('--renew', action='store_true', help='train model from scratch - \ 23 | clean saved checkpoints and summaries', default=False) 24 | 25 | return parser 26 | 27 | 28 | def input_pipeline(glob_pattern, batch_size, num_threads, num_epochs): 29 | tfrecords_list = glob.glob(glob_pattern) 30 | # num_examples = utils.num_examples_from_tfrecords(tfrecords_list) # takes too long time for lsun 31 | X = ip.shuffle_batch_join(tfrecords_list, batch_size=batch_size, num_threads=num_threads, num_epochs=num_epochs) 32 | return X 33 | 34 | 35 | def sample_z(shape): 36 | return np.random.normal(size=shape) 37 | 38 | 39 | def train(model, dataset, input_op, num_epochs, batch_size, n_examples, ckpt_step, renew=False): 40 | # n_examples = 202599 # same as util.num_examples_from_tfrecords(glob.glob('./data/celebA_tfrecords/*.tfrecord')) 41 | # 1 epoch = 1583 steps 42 | print("\n# of examples: {}".format(n_examples)) 43 | print("steps per epoch: {}\n".format(n_examples//batch_size)) 44 | 45 | summary_path = os.path.join('./summary/', dataset, model.name) 46 | ckpt_path = os.path.join('./checkpoints', dataset, model.name) 47 | if renew: 48 | if os.path.exists(summary_path): 49 | tf.gfile.DeleteRecursively(summary_path) 50 | if os.path.exists(ckpt_path): 51 | tf.gfile.DeleteRecursively(ckpt_path) 52 | if not os.path.exists(ckpt_path): 53 | tf.gfile.MakeDirs(ckpt_path) 54 | 55 | config = tf.ConfigProto() 56 | best_gpu = utils.get_best_gpu() 57 | config.gpu_options.visible_device_list = str(best_gpu) # Works same as CUDA_VISIBLE_DEVICES! 58 | with tf.Session(config=config) as sess: 59 | sess.run(tf.global_variables_initializer()) 60 | sess.run(tf.local_variables_initializer()) # for epochs 61 | 62 | coord = tf.train.Coordinator() 63 | threads = tf.train.start_queue_runners(coord=coord) 64 | 65 | # https://github.com/tensorflow/tensorflow/issues/10972 66 | # TensorFlow 1.2 has much bugs for text summary 67 | # make config_summary before define of summary_writer - bypass bug of tensorboard 68 | 69 | # It seems that batch_size should have been contained in the model config ... 70 | total_steps = int(np.ceil(n_examples * num_epochs / float(batch_size))) # total global step 71 | config_list = [ 72 | ('num_epochs', num_epochs), 73 | ('total_iteration', total_steps), 74 | ('batch_size', batch_size), 75 | ('dataset', dataset) 76 | ] 77 | model_config_list = [[k, str(w)] for k, w in sorted(model.args.items()) + config_list] 78 | model_config_summary_op = tf.summary.text(model.name + '/config', tf.convert_to_tensor(model_config_list), 79 | collections=[]) 80 | model_config_summary = sess.run(model_config_summary_op) 81 | 82 | # print to console 83 | print("\n====== Process info =======") 84 | print("argv: {}".format(' '.join(sys.argv))) 85 | print("PID: {}".format(os.getpid())) 86 | print("====== Model configs ======") 87 | for k, v in model_config_list: 88 | print("{}: {}".format(k, v)) 89 | print("===========================\n") 90 | 91 | summary_writer = tf.summary.FileWriter(summary_path, flush_secs=30, graph=sess.graph) 92 | summary_writer.add_summary(model_config_summary) 93 | pbar = tqdm(total=total_steps, desc='global_step') 94 | saver = tf.train.Saver(max_to_keep=9999) # save all checkpoints 95 | global_step = 0 96 | 97 | ckpt = tf.train.get_checkpoint_state(ckpt_path) 98 | if ckpt: 99 | saver.restore(sess, ckpt.model_checkpoint_path) 100 | global_step = sess.run(model.global_step) 101 | print('\n[!] Restore from {} ... starting global step is {}\n'.format(ckpt.model_checkpoint_path, global_step)) 102 | pbar.update(global_step) 103 | 104 | try: 105 | # If training process was resumed from checkpoints, input pipeline cannot detect 106 | # when training should stop. So we need `global_step < total_step` condition. 107 | while not coord.should_stop() and global_step < total_steps: 108 | # model.all_summary_op contains histogram summary and image summary which are heavy op 109 | summary_op = model.summary_op if global_step % 100 == 0 else model.all_summary_op 110 | 111 | batch_X = sess.run(input_op) 112 | batch_z = sample_z([batch_size, model.z_dim]) 113 | 114 | _, summary = sess.run([model.D_train_op, summary_op], {model.X: batch_X, model.z: batch_z}) 115 | _, global_step = sess.run([model.G_train_op, model.global_step], {model.z: batch_z}) 116 | 117 | summary_writer.add_summary(summary, global_step=global_step) 118 | 119 | if global_step % 10 == 0: 120 | pbar.update(10) 121 | 122 | if global_step % ckpt_step == 0: 123 | saver.save(sess, ckpt_path+'/'+model.name, global_step=global_step) 124 | 125 | except tf.errors.OutOfRangeError: 126 | print('\nDone -- epoch limit reached\n') 127 | finally: 128 | coord.request_stop() 129 | 130 | coord.join(threads) 131 | summary_writer.close() 132 | pbar.close() 133 | 134 | 135 | if __name__ == "__main__": 136 | parser = build_parser() 137 | FLAGS = parser.parse_args() 138 | FLAGS.model = FLAGS.model.upper() 139 | FLAGS.dataset = FLAGS.dataset.lower() 140 | if FLAGS.name is None: 141 | FLAGS.name = FLAGS.model.lower() 142 | config.pprint_args(FLAGS) 143 | 144 | # get information for dataset 145 | dataset_pattern, n_examples = config.get_dataset(FLAGS.dataset) 146 | # input pipeline 147 | X = input_pipeline(dataset_pattern, batch_size=FLAGS.batch_size, 148 | num_threads=FLAGS.num_threads, num_epochs=FLAGS.num_epochs) 149 | 150 | model = config.get_model(FLAGS.model, FLAGS.name, training=True) 151 | train(model=model, dataset=FLAGS.dataset, input_op=X, num_epochs=FLAGS.num_epochs, batch_size=FLAGS.batch_size, 152 | n_examples=n_examples, ckpt_step=FLAGS.ckpt_step, renew=FLAGS.renew) 153 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import tensorflow as tf 3 | import tensorflow.contrib.slim as slim 4 | '''https://stackoverflow.com/questions/37604289/tkinter-tclerror-no-display-name-and-no-display-environment-variable 5 | Matplotlib chooses Xwindows backend by default. You need to set matplotlib do not use Xwindows backend. 6 | - `matplotlib.use('Agg')` 7 | - Or add to .config/matplotlib/matplotlibrc line backend : Agg. 8 | - Or when connect to server use ssh -X ... command to use Xwindows. 9 | ''' 10 | import matplotlib 11 | matplotlib.use('Agg') 12 | import matplotlib.pyplot as plt 13 | import matplotlib.gridspec as gridspec 14 | import scipy.misc 15 | import numpy as np 16 | 17 | 18 | def get_best_gpu(): 19 | '''Dependency: pynvml (for gpu memory informations) 20 | return type is integer (gpu_id) 21 | ''' 22 | try: 23 | from pynvml import nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, nvmlDeviceGetName, nvmlDeviceGetMemoryInfo 24 | except Exception as e: 25 | print('[!] {} => Use default GPU settings ...\n'.format(e)) 26 | return '' 27 | 28 | print('\n===== Check GPU memory =====') 29 | 30 | # byte to megabyte 31 | def to_mb(x): 32 | return int(x/1024./1024.) 33 | 34 | best_idx = -1 35 | best_free = 0. 36 | nvmlInit() 37 | n_gpu = nvmlDeviceGetCount() 38 | for i in range(n_gpu): 39 | handle = nvmlDeviceGetHandleByIndex(i) 40 | name = nvmlDeviceGetName(handle) 41 | mem = nvmlDeviceGetMemoryInfo(handle) 42 | 43 | total = to_mb(mem.total) 44 | free = to_mb(mem.free) 45 | used = to_mb(mem.used) 46 | free_ratio = mem.free / float(mem.total) 47 | 48 | print("{} - {}/{} MB (free: {} MB - {:.2%})".format(name, used, total, free, free_ratio)) 49 | 50 | if free > best_free: 51 | best_free = free 52 | best_idx = i 53 | 54 | print('\nSelected GPU is gpu:{}'.format(best_idx)) 55 | print('============================\n') 56 | 57 | return best_idx 58 | 59 | 60 | # Iterate the whole dataset and count the numbers 61 | # CelebA contains about 200k examples with 128 tfrecord files and it takes about 1.5s to iterate 62 | def num_examples_from_tfrecords(tfrecords_list): 63 | num_examples = 0 64 | for path in tfrecords_list: 65 | num_examples += sum(1 for _ in tf.python_io.tf_record_iterator(path)) 66 | return num_examples 67 | 68 | 69 | def expected_shape(tensor, expected): 70 | """batch size N shouldn't be set. 71 | you can use shape of tensor instead of tensor itself. 72 | 73 | Usage: 74 | # batch size N is skipped. 75 | expected_shape(tensor, [28, 28, 1]) 76 | expected_shape(tensor.shape, [28, 28, 1]) 77 | """ 78 | if isinstance(tensor, tf.Tensor): 79 | shape = tensor.shape[1:] 80 | else: 81 | shape = tensor[1:] 82 | shape = map(lambda x: x.value, shape) 83 | err_msg = 'wrong shape {} (expected shape is {})'.format(shape, expected) 84 | assert shape == expected, err_msg 85 | # if not shape == expected: 86 | # warnings.warn('wrong shape {} (expected shape is {})'.format(shape, expected)) 87 | 88 | 89 | def plot(samples, shape=(4,4), figratio=0.75): 90 | """only for square-size samples 91 | wh = sqrt(samples.size) 92 | figratio: small-size = 0.75 (default) / big-size = 1.0 93 | """ 94 | if len(samples) != shape[0]*shape[1]: 95 | print("Error: # of samples = {} but shape is {}".format(len(samples), shape)) 96 | return 97 | 98 | h_figsize = shape[0] * figratio 99 | w_figsize = shape[1] * figratio 100 | fig = plt.figure(figsize=(w_figsize, h_figsize)) 101 | gs = gridspec.GridSpec(shape[0], shape[1]) 102 | gs.update(wspace=0.05, hspace=0.05) 103 | 104 | for i, sample in enumerate(samples): 105 | ax = plt.subplot(gs[i]) 106 | plt.axis('off') 107 | ax.set_xticklabels([]) 108 | ax.set_yticklabels([]) 109 | ax.set_aspect('equal') 110 | plt.imshow(sample) # checks cmap ... 111 | 112 | return fig 113 | 114 | 115 | def show_all_variables(): 116 | model_vars = tf.trainable_variables() 117 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 118 | 119 | 120 | def merge(images, size): 121 | """merge images - burrowed from @carpedm20. 122 | 123 | checklist before/after imsave: 124 | * are images post-processed? for example - denormalization 125 | * is np.squeeze required? maybe for grayscale... 126 | """ 127 | h, w = images.shape[1], images.shape[2] 128 | if (images.shape[3] in (3,4)): 129 | c = images.shape[3] 130 | img = np.zeros((h * size[0], w * size[1], c)) 131 | for idx, image in enumerate(images): 132 | i = idx % size[1] 133 | j = idx // size[1] 134 | img[j * h:j * h + h, i * w:i * w + w, :] = image 135 | return img 136 | elif images.shape[3]==1: 137 | img = np.zeros((h * size[0], w * size[1])) 138 | for idx, image in enumerate(images): 139 | i = idx % size[1] 140 | j = idx // size[1] 141 | img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0] 142 | return img 143 | else: 144 | raise ValueError('in merge(images,size) images parameter must have dimensions: HxW or HxWx3 or HxWx4') 145 | 146 | 147 | '''Sugar for gradients histograms 148 | # D_train_op = tf.train.AdamOptimizer(learning_rate=self.D_lr, beta1=self.beta1, beta2=self.beta2).\ 149 | # minimize(D_loss, var_list=D_vars) 150 | D_opt = tf.train.AdamOptimizer(learning_rate=self.D_lr, beta1=self.beta1, beta2=self.beta2) 151 | D_grads = tf.gradients(D_loss, D_vars) 152 | D_grads_and_vars = list(zip(D_grads, D_vars)) 153 | D_train_op = D_opt.apply_gradients(grads_and_vars=D_grads_and_vars) 154 | 155 | # G_train_op = tf.train.AdamOptimizer(learning_rate=self.G_lr, beta1=self.beta1, beta2=self.beta2).\ 156 | # minimize(G_loss, var_list=G_vars, global_step=global_step) 157 | G_opt = tf.train.AdamOptimizer(learning_rate=self.G_lr, beta1=self.beta1, beta2=self.beta2) 158 | G_grads = tf.gradients(G_loss, G_vars) 159 | G_grads_and_vars = list(zip(G_grads, G_vars)) 160 | G_train_op = G_opt.apply_gradients(grads_and_vars=G_grads_and_vars, global_step=global_step) 161 | 162 | 163 | for var in tf.trainable_variables(): 164 | tf.summary.histogram(var.op.name, var) 165 | 166 | for grad, var in D_grads_and_vars: 167 | tf.summary.histogram('D/' + var.name + '/gradient', grad) 168 | for grad, var in G_grads_and_vars: 169 | tf.summary.histogram('G/' + var.name + '/gradient', grad) 170 | ''' 171 | --------------------------------------------------------------------------------