├── .gitignore ├── LICENSE ├── README.md ├── StackGAN.py ├── assets ├── result.png └── teaser.png ├── main.py ├── ops.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Junho Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## StackGAN — Simple TensorFlow Implementation [[Paper]](https://arxiv.org/abs/1612.03242) 2 | ### : Text to Photo-realistic Image Synthesis with Stacked Generative Adversarial Networks 3 | 4 |
5 | 6 |
7 | 8 | ## Dataset 9 | ### char-CNN-RNN text embedding 10 | * [birds](https://drive.google.com/open?id=0B3y_msrWZaXLT1BZdVdycDY5TEE) 11 | * [flowers](https://drive.google.com/open?id=0B3y_msrWZaXLaUc0UXpmcnhaVmM) 12 | 13 | ### Image 14 | * [birds](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) 15 | * [flowers](http://www.robots.ox.ac.uk/~vgg/data/flowers/102/) 16 | 17 | ## Usage 18 | ``` 19 | ├── dataset 20 |    └── YOUR_DATASET_NAME 21 |    ├── images 22 |           ├── domain1 (domain folder) 23 | ├── xxx.jpg (domain1 image) 24 | ├── yyy.png 25 | ├── ... 26 | ├── domain2 27 | ├── aaa.jpg (domain2 image) 28 | ├── bbb.png 29 | ├── ... 30 | ├── domain3 31 | ├── ... 32 |    ├── text 33 | ├── char-CNN-RNN-embeddings.pickle 34 | ├── filenames.pickle 35 | ``` 36 | 37 | ### Train 38 | ``` 39 | python main.py --dataset birds --phase train 40 | ``` 41 | 42 | ### Test 43 | ``` 44 | python main.py --dataset birds --phase test 45 | ``` 46 | 47 | ## Results 48 |
49 | 50 |
51 | 52 | ## Author 53 | [Junho Kim](http://bit.ly/jhkim_ai) 54 | -------------------------------------------------------------------------------- /StackGAN.py: -------------------------------------------------------------------------------- 1 | from ops import * 2 | from utils import * 3 | import time 4 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch 5 | import numpy as np 6 | 7 | 8 | class StackGAN(): 9 | def __init__(self, sess, args): 10 | 11 | self.phase = args.phase 12 | self.model_name = 'StackGAN' 13 | 14 | self.sess = sess 15 | self.checkpoint_dir = args.checkpoint_dir 16 | self.result_dir = args.result_dir 17 | self.log_dir = args.log_dir 18 | self.dataset_name = args.dataset 19 | self.augment_flag = args.augment_flag 20 | 21 | self.iteration = args.iteration 22 | self.decay_flag = args.decay_flag 23 | self.decay_iter = args.decay_iter 24 | 25 | self.batch_size = args.batch_size 26 | self.print_freq = args.print_freq 27 | self.save_freq = args.save_freq 28 | 29 | self.init_lr = args.lr 30 | 31 | self.gan_type = args.gan_type 32 | 33 | self.condition_dim = 128 34 | self.df_dim = 96 35 | self.gf_dim = 128 36 | self.text_dim = 1024 37 | self.z_dim = 100 38 | 39 | 40 | """ Weight """ 41 | self.adv_weight = args.adv_weight 42 | self.kl_weight = args.kl_weight 43 | 44 | 45 | """ Generator """ 46 | 47 | """ Discriminator """ 48 | self.sn = args.sn 49 | 50 | self.img_height = args.img_height 51 | self.img_width = args.img_width 52 | 53 | self.img_ch = args.img_ch 54 | 55 | self.sample_dir = os.path.join(args.sample_dir, self.model_dir) 56 | check_folder(self.sample_dir) 57 | 58 | self.dataset_path = os.path.join('./dataset', self.dataset_name) 59 | 60 | print() 61 | 62 | print("##### Information #####") 63 | print("# dataset : ", self.dataset_name) 64 | print("# batch_size : ", self.batch_size) 65 | print("# max iteration : ", self.iteration) 66 | 67 | print() 68 | 69 | print("##### Generator #####") 70 | 71 | print() 72 | 73 | print("##### Discriminator #####") 74 | print("# spectral normalization : ", self.sn) 75 | 76 | print() 77 | 78 | print("##### Weight #####") 79 | print("# adv_weight : ", self.adv_weight) 80 | print("# kl_weight : ", self.kl_weight) 81 | 82 | print() 83 | 84 | ################################################################################## 85 | # Generator 86 | ################################################################################## 87 | 88 | def generator_1(self, text_embedding, noise, is_training=True, reuse=tf.AUTO_REUSE, scope='generator_1'): 89 | channels = self.gf_dim * 8 # 1024 90 | with tf.variable_scope(scope, reuse=reuse): 91 | mu = fully_connected(text_embedding, units=self.condition_dim, use_bias=True, sn=self.sn, scope='mu_fc') 92 | mu = relu(mu) 93 | 94 | logvar = fully_connected(text_embedding, units=self.condition_dim, use_bias=True, sn=self.sn, scope='logvar_fc') 95 | logvar = relu(logvar) 96 | 97 | condition = reparametrize(mu, logvar) 98 | 99 | z = tf.concat([noise, condition], axis=-1) 100 | z = fully_connected(z, units=channels * 4 * 4, use_bias=False, sn=self.sn) 101 | z = batch_norm(z, is_training) 102 | z = relu(z) 103 | z = tf.reshape(z, shape=[-1, 4, 4, channels]) 104 | 105 | x = z 106 | for i in range(4) : 107 | x = up_block(x, channels=channels // 2, is_training=is_training, use_bias=False, sn=self.sn, scope='up_block_' + str(i)) 108 | channels = channels // 2 109 | 110 | x = conv(x, channels=self.img_ch, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=True, sn=self.sn, scope='g_logit') 111 | x = tanh(x) 112 | 113 | return x, mu, logvar 114 | 115 | def generator_2(self, x_init, text_embedding, is_training=True, reuse=tf.AUTO_REUSE, scope='generator_2'): 116 | channels = self.gf_dim 117 | with tf.variable_scope(scope, reuse=reuse): 118 | 119 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=True, sn=self.sn, scope='conv') 120 | x = relu(x) 121 | 122 | for i in range(2): 123 | x = conv(x, channels * 2, kernel=4, stride=2, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='conv_' + str(i)) 124 | x = batch_norm(x, is_training, scope='batch_norm_' + str(i)) 125 | x = relu(x) 126 | 127 | channels = channels * 2 128 | 129 | mu = fully_connected(text_embedding, units=self.condition_dim, use_bias=True, sn=self.sn, scope='mu_fc') 130 | mu = relu(mu) 131 | 132 | logvar = fully_connected(text_embedding, units=self.condition_dim, use_bias=True, sn=self.sn, scope='logvar_fc') 133 | logvar = relu(logvar) 134 | 135 | condition = reparametrize(mu, logvar) 136 | condition = tf.reshape(condition, shape=[-1, 1, 1, self.condition_dim]) 137 | condition = tf.tile(condition, multiples=[1, 16, 16, 1]) 138 | 139 | x = tf.concat([x, condition], axis=-1) 140 | 141 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='joint_conv') 142 | x = batch_norm(x, is_training, scope='joint_batch_norm') 143 | x = relu(x) 144 | 145 | for i in range(2): 146 | x = resblock(x, channels, is_training, use_bias=False, sn=self.sn, scope='resblock_' + str(i)) 147 | 148 | for i in range(4): 149 | x = up_block(x, channels=channels // 2, is_training=is_training, use_bias=False, sn=self.sn, scope='up_block_' + str(i)) 150 | channels = channels // 2 151 | 152 | x = conv(x, channels=self.img_ch, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=True, sn=self.sn, scope='g_logit') 153 | x = tanh(x) 154 | 155 | return x, mu, logvar 156 | 157 | ################################################################################## 158 | # Discriminator 159 | ################################################################################## 160 | 161 | def discriminator_1(self, x_init, mu, is_training=True, reuse=tf.AUTO_REUSE, scope="discriminator_1"): 162 | channel = self.df_dim 163 | 164 | with tf.variable_scope(scope, reuse=reuse): 165 | x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', use_bias=True, sn=self.sn, scope='conv') 166 | x = lrelu(x, 0.2) 167 | 168 | for i in range(3) : 169 | x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='conv_' + str(i)) 170 | x = batch_norm(x, is_training, scope='batch_norm_' + str(i)) 171 | x = lrelu(x, 0.2) 172 | 173 | channel = channel * 2 174 | 175 | mu = tf.reshape(mu, shape=[-1, 1, 1, self.condition_dim]) 176 | mu = tf.tile(mu, multiples=[1, 4, 4, 1]) 177 | 178 | x = tf.concat([x, mu], axis=-1) 179 | 180 | x = conv(x, channel, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='conv_last') 181 | x = batch_norm(x, is_training, scope='batch_norm_last') 182 | x = lrelu(x, 0.2) 183 | 184 | x = conv(x, channels=1, kernel=4, stride=4, use_bias=True, sn=self.sn, scope='d_logit') 185 | 186 | return x 187 | 188 | def discriminator_2(self, x_init, mu, is_training=True, reuse=tf.AUTO_REUSE, scope="discriminator_2"): 189 | channel = self.df_dim 190 | 191 | with tf.variable_scope(scope, reuse=reuse): 192 | x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', use_bias=True, sn=self.sn, scope='conv') 193 | x = lrelu(x, 0.2) 194 | 195 | for i in range(5) : 196 | x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='conv_' + str(i)) 197 | x = batch_norm(x, is_training, scope='batch_norm_' + str(i)) 198 | x = lrelu(x, 0.2) 199 | 200 | channel = channel * 2 201 | 202 | for i in range(2): 203 | x = conv(x, channel // 2, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='conv3x3_' + str(i)) 204 | x = batch_norm(x, is_training, scope='batch_norm3x3_' + str(i)) 205 | x = lrelu(x, 0.2) 206 | 207 | channel = channel // 2 208 | 209 | mu = tf.reshape(mu, shape=[-1, 1, 1, self.condition_dim]) 210 | mu = tf.tile(mu, multiples=[1, 4, 4, 1]) 211 | 212 | x = tf.concat([x, mu], axis=-1) 213 | 214 | x = conv(x, channel, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=False, sn=self.sn, scope='conv_last') 215 | x = batch_norm(x, is_training, scope='batch_norm_last') 216 | x = lrelu(x, 0.2) 217 | 218 | x = conv(x, channels=1, kernel=4, stride=4, use_bias=True, sn=self.sn, scope='d_logit') 219 | 220 | return x 221 | 222 | ################################################################################## 223 | # Model 224 | ################################################################################## 225 | 226 | 227 | def build_model(self): 228 | 229 | if self.phase == 'train' : 230 | self.lr = tf.placeholder(tf.float32, name='learning_rate') 231 | """ Input Image""" 232 | img_data_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.augment_flag) 233 | img_data_class.preprocess() 234 | 235 | self.dataset_num = len(img_data_class.image_list) 236 | 237 | 238 | img_and_embedding = tf.data.Dataset.from_tensor_slices((img_data_class.image_list, img_data_class.embedding)) 239 | 240 | gpu_device = '/gpu:0' 241 | img_and_embedding = img_and_embedding.apply(shuffle_and_repeat(self.dataset_num)).apply( 242 | map_and_batch(img_data_class.image_processing, batch_size=self.batch_size, num_parallel_batches=16, 243 | drop_remainder=True)).apply(prefetch_to_device(gpu_device, None)) 244 | 245 | 246 | img_and_embedding_iterator = img_and_embedding.make_one_shot_iterator() 247 | 248 | self.real_img_256, self.embedding = img_and_embedding_iterator.get_next() 249 | sentence_index = tf.random.uniform(shape=[], minval=0, maxval=10, dtype=tf.int32) 250 | self.embedding = tf.gather(self.embedding, indices=sentence_index, axis=1) #[bs, 1024] 251 | 252 | noise = tf.random_normal(shape=[self.batch_size, self.z_dim]) 253 | self.fake_img_64, mu_64, logvar_64 = self.generator_1(self.embedding, noise) 254 | self.fake_img_256, mu_256, logvar_256 = self.generator_2(self.fake_img_64, self.embedding) 255 | self.real_img_64 = tf.image.resize_bilinear(self.real_img_256, size=[64, 64]) 256 | 257 | self.real_img = [self.real_img_64, self.real_img_256] 258 | self.fake_img = [self.fake_img_64, self.fake_img_256] 259 | 260 | real_logit_64 = self.discriminator_1(self.real_img_64, mu_64) 261 | fake_logit_64 = self.discriminator_1(self.fake_img_64, mu_64) 262 | 263 | real_logit_256 = self.discriminator_2(self.real_img_256, mu_256) 264 | fake_logit_256 = self.discriminator_2(self.fake_img_256, mu_256) 265 | 266 | g_adv_loss_64 = generator_loss(self.gan_type, fake_logit_64) * self.adv_weight 267 | g_kl_loss_64 = kl_loss(mu_64, logvar_64) * self.kl_weight 268 | 269 | d_adv_loss_64 = discriminator_loss(self.gan_type, real_logit_64, fake_logit_64) * self.adv_weight 270 | 271 | g_loss_64 = g_adv_loss_64 + g_kl_loss_64 272 | d_loss_64 = d_adv_loss_64 273 | 274 | g_adv_loss_256 = generator_loss(self.gan_type, fake_logit_256) * self.adv_weight 275 | g_kl_loss_256 = kl_loss(mu_256, logvar_256) * self.kl_weight 276 | 277 | d_adv_loss_256 = discriminator_loss(self.gan_type, real_logit_256, fake_logit_256) * self.adv_weight 278 | 279 | g_loss_256 = g_adv_loss_256 + g_kl_loss_256 280 | d_loss_256 = d_adv_loss_256 281 | 282 | self.g_loss = [g_loss_64, g_loss_256] 283 | self.d_loss = [d_loss_64, d_loss_256] 284 | 285 | 286 | """ Training """ 287 | t_vars = tf.trainable_variables() 288 | G1_vars = [var for var in t_vars if 'generator_1' in var.name] 289 | G2_vars = [var for var in t_vars if 'generator_2' in var.name] 290 | D1_vars = [var for var in t_vars if 'discriminator_1' in var.name] 291 | D2_vars = [var for var in t_vars if 'discriminator_2' in var.name] 292 | 293 | g1_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(g_loss_64, var_list=G1_vars) 294 | g2_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(g_loss_256, var_list=G2_vars) 295 | 296 | d1_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(d_loss_64,var_list=D1_vars) 297 | d2_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(d_loss_256, var_list=D2_vars) 298 | 299 | self.g_optim = [g1_optim, g2_optim] 300 | self.d_optim = [d1_optim, d2_optim] 301 | 302 | 303 | """" Summary """ 304 | self.summary_g_loss_64 = tf.summary.scalar("g_loss_64", g_loss_64) 305 | self.summary_g_loss_256 = tf.summary.scalar("g_loss_256", g_loss_256) 306 | self.summary_d_loss_64 = tf.summary.scalar("d_loss_64", d_loss_64) 307 | self.summary_d_loss_256 = tf.summary.scalar("d_loss_256", d_loss_256) 308 | 309 | self.summary_g_adv_loss_64 = tf.summary.scalar("g_adv_loss_64", g_adv_loss_64) 310 | self.summary_g_adv_loss_256 = tf.summary.scalar("g_adv_loss_256", g_adv_loss_256) 311 | self.summary_g_kl_loss_64 = tf.summary.scalar("g_kl_loss_64", g_kl_loss_64) 312 | self.summary_g_kl_loss_256 = tf.summary.scalar("g_kl_loss_256", g_kl_loss_256) 313 | 314 | self.summary_d_adv_loss_64 = tf.summary.scalar("d_adv_loss_64", d_adv_loss_64) 315 | self.summary_d_adv_loss_256 = tf.summary.scalar("d_adv_loss_256", d_adv_loss_256) 316 | 317 | 318 | g_summary_list = [self.summary_g_loss_64, self.summary_g_loss_256, 319 | self.summary_g_adv_loss_64, self.summary_g_adv_loss_256, 320 | self.summary_g_kl_loss_64, self.summary_g_kl_loss_256] 321 | 322 | d_summary_list = [self.summary_d_loss_64, self.summary_d_loss_256, 323 | self.summary_d_adv_loss_64, self.summary_d_adv_loss_256] 324 | 325 | self.summary_merge_g_loss = tf.summary.merge(g_summary_list) 326 | self.summary_merge_d_loss = tf.summary.merge(d_summary_list) 327 | 328 | else : 329 | """ Test """ 330 | """ Input Image""" 331 | img_data_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, augment_flag=False) 332 | img_data_class.preprocess() 333 | 334 | self.dataset_num = len(img_data_class.image_list) 335 | 336 | img_and_embedding = tf.data.Dataset.from_tensor_slices( 337 | (img_data_class.image_list, img_data_class.embedding)) 338 | 339 | gpu_device = '/gpu:0' 340 | img_and_embedding = img_and_embedding.apply(shuffle_and_repeat(self.dataset_num)).apply( 341 | map_and_batch(img_data_class.image_processing, batch_size=5, num_parallel_batches=16, 342 | drop_remainder=True)).apply(prefetch_to_device(gpu_device, None)) 343 | 344 | img_and_embedding_iterator = img_and_embedding.make_one_shot_iterator() 345 | 346 | self.real_img_256, self.embedding = img_and_embedding_iterator.get_next() 347 | sentence_index = tf.random.uniform(shape=[], minval=0, maxval=10, dtype=tf.int32) 348 | self.embedding = tf.gather(self.embedding, indices=sentence_index, axis=1) # [bs, 1024] 349 | 350 | noise = tf.random_normal(shape=[self.batch_size, self.z_dim]) 351 | self.fake_img_64, mu_64, logvar_64 = self.generator_1(self.embedding, noise, is_training=False) 352 | self.fake_img_256, mu_256, logvar_256 = self.generator_2(self.fake_img_64, self.embedding, is_training=False) 353 | 354 | self.test_fake_img = self.fake_img_256 355 | self.test_real_img = self.real_img_256 356 | 357 | 358 | def train(self): 359 | # initialize all variables 360 | tf.global_variables_initializer().run() 361 | 362 | # saver to save model 363 | self.saver = tf.train.Saver(max_to_keep=10) 364 | 365 | # summary writer 366 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph) 367 | 368 | # restore check-point if it exits 369 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 370 | if could_load: 371 | counter = checkpoint_counter 372 | init_stage = counter // self.iteration 373 | if init_stage == 1 : 374 | start_batch_id = checkpoint_counter - self.iteration 375 | else : 376 | start_batch_id = checkpoint_counter 377 | print(" [*] Load SUCCESS") 378 | 379 | else: 380 | start_batch_id = 0 381 | counter = 1 382 | init_stage = 0 383 | print(" [!] Load failed...") 384 | 385 | # loop for epoch 386 | start_time = time.time() 387 | 388 | for stage in range(init_stage, 2) : 389 | lr = self.init_lr 390 | for idx in range(start_batch_id, self.iteration): 391 | 392 | if self.decay_flag : 393 | if idx > 0 and (idx % self.decay_iter) == 0 : 394 | lr = self.init_lr * pow(0.5, idx // self.decay_iter) 395 | 396 | train_feed_dict = { 397 | self.lr : lr 398 | } 399 | 400 | # Update D 401 | _, d_loss, summary_str = self.sess.run([self.d_optim[stage], self.d_loss[stage], self.summary_merge_d_loss], feed_dict=train_feed_dict) 402 | self.writer.add_summary(summary_str, counter) 403 | 404 | # Update G 405 | real_images, fake_images, _, g_loss, summary_str = self.sess.run( 406 | [self.real_img[stage], self.fake_img[stage], 407 | self.g_optim[stage], 408 | self.g_loss[stage], self.summary_merge_g_loss], feed_dict=train_feed_dict) 409 | 410 | self.writer.add_summary(summary_str, counter) 411 | 412 | 413 | # display training status 414 | counter += 1 415 | print("Stage: [%1d] [%6d/%6d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (stage, idx, self.iteration, time.time() - start_time, d_loss, g_loss)) 416 | 417 | if np.mod(idx + 1, self.print_freq) == 0: 418 | real_images = real_images[:5] 419 | fake_images = fake_images[:5] 420 | 421 | merge_real_images = np.expand_dims(return_images(real_images, [5, 1]), axis=0) 422 | merge_fake_images = np.expand_dims(return_images(fake_images, [5, 1]), axis=0) 423 | 424 | merge_images = np.concatenate([merge_real_images, merge_fake_images], axis=0) 425 | 426 | save_images(merge_images, [1, 2], 427 | './{}/merge_stage{}_{:07d}.jpg'.format(self.sample_dir, stage, idx + 1)) 428 | 429 | 430 | if np.mod(counter - 1, self.save_freq) == 0: 431 | self.save(self.checkpoint_dir, counter) 432 | 433 | # save model for final step 434 | self.save(self.checkpoint_dir, counter) 435 | 436 | @property 437 | def model_dir(self): 438 | if self.sn: 439 | sn = '_sn' 440 | else: 441 | sn = '' 442 | 443 | return "{}_{}_{}_{}adv_{}kl{}".format(self.model_name, self.dataset_name, self.gan_type, 444 | self.adv_weight, self.kl_weight, 445 | sn) 446 | 447 | def save(self, checkpoint_dir, step): 448 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 449 | 450 | if not os.path.exists(checkpoint_dir): 451 | os.makedirs(checkpoint_dir) 452 | 453 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step) 454 | 455 | def load(self, checkpoint_dir): 456 | print(" [*] Reading checkpoints...") 457 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 458 | 459 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 460 | if ckpt and ckpt.model_checkpoint_path: 461 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 462 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 463 | counter = int(ckpt_name.split('-')[-1]) 464 | print(" [*] Success to read {}".format(ckpt_name)) 465 | return True, counter 466 | else: 467 | print(" [*] Failed to find a checkpoint") 468 | return False, 0 469 | 470 | def test(self): 471 | tf.global_variables_initializer().run() 472 | 473 | self.saver = tf.train.Saver() 474 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 475 | self.result_dir = os.path.join(self.result_dir, self.model_dir) 476 | check_folder(self.result_dir) 477 | 478 | if could_load: 479 | print(" [*] Load SUCCESS") 480 | else: 481 | print(" [!] Load failed...") 482 | 483 | # write html for visual comparisondkssjg 484 | index_path = os.path.join(self.result_dir, 'index.html') 485 | index = open(index_path, 'w') 486 | index.write("") 487 | index.write("") 488 | 489 | real_images, fake_images = self.sess.run([self.test_real_img, self.test_fake_img]) 490 | for i in range(5) : 491 | real_path = os.path.join(self.result_dir, 'real_{}.jpg'.format(i)) 492 | fake_path = os.path.join(self.result_dir, 'fake_{}.jpg'.format(i)) 493 | 494 | real_image = np.expand_dims(real_images[i], axis=0) 495 | fake_image = np.expand_dims(fake_images[i], axis=0) 496 | 497 | save_images(real_image, [1, 1], real_path) 498 | save_images(fake_image, [1, 1], fake_path) 499 | 500 | index.write("" % os.path.basename(real_path)) 501 | index.write("" % (real_path if os.path.isabs(real_path) else ( 502 | '../..' + os.path.sep + real_path), self.img_width, self.img_height)) 503 | 504 | index.write("" % (fake_path if os.path.isabs(fake_path) else ( 505 | '../..' + os.path.sep + fake_path), self.img_width, self.img_height)) 506 | index.write("") 507 | 508 | index.close() -------------------------------------------------------------------------------- /assets/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/StackGAN-Tensorflow/1a5ffed6613049d8fd43c8ff5cbe34061394975e/assets/result.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/StackGAN-Tensorflow/1a5ffed6613049d8fd43c8ff5cbe34061394975e/assets/teaser.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from StackGAN import StackGAN 2 | import argparse 3 | from utils import * 4 | 5 | """parsing and configuration""" 6 | def parse_args(): 7 | desc = "Tensorflow implementation of StackGAN" 8 | parser = argparse.ArgumentParser(description=desc) 9 | parser.add_argument('--phase', type=str, default='train', choices=('train', 'test'), help='phase name') 10 | parser.add_argument('--dataset', type=str, default='birds', help='dataset_name') 11 | 12 | parser.add_argument('--iteration', type=int, default=500000, help='The number of training iterations') 13 | parser.add_argument('--decay_flag', type=str2bool, default=True, help='The decay_flag') 14 | parser.add_argument('--decay_iter', type=int, default=100000, help='decay epoch') 15 | 16 | parser.add_argument('--batch_size', type=int, default=32, help='The size of batch size for each gpu') 17 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq') 18 | parser.add_argument('--save_freq', type=int, default=10000, help='The number of ckpt_save_freq') 19 | 20 | parser.add_argument('--lr', type=float, default=0.0002, help='The learning rate') 21 | 22 | parser.add_argument('--gan_type', type=str, default='gan', help='[gan / lsgan / hinge]') 23 | 24 | parser.add_argument('--adv_weight', type=int, default=1, help='Weight about GAN') 25 | parser.add_argument('--kl_weight', type=int, default=2, help='Weight about kl_loss') 26 | 27 | parser.add_argument('--sn', type=str2bool, default=False, help='using spectral norm') 28 | 29 | parser.add_argument('--img_height', type=int, default=256, help='The height size of image') 30 | parser.add_argument('--img_width', type=int, default=256, help='The width size of image ') 31 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel') 32 | parser.add_argument('--augment_flag', type=str2bool, default=True, help='Image augmentation use or not') 33 | 34 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 35 | help='Directory name to save the checkpoints') 36 | parser.add_argument('--result_dir', type=str, default='results', 37 | help='Directory name to save the generated images') 38 | parser.add_argument('--log_dir', type=str, default='logs', 39 | help='Directory name to save training logs') 40 | parser.add_argument('--sample_dir', type=str, default='samples', 41 | help='Directory name to save the samples on training') 42 | 43 | return check_args(parser.parse_args()) 44 | 45 | """checking arguments""" 46 | def check_args(args): 47 | # --checkpoint_dir 48 | check_folder(args.checkpoint_dir) 49 | 50 | # --result_dir 51 | check_folder(args.result_dir) 52 | 53 | # --log_dir 54 | check_folder(args.log_dir) 55 | 56 | # --sample_dir 57 | check_folder(args.sample_dir) 58 | 59 | # --batch_size 60 | try: 61 | assert args.batch_size >= 1 62 | except: 63 | print('batch size must be larger than or equal to one') 64 | return args 65 | 66 | """main""" 67 | def main(): 68 | # parse arguments 69 | args = parse_args() 70 | if args is None: 71 | exit() 72 | 73 | # open session 74 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 75 | gan = StackGAN(sess, args) 76 | 77 | # build graph 78 | gan.build_model() 79 | 80 | # show network architecture 81 | show_all_variables() 82 | 83 | if args.phase == 'train' : 84 | gan.train() 85 | print(" [*] Training finished!") 86 | 87 | if args.phase == 'test' : 88 | gan.test() 89 | print(" [*] Test finished!") 90 | 91 | 92 | 93 | if __name__ == '__main__': 94 | main() 95 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tf_contrib 3 | from utils import pytorch_kaiming_weight_factor 4 | 5 | ################################################################################## 6 | # Initialization 7 | ################################################################################## 8 | 9 | # factor, mode, uniform = pytorch_kaiming_weight_factor(a=0.0, uniform=False) 10 | # weight_init = tf_contrib.layers.variance_scaling_initializer(factor=factor, mode=mode, uniform=uniform) 11 | weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02) 12 | 13 | weight_regularizer = None 14 | weight_regularizer_fully = None 15 | 16 | ################################################################################## 17 | # Layer 18 | ################################################################################## 19 | 20 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'): 21 | with tf.variable_scope(scope): 22 | if pad > 0: 23 | h = x.get_shape().as_list()[1] 24 | if h % stride == 0: 25 | pad = pad * 2 26 | else: 27 | pad = max(kernel - (h % stride), 0) 28 | 29 | pad_top = pad // 2 30 | pad_bottom = pad - pad_top 31 | pad_left = pad // 2 32 | pad_right = pad - pad_left 33 | 34 | if pad_type == 'zero': 35 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]) 36 | if pad_type == 'reflect': 37 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT') 38 | 39 | if sn: 40 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init, 41 | regularizer=weight_regularizer) 42 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), 43 | strides=[1, stride, stride, 1], padding='VALID') 44 | if use_bias: 45 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 46 | x = tf.nn.bias_add(x, bias) 47 | 48 | else: 49 | x = tf.layers.conv2d(inputs=x, filters=channels, 50 | kernel_size=kernel, kernel_initializer=weight_init, 51 | kernel_regularizer=weight_regularizer, 52 | strides=stride, use_bias=use_bias) 53 | 54 | return x 55 | 56 | 57 | def fully_connected(x, units, use_bias=True, sn=False, scope='linear'): 58 | with tf.variable_scope(scope): 59 | x = flatten(x) 60 | shape = x.get_shape().as_list() 61 | channels = shape[-1] 62 | 63 | if sn: 64 | w = tf.get_variable("kernel", [channels, units], tf.float32, 65 | initializer=weight_init, regularizer=weight_regularizer_fully) 66 | if use_bias: 67 | bias = tf.get_variable("bias", [units], 68 | initializer=tf.constant_initializer(0.0)) 69 | 70 | x = tf.matmul(x, spectral_norm(w)) + bias 71 | else: 72 | x = tf.matmul(x, spectral_norm(w)) 73 | 74 | else: 75 | x = tf.layers.dense(x, units=units, kernel_initializer=weight_init, 76 | kernel_regularizer=weight_regularizer_fully, 77 | use_bias=use_bias) 78 | 79 | return x 80 | 81 | 82 | def flatten(x): 83 | return tf.layers.flatten(x) 84 | 85 | 86 | ################################################################################## 87 | # Residual-block 88 | ################################################################################## 89 | 90 | 91 | def resblock(x_init, channels, is_training=True, use_bias=True, sn=False, scope='resblock'): 92 | with tf.variable_scope(scope): 93 | with tf.variable_scope('res1'): 94 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn) 95 | x = batch_norm(x, is_training) 96 | x = relu(x) 97 | 98 | with tf.variable_scope('res2'): 99 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn) 100 | x = batch_norm(x, is_training) 101 | 102 | return relu(x + x_init) 103 | 104 | def up_block(x_init, channels, is_training=True, use_bias=True, sn=False, scope='up_block'): 105 | with tf.variable_scope(scope): 106 | x = up_sample(x_init, scale_factor=2) 107 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn) 108 | x = batch_norm(x, is_training) 109 | x = relu(x) 110 | 111 | return x 112 | 113 | ################################################################################## 114 | # Sampling 115 | ################################################################################## 116 | 117 | def up_sample(x, scale_factor=2): 118 | _, h, w, _ = x.get_shape().as_list() 119 | new_size = [h * scale_factor, w * scale_factor] 120 | return tf.image.resize_nearest_neighbor(x, size=new_size) 121 | 122 | 123 | def down_sample_avg(x, scale_factor=2): 124 | return tf.layers.average_pooling2d(x, pool_size=3, strides=scale_factor, padding='SAME') 125 | 126 | def global_avg_pooling(x): 127 | gap = tf.reduce_mean(x, axis=[1, 2], keepdims=True) 128 | return gap 129 | 130 | def reparametrize(mean, logvar): 131 | eps = tf.random_normal(tf.shape(mean), mean=0.0, stddev=1.0, dtype=tf.float32) 132 | 133 | return mean + tf.exp(logvar * 0.5) * eps 134 | 135 | ################################################################################## 136 | # Activation function 137 | ################################################################################## 138 | 139 | def lrelu(x, alpha=0.01): 140 | # pytorch alpha is 0.01 141 | return tf.nn.leaky_relu(x, alpha) 142 | 143 | 144 | def relu(x): 145 | return tf.nn.relu(x) 146 | 147 | 148 | def tanh(x): 149 | return tf.tanh(x) 150 | 151 | 152 | ################################################################################## 153 | # Normalization function 154 | ################################################################################## 155 | 156 | def instance_norm(x, scope='instance_norm'): 157 | return tf_contrib.layers.instance_norm(x, 158 | epsilon=1e-05, 159 | center=True, scale=True, 160 | scope=scope) 161 | 162 | def batch_norm(x, is_training=False, scope='batch_norm'): 163 | """ 164 | if x_norm = tf.layers.batch_normalization 165 | # ... 166 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 167 | train_op = optimizer.minimize(loss) 168 | """ 169 | 170 | return tf_contrib.layers.batch_norm(x, 171 | decay=0.9, epsilon=1e-05, 172 | center=True, scale=True, updates_collections=None, 173 | is_training=is_training, scope=scope) 174 | 175 | # return tf.layers.batch_normalization(x, momentum=0.9, epsilon=1e-05, center=True, scale=True, training=is_training, name=scope) 176 | 177 | 178 | def param_free_norm(x, epsilon=1e-5): 179 | x_mean, x_var = tf.nn.moments(x, axes=[1, 2], keep_dims=True) 180 | x_std = tf.sqrt(x_var + epsilon) 181 | 182 | return (x - x_mean) / x_std 183 | 184 | def adaptive_instance_norm(content, gamma, beta, epsilon=1e-5): 185 | # gamma, beta = style_mean, style_std from MLP 186 | 187 | x = param_free_norm(content, epsilon) 188 | 189 | return gamma * x + beta 190 | 191 | def spectral_norm(w, iteration=1): 192 | w_shape = w.shape.as_list() 193 | w = tf.reshape(w, [-1, w_shape[-1]]) 194 | 195 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False) 196 | 197 | u_hat = u 198 | v_hat = None 199 | for i in range(iteration): 200 | """ 201 | power iteration 202 | Usually iteration = 1 will be enough 203 | """ 204 | v_ = tf.matmul(u_hat, tf.transpose(w)) 205 | v_hat = tf.nn.l2_normalize(v_) 206 | 207 | u_ = tf.matmul(v_hat, w) 208 | u_hat = tf.nn.l2_normalize(u_) 209 | 210 | u_hat = tf.stop_gradient(u_hat) 211 | v_hat = tf.stop_gradient(v_hat) 212 | 213 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) 214 | 215 | with tf.control_dependencies([u.assign(u_hat)]): 216 | w_norm = w / sigma 217 | w_norm = tf.reshape(w_norm, w_shape) 218 | 219 | return w_norm 220 | 221 | 222 | ################################################################################## 223 | # Loss function 224 | ################################################################################## 225 | 226 | def L1_loss(x, y): 227 | loss = tf.reduce_mean(tf.abs(x - y)) # [64, h, w, c] 228 | 229 | return loss 230 | 231 | def discriminator_loss(gan_type, real_logit, fake_logit): 232 | real_loss = 0 233 | fake_loss = 0 234 | 235 | if gan_type == 'lsgan': 236 | real_loss = tf.reduce_mean(tf.squared_difference(real_logit, 1.0)) 237 | fake_loss = tf.reduce_mean(tf.square(fake_logit)) 238 | 239 | if gan_type == 'gan': 240 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real_logit), logits=real_logit)) 241 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake_logit), logits=fake_logit)) 242 | 243 | if gan_type == 'hinge': 244 | 245 | real_loss = tf.reduce_mean(relu(1 - real_logit)) 246 | fake_loss = tf.reduce_mean(relu(1 + fake_logit)) 247 | 248 | return real_loss + fake_loss 249 | 250 | 251 | def generator_loss(gan_type, fake_logit): 252 | fake_loss = 0 253 | 254 | if gan_type == 'lsgan': 255 | fake_loss = tf.reduce_mean(tf.squared_difference(fake_logit, 1.0)) 256 | 257 | if gan_type == 'gan': 258 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake_logit), logits=fake_logit)) 259 | 260 | if gan_type == 'hinge': 261 | fake_loss = -tf.reduce_mean(fake_logit) 262 | 263 | return fake_loss 264 | 265 | 266 | def regularization_loss(scope_name): 267 | """ 268 | If you want to use "Regularization" 269 | g_loss += regularization_loss('generator') 270 | d_loss += regularization_loss('discriminator') 271 | """ 272 | collection_regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 273 | 274 | loss = [] 275 | for item in collection_regularization: 276 | if scope_name in item.name: 277 | loss.append(item) 278 | 279 | return tf.reduce_sum(loss) 280 | 281 | def kl_loss(mean, logvar): 282 | # shape : [batch_size, channel] 283 | loss = 0.5 * tf.reduce_sum(tf.square(mean) + tf.exp(logvar) - 1 - logvar, axis=-1) 284 | loss = tf.reduce_mean(loss) 285 | 286 | return loss -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import slim 3 | import os 4 | import numpy as np 5 | from glob import glob 6 | import cv2 7 | import pickle 8 | 9 | class Image_data: 10 | 11 | def __init__(self, img_height, img_width, channels, dataset_path, augment_flag): 12 | self.img_height = img_height 13 | self.img_width = img_width 14 | self.channels = channels 15 | self.augment_flag = augment_flag 16 | 17 | self.dataset_path = dataset_path 18 | self.image_path = os.path.join(dataset_path, 'images') 19 | self.text_path = os.path.join(dataset_path, 'text') 20 | 21 | self.embedding_pickle = os.path.join(self.text_path, 'char-CNN-RNN-embeddings.pickle') 22 | self.image_filename_pickle = os.path.join(self.text_path, 'filenames.pickle') 23 | 24 | 25 | self.image_list = [] 26 | 27 | 28 | def image_processing(self, filename, vector): 29 | x = tf.read_file(filename) 30 | x_decode = tf.image.decode_jpeg(x, channels=self.channels, dct_method='INTEGER_ACCURATE') 31 | img = tf.image.resize_images(x_decode, [self.img_height, self.img_width]) 32 | img = tf.cast(img, tf.float32) / 127.5 - 1 33 | 34 | 35 | if self.augment_flag : 36 | augment_height_size = self.img_height + (30 if self.img_height == 256 else int(self.img_height * 0.1)) 37 | augment_width_size = self.img_width + (30 if self.img_width == 256 else int(self.img_width * 0.1)) 38 | 39 | img = tf.cond(pred=tf.greater_equal(tf.random_uniform(shape=[], minval=0.0, maxval=1.0), 0.5), 40 | true_fn=lambda : augmentation(img, augment_height_size, augment_width_size), 41 | false_fn=lambda : img) 42 | 43 | return img, vector 44 | 45 | def preprocess(self): 46 | with open(self.embedding_pickle, 'rb') as f: 47 | 48 | self.embedding = pickle._Unpickler(f) 49 | self.embedding.encoding = 'latin1' 50 | self.embedding = self.embedding.load() 51 | self.embedding = np.array(self.embedding) # (8855, 10, 1024) 52 | 53 | with open(self.image_filename_pickle, 'rb') as f: 54 | # ['002.Laysan_Albatross/Laysan_Albatross_0002_1027', '002.Laysan_Albatross/Laysan_Albatross_0003_1033', ... ] 55 | 56 | x_list = pickle.load(f) 57 | 58 | for x in x_list : 59 | folder_name = x.split('/')[0] 60 | file_name = x.split('/')[1] + '.jpg' 61 | 62 | self.image_list.append(os.path.join(self.image_path, folder_name, file_name)) 63 | 64 | 65 | def load_test_image(image_path, img_width, img_height, img_channel): 66 | 67 | if img_channel == 1 : 68 | img = cv2.imread(image_path, flags=cv2.IMREAD_GRAYSCALE) 69 | else : 70 | img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR) 71 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 72 | 73 | img = cv2.resize(img, dsize=(img_width, img_height)) 74 | 75 | if img_channel == 1 : 76 | img = np.expand_dims(img, axis=0) 77 | img = np.expand_dims(img, axis=-1) 78 | else : 79 | img = np.expand_dims(img, axis=0) 80 | 81 | img = img/127.5 - 1 82 | 83 | return img 84 | 85 | 86 | def preprocessing(x): 87 | x = x/127.5 - 1 # -1 ~ 1 88 | return x 89 | 90 | def preprocess_fit_train_image(images, height, width): 91 | images = tf.image.resize(images, size=[height, width], method=tf.image.ResizeMethod.BILINEAR) 92 | images = adjust_dynamic_range(images) 93 | 94 | return images 95 | 96 | def adjust_dynamic_range(images): 97 | drange_in = [0.0, 255.0] 98 | drange_out = [-1.0, 1.0] 99 | scale = (drange_out[1] - drange_out[0]) / (drange_in[1] - drange_in[0]) 100 | bias = drange_out[0] - drange_in[0] * scale 101 | images = images * scale + bias 102 | return images 103 | 104 | def augmentation(image, augment_height, augment_width): 105 | seed = np.random.randint(0, 2 ** 31 - 1) 106 | 107 | ori_image_shape = tf.shape(image) 108 | image = tf.image.random_flip_left_right(image, seed=seed) 109 | image = tf.image.resize(image, size=[augment_height, augment_width], method=tf.image.ResizeMethod.BILINEAR) 110 | image = tf.random_crop(image, ori_image_shape, seed=seed) 111 | 112 | 113 | return image 114 | 115 | def save_images(images, size, image_path): 116 | return imsave(inverse_transform(images), size, image_path) 117 | 118 | def inverse_transform(images): 119 | return ((images+1.) / 2) * 255.0 120 | 121 | def imsave(images, size, path): 122 | images = merge(images, size) 123 | images = cv2.cvtColor(images.astype('uint8'), cv2.COLOR_RGB2BGR) 124 | 125 | return cv2.imwrite(path, images) 126 | 127 | 128 | def post_process_generator_output(generator_output): 129 | 130 | drange_min, drange_max = -1.0, 1.0 131 | scale = 255.0 / (drange_max - drange_min) 132 | 133 | scaled_image = generator_output * scale + (0.5 - drange_min * scale) 134 | scaled_image = np.clip(scaled_image, 0, 255) 135 | 136 | return scaled_image 137 | 138 | def merge(images, size): 139 | h, w = images.shape[1], images.shape[2] 140 | c = images.shape[3] 141 | img = np.zeros((h * size[0], w * size[1], c)) 142 | for idx, image in enumerate(images): 143 | i = idx % size[1] 144 | j = idx // size[1] 145 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image 146 | 147 | return img 148 | 149 | def return_images(images, size) : 150 | x = merge(images, size) 151 | 152 | return x 153 | 154 | def show_all_variables(): 155 | model_vars = tf.trainable_variables() 156 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 157 | 158 | def check_folder(log_dir): 159 | if not os.path.exists(log_dir): 160 | os.makedirs(log_dir) 161 | return log_dir 162 | 163 | def str2bool(x): 164 | return x.lower() in ('true') 165 | 166 | def get_one_hot(targets, nb_classes): 167 | 168 | x = np.eye(nb_classes)[targets] 169 | 170 | return x 171 | 172 | def pytorch_xavier_weight_factor(gain=0.02, uniform=False) : 173 | 174 | if uniform : 175 | factor = gain * gain 176 | mode = 'FAN_AVG' 177 | else : 178 | factor = (gain * gain) / 1.3 179 | mode = 'FAN_AVG' 180 | 181 | return factor, mode, uniform 182 | 183 | def pytorch_kaiming_weight_factor(a=0.0, activation_function='leaky_relu', uniform=False) : 184 | 185 | if activation_function == 'relu' : 186 | gain = np.sqrt(2.0) 187 | elif activation_function == 'leaky_relu' : 188 | gain = np.sqrt(2.0 / (1 + a ** 2)) 189 | elif activation_function == 'tanh' : 190 | gain = 5.0 / 3 191 | else : 192 | gain = 1.0 193 | 194 | if uniform : 195 | factor = gain * gain 196 | mode = 'FAN_IN' 197 | else : 198 | factor = (gain * gain) / 1.3 199 | mode = 'FAN_IN' 200 | 201 | return factor, mode, uniform 202 | --------------------------------------------------------------------------------
namecontentstyleoutput
%s