├── .gitignore
├── LICENSE
├── README.md
├── datasets
└── download_dataset.sh
├── download_pretrained_model.sh
├── imgs
├── examples
│ ├── combine_A_and_B_2.py
│ ├── handdraw
│ │ ├── 01.png
│ │ ├── 02.png
│ │ ├── 03.png
│ │ ├── 04.png
│ │ ├── 05.png
│ │ ├── 06.png
│ │ ├── 07.png
│ │ ├── 08.png
│ │ ├── 09.png
│ │ ├── 10.png
│ │ ├── 11.png
│ │ ├── 12.png
│ │ ├── 13.png
│ │ ├── 14.png
│ │ ├── 15.png
│ │ ├── 16.png
│ │ └── 17.png
│ ├── match
│ │ ├── eyeglasses
│ │ │ ├── 01.png
│ │ │ ├── 02.png
│ │ │ ├── 03.png
│ │ │ ├── 04.png
│ │ │ ├── 05.png
│ │ │ ├── 06.png
│ │ │ ├── 07.png
│ │ │ └── 08.png
│ │ ├── lipstick
│ │ │ ├── 01.png
│ │ │ ├── 02.png
│ │ │ ├── 03.png
│ │ │ ├── 04.png
│ │ │ ├── 05.png
│ │ │ ├── 06.png
│ │ │ ├── 07.png
│ │ │ └── 08.png
│ │ └── mustache
│ │ │ ├── 01.png
│ │ │ ├── 02.png
│ │ │ ├── 03.png
│ │ │ ├── 04.png
│ │ │ ├── 05.png
│ │ │ ├── 06.png
│ │ │ ├── 07.png
│ │ │ └── 08.png
│ └── mismatch
│ │ ├── 01.png
│ │ ├── 02.png
│ │ ├── 03.png
│ │ ├── 04.png
│ │ ├── 05.png
│ │ ├── 06.png
│ │ ├── 07.png
│ │ ├── 08.png
│ │ ├── 09.png
│ │ ├── 10.png
│ │ ├── 11.png
│ │ ├── 12.png
│ │ ├── 13.png
│ │ ├── 14.png
│ │ ├── 15.png
│ │ └── 16.png
├── handdraw
│ ├── bird_handdraw.png
│ ├── face_handdraw.png
│ └── shoe_handdraw.png
└── problem.png
├── model.py
├── ops.py
├── test.py
├── test.sh
├── train.py
├── train.sh
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Shangzhe Wu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Attribute-Guided Sketch-to-Image Generation
2 |
3 |
4 | ## Introduction
5 | **Note: Implementation in this repo has been largely outdated. You may refer to my [presentation slides](https://goo.gl/S2JjAn) for your own implementation.**
6 |
7 | This is based on one of our prior works on [Sketch-to-Image Generation](https://arxiv.org/abs/1711.08972). Freehand sketch can be highly abstract (examples shown below), and learning representations of sketches is not trivial. In contrast to other cross domain learning approaches, like [pix2pix](https://phillipi.github.io/pix2pix/) and [CycleGAN](https://junyanz.github.io/CycleGAN/), where a mapping from representations in one domain to those in another domain is learned using translation networks, in [Sketch-to-Image Generation](https://arxiv.org/abs/1711.08972), we propose to learn a joint representation of sketch and image.
8 |
9 | In this project we intend to add text constraints to sketch-to-image generation, where texts provide the contents and sketches control the shapes. **So far, I only tried to add attribute guidance instead of using text embeddings as additional conditions, and this repo demonstrates results on _Attribute-Guided Sketch-to-Image Generation_.**
10 |
11 | face |bird |shoe
12 | :--------------------------:|:--------------------------:|:--------------------------:
13 |  | |
14 |
15 | * A few freehand sketches were collected from volunteers.
16 |
17 | #### Contributors:
18 | - Major Contributor: [Shangzhe Wu](https://elliottwu.github.io/) (HKUST)
19 | - Supervisor: Yu-wing Tai (Tencent), [Chi-Keung Tang](http://www.cs.ust.hk/~cktang/) (HKUST)
20 | - Mentor in MLJejuCamp2017: Hyungjoo Cho
21 |
22 | ### MLJejuCamp2017
23 | This project was developed in [Machine Learning Camp Jeju 2017](http://jeju.dlcamp.org/2017/) within one month. More interesting projects can be found in [final presentations](https://github.com/TensorFlowKR/dlcampjeju/blob/master/2017/github/04_FinalPresentation.md) and [program GitHub](https://github.com/MLJejuCamp2017). Final presentation video can be watched [here](https://www.youtube.com/watch?v=X6ieGv82PYU) (partially). Camp 2018 has been launched, and more details can be found [here](http://jeju.dlcamp.org/2018/).
24 |
25 | ## Get Started
26 | ### Prerequisites
27 | - Python 3.5
28 | - [Tensorflow 0.12.1](https://github.com/tensorflow/tensorflow/tree/r0.12)
29 | - [SciPy](https://www.scipy.org/install.html)
30 |
31 | ### Setup
32 | - Clone this repo:
33 | ```bash
34 | git clone https://github.com/elliottwu/sText2Image.git
35 | cd sText2Image
36 | ```
37 |
38 | - Download preprocessed CelebA data (~3GB):
39 | ```bash
40 | sh ./datasets/download_dataset.sh
41 | ```
42 |
43 | ### Train
44 | ```bash
45 | sh train.sh
46 | ```
47 | - To monitor training using Tensorboard, copy the following to your terminal and open `localhost:8888` in your browser
48 | ```bash
49 | tensorboard --logdir=logs_face --port=8888
50 | ```
51 |
52 | ### Test
53 | ```bash
54 | sh test.sh
55 | ```
56 |
57 | ### Pretrained Model
58 | - Download pretrained model:
59 | ```bash
60 | sh download_pretrained_model.sh
61 | ```
62 |
63 | - Test pretrained model on CelebA dataset:
64 | ```bash
65 | python test.py ./datasets/celeba/test/* --checkpointDir checkpoints_face_pretrained --maskType right --batchSize 64 --lam1 100 --lam2 1 --lam3 0.1 --lr 0.001 --nIter 1000 --outDir results_face_pretrained --text_vector_dim 18 --text_path datasets/celeba/imAttrs.pkl
66 | ```
67 |
68 | ## Experiments
69 | We test our framework with 3 kinds of data, face([CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)), bird([CUB](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html)), and flower([Oxford-102](http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html)). So far, we have only experimented with face images using attribute vectors as texts information. Here are some preliminary results:
70 |
71 | ### 1. Face
72 | We used CelebA dataset, which also provides 40 attributes for each image. Similar to the text information, attributes control the specific details of the generated images. We chose 18 attrbutes for training.
73 |
74 | #### a). Attributes match sketch:
75 | The following images were generated given sketches and the **corresponding** attriubtes.
76 |
77 | ##### Mustache
78 | attributes |sketch / generated / gt |attributes | sketch / generated / gt
79 | :-----------:|:----------------------:|:-----------:|:-----------------------:
80 | Male, 5_o_Clock_Shadow, Mouth_Open, Pointy_Nose |
| Male, 5_o_Clock_Shadow, Big_Nose, Mustache |
81 | Male, Big_Lips, Big_Nose, Chubby, Goatee, High_Cheekbones, Smiling |
| Male, Mustache |
82 | Male, Goatee, Mouth_Open, Smiling |
| Male, Big_Nose, Goatee, Smiling |
83 | Male, 5_o_Clock_Shadow, Big_Lips, Big_Nose, Goatee, High_Cheekbones, Mouth_Open, Rosy_Cheeks, Smiling |
| Male, 5_o_Clock_Shadow, Big_Nose, Narrow_Eyes |
84 |
85 | ##### Eyeglasses
86 | attributes |sketch / generated / gt |attributes | sketch / generated / gt
87 | :-----------:|:----------------------:|:-----------:|:-----------------------:
88 | Male, Big_Nose, Eyeglasses, Goatee |
| Female, Eyeglasses |
89 | Female, Eyeglasses, High_Cheekbones, Mouth_Open, Smiling |
| Male, 5_o_Clock_Shadow, Big_Nose, Eyeglasses, Mouth_Open, Smiling |
90 | Male, Big_Nose, Double_Chin, Eyeglasses, Mouth_Open, Pointy_Nose, Smiling |
| Male, Eyeglasses, High_Cheekbones, Mouth_Open, Smiling |
91 | Male, 5_o_Clock_Shadow, Eyeglasses, Mouth_Open, Smiling |
| Male, Big_Lips, Big_Nose, Eyeglasses, Goatee, Mouth_Open |
92 |
93 |
94 | ##### Lipstick
95 | attributes |sketch / generated / gt |attributes | sketch / generated / gt
96 | :-----------:|:----------------------:|:-----------:|:-----------------------:
97 | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Pointy_Nose, Smiling, Wearing_Lipstick |
| Female, Heavy_Makeup, Mouth_Open, Wearing_Lipstick |
98 | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Pointy_Nose, Smiling, Wearing_Lipstick |
| Female, Heavy_Makeup, Pointy_Nose, Smiling, Wearing_Lipstick |
99 | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Pointy_Nose, Smiling, Wearing_Lipstick |
| Female, Big_Lips, Big_Nose, Heavy_Makeup, High_Cheekbones, Mouth_Open, Rosy_Cheeks, Smiling, Wearing_Lipstick |
100 | Female, Heavy_Makeup, Pointy_Nose, Wearing_Lipstick |
| Female, Heavy_Makeup, Mouth_Open, Smiling, Wearing_Lipstick |
101 |
102 | #### b). Attributes mismatch sketch:
103 | The following images were generated given sketches and the **random** attriubtes. The controlling effects of the attributes are still under improvement.
104 |
105 | attributes |sketch / generated |attributes | sketch / generated |attributes | sketch / generated
106 | :-----------:|:----------------------:|:-----------:|:-----------------------:|:-----------:|:-----------------------:
107 | Female, Big_Lips, Heavy_Makeup, Wearing_Lipstick |
| Female, Big_Lips, Heavy_Makeup, Wearing_Lipstick |
| Male, Big_Nose, No_Eyeglasses |
108 | Male, Big_Nose, Chubby, Double_Chin, High_Cheekbones, Smiling |
| Male, Big_Nose, Chubby, Double_Chin, High_Cheekbones, Mouth_Open, Smiling |
| Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick |
109 | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick |
| Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick |
| Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick |
110 | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick, No_Eyeglasses |
| Male |
| Female, Heavy_Makeup, Pale_Skin, Wearing_Lipstick |
111 | Female, Heavy_Makeup, High_Cheekbones, Pointy_Nose, Smiling, Wearing_Lipstick, No_Eyeglasses |
| Female, Heavy_Makeup, High_Cheekbones, Pointy_Nose, Smiling, Wearing_Lipstick |
| Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Pointy_Nose, Rosy_Cheeks, Smiling, Wearing_Lipstick |
112 |
113 | #### c). Freehand sketch:
114 | The following images were generated given freehand sketches and the **random** attriubtes. The controlling effects of the attributes are still under improvement.
115 |
116 | attributes |sketch / generated |attributes | sketch / generated |attributes | sketch / generated
117 | :-----------:|:----------------------:|:-----------:|:-----------------------:|:-----------:|:-----------------------:
118 | Female, Big_Lips, Heavy_Makeup, Wearing_Lipstick |
| Male, Big_Nose |
| Male, Big_Nose, Chubby, Double_Chin, High_Cheekbones, Mouth_Open, Smiling |
119 | Male, Big_Nose, Chubby, Double_Chin, High_Cheekbones, Mouth_Open, Smiling |
| Male, Big_Nose, Chubby, Double_Chin, High_Cheekbones, Mouth_Open, Smiling |
| Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick |
120 | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick |
| Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick |
| Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick |
121 | Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick |
| Female, Big_Lips, Heavy_Makeup, High_Cheekbones, Mouth_Open, Narrow_Eyes, Smiling, Wearing_Lipstick |
| Female, Big_Lips, Heavy_Makeup, High_Cheekbones, Mouth_Open, Narrow_Eyes, Smiling, Wearing_Lipstick |
122 | Female, Heavy_Makeup, High_Cheekbones, Pointy_Nose, Smiling, Wearing_Lipstick |
| Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick |
| Female, Heavy_Makeup, High_Cheekbones, Mouth_Open, Smiling, Wearing_Lipstick |
123 |
124 | ## Acknowledgement
125 | Codes are based on [DCGAN](https://github.com/carpedm20/DCGAN-tensorflow) and [dcgan-completion](https://github.com/bamos/dcgan-completion.tensorflow).
126 |
127 | ## Citation
128 | Consider citing the following paper if you find this repo helpful:
129 | ```
130 | @InProceedings{Lu_2018_ECCV,
131 | author = {Lu, Yongyi and Wu, Shangzhe and Tai, Yu-Wing and Tang, Chi-Keung},
132 | title = {Image Generation from Sketch Constraint Using Contextual GAN},
133 | booktitle = {The European Conference on Computer Vision (ECCV)},
134 | month = {September},
135 | year = {2018}
136 | }
137 | ```
138 |
--------------------------------------------------------------------------------
/datasets/download_dataset.sh:
--------------------------------------------------------------------------------
1 | # celeba dataset
2 | download_celeba(){
3 | echo "----------------------- downloading celeba dataset -----------------------"
4 | wget -O celeba.tar.gz "https://www.robots.ox.ac.uk/~szwu/storage/18_sketch/celeba.tar.gz"
5 | tar xzvf celeba.tar.gz -C ./datasets/
6 | rm celeba.tar.gz
7 | }
8 |
9 | # all datasets
10 | download_all(){
11 | download_celeba
12 | echo "----------------------- done -----------------------"
13 | }
14 |
15 | download_all
16 |
--------------------------------------------------------------------------------
/download_pretrained_model.sh:
--------------------------------------------------------------------------------
1 | # face model
2 | download_face(){
3 | echo "----------------------- downloading face pretrained model -----------------------"
4 | wget -O face_pretrained.tar.gz "https://www.robots.ox.ac.uk/~szwu/storage/18_sketch/face_pretrained.tar.gz"
5 | tar xzvf face_pretrained.tar.gz
6 | rm face_pretrained.tar.gz
7 | }
8 |
9 | # all models
10 | download_all(){
11 | download_face
12 | echo "----------------------- done -----------------------"
13 | }
14 |
15 | download_all
16 |
--------------------------------------------------------------------------------
/imgs/examples/combine_A_and_B_2.py:
--------------------------------------------------------------------------------
1 | from pdb import set_trace as st
2 | import os
3 | import numpy as np
4 | import cv2
5 |
6 | fold = 'handdraw'
7 |
8 | l = ['2_60', '3_54', '5_29', '5_41', '5_62', '8_53', '9_28', '9_39', '9_40', '11_08', '11_20', '11_30', '13_05', '13_17', '15_44', '17_07', '17_51']
9 |
10 | num_imgs = len(l)
11 | for n in range(num_imgs):
12 | print('['+str(n)+'/'+str(num_imgs)+']')
13 | name = '{:02d}'.format(n+1)
14 | path_A = os.path.join(fold, l[n]+'_A.png')
15 | path_B = os.path.join(fold, l[n]+'_B.png')
16 | #path_C = os.path.join(fold, l[n]+'_C.png')
17 | if os.path.isfile(path_A) and os.path.isfile(path_B):
18 | path_ABC = os.path.join(fold, name+'.png')
19 | im_A = cv2.imread(path_A)
20 | im_B = cv2.imread(path_B)
21 | #im_C = cv2.imread(path_C)
22 | h, w, c = im_A.shape
23 | gap = np.ones((h, 5, c))*255
24 | im_ABC = np.concatenate([im_A, gap, im_B], 1)
25 | cv2.imwrite(path_ABC, im_ABC)
26 |
27 |
--------------------------------------------------------------------------------
/imgs/examples/handdraw/01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/01.png
--------------------------------------------------------------------------------
/imgs/examples/handdraw/02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/02.png
--------------------------------------------------------------------------------
/imgs/examples/handdraw/03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/03.png
--------------------------------------------------------------------------------
/imgs/examples/handdraw/04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/04.png
--------------------------------------------------------------------------------
/imgs/examples/handdraw/05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/05.png
--------------------------------------------------------------------------------
/imgs/examples/handdraw/06.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/06.png
--------------------------------------------------------------------------------
/imgs/examples/handdraw/07.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/07.png
--------------------------------------------------------------------------------
/imgs/examples/handdraw/08.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/08.png
--------------------------------------------------------------------------------
/imgs/examples/handdraw/09.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/09.png
--------------------------------------------------------------------------------
/imgs/examples/handdraw/10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/10.png
--------------------------------------------------------------------------------
/imgs/examples/handdraw/11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/11.png
--------------------------------------------------------------------------------
/imgs/examples/handdraw/12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/12.png
--------------------------------------------------------------------------------
/imgs/examples/handdraw/13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/13.png
--------------------------------------------------------------------------------
/imgs/examples/handdraw/14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/14.png
--------------------------------------------------------------------------------
/imgs/examples/handdraw/15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/15.png
--------------------------------------------------------------------------------
/imgs/examples/handdraw/16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/16.png
--------------------------------------------------------------------------------
/imgs/examples/handdraw/17.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/handdraw/17.png
--------------------------------------------------------------------------------
/imgs/examples/match/eyeglasses/01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/eyeglasses/01.png
--------------------------------------------------------------------------------
/imgs/examples/match/eyeglasses/02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/eyeglasses/02.png
--------------------------------------------------------------------------------
/imgs/examples/match/eyeglasses/03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/eyeglasses/03.png
--------------------------------------------------------------------------------
/imgs/examples/match/eyeglasses/04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/eyeglasses/04.png
--------------------------------------------------------------------------------
/imgs/examples/match/eyeglasses/05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/eyeglasses/05.png
--------------------------------------------------------------------------------
/imgs/examples/match/eyeglasses/06.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/eyeglasses/06.png
--------------------------------------------------------------------------------
/imgs/examples/match/eyeglasses/07.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/eyeglasses/07.png
--------------------------------------------------------------------------------
/imgs/examples/match/eyeglasses/08.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/eyeglasses/08.png
--------------------------------------------------------------------------------
/imgs/examples/match/lipstick/01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/lipstick/01.png
--------------------------------------------------------------------------------
/imgs/examples/match/lipstick/02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/lipstick/02.png
--------------------------------------------------------------------------------
/imgs/examples/match/lipstick/03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/lipstick/03.png
--------------------------------------------------------------------------------
/imgs/examples/match/lipstick/04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/lipstick/04.png
--------------------------------------------------------------------------------
/imgs/examples/match/lipstick/05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/lipstick/05.png
--------------------------------------------------------------------------------
/imgs/examples/match/lipstick/06.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/lipstick/06.png
--------------------------------------------------------------------------------
/imgs/examples/match/lipstick/07.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/lipstick/07.png
--------------------------------------------------------------------------------
/imgs/examples/match/lipstick/08.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/lipstick/08.png
--------------------------------------------------------------------------------
/imgs/examples/match/mustache/01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/mustache/01.png
--------------------------------------------------------------------------------
/imgs/examples/match/mustache/02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/mustache/02.png
--------------------------------------------------------------------------------
/imgs/examples/match/mustache/03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/mustache/03.png
--------------------------------------------------------------------------------
/imgs/examples/match/mustache/04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/mustache/04.png
--------------------------------------------------------------------------------
/imgs/examples/match/mustache/05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/mustache/05.png
--------------------------------------------------------------------------------
/imgs/examples/match/mustache/06.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/mustache/06.png
--------------------------------------------------------------------------------
/imgs/examples/match/mustache/07.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/mustache/07.png
--------------------------------------------------------------------------------
/imgs/examples/match/mustache/08.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/match/mustache/08.png
--------------------------------------------------------------------------------
/imgs/examples/mismatch/01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/01.png
--------------------------------------------------------------------------------
/imgs/examples/mismatch/02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/02.png
--------------------------------------------------------------------------------
/imgs/examples/mismatch/03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/03.png
--------------------------------------------------------------------------------
/imgs/examples/mismatch/04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/04.png
--------------------------------------------------------------------------------
/imgs/examples/mismatch/05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/05.png
--------------------------------------------------------------------------------
/imgs/examples/mismatch/06.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/06.png
--------------------------------------------------------------------------------
/imgs/examples/mismatch/07.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/07.png
--------------------------------------------------------------------------------
/imgs/examples/mismatch/08.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/08.png
--------------------------------------------------------------------------------
/imgs/examples/mismatch/09.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/09.png
--------------------------------------------------------------------------------
/imgs/examples/mismatch/10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/10.png
--------------------------------------------------------------------------------
/imgs/examples/mismatch/11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/11.png
--------------------------------------------------------------------------------
/imgs/examples/mismatch/12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/12.png
--------------------------------------------------------------------------------
/imgs/examples/mismatch/13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/13.png
--------------------------------------------------------------------------------
/imgs/examples/mismatch/14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/14.png
--------------------------------------------------------------------------------
/imgs/examples/mismatch/15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/15.png
--------------------------------------------------------------------------------
/imgs/examples/mismatch/16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/examples/mismatch/16.png
--------------------------------------------------------------------------------
/imgs/handdraw/bird_handdraw.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/handdraw/bird_handdraw.png
--------------------------------------------------------------------------------
/imgs/handdraw/face_handdraw.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/handdraw/face_handdraw.png
--------------------------------------------------------------------------------
/imgs/handdraw/shoe_handdraw.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/handdraw/shoe_handdraw.png
--------------------------------------------------------------------------------
/imgs/problem.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elliottwu/sText2Image/f3a32e0a93b84dda8447d5c959056405d0d9894a/imgs/problem.png
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # Original Version: Taehoon Kim (http://carpedm20.github.io)
2 | # + Source: https://github.com/carpedm20/DCGAN-tensorflow/blob/e30539fb5e20d5a0fed40935853da97e9e55eee8/model.py
3 | # + License: MIT
4 | # [2016-08-05] Modifications for Completion: Brandon Amos (http://bamos.github.io)
5 | # + License: MIT
6 | # [2017-07] Modifications for sText2Image: Shangzhe Wu
7 | # + License: MIT
8 |
9 | from __future__ import division
10 | import os
11 | import time
12 | from glob import glob
13 | import tensorflow as tf
14 | import pickle
15 | from six.moves import xrange
16 | from scipy.stats import entropy
17 |
18 | from ops import *
19 | from utils import *
20 |
21 | #import pdb
22 |
23 | class GAN(object):
24 | def __init__(self, sess, image_size=64, is_crop=False,
25 | batch_size=64, text_vector_dim=100,
26 | z_dim=100, t_dim=256, gf_dim=64, df_dim=64, c_dim=3,
27 | checkpoint_dir=None, sample_dir=None, log_dir=None,
28 | lam1=0.1, lam2=0.1, lam3=0.1):
29 | """
30 |
31 | Args:
32 | sess: TensorFlow session
33 | batch_size: The size of batch. Should be specified before training.
34 | z_dim: (optional) Dimension of dim for Z. [100]
35 | t_dim: (optional) Dimension of text features. [256]
36 | gf_dim: (optional) Dimension of gen filters in first conv layer. [64]
37 | df_dim: (optional) Dimension of discrim filters in first conv layer. [64]
38 | c_dim: (optional) Dimension of image color. [3]
39 | lam1: (optional) Hyperparameter for contextual loss. [0.1]
40 | lam2: (optional) Hyperparameter for perceptual loss. [0.1]
41 | lam3: (optional) Hyperparameter for wrong examples [0.1]
42 | """
43 | self.sess = sess
44 | self.is_crop = is_crop
45 | self.batch_size = batch_size
46 | self.text_vector_dim = text_vector_dim
47 | self.image_size = image_size
48 | self.image_shape = [image_size, image_size * 2, 3]
49 | # self.image_shape = [image_size, image_size, 3]
50 |
51 | self.sample_freq = int(100*64/batch_size)
52 | self.save_freq = int(500*64/batch_size)
53 |
54 | self.z_dim = z_dim
55 | self.t_dim = t_dim
56 |
57 | self.gf_dim = gf_dim
58 | self.df_dim = df_dim
59 |
60 | self.lam1 = lam1
61 | self.lam2 = lam2
62 | self.lam3 = lam3
63 |
64 | self.c_dim = 3
65 |
66 | # batch normalization : deals with poor initialization helps gradient flow
67 | self.d_bn1 = batch_norm(name='d_bn1')
68 | self.d_bn2 = batch_norm(name='d_bn2')
69 | self.d_bn3 = batch_norm(name='d_bn3')
70 | self.d_bn4 = batch_norm(name='d_bn4')
71 |
72 | self.g_bn0 = batch_norm(name='g_bn0')
73 | self.g_bn1 = batch_norm(name='g_bn1')
74 | self.g_bn2 = batch_norm(name='g_bn2')
75 | self.g_bn3 = batch_norm(name='g_bn3')
76 |
77 | self.checkpoint_dir = checkpoint_dir
78 | self.sample_dir = sample_dir
79 | self.log_dir = log_dir
80 |
81 | self.build_model()
82 |
83 | self.model_name = "GAN"
84 |
85 |
86 | def build_model(self):
87 | self.images = tf.placeholder(
88 | tf.float32, [self.batch_size] + self.image_shape, name='real_images')
89 | self.sample_images= tf.placeholder(
90 | tf.float32, [self.batch_size] + self.image_shape, name='sample_images')
91 |
92 | self.z = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='z')
93 | self.z_sum = tf.summary.histogram("z", self.z)
94 |
95 | self.t = tf.placeholder(tf.float32, [self.batch_size, self.text_vector_dim], name='t')
96 | self.t_sum = tf.summary.histogram("t", self.t)
97 |
98 | self.t_wr = tf.placeholder(tf.float32, [self.batch_size, self.text_vector_dim], name='t_wr')
99 | self.t_wr_sum = tf.summary.histogram("t_wr", self.t_wr)
100 |
101 | #self.images_wr = tf.placeholder(
102 | # tf.float32, [self.batch_size] + self.image_shape, name='wrong_images')
103 |
104 | self.G = self.generator(self.z, self.t)
105 | self.D_rl, self.D_logits_rl = self.discriminator(self.images, self.t)
106 | self.D_fk, self.D_logits_fk = self.discriminator(self.G, self.t, reuse=True)
107 | self.D_wr, self.D_logits_wr = self.discriminator(self.images, self.t_wr, reuse=True)
108 |
109 | self.sampler = self.sampler(self.z, self.t)
110 |
111 | self.G_sum = tf.image_summary("G", self.G)
112 | self.d_rl_sum = tf.summary.histogram("d", self.D_rl)
113 | self.d_fk_sum = tf.summary.histogram("d_", self.D_fk)
114 | self.d_wr_sum = tf.summary.histogram("d_wr", self.D_wr)
115 |
116 | # cross entropy loss
117 | self.g_loss = tf.reduce_mean(
118 | tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_fk,
119 | tf.ones_like(self.D_fk)))
120 | self.d_loss_real = tf.reduce_mean(
121 | tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_rl,
122 | tf.ones_like(self.D_rl)))
123 | self.d_loss_fake = tf.reduce_mean(
124 | tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_fk,
125 | tf.zeros_like(self.D_fk)))
126 | self.d_loss_wrong = tf.reduce_mean(
127 | tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_wr,
128 | tf.zeros_like(self.D_wr)))
129 |
130 | '''
131 | # least square loss
132 | self.d_loss_real = 0.5 * tf.reduce_mean((self.D_logits_rl - tf.ones_like(self.D_logits_rl))**2)
133 | self.d_loss_fake = 0.5 * tf.reduce_mean((self.D_logits_fk - tf.zeros_like(self.D_logits_fk))**2)
134 | self.d_loss_wrong = 0.5 * tf.reduce_mean((self.D_logits_wr - tf.zeros_like(self.D_logits_wr))**2)
135 | self.g_loss = 0.5 * tf.reduce_mean((self.D_logits_fk - tf.ones_like(self.D_logits_fk))**2)
136 | '''
137 |
138 | self.d_loss_real_sum = tf.scalar_summary("d_loss_real", self.d_loss_real)
139 | self.d_loss_fake_sum = tf.scalar_summary("d_loss_fake", self.d_loss_fake)
140 | self.d_loss_wrong_sum = tf.scalar_summary("d_loss_wrong", self.d_loss_wrong)
141 |
142 | self.d_loss = self.d_loss_real + self.d_loss_fake + self.lam3 * self.d_loss_wrong
143 |
144 | self.g_loss_sum = tf.scalar_summary("g_loss", self.g_loss)
145 | self.d_loss_sum = tf.scalar_summary("d_loss", self.d_loss)
146 |
147 | t_vars = tf.trainable_variables()
148 |
149 | self.d_vars = [var for var in t_vars if 'd_' in var.name]
150 | self.g_vars = [var for var in t_vars if 'g_' in var.name]
151 |
152 | self.saver = tf.train.Saver(max_to_keep=50)
153 |
154 | # mask to generate
155 | self.mask = tf.placeholder(tf.float32, [None] + self.image_shape, name='mask')
156 |
157 | # l1
158 | #self.contextual_loss = tf.reduce_sum(
159 | # tf.contrib.layers.flatten(
160 | # tf.abs(tf.mul(self.mask, self.G) - tf.mul(self.mask, self.images))), 1)
161 |
162 | # kl divergence
163 | self.contextual_loss = kl_divergence(
164 | tf.divide(tf.add(tf.contrib.layers.flatten(tf.image.rgb_to_grayscale(
165 | tf.slice(self.G, [0,0,0,0], [self.batch_size,self.image_size,self.image_size,self.c_dim]))), 1), 2),
166 | tf.divide(tf.add(tf.contrib.layers.flatten(tf.image.rgb_to_grayscale(
167 | tf.slice(self.images, [0,0,0,0], [self.batch_size,self.image_size,self.image_size,self.c_dim]))), 1), 2))
168 |
169 | self.perceptual_loss = self.g_loss
170 | self.complete_loss = self.lam1*self.contextual_loss + self.lam2*self.perceptual_loss
171 | self.grad_complete_loss = tf.gradients(self.complete_loss, self.z)
172 |
173 |
174 | def train(self, config):
175 | image_data = glob(os.path.join(config.dataset, "*.png"))
176 | #np.random.shuffle(data)
177 | print (os.path.join(config.dataset, "*.png"))
178 | assert(len(image_data) > 0)
179 |
180 | text_data = pickle.load(open(config.text_path, 'rb'))
181 |
182 | ######### for face attributes #########
183 | attr_sum = np.sum(text_data, 0)
184 | attr_percent = (1 + attr_sum/len(text_data)) / 2
185 | print ("selected attribute percentages:\n", attr_percent)
186 | ######### for face attributes #########
187 |
188 | d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
189 | .minimize(self.d_loss, var_list=self.d_vars)
190 | g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
191 | .minimize(self.g_loss, var_list=self.g_vars)
192 | tf.initialize_all_variables().run()
193 |
194 | self.g_sum = tf.merge_summary(
195 | [self.z_sum, self.t_sum, self.d_fk_sum, self.G_sum, self.d_loss_fake_sum, self.g_loss_sum])
196 | self.d_sum = tf.merge_summary(
197 | [self.z_sum, self.t_sum, self.d_rl_sum, self.d_wr_sum, self.d_loss_real_sum, self.d_loss_wrong_sum, self.d_loss_sum])
198 | self.writer = tf.train.SummaryWriter(self.log_dir, self.sess.graph)
199 |
200 |
201 | #++++++++ training sample ++++++++#
202 |
203 | sample_z = np.random.uniform(-1, 1, size=(self.batch_size , self.z_dim))
204 | sample_files = image_data[0:self.batch_size]
205 | sample = [get_image(sample_file, self.image_size, is_crop=self.is_crop) for sample_file in sample_files]
206 | sample_images = np.array(sample).astype(np.float32)
207 | sample_t_ = [get_text_batch(os.path.basename(sample_file), text_data) for sample_file in sample_files]
208 | sample_t = np.array(sample_t_).astype(np.float32)
209 | nRows = np.ceil(self.batch_size/8)
210 | nCols = min(8, self.batch_size) #8
211 |
212 | ######### for face attributes #########
213 | with open(os.path.join(self.sample_dir, 'sampled_texts.txt'), 'wb') as f:
214 | np.savetxt(f, sample_t, fmt='%i', delimiter='\t')
215 | ######### for face attributes #########
216 |
217 | #-------- training sample --------#
218 |
219 | counter = 1
220 | start_time = time.time()
221 |
222 | if self.load(self.checkpoint_dir):
223 | print("""
224 |
225 | ============
226 | An existing model was found in the checkpoint directory.
227 | If you just cloned this repository, it's Brandon Amos'
228 | trained model for faces that's used in the post.
229 | If you want to train a new model from scratch,
230 | delete the checkpoint directory or specify a different
231 | --checkpoint_dir argument.
232 | ============
233 |
234 | """)
235 | else:
236 | print("""
237 |
238 | ============
239 | An existing model was not found in the checkpoint directory.
240 | Initializing a new one.
241 | ============
242 |
243 | """)
244 |
245 | for epoch in xrange(config.epoch):
246 | image_data = glob(os.path.join(config.dataset, "*.png"))
247 | batch_idxs = min(len(image_data), config.train_size) // self.batch_size
248 |
249 | for idx in xrange(0, batch_idxs):
250 |
251 | #++++++++ data loading ++++++++#
252 |
253 | data_start_time = time.time()
254 | batch_files = image_data[idx*config.batch_size:(idx+1)*config.batch_size]
255 | batch = [get_image(batch_file, self.image_size, is_crop=self.is_crop)
256 | for batch_file in batch_files]
257 | batch_images = np.array(batch).astype(np.float32)
258 |
259 | batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]) \
260 | .astype(np.float32)
261 |
262 | batch_t_ = [get_text_batch(os.path.basename(batch_file), text_data)
263 | for batch_file in batch_files]
264 | batch_t = np.array(batch_t_).astype(np.float32)
265 |
266 | ######### for face attributes #########
267 | # randomly generated wrong face attributes
268 | batch_t_wr_ = [np.random.choice(np.arange(2), size=config.batch_size,
269 | p=[1-attr_percent[i], attr_percent[i]]) * 2 - 1
270 | for i in xrange(self.text_vector_dim)]
271 | batch_t_wr = np.transpose(batch_t_wr_).astype(np.float32)
272 | ######### for face attributes #########
273 |
274 | '''
275 | # randomly select wrong images
276 | idx_wr = np.random.randint(batch_idxs)
277 | while (idx_wr == idx):
278 | idx_wr = np.random.randint(batch_idxs)
279 |
280 | batch_files_wr = image_data[idx_wr*config.batch_size:(idx_wr+1)*config.batch_size]
281 | batch_wr = [get_image(batch_file_wr, self.image_size, is_crop=self.is_crop)
282 | for batch_file_wr in batch_files_wr]
283 | batch_images_wr = np.array(batch_wr).astype(np.float32)
284 | '''
285 | data_time = time.time() - data_start_time
286 |
287 | #-------- data loading --------#
288 |
289 |
290 | #++++++++ training ++++++++#
291 |
292 | # Update D network
293 | _, summary_str = self.sess.run([d_optim, self.d_sum],
294 | feed_dict={ self.images: batch_images, self.z: batch_z, self.t: batch_t, self.t_wr: batch_t_wr })
295 | self.writer.add_summary(summary_str, counter)
296 |
297 | # Update G network
298 | _, summary_str = self.sess.run([g_optim, self.g_sum],
299 | feed_dict={ self.z: batch_z, self.t: batch_t })
300 | self.writer.add_summary(summary_str, counter)
301 |
302 | # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
303 | _, summary_str = self.sess.run([g_optim, self.g_sum],
304 | feed_dict={ self.z: batch_z, self.t: batch_t })
305 | self.writer.add_summary(summary_str, counter)
306 |
307 | errD_fake = self.d_loss_fake.eval({self.z: batch_z, self.t: batch_t})
308 | errD_real = self.d_loss_real.eval({self.images: batch_images, self.t: batch_t})
309 | errG = self.g_loss.eval({self.z: batch_z, self.t: batch_t})
310 | #-------- training --------#
311 |
312 | counter += 1
313 | print("Epoch: [%2d] [%4d/%4d] data_time: %4.4f, time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
314 | % (epoch, idx, batch_idxs, data_time,
315 | time.time() - start_time, errD_fake+errD_real, errG))
316 |
317 | if np.mod(counter, self.sample_freq) == 1:
318 | samples = self.sess.run(
319 | [self.sampler], feed_dict={self.z: sample_z, self.t: sample_t})
320 | save_images(samples[0], [nRows, nCols],
321 | os.path.join(self.sample_dir, 'train_{:02d}_{:04d}.png'.format(epoch, idx)))
322 | #print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss))
323 |
324 | if np.mod(counter, self.save_freq) == 2:
325 | self.save(config.checkpoint_dir, counter)
326 |
327 |
328 | def test(self, config):
329 | tf.initialize_all_variables().run()
330 |
331 | isLoaded = self.load(self.checkpoint_dir)
332 | assert(isLoaded)
333 |
334 | # image_data = glob(os.path.join(config.dataset, "*.png"))
335 | nImgs = len(config.imgs)
336 |
337 | batch_idxs = int(np.ceil(nImgs/self.batch_size))
338 | if config.maskType == 'right':
339 | mask = np.ones(self.image_shape)
340 | mask[:,self.image_size:,:] = 0.0
341 | elif config.maskType == 'left':
342 | mask = np.ones(self.image_shape)
343 | mask[:,:self.image_size,:] = 0.0
344 | else:
345 | assert(False)
346 |
347 | text_data = pickle.load(open(config.text_path, 'rb'))
348 |
349 | num_batch = int(np.ceil(nImgs/self.batch_size))
350 | for idx in xrange(0, num_batch):
351 | print('batch no. ' + str(idx+1) + ':\n')
352 |
353 | l = idx*self.batch_size
354 | u = min((idx+1)*self.batch_size, nImgs)
355 | batchSz = u-l
356 | batch_files = config.imgs[l:u]
357 | batch = [get_image(batch_file, self.image_size, is_crop=self.is_crop)
358 | for batch_file in batch_files]
359 | batch_images = np.array(batch).astype(np.float32)
360 | batch_mask = np.resize(mask, [self.batch_size] + self.image_shape)
361 |
362 | os.makedirs(os.path.join(config.outDir, 'hats_imgs_{:04d}'.format(idx)))
363 | os.makedirs(os.path.join(config.outDir, 'completed_{:04d}'.format(idx)))
364 |
365 | #++++++++ for face attributes ++++++++#
366 |
367 | # attributes to be loaded
368 | if (config.attributes[0] == None):
369 | batch_t_ = [get_text_batch(os.path.basename(batch_file), text_data)
370 | for batch_file in batch_files]
371 | batch_t = np.array(batch_t_).astype(np.float32)
372 |
373 | # user_defiened attributes
374 | else:
375 | attr_v = config.attributes
376 | assert(len(attr_v) == self.text_vector_dim, "attribute vector must have the given length")
377 | print('using attributes: ', attr_v)
378 | batch_t = np.array([attr_v,]*batchSz).astype(np.float32)
379 |
380 | with open(os.path.join(config.outDir, 'completed_{:04d}/texts.txt'.format(idx)), 'wb') as f:
381 | np.savetxt(f, batch_t, fmt='%i', delimiter='\t')
382 |
383 | #-------- for face attributes --------#
384 |
385 | # last batch
386 | if batchSz < self.batch_size:
387 | print(batchSz)
388 | padSz = ((0, int(self.batch_size-batchSz)), (0,0), (0,0), (0,0))
389 | batch_images = np.pad(batch_images, padSz, 'wrap')
390 | batch_images = batch_images.astype(np.float32)
391 | batch_t = np.pad(batch_t, ((0, int(self.batch_size-batchSz)), (0,0)), 'wrap')
392 |
393 |
394 | nRows = np.ceil(batchSz/8)
395 | nCols = min(8, batchSz) #8
396 |
397 |
398 | #++++++++ z initialization ++++++++#
399 |
400 | zhats_init = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim)).astype(np.float32)
401 | zhats_ = zhats_init.copy()
402 | kl_div = np.full(len(zhats_), np.inf)
403 | in_flat = [rgb2gray(img[:,:self.image_size,:]).flatten() for img in batch_images]
404 | in_flat = np.array(in_flat) + 1
405 | kld_avg = 0
406 |
407 | kld_f = open(os.path.join(config.outDir, 'hats_imgs_{:04d}/kld_init.txt'.format(idx)), 'w')
408 | kld_f.write('average kl divergence of initializations:')
409 | for i in xrange(30):
410 | G_imgs = self.sess.run([self.G], feed_dict={ self.z: zhats_, self.t: batch_t })
411 | save_images(G_imgs[0][:batchSz,:,:,:], [nRows, nCols],
412 | os.path.join(config.outDir, 'hats_imgs_{:04d}/init_{:02d}.png'.format(idx, i)))
413 |
414 | out_flat = [rgb2gray(img[:,:self.image_size,:]).flatten() for img in G_imgs[0]]
415 | out_flat = np.array(out_flat) + 1
416 |
417 | # choose lowest kl divergence
418 | for j in xrange(self.batch_size):
419 | kl_d = entropy(in_flat[j], out_flat[j])
420 | if (kl_d < kl_div[j]):
421 | zhats_init[j] = zhats_[j]
422 | kl_div[j] = kl_d
423 |
424 | kld_avg = kl_div.mean()
425 | print('average KL divergence:', kld_avg)
426 | kld_f.write('{:02d}: {:04.4f}'.format(i, kld_avg))
427 |
428 | zhats_ = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim)).astype(np.float32)
429 |
430 | print('choosing min KL divergence:', kld_avg)
431 | kld_f.write('choosing min KL divergence: {:04.4f}'.format(kld_avg))
432 | kld_f.close()
433 |
434 | G_imgs = self.sess.run([self.G], feed_dict={ self.z: zhats_init, self.t: batch_t })
435 | save_images(G_imgs[0][:batchSz,:,:,:], [nRows, nCols],
436 | os.path.join(config.outDir, 'hats_imgs_{:04d}/chosen_init.png'.format(idx)))
437 |
438 | #-------- z initialization --------#
439 |
440 |
441 | #++++++++ completion ++++++++#
442 |
443 | zhats = zhats_init.copy().astype(np.float32)
444 | v = 0
445 |
446 | save_images(batch_images[:batchSz,:,:,:], [nRows,nCols],
447 | os.path.join(config.outDir, 'hats_imgs_{:04d}/gt.png'.format(idx)))
448 | masked_images = np.multiply(batch_images, batch_mask)
449 | save_images(masked_images[:batchSz,:,:,:], [nRows,nCols],
450 | os.path.join(config.outDir, 'hats_imgs_{:04d}/masked.png'.format(idx)))
451 |
452 | for i in xrange(config.nIter):
453 | fd = {
454 | self.z: zhats,
455 | self.mask: batch_mask,
456 | self.images: batch_images,
457 | self.t: batch_t,
458 | }
459 | run = [self.complete_loss, self.grad_complete_loss, self.G]
460 | loss, g, G_imgs = self.sess.run(run, feed_dict=fd)
461 |
462 | # update zhats
463 | v_prev = np.copy(v)
464 | v = config.momentum*v - config.lr*g[0]
465 | zhats += -config.momentum * v_prev + (1+config.momentum)*v
466 | zhats = np.clip(zhats, -1, 1)
467 |
468 | # save images
469 | if i % 20 == 0:
470 | print(i, np.mean(loss[0:batchSz]))
471 | imgName = os.path.join(config.outDir,
472 | 'hats_imgs_{:04d}/{:04d}.png'.format(idx, i))
473 | save_images(G_imgs[:batchSz,:,:,:], [nRows,nCols], imgName)
474 |
475 | inv_masked_hat_images = np.multiply(G_imgs, 1.0-batch_mask)
476 | completed = masked_images + inv_masked_hat_images
477 | imgName = os.path.join(config.outDir,
478 | 'completed_{:04d}/{:04d}.png'.format(idx, i))
479 | save_images(completed[:batchSz,:,:,:], [nRows,nCols], imgName)
480 |
481 | #-------- completion --------#
482 |
483 |
484 | #++++++++ interpolation visualization ++++++++#
485 |
486 | zhats_final = np.copy(zhats)
487 | diff = zhats_final - zhats_init
488 | step = 5
489 | for i in xrange(step):
490 | z_ = zhats_init + diff / (step-1) * i
491 | G_imgs = self.sess.run([self.G], feed_dict={ self.z: z_, self.t: batch_t })
492 | imgName = os.path.join(config.outDir, 'hats_imgs_{:04d}/{:01d}_interp.png'.format(idx, i))
493 | save_images(G_imgs[0][:batchSz,:,:,:], [nRows,nCols], imgName)
494 |
495 | #-------- interpolation visualization --------#
496 |
497 |
498 | def discriminator(self, image, t, reuse=False):
499 | if reuse:
500 | tf.get_variable_scope().reuse_variables()
501 |
502 | h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv'))
503 |
504 | t_ = tf.expand_dims(t, 1)
505 | t_ = tf.expand_dims(t_, 2)
506 | t_tiled = tf.tile(t_, [1,32,64,1], name='tiled_t')
507 | h0_concat = tf.concat(3, [h0, t_tiled], name='h0_concat')
508 |
509 | h1 = lrelu(self.d_bn1(conv2d(h0_concat, self.df_dim*2, name='d_h1_conv')))
510 | h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, name='d_h2_conv')))
511 | h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim*8, name='d_h3_conv')))
512 |
513 | #h4 = linear(tf.reshape(h3, [-1, 8192*2]), 1, 'd_h3_lin')
514 | # conv to 512x1x1
515 | h4 = conv2d(h3, self.df_dim*8, 4, 8, 1, 1, name='d_h4_conv')
516 |
517 | return tf.nn.sigmoid(h4), h4
518 |
519 |
520 | def generator(self, z, t):
521 |
522 | self.z_, self.h0_lin_w, self.h0_lin_b = linear(z, self.gf_dim*4*8, 'g_h0_lin', with_w=True)
523 | z_ = tf.reshape(self.z_, [-1, 4, 8, self.gf_dim])
524 |
525 | t_ = tf.expand_dims(tf.expand_dims(t, 1), 2)
526 | t_tiled = tf.tile(t_, [1,4,8,1])
527 |
528 | h0_concat = tf.concat(3, [z_, t_tiled])
529 |
530 | self.h0, self.h0_w, self.h0_b = conv2d_transpose(h0_concat,
531 | [self.batch_size, 4, 8, self.gf_dim*8], 1, 1, 1, 1, name='g_h0', with_w=True)
532 | h0 = tf.nn.relu(self.g_bn0(self.h0))
533 |
534 | self.h1, self.h1_w, self.h1_b = conv2d_transpose(h0,
535 | [self.batch_size, 8, 16, self.gf_dim*4], name='g_h1', with_w=True)
536 | h1 = tf.nn.relu(self.g_bn1(self.h1))
537 |
538 | h2, self.h2_w, self.h2_b = conv2d_transpose(h1,
539 | [self.batch_size, 16, 32, self.gf_dim*2], name='g_h2', with_w=True)
540 | h2 = tf.nn.relu(self.g_bn2(h2))
541 |
542 | h3, self.h3_w, self.h3_b = conv2d_transpose(h2,
543 | [self.batch_size, 32, 64, self.gf_dim*1], name='g_h3', with_w=True)
544 | h3 = tf.nn.relu(self.g_bn3(h3))
545 |
546 | h4, self.h4_w, self.h4_b = conv2d_transpose(h3,
547 | [self.batch_size, 64, 128, 3], name='g_h4', with_w=True)
548 |
549 | return tf.nn.tanh(h4)
550 |
551 |
552 | def sampler(self, z, t, y=None):
553 | tf.get_variable_scope().reuse_variables()
554 |
555 | z_ = tf.reshape(linear(z, self.gf_dim*4*8, 'g_h0_lin'), [-1, 4, 8, self.gf_dim])
556 |
557 | t_ = tf.expand_dims(tf.expand_dims(t, 1), 2)
558 | t_tiled = tf.tile(t_, [1,4,8,1])
559 |
560 | h0_concat = tf.concat(3, [z_, t_tiled])
561 |
562 | h0 = conv2d_transpose(h0_concat,
563 | [self.batch_size, 4, 8, self.gf_dim*8], 1, 1, 1, 1, name='g_h0')
564 | h0 = tf.nn.relu(self.g_bn0(h0, train=False))
565 |
566 | h1 = conv2d_transpose(h0, [self.batch_size, 8, 16, self.gf_dim*4], name='g_h1')
567 | h1 = tf.nn.relu(self.g_bn1(h1, train=False))
568 |
569 | h2 = conv2d_transpose(h1, [self.batch_size, 16, 32, self.gf_dim*2], name='g_h2')
570 | h2 = tf.nn.relu(self.g_bn2(h2, train=False))
571 |
572 | h3 = conv2d_transpose(h2, [self.batch_size, 32, 64, self.gf_dim*1], name='g_h3')
573 | h3 = tf.nn.relu(self.g_bn3(h3, train=False))
574 |
575 | h4 = conv2d_transpose(h3, [self.batch_size, 64, 128, 3], name='g_h4')
576 |
577 | return tf.nn.tanh(h4)
578 |
579 |
580 | def save(self, checkpoint_dir, step):
581 | if not os.path.exists(checkpoint_dir):
582 | os.makedirs(checkpoint_dir)
583 |
584 | self.saver.save(self.sess,
585 | os.path.join(checkpoint_dir, self.model_name),
586 | global_step=step)
587 |
588 |
589 | def load(self, checkpoint_dir):
590 | print(" [*] Reading checkpoints...")
591 |
592 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
593 | if ckpt and ckpt.model_checkpoint_path:
594 | self.saver.restore(self.sess, ckpt.model_checkpoint_path)
595 | return True
596 | else:
597 | return False
598 |
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | # Original Version: Taehoon Kim (http://carpedm20.github.io)
2 | # + Source: https://github.com/carpedm20/DCGAN-tensorflow/blob/e30539fb5e20d5a0fed40935853da97e9e55eee8/ops.py
3 | # + License: MIT
4 | # [2017-07] Modifications for sText2Image: Shangzhe Wu
5 | # + License: MIT
6 |
7 | import math
8 | import numpy as np
9 | import tensorflow as tf
10 |
11 | from tensorflow.python.framework import ops
12 |
13 | from utils import *
14 |
15 | class batch_norm(object):
16 | """Code modification of http://stackoverflow.com/a/33950177"""
17 | def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"):
18 | with tf.variable_scope(name):
19 | self.epsilon = epsilon
20 | self.momentum = momentum
21 |
22 | self.ema = tf.train.ExponentialMovingAverage(decay=self.momentum)
23 | self.name = name
24 |
25 | def __call__(self, x, train=True):
26 | shape = x.get_shape().as_list()
27 |
28 | if train:
29 | with tf.variable_scope(self.name) as scope:
30 | self.beta = tf.get_variable("beta", [shape[-1]],
31 | initializer=tf.constant_initializer(0.))
32 | self.gamma = tf.get_variable("gamma", [shape[-1]],
33 | initializer=tf.random_normal_initializer(1., 0.02))
34 |
35 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments')
36 | with tf.variable_scope(tf.get_variable_scope(), reuse=False):
37 | ema_apply_op = self.ema.apply([batch_mean, batch_var])
38 | self.ema_mean, self.ema_var = self.ema.average(batch_mean), self.ema.average(batch_var)
39 |
40 | with tf.control_dependencies([ema_apply_op]):
41 | mean, var = tf.identity(batch_mean), tf.identity(batch_var)
42 | else:
43 | mean, var = self.ema_mean, self.ema_var
44 |
45 | normed = tf.nn.batch_norm_with_global_normalization(
46 | x, mean, var, self.beta, self.gamma, self.epsilon, scale_after_normalization=True)
47 |
48 | return normed
49 |
50 | def binary_cross_entropy(preds, targets, name=None):
51 | """Computes binary cross entropy given `preds`.
52 |
53 | For brevity, let `x = `, `z = targets`. The logistic loss is
54 |
55 | loss(x, z) = - sum_i (x[i] * log(z[i]) + (1 - x[i]) * log(1 - z[i]))
56 |
57 | Args:
58 | preds: A `Tensor` of type `float32` or `float64`.
59 | targets: A `Tensor` of the same type and shape as `preds`.
60 | """
61 | eps = 1e-12
62 | with ops.op_scope([preds, targets], name, "bce_loss") as name:
63 | preds = ops.convert_to_tensor(preds, name="preds")
64 | targets = ops.convert_to_tensor(targets, name="targets")
65 | return tf.reduce_mean(-(targets * tf.log(preds + eps) +
66 | (1. - targets) * tf.log(1. - preds + eps)))
67 |
68 | def conv_cond_concat(x, y):
69 | """Concatenate conditioning vector on feature map axis."""
70 | x_shapes = x.get_shape()
71 | y_shapes = y.get_shape()
72 | return tf.concat(3, [x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])])
73 |
74 | def conv2d(input_, output_dim,
75 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
76 | name="conv2d"):
77 | with tf.variable_scope(name):
78 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
79 | initializer=tf.truncated_normal_initializer(stddev=stddev))
80 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')
81 |
82 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
83 | # conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
84 | conv = tf.nn.bias_add(conv, biases)
85 |
86 | return conv
87 |
88 | def conv2d_transpose(input_, output_shape,
89 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
90 | name="conv2d_transpose", with_w=False):
91 | with tf.variable_scope(name):
92 | # filter : [height, width, output_channels, in_channels]
93 | w = tf.get_variable('w', [k_h, k_h, output_shape[-1], input_.get_shape()[-1]],
94 | initializer=tf.random_normal_initializer(stddev=stddev))
95 |
96 | try:
97 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape,
98 | strides=[1, d_h, d_w, 1])
99 |
100 | # Support for verisons of TensorFlow before 0.7.0
101 | except AttributeError:
102 | deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape,
103 | strides=[1, d_h, d_w, 1])
104 |
105 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
106 | # deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())
107 | deconv = tf.nn.bias_add(deconv, biases)
108 |
109 | if with_w:
110 | return deconv, w, biases
111 | else:
112 | return deconv
113 |
114 | def lrelu(x, leak=0.2, name="lrelu"):
115 | with tf.variable_scope(name):
116 | f1 = 0.5 * (1 + leak)
117 | f2 = 0.5 * (1 - leak)
118 | return f1 * x + f2 * abs(x)
119 |
120 | def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):
121 | shape = input_.get_shape().as_list()
122 |
123 | with tf.variable_scope(scope or "Linear"):
124 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32,
125 | tf.random_normal_initializer(stddev=stddev))
126 | bias = tf.get_variable("bias", [output_size],
127 | initializer=tf.constant_initializer(bias_start))
128 | if with_w:
129 | return tf.matmul(input_, matrix) + bias, matrix, bias
130 | else:
131 | return tf.matmul(input_, matrix) + bias
132 |
133 |
134 | ######## Elliott ########
135 | def kl_divergence(p, q):
136 | tf.assert_rank(p,2)
137 | tf.assert_rank(q,2)
138 |
139 | p_shape = tf.shape(p)
140 | q_shape = tf.shape(q)
141 | tf.assert_equal(p_shape, q_shape)
142 |
143 | # normalize sum to 1
144 | p_ = tf.divide(p, tf.tile(tf.expand_dims(tf.reduce_sum(p,axis=1), 1), [1,p_shape[1]]))
145 | q_ = tf.divide(q, tf.tile(tf.expand_dims(tf.reduce_sum(q,axis=1), 1), [1,p_shape[1]]))
146 |
147 | return tf.reduce_sum(tf.multiply(p_, tf.log(tf.divide(p_, q_))), axis=1)
148 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # [2016-08-05] Modifications for Completion: Brandon Amos (http://bamos.github.io)
2 | # + License: MIT
3 | # [2017-07] Modifications for sText2Image: Shangzhe Wu
4 | # + License: MIT
5 |
6 | import argparse
7 | import os
8 | import tensorflow as tf
9 |
10 | from model import GAN
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--lr', type=float, default=0.01)
14 | parser.add_argument('--momentum', type=float, default=0.9)
15 | parser.add_argument('--nIter', type=int, default=1000)
16 | parser.add_argument('--imgSize', type=int, default=64)
17 | parser.add_argument('--batchSize', type=int, default=64)
18 | parser.add_argument('--text_vector_dim', type=int, default=100)
19 | parser.add_argument('--lam1', type=float, default=0.1) # Hyperparameter for contextual loss [0.1]
20 | parser.add_argument('--lam2', type=float, default=0.1) # Hyperparameter for perceptual loss [0.1]
21 | #parser.add_argument('--lam3', type=float, default=0.1) # Hyperparameter for wrong example [0.1]
22 | parser.add_argument('--checkpointDir', type=str, default='checkpoint')
23 | parser.add_argument('--outDir', type=str, default='results')
24 | parser.add_argument('--text_path', type=str, default='text_embeddings.pkl')
25 | parser.add_argument('--maskType', type=str,
26 | choices=['random', 'center', 'left', 'right', 'full'],
27 | default='right')
28 | parser.add_argument('--attributes', nargs='+', type=int, default=[None])
29 | parser.add_argument('imgs', type=str, nargs='+')
30 |
31 | args = parser.parse_args()
32 |
33 | assert(os.path.exists(args.checkpointDir))
34 |
35 | config = tf.ConfigProto()
36 | config.gpu_options.allow_growth = True
37 | with tf.Session(config=config) as sess:
38 | model = GAN(sess,
39 | image_size=args.imgSize,
40 | batch_size=args.batchSize,
41 | text_vector_dim=args.text_vector_dim,
42 | checkpoint_dir=args.checkpointDir,
43 | lam1=args.lam1,
44 | lam2=args.lam2,
45 | )
46 | model.test(args)
47 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | python test.py ./datasets/celeba/test/* --checkpointDir checkpoints_face --maskType right --batchSize 64 --lam1 100 --lam2 1 --lr 0.001 --nIter 1000 --outDir results_face --text_vector_dim 18 --text_path datasets/celeba/imAttrs.pkl
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Original Version: Taehoon Kim (http://carpedm20.github.io)
2 | # + Source: https://github.com/carpedm20/DCGAN-tensorflow/blob/e30539fb5e20d5a0fed40935853da97e9e55eee8/main.py
3 | # + License: MIT
4 | # [2016-08-05] Modifications for Inpainting: Brandon Amos (http://bamos.github.io)
5 | # + License: MIT
6 | # [2017-07] Modifications for sText2Image: Shangzhe Wu
7 | # + License: MIT
8 |
9 | import os
10 | import scipy.misc
11 | import numpy as np
12 | import tensorflow as tf
13 |
14 | from model import GAN
15 |
16 | flags = tf.app.flags
17 | flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")
18 | flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")
19 | flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")
20 | flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]")
21 | flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
22 | flags.DEFINE_integer("image_size", 64, "The size of image to use [64]")
23 | flags.DEFINE_integer("text_vector_dim", 100, "The dimension of input text vector [100]")
24 | flags.DEFINE_string("dataset", "datasets/celeba/train", "Dataset directory.")
25 | flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
26 | flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
27 | flags.DEFINE_string("log_dir", "logs", "Directory name to save the logs [logs]")
28 | flags.DEFINE_string("text_path", "text_embeddings.pkl", "Path of the text embeddings [text_embeddings.pkl]")
29 | #flags.DEFINE_float("lam1", 0.1, "Hyperparameter for contextual loss [0.1]")
30 | #flags.DEFINE_float("lam2", 0.1, "Hyperparameter for perceptual loss [0.1]")
31 | flags.DEFINE_float("lam3", 0.1, "Hyperparameter for wrong examples [0.1]")
32 | FLAGS = flags.FLAGS
33 |
34 | if not os.path.exists(FLAGS.checkpoint_dir):
35 | os.makedirs(FLAGS.checkpoint_dir)
36 | if not os.path.exists(FLAGS.sample_dir):
37 | os.makedirs(FLAGS.sample_dir)
38 |
39 | config = tf.ConfigProto()
40 | config.gpu_options.allow_growth = True
41 | with tf.Session(config=config) as sess:
42 | model = GAN(sess,
43 | image_size=FLAGS.image_size,
44 | batch_size=FLAGS.batch_size,
45 | text_vector_dim=FLAGS.text_vector_dim,
46 | checkpoint_dir=FLAGS.checkpoint_dir,
47 | sample_dir=FLAGS.sample_dir,
48 | log_dir=FLAGS.log_dir,
49 | is_crop=False,
50 | lam3=FLAGS.lam3,
51 | )
52 |
53 | model.train(FLAGS)
54 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | python train.py --dataset ./datasets/celeba/train --batch_size 64 --checkpoint_dir checkpoints_face --sample_dir samples_face --log_dir logs_face --epoch 50 --learning_rate 0.0002 --lam3 0.1 --text_path datasets/celeba/imAttrs.pkl --text_vector_dim 18
2 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # Original Version: Taehoon Kim (http://carpedm20.github.io)
2 | # + Source: https://github.com/carpedm20/DCGAN-tensorflow/blob/e30539fb5e20d5a0fed40935853da97e9e55eee8/utils.py
3 | # + License: MIT
4 | # [2017-07] Modifications for sText2Image: Shangzhe Wu
5 | # + License: MIT
6 |
7 | """
8 | Some codes from https://github.com/Newmu/dcgan_code
9 | """
10 | from __future__ import division
11 | import math
12 | import json
13 | import random
14 | import pprint
15 | import scipy.misc
16 | import numpy as np
17 | from time import gmtime, strftime
18 |
19 | import pdb
20 |
21 | pp = pprint.PrettyPrinter()
22 |
23 | get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])
24 |
25 | def get_image(image_path, image_size, is_crop=True):
26 | return transform(imread(image_path), image_size, is_crop)
27 |
28 | def save_images(images, size, image_path):
29 | return imsave(inverse_transform(images), size, image_path)
30 |
31 | def imread(path):
32 | return scipy.misc.imread(path).astype(np.float)
33 |
34 | def merge_images(images, size):
35 | return inverse_transform(images)
36 |
37 | def merge(images, size):
38 | h, w = images.shape[1], images.shape[2]
39 | img = np.zeros((int(h * size[0]), int(w * size[1]), 3))
40 | for idx, image in enumerate(images):
41 | i = idx % size[1]
42 | j = idx // size[1]
43 | img[j*h:j*h+h, i*w:i*w+w, :] = image
44 |
45 | return img
46 |
47 | def imsave(images, size, path):
48 | return scipy.misc.imsave(path, merge(images, size))
49 |
50 | def center_crop(x, crop_h, crop_w=None, resize_w=64):
51 | if crop_w is None:
52 | crop_w = crop_h
53 | h, w = x.shape[:2]
54 | j = int(round((h - crop_h)/2.))
55 | i = int(round((w - crop_w)/2.))
56 | return scipy.misc.imresize(x[j:j+crop_h, i:i+crop_w],
57 | [resize_w, resize_w])
58 |
59 | def transform(image, npx=64, is_crop=True):
60 | # npx : # of pixels width/height of image
61 | if is_crop:
62 | cropped_image = center_crop(image, npx)
63 | else:
64 | cropped_image = image
65 | return np.array(cropped_image)/127.5 - 1.
66 |
67 | def inverse_transform(images):
68 | return (images+1.)/2.
69 |
70 |
71 | def to_json(output_path, *layers):
72 | with open(output_path, "w") as layer_f:
73 | lines = ""
74 | for w, b, bn in layers:
75 | layer_idx = w.name.split('/')[0].split('h')[1]
76 |
77 | B = b.eval()
78 |
79 | if "lin/" in w.name:
80 | W = w.eval()
81 | depth = W.shape[1]
82 | else:
83 | W = np.rollaxis(w.eval(), 2, 0)
84 | depth = W.shape[0]
85 |
86 | biases = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(B)]}
87 | if bn != None:
88 | gamma = bn.gamma.eval()
89 | beta = bn.beta.eval()
90 |
91 | gamma = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(gamma)]}
92 | beta = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(beta)]}
93 | else:
94 | gamma = {"sy": 1, "sx": 1, "depth": 0, "w": []}
95 | beta = {"sy": 1, "sx": 1, "depth": 0, "w": []}
96 |
97 | if "lin/" in w.name:
98 | fs = []
99 | for w in W.T:
100 | fs.append({"sy": 1, "sx": 1, "depth": W.shape[0], "w": ['%.2f' % elem for elem in list(w)]})
101 |
102 | lines += """
103 | var layer_%s = {
104 | "layer_type": "fc",
105 | "sy": 1, "sx": 1,
106 | "out_sx": 1, "out_sy": 1,
107 | "stride": 1, "pad": 0,
108 | "out_depth": %s, "in_depth": %s,
109 | "biases": %s,
110 | "gamma": %s,
111 | "beta": %s,
112 | "filters": %s
113 | };""" % (layer_idx.split('_')[0], W.shape[1], W.shape[0], biases, gamma, beta, fs)
114 | else:
115 | fs = []
116 | for w_ in W:
117 | fs.append({"sy": 5, "sx": 5, "depth": W.shape[3], "w": ['%.2f' % elem for elem in list(w_.flatten())]})
118 |
119 | lines += """
120 | var layer_%s = {
121 | "layer_type": "deconv",
122 | "sy": 5, "sx": 5,
123 | "out_sx": %s, "out_sy": %s,
124 | "stride": 2, "pad": 1,
125 | "out_depth": %s, "in_depth": %s,
126 | "biases": %s,
127 | "gamma": %s,
128 | "beta": %s,
129 | "filters": %s
130 | };""" % (layer_idx, 2**(int(layer_idx)+2), 2**(int(layer_idx)+2),
131 | W.shape[0], W.shape[3], biases, gamma, beta, fs)
132 | layer_f.write(" ".join(lines.replace("'","").split()))
133 |
134 | def make_gif(images, fname, duration=2, true_image=False):
135 | import moviepy.editor as mpy
136 |
137 | def make_frame(t):
138 | try:
139 | x = images[int(len(images)/duration*t)]
140 | except:
141 | x = images[-1]
142 |
143 | if true_image:
144 | return x.astype(np.uint8)
145 | else:
146 | return ((x+1)/2*255).astype(np.uint8)
147 |
148 | clip = mpy.VideoClip(make_frame, duration=duration)
149 | clip.write_gif(fname, fps = len(images) / duration)
150 |
151 | def visualize(sess, dcgan, config, option):
152 | if option == 0:
153 | z_sample = np.random.uniform(-0.5, 0.5, size=(config.batch_size, dcgan.z_dim))
154 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
155 | save_images(samples, [8, 8], './samples/test_%s.png' % strftime("%Y-%m-%d %H:%M:%S", gmtime()))
156 | elif option == 1:
157 | values = np.arange(0, 1, 1./config.batch_size)
158 | for idx in xrange(100):
159 | print(" [*] %d" % idx)
160 | z_sample = np.zeros([config.batch_size, dcgan.z_dim])
161 | for kdx, z in enumerate(z_sample):
162 | z[idx] = values[kdx]
163 |
164 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
165 | save_images(samples, [8, 8], './samples/test_arange_%s.png' % (idx))
166 | elif option == 2:
167 | values = np.arange(0, 1, 1./config.batch_size)
168 | for idx in [random.randint(0, 99) for _ in xrange(100)]:
169 | print(" [*] %d" % idx)
170 | z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim))
171 | z_sample = np.tile(z, (config.batch_size, 1))
172 | #z_sample = np.zeros([config.batch_size, dcgan.z_dim])
173 | for kdx, z in enumerate(z_sample):
174 | z[idx] = values[kdx]
175 |
176 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
177 | make_gif(samples, './samples/test_gif_%s.gif' % (idx))
178 | elif option == 3:
179 | values = np.arange(0, 1, 1./config.batch_size)
180 | for idx in xrange(100):
181 | print(" [*] %d" % idx)
182 | z_sample = np.zeros([config.batch_size, dcgan.z_dim])
183 | for kdx, z in enumerate(z_sample):
184 | z[idx] = values[kdx]
185 |
186 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
187 | make_gif(samples, './samples/test_gif_%s.gif' % (idx))
188 | elif option == 4:
189 | image_set = []
190 | values = np.arange(0, 1, 1./config.batch_size)
191 |
192 | for idx in xrange(100):
193 | print(" [*] %d" % idx)
194 | z_sample = np.zeros([config.batch_size, dcgan.z_dim])
195 | for kdx, z in enumerate(z_sample): z[idx] = values[kdx]
196 |
197 | image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}))
198 | make_gif(image_set[-1], './samples/test_gif_%s.gif' % (idx))
199 |
200 | new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \
201 | for idx in range(64) + range(63, -1, -1)]
202 | make_gif(new_image_set, './samples/test_gif_merged.gif', duration=8)
203 |
204 |
205 | ########## Elliott ##########
206 | def get_text_batch(image_path, text_data):
207 | try:
208 | #pdb.set_trace()
209 | idx = int(image_path[0:6])-1
210 | except:
211 | print("image_path format is unexpected.")
212 | else:
213 | return text_data[idx]
214 |
215 | def rgb2gray(rgb):
216 | return np.dot(rgb[...,:3], [0.2989, 0.5870, 0.1140])
--------------------------------------------------------------------------------