├── LICENSE ├── PGGAN.py ├── README.md ├── download.py ├── h5tool.py ├── images ├── figure.png ├── hs_sample_128.jpg ├── hs_sample_64.jpg ├── sample.png └── sample_128.png ├── main.py ├── ops.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 JiChao Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PGGAN.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from ops import lrelu, conv2d, fully_connect, upscale, Pixl_Norm, downscale2d, MinibatchstateConcat 3 | from utils import save_images 4 | import numpy as np 5 | from scipy.ndimage.interpolation import zoom 6 | 7 | class PGGAN(object): 8 | 9 | # build model 10 | def __init__(self, batch_size, max_iters, model_path, read_model_path, data, sample_size, sample_path, log_dir, 11 | learn_rate, lam_gp, lam_eps, PG, t, use_wscale, is_celeba): 12 | self.batch_size = batch_size 13 | self.max_iters = max_iters 14 | self.gan_model_path = model_path 15 | self.read_model_path = read_model_path 16 | self.data_In = data 17 | self.sample_size = sample_size 18 | self.sample_path = sample_path 19 | self.log_dir = log_dir 20 | self.learning_rate = learn_rate 21 | self.lam_gp = lam_gp 22 | self.lam_eps = lam_eps 23 | self.pg = PG 24 | self.trans = t 25 | self.log_vars = [] 26 | self.channel = self.data_In.channel 27 | self.output_size = 4 * pow(2, PG - 1) 28 | self.use_wscale = use_wscale 29 | self.is_celeba = is_celeba 30 | self.images = tf.placeholder(tf.float32, [batch_size, self.output_size, self.output_size, self.channel]) 31 | self.z = tf.placeholder(tf.float32, [self.batch_size, self.sample_size]) 32 | self.alpha_tra = tf.Variable(initial_value=0.0, trainable=False,name='alpha_tra') 33 | 34 | def build_model_PGGan(self): 35 | self.fake_images = self.generate(self.z, pg=self.pg, t=self.trans, alpha_trans=self.alpha_tra) 36 | _, self.D_pro_logits = self.discriminate(self.images, reuse=False, pg = self.pg, t=self.trans, alpha_trans=self.alpha_tra) 37 | _, self.G_pro_logits = self.discriminate(self.fake_images, reuse=True,pg= self.pg, t=self.trans, alpha_trans=self.alpha_tra) 38 | 39 | # the defination of loss for D and G 40 | self.D_loss = tf.reduce_mean(self.G_pro_logits) - tf.reduce_mean(self.D_pro_logits) 41 | self.G_loss = -tf.reduce_mean(self.G_pro_logits) 42 | 43 | # gradient penalty from WGAN-GP 44 | self.differences = self.fake_images - self.images 45 | self.alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.) 46 | interpolates = self.images + (self.alpha * self.differences) 47 | _, discri_logits= self.discriminate(interpolates, reuse=True, pg=self.pg, t=self.trans, alpha_trans=self.alpha_tra) 48 | gradients = tf.gradients(discri_logits, [interpolates])[0] 49 | 50 | # 2 norm 51 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2, 3])) 52 | self.gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2) 53 | tf.summary.scalar("gp_loss", self.gradient_penalty) 54 | 55 | self.D_origin_loss = self.D_loss 56 | self.D_loss += self.lam_gp * self.gradient_penalty 57 | self.D_loss += self.lam_eps * tf.reduce_mean(tf.square(self.D_pro_logits)) 58 | 59 | self.log_vars.append(("generator_loss", self.G_loss)) 60 | self.log_vars.append(("discriminator_loss", self.D_loss)) 61 | 62 | t_vars = tf.trainable_variables() 63 | self.d_vars = [var for var in t_vars if 'dis' in var.name] 64 | 65 | total_para = 0 66 | for variable in self.d_vars: 67 | shape = variable.get_shape() 68 | print (variable.name, shape) 69 | variable_para = 1 70 | for dim in shape: 71 | variable_para *= dim.value 72 | total_para += variable_para 73 | print ("The total para of D", total_para) 74 | 75 | self.g_vars = [var for var in t_vars if 'gen' in var.name] 76 | 77 | total_para2 = 0 78 | for variable in self.g_vars: 79 | shape = variable.get_shape() 80 | print (variable.name, shape) 81 | variable_para = 1 82 | for dim in shape: 83 | variable_para *= dim.value 84 | total_para2 += variable_para 85 | print ("The total para of G", total_para2) 86 | 87 | #save the variables , which remain unchanged 88 | self.d_vars_n = [var for var in self.d_vars if 'dis_n' in var.name] 89 | self.g_vars_n = [var for var in self.g_vars if 'gen_n' in var.name] 90 | 91 | # remove the new variables for the new model 92 | self.d_vars_n_read = [var for var in self.d_vars_n if '{}'.format(self.output_size) not in var.name] 93 | self.g_vars_n_read = [var for var in self.g_vars_n if '{}'.format(self.output_size) not in var.name] 94 | 95 | # save the rgb variables, which remain unchanged 96 | self.d_vars_n_2 = [var for var in self.d_vars if 'dis_y_rgb_conv' in var.name] 97 | self.g_vars_n_2 = [var for var in self.g_vars if 'gen_y_rgb_conv' in var.name] 98 | 99 | self.d_vars_n_2_rgb = [var for var in self.d_vars_n_2 if '{}'.format(self.output_size) not in var.name] 100 | self.g_vars_n_2_rgb = [var for var in self.g_vars_n_2 if '{}'.format(self.output_size) not in var.name] 101 | 102 | print ("d_vars", len(self.d_vars)) 103 | print ("g_vars", len(self.g_vars)) 104 | 105 | print ("self.d_vars_n_read", len(self.d_vars_n_read)) 106 | print ("self.g_vars_n_read", len(self.g_vars_n_read)) 107 | 108 | print ("d_vars_n_2_rgb", len(self.d_vars_n_2_rgb)) 109 | print ("g_vars_n_2_rgb", len(self.g_vars_n_2_rgb)) 110 | 111 | # for n in self.d_vars: 112 | # print (n.name) 113 | 114 | self.g_d_w = [var for var in self.d_vars + self.g_vars if 'bias' not in var.name] 115 | 116 | print ("self.g_d_w", len(self.g_d_w)) 117 | 118 | self.saver = tf.train.Saver(self.d_vars + self.g_vars) 119 | self.r_saver = tf.train.Saver(self.d_vars_n_read + self.g_vars_n_read) 120 | 121 | if len(self.d_vars_n_2_rgb + self.g_vars_n_2_rgb): 122 | self.rgb_saver = tf.train.Saver(self.d_vars_n_2_rgb + self.g_vars_n_2_rgb) 123 | 124 | for k, v in self.log_vars: 125 | tf.summary.scalar(k, v) 126 | 127 | # do train 128 | def train(self): 129 | step_pl = tf.placeholder(tf.float32, shape=None) 130 | alpha_tra_assign = self.alpha_tra.assign(step_pl / self.max_iters) 131 | 132 | opti_D = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=0.0, beta2=0.99).minimize( 133 | self.D_loss, var_list=self.d_vars) 134 | opti_G = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=0.0, beta2=0.99).minimize( 135 | self.G_loss, var_list=self.g_vars) 136 | 137 | init = tf.global_variables_initializer() 138 | config = tf.ConfigProto() 139 | config.gpu_options.allow_growth = True 140 | 141 | with tf.Session(config=config) as sess: 142 | sess.run(init) 143 | summary_op = tf.summary.merge_all() 144 | summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph) 145 | if self.pg != 1 and self.pg != 7: 146 | if self.trans: 147 | self.r_saver.restore(sess, self.read_model_path) 148 | self.rgb_saver.restore(sess, self.read_model_path) 149 | 150 | else: 151 | self.saver.restore(sess, self.read_model_path) 152 | 153 | step = 0 154 | batch_num = 0 155 | while step <= self.max_iters: 156 | # optimization D 157 | n_critic = 1 158 | if self.pg >= 5: 159 | n_critic = 1 160 | 161 | for i in range(n_critic): 162 | sample_z = np.random.normal(size=[self.batch_size, self.sample_size]) 163 | if self.is_celeba: 164 | train_list = self.data_In.getNextBatch(batch_num, self.batch_size) 165 | realbatch_array = self.data_In.getShapeForData(train_list, resize_w=self.output_size) 166 | else: 167 | realbatch_array = self.data_In.getNextBatch(self.batch_size, resize_w=self.output_size) 168 | realbatch_array = np.transpose(realbatch_array, axes=[0, 3, 2, 1]).transpose([0, 2, 1, 3]) 169 | 170 | if self.trans and self.pg != 0: 171 | alpha = np.float(step) / self.max_iters 172 | low_realbatch_array = zoom(realbatch_array, zoom=[1, 0.5, 0.5, 1], mode='nearest') 173 | low_realbatch_array = zoom(low_realbatch_array, zoom=[1, 2, 2, 1], mode='nearest') 174 | realbatch_array = alpha * realbatch_array + (1 - alpha) * low_realbatch_array 175 | 176 | sess.run(opti_D, feed_dict={self.images: realbatch_array, self.z: sample_z}) 177 | batch_num += 1 178 | 179 | # optimization G 180 | sess.run(opti_G, feed_dict={self.z: sample_z}) 181 | 182 | summary_str = sess.run(summary_op, feed_dict={self.images: realbatch_array, self.z: sample_z}) 183 | summary_writer.add_summary(summary_str, step) 184 | summary_writer.add_summary(summary_str, step) 185 | # the alpha of fake_in process 186 | sess.run(alpha_tra_assign, feed_dict={step_pl: step}) 187 | 188 | if step % 400 == 0: 189 | D_loss, G_loss, D_origin_loss, alpha_tra = sess.run([self.D_loss, self.G_loss, self.D_origin_loss,self.alpha_tra], feed_dict={self.images: realbatch_array, self.z: sample_z}) 190 | print("PG %d, step %d: D loss=%.7f G loss=%.7f, D_or loss=%.7f, opt_alpha_tra=%.7f" % (self.pg, step, D_loss, G_loss, D_origin_loss, alpha_tra)) 191 | 192 | realbatch_array = np.clip(realbatch_array, -1, 1) 193 | save_images(realbatch_array[0:self.batch_size], [2, self.batch_size/2], 194 | '{}/{:02d}_real.jpg'.format(self.sample_path, step)) 195 | 196 | if self.trans and self.pg != 0: 197 | low_realbatch_array = np.clip(low_realbatch_array, -1, 1) 198 | save_images(low_realbatch_array[0:self.batch_size], [2, self.batch_size / 2], 199 | '{}/{:02d}_real_lower.jpg'.format(self.sample_path, step)) 200 | 201 | fake_image = sess.run(self.fake_images, 202 | feed_dict={self.images: realbatch_array, self.z: sample_z}) 203 | fake_image = np.clip(fake_image, -1, 1) 204 | save_images(fake_image[0:self.batch_size], [2, self.batch_size/2], '{}/{:02d}_train.jpg'.format(self.sample_path, step)) 205 | 206 | if np.mod(step, 4000) == 0 and step != 0: 207 | self.saver.save(sess, self.gan_model_path) 208 | 209 | step += 1 210 | 211 | save_path = self.saver.save(sess, self.gan_model_path) 212 | print ("Model saved in file: %s" % save_path) 213 | 214 | tf.reset_default_graph() 215 | 216 | def discriminate(self, conv, reuse=False, pg=1, t=False, alpha_trans=0.01): 217 | #dis_as_v = [] 218 | with tf.variable_scope("discriminator") as scope: 219 | 220 | if reuse == True: 221 | scope.reuse_variables() 222 | if t: 223 | conv_iden = downscale2d(conv) 224 | #from RGB 225 | conv_iden = lrelu(conv2d(conv_iden, output_dim= self.get_nf(pg - 2), k_w=1, k_h=1, d_h=1, d_w=1, use_wscale=self.use_wscale, 226 | name='dis_y_rgb_conv_{}'.format(conv_iden.shape[1]))) 227 | # fromRGB 228 | conv = lrelu(conv2d(conv, output_dim=self.get_nf(pg - 1), k_w=1, k_h=1, d_w=1, d_h=1, use_wscale=self.use_wscale, name='dis_y_rgb_conv_{}'.format(conv.shape[1]))) 229 | 230 | for i in range(pg - 1): 231 | conv = lrelu(conv2d(conv, output_dim=self.get_nf(pg - 1 - i), d_h=1, d_w=1, use_wscale=self.use_wscale, 232 | name='dis_n_conv_1_{}'.format(conv.shape[1]))) 233 | conv = lrelu(conv2d(conv, output_dim=self.get_nf(pg - 2 - i), d_h=1, d_w=1, use_wscale=self.use_wscale, 234 | name='dis_n_conv_2_{}'.format(conv.shape[1]))) 235 | conv = downscale2d(conv) 236 | if i == 0 and t: 237 | conv = alpha_trans * conv + (1 - alpha_trans) * conv_iden 238 | 239 | conv = MinibatchstateConcat(conv) 240 | conv = lrelu( 241 | conv2d(conv, output_dim=self.get_nf(1), k_w=3, k_h=3, d_h=1, d_w=1, use_wscale=self.use_wscale, name='dis_n_conv_1_{}'.format(conv.shape[1]))) 242 | conv = lrelu( 243 | conv2d(conv, output_dim=self.get_nf(1), k_w=4, k_h=4, d_h=1, d_w=1, use_wscale=self.use_wscale, padding='VALID', name='dis_n_conv_2_{}'.format(conv.shape[1]))) 244 | conv = tf.reshape(conv, [self.batch_size, -1]) 245 | 246 | #for D 247 | output = fully_connect(conv, output_size=1, use_wscale=self.use_wscale, gain=1, name='dis_n_fully') 248 | 249 | return tf.nn.sigmoid(output), output 250 | 251 | def generate(self, z_var, pg=1, t=False, alpha_trans=0.0): 252 | with tf.variable_scope('generator') as scope: 253 | 254 | de = tf.reshape(Pixl_Norm(z_var), [self.batch_size, 1, 1, int(self.get_nf(1))]) 255 | de = conv2d(de, output_dim=self.get_nf(1), k_h=4, k_w=4, d_w=1, d_h=1, use_wscale=self.use_wscale, gain=np.sqrt(2)/4, padding='Other', name='gen_n_1_conv') 256 | de = Pixl_Norm(lrelu(de)) 257 | de = tf.reshape(de, [self.batch_size, 4, 4, int(self.get_nf(1))]) 258 | de = conv2d(de, output_dim=self.get_nf(1), d_w=1, d_h=1, use_wscale=self.use_wscale, name='gen_n_2_conv') 259 | de = Pixl_Norm(lrelu(de)) 260 | 261 | for i in range(pg - 1): 262 | if i == pg - 2 and t: 263 | #To RGB 264 | de_iden = conv2d(de, output_dim=3, k_w=1, k_h=1, d_w=1, d_h=1, use_wscale=self.use_wscale, 265 | name='gen_y_rgb_conv_{}'.format(de.shape[1])) 266 | de_iden = upscale(de_iden, 2) 267 | 268 | de = upscale(de, 2) 269 | de = Pixl_Norm(lrelu( 270 | conv2d(de, output_dim=self.get_nf(i + 1), d_w=1, d_h=1, use_wscale=self.use_wscale, name='gen_n_conv_1_{}'.format(de.shape[1])))) 271 | de = Pixl_Norm(lrelu( 272 | conv2d(de, output_dim=self.get_nf(i + 1), d_w=1, d_h=1, use_wscale=self.use_wscale, name='gen_n_conv_2_{}'.format(de.shape[1])))) 273 | 274 | #To RGB 275 | de = conv2d(de, output_dim=3, k_w=1, k_h=1, d_w=1, d_h=1, use_wscale=self.use_wscale, gain=1, name='gen_y_rgb_conv_{}'.format(de.shape[1])) 276 | 277 | if pg == 1: return de 278 | if t: de = (1 - alpha_trans) * de_iden + alpha_trans*de 279 | else: de = de 280 | 281 | return de 282 | 283 | def get_nf(self, stage): 284 | return min(1024 / (2 **(stage * 1)), 512) 285 | 286 | def sample_z(self, mu, log_var): 287 | eps = tf.random_normal(shape=tf.shape(mu)) 288 | return mu + tf.exp(log_var / 2) * eps 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PGGAN-tensorflow 2 | The tensorflow implementation of [PROGRESSIVE GROWING OF GANS FOR IMPROVED QUALITY, STABILITY, AND VARIATION](https://arxiv.org/abs/1710.10196). 3 | 4 | ### The generative process of PG-GAN 5 | 6 |

7 | 8 |

9 | 10 | ## Differences with the original paper. 11 | 12 | - Recently, just generate 64x64 and 128x128 pixels samples. 13 | 14 | ## Setup 15 | 16 | ### Prerequisites 17 | 18 | - TensorFlow >= 1.4 19 | - python 2.7 or 3 20 | 21 | ### Getting Started 22 | - Clone this repo: 23 | ```bash 24 | git clone https://github.com/zhangqianhui/progressive_growing_of_gans_tensorflow.git 25 | cd progressive_growing_of_gans_tensorflow 26 | ``` 27 | - Download the CelebA dataset 28 | 29 | You can download the [CelebA dataset](https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AAB06FXaQRUNtjW9ntaoPGvCa?dl=0) 30 | and unzip CelebA into a directory. Noted that this directory don't contain the sub-directory. 31 | 32 | - The method for creating CelebA-HQ can be found on [Method](https://github.com/github-pengge/PyTorch-progressive_growing_of_gans#how-to-create-celeba-hq-dataset) 33 | 34 | - Train the model on CelebA dataset 35 | 36 | ```bash 37 | python main.py --path=your celeba data-path --celeba=True 38 | ``` 39 | 40 | - Train the model on CelebA-HQ dataset 41 | 42 | ```bash 43 | python main.py --path=your celeba-hq data-path --celeba=False 44 | ``` 45 | 46 | ## Results on celebA dataset 47 | Here is the generated 64x64 results(Left: generated; Right: Real): 48 | 49 |

50 | 51 |

52 | 53 | Here is the generated 128x128 results(Left: generated; Right: Real): 54 |

55 | 56 |

57 | 58 | 59 | ## Results on CelebA-HQ dataset 60 | Here is the generated 64x64 results(Left: Real; Right: Generated): 61 | 62 |

63 | 64 |

65 | 66 | Here is the generated 128x128 results(Left: Real; Right: Generated): 67 |

68 | 69 |

70 | 71 | ## Issue 72 | If you find some bugs, Thanks for your issue to propose it. 73 | 74 | ## Reference code 75 | 76 | [PGGAN Theano](https://github.com/tkarras/progressive_growing_of_gans) 77 | 78 | [PGGAN Pytorch](https://github.com/github-pengge/PyTorch-progressive_growing_of_gans) 79 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import json 5 | import zipfile 6 | import argparse 7 | import subprocess 8 | from six.moves import urllib 9 | 10 | parser = argparse.ArgumentParser(description='Download dataset for DCGAN.') 11 | parser.add_argument('datasets', metavar='N', type=str, nargs='+', choices=['celebA', 'lsun', 'mnist'], 12 | help='name of dataset to download [celebA, lsun, mnist]' , default='mnist') 13 | 14 | def download(url, dirpath): 15 | 16 | filename = url.split('/')[-1] 17 | filepath = os.path.join(dirpath, filename) 18 | u = urllib.request.urlopen(url) 19 | f = open(filepath, 'wb') 20 | filesize = int(u.headers["Content-Length"]) 21 | print("Downloading: %s Bytes: %s" % (filename, filesize)) 22 | 23 | downloaded = 0 24 | block_sz = 8192 25 | status_width = 70 26 | while True: 27 | buf = u.read(block_sz) 28 | if not buf: 29 | print('') 30 | break 31 | else: 32 | print('', end='\r') 33 | downloaded += len(buf) 34 | f.write(buf) 35 | status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") % 36 | ('=' * int(float(downloaded) / filesize * status_width) + '>', downloaded * 100. / filesize)) 37 | print(status, end='') 38 | sys.stdout.flush() 39 | f.close() 40 | return filepath 41 | 42 | def unzip(filepath): 43 | print("Extracting: " + filepath) 44 | dirpath = os.path.dirname(filepath) 45 | with zipfile.ZipFile(filepath) as zf: 46 | zf.extractall(dirpath) 47 | os.remove(filepath) 48 | 49 | def download_celeb_a(dirpath): 50 | data_dir = 'celebA' 51 | if os.path.exists(os.path.join(dirpath, data_dir)): 52 | print('Found Celeb-A - skip') 53 | return 54 | url = 'https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip?dl=1&pv=1' 55 | filepath = download(url, dirpath) 56 | zip_dir = '' 57 | with zipfile.ZipFile(filepath) as zf: 58 | zip_dir = zf.namelist()[0] 59 | zf.extractall(dirpath) 60 | os.remove(filepath) 61 | os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir)) 62 | 63 | def _list_categories(tag): 64 | url = 'http://lsun.cs.princeton.edu/htbin/list.cgi?tag=' + tag 65 | f = urllib.request.urlopen(url) 66 | return json.loads(f.read()) 67 | 68 | def _download_lsun(out_dir, category, set_name, tag): 69 | url = 'http://lsun.cs.princeton.edu/htbin/download.cgi?tag={tag}' \ 70 | '&category={category}&set={set_name}'.format(**locals()) 71 | print(url) 72 | if set_name == 'test': 73 | out_name = 'test_lmdb.zip' 74 | else: 75 | out_name = '{category}_{set_name}_lmdb.zip'.format(**locals()) 76 | out_path = os.path.join(out_dir, out_name) 77 | cmd = ['curl', url, '-o', out_path] 78 | print('Downloading', category, set_name, 'set') 79 | subprocess.call(cmd) 80 | 81 | def download_lsun(dirpath): 82 | data_dir = os.path.join(dirpath, 'lsun') 83 | if os.path.exists(data_dir): 84 | print('Found LSUN - skip') 85 | return 86 | else: 87 | os.mkdir(data_dir) 88 | 89 | tag = 'latest' 90 | #categories = _list_categories(tag) 91 | categories = ['bedroom'] 92 | 93 | for category in categories: 94 | _download_lsun(data_dir, category, 'train', tag) 95 | _download_lsun(data_dir, category, 'val', tag) 96 | _download_lsun(data_dir, '', 'test', tag) 97 | 98 | def download_mnist(dirpath): 99 | data_dir = os.path.join(dirpath, 'mnist') 100 | if os.path.exists(data_dir): 101 | print('Found MNIST - skip') 102 | return 103 | else: 104 | os.mkdir(data_dir) 105 | url_base = 'http://yann.lecun.com/exdb/mnist/' 106 | file_names = ['train-images-idx3-ubyte.gz','train-labels-idx1-ubyte.gz','t10k-images-idx3-ubyte.gz','t10k-labels-idx1-ubyte.gz'] 107 | for file_name in file_names: 108 | url = (url_base+file_name).format(**locals()) 109 | print(url) 110 | out_path = os.path.join(data_dir,file_name) 111 | cmd = ['curl', url, '-o', out_path] 112 | print('Downloading ', file_name) 113 | subprocess.call(cmd) 114 | cmd = ['gzip', '-d', out_path] 115 | print('Decompressing ', file_name) 116 | subprocess.call(cmd) 117 | 118 | def prepare_data_dir(path = './data'): 119 | if not os.path.exists(path): 120 | os.mkdir(path) 121 | 122 | if __name__ == '__main__': 123 | #args = parser.parse_args() 124 | prepare_data_dir() 125 | 126 | #if 'celebA' in args.datasets: 127 | download_celeb_a('./data') 128 | # if 'lsun' in args.datasets: 129 | # download_lsun('./data') 130 | # if 'mnist' in args.datasets: 131 | # download_mnist('./data') -------------------------------------------------------------------------------- /h5tool.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import io 4 | import glob 5 | import pickle 6 | import argparse 7 | import threading 8 | import Queue 9 | import traceback 10 | import numpy as np 11 | import scipy.ndimage 12 | import PIL.Image 13 | import h5py # conda install h5py 14 | 15 | # ---------------------------------------------------------------------------- 16 | 17 | class HDF5Exporter: 18 | def __init__(self, h5_filename, resolution, channels=3): 19 | rlog2 = int(np.floor(np.log2(resolution))) 20 | assert resolution == 2 ** rlog2 21 | self.resolution = resolution 22 | self.channels = channels 23 | self.h5_file = h5py.File(h5_filename, 'w') 24 | self.h5_lods = [] 25 | self.buffers = [] 26 | self.buffer_sizes = [] 27 | for lod in xrange(rlog2, -1, -1): 28 | r = 2 ** lod; 29 | c = channels 30 | bytes_per_item = c * (r ** 2) 31 | chunk_size = int(np.ceil(128.0 / bytes_per_item)) 32 | buffer_size = int(np.ceil(512.0 * np.exp2(20) / bytes_per_item)) 33 | lod = self.h5_file.create_dataset('data%dx%d' % (r, r), shape=(0, c, r, r), dtype=np.uint8, 34 | maxshape=(None, c, r, r), chunks=(chunk_size, c, r, r), 35 | compression='gzip', compression_opts=4) 36 | self.h5_lods.append(lod) 37 | self.buffers.append(np.zeros((buffer_size, c, r, r), dtype=np.uint8)) 38 | self.buffer_sizes.append(0) 39 | 40 | def close(self): 41 | for lod in xrange(len(self.h5_lods)): 42 | self.flush_lod(lod) 43 | self.h5_file.close() 44 | 45 | def add_images(self, img): 46 | assert img.ndim == 4 and img.shape[1] == self.channels and img.shape[2] == img.shape[3] 47 | assert img.shape[2] >= self.resolution and img.shape[2] == 2 ** int(np.floor(np.log2(img.shape[2]))) 48 | for lod in xrange(len(self.h5_lods)): 49 | while img.shape[2] > self.resolution / (2 ** lod): 50 | img = img.astype(np.float32) 51 | img = (img[:, :, 0::2, 0::2] + img[:, :, 0::2, 1::2] + img[:, :, 1::2, 0::2] + img[:, :, 1::2, 52 | 1::2]) * 0.25 53 | quant = np.uint8(np.clip(np.round(img), 0, 255)) 54 | ofs = 0 55 | while ofs < quant.shape[0]: 56 | num = min(quant.shape[0] - ofs, self.buffers[lod].shape[0] - self.buffer_sizes[lod]) 57 | self.buffers[lod][self.buffer_sizes[lod]: self.buffer_sizes[lod] + num] = quant[ofs: ofs + num] 58 | self.buffer_sizes[lod] += num 59 | if self.buffer_sizes[lod] == self.buffers[lod].shape[0]: 60 | self.flush_lod(lod) 61 | ofs += num 62 | 63 | def num_images(self): 64 | return self.h5_lods[0].shape[0] + self.buffer_sizes[0] 65 | 66 | def flush_lod(self, lod): 67 | num = self.buffer_sizes[lod] 68 | if num > 0: 69 | self.h5_lods[lod].resize(self.h5_lods[lod].shape[0] + num, axis=0) 70 | self.h5_lods[lod][-num:] = self.buffers[lod][:num] 71 | self.buffer_sizes[lod] = 0 72 | 73 | 74 | # ---------------------------------------------------------------------------- 75 | 76 | class ExceptionInfo(object): 77 | def __init__(self): 78 | self.type, self.value = sys.exc_info()[:2] 79 | self.traceback = traceback.format_exc() 80 | 81 | 82 | # ---------------------------------------------------------------------------- 83 | 84 | class WorkerThread(threading.Thread): 85 | def __init__(self, task_queue): 86 | threading.Thread.__init__(self) 87 | self.task_queue = task_queue 88 | 89 | def run(self): 90 | while True: 91 | func, args, result_queue = self.task_queue.get() 92 | if func is None: 93 | break 94 | try: 95 | result = func(*args) 96 | except: 97 | result = ExceptionInfo() 98 | result_queue.put((result, args)) 99 | 100 | 101 | # ---------------------------------------------------------------------------- 102 | 103 | class ThreadPool(object): 104 | def __init__(self, num_threads): 105 | assert num_threads >= 1 106 | self.task_queue = Queue.Queue() 107 | self.result_queues = dict() 108 | self.num_threads = num_threads 109 | for idx in xrange(self.num_threads): 110 | thread = WorkerThread(self.task_queue) 111 | thread.daemon = True 112 | thread.start() 113 | 114 | def add_task(self, func, args=()): 115 | assert hasattr(func, '__call__') # must be a function 116 | if func not in self.result_queues: 117 | self.result_queues[func] = Queue.Queue() 118 | self.task_queue.put((func, args, self.result_queues[func])) 119 | 120 | def get_result(self, func, verbose_exceptions=True): # returns (result, args) 121 | result, args = self.result_queues[func].get() 122 | if isinstance(result, ExceptionInfo): 123 | if verbose_exceptions: 124 | print('\n\nWorker thread caught an exception:\n' + result.traceback + '\n') 125 | raise Exception('%s, %s' % (result.type, result.value)) 126 | return result, args 127 | 128 | def finish(self): 129 | for idx in xrange(self.num_threads): 130 | self.task_queue.put((None, (), None)) 131 | 132 | def __enter__(self): # for 'with' statement 133 | return self 134 | 135 | def __exit__(self, *excinfo): 136 | self.finish() 137 | 138 | def process_items_concurrently(self, item_iterator, process_func=lambda x: x, pre_func=lambda x: x, 139 | post_func=lambda x: x, max_items_in_flight=None): 140 | if max_items_in_flight is None: max_items_in_flight = self.num_threads * 4 141 | assert max_items_in_flight >= 1 142 | results = [] 143 | retire_idx = [0] 144 | 145 | def task_func(prepared, idx): 146 | return process_func(prepared) 147 | 148 | def retire_result(): 149 | processed, (prepared, idx) = self.get_result(task_func) 150 | results[idx] = processed 151 | while retire_idx[0] < len(results) and results[retire_idx[0]] is not None: 152 | yield post_func(results[retire_idx[0]]) 153 | results[retire_idx[0]] = None 154 | retire_idx[0] += 1 155 | 156 | for idx, item in enumerate(item_iterator): 157 | prepared = pre_func(item) 158 | results.append(None) 159 | self.add_task(func=task_func, args=(prepared, idx)) 160 | while retire_idx[0] < idx - max_items_in_flight + 2: 161 | for res in retire_result(): yield res 162 | while retire_idx[0] < len(results): 163 | for res in retire_result(): yield res 164 | 165 | 166 | # ---------------------------------------------------------------------------- 167 | 168 | def inspect(h5_filename): 169 | print('%-20s%s' % ('HDF5 filename', h5_filename)) 170 | file_size = os.stat(h5_filename).st_size 171 | print('%-20s%.2f GB' % ('Total size', float(file_size) / np.exp2(30))) 172 | 173 | h5 = h5py.File(h5_filename, 'r') 174 | lods = sorted([value for key, value in h5.iteritems() if key.startswith('data')], key=lambda lod: -lod.shape[3]) 175 | shapes = [lod.shape for lod in lods] 176 | shape = shapes[0] 177 | h5.close() 178 | print('%-20s%d' % ('Total images', shape[0])) 179 | print('%-20s%dx%d' % ('Resolution', shape[3], shape[2])) 180 | print('%-20s%d' % ('Color channels', shape[1])) 181 | print('%-20s%.2f KB' % ('Size per image', float(file_size) / shape[0] / np.exp2(10))) 182 | 183 | if len(lods) != int(np.log2(shape[3])) + 1: 184 | print('Warning: The HDF5 file contains incorrect number of LODs') 185 | if any(s[0] != shape[0] for s in shapes): 186 | print('Warning: The HDF5 file contains inconsistent number of images in different LODs') 187 | print('Perhaps the dataset creation script was terminated abruptly?') 188 | 189 | 190 | # ---------------------------------------------------------------------------- 191 | 192 | def compare(first_h5, second_h5): 193 | print('Comparing %s vs. %s' % (first_h5, second_h5)) 194 | h5_a = h5py.File(first_h5, 'r') 195 | h5_b = h5py.File(second_h5, 'r') 196 | lods_a = sorted([value for key, value in h5_a.iteritems() if key.startswith('data')], key=lambda lod: -lod.shape[3]) 197 | lods_b = sorted([value for key, value in h5_b.iteritems() if key.startswith('data')], key=lambda lod: -lod.shape[3]) 198 | shape_a = lods_a[0].shape 199 | shape_b = lods_b[0].shape 200 | 201 | if shape_a[1] != shape_b[1]: 202 | print('The datasets have different number of color channels: %d vs. %d' % (shape_a[1], shape_b[1])) 203 | elif shape_a[3] != shape_b[3] or shape_a[2] != shape_b[2]: 204 | print( 205 | 'The datasets have different resolution: %dx%d vs. %dx%d' % (shape_a[3], shape_a[2], shape_b[3], shape_b[2])) 206 | else: 207 | min_images = min(shape_a[0], shape_b[0]) 208 | num_diffs = 0 209 | for idx in range(min_images): 210 | print('%d / %d\r' % (idx, min_images)) 211 | if np.any(lods_a[0][idx] != lods_b[0][idx]): 212 | print('%-40s\r' % '') 213 | print('Different image: %d' % idx) 214 | num_diffs += 1 215 | if shape_a[0] != shape_b[0]: 216 | print('The datasets contain different number of images: %d vs. %d' % (shape_a[0], shape_b[0])) 217 | if num_diffs == 0: 218 | print('All %d images are identical.' % min_images) 219 | else: 220 | print('%d images out of %d are different.' % (num_diffs, min_images)) 221 | 222 | h5_a.close() 223 | h5_b.close() 224 | 225 | 226 | # ---------------------------------------------------------------------------- 227 | 228 | def display(h5_filename, start=None, stop=None, step=None): 229 | print('Displaying images from %s' % h5_filename) 230 | h5 = h5py.File(h5_filename, 'r') 231 | lods = sorted([value for key, value in h5.iteritems() if key.startswith('data')], key=lambda lod: -lod.shape[3]) 232 | indices = range(lods[0].shape[0]) 233 | indices = indices[start: stop: step] 234 | 235 | import cv2 # pip install opencv-python 236 | window_name = 'h5tool' 237 | cv2.namedWindow(window_name) 238 | print('Press SPACE or ENTER to advance, ESC to exit.') 239 | 240 | for idx in indices: 241 | print('%d / %d\r' % (idx, lods[0].shape[0])) 242 | img = lods[0][idx] 243 | img = img.transpose(1, 2, 0) # CHW => HWC 244 | img = img[:, :, ::-1] # RGB => BGR 245 | cv2.imshow(window_name, img) 246 | c = cv2.waitKey() 247 | if c == 27: 248 | break 249 | 250 | h5.close() 251 | print('%-40s\r' % '') 252 | print('Done.') 253 | 254 | 255 | # ---------------------------------------------------------------------------- 256 | 257 | def extract(h5_filename, output_dir, start=None, stop=None, step=None): 258 | print('Extracting images from %s to %s' % (h5_filename, output_dir)) 259 | h5 = h5py.File(h5_filename, 'r') 260 | lods = sorted([value for key, value in h5.iteritems() if key.startswith('data')], key=lambda lod: -lod.shape[3]) 261 | shape = lods[0].shape 262 | indices = range(shape[0])[start: stop: step] 263 | if not os.path.isdir(output_dir): 264 | os.makedirs(output_dir) 265 | 266 | for idx in indices: 267 | print('%d / %d\r' % (idx, shape[0])) 268 | img = lods[0][idx] 269 | if img.shape[0] == 1: 270 | img = PIL.Image.fromarray(img[0], 'L') 271 | else: 272 | img = PIL.Image.fromarray(img.transpose(1, 2, 0), 'RGB') 273 | img.save(os.path.join(output_dir, 'img%08d.png' % idx)) 274 | 275 | h5.close() 276 | print('%-40s\r' % '') 277 | print('Extracted %d images.' % len(indices)) 278 | 279 | 280 | # ---------------------------------------------------------------------------- 281 | 282 | def create_custom(h5_filename, image_dir): 283 | print('Creating custom dataset %s from %s' % (h5_filename, image_dir)) 284 | glob_pattern = os.path.join(image_dir, '*') 285 | image_filenames = sorted(glob.glob(glob_pattern)) 286 | if len(image_filenames) == 0: 287 | print('Error: No input images found in %s' % glob_pattern) 288 | return 289 | 290 | img = np.asarray(PIL.Image.open(image_filenames[0])) 291 | resolution = img.shape[0] 292 | channels = img.shape[2] if img.ndim == 3 else 1 293 | if img.shape[1] != resolution: 294 | print('Error: Input images must have the same width and height') 295 | return 296 | if resolution != 2 ** int(np.floor(np.log2(resolution))): 297 | print('Error: Input image resolution must be a power-of-two') 298 | return 299 | if channels not in [1, 3]: 300 | print('Error: Input images must be stored as RGB or grayscale') 301 | 302 | h5 = HDF5Exporter(h5_filename, resolution, channels) 303 | for idx in xrange(len(image_filenames)): 304 | print('%d / %d\r' % (idx, len(image_filenames))) 305 | img = np.asarray(PIL.Image.open(image_filenames[idx])) 306 | if channels == 1: 307 | img = img[np.newaxis, :, :] # HW => CHW 308 | else: 309 | img = img.transpose(2, 0, 1) # HWC => CHW 310 | h5.add_images(img[np.newaxis]) 311 | 312 | print('%-40s\r' % 'Flushing data...') 313 | h5.close() 314 | print('%-40s\r' % '') 315 | print('Added %d images.' % len(image_filenames)) 316 | 317 | 318 | # ---------------------------------------------------------------------------- 319 | 320 | def create_mnist(h5_filename, mnist_dir, export_labels=False): 321 | print('Loading MNIST data from %s' % mnist_dir) 322 | import gzip 323 | with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file: 324 | images = np.frombuffer(file.read(), np.uint8, offset=16) 325 | with gzip.open(os.path.join(mnist_dir, 'train-labels-idx1-ubyte.gz'), 'rb') as file: 326 | labels = np.frombuffer(file.read(), np.uint8, offset=8) 327 | images = images.reshape(-1, 1, 28, 28) 328 | images = np.pad(images, [(0, 0), (0, 0), (2, 2), (2, 2)], 'constant', constant_values=0) 329 | assert images.shape == (60000, 1, 32, 32) and images.dtype == np.uint8 330 | assert labels.shape == (60000,) and labels.dtype == np.uint8 331 | assert np.min(images) == 0 and np.max(images) == 255 332 | assert np.min(labels) == 0 and np.max(labels) == 9 333 | 334 | print('Creating %s' % h5_filename) 335 | h5 = HDF5Exporter(h5_filename, 32, 1) 336 | h5.add_images(images) 337 | h5.close() 338 | 339 | if export_labels: 340 | npy_filename = os.path.splitext(h5_filename)[0] + '-labels.npy' 341 | print('Creating %s' % npy_filename) 342 | onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32) 343 | onehot[np.arange(labels.size), labels] = 1.0 344 | np.save(npy_filename, onehot) 345 | print('Added %d images.' % images.shape[0]) 346 | 347 | 348 | # ---------------------------------------------------------------------------- 349 | 350 | def create_mnist_rgb(h5_filename, mnist_dir, num_images=1000000, random_seed=123): 351 | print('Loading MNIST data from %s' % mnist_dir) 352 | import gzip 353 | with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file: 354 | images = np.frombuffer(file.read(), np.uint8, offset=16) 355 | images = images.reshape(-1, 28, 28) 356 | images = np.pad(images, [(0, 0), (2, 2), (2, 2)], 'constant', constant_values=0) 357 | assert images.shape == (60000, 32, 32) and images.dtype == np.uint8 358 | assert np.min(images) == 0 and np.max(images) == 255 359 | 360 | print('Creating %s' % h5_filename) 361 | h5 = HDF5Exporter(h5_filename, 32, 3) 362 | np.random.seed(random_seed) 363 | for idx in xrange(num_images): 364 | if idx % 100 == 0: 365 | print('%d / %d\r' % (idx, num_images)) 366 | h5.add_images(images[np.newaxis, np.random.randint(images.shape[0], size=3)]) 367 | 368 | print('%-40s\r' % 'Flushing data...') 369 | h5.close() 370 | print('%-40s\r' % '') 371 | print('Added %d images.' % num_images) 372 | 373 | 374 | # ---------------------------------------------------------------------------- 375 | 376 | def create_cifar10(h5_filename, cifar10_dir, export_labels=False): 377 | print('Loading CIFAR-10 data from %s' % cifar10_dir) 378 | images = [] 379 | labels = [] 380 | for batch in xrange(1, 6): 381 | with open(os.path.join(cifar10_dir, 'data_batch_%d' % batch), 'rb') as file: 382 | data = pickle.load(file) 383 | images.append(data['data'].reshape(-1, 3, 32, 32)) 384 | labels.append(np.uint8(data['labels'])) 385 | images = np.concatenate(images) 386 | labels = np.concatenate(labels) 387 | 388 | assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8 389 | assert labels.shape == (50000,) and labels.dtype == np.uint8 390 | assert np.min(images) == 0 and np.max(images) == 255 391 | assert np.min(labels) == 0 and np.max(labels) == 9 392 | 393 | print('Creating %s' % h5_filename) 394 | h5 = HDF5Exporter(h5_filename, 32, 3) 395 | h5.add_images(images) 396 | h5.close() 397 | 398 | if export_labels: 399 | npy_filename = os.path.splitext(h5_filename)[0] + '-labels.npy' 400 | print('Creating %s' % npy_filename) 401 | onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32) 402 | onehot[np.arange(labels.size), labels] = 1.0 403 | np.save(npy_filename, onehot) 404 | print('Added %d images.' % images.shape[0]) 405 | 406 | 407 | # ---------------------------------------------------------------------------- 408 | 409 | def create_lsun(h5_filename, lmdb_dir, resolution=256, max_images=None): 410 | print('Creating LSUN dataset %s from %s' % (h5_filename, lmdb_dir)) 411 | import lmdb # pip install lmdb 412 | import cv2 # pip install opencv-python 413 | with lmdb.open(lmdb_dir, readonly=True).begin(write=False) as txn: 414 | total_images = txn.stat()['entries'] 415 | if max_images is None: 416 | max_images = total_images 417 | 418 | h5 = HDF5Exporter(h5_filename, resolution, 3) 419 | for idx, (key, value) in enumerate(txn.cursor()): 420 | print('%d / %d\r' % (h5.num_images(), min(h5.num_images() + total_images - idx, max_images))) 421 | try: 422 | try: 423 | img = cv2.imdecode(np.fromstring(value, dtype=np.uint8), 1) 424 | if img is None: 425 | raise IOError('cv2.imdecode failed') 426 | img = img[:, :, ::-1] # BGR => RGB 427 | except IOError: 428 | img = np.asarray(PIL.Image.open(io.BytesIO(value))) 429 | crop = np.min(img.shape[:2]) 430 | img = img[(img.shape[0] - crop) / 2: (img.shape[0] + crop) / 2, 431 | (img.shape[1] - crop) / 2: (img.shape[1] + crop) / 2] 432 | img = PIL.Image.fromarray(img, 'RGB') 433 | img = img.resize((resolution, resolution), PIL.Image.ANTIALIAS) 434 | img = np.asarray(img) 435 | img = img.transpose(2, 0, 1) # HWC => CHW 436 | h5.add_images(img[np.newaxis]) 437 | except: 438 | print('%-40s\r' % '') 439 | print(sys.exc_info()[1]) 440 | raise 441 | if h5.num_images() == max_images: 442 | break 443 | 444 | print('%-40s\r' % 'Flushing data...') 445 | num_added = h5.num_images() 446 | h5.close() 447 | print('%-40s\r' % '') 448 | print('Added %d images.' % num_added) 449 | 450 | 451 | # ---------------------------------------------------------------------------- 452 | 453 | def create_celeba(h5_filename, celeba_dir, cx=89, cy=121): 454 | print('Creating CelebA dataset %s from %s' % (h5_filename, celeba_dir)) 455 | glob_pattern = os.path.join(celeba_dir, 'img_align_celeba_png', '*.png') 456 | image_filenames = sorted(glob.glob(glob_pattern)) 457 | num_images = 202599 458 | if len(image_filenames) != num_images: 459 | print('Error: Expected to find %d images in %s' % (num_images, glob_pattern)) 460 | return 461 | 462 | h5 = HDF5Exporter(h5_filename, 128, 3) 463 | for idx in xrange(num_images): 464 | print('%d / %d\r' % (idx, num_images)) 465 | img = np.asarray(PIL.Image.open(image_filenames[idx])) 466 | assert img.shape == (218, 178, 3) 467 | img = img[cy - 64: cy + 64, cx - 64: cx + 64] 468 | img = img.transpose(2, 0, 1) # HWC => CHW 469 | h5.add_images(img[np.newaxis]) 470 | 471 | print('%-40s\r' % 'Flushing data...') 472 | h5.close() 473 | print('%-40s\r' % '') 474 | print('Added %d images.' % num_images) 475 | 476 | 477 | # ---------------------------------------------------------------------------- 478 | 479 | def create_celeba_hq(h5_filename, celeba_dir, delta_dir, num_threads=4, num_tasks=100): 480 | print('Loading CelebA data from %s' % celeba_dir) 481 | glob_pattern = os.path.join(celeba_dir, '*.jpg') 482 | glob_expected = 202599 483 | if len(glob.glob(glob_pattern)) != glob_expected: 484 | print('Error: Expected to find %d images in %s' % (glob_expected, glob_pattern)) 485 | return 486 | with open(os.path.join(celeba_dir, 'list_landmarks_celeba.txt'), 'rt') as file: 487 | landmarks = [[float(value) for value in line.split()[1:]] for line in file.readlines()[2:]] 488 | landmarks = np.float32(landmarks).reshape(-1, 5, 2) 489 | 490 | print('Loading CelebA-HQ deltas from %s' % delta_dir) 491 | import hashlib 492 | import bz2 493 | import zipfile 494 | import base64 495 | import cryptography.hazmat.primitives.hashes 496 | import cryptography.hazmat.backends 497 | import cryptography.hazmat.primitives.kdf.pbkdf2 498 | import cryptography.fernet 499 | glob_pattern = os.path.join(delta_dir, 'delta*.zip') 500 | glob_expected = 30 501 | if len(glob.glob(glob_pattern)) != glob_expected: 502 | print('Error: Expected to find %d zips in %s' % (glob_expected, glob_pattern)) 503 | return 504 | with open(os.path.join(delta_dir, 'image_list.txt'), 'rt') as file: 505 | lines = [line.split() for line in file] 506 | fields = dict() 507 | for idx, field in enumerate(lines[0]): 508 | type = int if field.endswith('idx') else str 509 | fields[field] = [type(line[idx]) for line in lines[1:]] 510 | 511 | def rot90(v): 512 | return np.array([-v[1], v[0]]) 513 | 514 | def process_func(idx): 515 | # Load original image. 516 | orig_idx = fields['orig_idx'][idx] 517 | orig_file = fields['orig_file'][idx] 518 | orig_path = os.path.join(celeba_dir, orig_file) 519 | img = PIL.Image.open(orig_path) 520 | 521 | # Choose oriented crop rectangle. 522 | lm = landmarks[orig_idx] 523 | eye_avg = (lm[0] + lm[1]) * 0.5 + 0.5 524 | mouth_avg = (lm[3] + lm[4]) * 0.5 + 0.5 525 | eye_to_eye = lm[1] - lm[0] 526 | eye_to_mouth = mouth_avg - eye_avg 527 | x = eye_to_eye - rot90(eye_to_mouth) 528 | x /= np.hypot(*x) 529 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 530 | y = rot90(x) 531 | c = eye_avg + eye_to_mouth * 0.1 532 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 533 | zoom = 1024 / (np.hypot(*x) * 2) 534 | 535 | # Shrink. 536 | shrink = int(np.floor(0.5 / zoom)) 537 | if shrink > 1: 538 | size = (int(np.round(float(img.size[0]) / shrink)), int(np.round(float(img.size[1]) / shrink))) 539 | img = img.resize(size, PIL.Image.ANTIALIAS) 540 | quad /= shrink 541 | zoom *= shrink 542 | 543 | # Crop. 544 | border = max(int(np.round(1024 * 0.1 / zoom)), 3) 545 | crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), 546 | int(np.ceil(max(quad[:, 1])))) 547 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), 548 | min(crop[3] + border, img.size[1])) 549 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 550 | img = img.crop(crop) 551 | quad -= crop[0:2] 552 | 553 | # Simulate super-resolution. 554 | superres = int(np.exp2(np.ceil(np.log2(zoom)))) 555 | if superres > 1: 556 | img = img.resize((img.size[0] * superres, img.size[1] * superres), PIL.Image.ANTIALIAS) 557 | quad *= superres 558 | zoom /= superres 559 | 560 | # Pad. 561 | pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), 562 | int(np.ceil(max(quad[:, 1])))) 563 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), 564 | max(pad[3] - img.size[1] + border, 0)) 565 | if max(pad) > border - 4: 566 | pad = np.maximum(pad, int(np.round(1024 * 0.3 / zoom))) 567 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 568 | h, w, _ = img.shape 569 | y, x, _ = np.mgrid[:h, :w, :1] 570 | mask = 1.0 - np.minimum(np.minimum(np.float32(x) / pad[0], np.float32(y) / pad[1]), 571 | np.minimum(np.float32(w - 1 - x) / pad[2], np.float32(h - 1 - y) / pad[3])) 572 | blur = 1024 * 0.02 / zoom 573 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 574 | img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) 575 | img = PIL.Image.fromarray(np.uint8(np.clip(np.round(img), 0, 255)), 'RGB') 576 | quad += pad[0:2] 577 | 578 | # Transform. 579 | img = img.transform((4096, 4096), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) 580 | img = img.resize((1024, 1024), PIL.Image.ANTIALIAS) 581 | img = np.asarray(img).transpose(2, 0, 1) 582 | 583 | # Verify MD5. 584 | md5 = hashlib.md5() 585 | md5.update(img.tobytes()) 586 | # assert md5.hexdigest() == fields['proc_md5'][idx] # disable md5 verify 587 | 588 | # Load delta image and original JPG. 589 | with zipfile.ZipFile(os.path.join(delta_dir, 'deltas%05d.zip' % (idx - idx % 1000)), 'r') as zip: 590 | delta_bytes = zip.read('delta%05d.dat' % idx) 591 | with open(orig_path, 'rb') as file: 592 | orig_bytes = file.read() 593 | 594 | # Decrypt delta image, using original JPG data as decryption key. 595 | algorithm = cryptography.hazmat.primitives.hashes.SHA256() 596 | backend = cryptography.hazmat.backends.default_backend() 597 | kdf = cryptography.hazmat.primitives.kdf.pbkdf2.PBKDF2HMAC(algorithm=algorithm, length=32, salt=orig_file, 598 | iterations=100000, backend=backend) 599 | key = base64.urlsafe_b64encode(kdf.derive(orig_bytes)) 600 | delta = np.frombuffer(bz2.decompress(cryptography.fernet.Fernet(key).decrypt(delta_bytes)), 601 | dtype=np.uint8).reshape(3, 1024, 1024) 602 | 603 | # Apply delta image. 604 | img = img + delta 605 | 606 | # Verify MD5. 607 | md5 = hashlib.md5() 608 | md5.update(img.tobytes()) 609 | # assert md5.hexdigest() == fields['final_md5'][idx] # disable md5 verify 610 | return idx, img 611 | 612 | print('Creating %s' % h5_filename) 613 | h5 = HDF5Exporter(h5_filename, 1024, 3) 614 | with ThreadPool(num_threads) as pool: 615 | print('%d / %d\r' % (0, len(fields['idx']))) 616 | for idx, img in pool.process_items_concurrently(fields['idx'], process_func=process_func, 617 | max_items_in_flight=num_tasks): 618 | h5.add_images(img[np.newaxis]) 619 | print('%d / %d\r' % (idx + 1, len(fields['idx']))) 620 | 621 | print('%-40s\r' % 'Flushing data...') 622 | h5.close() 623 | print('%-40s\r' % '') 624 | print('Added %d images.' % len(fields['idx'])) 625 | 626 | 627 | # ---------------------------------------------------------------------------- 628 | 629 | def execute_cmdline(argv): 630 | prog = argv[0] 631 | parser = argparse.ArgumentParser( 632 | prog=prog, 633 | description='Tool for creating, extracting, and visualizing HDF5 datasets.', 634 | epilog='Type "%s -h" for more information.' % prog) 635 | 636 | subparsers = parser.add_subparsers(dest='command') 637 | 638 | def add_command(cmd, desc, example=None): 639 | epilog = 'Example: %s %s' % (prog, example) if example is not None else None 640 | return subparsers.add_parser(cmd, description=desc, help=desc, epilog=epilog) 641 | 642 | p = add_command('inspect', 'Print information about HDF5 dataset.', 643 | 'inspect mnist-32x32.h5') 644 | p.add_argument('h5_filename', help='HDF5 file to inspect') 645 | 646 | p = add_command('compare', 'Compare two HDF5 datasets.', 647 | 'compare mydataset.h5 mnist-32x32.h5') 648 | p.add_argument('first_h5', help='First HDF5 file to compare') 649 | p.add_argument('second_h5', help='Second HDF5 file to compare') 650 | 651 | p = add_command('display', 'Display images in HDF5 dataset.', 652 | 'display mnist-32x32.h5') 653 | p.add_argument('h5_filename', help='HDF5 file to visualize') 654 | p.add_argument('--start', help='Start index (inclusive)', type=int, default=None) 655 | p.add_argument('--stop', help='Stop index (exclusive)', type=int, default=None) 656 | p.add_argument('--step', help='Step between consecutive indices', type=int, default=None) 657 | 658 | p = add_command('extract', 'Extract images from HDF5 dataset.', 659 | 'extract mnist-32x32.h5 cifar10-images') 660 | p.add_argument('h5_filename', help='HDF5 file to extract') 661 | p.add_argument('output_dir', help='Directory to extract the images into') 662 | p.add_argument('--start', help='Start index (inclusive)', type=int, default=None) 663 | p.add_argument('--stop', help='Stop index (exclusive)', type=int, default=None) 664 | p.add_argument('--step', help='Step between consecutive indices', type=int, default=None) 665 | 666 | p = add_command('create_custom', 'Create HDF5 dataset for custom images.', 667 | 'create_custom mydataset.h5 myimagedir') 668 | p.add_argument('h5_filename', help='HDF5 file to create') 669 | p.add_argument('image_dir', help='Directory to read the images from') 670 | 671 | p = add_command('create_mnist', 'Create HDF5 dataset for MNIST.', 672 | 'create_mnist mnist-32x32.h5 ~/mnist --export_labels') 673 | p.add_argument('h5_filename', help='HDF5 file to create') 674 | p.add_argument('mnist_dir', help='Directory to read MNIST data from') 675 | p.add_argument('--export_labels', help='Create *-labels.npy alongside the HDF5', action='store_true') 676 | 677 | p = add_command('create_mnist_rgb', 'Create HDF5 dataset for MNIST-RGB.', 678 | 'create_mnist_rgb mnist-rgb-32x32.h5 ~/mnist') 679 | p.add_argument('h5_filename', help='HDF5 file to create') 680 | p.add_argument('mnist_dir', help='Directory to read MNIST data from') 681 | p.add_argument('--num_images', help='Number of composite images to create (default: 1000000)', type=int, 682 | default=1000000) 683 | p.add_argument('--random_seed', help='Random seed (default: 123)', type=int, default=123) 684 | 685 | p = add_command('create_cifar10', 'Create HDF5 dataset for CIFAR-10.', 686 | 'create_cifar10 cifar-10-32x32.h5 ~/cifar10 --export_labels') 687 | p.add_argument('h5_filename', help='HDF5 file to create') 688 | p.add_argument('cifar10_dir', help='Directory to read CIFAR-10 data from') 689 | p.add_argument('--export_labels', help='Create *-labels.npy alongside the HDF5', action='store_true') 690 | 691 | p = add_command('create_lsun', 'Create HDF5 dataset for single LSUN category.', 692 | 'create_lsun lsun-airplane-256x256-100k.h5 ~/lsun/airplane_lmdb --resolution 256 --max_images 100000') 693 | p.add_argument('h5_filename', help='HDF5 file to create') 694 | p.add_argument('lmdb_dir', help='Directory to read LMDB database from') 695 | p.add_argument('--resolution', help='Output resolution (default: 256)', type=int, default=256) 696 | p.add_argument('--max_images', help='Maximum number of images (default: none)', type=int, default=None) 697 | 698 | p = add_command('create_celeba', 'Create HDF5 dataset for CelebA.', 699 | 'create_celeba celeba-128x128.h5 ~/celeba') 700 | p.add_argument('h5_filename', help='HDF5 file to create') 701 | p.add_argument('celeba_dir', help='Directory to read CelebA data from') 702 | p.add_argument('--cx', help='Center X coordinate (default: 89)', type=int, default=89) 703 | p.add_argument('--cy', help='Center Y coordinate (default: 121)', type=int, default=121) 704 | 705 | p = add_command('create_celeba_hq', 'Create HDF5 dataset for CelebA-HQ.', 706 | 'create_celeba_hq celeba-hq-1024x1024.h5 ~/celeba ~/celeba-hq-deltas') 707 | p.add_argument('h5_filename', help='HDF5 file to create') 708 | p.add_argument('celeba_dir', help='Directory to read CelebA data from') 709 | p.add_argument('delta_dir', help='Directory to read CelebA-HQ deltas from') 710 | p.add_argument('--num_threads', help='Number of concurrent threads (default: 4)', type=int, default=4) 711 | p.add_argument('--num_tasks', help='Number of concurrent processing tasks (default: 100)', type=int, default=100) 712 | 713 | args = parser.parse_args(argv[1:]) 714 | func = globals()[args.command] 715 | del args.command 716 | func(**vars(args)) 717 | 718 | 719 | # ---------------------------------------------------------------------------- 720 | 721 | if __name__ == "__main__": 722 | execute_cmdline(sys.argv) -------------------------------------------------------------------------------- /images/figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangqianhui/progressive_growing_of_gans_tensorflow/e16b097117169e9521104138d4d461b2aef5a5fb/images/figure.png -------------------------------------------------------------------------------- /images/hs_sample_128.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangqianhui/progressive_growing_of_gans_tensorflow/e16b097117169e9521104138d4d461b2aef5a5fb/images/hs_sample_128.jpg -------------------------------------------------------------------------------- /images/hs_sample_64.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangqianhui/progressive_growing_of_gans_tensorflow/e16b097117169e9521104138d4d461b2aef5a5fb/images/hs_sample_64.jpg -------------------------------------------------------------------------------- /images/sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangqianhui/progressive_growing_of_gans_tensorflow/e16b097117169e9521104138d4d461b2aef5a5fb/images/sample.png -------------------------------------------------------------------------------- /images/sample_128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangqianhui/progressive_growing_of_gans_tensorflow/e16b097117169e9521104138d4d461b2aef5a5fb/images/sample_128.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from utils import mkdir_p 4 | from PGGAN import PGGAN 5 | from utils import CelebA, CelebA_HQ 6 | flags = tf.app.flags 7 | import os 8 | 9 | os.environ['CUDA_VISIBLE_DEVICES']='0' 10 | 11 | flags.DEFINE_string("OPER_NAME", "Experiment_6_30_1", "the name of experiments") 12 | flags.DEFINE_integer("OPER_FLAG", 0, "Flag of opertion: 0 is for training ") 13 | flags.DEFINE_string("path" , '?', "Path of training data, for example /home/hehe/") 14 | flags.DEFINE_integer("batch_size", 16, "Batch size") 15 | flags.DEFINE_integer("sample_size", 512, "Size of sample") 16 | flags.DEFINE_integer("max_iters", 40000, "Maxmization of training number") 17 | flags.DEFINE_float("learn_rate", 0.001, "Learning rate for G and D networks") 18 | flags.DEFINE_integer("lam_gp", 10, "Weight of gradient penalty term") 19 | flags.DEFINE_float("lam_eps", 0.001, "Weight for the epsilon term") 20 | flags.DEFINE_integer("flag", 11, "FLAG of gan training process") 21 | flags.DEFINE_boolean("use_wscale", True, "Using the scale of weight") 22 | flags.DEFINE_boolean("celeba", True, "Whether using celeba or using CelebA-HQ") 23 | 24 | FLAGS = flags.FLAGS 25 | if __name__ == "__main__": 26 | 27 | root_log_dir = "./output/{}/logs/".format(FLAGS.OPER_NAME) 28 | mkdir_p(root_log_dir) 29 | 30 | if FLAGS.celeba: 31 | data_In = CelebA(FLAGS.path) 32 | else: 33 | data_In = CelebA_HQ(FLAGS.path) 34 | 35 | print ("the num of dataset", len(data_In.image_list)) 36 | 37 | if FLAGS.OPER_FLAG == 0: 38 | 39 | fl = [1,2,2,3,3,4,4,5,5,6,6] 40 | r_fl = [1,1,2,2,3,3,4,4,5,5,6] 41 | 42 | for i in range(FLAGS.flag): 43 | 44 | t = False if (i % 2 == 0) else True 45 | pggan_checkpoint_dir_write = "./output/{}/model_pggan_{}/{}/".format(FLAGS.OPER_NAME, FLAGS.OPER_FLAG, fl[i]) 46 | sample_path = "./output/{}/{}/sample_{}_{}".format(FLAGS.OPER_NAME, FLAGS.OPER_FLAG, fl[i], t) 47 | mkdir_p(pggan_checkpoint_dir_write) 48 | mkdir_p(sample_path) 49 | pggan_checkpoint_dir_read = "./output/{}/model_pggan_{}/{}/".format(FLAGS.OPER_NAME, FLAGS.OPER_FLAG, r_fl[i]) 50 | 51 | pggan = PGGAN(batch_size=FLAGS.batch_size, max_iters=FLAGS.max_iters, 52 | model_path=pggan_checkpoint_dir_write, read_model_path=pggan_checkpoint_dir_read, 53 | data=data_In, sample_size=FLAGS.sample_size, 54 | sample_path=sample_path, log_dir=root_log_dir, learn_rate=FLAGS.learn_rate, lam_gp=FLAGS.lam_gp, lam_eps=FLAGS.lam_eps, PG= fl[i], 55 | t=t, use_wscale=FLAGS.use_wscale, is_celeba=FLAGS.celeba) 56 | 57 | pggan.build_model_PGGan() 58 | pggan.train() 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.layers.python.layers import batch_norm 3 | import numpy as np 4 | 5 | # the implements of leakyRelu 6 | def lrelu(x , alpha=0.2 , name="LeakyReLU"): 7 | with tf.name_scope(name): 8 | return tf.maximum(x , alpha*x) 9 | 10 | def get_weight(shape, gain=np.sqrt(2), use_wscale=False, fan_in=None): 11 | if fan_in is None: 12 | fan_in = np.prod(shape[:-1]) 13 | print "current", shape[:-1], fan_in 14 | std = gain / np.sqrt(fan_in) # He init 15 | 16 | if use_wscale: 17 | wscale = tf.constant(np.float32(std), name='wscale') 18 | return tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal()) * wscale 19 | else: 20 | return tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal(0, std)) 21 | 22 | def conv2d(input_, output_dim, 23 | k_h=3, k_w=3, d_h=2, d_w=2, gain=np.sqrt(2), use_wscale=False, padding='SAME', 24 | name="conv2d", with_w=False): 25 | with tf.variable_scope(name): 26 | 27 | w = get_weight([k_h, k_w, input_.shape[-1].value, output_dim], gain=gain, use_wscale=use_wscale) 28 | w = tf.cast(w, input_.dtype) 29 | 30 | if padding == 'Other': 31 | padding = 'VALID' 32 | input_ = tf.pad(input_, [[0,0], [3, 3], [3, 3], [0, 0]], "CONSTANT") 33 | 34 | elif padding == 'VALID': 35 | padding = 'VALID' 36 | 37 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding=padding) 38 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 39 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 40 | 41 | if with_w: 42 | return conv, w, biases 43 | 44 | else: 45 | return conv 46 | 47 | def fully_connect(input_, output_size, gain=np.sqrt(2), use_wscale=False, name=None, with_w=False): 48 | shape = input_.get_shape().as_list() 49 | with tf.variable_scope(name or "Linear"): 50 | 51 | w = get_weight([shape[1], output_size], gain=gain, use_wscale=use_wscale) 52 | w = tf.cast(w, input_.dtype) 53 | bias = tf.get_variable("bias", [output_size], initializer=tf.constant_initializer(0.0)) 54 | 55 | output = tf.matmul(input_, w) + bias 56 | 57 | if with_w: 58 | return output, with_w, bias 59 | 60 | else: 61 | return output 62 | 63 | def conv_cond_concat(x, y): 64 | x_shapes = x.get_shape() 65 | y_shapes = y.get_shape() 66 | return tf.concat(3 , [x , y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2] , y_shapes[3]])]) 67 | 68 | def batch_normal(input , scope="scope" , reuse=False): 69 | return batch_norm(input , epsilon=1e-5, decay=0.9 , scale=True, scope=scope , reuse= reuse , updates_collections=None) 70 | 71 | def resize_nearest_neighbor(x, new_size): 72 | x = tf.image.resize_nearest_neighbor(x, new_size) 73 | return x 74 | 75 | def upscale(x, scale): 76 | _, h, w, _ = get_conv_shape(x) 77 | return resize_nearest_neighbor(x, (h * scale, w * scale)) 78 | 79 | def get_conv_shape(tensor): 80 | shape = int_shape(tensor) 81 | return shape 82 | 83 | def int_shape(tensor): 84 | shape = tensor.get_shape().as_list() 85 | return [num if num is not None else -1 for num in shape] 86 | 87 | def downscale2d(x, k=2): 88 | # avgpool wrapper 89 | return tf.nn.avg_pool(x, ksize=[1, k, k, 1], strides=[1, k, k, 1], 90 | padding='VALID') 91 | 92 | def Pixl_Norm(x, eps=1e-8): 93 | if len(x.shape) > 2: 94 | axis_ = 3 95 | else: 96 | axis_ = 1 97 | with tf.variable_scope('PixelNorm'): 98 | return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=axis_, keep_dims=True) + eps) 99 | 100 | def MinibatchstateConcat(input, averaging='all'): 101 | s = input.shape 102 | adjusted_std = lambda x, **kwargs: tf.sqrt(tf.reduce_mean((x - tf.reduce_mean(x, **kwargs)) **2, **kwargs) + 1e-8) 103 | vals = adjusted_std(input, axis=0, keep_dims=True) 104 | if averaging == 'all': 105 | vals = tf.reduce_mean(vals, keep_dims=True) 106 | else: 107 | print ("nothing") 108 | 109 | vals = tf.tile(vals, multiples=[s[0], s[1], s[2], 1]) 110 | return tf.concat([input, vals], axis=3) 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import numpy as np 4 | import scipy 5 | import scipy.misc 6 | import h5py 7 | 8 | def mkdir_p(path): 9 | try: 10 | os.makedirs(path) 11 | except OSError as exc: # Python >2.5 12 | if exc.errno == errno.EEXIST and os.path.isdir(path): 13 | pass 14 | else: 15 | raise 16 | 17 | class CelebA(object): 18 | def __init__(self, image_path): 19 | 20 | self.dataname = "CelebA" 21 | self.channel = 3 22 | self.image_list = self.load_celebA(image_path=image_path) 23 | 24 | def load_celebA(self, image_path): 25 | 26 | # get the list of image path 27 | images_list = read_image_list(image_path) 28 | # get the data array of image 29 | 30 | return images_list 31 | 32 | def getShapeForData(self, filenames, resize_w=64): 33 | array = [get_image(batch_file, 128, is_crop=True, resize_w=resize_w, 34 | is_grayscale=False) for batch_file in filenames] 35 | 36 | sample_images = np.array(array) 37 | # return sub_image_mean(array , IMG_CHANNEL) 38 | return sample_images 39 | 40 | def getNextBatch(self, batch_num=0, batch_size=64): 41 | ro_num = len(self.image_list) / batch_size - 1 42 | if batch_num % ro_num == 0: 43 | 44 | length = len(self.image_list) 45 | perm = np.arange(length) 46 | np.random.shuffle(perm) 47 | self.image_list = np.array(self.image_list) 48 | self.image_list = self.image_list[perm] 49 | 50 | print ("images shuffle") 51 | 52 | return self.image_list[(batch_num % ro_num) * batch_size: (batch_num % ro_num + 1) * batch_size] 53 | 54 | class CelebA_HQ(object): 55 | def __init__(self, image_path): 56 | self.dataname = "CelebA_HQ" 57 | resolution = ['data2x2', 'data4x4', 'data8x8', 'data16x16', 'data32x32', 'data64x64', \ 58 | 'data128x128', 'data256x256', 'data512x512', 'data1024x1024'] 59 | self.channel = 3 60 | self.image_list = self.load_celeba_hq(image_path=image_path) 61 | self._base_key = 'data' 62 | self._len = {k: len(self.image_list[k]) for k in resolution} 63 | 64 | def load_celeba_hq(self, image_path): 65 | # get the list of image path 66 | images_list = h5py.File(os.path.join(image_path, "celebA_hq"), 'r') 67 | # get the data array of image 68 | return images_list 69 | 70 | def getNextBatch(self, batch_size=64, resize_w=64): 71 | key = self._base_key + '{}x{}'.format(resize_w, resize_w) 72 | idx = np.random.randint(self._len[key], size=batch_size) 73 | batch_x = np.array([self.image_list[key][i] / 127.5 - 1.0 for i in idx], dtype=np.float32) 74 | 75 | return batch_x 76 | 77 | def get_image(image_path , image_size, is_crop=True, resize_w=64, is_grayscale=False): 78 | return transform(imread(image_path , is_grayscale), image_size, is_crop , resize_w) 79 | 80 | def get_image_dat(image_path , image_size, is_crop=True, resize_w=64, is_grayscale=False): 81 | return transform(imread_dat(image_path , is_grayscale), image_size, is_crop , resize_w) 82 | 83 | def transform(image, npx=64 , is_crop=False, resize_w=64): 84 | # npx : # of pixels width/height of image 85 | if is_crop: 86 | cropped_image = center_crop(image , npx , resize_w = resize_w) 87 | else: 88 | cropped_image = image 89 | cropped_image = scipy.misc.imresize(cropped_image , 90 | [resize_w , resize_w]) 91 | 92 | return np.array(cropped_image)/127.5 - 1 93 | 94 | def center_crop(x, crop_h, crop_w=None, resize_w=64): 95 | if crop_w is None: 96 | crop_w = crop_h 97 | h, w = x.shape[:2] 98 | j = int(round((h - crop_h)/2.)) 99 | i = int(round((w - crop_w)/2.)) 100 | 101 | rate = np.random.uniform(0, 1, size=1) 102 | 103 | if rate < 0.5: 104 | x = np.fliplr(x) 105 | 106 | return scipy.misc.imresize(x[j:j + crop_h, i:i + crop_w], 107 | [resize_w, resize_w]) 108 | 109 | # return scipy.misc.imresize(x[20:218 - 20, 0: 178], [resize_w, resize_w]) 110 | 111 | # return scipy.misc.imresize(x[45: 45 + 128, 25:25 + 128], [resize_w, resize_w]) 112 | 113 | def save_images(images, size, image_path): 114 | return imsave(inverse_transform(images), size, image_path) 115 | 116 | def imread(path, is_grayscale=False): 117 | if (is_grayscale): 118 | return scipy.misc.imread(path, flatten=True).astype(np.float) 119 | else: 120 | return scipy.misc.imread(path).astype(np.float) 121 | 122 | def imread_dat(path, is_grayscale): 123 | return np.load(path) 124 | 125 | def imsave(images, size, path): 126 | return scipy.misc.imsave(path, merge(images, size)) 127 | 128 | def merge(images, size): 129 | h, w = images.shape[1], images.shape[2] 130 | img = np.zeros((h * size[0], w * size[1], 3)) 131 | for idx, image in enumerate(images): 132 | i = idx % size[1] 133 | j = idx // size[1] 134 | img[j * h:j * h + h, i * w: i * w + w, :] = image 135 | return img 136 | 137 | def inverse_transform(image): 138 | return ((image + 1.)* 127.5).astype(np.uint8) 139 | 140 | def read_image_list(category): 141 | filenames = [] 142 | print("list file") 143 | list = os.listdir(category) 144 | list.sort() 145 | for file in list: 146 | if 'jpg' in file: 147 | filenames.append(category + "/" + file) 148 | print("list file ending!") 149 | length = len(filenames) 150 | perm = np.arange(length) 151 | np.random.shuffle(perm) 152 | filenames = np.array(filenames) 153 | filenames = filenames[perm] 154 | 155 | return filenames 156 | 157 | 158 | 159 | 160 | --------------------------------------------------------------------------------