├── LICENSE
├── PGGAN.py
├── README.md
├── download.py
├── h5tool.py
├── images
├── figure.png
├── hs_sample_128.jpg
├── hs_sample_64.jpg
├── sample.png
└── sample_128.png
├── main.py
├── ops.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 JiChao Zhang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/PGGAN.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from ops import lrelu, conv2d, fully_connect, upscale, Pixl_Norm, downscale2d, MinibatchstateConcat
3 | from utils import save_images
4 | import numpy as np
5 | from scipy.ndimage.interpolation import zoom
6 |
7 | class PGGAN(object):
8 |
9 | # build model
10 | def __init__(self, batch_size, max_iters, model_path, read_model_path, data, sample_size, sample_path, log_dir,
11 | learn_rate, lam_gp, lam_eps, PG, t, use_wscale, is_celeba):
12 | self.batch_size = batch_size
13 | self.max_iters = max_iters
14 | self.gan_model_path = model_path
15 | self.read_model_path = read_model_path
16 | self.data_In = data
17 | self.sample_size = sample_size
18 | self.sample_path = sample_path
19 | self.log_dir = log_dir
20 | self.learning_rate = learn_rate
21 | self.lam_gp = lam_gp
22 | self.lam_eps = lam_eps
23 | self.pg = PG
24 | self.trans = t
25 | self.log_vars = []
26 | self.channel = self.data_In.channel
27 | self.output_size = 4 * pow(2, PG - 1)
28 | self.use_wscale = use_wscale
29 | self.is_celeba = is_celeba
30 | self.images = tf.placeholder(tf.float32, [batch_size, self.output_size, self.output_size, self.channel])
31 | self.z = tf.placeholder(tf.float32, [self.batch_size, self.sample_size])
32 | self.alpha_tra = tf.Variable(initial_value=0.0, trainable=False,name='alpha_tra')
33 |
34 | def build_model_PGGan(self):
35 | self.fake_images = self.generate(self.z, pg=self.pg, t=self.trans, alpha_trans=self.alpha_tra)
36 | _, self.D_pro_logits = self.discriminate(self.images, reuse=False, pg = self.pg, t=self.trans, alpha_trans=self.alpha_tra)
37 | _, self.G_pro_logits = self.discriminate(self.fake_images, reuse=True,pg= self.pg, t=self.trans, alpha_trans=self.alpha_tra)
38 |
39 | # the defination of loss for D and G
40 | self.D_loss = tf.reduce_mean(self.G_pro_logits) - tf.reduce_mean(self.D_pro_logits)
41 | self.G_loss = -tf.reduce_mean(self.G_pro_logits)
42 |
43 | # gradient penalty from WGAN-GP
44 | self.differences = self.fake_images - self.images
45 | self.alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.)
46 | interpolates = self.images + (self.alpha * self.differences)
47 | _, discri_logits= self.discriminate(interpolates, reuse=True, pg=self.pg, t=self.trans, alpha_trans=self.alpha_tra)
48 | gradients = tf.gradients(discri_logits, [interpolates])[0]
49 |
50 | # 2 norm
51 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2, 3]))
52 | self.gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
53 | tf.summary.scalar("gp_loss", self.gradient_penalty)
54 |
55 | self.D_origin_loss = self.D_loss
56 | self.D_loss += self.lam_gp * self.gradient_penalty
57 | self.D_loss += self.lam_eps * tf.reduce_mean(tf.square(self.D_pro_logits))
58 |
59 | self.log_vars.append(("generator_loss", self.G_loss))
60 | self.log_vars.append(("discriminator_loss", self.D_loss))
61 |
62 | t_vars = tf.trainable_variables()
63 | self.d_vars = [var for var in t_vars if 'dis' in var.name]
64 |
65 | total_para = 0
66 | for variable in self.d_vars:
67 | shape = variable.get_shape()
68 | print (variable.name, shape)
69 | variable_para = 1
70 | for dim in shape:
71 | variable_para *= dim.value
72 | total_para += variable_para
73 | print ("The total para of D", total_para)
74 |
75 | self.g_vars = [var for var in t_vars if 'gen' in var.name]
76 |
77 | total_para2 = 0
78 | for variable in self.g_vars:
79 | shape = variable.get_shape()
80 | print (variable.name, shape)
81 | variable_para = 1
82 | for dim in shape:
83 | variable_para *= dim.value
84 | total_para2 += variable_para
85 | print ("The total para of G", total_para2)
86 |
87 | #save the variables , which remain unchanged
88 | self.d_vars_n = [var for var in self.d_vars if 'dis_n' in var.name]
89 | self.g_vars_n = [var for var in self.g_vars if 'gen_n' in var.name]
90 |
91 | # remove the new variables for the new model
92 | self.d_vars_n_read = [var for var in self.d_vars_n if '{}'.format(self.output_size) not in var.name]
93 | self.g_vars_n_read = [var for var in self.g_vars_n if '{}'.format(self.output_size) not in var.name]
94 |
95 | # save the rgb variables, which remain unchanged
96 | self.d_vars_n_2 = [var for var in self.d_vars if 'dis_y_rgb_conv' in var.name]
97 | self.g_vars_n_2 = [var for var in self.g_vars if 'gen_y_rgb_conv' in var.name]
98 |
99 | self.d_vars_n_2_rgb = [var for var in self.d_vars_n_2 if '{}'.format(self.output_size) not in var.name]
100 | self.g_vars_n_2_rgb = [var for var in self.g_vars_n_2 if '{}'.format(self.output_size) not in var.name]
101 |
102 | print ("d_vars", len(self.d_vars))
103 | print ("g_vars", len(self.g_vars))
104 |
105 | print ("self.d_vars_n_read", len(self.d_vars_n_read))
106 | print ("self.g_vars_n_read", len(self.g_vars_n_read))
107 |
108 | print ("d_vars_n_2_rgb", len(self.d_vars_n_2_rgb))
109 | print ("g_vars_n_2_rgb", len(self.g_vars_n_2_rgb))
110 |
111 | # for n in self.d_vars:
112 | # print (n.name)
113 |
114 | self.g_d_w = [var for var in self.d_vars + self.g_vars if 'bias' not in var.name]
115 |
116 | print ("self.g_d_w", len(self.g_d_w))
117 |
118 | self.saver = tf.train.Saver(self.d_vars + self.g_vars)
119 | self.r_saver = tf.train.Saver(self.d_vars_n_read + self.g_vars_n_read)
120 |
121 | if len(self.d_vars_n_2_rgb + self.g_vars_n_2_rgb):
122 | self.rgb_saver = tf.train.Saver(self.d_vars_n_2_rgb + self.g_vars_n_2_rgb)
123 |
124 | for k, v in self.log_vars:
125 | tf.summary.scalar(k, v)
126 |
127 | # do train
128 | def train(self):
129 | step_pl = tf.placeholder(tf.float32, shape=None)
130 | alpha_tra_assign = self.alpha_tra.assign(step_pl / self.max_iters)
131 |
132 | opti_D = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=0.0, beta2=0.99).minimize(
133 | self.D_loss, var_list=self.d_vars)
134 | opti_G = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=0.0, beta2=0.99).minimize(
135 | self.G_loss, var_list=self.g_vars)
136 |
137 | init = tf.global_variables_initializer()
138 | config = tf.ConfigProto()
139 | config.gpu_options.allow_growth = True
140 |
141 | with tf.Session(config=config) as sess:
142 | sess.run(init)
143 | summary_op = tf.summary.merge_all()
144 | summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph)
145 | if self.pg != 1 and self.pg != 7:
146 | if self.trans:
147 | self.r_saver.restore(sess, self.read_model_path)
148 | self.rgb_saver.restore(sess, self.read_model_path)
149 |
150 | else:
151 | self.saver.restore(sess, self.read_model_path)
152 |
153 | step = 0
154 | batch_num = 0
155 | while step <= self.max_iters:
156 | # optimization D
157 | n_critic = 1
158 | if self.pg >= 5:
159 | n_critic = 1
160 |
161 | for i in range(n_critic):
162 | sample_z = np.random.normal(size=[self.batch_size, self.sample_size])
163 | if self.is_celeba:
164 | train_list = self.data_In.getNextBatch(batch_num, self.batch_size)
165 | realbatch_array = self.data_In.getShapeForData(train_list, resize_w=self.output_size)
166 | else:
167 | realbatch_array = self.data_In.getNextBatch(self.batch_size, resize_w=self.output_size)
168 | realbatch_array = np.transpose(realbatch_array, axes=[0, 3, 2, 1]).transpose([0, 2, 1, 3])
169 |
170 | if self.trans and self.pg != 0:
171 | alpha = np.float(step) / self.max_iters
172 | low_realbatch_array = zoom(realbatch_array, zoom=[1, 0.5, 0.5, 1], mode='nearest')
173 | low_realbatch_array = zoom(low_realbatch_array, zoom=[1, 2, 2, 1], mode='nearest')
174 | realbatch_array = alpha * realbatch_array + (1 - alpha) * low_realbatch_array
175 |
176 | sess.run(opti_D, feed_dict={self.images: realbatch_array, self.z: sample_z})
177 | batch_num += 1
178 |
179 | # optimization G
180 | sess.run(opti_G, feed_dict={self.z: sample_z})
181 |
182 | summary_str = sess.run(summary_op, feed_dict={self.images: realbatch_array, self.z: sample_z})
183 | summary_writer.add_summary(summary_str, step)
184 | summary_writer.add_summary(summary_str, step)
185 | # the alpha of fake_in process
186 | sess.run(alpha_tra_assign, feed_dict={step_pl: step})
187 |
188 | if step % 400 == 0:
189 | D_loss, G_loss, D_origin_loss, alpha_tra = sess.run([self.D_loss, self.G_loss, self.D_origin_loss,self.alpha_tra], feed_dict={self.images: realbatch_array, self.z: sample_z})
190 | print("PG %d, step %d: D loss=%.7f G loss=%.7f, D_or loss=%.7f, opt_alpha_tra=%.7f" % (self.pg, step, D_loss, G_loss, D_origin_loss, alpha_tra))
191 |
192 | realbatch_array = np.clip(realbatch_array, -1, 1)
193 | save_images(realbatch_array[0:self.batch_size], [2, self.batch_size/2],
194 | '{}/{:02d}_real.jpg'.format(self.sample_path, step))
195 |
196 | if self.trans and self.pg != 0:
197 | low_realbatch_array = np.clip(low_realbatch_array, -1, 1)
198 | save_images(low_realbatch_array[0:self.batch_size], [2, self.batch_size / 2],
199 | '{}/{:02d}_real_lower.jpg'.format(self.sample_path, step))
200 |
201 | fake_image = sess.run(self.fake_images,
202 | feed_dict={self.images: realbatch_array, self.z: sample_z})
203 | fake_image = np.clip(fake_image, -1, 1)
204 | save_images(fake_image[0:self.batch_size], [2, self.batch_size/2], '{}/{:02d}_train.jpg'.format(self.sample_path, step))
205 |
206 | if np.mod(step, 4000) == 0 and step != 0:
207 | self.saver.save(sess, self.gan_model_path)
208 |
209 | step += 1
210 |
211 | save_path = self.saver.save(sess, self.gan_model_path)
212 | print ("Model saved in file: %s" % save_path)
213 |
214 | tf.reset_default_graph()
215 |
216 | def discriminate(self, conv, reuse=False, pg=1, t=False, alpha_trans=0.01):
217 | #dis_as_v = []
218 | with tf.variable_scope("discriminator") as scope:
219 |
220 | if reuse == True:
221 | scope.reuse_variables()
222 | if t:
223 | conv_iden = downscale2d(conv)
224 | #from RGB
225 | conv_iden = lrelu(conv2d(conv_iden, output_dim= self.get_nf(pg - 2), k_w=1, k_h=1, d_h=1, d_w=1, use_wscale=self.use_wscale,
226 | name='dis_y_rgb_conv_{}'.format(conv_iden.shape[1])))
227 | # fromRGB
228 | conv = lrelu(conv2d(conv, output_dim=self.get_nf(pg - 1), k_w=1, k_h=1, d_w=1, d_h=1, use_wscale=self.use_wscale, name='dis_y_rgb_conv_{}'.format(conv.shape[1])))
229 |
230 | for i in range(pg - 1):
231 | conv = lrelu(conv2d(conv, output_dim=self.get_nf(pg - 1 - i), d_h=1, d_w=1, use_wscale=self.use_wscale,
232 | name='dis_n_conv_1_{}'.format(conv.shape[1])))
233 | conv = lrelu(conv2d(conv, output_dim=self.get_nf(pg - 2 - i), d_h=1, d_w=1, use_wscale=self.use_wscale,
234 | name='dis_n_conv_2_{}'.format(conv.shape[1])))
235 | conv = downscale2d(conv)
236 | if i == 0 and t:
237 | conv = alpha_trans * conv + (1 - alpha_trans) * conv_iden
238 |
239 | conv = MinibatchstateConcat(conv)
240 | conv = lrelu(
241 | conv2d(conv, output_dim=self.get_nf(1), k_w=3, k_h=3, d_h=1, d_w=1, use_wscale=self.use_wscale, name='dis_n_conv_1_{}'.format(conv.shape[1])))
242 | conv = lrelu(
243 | conv2d(conv, output_dim=self.get_nf(1), k_w=4, k_h=4, d_h=1, d_w=1, use_wscale=self.use_wscale, padding='VALID', name='dis_n_conv_2_{}'.format(conv.shape[1])))
244 | conv = tf.reshape(conv, [self.batch_size, -1])
245 |
246 | #for D
247 | output = fully_connect(conv, output_size=1, use_wscale=self.use_wscale, gain=1, name='dis_n_fully')
248 |
249 | return tf.nn.sigmoid(output), output
250 |
251 | def generate(self, z_var, pg=1, t=False, alpha_trans=0.0):
252 | with tf.variable_scope('generator') as scope:
253 |
254 | de = tf.reshape(Pixl_Norm(z_var), [self.batch_size, 1, 1, int(self.get_nf(1))])
255 | de = conv2d(de, output_dim=self.get_nf(1), k_h=4, k_w=4, d_w=1, d_h=1, use_wscale=self.use_wscale, gain=np.sqrt(2)/4, padding='Other', name='gen_n_1_conv')
256 | de = Pixl_Norm(lrelu(de))
257 | de = tf.reshape(de, [self.batch_size, 4, 4, int(self.get_nf(1))])
258 | de = conv2d(de, output_dim=self.get_nf(1), d_w=1, d_h=1, use_wscale=self.use_wscale, name='gen_n_2_conv')
259 | de = Pixl_Norm(lrelu(de))
260 |
261 | for i in range(pg - 1):
262 | if i == pg - 2 and t:
263 | #To RGB
264 | de_iden = conv2d(de, output_dim=3, k_w=1, k_h=1, d_w=1, d_h=1, use_wscale=self.use_wscale,
265 | name='gen_y_rgb_conv_{}'.format(de.shape[1]))
266 | de_iden = upscale(de_iden, 2)
267 |
268 | de = upscale(de, 2)
269 | de = Pixl_Norm(lrelu(
270 | conv2d(de, output_dim=self.get_nf(i + 1), d_w=1, d_h=1, use_wscale=self.use_wscale, name='gen_n_conv_1_{}'.format(de.shape[1]))))
271 | de = Pixl_Norm(lrelu(
272 | conv2d(de, output_dim=self.get_nf(i + 1), d_w=1, d_h=1, use_wscale=self.use_wscale, name='gen_n_conv_2_{}'.format(de.shape[1]))))
273 |
274 | #To RGB
275 | de = conv2d(de, output_dim=3, k_w=1, k_h=1, d_w=1, d_h=1, use_wscale=self.use_wscale, gain=1, name='gen_y_rgb_conv_{}'.format(de.shape[1]))
276 |
277 | if pg == 1: return de
278 | if t: de = (1 - alpha_trans) * de_iden + alpha_trans*de
279 | else: de = de
280 |
281 | return de
282 |
283 | def get_nf(self, stage):
284 | return min(1024 / (2 **(stage * 1)), 512)
285 |
286 | def sample_z(self, mu, log_var):
287 | eps = tf.random_normal(shape=tf.shape(mu))
288 | return mu + tf.exp(log_var / 2) * eps
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PGGAN-tensorflow
2 | The tensorflow implementation of [PROGRESSIVE GROWING OF GANS FOR IMPROVED QUALITY, STABILITY, AND VARIATION](https://arxiv.org/abs/1710.10196).
3 |
4 | ### The generative process of PG-GAN
5 |
6 |
7 |
8 |
9 |
10 | ## Differences with the original paper.
11 |
12 | - Recently, just generate 64x64 and 128x128 pixels samples.
13 |
14 | ## Setup
15 |
16 | ### Prerequisites
17 |
18 | - TensorFlow >= 1.4
19 | - python 2.7 or 3
20 |
21 | ### Getting Started
22 | - Clone this repo:
23 | ```bash
24 | git clone https://github.com/zhangqianhui/progressive_growing_of_gans_tensorflow.git
25 | cd progressive_growing_of_gans_tensorflow
26 | ```
27 | - Download the CelebA dataset
28 |
29 | You can download the [CelebA dataset](https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AAB06FXaQRUNtjW9ntaoPGvCa?dl=0)
30 | and unzip CelebA into a directory. Noted that this directory don't contain the sub-directory.
31 |
32 | - The method for creating CelebA-HQ can be found on [Method](https://github.com/github-pengge/PyTorch-progressive_growing_of_gans#how-to-create-celeba-hq-dataset)
33 |
34 | - Train the model on CelebA dataset
35 |
36 | ```bash
37 | python main.py --path=your celeba data-path --celeba=True
38 | ```
39 |
40 | - Train the model on CelebA-HQ dataset
41 |
42 | ```bash
43 | python main.py --path=your celeba-hq data-path --celeba=False
44 | ```
45 |
46 | ## Results on celebA dataset
47 | Here is the generated 64x64 results(Left: generated; Right: Real):
48 |
49 |
50 |
51 |
52 |
53 | Here is the generated 128x128 results(Left: generated; Right: Real):
54 |
55 |
56 |
57 |
58 |
59 | ## Results on CelebA-HQ dataset
60 | Here is the generated 64x64 results(Left: Real; Right: Generated):
61 |
62 |
63 |
64 |
65 |
66 | Here is the generated 128x128 results(Left: Real; Right: Generated):
67 |
68 |
69 |
70 |
71 | ## Issue
72 | If you find some bugs, Thanks for your issue to propose it.
73 |
74 | ## Reference code
75 |
76 | [PGGAN Theano](https://github.com/tkarras/progressive_growing_of_gans)
77 |
78 | [PGGAN Pytorch](https://github.com/github-pengge/PyTorch-progressive_growing_of_gans)
79 |
--------------------------------------------------------------------------------
/download.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import sys
4 | import json
5 | import zipfile
6 | import argparse
7 | import subprocess
8 | from six.moves import urllib
9 |
10 | parser = argparse.ArgumentParser(description='Download dataset for DCGAN.')
11 | parser.add_argument('datasets', metavar='N', type=str, nargs='+', choices=['celebA', 'lsun', 'mnist'],
12 | help='name of dataset to download [celebA, lsun, mnist]' , default='mnist')
13 |
14 | def download(url, dirpath):
15 |
16 | filename = url.split('/')[-1]
17 | filepath = os.path.join(dirpath, filename)
18 | u = urllib.request.urlopen(url)
19 | f = open(filepath, 'wb')
20 | filesize = int(u.headers["Content-Length"])
21 | print("Downloading: %s Bytes: %s" % (filename, filesize))
22 |
23 | downloaded = 0
24 | block_sz = 8192
25 | status_width = 70
26 | while True:
27 | buf = u.read(block_sz)
28 | if not buf:
29 | print('')
30 | break
31 | else:
32 | print('', end='\r')
33 | downloaded += len(buf)
34 | f.write(buf)
35 | status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") %
36 | ('=' * int(float(downloaded) / filesize * status_width) + '>', downloaded * 100. / filesize))
37 | print(status, end='')
38 | sys.stdout.flush()
39 | f.close()
40 | return filepath
41 |
42 | def unzip(filepath):
43 | print("Extracting: " + filepath)
44 | dirpath = os.path.dirname(filepath)
45 | with zipfile.ZipFile(filepath) as zf:
46 | zf.extractall(dirpath)
47 | os.remove(filepath)
48 |
49 | def download_celeb_a(dirpath):
50 | data_dir = 'celebA'
51 | if os.path.exists(os.path.join(dirpath, data_dir)):
52 | print('Found Celeb-A - skip')
53 | return
54 | url = 'https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip?dl=1&pv=1'
55 | filepath = download(url, dirpath)
56 | zip_dir = ''
57 | with zipfile.ZipFile(filepath) as zf:
58 | zip_dir = zf.namelist()[0]
59 | zf.extractall(dirpath)
60 | os.remove(filepath)
61 | os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir))
62 |
63 | def _list_categories(tag):
64 | url = 'http://lsun.cs.princeton.edu/htbin/list.cgi?tag=' + tag
65 | f = urllib.request.urlopen(url)
66 | return json.loads(f.read())
67 |
68 | def _download_lsun(out_dir, category, set_name, tag):
69 | url = 'http://lsun.cs.princeton.edu/htbin/download.cgi?tag={tag}' \
70 | '&category={category}&set={set_name}'.format(**locals())
71 | print(url)
72 | if set_name == 'test':
73 | out_name = 'test_lmdb.zip'
74 | else:
75 | out_name = '{category}_{set_name}_lmdb.zip'.format(**locals())
76 | out_path = os.path.join(out_dir, out_name)
77 | cmd = ['curl', url, '-o', out_path]
78 | print('Downloading', category, set_name, 'set')
79 | subprocess.call(cmd)
80 |
81 | def download_lsun(dirpath):
82 | data_dir = os.path.join(dirpath, 'lsun')
83 | if os.path.exists(data_dir):
84 | print('Found LSUN - skip')
85 | return
86 | else:
87 | os.mkdir(data_dir)
88 |
89 | tag = 'latest'
90 | #categories = _list_categories(tag)
91 | categories = ['bedroom']
92 |
93 | for category in categories:
94 | _download_lsun(data_dir, category, 'train', tag)
95 | _download_lsun(data_dir, category, 'val', tag)
96 | _download_lsun(data_dir, '', 'test', tag)
97 |
98 | def download_mnist(dirpath):
99 | data_dir = os.path.join(dirpath, 'mnist')
100 | if os.path.exists(data_dir):
101 | print('Found MNIST - skip')
102 | return
103 | else:
104 | os.mkdir(data_dir)
105 | url_base = 'http://yann.lecun.com/exdb/mnist/'
106 | file_names = ['train-images-idx3-ubyte.gz','train-labels-idx1-ubyte.gz','t10k-images-idx3-ubyte.gz','t10k-labels-idx1-ubyte.gz']
107 | for file_name in file_names:
108 | url = (url_base+file_name).format(**locals())
109 | print(url)
110 | out_path = os.path.join(data_dir,file_name)
111 | cmd = ['curl', url, '-o', out_path]
112 | print('Downloading ', file_name)
113 | subprocess.call(cmd)
114 | cmd = ['gzip', '-d', out_path]
115 | print('Decompressing ', file_name)
116 | subprocess.call(cmd)
117 |
118 | def prepare_data_dir(path = './data'):
119 | if not os.path.exists(path):
120 | os.mkdir(path)
121 |
122 | if __name__ == '__main__':
123 | #args = parser.parse_args()
124 | prepare_data_dir()
125 |
126 | #if 'celebA' in args.datasets:
127 | download_celeb_a('./data')
128 | # if 'lsun' in args.datasets:
129 | # download_lsun('./data')
130 | # if 'mnist' in args.datasets:
131 | # download_mnist('./data')
--------------------------------------------------------------------------------
/h5tool.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import io
4 | import glob
5 | import pickle
6 | import argparse
7 | import threading
8 | import Queue
9 | import traceback
10 | import numpy as np
11 | import scipy.ndimage
12 | import PIL.Image
13 | import h5py # conda install h5py
14 |
15 | # ----------------------------------------------------------------------------
16 |
17 | class HDF5Exporter:
18 | def __init__(self, h5_filename, resolution, channels=3):
19 | rlog2 = int(np.floor(np.log2(resolution)))
20 | assert resolution == 2 ** rlog2
21 | self.resolution = resolution
22 | self.channels = channels
23 | self.h5_file = h5py.File(h5_filename, 'w')
24 | self.h5_lods = []
25 | self.buffers = []
26 | self.buffer_sizes = []
27 | for lod in xrange(rlog2, -1, -1):
28 | r = 2 ** lod;
29 | c = channels
30 | bytes_per_item = c * (r ** 2)
31 | chunk_size = int(np.ceil(128.0 / bytes_per_item))
32 | buffer_size = int(np.ceil(512.0 * np.exp2(20) / bytes_per_item))
33 | lod = self.h5_file.create_dataset('data%dx%d' % (r, r), shape=(0, c, r, r), dtype=np.uint8,
34 | maxshape=(None, c, r, r), chunks=(chunk_size, c, r, r),
35 | compression='gzip', compression_opts=4)
36 | self.h5_lods.append(lod)
37 | self.buffers.append(np.zeros((buffer_size, c, r, r), dtype=np.uint8))
38 | self.buffer_sizes.append(0)
39 |
40 | def close(self):
41 | for lod in xrange(len(self.h5_lods)):
42 | self.flush_lod(lod)
43 | self.h5_file.close()
44 |
45 | def add_images(self, img):
46 | assert img.ndim == 4 and img.shape[1] == self.channels and img.shape[2] == img.shape[3]
47 | assert img.shape[2] >= self.resolution and img.shape[2] == 2 ** int(np.floor(np.log2(img.shape[2])))
48 | for lod in xrange(len(self.h5_lods)):
49 | while img.shape[2] > self.resolution / (2 ** lod):
50 | img = img.astype(np.float32)
51 | img = (img[:, :, 0::2, 0::2] + img[:, :, 0::2, 1::2] + img[:, :, 1::2, 0::2] + img[:, :, 1::2,
52 | 1::2]) * 0.25
53 | quant = np.uint8(np.clip(np.round(img), 0, 255))
54 | ofs = 0
55 | while ofs < quant.shape[0]:
56 | num = min(quant.shape[0] - ofs, self.buffers[lod].shape[0] - self.buffer_sizes[lod])
57 | self.buffers[lod][self.buffer_sizes[lod]: self.buffer_sizes[lod] + num] = quant[ofs: ofs + num]
58 | self.buffer_sizes[lod] += num
59 | if self.buffer_sizes[lod] == self.buffers[lod].shape[0]:
60 | self.flush_lod(lod)
61 | ofs += num
62 |
63 | def num_images(self):
64 | return self.h5_lods[0].shape[0] + self.buffer_sizes[0]
65 |
66 | def flush_lod(self, lod):
67 | num = self.buffer_sizes[lod]
68 | if num > 0:
69 | self.h5_lods[lod].resize(self.h5_lods[lod].shape[0] + num, axis=0)
70 | self.h5_lods[lod][-num:] = self.buffers[lod][:num]
71 | self.buffer_sizes[lod] = 0
72 |
73 |
74 | # ----------------------------------------------------------------------------
75 |
76 | class ExceptionInfo(object):
77 | def __init__(self):
78 | self.type, self.value = sys.exc_info()[:2]
79 | self.traceback = traceback.format_exc()
80 |
81 |
82 | # ----------------------------------------------------------------------------
83 |
84 | class WorkerThread(threading.Thread):
85 | def __init__(self, task_queue):
86 | threading.Thread.__init__(self)
87 | self.task_queue = task_queue
88 |
89 | def run(self):
90 | while True:
91 | func, args, result_queue = self.task_queue.get()
92 | if func is None:
93 | break
94 | try:
95 | result = func(*args)
96 | except:
97 | result = ExceptionInfo()
98 | result_queue.put((result, args))
99 |
100 |
101 | # ----------------------------------------------------------------------------
102 |
103 | class ThreadPool(object):
104 | def __init__(self, num_threads):
105 | assert num_threads >= 1
106 | self.task_queue = Queue.Queue()
107 | self.result_queues = dict()
108 | self.num_threads = num_threads
109 | for idx in xrange(self.num_threads):
110 | thread = WorkerThread(self.task_queue)
111 | thread.daemon = True
112 | thread.start()
113 |
114 | def add_task(self, func, args=()):
115 | assert hasattr(func, '__call__') # must be a function
116 | if func not in self.result_queues:
117 | self.result_queues[func] = Queue.Queue()
118 | self.task_queue.put((func, args, self.result_queues[func]))
119 |
120 | def get_result(self, func, verbose_exceptions=True): # returns (result, args)
121 | result, args = self.result_queues[func].get()
122 | if isinstance(result, ExceptionInfo):
123 | if verbose_exceptions:
124 | print('\n\nWorker thread caught an exception:\n' + result.traceback + '\n')
125 | raise Exception('%s, %s' % (result.type, result.value))
126 | return result, args
127 |
128 | def finish(self):
129 | for idx in xrange(self.num_threads):
130 | self.task_queue.put((None, (), None))
131 |
132 | def __enter__(self): # for 'with' statement
133 | return self
134 |
135 | def __exit__(self, *excinfo):
136 | self.finish()
137 |
138 | def process_items_concurrently(self, item_iterator, process_func=lambda x: x, pre_func=lambda x: x,
139 | post_func=lambda x: x, max_items_in_flight=None):
140 | if max_items_in_flight is None: max_items_in_flight = self.num_threads * 4
141 | assert max_items_in_flight >= 1
142 | results = []
143 | retire_idx = [0]
144 |
145 | def task_func(prepared, idx):
146 | return process_func(prepared)
147 |
148 | def retire_result():
149 | processed, (prepared, idx) = self.get_result(task_func)
150 | results[idx] = processed
151 | while retire_idx[0] < len(results) and results[retire_idx[0]] is not None:
152 | yield post_func(results[retire_idx[0]])
153 | results[retire_idx[0]] = None
154 | retire_idx[0] += 1
155 |
156 | for idx, item in enumerate(item_iterator):
157 | prepared = pre_func(item)
158 | results.append(None)
159 | self.add_task(func=task_func, args=(prepared, idx))
160 | while retire_idx[0] < idx - max_items_in_flight + 2:
161 | for res in retire_result(): yield res
162 | while retire_idx[0] < len(results):
163 | for res in retire_result(): yield res
164 |
165 |
166 | # ----------------------------------------------------------------------------
167 |
168 | def inspect(h5_filename):
169 | print('%-20s%s' % ('HDF5 filename', h5_filename))
170 | file_size = os.stat(h5_filename).st_size
171 | print('%-20s%.2f GB' % ('Total size', float(file_size) / np.exp2(30)))
172 |
173 | h5 = h5py.File(h5_filename, 'r')
174 | lods = sorted([value for key, value in h5.iteritems() if key.startswith('data')], key=lambda lod: -lod.shape[3])
175 | shapes = [lod.shape for lod in lods]
176 | shape = shapes[0]
177 | h5.close()
178 | print('%-20s%d' % ('Total images', shape[0]))
179 | print('%-20s%dx%d' % ('Resolution', shape[3], shape[2]))
180 | print('%-20s%d' % ('Color channels', shape[1]))
181 | print('%-20s%.2f KB' % ('Size per image', float(file_size) / shape[0] / np.exp2(10)))
182 |
183 | if len(lods) != int(np.log2(shape[3])) + 1:
184 | print('Warning: The HDF5 file contains incorrect number of LODs')
185 | if any(s[0] != shape[0] for s in shapes):
186 | print('Warning: The HDF5 file contains inconsistent number of images in different LODs')
187 | print('Perhaps the dataset creation script was terminated abruptly?')
188 |
189 |
190 | # ----------------------------------------------------------------------------
191 |
192 | def compare(first_h5, second_h5):
193 | print('Comparing %s vs. %s' % (first_h5, second_h5))
194 | h5_a = h5py.File(first_h5, 'r')
195 | h5_b = h5py.File(second_h5, 'r')
196 | lods_a = sorted([value for key, value in h5_a.iteritems() if key.startswith('data')], key=lambda lod: -lod.shape[3])
197 | lods_b = sorted([value for key, value in h5_b.iteritems() if key.startswith('data')], key=lambda lod: -lod.shape[3])
198 | shape_a = lods_a[0].shape
199 | shape_b = lods_b[0].shape
200 |
201 | if shape_a[1] != shape_b[1]:
202 | print('The datasets have different number of color channels: %d vs. %d' % (shape_a[1], shape_b[1]))
203 | elif shape_a[3] != shape_b[3] or shape_a[2] != shape_b[2]:
204 | print(
205 | 'The datasets have different resolution: %dx%d vs. %dx%d' % (shape_a[3], shape_a[2], shape_b[3], shape_b[2]))
206 | else:
207 | min_images = min(shape_a[0], shape_b[0])
208 | num_diffs = 0
209 | for idx in range(min_images):
210 | print('%d / %d\r' % (idx, min_images))
211 | if np.any(lods_a[0][idx] != lods_b[0][idx]):
212 | print('%-40s\r' % '')
213 | print('Different image: %d' % idx)
214 | num_diffs += 1
215 | if shape_a[0] != shape_b[0]:
216 | print('The datasets contain different number of images: %d vs. %d' % (shape_a[0], shape_b[0]))
217 | if num_diffs == 0:
218 | print('All %d images are identical.' % min_images)
219 | else:
220 | print('%d images out of %d are different.' % (num_diffs, min_images))
221 |
222 | h5_a.close()
223 | h5_b.close()
224 |
225 |
226 | # ----------------------------------------------------------------------------
227 |
228 | def display(h5_filename, start=None, stop=None, step=None):
229 | print('Displaying images from %s' % h5_filename)
230 | h5 = h5py.File(h5_filename, 'r')
231 | lods = sorted([value for key, value in h5.iteritems() if key.startswith('data')], key=lambda lod: -lod.shape[3])
232 | indices = range(lods[0].shape[0])
233 | indices = indices[start: stop: step]
234 |
235 | import cv2 # pip install opencv-python
236 | window_name = 'h5tool'
237 | cv2.namedWindow(window_name)
238 | print('Press SPACE or ENTER to advance, ESC to exit.')
239 |
240 | for idx in indices:
241 | print('%d / %d\r' % (idx, lods[0].shape[0]))
242 | img = lods[0][idx]
243 | img = img.transpose(1, 2, 0) # CHW => HWC
244 | img = img[:, :, ::-1] # RGB => BGR
245 | cv2.imshow(window_name, img)
246 | c = cv2.waitKey()
247 | if c == 27:
248 | break
249 |
250 | h5.close()
251 | print('%-40s\r' % '')
252 | print('Done.')
253 |
254 |
255 | # ----------------------------------------------------------------------------
256 |
257 | def extract(h5_filename, output_dir, start=None, stop=None, step=None):
258 | print('Extracting images from %s to %s' % (h5_filename, output_dir))
259 | h5 = h5py.File(h5_filename, 'r')
260 | lods = sorted([value for key, value in h5.iteritems() if key.startswith('data')], key=lambda lod: -lod.shape[3])
261 | shape = lods[0].shape
262 | indices = range(shape[0])[start: stop: step]
263 | if not os.path.isdir(output_dir):
264 | os.makedirs(output_dir)
265 |
266 | for idx in indices:
267 | print('%d / %d\r' % (idx, shape[0]))
268 | img = lods[0][idx]
269 | if img.shape[0] == 1:
270 | img = PIL.Image.fromarray(img[0], 'L')
271 | else:
272 | img = PIL.Image.fromarray(img.transpose(1, 2, 0), 'RGB')
273 | img.save(os.path.join(output_dir, 'img%08d.png' % idx))
274 |
275 | h5.close()
276 | print('%-40s\r' % '')
277 | print('Extracted %d images.' % len(indices))
278 |
279 |
280 | # ----------------------------------------------------------------------------
281 |
282 | def create_custom(h5_filename, image_dir):
283 | print('Creating custom dataset %s from %s' % (h5_filename, image_dir))
284 | glob_pattern = os.path.join(image_dir, '*')
285 | image_filenames = sorted(glob.glob(glob_pattern))
286 | if len(image_filenames) == 0:
287 | print('Error: No input images found in %s' % glob_pattern)
288 | return
289 |
290 | img = np.asarray(PIL.Image.open(image_filenames[0]))
291 | resolution = img.shape[0]
292 | channels = img.shape[2] if img.ndim == 3 else 1
293 | if img.shape[1] != resolution:
294 | print('Error: Input images must have the same width and height')
295 | return
296 | if resolution != 2 ** int(np.floor(np.log2(resolution))):
297 | print('Error: Input image resolution must be a power-of-two')
298 | return
299 | if channels not in [1, 3]:
300 | print('Error: Input images must be stored as RGB or grayscale')
301 |
302 | h5 = HDF5Exporter(h5_filename, resolution, channels)
303 | for idx in xrange(len(image_filenames)):
304 | print('%d / %d\r' % (idx, len(image_filenames)))
305 | img = np.asarray(PIL.Image.open(image_filenames[idx]))
306 | if channels == 1:
307 | img = img[np.newaxis, :, :] # HW => CHW
308 | else:
309 | img = img.transpose(2, 0, 1) # HWC => CHW
310 | h5.add_images(img[np.newaxis])
311 |
312 | print('%-40s\r' % 'Flushing data...')
313 | h5.close()
314 | print('%-40s\r' % '')
315 | print('Added %d images.' % len(image_filenames))
316 |
317 |
318 | # ----------------------------------------------------------------------------
319 |
320 | def create_mnist(h5_filename, mnist_dir, export_labels=False):
321 | print('Loading MNIST data from %s' % mnist_dir)
322 | import gzip
323 | with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file:
324 | images = np.frombuffer(file.read(), np.uint8, offset=16)
325 | with gzip.open(os.path.join(mnist_dir, 'train-labels-idx1-ubyte.gz'), 'rb') as file:
326 | labels = np.frombuffer(file.read(), np.uint8, offset=8)
327 | images = images.reshape(-1, 1, 28, 28)
328 | images = np.pad(images, [(0, 0), (0, 0), (2, 2), (2, 2)], 'constant', constant_values=0)
329 | assert images.shape == (60000, 1, 32, 32) and images.dtype == np.uint8
330 | assert labels.shape == (60000,) and labels.dtype == np.uint8
331 | assert np.min(images) == 0 and np.max(images) == 255
332 | assert np.min(labels) == 0 and np.max(labels) == 9
333 |
334 | print('Creating %s' % h5_filename)
335 | h5 = HDF5Exporter(h5_filename, 32, 1)
336 | h5.add_images(images)
337 | h5.close()
338 |
339 | if export_labels:
340 | npy_filename = os.path.splitext(h5_filename)[0] + '-labels.npy'
341 | print('Creating %s' % npy_filename)
342 | onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
343 | onehot[np.arange(labels.size), labels] = 1.0
344 | np.save(npy_filename, onehot)
345 | print('Added %d images.' % images.shape[0])
346 |
347 |
348 | # ----------------------------------------------------------------------------
349 |
350 | def create_mnist_rgb(h5_filename, mnist_dir, num_images=1000000, random_seed=123):
351 | print('Loading MNIST data from %s' % mnist_dir)
352 | import gzip
353 | with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file:
354 | images = np.frombuffer(file.read(), np.uint8, offset=16)
355 | images = images.reshape(-1, 28, 28)
356 | images = np.pad(images, [(0, 0), (2, 2), (2, 2)], 'constant', constant_values=0)
357 | assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
358 | assert np.min(images) == 0 and np.max(images) == 255
359 |
360 | print('Creating %s' % h5_filename)
361 | h5 = HDF5Exporter(h5_filename, 32, 3)
362 | np.random.seed(random_seed)
363 | for idx in xrange(num_images):
364 | if idx % 100 == 0:
365 | print('%d / %d\r' % (idx, num_images))
366 | h5.add_images(images[np.newaxis, np.random.randint(images.shape[0], size=3)])
367 |
368 | print('%-40s\r' % 'Flushing data...')
369 | h5.close()
370 | print('%-40s\r' % '')
371 | print('Added %d images.' % num_images)
372 |
373 |
374 | # ----------------------------------------------------------------------------
375 |
376 | def create_cifar10(h5_filename, cifar10_dir, export_labels=False):
377 | print('Loading CIFAR-10 data from %s' % cifar10_dir)
378 | images = []
379 | labels = []
380 | for batch in xrange(1, 6):
381 | with open(os.path.join(cifar10_dir, 'data_batch_%d' % batch), 'rb') as file:
382 | data = pickle.load(file)
383 | images.append(data['data'].reshape(-1, 3, 32, 32))
384 | labels.append(np.uint8(data['labels']))
385 | images = np.concatenate(images)
386 | labels = np.concatenate(labels)
387 |
388 | assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8
389 | assert labels.shape == (50000,) and labels.dtype == np.uint8
390 | assert np.min(images) == 0 and np.max(images) == 255
391 | assert np.min(labels) == 0 and np.max(labels) == 9
392 |
393 | print('Creating %s' % h5_filename)
394 | h5 = HDF5Exporter(h5_filename, 32, 3)
395 | h5.add_images(images)
396 | h5.close()
397 |
398 | if export_labels:
399 | npy_filename = os.path.splitext(h5_filename)[0] + '-labels.npy'
400 | print('Creating %s' % npy_filename)
401 | onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
402 | onehot[np.arange(labels.size), labels] = 1.0
403 | np.save(npy_filename, onehot)
404 | print('Added %d images.' % images.shape[0])
405 |
406 |
407 | # ----------------------------------------------------------------------------
408 |
409 | def create_lsun(h5_filename, lmdb_dir, resolution=256, max_images=None):
410 | print('Creating LSUN dataset %s from %s' % (h5_filename, lmdb_dir))
411 | import lmdb # pip install lmdb
412 | import cv2 # pip install opencv-python
413 | with lmdb.open(lmdb_dir, readonly=True).begin(write=False) as txn:
414 | total_images = txn.stat()['entries']
415 | if max_images is None:
416 | max_images = total_images
417 |
418 | h5 = HDF5Exporter(h5_filename, resolution, 3)
419 | for idx, (key, value) in enumerate(txn.cursor()):
420 | print('%d / %d\r' % (h5.num_images(), min(h5.num_images() + total_images - idx, max_images)))
421 | try:
422 | try:
423 | img = cv2.imdecode(np.fromstring(value, dtype=np.uint8), 1)
424 | if img is None:
425 | raise IOError('cv2.imdecode failed')
426 | img = img[:, :, ::-1] # BGR => RGB
427 | except IOError:
428 | img = np.asarray(PIL.Image.open(io.BytesIO(value)))
429 | crop = np.min(img.shape[:2])
430 | img = img[(img.shape[0] - crop) / 2: (img.shape[0] + crop) / 2,
431 | (img.shape[1] - crop) / 2: (img.shape[1] + crop) / 2]
432 | img = PIL.Image.fromarray(img, 'RGB')
433 | img = img.resize((resolution, resolution), PIL.Image.ANTIALIAS)
434 | img = np.asarray(img)
435 | img = img.transpose(2, 0, 1) # HWC => CHW
436 | h5.add_images(img[np.newaxis])
437 | except:
438 | print('%-40s\r' % '')
439 | print(sys.exc_info()[1])
440 | raise
441 | if h5.num_images() == max_images:
442 | break
443 |
444 | print('%-40s\r' % 'Flushing data...')
445 | num_added = h5.num_images()
446 | h5.close()
447 | print('%-40s\r' % '')
448 | print('Added %d images.' % num_added)
449 |
450 |
451 | # ----------------------------------------------------------------------------
452 |
453 | def create_celeba(h5_filename, celeba_dir, cx=89, cy=121):
454 | print('Creating CelebA dataset %s from %s' % (h5_filename, celeba_dir))
455 | glob_pattern = os.path.join(celeba_dir, 'img_align_celeba_png', '*.png')
456 | image_filenames = sorted(glob.glob(glob_pattern))
457 | num_images = 202599
458 | if len(image_filenames) != num_images:
459 | print('Error: Expected to find %d images in %s' % (num_images, glob_pattern))
460 | return
461 |
462 | h5 = HDF5Exporter(h5_filename, 128, 3)
463 | for idx in xrange(num_images):
464 | print('%d / %d\r' % (idx, num_images))
465 | img = np.asarray(PIL.Image.open(image_filenames[idx]))
466 | assert img.shape == (218, 178, 3)
467 | img = img[cy - 64: cy + 64, cx - 64: cx + 64]
468 | img = img.transpose(2, 0, 1) # HWC => CHW
469 | h5.add_images(img[np.newaxis])
470 |
471 | print('%-40s\r' % 'Flushing data...')
472 | h5.close()
473 | print('%-40s\r' % '')
474 | print('Added %d images.' % num_images)
475 |
476 |
477 | # ----------------------------------------------------------------------------
478 |
479 | def create_celeba_hq(h5_filename, celeba_dir, delta_dir, num_threads=4, num_tasks=100):
480 | print('Loading CelebA data from %s' % celeba_dir)
481 | glob_pattern = os.path.join(celeba_dir, '*.jpg')
482 | glob_expected = 202599
483 | if len(glob.glob(glob_pattern)) != glob_expected:
484 | print('Error: Expected to find %d images in %s' % (glob_expected, glob_pattern))
485 | return
486 | with open(os.path.join(celeba_dir, 'list_landmarks_celeba.txt'), 'rt') as file:
487 | landmarks = [[float(value) for value in line.split()[1:]] for line in file.readlines()[2:]]
488 | landmarks = np.float32(landmarks).reshape(-1, 5, 2)
489 |
490 | print('Loading CelebA-HQ deltas from %s' % delta_dir)
491 | import hashlib
492 | import bz2
493 | import zipfile
494 | import base64
495 | import cryptography.hazmat.primitives.hashes
496 | import cryptography.hazmat.backends
497 | import cryptography.hazmat.primitives.kdf.pbkdf2
498 | import cryptography.fernet
499 | glob_pattern = os.path.join(delta_dir, 'delta*.zip')
500 | glob_expected = 30
501 | if len(glob.glob(glob_pattern)) != glob_expected:
502 | print('Error: Expected to find %d zips in %s' % (glob_expected, glob_pattern))
503 | return
504 | with open(os.path.join(delta_dir, 'image_list.txt'), 'rt') as file:
505 | lines = [line.split() for line in file]
506 | fields = dict()
507 | for idx, field in enumerate(lines[0]):
508 | type = int if field.endswith('idx') else str
509 | fields[field] = [type(line[idx]) for line in lines[1:]]
510 |
511 | def rot90(v):
512 | return np.array([-v[1], v[0]])
513 |
514 | def process_func(idx):
515 | # Load original image.
516 | orig_idx = fields['orig_idx'][idx]
517 | orig_file = fields['orig_file'][idx]
518 | orig_path = os.path.join(celeba_dir, orig_file)
519 | img = PIL.Image.open(orig_path)
520 |
521 | # Choose oriented crop rectangle.
522 | lm = landmarks[orig_idx]
523 | eye_avg = (lm[0] + lm[1]) * 0.5 + 0.5
524 | mouth_avg = (lm[3] + lm[4]) * 0.5 + 0.5
525 | eye_to_eye = lm[1] - lm[0]
526 | eye_to_mouth = mouth_avg - eye_avg
527 | x = eye_to_eye - rot90(eye_to_mouth)
528 | x /= np.hypot(*x)
529 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
530 | y = rot90(x)
531 | c = eye_avg + eye_to_mouth * 0.1
532 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
533 | zoom = 1024 / (np.hypot(*x) * 2)
534 |
535 | # Shrink.
536 | shrink = int(np.floor(0.5 / zoom))
537 | if shrink > 1:
538 | size = (int(np.round(float(img.size[0]) / shrink)), int(np.round(float(img.size[1]) / shrink)))
539 | img = img.resize(size, PIL.Image.ANTIALIAS)
540 | quad /= shrink
541 | zoom *= shrink
542 |
543 | # Crop.
544 | border = max(int(np.round(1024 * 0.1 / zoom)), 3)
545 | crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
546 | int(np.ceil(max(quad[:, 1]))))
547 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
548 | min(crop[3] + border, img.size[1]))
549 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
550 | img = img.crop(crop)
551 | quad -= crop[0:2]
552 |
553 | # Simulate super-resolution.
554 | superres = int(np.exp2(np.ceil(np.log2(zoom))))
555 | if superres > 1:
556 | img = img.resize((img.size[0] * superres, img.size[1] * superres), PIL.Image.ANTIALIAS)
557 | quad *= superres
558 | zoom /= superres
559 |
560 | # Pad.
561 | pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
562 | int(np.ceil(max(quad[:, 1]))))
563 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
564 | max(pad[3] - img.size[1] + border, 0))
565 | if max(pad) > border - 4:
566 | pad = np.maximum(pad, int(np.round(1024 * 0.3 / zoom)))
567 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
568 | h, w, _ = img.shape
569 | y, x, _ = np.mgrid[:h, :w, :1]
570 | mask = 1.0 - np.minimum(np.minimum(np.float32(x) / pad[0], np.float32(y) / pad[1]),
571 | np.minimum(np.float32(w - 1 - x) / pad[2], np.float32(h - 1 - y) / pad[3]))
572 | blur = 1024 * 0.02 / zoom
573 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
574 | img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
575 | img = PIL.Image.fromarray(np.uint8(np.clip(np.round(img), 0, 255)), 'RGB')
576 | quad += pad[0:2]
577 |
578 | # Transform.
579 | img = img.transform((4096, 4096), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
580 | img = img.resize((1024, 1024), PIL.Image.ANTIALIAS)
581 | img = np.asarray(img).transpose(2, 0, 1)
582 |
583 | # Verify MD5.
584 | md5 = hashlib.md5()
585 | md5.update(img.tobytes())
586 | # assert md5.hexdigest() == fields['proc_md5'][idx] # disable md5 verify
587 |
588 | # Load delta image and original JPG.
589 | with zipfile.ZipFile(os.path.join(delta_dir, 'deltas%05d.zip' % (idx - idx % 1000)), 'r') as zip:
590 | delta_bytes = zip.read('delta%05d.dat' % idx)
591 | with open(orig_path, 'rb') as file:
592 | orig_bytes = file.read()
593 |
594 | # Decrypt delta image, using original JPG data as decryption key.
595 | algorithm = cryptography.hazmat.primitives.hashes.SHA256()
596 | backend = cryptography.hazmat.backends.default_backend()
597 | kdf = cryptography.hazmat.primitives.kdf.pbkdf2.PBKDF2HMAC(algorithm=algorithm, length=32, salt=orig_file,
598 | iterations=100000, backend=backend)
599 | key = base64.urlsafe_b64encode(kdf.derive(orig_bytes))
600 | delta = np.frombuffer(bz2.decompress(cryptography.fernet.Fernet(key).decrypt(delta_bytes)),
601 | dtype=np.uint8).reshape(3, 1024, 1024)
602 |
603 | # Apply delta image.
604 | img = img + delta
605 |
606 | # Verify MD5.
607 | md5 = hashlib.md5()
608 | md5.update(img.tobytes())
609 | # assert md5.hexdigest() == fields['final_md5'][idx] # disable md5 verify
610 | return idx, img
611 |
612 | print('Creating %s' % h5_filename)
613 | h5 = HDF5Exporter(h5_filename, 1024, 3)
614 | with ThreadPool(num_threads) as pool:
615 | print('%d / %d\r' % (0, len(fields['idx'])))
616 | for idx, img in pool.process_items_concurrently(fields['idx'], process_func=process_func,
617 | max_items_in_flight=num_tasks):
618 | h5.add_images(img[np.newaxis])
619 | print('%d / %d\r' % (idx + 1, len(fields['idx'])))
620 |
621 | print('%-40s\r' % 'Flushing data...')
622 | h5.close()
623 | print('%-40s\r' % '')
624 | print('Added %d images.' % len(fields['idx']))
625 |
626 |
627 | # ----------------------------------------------------------------------------
628 |
629 | def execute_cmdline(argv):
630 | prog = argv[0]
631 | parser = argparse.ArgumentParser(
632 | prog=prog,
633 | description='Tool for creating, extracting, and visualizing HDF5 datasets.',
634 | epilog='Type "%s -h" for more information.' % prog)
635 |
636 | subparsers = parser.add_subparsers(dest='command')
637 |
638 | def add_command(cmd, desc, example=None):
639 | epilog = 'Example: %s %s' % (prog, example) if example is not None else None
640 | return subparsers.add_parser(cmd, description=desc, help=desc, epilog=epilog)
641 |
642 | p = add_command('inspect', 'Print information about HDF5 dataset.',
643 | 'inspect mnist-32x32.h5')
644 | p.add_argument('h5_filename', help='HDF5 file to inspect')
645 |
646 | p = add_command('compare', 'Compare two HDF5 datasets.',
647 | 'compare mydataset.h5 mnist-32x32.h5')
648 | p.add_argument('first_h5', help='First HDF5 file to compare')
649 | p.add_argument('second_h5', help='Second HDF5 file to compare')
650 |
651 | p = add_command('display', 'Display images in HDF5 dataset.',
652 | 'display mnist-32x32.h5')
653 | p.add_argument('h5_filename', help='HDF5 file to visualize')
654 | p.add_argument('--start', help='Start index (inclusive)', type=int, default=None)
655 | p.add_argument('--stop', help='Stop index (exclusive)', type=int, default=None)
656 | p.add_argument('--step', help='Step between consecutive indices', type=int, default=None)
657 |
658 | p = add_command('extract', 'Extract images from HDF5 dataset.',
659 | 'extract mnist-32x32.h5 cifar10-images')
660 | p.add_argument('h5_filename', help='HDF5 file to extract')
661 | p.add_argument('output_dir', help='Directory to extract the images into')
662 | p.add_argument('--start', help='Start index (inclusive)', type=int, default=None)
663 | p.add_argument('--stop', help='Stop index (exclusive)', type=int, default=None)
664 | p.add_argument('--step', help='Step between consecutive indices', type=int, default=None)
665 |
666 | p = add_command('create_custom', 'Create HDF5 dataset for custom images.',
667 | 'create_custom mydataset.h5 myimagedir')
668 | p.add_argument('h5_filename', help='HDF5 file to create')
669 | p.add_argument('image_dir', help='Directory to read the images from')
670 |
671 | p = add_command('create_mnist', 'Create HDF5 dataset for MNIST.',
672 | 'create_mnist mnist-32x32.h5 ~/mnist --export_labels')
673 | p.add_argument('h5_filename', help='HDF5 file to create')
674 | p.add_argument('mnist_dir', help='Directory to read MNIST data from')
675 | p.add_argument('--export_labels', help='Create *-labels.npy alongside the HDF5', action='store_true')
676 |
677 | p = add_command('create_mnist_rgb', 'Create HDF5 dataset for MNIST-RGB.',
678 | 'create_mnist_rgb mnist-rgb-32x32.h5 ~/mnist')
679 | p.add_argument('h5_filename', help='HDF5 file to create')
680 | p.add_argument('mnist_dir', help='Directory to read MNIST data from')
681 | p.add_argument('--num_images', help='Number of composite images to create (default: 1000000)', type=int,
682 | default=1000000)
683 | p.add_argument('--random_seed', help='Random seed (default: 123)', type=int, default=123)
684 |
685 | p = add_command('create_cifar10', 'Create HDF5 dataset for CIFAR-10.',
686 | 'create_cifar10 cifar-10-32x32.h5 ~/cifar10 --export_labels')
687 | p.add_argument('h5_filename', help='HDF5 file to create')
688 | p.add_argument('cifar10_dir', help='Directory to read CIFAR-10 data from')
689 | p.add_argument('--export_labels', help='Create *-labels.npy alongside the HDF5', action='store_true')
690 |
691 | p = add_command('create_lsun', 'Create HDF5 dataset for single LSUN category.',
692 | 'create_lsun lsun-airplane-256x256-100k.h5 ~/lsun/airplane_lmdb --resolution 256 --max_images 100000')
693 | p.add_argument('h5_filename', help='HDF5 file to create')
694 | p.add_argument('lmdb_dir', help='Directory to read LMDB database from')
695 | p.add_argument('--resolution', help='Output resolution (default: 256)', type=int, default=256)
696 | p.add_argument('--max_images', help='Maximum number of images (default: none)', type=int, default=None)
697 |
698 | p = add_command('create_celeba', 'Create HDF5 dataset for CelebA.',
699 | 'create_celeba celeba-128x128.h5 ~/celeba')
700 | p.add_argument('h5_filename', help='HDF5 file to create')
701 | p.add_argument('celeba_dir', help='Directory to read CelebA data from')
702 | p.add_argument('--cx', help='Center X coordinate (default: 89)', type=int, default=89)
703 | p.add_argument('--cy', help='Center Y coordinate (default: 121)', type=int, default=121)
704 |
705 | p = add_command('create_celeba_hq', 'Create HDF5 dataset for CelebA-HQ.',
706 | 'create_celeba_hq celeba-hq-1024x1024.h5 ~/celeba ~/celeba-hq-deltas')
707 | p.add_argument('h5_filename', help='HDF5 file to create')
708 | p.add_argument('celeba_dir', help='Directory to read CelebA data from')
709 | p.add_argument('delta_dir', help='Directory to read CelebA-HQ deltas from')
710 | p.add_argument('--num_threads', help='Number of concurrent threads (default: 4)', type=int, default=4)
711 | p.add_argument('--num_tasks', help='Number of concurrent processing tasks (default: 100)', type=int, default=100)
712 |
713 | args = parser.parse_args(argv[1:])
714 | func = globals()[args.command]
715 | del args.command
716 | func(**vars(args))
717 |
718 |
719 | # ----------------------------------------------------------------------------
720 |
721 | if __name__ == "__main__":
722 | execute_cmdline(sys.argv)
--------------------------------------------------------------------------------
/images/figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangqianhui/progressive_growing_of_gans_tensorflow/e16b097117169e9521104138d4d461b2aef5a5fb/images/figure.png
--------------------------------------------------------------------------------
/images/hs_sample_128.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangqianhui/progressive_growing_of_gans_tensorflow/e16b097117169e9521104138d4d461b2aef5a5fb/images/hs_sample_128.jpg
--------------------------------------------------------------------------------
/images/hs_sample_64.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangqianhui/progressive_growing_of_gans_tensorflow/e16b097117169e9521104138d4d461b2aef5a5fb/images/hs_sample_64.jpg
--------------------------------------------------------------------------------
/images/sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangqianhui/progressive_growing_of_gans_tensorflow/e16b097117169e9521104138d4d461b2aef5a5fb/images/sample.png
--------------------------------------------------------------------------------
/images/sample_128.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangqianhui/progressive_growing_of_gans_tensorflow/e16b097117169e9521104138d4d461b2aef5a5fb/images/sample_128.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from utils import mkdir_p
4 | from PGGAN import PGGAN
5 | from utils import CelebA, CelebA_HQ
6 | flags = tf.app.flags
7 | import os
8 |
9 | os.environ['CUDA_VISIBLE_DEVICES']='0'
10 |
11 | flags.DEFINE_string("OPER_NAME", "Experiment_6_30_1", "the name of experiments")
12 | flags.DEFINE_integer("OPER_FLAG", 0, "Flag of opertion: 0 is for training ")
13 | flags.DEFINE_string("path" , '?', "Path of training data, for example /home/hehe/")
14 | flags.DEFINE_integer("batch_size", 16, "Batch size")
15 | flags.DEFINE_integer("sample_size", 512, "Size of sample")
16 | flags.DEFINE_integer("max_iters", 40000, "Maxmization of training number")
17 | flags.DEFINE_float("learn_rate", 0.001, "Learning rate for G and D networks")
18 | flags.DEFINE_integer("lam_gp", 10, "Weight of gradient penalty term")
19 | flags.DEFINE_float("lam_eps", 0.001, "Weight for the epsilon term")
20 | flags.DEFINE_integer("flag", 11, "FLAG of gan training process")
21 | flags.DEFINE_boolean("use_wscale", True, "Using the scale of weight")
22 | flags.DEFINE_boolean("celeba", True, "Whether using celeba or using CelebA-HQ")
23 |
24 | FLAGS = flags.FLAGS
25 | if __name__ == "__main__":
26 |
27 | root_log_dir = "./output/{}/logs/".format(FLAGS.OPER_NAME)
28 | mkdir_p(root_log_dir)
29 |
30 | if FLAGS.celeba:
31 | data_In = CelebA(FLAGS.path)
32 | else:
33 | data_In = CelebA_HQ(FLAGS.path)
34 |
35 | print ("the num of dataset", len(data_In.image_list))
36 |
37 | if FLAGS.OPER_FLAG == 0:
38 |
39 | fl = [1,2,2,3,3,4,4,5,5,6,6]
40 | r_fl = [1,1,2,2,3,3,4,4,5,5,6]
41 |
42 | for i in range(FLAGS.flag):
43 |
44 | t = False if (i % 2 == 0) else True
45 | pggan_checkpoint_dir_write = "./output/{}/model_pggan_{}/{}/".format(FLAGS.OPER_NAME, FLAGS.OPER_FLAG, fl[i])
46 | sample_path = "./output/{}/{}/sample_{}_{}".format(FLAGS.OPER_NAME, FLAGS.OPER_FLAG, fl[i], t)
47 | mkdir_p(pggan_checkpoint_dir_write)
48 | mkdir_p(sample_path)
49 | pggan_checkpoint_dir_read = "./output/{}/model_pggan_{}/{}/".format(FLAGS.OPER_NAME, FLAGS.OPER_FLAG, r_fl[i])
50 |
51 | pggan = PGGAN(batch_size=FLAGS.batch_size, max_iters=FLAGS.max_iters,
52 | model_path=pggan_checkpoint_dir_write, read_model_path=pggan_checkpoint_dir_read,
53 | data=data_In, sample_size=FLAGS.sample_size,
54 | sample_path=sample_path, log_dir=root_log_dir, learn_rate=FLAGS.learn_rate, lam_gp=FLAGS.lam_gp, lam_eps=FLAGS.lam_eps, PG= fl[i],
55 | t=t, use_wscale=FLAGS.use_wscale, is_celeba=FLAGS.celeba)
56 |
57 | pggan.build_model_PGGan()
58 | pggan.train()
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib.layers.python.layers import batch_norm
3 | import numpy as np
4 |
5 | # the implements of leakyRelu
6 | def lrelu(x , alpha=0.2 , name="LeakyReLU"):
7 | with tf.name_scope(name):
8 | return tf.maximum(x , alpha*x)
9 |
10 | def get_weight(shape, gain=np.sqrt(2), use_wscale=False, fan_in=None):
11 | if fan_in is None:
12 | fan_in = np.prod(shape[:-1])
13 | print "current", shape[:-1], fan_in
14 | std = gain / np.sqrt(fan_in) # He init
15 |
16 | if use_wscale:
17 | wscale = tf.constant(np.float32(std), name='wscale')
18 | return tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal()) * wscale
19 | else:
20 | return tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal(0, std))
21 |
22 | def conv2d(input_, output_dim,
23 | k_h=3, k_w=3, d_h=2, d_w=2, gain=np.sqrt(2), use_wscale=False, padding='SAME',
24 | name="conv2d", with_w=False):
25 | with tf.variable_scope(name):
26 |
27 | w = get_weight([k_h, k_w, input_.shape[-1].value, output_dim], gain=gain, use_wscale=use_wscale)
28 | w = tf.cast(w, input_.dtype)
29 |
30 | if padding == 'Other':
31 | padding = 'VALID'
32 | input_ = tf.pad(input_, [[0,0], [3, 3], [3, 3], [0, 0]], "CONSTANT")
33 |
34 | elif padding == 'VALID':
35 | padding = 'VALID'
36 |
37 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding=padding)
38 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
39 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
40 |
41 | if with_w:
42 | return conv, w, biases
43 |
44 | else:
45 | return conv
46 |
47 | def fully_connect(input_, output_size, gain=np.sqrt(2), use_wscale=False, name=None, with_w=False):
48 | shape = input_.get_shape().as_list()
49 | with tf.variable_scope(name or "Linear"):
50 |
51 | w = get_weight([shape[1], output_size], gain=gain, use_wscale=use_wscale)
52 | w = tf.cast(w, input_.dtype)
53 | bias = tf.get_variable("bias", [output_size], initializer=tf.constant_initializer(0.0))
54 |
55 | output = tf.matmul(input_, w) + bias
56 |
57 | if with_w:
58 | return output, with_w, bias
59 |
60 | else:
61 | return output
62 |
63 | def conv_cond_concat(x, y):
64 | x_shapes = x.get_shape()
65 | y_shapes = y.get_shape()
66 | return tf.concat(3 , [x , y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2] , y_shapes[3]])])
67 |
68 | def batch_normal(input , scope="scope" , reuse=False):
69 | return batch_norm(input , epsilon=1e-5, decay=0.9 , scale=True, scope=scope , reuse= reuse , updates_collections=None)
70 |
71 | def resize_nearest_neighbor(x, new_size):
72 | x = tf.image.resize_nearest_neighbor(x, new_size)
73 | return x
74 |
75 | def upscale(x, scale):
76 | _, h, w, _ = get_conv_shape(x)
77 | return resize_nearest_neighbor(x, (h * scale, w * scale))
78 |
79 | def get_conv_shape(tensor):
80 | shape = int_shape(tensor)
81 | return shape
82 |
83 | def int_shape(tensor):
84 | shape = tensor.get_shape().as_list()
85 | return [num if num is not None else -1 for num in shape]
86 |
87 | def downscale2d(x, k=2):
88 | # avgpool wrapper
89 | return tf.nn.avg_pool(x, ksize=[1, k, k, 1], strides=[1, k, k, 1],
90 | padding='VALID')
91 |
92 | def Pixl_Norm(x, eps=1e-8):
93 | if len(x.shape) > 2:
94 | axis_ = 3
95 | else:
96 | axis_ = 1
97 | with tf.variable_scope('PixelNorm'):
98 | return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=axis_, keep_dims=True) + eps)
99 |
100 | def MinibatchstateConcat(input, averaging='all'):
101 | s = input.shape
102 | adjusted_std = lambda x, **kwargs: tf.sqrt(tf.reduce_mean((x - tf.reduce_mean(x, **kwargs)) **2, **kwargs) + 1e-8)
103 | vals = adjusted_std(input, axis=0, keep_dims=True)
104 | if averaging == 'all':
105 | vals = tf.reduce_mean(vals, keep_dims=True)
106 | else:
107 | print ("nothing")
108 |
109 | vals = tf.tile(vals, multiples=[s[0], s[1], s[2], 1])
110 | return tf.concat([input, vals], axis=3)
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import errno
3 | import numpy as np
4 | import scipy
5 | import scipy.misc
6 | import h5py
7 |
8 | def mkdir_p(path):
9 | try:
10 | os.makedirs(path)
11 | except OSError as exc: # Python >2.5
12 | if exc.errno == errno.EEXIST and os.path.isdir(path):
13 | pass
14 | else:
15 | raise
16 |
17 | class CelebA(object):
18 | def __init__(self, image_path):
19 |
20 | self.dataname = "CelebA"
21 | self.channel = 3
22 | self.image_list = self.load_celebA(image_path=image_path)
23 |
24 | def load_celebA(self, image_path):
25 |
26 | # get the list of image path
27 | images_list = read_image_list(image_path)
28 | # get the data array of image
29 |
30 | return images_list
31 |
32 | def getShapeForData(self, filenames, resize_w=64):
33 | array = [get_image(batch_file, 128, is_crop=True, resize_w=resize_w,
34 | is_grayscale=False) for batch_file in filenames]
35 |
36 | sample_images = np.array(array)
37 | # return sub_image_mean(array , IMG_CHANNEL)
38 | return sample_images
39 |
40 | def getNextBatch(self, batch_num=0, batch_size=64):
41 | ro_num = len(self.image_list) / batch_size - 1
42 | if batch_num % ro_num == 0:
43 |
44 | length = len(self.image_list)
45 | perm = np.arange(length)
46 | np.random.shuffle(perm)
47 | self.image_list = np.array(self.image_list)
48 | self.image_list = self.image_list[perm]
49 |
50 | print ("images shuffle")
51 |
52 | return self.image_list[(batch_num % ro_num) * batch_size: (batch_num % ro_num + 1) * batch_size]
53 |
54 | class CelebA_HQ(object):
55 | def __init__(self, image_path):
56 | self.dataname = "CelebA_HQ"
57 | resolution = ['data2x2', 'data4x4', 'data8x8', 'data16x16', 'data32x32', 'data64x64', \
58 | 'data128x128', 'data256x256', 'data512x512', 'data1024x1024']
59 | self.channel = 3
60 | self.image_list = self.load_celeba_hq(image_path=image_path)
61 | self._base_key = 'data'
62 | self._len = {k: len(self.image_list[k]) for k in resolution}
63 |
64 | def load_celeba_hq(self, image_path):
65 | # get the list of image path
66 | images_list = h5py.File(os.path.join(image_path, "celebA_hq"), 'r')
67 | # get the data array of image
68 | return images_list
69 |
70 | def getNextBatch(self, batch_size=64, resize_w=64):
71 | key = self._base_key + '{}x{}'.format(resize_w, resize_w)
72 | idx = np.random.randint(self._len[key], size=batch_size)
73 | batch_x = np.array([self.image_list[key][i] / 127.5 - 1.0 for i in idx], dtype=np.float32)
74 |
75 | return batch_x
76 |
77 | def get_image(image_path , image_size, is_crop=True, resize_w=64, is_grayscale=False):
78 | return transform(imread(image_path , is_grayscale), image_size, is_crop , resize_w)
79 |
80 | def get_image_dat(image_path , image_size, is_crop=True, resize_w=64, is_grayscale=False):
81 | return transform(imread_dat(image_path , is_grayscale), image_size, is_crop , resize_w)
82 |
83 | def transform(image, npx=64 , is_crop=False, resize_w=64):
84 | # npx : # of pixels width/height of image
85 | if is_crop:
86 | cropped_image = center_crop(image , npx , resize_w = resize_w)
87 | else:
88 | cropped_image = image
89 | cropped_image = scipy.misc.imresize(cropped_image ,
90 | [resize_w , resize_w])
91 |
92 | return np.array(cropped_image)/127.5 - 1
93 |
94 | def center_crop(x, crop_h, crop_w=None, resize_w=64):
95 | if crop_w is None:
96 | crop_w = crop_h
97 | h, w = x.shape[:2]
98 | j = int(round((h - crop_h)/2.))
99 | i = int(round((w - crop_w)/2.))
100 |
101 | rate = np.random.uniform(0, 1, size=1)
102 |
103 | if rate < 0.5:
104 | x = np.fliplr(x)
105 |
106 | return scipy.misc.imresize(x[j:j + crop_h, i:i + crop_w],
107 | [resize_w, resize_w])
108 |
109 | # return scipy.misc.imresize(x[20:218 - 20, 0: 178], [resize_w, resize_w])
110 |
111 | # return scipy.misc.imresize(x[45: 45 + 128, 25:25 + 128], [resize_w, resize_w])
112 |
113 | def save_images(images, size, image_path):
114 | return imsave(inverse_transform(images), size, image_path)
115 |
116 | def imread(path, is_grayscale=False):
117 | if (is_grayscale):
118 | return scipy.misc.imread(path, flatten=True).astype(np.float)
119 | else:
120 | return scipy.misc.imread(path).astype(np.float)
121 |
122 | def imread_dat(path, is_grayscale):
123 | return np.load(path)
124 |
125 | def imsave(images, size, path):
126 | return scipy.misc.imsave(path, merge(images, size))
127 |
128 | def merge(images, size):
129 | h, w = images.shape[1], images.shape[2]
130 | img = np.zeros((h * size[0], w * size[1], 3))
131 | for idx, image in enumerate(images):
132 | i = idx % size[1]
133 | j = idx // size[1]
134 | img[j * h:j * h + h, i * w: i * w + w, :] = image
135 | return img
136 |
137 | def inverse_transform(image):
138 | return ((image + 1.)* 127.5).astype(np.uint8)
139 |
140 | def read_image_list(category):
141 | filenames = []
142 | print("list file")
143 | list = os.listdir(category)
144 | list.sort()
145 | for file in list:
146 | if 'jpg' in file:
147 | filenames.append(category + "/" + file)
148 | print("list file ending!")
149 | length = len(filenames)
150 | perm = np.arange(length)
151 | np.random.shuffle(perm)
152 | filenames = np.array(filenames)
153 | filenames = filenames[perm]
154 |
155 | return filenames
156 |
157 |
158 |
159 |
160 |
--------------------------------------------------------------------------------