├── .gitignore ├── LICENSE ├── README.md ├── datasets └── download_dataset.sh ├── download_pretrained_model.sh ├── imgs ├── examples │ ├── combine_A_and_B_2.py │ ├── handdraw │ │ ├── 01.png │ │ ├── 02.png │ │ ├── 03.png │ │ ├── 04.png │ │ ├── 05.png │ │ ├── 06.png │ │ ├── 07.png │ │ ├── 08.png │ │ ├── 09.png │ │ ├── 10.png │ │ ├── 11.png │ │ ├── 12.png │ │ ├── 13.png │ │ ├── 14.png │ │ ├── 15.png │ │ ├── 16.png │ │ └── 17.png │ ├── match │ │ ├── eyeglasses │ │ │ ├── 01.png │ │ │ ├── 02.png │ │ │ ├── 03.png │ │ │ ├── 04.png │ │ │ ├── 05.png │ │ │ ├── 06.png │ │ │ ├── 07.png │ │ │ └── 08.png │ │ ├── lipstick │ │ │ ├── 01.png │ │ │ ├── 02.png │ │ │ ├── 03.png │ │ │ ├── 04.png │ │ │ ├── 05.png │ │ │ ├── 06.png │ │ │ ├── 07.png │ │ │ └── 08.png │ │ └── mustache │ │ │ ├── 01.png │ │ │ ├── 02.png │ │ │ ├── 03.png │ │ │ ├── 04.png │ │ │ ├── 05.png │ │ │ ├── 06.png │ │ │ ├── 07.png │ │ │ └── 08.png │ └── mismatch │ │ ├── 01.png │ │ ├── 02.png │ │ ├── 03.png │ │ ├── 04.png │ │ ├── 05.png │ │ ├── 06.png │ │ ├── 07.png │ │ ├── 08.png │ │ ├── 09.png │ │ ├── 10.png │ │ ├── 11.png │ │ ├── 12.png │ │ ├── 13.png │ │ ├── 14.png │ │ ├── 15.png │ │ └── 16.png ├── handdraw │ ├── bird_handdraw.png │ ├── face_handdraw.png │ └── shoe_handdraw.png └── problem.png ├── model.py ├── ops.py ├── test.py ├── test.sh ├── train.py ├── train.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Shangzhe Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Attribute-Guided Sketch-to-Image Generation 2 | 3 | 4 | ## Introduction 5 | **Note: Implementation in this repo has been largely outdated. You may refer to my [presentation slides](https://goo.gl/S2JjAn) for your own implementation.** 6 | 7 | This is based on one of our prior works on [Sketch-to-Image Generation](https://arxiv.org/abs/1711.08972). Freehand sketch can be highly abstract (examples shown below), and learning representations of sketches is not trivial. In contrast to other cross domain learning approaches, like [pix2pix](https://phillipi.github.io/pix2pix/) and [CycleGAN](https://junyanz.github.io/CycleGAN/), where a mapping from representations in one domain to those in another domain is learned using translation networks, in [Sketch-to-Image Generation](https://arxiv.org/abs/1711.08972), we propose to learn a joint representation of sketch and image. 8 | 9 | In this project we intend to add text constraints to sketch-to-image generation, where texts provide the contents and sketches control the shapes. **So far, I only tried to add attribute guidance instead of using text embeddings as additional conditions, and this repo demonstrates results on _Attribute-Guided Sketch-to-Image Generation_.** 10 | 11 | face |bird |shoe 12 | :--------------------------:|:--------------------------:|:--------------------------: 13 | ![](imgs/handdraw/face_handdraw.png) |![](imgs/handdraw/bird_handdraw.png) |![](imgs/handdraw/shoe_handdraw.png) 14 | 15 | * A few freehand sketches were collected from volunteers. 16 | 17 | #### Contributors: 18 | - Major Contributor: [Shangzhe Wu](https://elliottwu.github.io/) (HKUST) 19 | - Supervisor: Yu-wing Tai (Tencent), [Chi-Keung Tang](http://www.cs.ust.hk/~cktang/) (HKUST) 20 | - Mentor in MLJejuCamp2017: Hyungjoo Cho 21 | 22 | ### MLJejuCamp2017 23 | This project was developed in [Machine Learning Camp Jeju 2017](http://jeju.dlcamp.org/2017/) within one month. More interesting projects can be found in [final presentations](https://github.com/TensorFlowKR/dlcampjeju/blob/master/2017/github/04_FinalPresentation.md) and [program GitHub](https://github.com/MLJejuCamp2017). Final presentation video can be watched [here](https://www.youtube.com/watch?v=X6ieGv82PYU) (partially). Camp 2018 has been launched, and more details can be found [here](http://jeju.dlcamp.org/2018/). 24 | 25 | ## Get Started 26 | ### Prerequisites 27 | - Python 3.5 28 | - [Tensorflow 0.12.1](https://github.com/tensorflow/tensorflow/tree/r0.12) 29 | - [SciPy](https://www.scipy.org/install.html) 30 | 31 | ### Setup 32 | - Clone this repo: 33 | ```bash 34 | git clone https://github.com/elliottwu/sText2Image.git 35 | cd sText2Image 36 | ``` 37 | 38 | - Download preprocessed CelebA data (~3GB): 39 | ```bash 40 | sh ./datasets/download_dataset.sh 41 | ``` 42 | 43 | ### Train 44 | ```bash 45 | sh train.sh 46 | ``` 47 | - To monitor training using Tensorboard, copy the following to your terminal and open `localhost:8888` in your browser 48 | ```bash 49 | tensorboard --logdir=logs_face --port=8888 50 | ``` 51 | 52 | ### Test 53 | ```bash 54 | sh test.sh 55 | ``` 56 | 57 | ### Pretrained Model 58 | - Download pretrained model: 59 | ```bash 60 | sh download_pretrained_model.sh 61 | ``` 62 | 63 | - Test pretrained model on CelebA dataset: 64 | ```bash 65 | python test.py ./datasets/celeba/test/* --checkpointDir checkpoints_face_pretrained --maskType right --batchSize 64 --lam1 100 --lam2 1 --lam3 0.1 --lr 0.001 --nIter 1000 --outDir results_face_pretrained --text_vector_dim 18 --text_path datasets/celeba/imAttrs.pkl 66 | ``` 67 | 68 | ## Experiments 69 | We test our framework with 3 kinds of data, face([CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)), bird([CUB](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html)), and flower([Oxford-102](http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html)). So far, we have only experimented with face images using attribute vectors as texts information. Here are some preliminary results: 70 | 71 | ### 1. Face 72 | We used CelebA dataset, which also provides 40 attributes for each image. Similar to the text information, attributes control the specific details of the generated images. We chose 18 attrbutes for training. 73 | 74 | #### a). Attributes match sketch: 75 | The following images were generated given sketches and the **corresponding** attriubtes. 76 | 77 | ##### Mustache 78 | attributes |sketch / generated / gt |attributes | sketch / generated / gt 79 | :-----------:|:----------------------:|:-----------:|:-----------------------: 80 | Male, 5_o_Clock_Shadow, Mouth_Open, Pointy_Nose | | Male, 5_o_Clock_Shadow, Big_Nose, Mustache | 81 | Male, Big_Lips, Big_Nose, Chubby, Goatee, High_Cheekbones, Smiling | | Male, Mustache | 82 | Male, Goatee, Mouth_Open, Smiling | | Male, Big_Nose, Goatee, Smiling | 83 | Male, 5_o_Clock_Shadow, Big_Lips, Big_Nose, Goatee, High_Cheekbones, Mouth_Open, Rosy_Cheeks, Smiling | | Male, 5_o_Clock_Shadow, Big_Nose, Narrow_Eyes | 84 | 85 | ##### Eyeglasses 86 | attributes |sketch / generated / gt |attributes | sketch / generated / gt 87 | :-----------:|:----------------------:|:-----------:|:-----------------------: 88 | Male, Big_Nose, Eyeglasses, Goatee | | Female, Eyeglasses | 89 | Female, Eyeglasses, High_Cheekbones, Mouth_Open, Smiling | | Male, 5_o_Clock_Shadow, Big_Nose, Eyeglasses, Mouth_Open, Smiling | 90 | Male, Big_Nose, Double_Chin, Eyeglasses, Mouth_Open, Pointy_Nose, Smiling | | Male, Eyeglasses, High_Cheekbones, Mouth_Open, Smiling | 91 | Male, 5_o_Clock_Shadow, Eyeglasses, Mouth_Open, Smiling | | Male, Big_Lips, Big_Nose, Eyeglasses, Goatee, Mouth_Open | 92 | 93 | 94 | ##### Lipstick 95 | attributes |sketch / generated / gt |attributes | sketch / generated / gt 96 | :-----------:|:----------------------:|:-----------:|:-----------------------: 97 | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Pointy_Nose, Smiling, Wearing_Lipstick | | Female, Heavy_Makeup, Mouth_Open, Wearing_Lipstick | 98 | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Pointy_Nose, Smiling, Wearing_Lipstick | | Female, Heavy_Makeup, Pointy_Nose, Smiling, Wearing_Lipstick | 99 | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Pointy_Nose, Smiling, Wearing_Lipstick | | Female, Big_Lips, Big_Nose, Heavy_Makeup, High_Cheekbones, Mouth_Open, Rosy_Cheeks, Smiling, Wearing_Lipstick | 100 | Female, Heavy_Makeup, Pointy_Nose, Wearing_Lipstick | | Female, Heavy_Makeup, Mouth_Open, Smiling, Wearing_Lipstick | 101 | 102 | #### b). Attributes mismatch sketch: 103 | The following images were generated given sketches and the **random** attriubtes. The controlling effects of the attributes are still under improvement. 104 | 105 | attributes |sketch / generated |attributes | sketch / generated |attributes | sketch / generated 106 | :-----------:|:----------------------:|:-----------:|:-----------------------:|:-----------:|:-----------------------: 107 | Female, Big_Lips, Heavy_Makeup, Wearing_Lipstick | | Female, Big_Lips, Heavy_Makeup, Wearing_Lipstick | | Male, Big_Nose, No_Eyeglasses | 108 | Male, Big_Nose, Chubby, Double_Chin, High_Cheekbones, Smiling | | Male, Big_Nose, Chubby, Double_Chin, High_Cheekbones, Mouth_Open, Smiling | | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick | 109 | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick | | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick | | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick | 110 | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick, No_Eyeglasses | | Male | | Female, Heavy_Makeup, Pale_Skin, Wearing_Lipstick | 111 | Female, Heavy_Makeup, High_Cheekbones, Pointy_Nose, Smiling, Wearing_Lipstick, No_Eyeglasses | | Female, Heavy_Makeup, High_Cheekbones, Pointy_Nose, Smiling, Wearing_Lipstick | | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Pointy_Nose, Rosy_Cheeks, Smiling, Wearing_Lipstick | 112 | 113 | #### c). Freehand sketch: 114 | The following images were generated given freehand sketches and the **random** attriubtes. The controlling effects of the attributes are still under improvement. 115 | 116 | attributes |sketch / generated |attributes | sketch / generated |attributes | sketch / generated 117 | :-----------:|:----------------------:|:-----------:|:-----------------------:|:-----------:|:-----------------------: 118 | Female, Big_Lips, Heavy_Makeup, Wearing_Lipstick | | Male, Big_Nose | | Male, Big_Nose, Chubby, Double_Chin, High_Cheekbones, Mouth_Open, Smiling | 119 | Male, Big_Nose, Chubby, Double_Chin, High_Cheekbones, Mouth_Open, Smiling | | Male, Big_Nose, Chubby, Double_Chin, High_Cheekbones, Mouth_Open, Smiling | | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick | 120 | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick | | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick | | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick | 121 | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick | | Female, Big_Lips, Heavy_Makeup, High_Cheekbones, Mouth_Open, Narrow_Eyes, Smiling, Wearing_Lipstick | | Female, Big_Lips, Heavy_Makeup, High_Cheekbones, Mouth_Open, Narrow_Eyes, Smiling, Wearing_Lipstick | 122 | Female, Heavy_Makeup, High_Cheekbones, Pointy_Nose, Smiling, Wearing_Lipstick | | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick | | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick | 123 | 124 | ## Acknowledgement 125 | Codes are based on [DCGAN](https://github.com/carpedm20/DCGAN-tensorflow) and [dcgan-completion](https://github.com/bamos/dcgan-completion.tensorflow). 126 | 127 | ## Citation 128 | Consider citing the following paper if you find this repo helpful: 129 | ``` 130 | @InProceedings{Lu_2018_ECCV, 131 | author = {Lu, Yongyi and Wu, Shangzhe and Tai, Yu-Wing and Tang, Chi-Keung}, 132 | title = {Image Generation from Sketch Constraint Using Contextual GAN}, 133 | booktitle = {The European Conference on Computer Vision (ECCV)}, 134 | month = {September}, 135 | year = {2018} 136 | } 137 | ``` 138 | -------------------------------------------------------------------------------- /datasets/download_dataset.sh: -------------------------------------------------------------------------------- 1 | # celeba dataset 2 | download_celeba(){ 3 | echo "----------------------- downloading celeba dataset -----------------------" 4 | wget -O celeba.tar.gz "https://www.robots.ox.ac.uk/~szwu/storage/18_sketch/celeba.tar.gz" 5 | tar xzvf celeba.tar.gz -C ./datasets/ 6 | rm celeba.tar.gz 7 | } 8 | 9 | # all datasets 10 | download_all(){ 11 | download_celeba 12 | echo "----------------------- done -----------------------" 13 | } 14 | 15 | download_all 16 | -------------------------------------------------------------------------------- /download_pretrained_model.sh: -------------------------------------------------------------------------------- 1 | # face model 2 | download_face(){ 3 | echo "----------------------- downloading face pretrained model -----------------------" 4 | wget -O face_pretrained.tar.gz "https://www.robots.ox.ac.uk/~szwu/storage/18_sketch/face_pretrained.tar.gz" 5 | tar xzvf face_pretrained.tar.gz 6 | rm face_pretrained.tar.gz 7 | } 8 | 9 | # all models 10 | download_all(){ 11 | download_face 12 | echo "----------------------- done -----------------------" 13 | } 14 | 15 | download_all 16 | -------------------------------------------------------------------------------- /imgs/examples/combine_A_and_B_2.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as st 2 | import os 3 | import numpy as np 4 | import cv2 5 | 6 | fold = 'handdraw' 7 | 8 | l = ['2_60', '3_54', '5_29', '5_41', '5_62', '8_53', '9_28', '9_39', '9_40', '11_08', '11_20', '11_30', '13_05', '13_17', '15_44', '17_07', '17_51'] 9 | 10 | num_imgs = len(l) 11 | for n in range(num_imgs): 12 | print('['+str(n)+'/'+str(num_imgs)+']') 13 | name = '{:02d}'.format(n+1) 14 | path_A = os.path.join(fold, l[n]+'_A.png') 15 | path_B = os.path.join(fold, l[n]+'_B.png') 16 | #path_C = os.path.join(fold, l[n]+'_C.png') 17 | if os.path.isfile(path_A) and os.path.isfile(path_B): 18 | path_ABC = os.path.join(fold, name+'.png') 19 | im_A = cv2.imread(path_A) 20 | im_B = cv2.imread(path_B) 21 | #im_C = cv2.imread(path_C) 22 | h, w, c = im_A.shape 23 | gap = np.ones((h, 5, c))*255 24 | im_ABC = np.concatenate([im_A, gap, im_B], 1) 25 | cv2.imwrite(path_ABC, im_ABC) 26 | 27 | -------------------------------------------------------------------------------- /imgs/examples/handdraw/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/01.png -------------------------------------------------------------------------------- /imgs/examples/handdraw/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/02.png -------------------------------------------------------------------------------- /imgs/examples/handdraw/03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/03.png -------------------------------------------------------------------------------- /imgs/examples/handdraw/04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/04.png -------------------------------------------------------------------------------- /imgs/examples/handdraw/05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/05.png -------------------------------------------------------------------------------- /imgs/examples/handdraw/06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/06.png -------------------------------------------------------------------------------- /imgs/examples/handdraw/07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/07.png -------------------------------------------------------------------------------- /imgs/examples/handdraw/08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/08.png -------------------------------------------------------------------------------- /imgs/examples/handdraw/09.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/09.png -------------------------------------------------------------------------------- /imgs/examples/handdraw/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/10.png -------------------------------------------------------------------------------- /imgs/examples/handdraw/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/11.png -------------------------------------------------------------------------------- /imgs/examples/handdraw/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/12.png -------------------------------------------------------------------------------- /imgs/examples/handdraw/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/13.png -------------------------------------------------------------------------------- /imgs/examples/handdraw/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/14.png -------------------------------------------------------------------------------- /imgs/examples/handdraw/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/15.png -------------------------------------------------------------------------------- /imgs/examples/handdraw/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/16.png -------------------------------------------------------------------------------- /imgs/examples/handdraw/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/17.png -------------------------------------------------------------------------------- /imgs/examples/match/eyeglasses/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/eyeglasses/01.png -------------------------------------------------------------------------------- /imgs/examples/match/eyeglasses/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/eyeglasses/02.png -------------------------------------------------------------------------------- /imgs/examples/match/eyeglasses/03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/eyeglasses/03.png -------------------------------------------------------------------------------- /imgs/examples/match/eyeglasses/04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/eyeglasses/04.png -------------------------------------------------------------------------------- /imgs/examples/match/eyeglasses/05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/eyeglasses/05.png -------------------------------------------------------------------------------- /imgs/examples/match/eyeglasses/06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/eyeglasses/06.png -------------------------------------------------------------------------------- /imgs/examples/match/eyeglasses/07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/eyeglasses/07.png -------------------------------------------------------------------------------- /imgs/examples/match/eyeglasses/08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/eyeglasses/08.png -------------------------------------------------------------------------------- /imgs/examples/match/lipstick/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/lipstick/01.png -------------------------------------------------------------------------------- /imgs/examples/match/lipstick/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/lipstick/02.png -------------------------------------------------------------------------------- /imgs/examples/match/lipstick/03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/lipstick/03.png -------------------------------------------------------------------------------- /imgs/examples/match/lipstick/04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/lipstick/04.png -------------------------------------------------------------------------------- /imgs/examples/match/lipstick/05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/lipstick/05.png -------------------------------------------------------------------------------- /imgs/examples/match/lipstick/06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/lipstick/06.png -------------------------------------------------------------------------------- /imgs/examples/match/lipstick/07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/lipstick/07.png -------------------------------------------------------------------------------- /imgs/examples/match/lipstick/08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/lipstick/08.png -------------------------------------------------------------------------------- /imgs/examples/match/mustache/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/mustache/01.png -------------------------------------------------------------------------------- /imgs/examples/match/mustache/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/mustache/02.png -------------------------------------------------------------------------------- /imgs/examples/match/mustache/03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/mustache/03.png -------------------------------------------------------------------------------- /imgs/examples/match/mustache/04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/mustache/04.png -------------------------------------------------------------------------------- /imgs/examples/match/mustache/05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/mustache/05.png -------------------------------------------------------------------------------- /imgs/examples/match/mustache/06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/mustache/06.png -------------------------------------------------------------------------------- /imgs/examples/match/mustache/07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/mustache/07.png -------------------------------------------------------------------------------- /imgs/examples/match/mustache/08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/mustache/08.png -------------------------------------------------------------------------------- /imgs/examples/mismatch/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/01.png -------------------------------------------------------------------------------- /imgs/examples/mismatch/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/02.png -------------------------------------------------------------------------------- /imgs/examples/mismatch/03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/03.png -------------------------------------------------------------------------------- /imgs/examples/mismatch/04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/04.png -------------------------------------------------------------------------------- /imgs/examples/mismatch/05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/05.png -------------------------------------------------------------------------------- /imgs/examples/mismatch/06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/06.png -------------------------------------------------------------------------------- /imgs/examples/mismatch/07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/07.png -------------------------------------------------------------------------------- /imgs/examples/mismatch/08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/08.png -------------------------------------------------------------------------------- /imgs/examples/mismatch/09.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/09.png -------------------------------------------------------------------------------- /imgs/examples/mismatch/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/10.png -------------------------------------------------------------------------------- /imgs/examples/mismatch/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/11.png -------------------------------------------------------------------------------- /imgs/examples/mismatch/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/12.png -------------------------------------------------------------------------------- /imgs/examples/mismatch/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/13.png -------------------------------------------------------------------------------- /imgs/examples/mismatch/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/14.png -------------------------------------------------------------------------------- /imgs/examples/mismatch/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/15.png -------------------------------------------------------------------------------- /imgs/examples/mismatch/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/16.png -------------------------------------------------------------------------------- /imgs/handdraw/bird_handdraw.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/handdraw/bird_handdraw.png -------------------------------------------------------------------------------- /imgs/handdraw/face_handdraw.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/handdraw/face_handdraw.png -------------------------------------------------------------------------------- /imgs/handdraw/shoe_handdraw.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/handdraw/shoe_handdraw.png -------------------------------------------------------------------------------- /imgs/problem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/problem.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Original Version: Taehoon Kim (http://carpedm20.github.io) 2 | # + Source: https://github.com/carpedm20/DCGAN-tensorflow/blob/e30539fb5e20d5a0fed40935853da97e9e55eee8/model.py 3 | # + License: MIT 4 | # [2016-08-05] Modifications for Completion: Brandon Amos (http://bamos.github.io) 5 | # + License: MIT 6 | # [2017-07] Modifications for sText2Image: Shangzhe Wu 7 | # + License: MIT 8 | 9 | from __future__ import division 10 | import os 11 | import time 12 | from glob import glob 13 | import tensorflow as tf 14 | import pickle 15 | from six.moves import xrange 16 | from scipy.stats import entropy 17 | 18 | from ops import * 19 | from utils import * 20 | 21 | #import pdb 22 | 23 | class GAN(object): 24 | def __init__(self, sess, image_size=64, is_crop=False, 25 | batch_size=64, text_vector_dim=100, 26 | z_dim=100, t_dim=256, gf_dim=64, df_dim=64, c_dim=3, 27 | checkpoint_dir=None, sample_dir=None, log_dir=None, 28 | lam1=0.1, lam2=0.1, lam3=0.1): 29 | """ 30 | 31 | Args: 32 | sess: TensorFlow session 33 | batch_size: The size of batch. Should be specified before training. 34 | z_dim: (optional) Dimension of dim for Z. [100] 35 | t_dim: (optional) Dimension of text features. [256] 36 | gf_dim: (optional) Dimension of gen filters in first conv layer. [64] 37 | df_dim: (optional) Dimension of discrim filters in first conv layer. [64] 38 | c_dim: (optional) Dimension of image color. [3] 39 | lam1: (optional) Hyperparameter for contextual loss. [0.1] 40 | lam2: (optional) Hyperparameter for perceptual loss. [0.1] 41 | lam3: (optional) Hyperparameter for wrong examples [0.1] 42 | """ 43 | self.sess = sess 44 | self.is_crop = is_crop 45 | self.batch_size = batch_size 46 | self.text_vector_dim = text_vector_dim 47 | self.image_size = image_size 48 | self.image_shape = [image_size, image_size * 2, 3] 49 | # self.image_shape = [image_size, image_size, 3] 50 | 51 | self.sample_freq = int(100*64/batch_size) 52 | self.save_freq = int(500*64/batch_size) 53 | 54 | self.z_dim = z_dim 55 | self.t_dim = t_dim 56 | 57 | self.gf_dim = gf_dim 58 | self.df_dim = df_dim 59 | 60 | self.lam1 = lam1 61 | self.lam2 = lam2 62 | self.lam3 = lam3 63 | 64 | self.c_dim = 3 65 | 66 | # batch normalization : deals with poor initialization helps gradient flow 67 | self.d_bn1 = batch_norm(name='d_bn1') 68 | self.d_bn2 = batch_norm(name='d_bn2') 69 | self.d_bn3 = batch_norm(name='d_bn3') 70 | self.d_bn4 = batch_norm(name='d_bn4') 71 | 72 | self.g_bn0 = batch_norm(name='g_bn0') 73 | self.g_bn1 = batch_norm(name='g_bn1') 74 | self.g_bn2 = batch_norm(name='g_bn2') 75 | self.g_bn3 = batch_norm(name='g_bn3') 76 | 77 | self.checkpoint_dir = checkpoint_dir 78 | self.sample_dir = sample_dir 79 | self.log_dir = log_dir 80 | 81 | self.build_model() 82 | 83 | self.model_name = "GAN" 84 | 85 | 86 | def build_model(self): 87 | self.images = tf.placeholder( 88 | tf.float32, [self.batch_size] + self.image_shape, name='real_images') 89 | self.sample_images= tf.placeholder( 90 | tf.float32, [self.batch_size] + self.image_shape, name='sample_images') 91 | 92 | self.z = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='z') 93 | self.z_sum = tf.summary.histogram("z", self.z) 94 | 95 | self.t = tf.placeholder(tf.float32, [self.batch_size, self.text_vector_dim], name='t') 96 | self.t_sum = tf.summary.histogram("t", self.t) 97 | 98 | self.t_wr = tf.placeholder(tf.float32, [self.batch_size, self.text_vector_dim], name='t_wr') 99 | self.t_wr_sum = tf.summary.histogram("t_wr", self.t_wr) 100 | 101 | #self.images_wr = tf.placeholder( 102 | # tf.float32, [self.batch_size] + self.image_shape, name='wrong_images') 103 | 104 | self.G = self.generator(self.z, self.t) 105 | self.D_rl, self.D_logits_rl = self.discriminator(self.images, self.t) 106 | self.D_fk, self.D_logits_fk = self.discriminator(self.G, self.t, reuse=True) 107 | self.D_wr, self.D_logits_wr = self.discriminator(self.images, self.t_wr, reuse=True) 108 | 109 | self.sampler = self.sampler(self.z, self.t) 110 | 111 | self.G_sum = tf.image_summary("G", self.G) 112 | self.d_rl_sum = tf.summary.histogram("d", self.D_rl) 113 | self.d_fk_sum = tf.summary.histogram("d_", self.D_fk) 114 | self.d_wr_sum = tf.summary.histogram("d_wr", self.D_wr) 115 | 116 | # cross entropy loss 117 | self.g_loss = tf.reduce_mean( 118 | tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_fk, 119 | tf.ones_like(self.D_fk))) 120 | self.d_loss_real = tf.reduce_mean( 121 | tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_rl, 122 | tf.ones_like(self.D_rl))) 123 | self.d_loss_fake = tf.reduce_mean( 124 | tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_fk, 125 | tf.zeros_like(self.D_fk))) 126 | self.d_loss_wrong = tf.reduce_mean( 127 | tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_wr, 128 | tf.zeros_like(self.D_wr))) 129 | 130 | ''' 131 | # least square loss 132 | self.d_loss_real = 0.5 * tf.reduce_mean((self.D_logits_rl - tf.ones_like(self.D_logits_rl))**2) 133 | self.d_loss_fake = 0.5 * tf.reduce_mean((self.D_logits_fk - tf.zeros_like(self.D_logits_fk))**2) 134 | self.d_loss_wrong = 0.5 * tf.reduce_mean((self.D_logits_wr - tf.zeros_like(self.D_logits_wr))**2) 135 | self.g_loss = 0.5 * tf.reduce_mean((self.D_logits_fk - tf.ones_like(self.D_logits_fk))**2) 136 | ''' 137 | 138 | self.d_loss_real_sum = tf.scalar_summary("d_loss_real", self.d_loss_real) 139 | self.d_loss_fake_sum = tf.scalar_summary("d_loss_fake", self.d_loss_fake) 140 | self.d_loss_wrong_sum = tf.scalar_summary("d_loss_wrong", self.d_loss_wrong) 141 | 142 | self.d_loss = self.d_loss_real + self.d_loss_fake + self.lam3 * self.d_loss_wrong 143 | 144 | self.g_loss_sum = tf.scalar_summary("g_loss", self.g_loss) 145 | self.d_loss_sum = tf.scalar_summary("d_loss", self.d_loss) 146 | 147 | t_vars = tf.trainable_variables() 148 | 149 | self.d_vars = [var for var in t_vars if 'd_' in var.name] 150 | self.g_vars = [var for var in t_vars if 'g_' in var.name] 151 | 152 | self.saver = tf.train.Saver(max_to_keep=50) 153 | 154 | # mask to generate 155 | self.mask = tf.placeholder(tf.float32, [None] + self.image_shape, name='mask') 156 | 157 | # l1 158 | #self.contextual_loss = tf.reduce_sum( 159 | # tf.contrib.layers.flatten( 160 | # tf.abs(tf.mul(self.mask, self.G) - tf.mul(self.mask, self.images))), 1) 161 | 162 | # kl divergence 163 | self.contextual_loss = kl_divergence( 164 | tf.divide(tf.add(tf.contrib.layers.flatten(tf.image.rgb_to_grayscale( 165 | tf.slice(self.G, [0,0,0,0], [self.batch_size,self.image_size,self.image_size,self.c_dim]))), 1), 2), 166 | tf.divide(tf.add(tf.contrib.layers.flatten(tf.image.rgb_to_grayscale( 167 | tf.slice(self.images, [0,0,0,0], [self.batch_size,self.image_size,self.image_size,self.c_dim]))), 1), 2)) 168 | 169 | self.perceptual_loss = self.g_loss 170 | self.complete_loss = self.lam1*self.contextual_loss + self.lam2*self.perceptual_loss 171 | self.grad_complete_loss = tf.gradients(self.complete_loss, self.z) 172 | 173 | 174 | def train(self, config): 175 | image_data = glob(os.path.join(config.dataset, "*.png")) 176 | #np.random.shuffle(data) 177 | print (os.path.join(config.dataset, "*.png")) 178 | assert(len(image_data) > 0) 179 | 180 | text_data = pickle.load(open(config.text_path, 'rb')) 181 | 182 | ######### for face attributes ######### 183 | attr_sum = np.sum(text_data, 0) 184 | attr_percent = (1 + attr_sum/len(text_data)) / 2 185 | print ("selected attribute percentages:\n", attr_percent) 186 | ######### for face attributes ######### 187 | 188 | d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \ 189 | .minimize(self.d_loss, var_list=self.d_vars) 190 | g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \ 191 | .minimize(self.g_loss, var_list=self.g_vars) 192 | tf.initialize_all_variables().run() 193 | 194 | self.g_sum = tf.merge_summary( 195 | [self.z_sum, self.t_sum, self.d_fk_sum, self.G_sum, self.d_loss_fake_sum, self.g_loss_sum]) 196 | self.d_sum = tf.merge_summary( 197 | [self.z_sum, self.t_sum, self.d_rl_sum, self.d_wr_sum, self.d_loss_real_sum, self.d_loss_wrong_sum, self.d_loss_sum]) 198 | self.writer = tf.train.SummaryWriter(self.log_dir, self.sess.graph) 199 | 200 | 201 | #++++++++ training sample ++++++++# 202 | 203 | sample_z = np.random.uniform(-1, 1, size=(self.batch_size , self.z_dim)) 204 | sample_files = image_data[0:self.batch_size] 205 | sample = [get_image(sample_file, self.image_size, is_crop=self.is_crop) for sample_file in sample_files] 206 | sample_images = np.array(sample).astype(np.float32) 207 | sample_t_ = [get_text_batch(os.path.basename(sample_file), text_data) for sample_file in sample_files] 208 | sample_t = np.array(sample_t_).astype(np.float32) 209 | nRows = np.ceil(self.batch_size/8) 210 | nCols = min(8, self.batch_size) #8 211 | 212 | ######### for face attributes ######### 213 | with open(os.path.join(self.sample_dir, 'sampled_texts.txt'), 'wb') as f: 214 | np.savetxt(f, sample_t, fmt='%i', delimiter='\t') 215 | ######### for face attributes ######### 216 | 217 | #-------- training sample --------# 218 | 219 | counter = 1 220 | start_time = time.time() 221 | 222 | if self.load(self.checkpoint_dir): 223 | print(""" 224 | 225 | ============ 226 | An existing model was found in the checkpoint directory. 227 | If you just cloned this repository, it's Brandon Amos' 228 | trained model for faces that's used in the post. 229 | If you want to train a new model from scratch, 230 | delete the checkpoint directory or specify a different 231 | --checkpoint_dir argument. 232 | ============ 233 | 234 | """) 235 | else: 236 | print(""" 237 | 238 | ============ 239 | An existing model was not found in the checkpoint directory. 240 | Initializing a new one. 241 | ============ 242 | 243 | """) 244 | 245 | for epoch in xrange(config.epoch): 246 | image_data = glob(os.path.join(config.dataset, "*.png")) 247 | batch_idxs = min(len(image_data), config.train_size) // self.batch_size 248 | 249 | for idx in xrange(0, batch_idxs): 250 | 251 | #++++++++ data loading ++++++++# 252 | 253 | data_start_time = time.time() 254 | batch_files = image_data[idx*config.batch_size:(idx+1)*config.batch_size] 255 | batch = [get_image(batch_file, self.image_size, is_crop=self.is_crop) 256 | for batch_file in batch_files] 257 | batch_images = np.array(batch).astype(np.float32) 258 | 259 | batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]) \ 260 | .astype(np.float32) 261 | 262 | batch_t_ = [get_text_batch(os.path.basename(batch_file), text_data) 263 | for batch_file in batch_files] 264 | batch_t = np.array(batch_t_).astype(np.float32) 265 | 266 | ######### for face attributes ######### 267 | # randomly generated wrong face attributes 268 | batch_t_wr_ = [np.random.choice(np.arange(2), size=config.batch_size, 269 | p=[1-attr_percent[i], attr_percent[i]]) * 2 - 1 270 | for i in xrange(self.text_vector_dim)] 271 | batch_t_wr = np.transpose(batch_t_wr_).astype(np.float32) 272 | ######### for face attributes ######### 273 | 274 | ''' 275 | # randomly select wrong images 276 | idx_wr = np.random.randint(batch_idxs) 277 | while (idx_wr == idx): 278 | idx_wr = np.random.randint(batch_idxs) 279 | 280 | batch_files_wr = image_data[idx_wr*config.batch_size:(idx_wr+1)*config.batch_size] 281 | batch_wr = [get_image(batch_file_wr, self.image_size, is_crop=self.is_crop) 282 | for batch_file_wr in batch_files_wr] 283 | batch_images_wr = np.array(batch_wr).astype(np.float32) 284 | ''' 285 | data_time = time.time() - data_start_time 286 | 287 | #-------- data loading --------# 288 | 289 | 290 | #++++++++ training ++++++++# 291 | 292 | # Update D network 293 | _, summary_str = self.sess.run([d_optim, self.d_sum], 294 | feed_dict={ self.images: batch_images, self.z: batch_z, self.t: batch_t, self.t_wr: batch_t_wr }) 295 | self.writer.add_summary(summary_str, counter) 296 | 297 | # Update G network 298 | _, summary_str = self.sess.run([g_optim, self.g_sum], 299 | feed_dict={ self.z: batch_z, self.t: batch_t }) 300 | self.writer.add_summary(summary_str, counter) 301 | 302 | # Run g_optim twice to make sure that d_loss does not go to zero (different from paper) 303 | _, summary_str = self.sess.run([g_optim, self.g_sum], 304 | feed_dict={ self.z: batch_z, self.t: batch_t }) 305 | self.writer.add_summary(summary_str, counter) 306 | 307 | errD_fake = self.d_loss_fake.eval({self.z: batch_z, self.t: batch_t}) 308 | errD_real = self.d_loss_real.eval({self.images: batch_images, self.t: batch_t}) 309 | errG = self.g_loss.eval({self.z: batch_z, self.t: batch_t}) 310 | #-------- training --------# 311 | 312 | counter += 1 313 | print("Epoch: [%2d] [%4d/%4d] data_time: %4.4f, time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ 314 | % (epoch, idx, batch_idxs, data_time, 315 | time.time() - start_time, errD_fake+errD_real, errG)) 316 | 317 | if np.mod(counter, self.sample_freq) == 1: 318 | samples = self.sess.run( 319 | [self.sampler], feed_dict={self.z: sample_z, self.t: sample_t}) 320 | save_images(samples[0], [nRows, nCols], 321 | os.path.join(self.sample_dir, 'train_{:02d}_{:04d}.png'.format(epoch, idx))) 322 | #print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) 323 | 324 | if np.mod(counter, self.save_freq) == 2: 325 | self.save(config.checkpoint_dir, counter) 326 | 327 | 328 | def test(self, config): 329 | tf.initialize_all_variables().run() 330 | 331 | isLoaded = self.load(self.checkpoint_dir) 332 | assert(isLoaded) 333 | 334 | # image_data = glob(os.path.join(config.dataset, "*.png")) 335 | nImgs = len(config.imgs) 336 | 337 | batch_idxs = int(np.ceil(nImgs/self.batch_size)) 338 | if config.maskType == 'right': 339 | mask = np.ones(self.image_shape) 340 | mask[:,self.image_size:,:] = 0.0 341 | elif config.maskType == 'left': 342 | mask = np.ones(self.image_shape) 343 | mask[:,:self.image_size,:] = 0.0 344 | else: 345 | assert(False) 346 | 347 | text_data = pickle.load(open(config.text_path, 'rb')) 348 | 349 | num_batch = int(np.ceil(nImgs/self.batch_size)) 350 | for idx in xrange(0, num_batch): 351 | print('batch no. ' + str(idx+1) + ':\n') 352 | 353 | l = idx*self.batch_size 354 | u = min((idx+1)*self.batch_size, nImgs) 355 | batchSz = u-l 356 | batch_files = config.imgs[l:u] 357 | batch = [get_image(batch_file, self.image_size, is_crop=self.is_crop) 358 | for batch_file in batch_files] 359 | batch_images = np.array(batch).astype(np.float32) 360 | batch_mask = np.resize(mask, [self.batch_size] + self.image_shape) 361 | 362 | os.makedirs(os.path.join(config.outDir, 'hats_imgs_{:04d}'.format(idx))) 363 | os.makedirs(os.path.join(config.outDir, 'completed_{:04d}'.format(idx))) 364 | 365 | #++++++++ for face attributes ++++++++# 366 | 367 | # attributes to be loaded 368 | if (config.attributes[0] == None): 369 | batch_t_ = [get_text_batch(os.path.basename(batch_file), text_data) 370 | for batch_file in batch_files] 371 | batch_t = np.array(batch_t_).astype(np.float32) 372 | 373 | # user_defiened attributes 374 | else: 375 | attr_v = config.attributes 376 | assert(len(attr_v) == self.text_vector_dim, "attribute vector must have the given length") 377 | print('using attributes: ', attr_v) 378 | batch_t = np.array([attr_v,]*batchSz).astype(np.float32) 379 | 380 | with open(os.path.join(config.outDir, 'completed_{:04d}/texts.txt'.format(idx)), 'wb') as f: 381 | np.savetxt(f, batch_t, fmt='%i', delimiter='\t') 382 | 383 | #-------- for face attributes --------# 384 | 385 | # last batch 386 | if batchSz < self.batch_size: 387 | print(batchSz) 388 | padSz = ((0, int(self.batch_size-batchSz)), (0,0), (0,0), (0,0)) 389 | batch_images = np.pad(batch_images, padSz, 'wrap') 390 | batch_images = batch_images.astype(np.float32) 391 | batch_t = np.pad(batch_t, ((0, int(self.batch_size-batchSz)), (0,0)), 'wrap') 392 | 393 | 394 | nRows = np.ceil(batchSz/8) 395 | nCols = min(8, batchSz) #8 396 | 397 | 398 | #++++++++ z initialization ++++++++# 399 | 400 | zhats_init = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim)).astype(np.float32) 401 | zhats_ = zhats_init.copy() 402 | kl_div = np.full(len(zhats_), np.inf) 403 | in_flat = [rgb2gray(img[:,:self.image_size,:]).flatten() for img in batch_images] 404 | in_flat = np.array(in_flat) + 1 405 | kld_avg = 0 406 | 407 | kld_f = open(os.path.join(config.outDir, 'hats_imgs_{:04d}/kld_init.txt'.format(idx)), 'w') 408 | kld_f.write('average kl divergence of initializations:') 409 | for i in xrange(30): 410 | G_imgs = self.sess.run([self.G], feed_dict={ self.z: zhats_, self.t: batch_t }) 411 | save_images(G_imgs[0][:batchSz,:,:,:], [nRows, nCols], 412 | os.path.join(config.outDir, 'hats_imgs_{:04d}/init_{:02d}.png'.format(idx, i))) 413 | 414 | out_flat = [rgb2gray(img[:,:self.image_size,:]).flatten() for img in G_imgs[0]] 415 | out_flat = np.array(out_flat) + 1 416 | 417 | # choose lowest kl divergence 418 | for j in xrange(self.batch_size): 419 | kl_d = entropy(in_flat[j], out_flat[j]) 420 | if (kl_d < kl_div[j]): 421 | zhats_init[j] = zhats_[j] 422 | kl_div[j] = kl_d 423 | 424 | kld_avg = kl_div.mean() 425 | print('average KL divergence:', kld_avg) 426 | kld_f.write('{:02d}: {:04.4f}'.format(i, kld_avg)) 427 | 428 | zhats_ = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim)).astype(np.float32) 429 | 430 | print('choosing min KL divergence:', kld_avg) 431 | kld_f.write('choosing min KL divergence: {:04.4f}'.format(kld_avg)) 432 | kld_f.close() 433 | 434 | G_imgs = self.sess.run([self.G], feed_dict={ self.z: zhats_init, self.t: batch_t }) 435 | save_images(G_imgs[0][:batchSz,:,:,:], [nRows, nCols], 436 | os.path.join(config.outDir, 'hats_imgs_{:04d}/chosen_init.png'.format(idx))) 437 | 438 | #-------- z initialization --------# 439 | 440 | 441 | #++++++++ completion ++++++++# 442 | 443 | zhats = zhats_init.copy().astype(np.float32) 444 | v = 0 445 | 446 | save_images(batch_images[:batchSz,:,:,:], [nRows,nCols], 447 | os.path.join(config.outDir, 'hats_imgs_{:04d}/gt.png'.format(idx))) 448 | masked_images = np.multiply(batch_images, batch_mask) 449 | save_images(masked_images[:batchSz,:,:,:], [nRows,nCols], 450 | os.path.join(config.outDir, 'hats_imgs_{:04d}/masked.png'.format(idx))) 451 | 452 | for i in xrange(config.nIter): 453 | fd = { 454 | self.z: zhats, 455 | self.mask: batch_mask, 456 | self.images: batch_images, 457 | self.t: batch_t, 458 | } 459 | run = [self.complete_loss, self.grad_complete_loss, self.G] 460 | loss, g, G_imgs = self.sess.run(run, feed_dict=fd) 461 | 462 | # update zhats 463 | v_prev = np.copy(v) 464 | v = config.momentum*v - config.lr*g[0] 465 | zhats += -config.momentum * v_prev + (1+config.momentum)*v 466 | zhats = np.clip(zhats, -1, 1) 467 | 468 | # save images 469 | if i % 20 == 0: 470 | print(i, np.mean(loss[0:batchSz])) 471 | imgName = os.path.join(config.outDir, 472 | 'hats_imgs_{:04d}/{:04d}.png'.format(idx, i)) 473 | save_images(G_imgs[:batchSz,:,:,:], [nRows,nCols], imgName) 474 | 475 | inv_masked_hat_images = np.multiply(G_imgs, 1.0-batch_mask) 476 | completed = masked_images + inv_masked_hat_images 477 | imgName = os.path.join(config.outDir, 478 | 'completed_{:04d}/{:04d}.png'.format(idx, i)) 479 | save_images(completed[:batchSz,:,:,:], [nRows,nCols], imgName) 480 | 481 | #-------- completion --------# 482 | 483 | 484 | #++++++++ interpolation visualization ++++++++# 485 | 486 | zhats_final = np.copy(zhats) 487 | diff = zhats_final - zhats_init 488 | step = 5 489 | for i in xrange(step): 490 | z_ = zhats_init + diff / (step-1) * i 491 | G_imgs = self.sess.run([self.G], feed_dict={ self.z: z_, self.t: batch_t }) 492 | imgName = os.path.join(config.outDir, 'hats_imgs_{:04d}/{:01d}_interp.png'.format(idx, i)) 493 | save_images(G_imgs[0][:batchSz,:,:,:], [nRows,nCols], imgName) 494 | 495 | #-------- interpolation visualization --------# 496 | 497 | 498 | def discriminator(self, image, t, reuse=False): 499 | if reuse: 500 | tf.get_variable_scope().reuse_variables() 501 | 502 | h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv')) 503 | 504 | t_ = tf.expand_dims(t, 1) 505 | t_ = tf.expand_dims(t_, 2) 506 | t_tiled = tf.tile(t_, [1,32,64,1], name='tiled_t') 507 | h0_concat = tf.concat(3, [h0, t_tiled], name='h0_concat') 508 | 509 | h1 = lrelu(self.d_bn1(conv2d(h0_concat, self.df_dim*2, name='d_h1_conv'))) 510 | h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, name='d_h2_conv'))) 511 | h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim*8, name='d_h3_conv'))) 512 | 513 | #h4 = linear(tf.reshape(h3, [-1, 8192*2]), 1, 'd_h3_lin') 514 | # conv to 512x1x1 515 | h4 = conv2d(h3, self.df_dim*8, 4, 8, 1, 1, name='d_h4_conv') 516 | 517 | return tf.nn.sigmoid(h4), h4 518 | 519 | 520 | def generator(self, z, t): 521 | 522 | self.z_, self.h0_lin_w, self.h0_lin_b = linear(z, self.gf_dim*4*8, 'g_h0_lin', with_w=True) 523 | z_ = tf.reshape(self.z_, [-1, 4, 8, self.gf_dim]) 524 | 525 | t_ = tf.expand_dims(tf.expand_dims(t, 1), 2) 526 | t_tiled = tf.tile(t_, [1,4,8,1]) 527 | 528 | h0_concat = tf.concat(3, [z_, t_tiled]) 529 | 530 | self.h0, self.h0_w, self.h0_b = conv2d_transpose(h0_concat, 531 | [self.batch_size, 4, 8, self.gf_dim*8], 1, 1, 1, 1, name='g_h0', with_w=True) 532 | h0 = tf.nn.relu(self.g_bn0(self.h0)) 533 | 534 | self.h1, self.h1_w, self.h1_b = conv2d_transpose(h0, 535 | [self.batch_size, 8, 16, self.gf_dim*4], name='g_h1', with_w=True) 536 | h1 = tf.nn.relu(self.g_bn1(self.h1)) 537 | 538 | h2, self.h2_w, self.h2_b = conv2d_transpose(h1, 539 | [self.batch_size, 16, 32, self.gf_dim*2], name='g_h2', with_w=True) 540 | h2 = tf.nn.relu(self.g_bn2(h2)) 541 | 542 | h3, self.h3_w, self.h3_b = conv2d_transpose(h2, 543 | [self.batch_size, 32, 64, self.gf_dim*1], name='g_h3', with_w=True) 544 | h3 = tf.nn.relu(self.g_bn3(h3)) 545 | 546 | h4, self.h4_w, self.h4_b = conv2d_transpose(h3, 547 | [self.batch_size, 64, 128, 3], name='g_h4', with_w=True) 548 | 549 | return tf.nn.tanh(h4) 550 | 551 | 552 | def sampler(self, z, t, y=None): 553 | tf.get_variable_scope().reuse_variables() 554 | 555 | z_ = tf.reshape(linear(z, self.gf_dim*4*8, 'g_h0_lin'), [-1, 4, 8, self.gf_dim]) 556 | 557 | t_ = tf.expand_dims(tf.expand_dims(t, 1), 2) 558 | t_tiled = tf.tile(t_, [1,4,8,1]) 559 | 560 | h0_concat = tf.concat(3, [z_, t_tiled]) 561 | 562 | h0 = conv2d_transpose(h0_concat, 563 | [self.batch_size, 4, 8, self.gf_dim*8], 1, 1, 1, 1, name='g_h0') 564 | h0 = tf.nn.relu(self.g_bn0(h0, train=False)) 565 | 566 | h1 = conv2d_transpose(h0, [self.batch_size, 8, 16, self.gf_dim*4], name='g_h1') 567 | h1 = tf.nn.relu(self.g_bn1(h1, train=False)) 568 | 569 | h2 = conv2d_transpose(h1, [self.batch_size, 16, 32, self.gf_dim*2], name='g_h2') 570 | h2 = tf.nn.relu(self.g_bn2(h2, train=False)) 571 | 572 | h3 = conv2d_transpose(h2, [self.batch_size, 32, 64, self.gf_dim*1], name='g_h3') 573 | h3 = tf.nn.relu(self.g_bn3(h3, train=False)) 574 | 575 | h4 = conv2d_transpose(h3, [self.batch_size, 64, 128, 3], name='g_h4') 576 | 577 | return tf.nn.tanh(h4) 578 | 579 | 580 | def save(self, checkpoint_dir, step): 581 | if not os.path.exists(checkpoint_dir): 582 | os.makedirs(checkpoint_dir) 583 | 584 | self.saver.save(self.sess, 585 | os.path.join(checkpoint_dir, self.model_name), 586 | global_step=step) 587 | 588 | 589 | def load(self, checkpoint_dir): 590 | print(" [*] Reading checkpoints...") 591 | 592 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 593 | if ckpt and ckpt.model_checkpoint_path: 594 | self.saver.restore(self.sess, ckpt.model_checkpoint_path) 595 | return True 596 | else: 597 | return False 598 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | # Original Version: Taehoon Kim (http://carpedm20.github.io) 2 | # + Source: https://github.com/carpedm20/DCGAN-tensorflow/blob/e30539fb5e20d5a0fed40935853da97e9e55eee8/ops.py 3 | # + License: MIT 4 | # [2017-07] Modifications for sText2Image: Shangzhe Wu 5 | # + License: MIT 6 | 7 | import math 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | from tensorflow.python.framework import ops 12 | 13 | from utils import * 14 | 15 | class batch_norm(object): 16 | """Code modification of http://stackoverflow.com/a/33950177""" 17 | def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"): 18 | with tf.variable_scope(name): 19 | self.epsilon = epsilon 20 | self.momentum = momentum 21 | 22 | self.ema = tf.train.ExponentialMovingAverage(decay=self.momentum) 23 | self.name = name 24 | 25 | def __call__(self, x, train=True): 26 | shape = x.get_shape().as_list() 27 | 28 | if train: 29 | with tf.variable_scope(self.name) as scope: 30 | self.beta = tf.get_variable("beta", [shape[-1]], 31 | initializer=tf.constant_initializer(0.)) 32 | self.gamma = tf.get_variable("gamma", [shape[-1]], 33 | initializer=tf.random_normal_initializer(1., 0.02)) 34 | 35 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments') 36 | with tf.variable_scope(tf.get_variable_scope(), reuse=False): 37 | ema_apply_op = self.ema.apply([batch_mean, batch_var]) 38 | self.ema_mean, self.ema_var = self.ema.average(batch_mean), self.ema.average(batch_var) 39 | 40 | with tf.control_dependencies([ema_apply_op]): 41 | mean, var = tf.identity(batch_mean), tf.identity(batch_var) 42 | else: 43 | mean, var = self.ema_mean, self.ema_var 44 | 45 | normed = tf.nn.batch_norm_with_global_normalization( 46 | x, mean, var, self.beta, self.gamma, self.epsilon, scale_after_normalization=True) 47 | 48 | return normed 49 | 50 | def binary_cross_entropy(preds, targets, name=None): 51 | """Computes binary cross entropy given `preds`. 52 | 53 | For brevity, let `x = `, `z = targets`. The logistic loss is 54 | 55 | loss(x, z) = - sum_i (x[i] * log(z[i]) + (1 - x[i]) * log(1 - z[i])) 56 | 57 | Args: 58 | preds: A `Tensor` of type `float32` or `float64`. 59 | targets: A `Tensor` of the same type and shape as `preds`. 60 | """ 61 | eps = 1e-12 62 | with ops.op_scope([preds, targets], name, "bce_loss") as name: 63 | preds = ops.convert_to_tensor(preds, name="preds") 64 | targets = ops.convert_to_tensor(targets, name="targets") 65 | return tf.reduce_mean(-(targets * tf.log(preds + eps) + 66 | (1. - targets) * tf.log(1. - preds + eps))) 67 | 68 | def conv_cond_concat(x, y): 69 | """Concatenate conditioning vector on feature map axis.""" 70 | x_shapes = x.get_shape() 71 | y_shapes = y.get_shape() 72 | return tf.concat(3, [x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])]) 73 | 74 | def conv2d(input_, output_dim, 75 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 76 | name="conv2d"): 77 | with tf.variable_scope(name): 78 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], 79 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 80 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') 81 | 82 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 83 | # conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 84 | conv = tf.nn.bias_add(conv, biases) 85 | 86 | return conv 87 | 88 | def conv2d_transpose(input_, output_shape, 89 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 90 | name="conv2d_transpose", with_w=False): 91 | with tf.variable_scope(name): 92 | # filter : [height, width, output_channels, in_channels] 93 | w = tf.get_variable('w', [k_h, k_h, output_shape[-1], input_.get_shape()[-1]], 94 | initializer=tf.random_normal_initializer(stddev=stddev)) 95 | 96 | try: 97 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, 98 | strides=[1, d_h, d_w, 1]) 99 | 100 | # Support for verisons of TensorFlow before 0.7.0 101 | except AttributeError: 102 | deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, 103 | strides=[1, d_h, d_w, 1]) 104 | 105 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 106 | # deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) 107 | deconv = tf.nn.bias_add(deconv, biases) 108 | 109 | if with_w: 110 | return deconv, w, biases 111 | else: 112 | return deconv 113 | 114 | def lrelu(x, leak=0.2, name="lrelu"): 115 | with tf.variable_scope(name): 116 | f1 = 0.5 * (1 + leak) 117 | f2 = 0.5 * (1 - leak) 118 | return f1 * x + f2 * abs(x) 119 | 120 | def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): 121 | shape = input_.get_shape().as_list() 122 | 123 | with tf.variable_scope(scope or "Linear"): 124 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, 125 | tf.random_normal_initializer(stddev=stddev)) 126 | bias = tf.get_variable("bias", [output_size], 127 | initializer=tf.constant_initializer(bias_start)) 128 | if with_w: 129 | return tf.matmul(input_, matrix) + bias, matrix, bias 130 | else: 131 | return tf.matmul(input_, matrix) + bias 132 | 133 | 134 | ######## Elliott ######## 135 | def kl_divergence(p, q): 136 | tf.assert_rank(p,2) 137 | tf.assert_rank(q,2) 138 | 139 | p_shape = tf.shape(p) 140 | q_shape = tf.shape(q) 141 | tf.assert_equal(p_shape, q_shape) 142 | 143 | # normalize sum to 1 144 | p_ = tf.divide(p, tf.tile(tf.expand_dims(tf.reduce_sum(p,axis=1), 1), [1,p_shape[1]])) 145 | q_ = tf.divide(q, tf.tile(tf.expand_dims(tf.reduce_sum(q,axis=1), 1), [1,p_shape[1]])) 146 | 147 | return tf.reduce_sum(tf.multiply(p_, tf.log(tf.divide(p_, q_))), axis=1) 148 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # [2016-08-05] Modifications for Completion: Brandon Amos (http://bamos.github.io) 2 | # + License: MIT 3 | # [2017-07] Modifications for sText2Image: Shangzhe Wu 4 | # + License: MIT 5 | 6 | import argparse 7 | import os 8 | import tensorflow as tf 9 | 10 | from model import GAN 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--lr', type=float, default=0.01) 14 | parser.add_argument('--momentum', type=float, default=0.9) 15 | parser.add_argument('--nIter', type=int, default=1000) 16 | parser.add_argument('--imgSize', type=int, default=64) 17 | parser.add_argument('--batchSize', type=int, default=64) 18 | parser.add_argument('--text_vector_dim', type=int, default=100) 19 | parser.add_argument('--lam1', type=float, default=0.1) # Hyperparameter for contextual loss [0.1] 20 | parser.add_argument('--lam2', type=float, default=0.1) # Hyperparameter for perceptual loss [0.1] 21 | #parser.add_argument('--lam3', type=float, default=0.1) # Hyperparameter for wrong example [0.1] 22 | parser.add_argument('--checkpointDir', type=str, default='checkpoint') 23 | parser.add_argument('--outDir', type=str, default='results') 24 | parser.add_argument('--text_path', type=str, default='text_embeddings.pkl') 25 | parser.add_argument('--maskType', type=str, 26 | choices=['random', 'center', 'left', 'right', 'full'], 27 | default='right') 28 | parser.add_argument('--attributes', nargs='+', type=int, default=[None]) 29 | parser.add_argument('imgs', type=str, nargs='+') 30 | 31 | args = parser.parse_args() 32 | 33 | assert(os.path.exists(args.checkpointDir)) 34 | 35 | config = tf.ConfigProto() 36 | config.gpu_options.allow_growth = True 37 | with tf.Session(config=config) as sess: 38 | model = GAN(sess, 39 | image_size=args.imgSize, 40 | batch_size=args.batchSize, 41 | text_vector_dim=args.text_vector_dim, 42 | checkpoint_dir=args.checkpointDir, 43 | lam1=args.lam1, 44 | lam2=args.lam2, 45 | ) 46 | model.test(args) 47 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python test.py ./datasets/celeba/test/* --checkpointDir checkpoints_face --maskType right --batchSize 64 --lam1 100 --lam2 1 --lr 0.001 --nIter 1000 --outDir results_face --text_vector_dim 18 --text_path datasets/celeba/imAttrs.pkl -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Original Version: Taehoon Kim (http://carpedm20.github.io) 2 | # + Source: https://github.com/carpedm20/DCGAN-tensorflow/blob/e30539fb5e20d5a0fed40935853da97e9e55eee8/main.py 3 | # + License: MIT 4 | # [2016-08-05] Modifications for Inpainting: Brandon Amos (http://bamos.github.io) 5 | # + License: MIT 6 | # [2017-07] Modifications for sText2Image: Shangzhe Wu 7 | # + License: MIT 8 | 9 | import os 10 | import scipy.misc 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | from model import GAN 15 | 16 | flags = tf.app.flags 17 | flags.DEFINE_integer("epoch", 25, "Epoch to train [25]") 18 | flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]") 19 | flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]") 20 | flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]") 21 | flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]") 22 | flags.DEFINE_integer("image_size", 64, "The size of image to use [64]") 23 | flags.DEFINE_integer("text_vector_dim", 100, "The dimension of input text vector [100]") 24 | flags.DEFINE_string("dataset", "datasets/celeba/train", "Dataset directory.") 25 | flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]") 26 | flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]") 27 | flags.DEFINE_string("log_dir", "logs", "Directory name to save the logs [logs]") 28 | flags.DEFINE_string("text_path", "text_embeddings.pkl", "Path of the text embeddings [text_embeddings.pkl]") 29 | #flags.DEFINE_float("lam1", 0.1, "Hyperparameter for contextual loss [0.1]") 30 | #flags.DEFINE_float("lam2", 0.1, "Hyperparameter for perceptual loss [0.1]") 31 | flags.DEFINE_float("lam3", 0.1, "Hyperparameter for wrong examples [0.1]") 32 | FLAGS = flags.FLAGS 33 | 34 | if not os.path.exists(FLAGS.checkpoint_dir): 35 | os.makedirs(FLAGS.checkpoint_dir) 36 | if not os.path.exists(FLAGS.sample_dir): 37 | os.makedirs(FLAGS.sample_dir) 38 | 39 | config = tf.ConfigProto() 40 | config.gpu_options.allow_growth = True 41 | with tf.Session(config=config) as sess: 42 | model = GAN(sess, 43 | image_size=FLAGS.image_size, 44 | batch_size=FLAGS.batch_size, 45 | text_vector_dim=FLAGS.text_vector_dim, 46 | checkpoint_dir=FLAGS.checkpoint_dir, 47 | sample_dir=FLAGS.sample_dir, 48 | log_dir=FLAGS.log_dir, 49 | is_crop=False, 50 | lam3=FLAGS.lam3, 51 | ) 52 | 53 | model.train(FLAGS) 54 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python train.py --dataset ./datasets/celeba/train --batch_size 64 --checkpoint_dir checkpoints_face --sample_dir samples_face --log_dir logs_face --epoch 50 --learning_rate 0.0002 --lam3 0.1 --text_path datasets/celeba/imAttrs.pkl --text_vector_dim 18 2 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Original Version: Taehoon Kim (http://carpedm20.github.io) 2 | # + Source: https://github.com/carpedm20/DCGAN-tensorflow/blob/e30539fb5e20d5a0fed40935853da97e9e55eee8/utils.py 3 | # + License: MIT 4 | # [2017-07] Modifications for sText2Image: Shangzhe Wu 5 | # + License: MIT 6 | 7 | """ 8 | Some codes from https://github.com/Newmu/dcgan_code 9 | """ 10 | from __future__ import division 11 | import math 12 | import json 13 | import random 14 | import pprint 15 | import scipy.misc 16 | import numpy as np 17 | from time import gmtime, strftime 18 | 19 | import pdb 20 | 21 | pp = pprint.PrettyPrinter() 22 | 23 | get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1]) 24 | 25 | def get_image(image_path, image_size, is_crop=True): 26 | return transform(imread(image_path), image_size, is_crop) 27 | 28 | def save_images(images, size, image_path): 29 | return imsave(inverse_transform(images), size, image_path) 30 | 31 | def imread(path): 32 | return scipy.misc.imread(path).astype(np.float) 33 | 34 | def merge_images(images, size): 35 | return inverse_transform(images) 36 | 37 | def merge(images, size): 38 | h, w = images.shape[1], images.shape[2] 39 | img = np.zeros((int(h * size[0]), int(w * size[1]), 3)) 40 | for idx, image in enumerate(images): 41 | i = idx % size[1] 42 | j = idx // size[1] 43 | img[j*h:j*h+h, i*w:i*w+w, :] = image 44 | 45 | return img 46 | 47 | def imsave(images, size, path): 48 | return scipy.misc.imsave(path, merge(images, size)) 49 | 50 | def center_crop(x, crop_h, crop_w=None, resize_w=64): 51 | if crop_w is None: 52 | crop_w = crop_h 53 | h, w = x.shape[:2] 54 | j = int(round((h - crop_h)/2.)) 55 | i = int(round((w - crop_w)/2.)) 56 | return scipy.misc.imresize(x[j:j+crop_h, i:i+crop_w], 57 | [resize_w, resize_w]) 58 | 59 | def transform(image, npx=64, is_crop=True): 60 | # npx : # of pixels width/height of image 61 | if is_crop: 62 | cropped_image = center_crop(image, npx) 63 | else: 64 | cropped_image = image 65 | return np.array(cropped_image)/127.5 - 1. 66 | 67 | def inverse_transform(images): 68 | return (images+1.)/2. 69 | 70 | 71 | def to_json(output_path, *layers): 72 | with open(output_path, "w") as layer_f: 73 | lines = "" 74 | for w, b, bn in layers: 75 | layer_idx = w.name.split('/')[0].split('h')[1] 76 | 77 | B = b.eval() 78 | 79 | if "lin/" in w.name: 80 | W = w.eval() 81 | depth = W.shape[1] 82 | else: 83 | W = np.rollaxis(w.eval(), 2, 0) 84 | depth = W.shape[0] 85 | 86 | biases = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(B)]} 87 | if bn != None: 88 | gamma = bn.gamma.eval() 89 | beta = bn.beta.eval() 90 | 91 | gamma = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(gamma)]} 92 | beta = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(beta)]} 93 | else: 94 | gamma = {"sy": 1, "sx": 1, "depth": 0, "w": []} 95 | beta = {"sy": 1, "sx": 1, "depth": 0, "w": []} 96 | 97 | if "lin/" in w.name: 98 | fs = [] 99 | for w in W.T: 100 | fs.append({"sy": 1, "sx": 1, "depth": W.shape[0], "w": ['%.2f' % elem for elem in list(w)]}) 101 | 102 | lines += """ 103 | var layer_%s = { 104 | "layer_type": "fc", 105 | "sy": 1, "sx": 1, 106 | "out_sx": 1, "out_sy": 1, 107 | "stride": 1, "pad": 0, 108 | "out_depth": %s, "in_depth": %s, 109 | "biases": %s, 110 | "gamma": %s, 111 | "beta": %s, 112 | "filters": %s 113 | };""" % (layer_idx.split('_')[0], W.shape[1], W.shape[0], biases, gamma, beta, fs) 114 | else: 115 | fs = [] 116 | for w_ in W: 117 | fs.append({"sy": 5, "sx": 5, "depth": W.shape[3], "w": ['%.2f' % elem for elem in list(w_.flatten())]}) 118 | 119 | lines += """ 120 | var layer_%s = { 121 | "layer_type": "deconv", 122 | "sy": 5, "sx": 5, 123 | "out_sx": %s, "out_sy": %s, 124 | "stride": 2, "pad": 1, 125 | "out_depth": %s, "in_depth": %s, 126 | "biases": %s, 127 | "gamma": %s, 128 | "beta": %s, 129 | "filters": %s 130 | };""" % (layer_idx, 2**(int(layer_idx)+2), 2**(int(layer_idx)+2), 131 | W.shape[0], W.shape[3], biases, gamma, beta, fs) 132 | layer_f.write(" ".join(lines.replace("'","").split())) 133 | 134 | def make_gif(images, fname, duration=2, true_image=False): 135 | import moviepy.editor as mpy 136 | 137 | def make_frame(t): 138 | try: 139 | x = images[int(len(images)/duration*t)] 140 | except: 141 | x = images[-1] 142 | 143 | if true_image: 144 | return x.astype(np.uint8) 145 | else: 146 | return ((x+1)/2*255).astype(np.uint8) 147 | 148 | clip = mpy.VideoClip(make_frame, duration=duration) 149 | clip.write_gif(fname, fps = len(images) / duration) 150 | 151 | def visualize(sess, dcgan, config, option): 152 | if option == 0: 153 | z_sample = np.random.uniform(-0.5, 0.5, size=(config.batch_size, dcgan.z_dim)) 154 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 155 | save_images(samples, [8, 8], './samples/test_%s.png' % strftime("%Y-%m-%d %H:%M:%S", gmtime())) 156 | elif option == 1: 157 | values = np.arange(0, 1, 1./config.batch_size) 158 | for idx in xrange(100): 159 | print(" [*] %d" % idx) 160 | z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 161 | for kdx, z in enumerate(z_sample): 162 | z[idx] = values[kdx] 163 | 164 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 165 | save_images(samples, [8, 8], './samples/test_arange_%s.png' % (idx)) 166 | elif option == 2: 167 | values = np.arange(0, 1, 1./config.batch_size) 168 | for idx in [random.randint(0, 99) for _ in xrange(100)]: 169 | print(" [*] %d" % idx) 170 | z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim)) 171 | z_sample = np.tile(z, (config.batch_size, 1)) 172 | #z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 173 | for kdx, z in enumerate(z_sample): 174 | z[idx] = values[kdx] 175 | 176 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 177 | make_gif(samples, './samples/test_gif_%s.gif' % (idx)) 178 | elif option == 3: 179 | values = np.arange(0, 1, 1./config.batch_size) 180 | for idx in xrange(100): 181 | print(" [*] %d" % idx) 182 | z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 183 | for kdx, z in enumerate(z_sample): 184 | z[idx] = values[kdx] 185 | 186 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 187 | make_gif(samples, './samples/test_gif_%s.gif' % (idx)) 188 | elif option == 4: 189 | image_set = [] 190 | values = np.arange(0, 1, 1./config.batch_size) 191 | 192 | for idx in xrange(100): 193 | print(" [*] %d" % idx) 194 | z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 195 | for kdx, z in enumerate(z_sample): z[idx] = values[kdx] 196 | 197 | image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})) 198 | make_gif(image_set[-1], './samples/test_gif_%s.gif' % (idx)) 199 | 200 | new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \ 201 | for idx in range(64) + range(63, -1, -1)] 202 | make_gif(new_image_set, './samples/test_gif_merged.gif', duration=8) 203 | 204 | 205 | ########## Elliott ########## 206 | def get_text_batch(image_path, text_data): 207 | try: 208 | #pdb.set_trace() 209 | idx = int(image_path[0:6])-1 210 | except: 211 | print("image_path format is unexpected.") 212 | else: 213 | return text_data[idx] 214 | 215 | def rgb2gray(rgb): 216 | return np.dot(rgb[...,:3], [0.2989, 0.5870, 0.1140]) --------------------------------------------------------------------------------