├── README.md
├── checkpoint
└── readme.txt
├── code
├── lib
│ ├── __init__.py
│ ├── cgan_triple_alpha.py
│ ├── ops.py
│ ├── utils.py
│ └── vae_cgan_double_alpha.py
└── main_cgan_triple_alpha_celeba.py
├── data
├── download.py
└── preprocess.py
└── samples
└── readme.txt
/README.md:
--------------------------------------------------------------------------------
1 | # Generating images part by part with composite generative adversarial networks
2 | Tensorflow implementation of the paper "Generating images part by part with composite generative adversarial networks".
3 |
4 | The Composite GANs (CGAN) disentangles complicated factors of images with multiple generators in which each generator generates some part of the image. Those parts are combined by an alpha blending process to create a new single image. For example, it can generate background, face, and hair sequentially with three generators. There is no supervision on what each generator should generate.
5 |
6 |
7 |

8 |
9 |
10 |
11 |

12 |
13 |
14 |
15 | #### Dependencies
16 | * tensorflow 0.10.0rc0+
17 | * h5py 2.3.1+
18 | * scipy 0.18.0+
19 | * Pillow 3.1.0+
20 |
21 | #### How to use
22 | First, go into the 'data' directory and download celebA dataset.
23 |
24 | ```
25 | cd data
26 | python download.py celebA
27 | ```
28 |
29 | Preprocess the celebA dataset to create a hdf5 file. It resizes images to 64*64.
30 |
31 | ```
32 | python preprocess.py
33 | ```
34 |
35 | Finally, go into the 'code' directory and run 'main_cgan_triple_alpha_celeba.py'.
36 |
37 | ```
38 | cd ../code
39 | python main_cgan_triple_alpha_celeba.py
40 | ```
41 |
42 | You can see the samples in 'samples' directory.
43 |
44 | ### References
45 |
46 | [1] Hanock Kwak and Byoung-Tak Zhang. "Generating Images Part by Part with Composite Generative Adversarial Networks." arXiv preprint arXiv:1607.05387 (2016).
47 |
--------------------------------------------------------------------------------
/checkpoint/readme.txt:
--------------------------------------------------------------------------------
1 | This folder is where the models are saved!
2 |
--------------------------------------------------------------------------------
/code/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hanock/generating_images_part_by_part/f934aa3f6d18bf635adc2311522fd9d7504f13e8/code/lib/__init__.py
--------------------------------------------------------------------------------
/code/lib/cgan_triple_alpha.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from glob import glob
4 | import tensorflow as tf
5 |
6 | from ops import *
7 | from utils import *
8 |
9 | class GAN(object):
10 | def __init__(self, sess, config):
11 |
12 | self.sess = sess
13 | self.batch_size = config.batch_size
14 | self.train_size = config.train_size
15 | self.image_size = 64
16 | self.image_shape = [64, 64, 3]
17 |
18 | self.checked_img = self.create_checked_img(self.image_size)
19 |
20 | self.z_dim = 128
21 |
22 | self.checkpoint_dir = config.checkpoint_dir
23 | self.sample_dir = config.sample_dir
24 | self.build_model()
25 |
26 | def build_model(self):
27 |
28 | self.images = tf.placeholder(tf.float32, [self.batch_size] + self.image_shape, name='real_images')
29 | self.sample_images= tf.placeholder(tf.float32, [self.batch_size] + self.image_shape, name='sample_images')
30 | self.z1 = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='z1')
31 | self.z2 = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='z2')
32 | self.z3 = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='z3')
33 |
34 | # counter
35 | self.global_step = tf.Variable(0, name='global_step', trainable=False)
36 |
37 | #
38 | self.O1, self.O2, self.O3, self.G, self.A1, self.A2, self.A3 = self.generator(self.z1, self.z2, self.z3, feature=True)
39 | self.D = self.discriminator(self.images)
40 |
41 | self.D_ = self.discriminator(self.G, reuse=True)
42 |
43 | # alpha loss
44 | self.alpha_loss1_1 = tf.reduce_sum(tf.abs(tf.reduce_sum(self.A1, [1, 2, 3]) - self.image_size*self.image_size*0.3))
45 | self.alpha_loss1_2 = tf.reduce_sum(tf.reduce_sum(-tf.square(self.A1 - 0.5) + 0.25, [1, 2, 3]))
46 | self.alpha_loss2_1 = tf.reduce_sum(tf.abs(tf.reduce_sum(self.A2, [1, 2, 3]) - self.image_size*self.image_size*0.3))
47 | self.alpha_loss2_2 = tf.reduce_sum(tf.reduce_sum(-tf.square(self.A2 - 0.5) + 0.25, [1, 2, 3]))
48 | self.alpha_loss3_1 = tf.reduce_sum(tf.abs(tf.reduce_sum(self.A3, [1, 2, 3]) - self.image_size*self.image_size*0.3))
49 | self.alpha_loss3_2 = tf.reduce_sum(tf.reduce_sum(-tf.square(self.A3 - 0.5) + 0.25, [1, 2, 3]))
50 | self.alpha_loss = self.alpha_loss1_1 + self.alpha_loss1_2 + self.alpha_loss2_1 + self.alpha_loss2_2 + self.alpha_loss3_1 + self.alpha_loss3_2
51 |
52 | # d loss
53 | self.d_loss_real = binary_cross_entropy_with_logits(tf.ones_like(self.D), self.D)
54 | self.d_loss_fake = binary_cross_entropy_with_logits(tf.zeros_like(self.D_), self.D_)
55 | self.d_loss = self.d_loss_real + self.d_loss_fake
56 |
57 | # g loss
58 | self.g_gan_loss = binary_cross_entropy_with_logits(tf.ones_like(self.D_), self.D_)
59 | self.g_loss = self.g_gan_loss + 0.000005*self.alpha_loss
60 |
61 |
62 | t_vars = tf.trainable_variables()
63 |
64 | self.d_vars = [var for var in t_vars if 'discriminator' in var.name]
65 | self.g_vars = [var for var in t_vars if 'generator' in var.name]
66 |
67 | self.saver = tf.train.Saver()
68 |
69 | def train(self, data):
70 | """Train DCGAN"""
71 |
72 | data_size = len(data)
73 |
74 | d_optim = tf.train.AdamOptimizer(0.0002, beta1=0.5).minimize(self.d_loss, var_list=self.d_vars)
75 | g_optim = tf.train.AdamOptimizer(0.0004, beta1=0.5).minimize(self.g_loss, var_list=self.g_vars)
76 | tf.initialize_all_variables().run()
77 |
78 | self.saver = tf.train.Saver()
79 |
80 | sample_images = self.transform(data[0:self.batch_size]).astype(np.float32)
81 | save_images(sample_images[0:100], [10, 10], '%s/sample_images.png' % (self.sample_dir))
82 |
83 |
84 | start_time = time.time()
85 |
86 | if self.load(self.checkpoint_dir):
87 | print(" [*] Load SUCCESS")
88 | counter = self.global_step.eval()
89 |
90 |
91 | for inf_loop in xrange(1000000):
92 |
93 | errD = 0
94 | errG = 0
95 |
96 | # random mini-batch
97 | i = np.random.randint(0, self.train_size - self.batch_size)
98 | batch_images = self.transform(data[i:i + self.batch_size]).astype(np.float32)
99 |
100 | # Update D network
101 | batch_z1 = self.random_z()
102 | batch_z2 = self.random_z()
103 | batch_z3 = self.random_z()
104 | self.sess.run(d_optim, feed_dict={ self.images: batch_images, self.z1: batch_z1, self.z2: batch_z2, self.z3: batch_z3 })
105 |
106 | # Update G network
107 | self.sess.run(g_optim, feed_dict={ self.z1: batch_z1, self.z2: batch_z2, self.z3: batch_z3 })
108 |
109 | counter += 1
110 |
111 | if np.mod(counter, 10) == 1:
112 | # random mini-batch
113 | i = np.random.randint(0, self.train_size - self.batch_size)
114 | batch_images = self.transform(data[i:i + self.batch_size]).astype(np.float32)
115 | batch_z1 = self.random_z()
116 | batch_z2 = self.random_z()
117 | batch_z3 = self.random_z()
118 |
119 | d_loss_fake = self.d_loss_fake.eval({self.z1: batch_z1, self.z2: batch_z2, self.z3: batch_z3 })
120 | d_loss_real = self.d_loss_real.eval({self.images: batch_images})
121 | errD = d_loss_fake + d_loss_real
122 | errG = self.g_loss.eval({self.z1: batch_z1, self.z2: batch_z2, self.z3: batch_z3})
123 |
124 | print("[%5d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
125 | % (counter, time.time() - start_time, errD, errG))
126 |
127 | if np.mod(counter, 200) == 1:
128 | #for gi in np.arange(1):
129 | samp_o1, samp_o2, samp_o3, samp = self.sess.run(
130 | [self.O1, self.O2, self.O3, self.G],
131 | feed_dict={self.z1: self.random_z(), self.z2: self.random_z(), self.z3: self.random_z()})
132 |
133 |
134 | samples = np.concatenate((samp_o1, samp_o2, samp_o3, samp), axis=0)
135 |
136 | save_images(samples, [32, 16], '%s/train_%05d.png' % (self.sample_dir, counter))
137 |
138 | if np.mod(counter, 500) == 1:
139 | self.save(self.checkpoint_dir, counter)
140 |
141 | def discriminator(self, image, reuse=False):
142 |
143 | with tf.variable_scope('discriminator', reuse=reuse):
144 | h0 = lrelu(conv2d(image, 64, name='d_h0_conv'))
145 |
146 | d_bn1 = batch_norm(self.batch_size, name='d_bn1')
147 | h1 = lrelu(d_bn1(conv2d(h0, 64*2, name='d_h1_conv')))
148 |
149 | d_bn2 = batch_norm(self.batch_size, name='d_bn2')
150 | h2 = lrelu(d_bn2(conv2d(h1, 64*4, name='d_h2_conv')))
151 |
152 | d_bn3 = batch_norm(self.batch_size, name='d_bn3')
153 | h3 = lrelu(d_bn3(conv2d(h2, 64*8, name='d_h3_conv')))
154 |
155 | h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h3_lin')
156 |
157 | return tf.nn.sigmoid(h4)
158 |
159 | def generator(self, z1, z2, z3, reuse=False, feature=False):
160 |
161 | def gen(h, gen_name='gen', reuse=False):
162 | with tf.variable_scope(gen_name, reuse=reuse):
163 | h0 = linear(h, 512*4*4, 'g_h0_lin')
164 | h0 = tf.reshape(h0, [-1, 4, 4, 512])
165 | h0 = tf.nn.relu(h0)
166 |
167 | h1 = deconv2d(h0, [self.batch_size, 8, 8, 256], name='g_h1')
168 | g_bn1 = batch_norm(self.batch_size, name='g_bn1')
169 | h1 = tf.nn.relu(g_bn1(h1))
170 |
171 | h2 = deconv2d(h1,[self.batch_size, 16, 16, 128], name='g_h2')
172 | g_bn2 = batch_norm(self.batch_size, name='g_bn2')
173 | h2 = tf.nn.relu(g_bn2(h2))
174 |
175 | h3 = deconv2d(h2, [self.batch_size, 32, 32, 64], name='g_h3')
176 | g_bn3 = batch_norm(self.batch_size, name='g_bn3')
177 | h3 = tf.nn.relu(g_bn3(h3))
178 |
179 | h4 = deconv2d(h3, [self.batch_size, 64, 64, 3], name='g_h4')
180 | alpha = deconv2d(h3, [self.batch_size, 64, 64, 1], name='g_a')
181 |
182 | return tf.sigmoid(h4), tf.sigmoid(alpha)
183 |
184 | with tf.variable_scope('generator', reuse=reuse):
185 | lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(128, forget_bias=0.0)
186 | state = lstm_cell.zero_state(self.batch_size, tf.float32)
187 |
188 | with tf.variable_scope('g_rnn'):
189 | (cell_output1, state) = lstm_cell(z1, state)
190 | tf.get_variable_scope().reuse_variables()
191 | (cell_output2, state) = lstm_cell(z2, state)
192 | (cell_output3, state) = lstm_cell(z3, state)
193 |
194 | rgb1, alpha1 = gen(cell_output1, 'gen1')
195 | rgb2, alpha2 = gen(cell_output2, 'gen2')
196 | rgb3, alpha3 = gen(cell_output3, 'gen3')
197 |
198 | a_norm = (alpha1*(1 - alpha2) + alpha2)*(1 - alpha3) + alpha3
199 | a1 = alpha1/a_norm
200 | a2 = alpha2/a_norm
201 | a3 = alpha3/a_norm
202 |
203 | o1 = self.checked_img*(1 - a1) + rgb1*a1
204 | o2 = self.checked_img*(1 - a2) + rgb2*a2
205 | o3 = self.checked_img*(1 - a3) + rgb3*a3
206 | o = (rgb1*alpha1*(1 - alpha2) + rgb2*alpha2)*(1 - alpha3) + rgb3*alpha3
207 |
208 | if feature:
209 | return o1*2 - 1, o2*2 - 1, o3*2 - 1, o*2 - 1, alpha1, alpha2, alpha3
210 |
211 | return o*2 - 1
212 |
213 | def random_z(self):
214 | return np.random.normal(size=(self.batch_size, self.z_dim))
215 |
216 | def random_fix_z(self):
217 | r = np.zeros([self.batch_size, self.z_dim])
218 | r[0:,:] = np.random.normal(size=(1, self.z_dim))
219 | return r
220 |
221 | def save(self, checkpoint_dir, step):
222 | self.sess.run(self.global_step.assign(step))
223 |
224 | model_name = "GAN.model"
225 |
226 | if not os.path.exists(checkpoint_dir):
227 | os.makedirs(checkpoint_dir)
228 |
229 | self.saver.save(self.sess,
230 | os.path.join(checkpoint_dir, model_name),
231 | global_step=step)
232 |
233 | def load(self, checkpoint_dir):
234 | print(" [*] Reading checkpoints...")
235 |
236 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
237 | if ckpt and ckpt.model_checkpoint_path:
238 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
239 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
240 | return True
241 | else:
242 | return False
243 |
244 | def transform(self, X):
245 | return X*2 - 1
246 |
247 | def inverse_transform(self, X):
248 | X = (X+1.)/2.
249 | return X
250 |
251 | def create_checked_img(self, size):
252 | arr = np.arange(size)
253 | chk1 = arr[np.mod(arr, size/8) < size/16]
254 | chk2 = arr[np.mod(arr, size/8) >= size/16]
255 | a = np.meshgrid(chk1, chk1)
256 | b = np.meshgrid(chk2, chk2)
257 |
258 | img = np.ones([size, size, 3], dtype=np.float32)
259 |
260 | img[a[0], a[1]] = 0
261 | img[b[0], b[1]] = 0
262 |
263 | return img
--------------------------------------------------------------------------------
/code/lib/ops.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import tensorflow as tf
4 |
5 | from tensorflow.python.framework import ops
6 | from tensorflow.python import control_flow_ops
7 |
8 | from utils import *
9 |
10 | class batch_norm(object):
11 | """Code modification of http://stackoverflow.com/a/33950177"""
12 | def __init__(self, batch_size, epsilon=1e-5, momentum = 0.1, name="batch_norm"):
13 | with tf.variable_scope(name) as scope:
14 | self.epsilon = epsilon
15 | self.momentum = momentum
16 | self.batch_size = batch_size
17 |
18 | self.ema = tf.train.ExponentialMovingAverage(decay=self.momentum)
19 | self.name=name
20 |
21 | def __call__(self, x, train=True):
22 | shape = x.get_shape().as_list()
23 |
24 | with tf.variable_scope(self.name) as scope:
25 | self.gamma = tf.get_variable("gamma", [shape[-1]],
26 | initializer=tf.random_normal_initializer(1., 0.02))
27 | self.beta = tf.get_variable("beta", [shape[-1]],
28 | initializer=tf.constant_initializer(0.))
29 |
30 | self.mean, self.variance = tf.nn.moments(x, [0, 1, 2])
31 |
32 | return tf.nn.batch_norm_with_global_normalization(
33 | x, self.mean, self.variance, self.beta, self.gamma, self.epsilon,
34 | scale_after_normalization=True)
35 |
36 |
37 | def batch_norm2(x, n_out, phase_train, scope='bn', affine=True):
38 | """
39 | http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow/33950177
40 | Batch normalization on convolutional maps.
41 | Args:
42 | x: Tensor, 4D BHWD input maps
43 | n_out: integer, depth (channel) of input maps
44 | phase_train: boolean tf.Variable, true indicates training phase
45 | scope: string, variable scope
46 | affine: whether to affine-transform outputs
47 | Return:
48 | normed: batch-normalized maps
49 | """
50 | with tf.variable_scope(scope):
51 | beta = tf.Variable(tf.constant(0.0, shape=[n_out]),
52 | name='beta', trainable=True)
53 | gamma = tf.Variable(tf.constant(1.0, shape=[n_out]),
54 | name='gamma', trainable=affine)
55 |
56 | batch_mean, batch_var = tf.nn.moments(x, [0,1,2], name='moments')
57 | ema = tf.train.ExponentialMovingAverage(decay=0.9)
58 | ema_apply_op = ema.apply([batch_mean, batch_var])
59 | ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
60 | def mean_var_with_update():
61 | with tf.control_dependencies([ema_apply_op]):
62 | return tf.identity(batch_mean), tf.identity(batch_var)
63 | mean, var = control_flow_ops.cond(phase_train,
64 | mean_var_with_update,
65 | lambda: (ema_mean, ema_var))
66 |
67 | normed = tf.nn.batch_norm_with_global_normalization(x, mean, var,
68 | beta, gamma, 1e-3, affine)
69 | return normed
70 |
71 |
72 | def binary_cross_entropy_with_logits(logits, targets, name=None):
73 | """Computes binary cross entropy given `logits`.
74 | For brevity, let `x = logits`, `z = targets`. The logistic loss is
75 | loss(x, z) = - sum_i (x[i] * log(z[i]) + (1 - x[i]) * log(1 - z[i]))
76 | Args:
77 | logits: A `Tensor` of type `float32` or `float64`.
78 | targets: A `Tensor` of the same type and shape as `logits`.
79 | """
80 | eps = 1e-12
81 | with ops.op_scope([logits, targets], name, "bce_loss") as name:
82 | logits = ops.convert_to_tensor(logits, name="logits")
83 | targets = ops.convert_to_tensor(targets, name="targets")
84 | return tf.reduce_mean(-(logits * tf.log(targets + eps) +
85 | (1. - logits) * tf.log(1. - targets + eps)))
86 |
87 | # zero-based
88 | def dense_to_one_hot(labels_dense, num_classes=10):
89 | """Convert class labels from scalars to one-hot vectors."""
90 | num_labels = labels_dense.shape[0]
91 | index_offset = np.arange(num_labels) * num_classes
92 | labels_one_hot = np.zeros((num_labels, num_classes))
93 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
94 | return labels_one_hot
95 |
96 | def conv_cond_concat(x, y):
97 | """Concatenate conditioning vector on feature map axis."""
98 | x_shapes = x.get_shape()
99 | y_shapes = y.get_shape()
100 | return tf.concat(3, [x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])])
101 |
102 | def conv2d(input_, output_dim,
103 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
104 | name="conv2d"):
105 | with tf.variable_scope(name):
106 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
107 | initializer=tf.truncated_normal_initializer(stddev=stddev))
108 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')
109 |
110 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
111 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
112 |
113 | return conv
114 |
115 | def conv2d_v2(input_, output_dim,
116 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
117 | name="conv2d"):
118 | with tf.variable_scope(name):
119 |
120 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
121 | initializer=tf.truncated_normal_initializer(stddev=stddev))
122 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')
123 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
124 |
125 | return tf.nn.bias_add(conv, biases)
126 |
127 | def deconv2d(input_, output_shape,
128 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
129 | name="deconv2d", with_w=False):
130 | with tf.variable_scope(name):
131 | # filter : [height, width, output_channels, in_channels]
132 | w = tf.get_variable('w', [k_h, k_h, output_shape[-1], input_.get_shape()[-1]],
133 | initializer=tf.random_normal_initializer(stddev=stddev))
134 |
135 | try:
136 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape,
137 | strides=[1, d_h, d_w, 1])
138 |
139 | # Support for verisons of TensorFlow before 0.7.0
140 | except AttributeError:
141 | deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape,
142 | strides=[1, d_h, d_w, 1])
143 |
144 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
145 | deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())
146 |
147 | if with_w:
148 | return deconv, w, biases
149 | else:
150 | return deconv
151 |
152 | def lrelu(x, leak=0.2, name="lrelu"):
153 | with tf.variable_scope(name):
154 | f1 = 0.5 * (1 + leak)
155 | f2 = 0.5 * (1 - leak)
156 | return f1 * x + f2 * abs(x)
157 |
158 | def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):
159 | shape = input_.get_shape().as_list()
160 |
161 | with tf.variable_scope(scope or "Linear"):
162 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32,
163 | tf.random_normal_initializer(stddev=stddev))
164 | bias = tf.get_variable("bias", [output_size],
165 | initializer=tf.constant_initializer(bias_start))
166 | if with_w:
167 | return tf.matmul(input_, matrix) + bias, matrix, bias
168 | else:
169 | return tf.matmul(input_, matrix) + bias
--------------------------------------------------------------------------------
/code/lib/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Some codes from https://github.com/Newmu/dcgan_code
3 | """
4 | import math
5 | import json
6 | import random
7 | import pprint
8 | import scipy.misc
9 | import numpy as np
10 | from cStringIO import StringIO
11 | from time import gmtime, strftime
12 |
13 | pp = pprint.PrettyPrinter()
14 |
15 | get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])
16 |
17 |
18 |
19 | def save_sentence_pairs(X, Y, dic, fname, CODE_EOF=1):
20 | data_num = X.shape[0]
21 | max_time = X.shape[1]
22 |
23 | file_str = StringIO()
24 | for i in xrange(data_num):
25 | write_sentence(file_str, dic, X[i, :], CODE_EOF)
26 | write_sentence(file_str, dic, Y[i, :], CODE_EOF)
27 | file_str.write('\n')
28 |
29 | with open(fname, 'w') as fp:
30 | fp.write(file_str.getvalue())
31 |
32 | def save_sentences(X, dic, fname, CODE_EOF=1):
33 | data_num = X.shape[0]
34 | max_time = X.shape[1]
35 |
36 | file_str = StringIO()
37 | for i in xrange(data_num):
38 | write_sentence(file_str, dic, X[i, :], CODE_EOF)
39 | file_str.write('\n')
40 |
41 | with open(fname, 'w') as fp:
42 | fp.write(file_str.getvalue())
43 |
44 | def write_sentence(fp, dic, sentence, CODE_EOF=1):
45 | max_time = sentence.shape[0]
46 |
47 | for i in xrange(max_time):
48 | v = sentence[i]
49 | if v == 0:
50 | break
51 | if v == CODE_EOF:
52 | fp.write('EOF')
53 | break
54 |
55 | try:
56 | w = dic.keys()[dic.values().index(v)]
57 | except ValueError:
58 | w = 'NULL'
59 | fp.write(w + ' ')
60 | fp.write('\n')
61 |
62 | def get_image(image_path, image_size, is_crop=True):
63 | return transform(imread(image_path), image_size, is_crop)
64 |
65 | def save_images(images, size, image_path):
66 | return imsave(inverse_transform(images), size, image_path)
67 |
68 | def imread(path):
69 | return scipy.misc.imread(path).astype(np.float)
70 |
71 | def merge_images(images, size):
72 | return inverse_transform(images)
73 |
74 | def merge(images, size):
75 | h, w = images.shape[1], images.shape[2]
76 | img = np.zeros((h * size[0], w * size[1], 3))
77 |
78 | for idx, image in enumerate(images):
79 | i = idx % size[1]
80 | j = idx / size[1]
81 | img[j*h:j*h+h, i*w:i*w+w, :] = image
82 |
83 | return img
84 |
85 | def imsave(images, size, path):
86 | return scipy.misc.imsave(path, merge(images, size))
87 |
88 | def center_crop(x, crop_h, crop_w=None, resize_w=64):
89 | if crop_w is None:
90 | crop_w = crop_h
91 | h, w = x.shape[:2]
92 | j = int(round((h - crop_h)/2.))
93 | i = int(round((w - crop_w)/2.))
94 | return scipy.misc.imresize(x[j:j+crop_h, i:i+crop_w],
95 | [resize_w, resize_w])
96 |
97 | def transform(image, npx=64, is_crop=True):
98 | # npx : # of pixels width/height of image
99 | if is_crop:
100 | cropped_image = center_crop(image, npx)
101 | else:
102 | cropped_image = image
103 | return np.array(cropped_image)/127.5 - 1.
104 |
105 | def inverse_transform(images):
106 | return (images+1.)/2.
107 |
108 |
109 | def to_json(output_path, *layers):
110 | with open(output_path, "w") as layer_f:
111 | lines = ""
112 | for w, b, bn in layers:
113 | layer_idx = w.name.split('/')[0].split('h')[1]
114 |
115 | B = b.eval()
116 |
117 | if "lin/" in w.name:
118 | W = w.eval()
119 | depth = W.shape[1]
120 | else:
121 | W = np.rollaxis(w.eval(), 2, 0)
122 | depth = W.shape[0]
123 |
124 | biases = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(B)]}
125 | if bn != None:
126 | gamma = bn.gamma.eval()
127 | beta = bn.beta.eval()
128 |
129 | gamma = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(gamma)]}
130 | beta = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(beta)]}
131 | else:
132 | gamma = {"sy": 1, "sx": 1, "depth": 0, "w": []}
133 | beta = {"sy": 1, "sx": 1, "depth": 0, "w": []}
134 |
135 | if "lin/" in w.name:
136 | fs = []
137 | for w in W.T:
138 | fs.append({"sy": 1, "sx": 1, "depth": W.shape[0], "w": ['%.2f' % elem for elem in list(w)]})
139 |
140 | lines += """
141 | var layer_%s = {
142 | "layer_type": "fc",
143 | "sy": 1, "sx": 1,
144 | "out_sx": 1, "out_sy": 1,
145 | "stride": 1, "pad": 0,
146 | "out_depth": %s, "in_depth": %s,
147 | "biases": %s,
148 | "gamma": %s,
149 | "beta": %s,
150 | "filters": %s
151 | };""" % (layer_idx.split('_')[0], W.shape[1], W.shape[0], biases, gamma, beta, fs)
152 | else:
153 | fs = []
154 | for w_ in W:
155 | fs.append({"sy": 5, "sx": 5, "depth": W.shape[3], "w": ['%.2f' % elem for elem in list(w_.flatten())]})
156 |
157 | lines += """
158 | var layer_%s = {
159 | "layer_type": "deconv",
160 | "sy": 5, "sx": 5,
161 | "out_sx": %s, "out_sy": %s,
162 | "stride": 2, "pad": 1,
163 | "out_depth": %s, "in_depth": %s,
164 | "biases": %s,
165 | "gamma": %s,
166 | "beta": %s,
167 | "filters": %s
168 | };""" % (layer_idx, 2**(int(layer_idx)+2), 2**(int(layer_idx)+2),
169 | W.shape[0], W.shape[3], biases, gamma, beta, fs)
170 | layer_f.write(" ".join(lines.replace("'","").split()))
171 |
172 | def make_gif(images, fname, duration=2, true_image=False):
173 | import moviepy.editor as mpy
174 |
175 | def make_frame(t):
176 | try:
177 | x = images[int(len(images)/duration*t)]
178 | except:
179 | x = images[-1]
180 |
181 | if true_image:
182 | return x.astype(np.uint8)
183 | else:
184 | return ((x+1)/2*255).astype(np.uint8)
185 |
186 | clip = mpy.VideoClip(make_frame, duration=duration)
187 | clip.write_gif(fname, fps = len(images) / duration)
188 |
189 | def visualize(sess, dcgan, config, option):
190 | if option == 0:
191 | z_sample = np.random.uniform(-1, 1, size=(config.batch_size, dcgan.z_dim))
192 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
193 | save_images(samples[0:121], [11, 11], './samples/test_%s.png' % strftime("%Y-%m-%d %H:%M:%S", gmtime()))
194 | elif option == 1:
195 | values = np.arange(0, 1, 1./config.batch_size)
196 | for idx in xrange(100):
197 | print(" [*] %d" % idx)
198 | z_sample = np.zeros([config.batch_size, dcgan.z_dim])
199 | for kdx, z in enumerate(z_sample):
200 | z[idx] = values[kdx]
201 |
202 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
203 | save_images(samples, [8, 8], './samples/test_arange_%s.png' % (idx))
204 |
205 | elif option == 2:
206 | values = np.arange(0, 1, 1./config.batch_size)
207 | for idx in [random.randint(0, 99) for _ in xrange(100)]:
208 | print(" [*] %d" % idx)
209 | z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim))
210 | z_sample = np.tile(z, (config.batch_size, 1))
211 | #z_sample = np.zeros([config.batch_size, dcgan.z_dim])
212 | for kdx, z in enumerate(z_sample):
213 | z[idx] = values[kdx]
214 |
215 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
216 | make_gif(samples, './samples/test_gif_%s.gif' % (idx))
217 |
218 | elif option == 3:
219 | values = np.arange(0, 1, 1./config.batch_size)
220 | for idx in xrange(100):
221 | print(" [*] %d" % idx)
222 | z_sample = np.zeros([config.batch_size, dcgan.z_dim])
223 | for kdx, z in enumerate(z_sample):
224 | z[idx] = values[kdx]
225 |
226 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
227 | make_gif(samples, './samples/test_gif_%s.gif' % (idx))
228 |
229 | elif option == 4:
230 | image_set = []
231 | values = np.arange(0, 1, 1./config.batch_size)
232 |
233 | for idx in xrange(100):
234 | print(" [*] %d" % idx)
235 | z_sample = np.zeros([config.batch_size, dcgan.z_dim])
236 | for kdx, z in enumerate(z_sample):
237 | z[idx] = values[kdx]
238 |
239 | image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}))
240 | make_gif(image_set[-1], './samples/test_gif_%s.gif' % (idx))
241 |
242 | new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) for idx in range(64) + range(63, -1, -1)]
243 | make_gif(new_image_set, './samples/test_gif_merged.gif', duration=8)
--------------------------------------------------------------------------------
/code/lib/vae_cgan_double_alpha.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from glob import glob
4 | import tensorflow as tf
5 |
6 | from ops import *
7 | from utils import *
8 |
9 | class GAN(object):
10 | def __init__(self, sess, config):
11 |
12 | self.sess = sess
13 | self.batch_size = config.batch_size
14 | self.train_size = config.train_size
15 | self.image_size = 64
16 | self.image_shape = [64, 64, 3]
17 |
18 | self.checked_img = self.create_checked_img(self.image_size)
19 |
20 | self.z_dim = 128
21 |
22 | self.checkpoint_dir = config.checkpoint_dir
23 | self.sample_dir = config.sample_dir
24 | self.build_model()
25 |
26 | def build_model(self):
27 |
28 | self.images = tf.placeholder(tf.float32, [self.batch_size] + self.image_shape, name='real_images')
29 | self.sample_images= tf.placeholder(tf.float32, [self.batch_size] + self.image_shape, name='sample_images')
30 | self.z1 = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='z1')
31 | self.z2 = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='z2')
32 |
33 | # counter
34 | self.global_step = tf.Variable(0, name='global_step', trainable=False)
35 |
36 | #
37 | self.mean1, self.stddev1, self.mean2, self.stddev2 = self.encoder(self.images)
38 | self.vae_output, self.O1, self.O2, self.A1, self.A2 = self.generator(self.mean1, self.stddev1, self.mean2, self.stddev2)
39 | self.D = self.discriminator(self.images)
40 |
41 | self.G, _, _, _, _ = self.generator(self.z1, tf.zeros([self.batch_size, self.z_dim]), self.z2, tf.zeros([self.batch_size, self.z_dim]), reuse=True)
42 | self.D_ = self.discriminator(self.G, reuse=True)
43 |
44 | self.feature_out = self.discriminator(self.vae_output, reuse=True, feature=True)
45 | self.feature_img = self.discriminator(self.images, reuse=True, feature=True)
46 |
47 |
48 | # alpha loss
49 | self.alpha_loss1_1 = tf.reduce_sum(tf.abs(tf.reduce_sum(self.A1, [1, 2, 3]) - self.image_size*self.image_size*0.3))
50 | self.alpha_loss1_2 = tf.reduce_sum(tf.reduce_sum(-tf.square(self.A1 - 0.5) + 0.25, [1, 2, 3]))
51 | self.alpha_loss2_1 = tf.reduce_sum(tf.abs(tf.reduce_sum(self.A2, [1, 2, 3]) - self.image_size*self.image_size*0.3))
52 | self.alpha_loss2_2 = tf.reduce_sum(tf.reduce_sum(-tf.square(self.A2 - 0.5) + 0.25, [1, 2, 3]))
53 | self.alpha_loss = self.alpha_loss1_1 + self.alpha_loss1_2 + self.alpha_loss2_1 + self.alpha_loss2_2
54 |
55 | # prior loss
56 | self.prior_loss1 = tf.reduce_sum(0.5 * (tf.square(self.mean1) + tf.square(self.stddev1) - 2.0 * tf.log(self.stddev1 + 1e-8) - 1.0))
57 | self.prior_loss2 = tf.reduce_sum(0.5 * (tf.square(self.mean2) + tf.square(self.stddev2) - 2.0 * tf.log(self.stddev2 + 1e-8) - 1.0))
58 | self.prior_loss = self.prior_loss1 + self.prior_loss2
59 |
60 | # recon loss
61 | self.recon_loss_raw = tf.reduce_sum(tf.square(self.vae_output - self.images))
62 | self.recon_loss_adv = tf.reduce_sum(tf.square(self.feature_out - self.feature_img))
63 |
64 | # e loss
65 | self.e_loss = self.prior_loss + self.recon_loss_adv + 0.2*self.recon_loss_raw + 0.2*self.alpha_loss
66 |
67 | # g loss
68 | self.g_gan_loss = binary_cross_entropy_with_logits(tf.ones_like(self.D_), self.D_)
69 | self.g_loss = self.g_gan_loss + 0.0005*self.recon_loss_adv + 0.0001*self.recon_loss_raw + 0.0001*self.alpha_loss
70 |
71 | # d loss
72 | self.d_loss_real = binary_cross_entropy_with_logits(tf.ones_like(self.D), self.D)
73 | self.d_loss_fake = binary_cross_entropy_with_logits(tf.zeros_like(self.D_), self.D_)
74 | self.d_loss = self.d_loss_real + self.d_loss_fake
75 |
76 |
77 | t_vars = tf.trainable_variables()
78 |
79 | self.e_vars = [var for var in t_vars if 'encoder' in var.name]
80 | self.d_vars = [var for var in t_vars if 'discriminator' in var.name]
81 | self.g_vars = [var for var in t_vars if 'generator' in var.name]
82 |
83 | self.saver = tf.train.Saver()
84 |
85 | def train(self, data):
86 | """Train DCGAN"""
87 |
88 | data_size = len(data)
89 |
90 | e_optim = tf.train.AdamOptimizer(0.0002, beta1=0.5).minimize(self.e_loss, var_list=self.e_vars)
91 | d_optim = tf.train.AdamOptimizer(0.0002, beta1=0.5).minimize(self.d_loss, var_list=self.d_vars)
92 | g_optim = tf.train.AdamOptimizer(0.0004, beta1=0.5).minimize(self.g_loss, var_list=self.g_vars)
93 | tf.initialize_all_variables().run()
94 |
95 | self.saver = tf.train.Saver()
96 |
97 | sample_images = self.transform(data[0:self.batch_size]).astype(np.float32)
98 | save_images(sample_images, [8, 16], '%s/sample_images.png' % (self.sample_dir))
99 |
100 |
101 | start_time = time.time()
102 |
103 | if self.load(self.checkpoint_dir):
104 | print(" [*] Load SUCCESS")
105 | counter = self.global_step.eval()
106 |
107 |
108 | for inf_loop in xrange(1000000):
109 |
110 | errD = 0
111 | errG = 0
112 |
113 | # random mini-batch
114 | i = np.random.randint(0, self.train_size - self.batch_size)
115 | batch_images = self.transform(data[i:i + self.batch_size]).astype(np.float32)
116 |
117 | # Update E network
118 | self.sess.run([e_optim], feed_dict={ self.images: batch_images })
119 |
120 | # Update D network
121 | batch_z1 = self.random_z()
122 | batch_z2 = self.random_z()
123 | self.sess.run(d_optim, feed_dict={ self.images: batch_images, self.z1: batch_z1, self.z2: batch_z2 })
124 |
125 | # Update G network
126 | self.sess.run(g_optim, feed_dict={ self.images: batch_images, self.z1: batch_z1, self.z2: batch_z2 })
127 |
128 | counter += 1
129 |
130 | if np.mod(counter, 10) == 1:
131 | # random mini-batch
132 | i = np.random.randint(0, self.train_size - self.batch_size)
133 | batch_images = self.transform(data[i:i + self.batch_size]).astype(np.float32)
134 | batch_z1 = self.random_z()
135 | batch_z2 = self.random_z()
136 |
137 | recon_loss_adv = self.recon_loss_adv.eval({self.images: batch_images})
138 | recon_loss_raw = self.recon_loss_raw.eval({self.images: batch_images})
139 | alpha_loss = self.alpha_loss.eval({self.images: batch_images})
140 | g_gan_loss = self.g_gan_loss.eval({self.z1: batch_z1, self.z2: batch_z2})
141 | d_loss = self.d_loss.eval({self.z1: batch_z1, self.z2: batch_z2, self.images: batch_images})
142 |
143 | print("[%5d] time: %4.4f, adv: %.1f, raw: %.1f, alpha: %.1f, d: %.8f, g_gan: %.8f" \
144 | % (counter, time.time() - start_time, recon_loss_adv, recon_loss_raw, alpha_loss, d_loss, g_gan_loss))
145 |
146 | if np.mod(counter, 100) == 1:
147 | #for gi in np.arange(1):
148 | samp_o1, samp_o2, samp = self.sess.run(
149 | [self.O1, self.O2, self.vae_output],
150 | feed_dict={self.images: sample_images})
151 |
152 | samp_fix = self.sess.run(
153 | self.G,
154 | feed_dict={self.z1: self.random_fix_z(), self.z2: self.random_z()})
155 |
156 | samples = np.concatenate((samp_fix, samp_o1, samp_o2, samp), axis=0)
157 |
158 | save_images(samples, [32, 16], '%s/train_%05d.png' % (self.sample_dir, counter))
159 |
160 |
161 | if np.mod(counter, 500) == 1:
162 | self.save(self.checkpoint_dir, counter)
163 |
164 | def encoder(self, image, reuse=False):
165 |
166 | def enc(image, name='enc', reuse=False):
167 | with tf.variable_scope(name, reuse=reuse):
168 | h0 = lrelu(conv2d(image, 64, name='e_h0_conv'))
169 |
170 | d_bn1 = batch_norm(self.batch_size, name='e_bn1')
171 | h1 = lrelu(d_bn1(conv2d(h0, 64*2, name='e_h1_conv')))
172 |
173 | d_bn2 = batch_norm(self.batch_size, name='e_bn2')
174 | h2 = lrelu(d_bn2(conv2d(h1, 64*4, name='e_h2_conv')))
175 |
176 | d_bn3 = batch_norm(self.batch_size, name='e_bn3')
177 | h3 = lrelu(d_bn3(conv2d(h2, 64*8, name='e_h3_conv')))
178 |
179 | h3_flat = tf.reshape(h3, [self.batch_size, -1])
180 | u = linear(h3_flat, self.z_dim, 'e_u_lin')
181 | s = tf.sqrt(tf.exp(linear(h3_flat, self.z_dim, 'e_s_lin')))
182 |
183 | return u, s
184 |
185 | with tf.variable_scope('encoder', reuse=reuse):
186 | u1, s1 = enc(image, 'enc1', reuse)
187 | u2, s2 = enc(image, 'enc2', reuse)
188 |
189 | return u1, s1, u2, s2
190 |
191 | def discriminator(self, image, feature=False, reuse=False):
192 |
193 | with tf.variable_scope('discriminator', reuse=reuse):
194 | h0 = lrelu(conv2d(image, 64, name='d_h0_conv'))
195 |
196 | d_bn1 = batch_norm(self.batch_size, name='d_bn1')
197 | h1 = lrelu(d_bn1(conv2d(h0, 64*2, name='d_h1_conv')))
198 |
199 | d_bn2 = batch_norm(self.batch_size, name='d_bn2')
200 | h2 = lrelu(d_bn2(conv2d(h1, 64*4, name='d_h2_conv')))
201 |
202 | d_bn3 = batch_norm(self.batch_size, name='d_bn3')
203 | h3 = lrelu(d_bn3(conv2d(h2, 64*8, name='d_h3_conv')))
204 |
205 | h3_flat = tf.reshape(h3, [self.batch_size, -1])
206 | h4 = linear(h3_flat, 1, 'd_h3_lin')
207 |
208 | if feature:
209 | return h3_flat
210 |
211 | return tf.nn.sigmoid(h4)
212 |
213 | def generator(self, u1, s1, u2, s2, name='generator', reuse=False, feature=False):
214 |
215 | def gen(h, gen_name='gen', reuse=False):
216 | with tf.variable_scope(gen_name, reuse=reuse):
217 | h0 = linear(h, 512*4*4, 'g_h0_lin')
218 | h0 = tf.reshape(h0, [-1, 4, 4, 512])
219 | h0 = tf.nn.relu(h0)
220 |
221 | h1 = deconv2d(h0, [self.batch_size, 8, 8, 256], name='g_h1')
222 | g_bn1 = batch_norm(self.batch_size, name='g_bn1')
223 | h1 = tf.nn.relu(g_bn1(h1))
224 |
225 | h2 = deconv2d(h1,[self.batch_size, 16, 16, 128], name='g_h2')
226 | g_bn2 = batch_norm(self.batch_size, name='g_bn2')
227 | h2 = tf.nn.relu(g_bn2(h2))
228 |
229 | h3 = deconv2d(h2, [self.batch_size, 32, 32, 64], name='g_h3')
230 | g_bn3 = batch_norm(self.batch_size, name='g_bn3')
231 | h3 = tf.nn.relu(g_bn3(h3))
232 |
233 | h4 = deconv2d(h3, [self.batch_size, 64, 64, 3], name='g_h4')
234 | alpha = deconv2d(h3, [self.batch_size, 64, 64, 1], name='g_a')
235 |
236 | return tf.sigmoid(h4), tf.sigmoid(alpha)
237 |
238 | e = tf.random_normal([self.batch_size, self.z_dim])
239 | samples1 = u1 + e * s1
240 | samples2 = u2 + e * s2
241 |
242 | with tf.variable_scope(name, reuse=reuse):
243 | lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(self.z_dim, forget_bias=0.0)
244 | state = lstm_cell.zero_state(self.batch_size, tf.float32)
245 |
246 | with tf.variable_scope('g_rnn'):
247 | (cell_output1, state) = lstm_cell(samples1, state)
248 | tf.get_variable_scope().reuse_variables()
249 | (cell_output2, state) = lstm_cell(samples2, state)
250 |
251 | rgb1, alpha1 = gen(cell_output1, 'gen1')
252 | rgb2, alpha2 = gen(cell_output2, 'gen2')
253 |
254 | a_norm = alpha1*(1 - alpha2) + alpha2
255 | a1 = alpha1/a_norm
256 | a2 = alpha2/a_norm
257 |
258 | o1 = self.checked_img*(1 - a1) + rgb1*a1
259 | o2 = self.checked_img*(1 - a2) + rgb2*a2
260 | o = rgb1*alpha1*(1 - alpha2) + rgb2*alpha2
261 |
262 | return o*2 - 1, o1*2 - 1, o2*2 - 1, alpha1, alpha2
263 |
264 | def random_z(self):
265 | return np.random.normal(size=(self.batch_size, self.z_dim))
266 |
267 | def random_fix_z(self):
268 | r = np.zeros([self.batch_size, self.z_dim])
269 | r[0:,:] = np.random.normal(size=(1, self.z_dim))
270 | return r
271 |
272 | def save(self, checkpoint_dir, step):
273 | self.sess.run(self.global_step.assign(step))
274 |
275 | model_name = "GAN.model"
276 |
277 | if not os.path.exists(checkpoint_dir):
278 | os.makedirs(checkpoint_dir)
279 |
280 | self.saver.save(self.sess,
281 | os.path.join(checkpoint_dir, model_name),
282 | global_step=step)
283 |
284 | def load(self, checkpoint_dir):
285 | print(" [*] Reading checkpoints...")
286 |
287 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
288 | if ckpt and ckpt.model_checkpoint_path:
289 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
290 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
291 | return True
292 | else:
293 | return False
294 |
295 | def transform(self, X):
296 | return X*2 - 1
297 |
298 | def inverse_transform(self, X):
299 | X = (X+1.)/2.
300 | return X
301 |
302 | def create_checked_img(self, size):
303 | arr = np.arange(size)
304 | chk1 = arr[np.mod(arr, size/8) < size/16]
305 | chk2 = arr[np.mod(arr, size/8) >= size/16]
306 | a = np.meshgrid(chk1, chk1)
307 | b = np.meshgrid(chk2, chk2)
308 |
309 | img = np.ones([size, size, 3], dtype=np.float32)
310 |
311 | img[a[0], a[1]] = 0
312 | img[b[0], b[1]] = 0
313 |
314 | return img
--------------------------------------------------------------------------------
/code/main_cgan_triple_alpha_celeba.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import tensorflow as tf
4 | import h5py
5 |
6 | from lib.cgan_triple_alpha import GAN
7 |
8 | flags = tf.app.flags
9 | flags.DEFINE_integer("batch_size", 128, "The size of batch images")
10 | flags.DEFINE_integer("train_size", 160000, "The size of train images")
11 | flags.DEFINE_string("dataset", "celeba", "The name of dataset")
12 | flags.DEFINE_string("data_dir", "../data/celeba.hdf5", "Directory of dataset")
13 | flags.DEFINE_string("checkpoint_dir", "../checkpoint/cgan_triple_alpha_celeba", "Directory name to save the checkpoints [checkpoint]")
14 | flags.DEFINE_string("sample_dir", "../samples/cgan_triple_alpha_celeba", "Directory name to save the image samples [samples]")
15 | FLAGS = flags.FLAGS
16 |
17 | def main(_):
18 |
19 | if not os.path.exists(FLAGS.checkpoint_dir):
20 | os.makedirs(FLAGS.checkpoint_dir)
21 | if not os.path.exists(FLAGS.sample_dir):
22 | os.makedirs(FLAGS.sample_dir)
23 |
24 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
25 |
26 | with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
27 | gan = GAN(sess, FLAGS)
28 |
29 | HDF5_FILE = FLAGS.data_dir
30 | f = h5py.File(HDF5_FILE, 'r')
31 | tr_data = f['images'] # data_num * height * width * channel (float32)
32 |
33 | gan.train(tr_data)
34 | f.close()
35 |
36 |
37 | if __name__ == '__main__':
38 | tf.app.run()
39 |
--------------------------------------------------------------------------------
/data/download.py:
--------------------------------------------------------------------------------
1 | """
2 | Modification of https://github.com/stanfordnlp/treelstm/blob/master/scripts/download.py
3 | Downloads the following:
4 | - Celeb-A dataset
5 | - LSUN dataset
6 | - MNIST dataset
7 | """
8 |
9 | from __future__ import print_function
10 | import os
11 | import sys
12 | import gzip
13 | import json
14 | import shutil
15 | import zipfile
16 | import argparse
17 | import subprocess
18 | from six.moves import urllib
19 |
20 | parser = argparse.ArgumentParser(description='Download dataset for DCGAN.')
21 | parser.add_argument('datasets', metavar='N', type=str, nargs='+', choices=['celebA', 'lsun', 'mnist'],
22 | help='name of dataset to download [celebA, lsun, mnist]')
23 |
24 | def download(url, dirpath):
25 | filename = url.split('/')[-1]
26 | filepath = os.path.join(dirpath, filename)
27 | u = urllib.request.urlopen(url)
28 | f = open(filepath, 'wb')
29 | filesize = int(u.headers["Content-Length"])
30 | print("Downloading: %s Bytes: %s" % (filename, filesize))
31 |
32 | downloaded = 0
33 | block_sz = 8192
34 | status_width = 70
35 | while True:
36 | buf = u.read(block_sz)
37 | if not buf:
38 | print('')
39 | break
40 | else:
41 | print('', end='\r')
42 | downloaded += len(buf)
43 | f.write(buf)
44 | status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") %
45 | ('=' * int(float(downloaded) / filesize * status_width) + '>', downloaded * 100. / filesize))
46 | print(status, end='')
47 | sys.stdout.flush()
48 | f.close()
49 | return filepath
50 |
51 | def unzip(filepath):
52 | print("Extracting: " + filepath)
53 | dirpath = os.path.dirname(filepath)
54 | with zipfile.ZipFile(filepath) as zf:
55 | zf.extractall(dirpath)
56 | os.remove(filepath)
57 |
58 | def download_celeb_a(dirpath):
59 | data_dir = 'celebA'
60 | if os.path.exists(os.path.join(dirpath, data_dir)):
61 | print('Found Celeb-A - skip')
62 | return
63 | url = 'https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip?dl=1&pv=1'
64 | filepath = download(url, dirpath)
65 | zip_dir = ''
66 | with zipfile.ZipFile(filepath) as zf:
67 | zip_dir = zf.namelist()[0]
68 | zf.extractall(dirpath)
69 | os.remove(filepath)
70 | os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir))
71 |
72 | def _list_categories(tag):
73 | url = 'http://lsun.cs.princeton.edu/htbin/list.cgi?tag=' + tag
74 | f = urllib.request.urlopen(url)
75 | return json.loads(f.read())
76 |
77 | def _download_lsun(out_dir, category, set_name, tag):
78 | url = 'http://lsun.cs.princeton.edu/htbin/download.cgi?tag={tag}' \
79 | '&category={category}&set={set_name}'.format(**locals())
80 | print(url)
81 | if set_name == 'test':
82 | out_name = 'test_lmdb.zip'
83 | else:
84 | out_name = '{category}_{set_name}_lmdb.zip'.format(**locals())
85 | out_path = os.path.join(out_dir, out_name)
86 | cmd = ['curl', url, '-o', out_path]
87 | print('Downloading', category, set_name, 'set')
88 | subprocess.call(cmd)
89 |
90 | def download_lsun(dirpath):
91 | data_dir = os.path.join(dirpath, 'lsun')
92 | if os.path.exists(data_dir):
93 | print('Found LSUN - skip')
94 | return
95 | else:
96 | os.mkdir(data_dir)
97 |
98 | tag = 'latest'
99 | #categories = _list_categories(tag)
100 | categories = ['bedroom']
101 |
102 | for category in categories:
103 | _download_lsun(data_dir, category, 'train', tag)
104 | _download_lsun(data_dir, category, 'val', tag)
105 | _download_lsun(data_dir, '', 'test', tag)
106 |
107 | def download_mnist(dirpath):
108 | data_dir = os.path.join(dirpath, 'mnist')
109 | if os.path.exists(data_dir):
110 | print('Found MNIST - skip')
111 | return
112 | else:
113 | os.mkdir(data_dir)
114 | url_base = 'http://yann.lecun.com/exdb/mnist/'
115 | file_names = ['train-images-idx3-ubyte.gz','train-labels-idx1-ubyte.gz','t10k-images-idx3-ubyte.gz','t10k-labels-idx1-ubyte.gz']
116 | for file_name in file_names:
117 | url = (url_base+file_name).format(**locals())
118 | print(url)
119 | out_path = os.path.join(data_dir,file_name)
120 | cmd = ['curl', url, '-o', out_path]
121 | print('Downloading ', file_name)
122 | subprocess.call(cmd)
123 | cmd = ['gzip', '-d', out_path]
124 | print('Decompressing ', file_name)
125 | subprocess.call(cmd)
126 |
127 | def prepare_data_dir(path = './data'):
128 | if not os.path.exists(path):
129 | os.mkdir(path)
130 |
131 | if __name__ == '__main__':
132 | args = parser.parse_args()
133 | prepare_data_dir()
134 |
135 | if 'celebA' in args.datasets:
136 | download_celeb_a('./data')
137 | if 'lsun' in args.datasets:
138 | download_lsun('./data')
139 | if 'mnist' in args.datasets:
140 | download_mnist('./data')
--------------------------------------------------------------------------------
/data/preprocess.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import glob
4 | import h5py
5 | from PIL import Image
6 |
7 |
8 | IMG_SIZE = 64
9 |
10 |
11 | def load_Img(fname):
12 | img = Image.open(fname)
13 | img = img.resize((IMG_SIZE, IMG_SIZE), Image.ANTIALIAS)
14 | img = np.asarray(img)/255.
15 |
16 | if len(img.shape) < 3:
17 | rgb_img = np.zeros([IMG_SIZE, IMG_SIZE, 3], dtype=np.float32)
18 | rgb_img[:, :, 0] = rgb_img[:, :, 1] = rgb_img[:, :, 2] = img
19 | return rgb_img
20 |
21 | return img
22 |
23 |
24 | file_list = glob.glob('./data/celebA/*.jpg')
25 | IMG_NUM = len(file_list)
26 |
27 | rand_idx = np.arange(IMG_NUM)
28 | np.random.shuffle(rand_idx)
29 |
30 | HDF5_FILE_WRITE = 'celeba.hdf5'
31 | fw = h5py.File(HDF5_FILE_WRITE, 'w')
32 | images = fw.create_dataset('images', (IMG_NUM, IMG_SIZE, IMG_SIZE, 3), dtype='float32')
33 |
34 | for i in xrange(IMG_NUM):
35 | idx = rand_idx[i]
36 | images[i] = load_Img(file_list[idx])
37 |
38 | if i % 1000 == 0:
39 | print '%.1f %% preprocessed.' % (100.*i/IMG_NUM)
40 |
41 | fw.close()
42 |
43 |
--------------------------------------------------------------------------------
/samples/readme.txt:
--------------------------------------------------------------------------------
1 | This directory is where the samples are saved.
2 |
--------------------------------------------------------------------------------