├── .gitignore
├── LICENSE
├── README.md
├── StackGAN.py
├── assets
├── result.png
└── teaser.png
├── main.py
├── ops.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Junho Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## StackGAN — Simple TensorFlow Implementation [[Paper]](https://arxiv.org/abs/1612.03242)
2 | ### : Text to Photo-realistic Image Synthesis with Stacked Generative Adversarial Networks
3 |
4 |
5 |

6 |
7 |
8 | ## Dataset
9 | ### char-CNN-RNN text embedding
10 | * [birds](https://drive.google.com/open?id=0B3y_msrWZaXLT1BZdVdycDY5TEE)
11 | * [flowers](https://drive.google.com/open?id=0B3y_msrWZaXLaUc0UXpmcnhaVmM)
12 |
13 | ### Image
14 | * [birds](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html)
15 | * [flowers](http://www.robots.ox.ac.uk/~vgg/data/flowers/102/)
16 |
17 | ## Usage
18 | ```
19 | ├── dataset
20 | └── YOUR_DATASET_NAME
21 | ├── images
22 | ├── domain1 (domain folder)
23 | ├── xxx.jpg (domain1 image)
24 | ├── yyy.png
25 | ├── ...
26 | ├── domain2
27 | ├── aaa.jpg (domain2 image)
28 | ├── bbb.png
29 | ├── ...
30 | ├── domain3
31 | ├── ...
32 | ├── text
33 | ├── char-CNN-RNN-embeddings.pickle
34 | ├── filenames.pickle
35 | ```
36 |
37 | ### Train
38 | ```
39 | python main.py --dataset birds --phase train
40 | ```
41 |
42 | ### Test
43 | ```
44 | python main.py --dataset birds --phase test
45 | ```
46 |
47 | ## Results
48 |
49 |

50 |
51 |
52 | ## Author
53 | [Junho Kim](http://bit.ly/jhkim_ai)
54 |
--------------------------------------------------------------------------------
/StackGAN.py:
--------------------------------------------------------------------------------
1 | from ops import *
2 | from utils import *
3 | import time
4 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
5 | import numpy as np
6 |
7 |
8 | class StackGAN():
9 | def __init__(self, sess, args):
10 |
11 | self.phase = args.phase
12 | self.model_name = 'StackGAN'
13 |
14 | self.sess = sess
15 | self.checkpoint_dir = args.checkpoint_dir
16 | self.result_dir = args.result_dir
17 | self.log_dir = args.log_dir
18 | self.dataset_name = args.dataset
19 | self.augment_flag = args.augment_flag
20 |
21 | self.iteration = args.iteration
22 | self.decay_flag = args.decay_flag
23 | self.decay_iter = args.decay_iter
24 |
25 | self.batch_size = args.batch_size
26 | self.print_freq = args.print_freq
27 | self.save_freq = args.save_freq
28 |
29 | self.init_lr = args.lr
30 |
31 | self.gan_type = args.gan_type
32 |
33 | self.condition_dim = 128
34 | self.df_dim = 96
35 | self.gf_dim = 128
36 | self.text_dim = 1024
37 | self.z_dim = 100
38 |
39 |
40 | """ Weight """
41 | self.adv_weight = args.adv_weight
42 | self.kl_weight = args.kl_weight
43 |
44 |
45 | """ Generator """
46 |
47 | """ Discriminator """
48 | self.sn = args.sn
49 |
50 | self.img_height = args.img_height
51 | self.img_width = args.img_width
52 |
53 | self.img_ch = args.img_ch
54 |
55 | self.sample_dir = os.path.join(args.sample_dir, self.model_dir)
56 | check_folder(self.sample_dir)
57 |
58 | self.dataset_path = os.path.join('./dataset', self.dataset_name)
59 |
60 | print()
61 |
62 | print("##### Information #####")
63 | print("# dataset : ", self.dataset_name)
64 | print("# batch_size : ", self.batch_size)
65 | print("# max iteration : ", self.iteration)
66 |
67 | print()
68 |
69 | print("##### Generator #####")
70 |
71 | print()
72 |
73 | print("##### Discriminator #####")
74 | print("# spectral normalization : ", self.sn)
75 |
76 | print()
77 |
78 | print("##### Weight #####")
79 | print("# adv_weight : ", self.adv_weight)
80 | print("# kl_weight : ", self.kl_weight)
81 |
82 | print()
83 |
84 | ##################################################################################
85 | # Generator
86 | ##################################################################################
87 |
88 | def generator_1(self, text_embedding, noise, is_training=True, reuse=tf.AUTO_REUSE, scope='generator_1'):
89 | channels = self.gf_dim * 8 # 1024
90 | with tf.variable_scope(scope, reuse=reuse):
91 | mu = fully_connected(text_embedding, units=self.condition_dim, use_bias=True, sn=self.sn, scope='mu_fc')
92 | mu = relu(mu)
93 |
94 | logvar = fully_connected(text_embedding, units=self.condition_dim, use_bias=True, sn=self.sn, scope='logvar_fc')
95 | logvar = relu(logvar)
96 |
97 | condition = reparametrize(mu, logvar)
98 |
99 | z = tf.concat([noise, condition], axis=-1)
100 | z = fully_connected(z, units=channels * 4 * 4, use_bias=False, sn=self.sn)
101 | z = batch_norm(z, is_training)
102 | z = relu(z)
103 | z = tf.reshape(z, shape=[-1, 4, 4, channels])
104 |
105 | x = z
106 | for i in range(4) :
107 | x = up_block(x, channels=channels // 2, is_training=is_training, use_bias=False, sn=self.sn, scope='up_block_' + str(i))
108 | channels = channels // 2
109 |
110 | x = conv(x, channels=self.img_ch, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=True, sn=self.sn, scope='g_logit')
111 | x = tanh(x)
112 |
113 | return x, mu, logvar
114 |
115 | def generator_2(self, x_init, text_embedding, is_training=True, reuse=tf.AUTO_REUSE, scope='generator_2'):
116 | channels = self.gf_dim
117 | with tf.variable_scope(scope, reuse=reuse):
118 |
119 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=True, sn=self.sn, scope='conv')
120 | x = relu(x)
121 |
122 | for i in range(2):
123 | x = conv(x, channels * 2, kernel=4, stride=2, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='conv_' + str(i))
124 | x = batch_norm(x, is_training, scope='batch_norm_' + str(i))
125 | x = relu(x)
126 |
127 | channels = channels * 2
128 |
129 | mu = fully_connected(text_embedding, units=self.condition_dim, use_bias=True, sn=self.sn, scope='mu_fc')
130 | mu = relu(mu)
131 |
132 | logvar = fully_connected(text_embedding, units=self.condition_dim, use_bias=True, sn=self.sn, scope='logvar_fc')
133 | logvar = relu(logvar)
134 |
135 | condition = reparametrize(mu, logvar)
136 | condition = tf.reshape(condition, shape=[-1, 1, 1, self.condition_dim])
137 | condition = tf.tile(condition, multiples=[1, 16, 16, 1])
138 |
139 | x = tf.concat([x, condition], axis=-1)
140 |
141 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='joint_conv')
142 | x = batch_norm(x, is_training, scope='joint_batch_norm')
143 | x = relu(x)
144 |
145 | for i in range(2):
146 | x = resblock(x, channels, is_training, use_bias=False, sn=self.sn, scope='resblock_' + str(i))
147 |
148 | for i in range(4):
149 | x = up_block(x, channels=channels // 2, is_training=is_training, use_bias=False, sn=self.sn, scope='up_block_' + str(i))
150 | channels = channels // 2
151 |
152 | x = conv(x, channels=self.img_ch, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=True, sn=self.sn, scope='g_logit')
153 | x = tanh(x)
154 |
155 | return x, mu, logvar
156 |
157 | ##################################################################################
158 | # Discriminator
159 | ##################################################################################
160 |
161 | def discriminator_1(self, x_init, mu, is_training=True, reuse=tf.AUTO_REUSE, scope="discriminator_1"):
162 | channel = self.df_dim
163 |
164 | with tf.variable_scope(scope, reuse=reuse):
165 | x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', use_bias=True, sn=self.sn, scope='conv')
166 | x = lrelu(x, 0.2)
167 |
168 | for i in range(3) :
169 | x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='conv_' + str(i))
170 | x = batch_norm(x, is_training, scope='batch_norm_' + str(i))
171 | x = lrelu(x, 0.2)
172 |
173 | channel = channel * 2
174 |
175 | mu = tf.reshape(mu, shape=[-1, 1, 1, self.condition_dim])
176 | mu = tf.tile(mu, multiples=[1, 4, 4, 1])
177 |
178 | x = tf.concat([x, mu], axis=-1)
179 |
180 | x = conv(x, channel, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='conv_last')
181 | x = batch_norm(x, is_training, scope='batch_norm_last')
182 | x = lrelu(x, 0.2)
183 |
184 | x = conv(x, channels=1, kernel=4, stride=4, use_bias=True, sn=self.sn, scope='d_logit')
185 |
186 | return x
187 |
188 | def discriminator_2(self, x_init, mu, is_training=True, reuse=tf.AUTO_REUSE, scope="discriminator_2"):
189 | channel = self.df_dim
190 |
191 | with tf.variable_scope(scope, reuse=reuse):
192 | x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', use_bias=True, sn=self.sn, scope='conv')
193 | x = lrelu(x, 0.2)
194 |
195 | for i in range(5) :
196 | x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='conv_' + str(i))
197 | x = batch_norm(x, is_training, scope='batch_norm_' + str(i))
198 | x = lrelu(x, 0.2)
199 |
200 | channel = channel * 2
201 |
202 | for i in range(2):
203 | x = conv(x, channel // 2, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='conv3x3_' + str(i))
204 | x = batch_norm(x, is_training, scope='batch_norm3x3_' + str(i))
205 | x = lrelu(x, 0.2)
206 |
207 | channel = channel // 2
208 |
209 | mu = tf.reshape(mu, shape=[-1, 1, 1, self.condition_dim])
210 | mu = tf.tile(mu, multiples=[1, 4, 4, 1])
211 |
212 | x = tf.concat([x, mu], axis=-1)
213 |
214 | x = conv(x, channel, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='conv_last')
215 | x = batch_norm(x, is_training, scope='batch_norm_last')
216 | x = lrelu(x, 0.2)
217 |
218 | x = conv(x, channels=1, kernel=4, stride=4, use_bias=True, sn=self.sn, scope='d_logit')
219 |
220 | return x
221 |
222 | ##################################################################################
223 | # Model
224 | ##################################################################################
225 |
226 |
227 | def build_model(self):
228 |
229 | if self.phase == 'train' :
230 | self.lr = tf.placeholder(tf.float32, name='learning_rate')
231 | """ Input Image"""
232 | img_data_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.augment_flag)
233 | img_data_class.preprocess()
234 |
235 | self.dataset_num = len(img_data_class.image_list)
236 |
237 |
238 | img_and_embedding = tf.data.Dataset.from_tensor_slices((img_data_class.image_list, img_data_class.embedding))
239 |
240 | gpu_device = '/gpu:0'
241 | img_and_embedding = img_and_embedding.apply(shuffle_and_repeat(self.dataset_num)).apply(
242 | map_and_batch(img_data_class.image_processing, batch_size=self.batch_size, num_parallel_batches=16,
243 | drop_remainder=True)).apply(prefetch_to_device(gpu_device, None))
244 |
245 |
246 | img_and_embedding_iterator = img_and_embedding.make_one_shot_iterator()
247 |
248 | self.real_img_256, self.embedding = img_and_embedding_iterator.get_next()
249 | sentence_index = tf.random.uniform(shape=[], minval=0, maxval=10, dtype=tf.int32)
250 | self.embedding = tf.gather(self.embedding, indices=sentence_index, axis=1) #[bs, 1024]
251 |
252 | noise = tf.random_normal(shape=[self.batch_size, self.z_dim])
253 | self.fake_img_64, mu_64, logvar_64 = self.generator_1(self.embedding, noise)
254 | self.fake_img_256, mu_256, logvar_256 = self.generator_2(self.fake_img_64, self.embedding)
255 | self.real_img_64 = tf.image.resize_bilinear(self.real_img_256, size=[64, 64])
256 |
257 | self.real_img = [self.real_img_64, self.real_img_256]
258 | self.fake_img = [self.fake_img_64, self.fake_img_256]
259 |
260 | real_logit_64 = self.discriminator_1(self.real_img_64, mu_64)
261 | fake_logit_64 = self.discriminator_1(self.fake_img_64, mu_64)
262 |
263 | real_logit_256 = self.discriminator_2(self.real_img_256, mu_256)
264 | fake_logit_256 = self.discriminator_2(self.fake_img_256, mu_256)
265 |
266 | g_adv_loss_64 = generator_loss(self.gan_type, fake_logit_64) * self.adv_weight
267 | g_kl_loss_64 = kl_loss(mu_64, logvar_64) * self.kl_weight
268 |
269 | d_adv_loss_64 = discriminator_loss(self.gan_type, real_logit_64, fake_logit_64) * self.adv_weight
270 |
271 | g_loss_64 = g_adv_loss_64 + g_kl_loss_64
272 | d_loss_64 = d_adv_loss_64
273 |
274 | g_adv_loss_256 = generator_loss(self.gan_type, fake_logit_256) * self.adv_weight
275 | g_kl_loss_256 = kl_loss(mu_256, logvar_256) * self.kl_weight
276 |
277 | d_adv_loss_256 = discriminator_loss(self.gan_type, real_logit_256, fake_logit_256) * self.adv_weight
278 |
279 | g_loss_256 = g_adv_loss_256 + g_kl_loss_256
280 | d_loss_256 = d_adv_loss_256
281 |
282 | self.g_loss = [g_loss_64, g_loss_256]
283 | self.d_loss = [d_loss_64, d_loss_256]
284 |
285 |
286 | """ Training """
287 | t_vars = tf.trainable_variables()
288 | G1_vars = [var for var in t_vars if 'generator_1' in var.name]
289 | G2_vars = [var for var in t_vars if 'generator_2' in var.name]
290 | D1_vars = [var for var in t_vars if 'discriminator_1' in var.name]
291 | D2_vars = [var for var in t_vars if 'discriminator_2' in var.name]
292 |
293 | g1_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(g_loss_64, var_list=G1_vars)
294 | g2_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(g_loss_256, var_list=G2_vars)
295 |
296 | d1_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(d_loss_64,var_list=D1_vars)
297 | d2_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(d_loss_256, var_list=D2_vars)
298 |
299 | self.g_optim = [g1_optim, g2_optim]
300 | self.d_optim = [d1_optim, d2_optim]
301 |
302 |
303 | """" Summary """
304 | self.summary_g_loss_64 = tf.summary.scalar("g_loss_64", g_loss_64)
305 | self.summary_g_loss_256 = tf.summary.scalar("g_loss_256", g_loss_256)
306 | self.summary_d_loss_64 = tf.summary.scalar("d_loss_64", d_loss_64)
307 | self.summary_d_loss_256 = tf.summary.scalar("d_loss_256", d_loss_256)
308 |
309 | self.summary_g_adv_loss_64 = tf.summary.scalar("g_adv_loss_64", g_adv_loss_64)
310 | self.summary_g_adv_loss_256 = tf.summary.scalar("g_adv_loss_256", g_adv_loss_256)
311 | self.summary_g_kl_loss_64 = tf.summary.scalar("g_kl_loss_64", g_kl_loss_64)
312 | self.summary_g_kl_loss_256 = tf.summary.scalar("g_kl_loss_256", g_kl_loss_256)
313 |
314 | self.summary_d_adv_loss_64 = tf.summary.scalar("d_adv_loss_64", d_adv_loss_64)
315 | self.summary_d_adv_loss_256 = tf.summary.scalar("d_adv_loss_256", d_adv_loss_256)
316 |
317 |
318 | g_summary_list = [self.summary_g_loss_64, self.summary_g_loss_256,
319 | self.summary_g_adv_loss_64, self.summary_g_adv_loss_256,
320 | self.summary_g_kl_loss_64, self.summary_g_kl_loss_256]
321 |
322 | d_summary_list = [self.summary_d_loss_64, self.summary_d_loss_256,
323 | self.summary_d_adv_loss_64, self.summary_d_adv_loss_256]
324 |
325 | self.summary_merge_g_loss = tf.summary.merge(g_summary_list)
326 | self.summary_merge_d_loss = tf.summary.merge(d_summary_list)
327 |
328 | else :
329 | """ Test """
330 | """ Input Image"""
331 | img_data_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, augment_flag=False)
332 | img_data_class.preprocess()
333 |
334 | self.dataset_num = len(img_data_class.image_list)
335 |
336 | img_and_embedding = tf.data.Dataset.from_tensor_slices(
337 | (img_data_class.image_list, img_data_class.embedding))
338 |
339 | gpu_device = '/gpu:0'
340 | img_and_embedding = img_and_embedding.apply(shuffle_and_repeat(self.dataset_num)).apply(
341 | map_and_batch(img_data_class.image_processing, batch_size=5, num_parallel_batches=16,
342 | drop_remainder=True)).apply(prefetch_to_device(gpu_device, None))
343 |
344 | img_and_embedding_iterator = img_and_embedding.make_one_shot_iterator()
345 |
346 | self.real_img_256, self.embedding = img_and_embedding_iterator.get_next()
347 | sentence_index = tf.random.uniform(shape=[], minval=0, maxval=10, dtype=tf.int32)
348 | self.embedding = tf.gather(self.embedding, indices=sentence_index, axis=1) # [bs, 1024]
349 |
350 | noise = tf.random_normal(shape=[self.batch_size, self.z_dim])
351 | self.fake_img_64, mu_64, logvar_64 = self.generator_1(self.embedding, noise, is_training=False)
352 | self.fake_img_256, mu_256, logvar_256 = self.generator_2(self.fake_img_64, self.embedding, is_training=False)
353 |
354 | self.test_fake_img = self.fake_img_256
355 | self.test_real_img = self.real_img_256
356 |
357 |
358 | def train(self):
359 | # initialize all variables
360 | tf.global_variables_initializer().run()
361 |
362 | # saver to save model
363 | self.saver = tf.train.Saver(max_to_keep=10)
364 |
365 | # summary writer
366 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)
367 |
368 | # restore check-point if it exits
369 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
370 | if could_load:
371 | counter = checkpoint_counter
372 | init_stage = counter // self.iteration
373 | if init_stage == 1 :
374 | start_batch_id = checkpoint_counter - self.iteration
375 | else :
376 | start_batch_id = checkpoint_counter
377 | print(" [*] Load SUCCESS")
378 |
379 | else:
380 | start_batch_id = 0
381 | counter = 1
382 | init_stage = 0
383 | print(" [!] Load failed...")
384 |
385 | # loop for epoch
386 | start_time = time.time()
387 |
388 | for stage in range(init_stage, 2) :
389 | lr = self.init_lr
390 | for idx in range(start_batch_id, self.iteration):
391 |
392 | if self.decay_flag :
393 | if idx > 0 and (idx % self.decay_iter) == 0 :
394 | lr = self.init_lr * pow(0.5, idx // self.decay_iter)
395 |
396 | train_feed_dict = {
397 | self.lr : lr
398 | }
399 |
400 | # Update D
401 | _, d_loss, summary_str = self.sess.run([self.d_optim[stage], self.d_loss[stage], self.summary_merge_d_loss], feed_dict=train_feed_dict)
402 | self.writer.add_summary(summary_str, counter)
403 |
404 | # Update G
405 | real_images, fake_images, _, g_loss, summary_str = self.sess.run(
406 | [self.real_img[stage], self.fake_img[stage],
407 | self.g_optim[stage],
408 | self.g_loss[stage], self.summary_merge_g_loss], feed_dict=train_feed_dict)
409 |
410 | self.writer.add_summary(summary_str, counter)
411 |
412 |
413 | # display training status
414 | counter += 1
415 | print("Stage: [%1d] [%6d/%6d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (stage, idx, self.iteration, time.time() - start_time, d_loss, g_loss))
416 |
417 | if np.mod(idx + 1, self.print_freq) == 0:
418 | real_images = real_images[:5]
419 | fake_images = fake_images[:5]
420 |
421 | merge_real_images = np.expand_dims(return_images(real_images, [5, 1]), axis=0)
422 | merge_fake_images = np.expand_dims(return_images(fake_images, [5, 1]), axis=0)
423 |
424 | merge_images = np.concatenate([merge_real_images, merge_fake_images], axis=0)
425 |
426 | save_images(merge_images, [1, 2],
427 | './{}/merge_stage{}_{:07d}.jpg'.format(self.sample_dir, stage, idx + 1))
428 |
429 |
430 | if np.mod(counter - 1, self.save_freq) == 0:
431 | self.save(self.checkpoint_dir, counter)
432 |
433 | # save model for final step
434 | self.save(self.checkpoint_dir, counter)
435 |
436 | @property
437 | def model_dir(self):
438 | if self.sn:
439 | sn = '_sn'
440 | else:
441 | sn = ''
442 |
443 | return "{}_{}_{}_{}adv_{}kl{}".format(self.model_name, self.dataset_name, self.gan_type,
444 | self.adv_weight, self.kl_weight,
445 | sn)
446 |
447 | def save(self, checkpoint_dir, step):
448 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
449 |
450 | if not os.path.exists(checkpoint_dir):
451 | os.makedirs(checkpoint_dir)
452 |
453 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
454 |
455 | def load(self, checkpoint_dir):
456 | print(" [*] Reading checkpoints...")
457 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
458 |
459 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
460 | if ckpt and ckpt.model_checkpoint_path:
461 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
462 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
463 | counter = int(ckpt_name.split('-')[-1])
464 | print(" [*] Success to read {}".format(ckpt_name))
465 | return True, counter
466 | else:
467 | print(" [*] Failed to find a checkpoint")
468 | return False, 0
469 |
470 | def test(self):
471 | tf.global_variables_initializer().run()
472 |
473 | self.saver = tf.train.Saver()
474 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
475 | self.result_dir = os.path.join(self.result_dir, self.model_dir)
476 | check_folder(self.result_dir)
477 |
478 | if could_load:
479 | print(" [*] Load SUCCESS")
480 | else:
481 | print(" [!] Load failed...")
482 |
483 | # write html for visual comparisondkssjg
484 | index_path = os.path.join(self.result_dir, 'index.html')
485 | index = open(index_path, 'w')
486 | index.write("")
487 | index.write("name | content | style | output |
")
488 |
489 | real_images, fake_images = self.sess.run([self.test_real_img, self.test_fake_img])
490 | for i in range(5) :
491 | real_path = os.path.join(self.result_dir, 'real_{}.jpg'.format(i))
492 | fake_path = os.path.join(self.result_dir, 'fake_{}.jpg'.format(i))
493 |
494 | real_image = np.expand_dims(real_images[i], axis=0)
495 | fake_image = np.expand_dims(fake_images[i], axis=0)
496 |
497 | save_images(real_image, [1, 1], real_path)
498 | save_images(fake_image, [1, 1], fake_path)
499 |
500 | index.write("%s | " % os.path.basename(real_path))
501 | index.write(" | " % (real_path if os.path.isabs(real_path) else (
502 | '../..' + os.path.sep + real_path), self.img_width, self.img_height))
503 |
504 | index.write(" | " % (fake_path if os.path.isabs(fake_path) else (
505 | '../..' + os.path.sep + fake_path), self.img_width, self.img_height))
506 | index.write("")
507 |
508 | index.close()
--------------------------------------------------------------------------------
/assets/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/StackGAN-Tensorflow/1a5ffed6613049d8fd43c8ff5cbe34061394975e/assets/result.png
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/StackGAN-Tensorflow/1a5ffed6613049d8fd43c8ff5cbe34061394975e/assets/teaser.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from StackGAN import StackGAN
2 | import argparse
3 | from utils import *
4 |
5 | """parsing and configuration"""
6 | def parse_args():
7 | desc = "Tensorflow implementation of StackGAN"
8 | parser = argparse.ArgumentParser(description=desc)
9 | parser.add_argument('--phase', type=str, default='train', choices=('train', 'test'), help='phase name')
10 | parser.add_argument('--dataset', type=str, default='birds', help='dataset_name')
11 |
12 | parser.add_argument('--iteration', type=int, default=500000, help='The number of training iterations')
13 | parser.add_argument('--decay_flag', type=str2bool, default=True, help='The decay_flag')
14 | parser.add_argument('--decay_iter', type=int, default=100000, help='decay epoch')
15 |
16 | parser.add_argument('--batch_size', type=int, default=32, help='The size of batch size for each gpu')
17 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq')
18 | parser.add_argument('--save_freq', type=int, default=10000, help='The number of ckpt_save_freq')
19 |
20 | parser.add_argument('--lr', type=float, default=0.0002, help='The learning rate')
21 |
22 | parser.add_argument('--gan_type', type=str, default='gan', help='[gan / lsgan / hinge]')
23 |
24 | parser.add_argument('--adv_weight', type=int, default=1, help='Weight about GAN')
25 | parser.add_argument('--kl_weight', type=int, default=2, help='Weight about kl_loss')
26 |
27 | parser.add_argument('--sn', type=str2bool, default=False, help='using spectral norm')
28 |
29 | parser.add_argument('--img_height', type=int, default=256, help='The height size of image')
30 | parser.add_argument('--img_width', type=int, default=256, help='The width size of image ')
31 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
32 | parser.add_argument('--augment_flag', type=str2bool, default=True, help='Image augmentation use or not')
33 |
34 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
35 | help='Directory name to save the checkpoints')
36 | parser.add_argument('--result_dir', type=str, default='results',
37 | help='Directory name to save the generated images')
38 | parser.add_argument('--log_dir', type=str, default='logs',
39 | help='Directory name to save training logs')
40 | parser.add_argument('--sample_dir', type=str, default='samples',
41 | help='Directory name to save the samples on training')
42 |
43 | return check_args(parser.parse_args())
44 |
45 | """checking arguments"""
46 | def check_args(args):
47 | # --checkpoint_dir
48 | check_folder(args.checkpoint_dir)
49 |
50 | # --result_dir
51 | check_folder(args.result_dir)
52 |
53 | # --log_dir
54 | check_folder(args.log_dir)
55 |
56 | # --sample_dir
57 | check_folder(args.sample_dir)
58 |
59 | # --batch_size
60 | try:
61 | assert args.batch_size >= 1
62 | except:
63 | print('batch size must be larger than or equal to one')
64 | return args
65 |
66 | """main"""
67 | def main():
68 | # parse arguments
69 | args = parse_args()
70 | if args is None:
71 | exit()
72 |
73 | # open session
74 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
75 | gan = StackGAN(sess, args)
76 |
77 | # build graph
78 | gan.build_model()
79 |
80 | # show network architecture
81 | show_all_variables()
82 |
83 | if args.phase == 'train' :
84 | gan.train()
85 | print(" [*] Training finished!")
86 |
87 | if args.phase == 'test' :
88 | gan.test()
89 | print(" [*] Test finished!")
90 |
91 |
92 |
93 | if __name__ == '__main__':
94 | main()
95 |
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow.contrib as tf_contrib
3 | from utils import pytorch_kaiming_weight_factor
4 |
5 | ##################################################################################
6 | # Initialization
7 | ##################################################################################
8 |
9 | # factor, mode, uniform = pytorch_kaiming_weight_factor(a=0.0, uniform=False)
10 | # weight_init = tf_contrib.layers.variance_scaling_initializer(factor=factor, mode=mode, uniform=uniform)
11 | weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02)
12 |
13 | weight_regularizer = None
14 | weight_regularizer_fully = None
15 |
16 | ##################################################################################
17 | # Layer
18 | ##################################################################################
19 |
20 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'):
21 | with tf.variable_scope(scope):
22 | if pad > 0:
23 | h = x.get_shape().as_list()[1]
24 | if h % stride == 0:
25 | pad = pad * 2
26 | else:
27 | pad = max(kernel - (h % stride), 0)
28 |
29 | pad_top = pad // 2
30 | pad_bottom = pad - pad_top
31 | pad_left = pad // 2
32 | pad_right = pad - pad_left
33 |
34 | if pad_type == 'zero':
35 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]])
36 | if pad_type == 'reflect':
37 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT')
38 |
39 | if sn:
40 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
41 | regularizer=weight_regularizer)
42 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
43 | strides=[1, stride, stride, 1], padding='VALID')
44 | if use_bias:
45 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
46 | x = tf.nn.bias_add(x, bias)
47 |
48 | else:
49 | x = tf.layers.conv2d(inputs=x, filters=channels,
50 | kernel_size=kernel, kernel_initializer=weight_init,
51 | kernel_regularizer=weight_regularizer,
52 | strides=stride, use_bias=use_bias)
53 |
54 | return x
55 |
56 |
57 | def fully_connected(x, units, use_bias=True, sn=False, scope='linear'):
58 | with tf.variable_scope(scope):
59 | x = flatten(x)
60 | shape = x.get_shape().as_list()
61 | channels = shape[-1]
62 |
63 | if sn:
64 | w = tf.get_variable("kernel", [channels, units], tf.float32,
65 | initializer=weight_init, regularizer=weight_regularizer_fully)
66 | if use_bias:
67 | bias = tf.get_variable("bias", [units],
68 | initializer=tf.constant_initializer(0.0))
69 |
70 | x = tf.matmul(x, spectral_norm(w)) + bias
71 | else:
72 | x = tf.matmul(x, spectral_norm(w))
73 |
74 | else:
75 | x = tf.layers.dense(x, units=units, kernel_initializer=weight_init,
76 | kernel_regularizer=weight_regularizer_fully,
77 | use_bias=use_bias)
78 |
79 | return x
80 |
81 |
82 | def flatten(x):
83 | return tf.layers.flatten(x)
84 |
85 |
86 | ##################################################################################
87 | # Residual-block
88 | ##################################################################################
89 |
90 |
91 | def resblock(x_init, channels, is_training=True, use_bias=True, sn=False, scope='resblock'):
92 | with tf.variable_scope(scope):
93 | with tf.variable_scope('res1'):
94 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn)
95 | x = batch_norm(x, is_training)
96 | x = relu(x)
97 |
98 | with tf.variable_scope('res2'):
99 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn)
100 | x = batch_norm(x, is_training)
101 |
102 | return relu(x + x_init)
103 |
104 | def up_block(x_init, channels, is_training=True, use_bias=True, sn=False, scope='up_block'):
105 | with tf.variable_scope(scope):
106 | x = up_sample(x_init, scale_factor=2)
107 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn)
108 | x = batch_norm(x, is_training)
109 | x = relu(x)
110 |
111 | return x
112 |
113 | ##################################################################################
114 | # Sampling
115 | ##################################################################################
116 |
117 | def up_sample(x, scale_factor=2):
118 | _, h, w, _ = x.get_shape().as_list()
119 | new_size = [h * scale_factor, w * scale_factor]
120 | return tf.image.resize_nearest_neighbor(x, size=new_size)
121 |
122 |
123 | def down_sample_avg(x, scale_factor=2):
124 | return tf.layers.average_pooling2d(x, pool_size=3, strides=scale_factor, padding='SAME')
125 |
126 | def global_avg_pooling(x):
127 | gap = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
128 | return gap
129 |
130 | def reparametrize(mean, logvar):
131 | eps = tf.random_normal(tf.shape(mean), mean=0.0, stddev=1.0, dtype=tf.float32)
132 |
133 | return mean + tf.exp(logvar * 0.5) * eps
134 |
135 | ##################################################################################
136 | # Activation function
137 | ##################################################################################
138 |
139 | def lrelu(x, alpha=0.01):
140 | # pytorch alpha is 0.01
141 | return tf.nn.leaky_relu(x, alpha)
142 |
143 |
144 | def relu(x):
145 | return tf.nn.relu(x)
146 |
147 |
148 | def tanh(x):
149 | return tf.tanh(x)
150 |
151 |
152 | ##################################################################################
153 | # Normalization function
154 | ##################################################################################
155 |
156 | def instance_norm(x, scope='instance_norm'):
157 | return tf_contrib.layers.instance_norm(x,
158 | epsilon=1e-05,
159 | center=True, scale=True,
160 | scope=scope)
161 |
162 | def batch_norm(x, is_training=False, scope='batch_norm'):
163 | """
164 | if x_norm = tf.layers.batch_normalization
165 | # ...
166 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
167 | train_op = optimizer.minimize(loss)
168 | """
169 |
170 | return tf_contrib.layers.batch_norm(x,
171 | decay=0.9, epsilon=1e-05,
172 | center=True, scale=True, updates_collections=None,
173 | is_training=is_training, scope=scope)
174 |
175 | # return tf.layers.batch_normalization(x, momentum=0.9, epsilon=1e-05, center=True, scale=True, training=is_training, name=scope)
176 |
177 |
178 | def param_free_norm(x, epsilon=1e-5):
179 | x_mean, x_var = tf.nn.moments(x, axes=[1, 2], keep_dims=True)
180 | x_std = tf.sqrt(x_var + epsilon)
181 |
182 | return (x - x_mean) / x_std
183 |
184 | def adaptive_instance_norm(content, gamma, beta, epsilon=1e-5):
185 | # gamma, beta = style_mean, style_std from MLP
186 |
187 | x = param_free_norm(content, epsilon)
188 |
189 | return gamma * x + beta
190 |
191 | def spectral_norm(w, iteration=1):
192 | w_shape = w.shape.as_list()
193 | w = tf.reshape(w, [-1, w_shape[-1]])
194 |
195 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False)
196 |
197 | u_hat = u
198 | v_hat = None
199 | for i in range(iteration):
200 | """
201 | power iteration
202 | Usually iteration = 1 will be enough
203 | """
204 | v_ = tf.matmul(u_hat, tf.transpose(w))
205 | v_hat = tf.nn.l2_normalize(v_)
206 |
207 | u_ = tf.matmul(v_hat, w)
208 | u_hat = tf.nn.l2_normalize(u_)
209 |
210 | u_hat = tf.stop_gradient(u_hat)
211 | v_hat = tf.stop_gradient(v_hat)
212 |
213 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
214 |
215 | with tf.control_dependencies([u.assign(u_hat)]):
216 | w_norm = w / sigma
217 | w_norm = tf.reshape(w_norm, w_shape)
218 |
219 | return w_norm
220 |
221 |
222 | ##################################################################################
223 | # Loss function
224 | ##################################################################################
225 |
226 | def L1_loss(x, y):
227 | loss = tf.reduce_mean(tf.abs(x - y)) # [64, h, w, c]
228 |
229 | return loss
230 |
231 | def discriminator_loss(gan_type, real_logit, fake_logit):
232 | real_loss = 0
233 | fake_loss = 0
234 |
235 | if gan_type == 'lsgan':
236 | real_loss = tf.reduce_mean(tf.squared_difference(real_logit, 1.0))
237 | fake_loss = tf.reduce_mean(tf.square(fake_logit))
238 |
239 | if gan_type == 'gan':
240 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real_logit), logits=real_logit))
241 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake_logit), logits=fake_logit))
242 |
243 | if gan_type == 'hinge':
244 |
245 | real_loss = tf.reduce_mean(relu(1 - real_logit))
246 | fake_loss = tf.reduce_mean(relu(1 + fake_logit))
247 |
248 | return real_loss + fake_loss
249 |
250 |
251 | def generator_loss(gan_type, fake_logit):
252 | fake_loss = 0
253 |
254 | if gan_type == 'lsgan':
255 | fake_loss = tf.reduce_mean(tf.squared_difference(fake_logit, 1.0))
256 |
257 | if gan_type == 'gan':
258 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake_logit), logits=fake_logit))
259 |
260 | if gan_type == 'hinge':
261 | fake_loss = -tf.reduce_mean(fake_logit)
262 |
263 | return fake_loss
264 |
265 |
266 | def regularization_loss(scope_name):
267 | """
268 | If you want to use "Regularization"
269 | g_loss += regularization_loss('generator')
270 | d_loss += regularization_loss('discriminator')
271 | """
272 | collection_regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
273 |
274 | loss = []
275 | for item in collection_regularization:
276 | if scope_name in item.name:
277 | loss.append(item)
278 |
279 | return tf.reduce_sum(loss)
280 |
281 | def kl_loss(mean, logvar):
282 | # shape : [batch_size, channel]
283 | loss = 0.5 * tf.reduce_sum(tf.square(mean) + tf.exp(logvar) - 1 - logvar, axis=-1)
284 | loss = tf.reduce_mean(loss)
285 |
286 | return loss
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib import slim
3 | import os
4 | import numpy as np
5 | from glob import glob
6 | import cv2
7 | import pickle
8 |
9 | class Image_data:
10 |
11 | def __init__(self, img_height, img_width, channels, dataset_path, augment_flag):
12 | self.img_height = img_height
13 | self.img_width = img_width
14 | self.channels = channels
15 | self.augment_flag = augment_flag
16 |
17 | self.dataset_path = dataset_path
18 | self.image_path = os.path.join(dataset_path, 'images')
19 | self.text_path = os.path.join(dataset_path, 'text')
20 |
21 | self.embedding_pickle = os.path.join(self.text_path, 'char-CNN-RNN-embeddings.pickle')
22 | self.image_filename_pickle = os.path.join(self.text_path, 'filenames.pickle')
23 |
24 |
25 | self.image_list = []
26 |
27 |
28 | def image_processing(self, filename, vector):
29 | x = tf.read_file(filename)
30 | x_decode = tf.image.decode_jpeg(x, channels=self.channels, dct_method='INTEGER_ACCURATE')
31 | img = tf.image.resize_images(x_decode, [self.img_height, self.img_width])
32 | img = tf.cast(img, tf.float32) / 127.5 - 1
33 |
34 |
35 | if self.augment_flag :
36 | augment_height_size = self.img_height + (30 if self.img_height == 256 else int(self.img_height * 0.1))
37 | augment_width_size = self.img_width + (30 if self.img_width == 256 else int(self.img_width * 0.1))
38 |
39 | img = tf.cond(pred=tf.greater_equal(tf.random_uniform(shape=[], minval=0.0, maxval=1.0), 0.5),
40 | true_fn=lambda : augmentation(img, augment_height_size, augment_width_size),
41 | false_fn=lambda : img)
42 |
43 | return img, vector
44 |
45 | def preprocess(self):
46 | with open(self.embedding_pickle, 'rb') as f:
47 |
48 | self.embedding = pickle._Unpickler(f)
49 | self.embedding.encoding = 'latin1'
50 | self.embedding = self.embedding.load()
51 | self.embedding = np.array(self.embedding) # (8855, 10, 1024)
52 |
53 | with open(self.image_filename_pickle, 'rb') as f:
54 | # ['002.Laysan_Albatross/Laysan_Albatross_0002_1027', '002.Laysan_Albatross/Laysan_Albatross_0003_1033', ... ]
55 |
56 | x_list = pickle.load(f)
57 |
58 | for x in x_list :
59 | folder_name = x.split('/')[0]
60 | file_name = x.split('/')[1] + '.jpg'
61 |
62 | self.image_list.append(os.path.join(self.image_path, folder_name, file_name))
63 |
64 |
65 | def load_test_image(image_path, img_width, img_height, img_channel):
66 |
67 | if img_channel == 1 :
68 | img = cv2.imread(image_path, flags=cv2.IMREAD_GRAYSCALE)
69 | else :
70 | img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR)
71 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
72 |
73 | img = cv2.resize(img, dsize=(img_width, img_height))
74 |
75 | if img_channel == 1 :
76 | img = np.expand_dims(img, axis=0)
77 | img = np.expand_dims(img, axis=-1)
78 | else :
79 | img = np.expand_dims(img, axis=0)
80 |
81 | img = img/127.5 - 1
82 |
83 | return img
84 |
85 |
86 | def preprocessing(x):
87 | x = x/127.5 - 1 # -1 ~ 1
88 | return x
89 |
90 | def preprocess_fit_train_image(images, height, width):
91 | images = tf.image.resize(images, size=[height, width], method=tf.image.ResizeMethod.BILINEAR)
92 | images = adjust_dynamic_range(images)
93 |
94 | return images
95 |
96 | def adjust_dynamic_range(images):
97 | drange_in = [0.0, 255.0]
98 | drange_out = [-1.0, 1.0]
99 | scale = (drange_out[1] - drange_out[0]) / (drange_in[1] - drange_in[0])
100 | bias = drange_out[0] - drange_in[0] * scale
101 | images = images * scale + bias
102 | return images
103 |
104 | def augmentation(image, augment_height, augment_width):
105 | seed = np.random.randint(0, 2 ** 31 - 1)
106 |
107 | ori_image_shape = tf.shape(image)
108 | image = tf.image.random_flip_left_right(image, seed=seed)
109 | image = tf.image.resize(image, size=[augment_height, augment_width], method=tf.image.ResizeMethod.BILINEAR)
110 | image = tf.random_crop(image, ori_image_shape, seed=seed)
111 |
112 |
113 | return image
114 |
115 | def save_images(images, size, image_path):
116 | return imsave(inverse_transform(images), size, image_path)
117 |
118 | def inverse_transform(images):
119 | return ((images+1.) / 2) * 255.0
120 |
121 | def imsave(images, size, path):
122 | images = merge(images, size)
123 | images = cv2.cvtColor(images.astype('uint8'), cv2.COLOR_RGB2BGR)
124 |
125 | return cv2.imwrite(path, images)
126 |
127 |
128 | def post_process_generator_output(generator_output):
129 |
130 | drange_min, drange_max = -1.0, 1.0
131 | scale = 255.0 / (drange_max - drange_min)
132 |
133 | scaled_image = generator_output * scale + (0.5 - drange_min * scale)
134 | scaled_image = np.clip(scaled_image, 0, 255)
135 |
136 | return scaled_image
137 |
138 | def merge(images, size):
139 | h, w = images.shape[1], images.shape[2]
140 | c = images.shape[3]
141 | img = np.zeros((h * size[0], w * size[1], c))
142 | for idx, image in enumerate(images):
143 | i = idx % size[1]
144 | j = idx // size[1]
145 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image
146 |
147 | return img
148 |
149 | def return_images(images, size) :
150 | x = merge(images, size)
151 |
152 | return x
153 |
154 | def show_all_variables():
155 | model_vars = tf.trainable_variables()
156 | slim.model_analyzer.analyze_vars(model_vars, print_info=True)
157 |
158 | def check_folder(log_dir):
159 | if not os.path.exists(log_dir):
160 | os.makedirs(log_dir)
161 | return log_dir
162 |
163 | def str2bool(x):
164 | return x.lower() in ('true')
165 |
166 | def get_one_hot(targets, nb_classes):
167 |
168 | x = np.eye(nb_classes)[targets]
169 |
170 | return x
171 |
172 | def pytorch_xavier_weight_factor(gain=0.02, uniform=False) :
173 |
174 | if uniform :
175 | factor = gain * gain
176 | mode = 'FAN_AVG'
177 | else :
178 | factor = (gain * gain) / 1.3
179 | mode = 'FAN_AVG'
180 |
181 | return factor, mode, uniform
182 |
183 | def pytorch_kaiming_weight_factor(a=0.0, activation_function='leaky_relu', uniform=False) :
184 |
185 | if activation_function == 'relu' :
186 | gain = np.sqrt(2.0)
187 | elif activation_function == 'leaky_relu' :
188 | gain = np.sqrt(2.0 / (1 + a ** 2))
189 | elif activation_function == 'tanh' :
190 | gain = 5.0 / 3
191 | else :
192 | gain = 1.0
193 |
194 | if uniform :
195 | factor = gain * gain
196 | mode = 'FAN_IN'
197 | else :
198 | factor = (gain * gain) / 1.3
199 | mode = 'FAN_IN'
200 |
201 | return factor, mode, uniform
202 |
--------------------------------------------------------------------------------