├── README.md ├── checkpoint └── readme.txt ├── code ├── lib │ ├── __init__.py │ ├── cgan_triple_alpha.py │ ├── ops.py │ ├── utils.py │ └── vae_cgan_double_alpha.py └── main_cgan_triple_alpha_celeba.py ├── data ├── download.py └── preprocess.py └── samples └── readme.txt /README.md: -------------------------------------------------------------------------------- 1 | # Generating images part by part with composite generative adversarial networks 2 | Tensorflow implementation of the paper "Generating images part by part with composite generative adversarial networks". 3 | 4 | The Composite GANs (CGAN) disentangles complicated factors of images with multiple generators in which each generator generates some part of the image. Those parts are combined by an alpha blending process to create a new single image. For example, it can generate background, face, and hair sequentially with three generators. There is no supervision on what each generator should generate. 5 | 6 |
7 |

8 |
9 | 10 |
11 |

12 |
13 | 14 | 15 | #### Dependencies 16 | * tensorflow 0.10.0rc0+ 17 | * h5py 2.3.1+ 18 | * scipy 0.18.0+ 19 | * Pillow 3.1.0+ 20 | 21 | #### How to use 22 | First, go into the 'data' directory and download celebA dataset. 23 | 24 | ``` 25 | cd data 26 | python download.py celebA 27 | ``` 28 | 29 | Preprocess the celebA dataset to create a hdf5 file. It resizes images to 64*64. 30 | 31 | ``` 32 | python preprocess.py 33 | ``` 34 | 35 | Finally, go into the 'code' directory and run 'main_cgan_triple_alpha_celeba.py'. 36 | 37 | ``` 38 | cd ../code 39 | python main_cgan_triple_alpha_celeba.py 40 | ``` 41 | 42 | You can see the samples in 'samples' directory. 43 | 44 | ### References 45 | 46 | [1] Hanock Kwak and Byoung-Tak Zhang. "Generating Images Part by Part with Composite Generative Adversarial Networks." arXiv preprint arXiv:1607.05387 (2016). 47 | -------------------------------------------------------------------------------- /checkpoint/readme.txt: -------------------------------------------------------------------------------- 1 | This folder is where the models are saved! 2 | -------------------------------------------------------------------------------- /code/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hanock/generating_images_part_by_part/f934aa3f6d18bf635adc2311522fd9d7504f13e8/code/lib/__init__.py -------------------------------------------------------------------------------- /code/lib/cgan_triple_alpha.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from glob import glob 4 | import tensorflow as tf 5 | 6 | from ops import * 7 | from utils import * 8 | 9 | class GAN(object): 10 | def __init__(self, sess, config): 11 | 12 | self.sess = sess 13 | self.batch_size = config.batch_size 14 | self.train_size = config.train_size 15 | self.image_size = 64 16 | self.image_shape = [64, 64, 3] 17 | 18 | self.checked_img = self.create_checked_img(self.image_size) 19 | 20 | self.z_dim = 128 21 | 22 | self.checkpoint_dir = config.checkpoint_dir 23 | self.sample_dir = config.sample_dir 24 | self.build_model() 25 | 26 | def build_model(self): 27 | 28 | self.images = tf.placeholder(tf.float32, [self.batch_size] + self.image_shape, name='real_images') 29 | self.sample_images= tf.placeholder(tf.float32, [self.batch_size] + self.image_shape, name='sample_images') 30 | self.z1 = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='z1') 31 | self.z2 = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='z2') 32 | self.z3 = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='z3') 33 | 34 | # counter 35 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 36 | 37 | # 38 | self.O1, self.O2, self.O3, self.G, self.A1, self.A2, self.A3 = self.generator(self.z1, self.z2, self.z3, feature=True) 39 | self.D = self.discriminator(self.images) 40 | 41 | self.D_ = self.discriminator(self.G, reuse=True) 42 | 43 | # alpha loss 44 | self.alpha_loss1_1 = tf.reduce_sum(tf.abs(tf.reduce_sum(self.A1, [1, 2, 3]) - self.image_size*self.image_size*0.3)) 45 | self.alpha_loss1_2 = tf.reduce_sum(tf.reduce_sum(-tf.square(self.A1 - 0.5) + 0.25, [1, 2, 3])) 46 | self.alpha_loss2_1 = tf.reduce_sum(tf.abs(tf.reduce_sum(self.A2, [1, 2, 3]) - self.image_size*self.image_size*0.3)) 47 | self.alpha_loss2_2 = tf.reduce_sum(tf.reduce_sum(-tf.square(self.A2 - 0.5) + 0.25, [1, 2, 3])) 48 | self.alpha_loss3_1 = tf.reduce_sum(tf.abs(tf.reduce_sum(self.A3, [1, 2, 3]) - self.image_size*self.image_size*0.3)) 49 | self.alpha_loss3_2 = tf.reduce_sum(tf.reduce_sum(-tf.square(self.A3 - 0.5) + 0.25, [1, 2, 3])) 50 | self.alpha_loss = self.alpha_loss1_1 + self.alpha_loss1_2 + self.alpha_loss2_1 + self.alpha_loss2_2 + self.alpha_loss3_1 + self.alpha_loss3_2 51 | 52 | # d loss 53 | self.d_loss_real = binary_cross_entropy_with_logits(tf.ones_like(self.D), self.D) 54 | self.d_loss_fake = binary_cross_entropy_with_logits(tf.zeros_like(self.D_), self.D_) 55 | self.d_loss = self.d_loss_real + self.d_loss_fake 56 | 57 | # g loss 58 | self.g_gan_loss = binary_cross_entropy_with_logits(tf.ones_like(self.D_), self.D_) 59 | self.g_loss = self.g_gan_loss + 0.000005*self.alpha_loss 60 | 61 | 62 | t_vars = tf.trainable_variables() 63 | 64 | self.d_vars = [var for var in t_vars if 'discriminator' in var.name] 65 | self.g_vars = [var for var in t_vars if 'generator' in var.name] 66 | 67 | self.saver = tf.train.Saver() 68 | 69 | def train(self, data): 70 | """Train DCGAN""" 71 | 72 | data_size = len(data) 73 | 74 | d_optim = tf.train.AdamOptimizer(0.0002, beta1=0.5).minimize(self.d_loss, var_list=self.d_vars) 75 | g_optim = tf.train.AdamOptimizer(0.0004, beta1=0.5).minimize(self.g_loss, var_list=self.g_vars) 76 | tf.initialize_all_variables().run() 77 | 78 | self.saver = tf.train.Saver() 79 | 80 | sample_images = self.transform(data[0:self.batch_size]).astype(np.float32) 81 | save_images(sample_images[0:100], [10, 10], '%s/sample_images.png' % (self.sample_dir)) 82 | 83 | 84 | start_time = time.time() 85 | 86 | if self.load(self.checkpoint_dir): 87 | print(" [*] Load SUCCESS") 88 | counter = self.global_step.eval() 89 | 90 | 91 | for inf_loop in xrange(1000000): 92 | 93 | errD = 0 94 | errG = 0 95 | 96 | # random mini-batch 97 | i = np.random.randint(0, self.train_size - self.batch_size) 98 | batch_images = self.transform(data[i:i + self.batch_size]).astype(np.float32) 99 | 100 | # Update D network 101 | batch_z1 = self.random_z() 102 | batch_z2 = self.random_z() 103 | batch_z3 = self.random_z() 104 | self.sess.run(d_optim, feed_dict={ self.images: batch_images, self.z1: batch_z1, self.z2: batch_z2, self.z3: batch_z3 }) 105 | 106 | # Update G network 107 | self.sess.run(g_optim, feed_dict={ self.z1: batch_z1, self.z2: batch_z2, self.z3: batch_z3 }) 108 | 109 | counter += 1 110 | 111 | if np.mod(counter, 10) == 1: 112 | # random mini-batch 113 | i = np.random.randint(0, self.train_size - self.batch_size) 114 | batch_images = self.transform(data[i:i + self.batch_size]).astype(np.float32) 115 | batch_z1 = self.random_z() 116 | batch_z2 = self.random_z() 117 | batch_z3 = self.random_z() 118 | 119 | d_loss_fake = self.d_loss_fake.eval({self.z1: batch_z1, self.z2: batch_z2, self.z3: batch_z3 }) 120 | d_loss_real = self.d_loss_real.eval({self.images: batch_images}) 121 | errD = d_loss_fake + d_loss_real 122 | errG = self.g_loss.eval({self.z1: batch_z1, self.z2: batch_z2, self.z3: batch_z3}) 123 | 124 | print("[%5d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ 125 | % (counter, time.time() - start_time, errD, errG)) 126 | 127 | if np.mod(counter, 200) == 1: 128 | #for gi in np.arange(1): 129 | samp_o1, samp_o2, samp_o3, samp = self.sess.run( 130 | [self.O1, self.O2, self.O3, self.G], 131 | feed_dict={self.z1: self.random_z(), self.z2: self.random_z(), self.z3: self.random_z()}) 132 | 133 | 134 | samples = np.concatenate((samp_o1, samp_o2, samp_o3, samp), axis=0) 135 | 136 | save_images(samples, [32, 16], '%s/train_%05d.png' % (self.sample_dir, counter)) 137 | 138 | if np.mod(counter, 500) == 1: 139 | self.save(self.checkpoint_dir, counter) 140 | 141 | def discriminator(self, image, reuse=False): 142 | 143 | with tf.variable_scope('discriminator', reuse=reuse): 144 | h0 = lrelu(conv2d(image, 64, name='d_h0_conv')) 145 | 146 | d_bn1 = batch_norm(self.batch_size, name='d_bn1') 147 | h1 = lrelu(d_bn1(conv2d(h0, 64*2, name='d_h1_conv'))) 148 | 149 | d_bn2 = batch_norm(self.batch_size, name='d_bn2') 150 | h2 = lrelu(d_bn2(conv2d(h1, 64*4, name='d_h2_conv'))) 151 | 152 | d_bn3 = batch_norm(self.batch_size, name='d_bn3') 153 | h3 = lrelu(d_bn3(conv2d(h2, 64*8, name='d_h3_conv'))) 154 | 155 | h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h3_lin') 156 | 157 | return tf.nn.sigmoid(h4) 158 | 159 | def generator(self, z1, z2, z3, reuse=False, feature=False): 160 | 161 | def gen(h, gen_name='gen', reuse=False): 162 | with tf.variable_scope(gen_name, reuse=reuse): 163 | h0 = linear(h, 512*4*4, 'g_h0_lin') 164 | h0 = tf.reshape(h0, [-1, 4, 4, 512]) 165 | h0 = tf.nn.relu(h0) 166 | 167 | h1 = deconv2d(h0, [self.batch_size, 8, 8, 256], name='g_h1') 168 | g_bn1 = batch_norm(self.batch_size, name='g_bn1') 169 | h1 = tf.nn.relu(g_bn1(h1)) 170 | 171 | h2 = deconv2d(h1,[self.batch_size, 16, 16, 128], name='g_h2') 172 | g_bn2 = batch_norm(self.batch_size, name='g_bn2') 173 | h2 = tf.nn.relu(g_bn2(h2)) 174 | 175 | h3 = deconv2d(h2, [self.batch_size, 32, 32, 64], name='g_h3') 176 | g_bn3 = batch_norm(self.batch_size, name='g_bn3') 177 | h3 = tf.nn.relu(g_bn3(h3)) 178 | 179 | h4 = deconv2d(h3, [self.batch_size, 64, 64, 3], name='g_h4') 180 | alpha = deconv2d(h3, [self.batch_size, 64, 64, 1], name='g_a') 181 | 182 | return tf.sigmoid(h4), tf.sigmoid(alpha) 183 | 184 | with tf.variable_scope('generator', reuse=reuse): 185 | lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(128, forget_bias=0.0) 186 | state = lstm_cell.zero_state(self.batch_size, tf.float32) 187 | 188 | with tf.variable_scope('g_rnn'): 189 | (cell_output1, state) = lstm_cell(z1, state) 190 | tf.get_variable_scope().reuse_variables() 191 | (cell_output2, state) = lstm_cell(z2, state) 192 | (cell_output3, state) = lstm_cell(z3, state) 193 | 194 | rgb1, alpha1 = gen(cell_output1, 'gen1') 195 | rgb2, alpha2 = gen(cell_output2, 'gen2') 196 | rgb3, alpha3 = gen(cell_output3, 'gen3') 197 | 198 | a_norm = (alpha1*(1 - alpha2) + alpha2)*(1 - alpha3) + alpha3 199 | a1 = alpha1/a_norm 200 | a2 = alpha2/a_norm 201 | a3 = alpha3/a_norm 202 | 203 | o1 = self.checked_img*(1 - a1) + rgb1*a1 204 | o2 = self.checked_img*(1 - a2) + rgb2*a2 205 | o3 = self.checked_img*(1 - a3) + rgb3*a3 206 | o = (rgb1*alpha1*(1 - alpha2) + rgb2*alpha2)*(1 - alpha3) + rgb3*alpha3 207 | 208 | if feature: 209 | return o1*2 - 1, o2*2 - 1, o3*2 - 1, o*2 - 1, alpha1, alpha2, alpha3 210 | 211 | return o*2 - 1 212 | 213 | def random_z(self): 214 | return np.random.normal(size=(self.batch_size, self.z_dim)) 215 | 216 | def random_fix_z(self): 217 | r = np.zeros([self.batch_size, self.z_dim]) 218 | r[0:,:] = np.random.normal(size=(1, self.z_dim)) 219 | return r 220 | 221 | def save(self, checkpoint_dir, step): 222 | self.sess.run(self.global_step.assign(step)) 223 | 224 | model_name = "GAN.model" 225 | 226 | if not os.path.exists(checkpoint_dir): 227 | os.makedirs(checkpoint_dir) 228 | 229 | self.saver.save(self.sess, 230 | os.path.join(checkpoint_dir, model_name), 231 | global_step=step) 232 | 233 | def load(self, checkpoint_dir): 234 | print(" [*] Reading checkpoints...") 235 | 236 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 237 | if ckpt and ckpt.model_checkpoint_path: 238 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 239 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 240 | return True 241 | else: 242 | return False 243 | 244 | def transform(self, X): 245 | return X*2 - 1 246 | 247 | def inverse_transform(self, X): 248 | X = (X+1.)/2. 249 | return X 250 | 251 | def create_checked_img(self, size): 252 | arr = np.arange(size) 253 | chk1 = arr[np.mod(arr, size/8) < size/16] 254 | chk2 = arr[np.mod(arr, size/8) >= size/16] 255 | a = np.meshgrid(chk1, chk1) 256 | b = np.meshgrid(chk2, chk2) 257 | 258 | img = np.ones([size, size, 3], dtype=np.float32) 259 | 260 | img[a[0], a[1]] = 0 261 | img[b[0], b[1]] = 0 262 | 263 | return img -------------------------------------------------------------------------------- /code/lib/ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from tensorflow.python.framework import ops 6 | from tensorflow.python import control_flow_ops 7 | 8 | from utils import * 9 | 10 | class batch_norm(object): 11 | """Code modification of http://stackoverflow.com/a/33950177""" 12 | def __init__(self, batch_size, epsilon=1e-5, momentum = 0.1, name="batch_norm"): 13 | with tf.variable_scope(name) as scope: 14 | self.epsilon = epsilon 15 | self.momentum = momentum 16 | self.batch_size = batch_size 17 | 18 | self.ema = tf.train.ExponentialMovingAverage(decay=self.momentum) 19 | self.name=name 20 | 21 | def __call__(self, x, train=True): 22 | shape = x.get_shape().as_list() 23 | 24 | with tf.variable_scope(self.name) as scope: 25 | self.gamma = tf.get_variable("gamma", [shape[-1]], 26 | initializer=tf.random_normal_initializer(1., 0.02)) 27 | self.beta = tf.get_variable("beta", [shape[-1]], 28 | initializer=tf.constant_initializer(0.)) 29 | 30 | self.mean, self.variance = tf.nn.moments(x, [0, 1, 2]) 31 | 32 | return tf.nn.batch_norm_with_global_normalization( 33 | x, self.mean, self.variance, self.beta, self.gamma, self.epsilon, 34 | scale_after_normalization=True) 35 | 36 | 37 | def batch_norm2(x, n_out, phase_train, scope='bn', affine=True): 38 | """ 39 | http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow/33950177 40 | Batch normalization on convolutional maps. 41 | Args: 42 | x: Tensor, 4D BHWD input maps 43 | n_out: integer, depth (channel) of input maps 44 | phase_train: boolean tf.Variable, true indicates training phase 45 | scope: string, variable scope 46 | affine: whether to affine-transform outputs 47 | Return: 48 | normed: batch-normalized maps 49 | """ 50 | with tf.variable_scope(scope): 51 | beta = tf.Variable(tf.constant(0.0, shape=[n_out]), 52 | name='beta', trainable=True) 53 | gamma = tf.Variable(tf.constant(1.0, shape=[n_out]), 54 | name='gamma', trainable=affine) 55 | 56 | batch_mean, batch_var = tf.nn.moments(x, [0,1,2], name='moments') 57 | ema = tf.train.ExponentialMovingAverage(decay=0.9) 58 | ema_apply_op = ema.apply([batch_mean, batch_var]) 59 | ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var) 60 | def mean_var_with_update(): 61 | with tf.control_dependencies([ema_apply_op]): 62 | return tf.identity(batch_mean), tf.identity(batch_var) 63 | mean, var = control_flow_ops.cond(phase_train, 64 | mean_var_with_update, 65 | lambda: (ema_mean, ema_var)) 66 | 67 | normed = tf.nn.batch_norm_with_global_normalization(x, mean, var, 68 | beta, gamma, 1e-3, affine) 69 | return normed 70 | 71 | 72 | def binary_cross_entropy_with_logits(logits, targets, name=None): 73 | """Computes binary cross entropy given `logits`. 74 | For brevity, let `x = logits`, `z = targets`. The logistic loss is 75 | loss(x, z) = - sum_i (x[i] * log(z[i]) + (1 - x[i]) * log(1 - z[i])) 76 | Args: 77 | logits: A `Tensor` of type `float32` or `float64`. 78 | targets: A `Tensor` of the same type and shape as `logits`. 79 | """ 80 | eps = 1e-12 81 | with ops.op_scope([logits, targets], name, "bce_loss") as name: 82 | logits = ops.convert_to_tensor(logits, name="logits") 83 | targets = ops.convert_to_tensor(targets, name="targets") 84 | return tf.reduce_mean(-(logits * tf.log(targets + eps) + 85 | (1. - logits) * tf.log(1. - targets + eps))) 86 | 87 | # zero-based 88 | def dense_to_one_hot(labels_dense, num_classes=10): 89 | """Convert class labels from scalars to one-hot vectors.""" 90 | num_labels = labels_dense.shape[0] 91 | index_offset = np.arange(num_labels) * num_classes 92 | labels_one_hot = np.zeros((num_labels, num_classes)) 93 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 94 | return labels_one_hot 95 | 96 | def conv_cond_concat(x, y): 97 | """Concatenate conditioning vector on feature map axis.""" 98 | x_shapes = x.get_shape() 99 | y_shapes = y.get_shape() 100 | return tf.concat(3, [x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])]) 101 | 102 | def conv2d(input_, output_dim, 103 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 104 | name="conv2d"): 105 | with tf.variable_scope(name): 106 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], 107 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 108 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') 109 | 110 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 111 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 112 | 113 | return conv 114 | 115 | def conv2d_v2(input_, output_dim, 116 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 117 | name="conv2d"): 118 | with tf.variable_scope(name): 119 | 120 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], 121 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 122 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') 123 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 124 | 125 | return tf.nn.bias_add(conv, biases) 126 | 127 | def deconv2d(input_, output_shape, 128 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 129 | name="deconv2d", with_w=False): 130 | with tf.variable_scope(name): 131 | # filter : [height, width, output_channels, in_channels] 132 | w = tf.get_variable('w', [k_h, k_h, output_shape[-1], input_.get_shape()[-1]], 133 | initializer=tf.random_normal_initializer(stddev=stddev)) 134 | 135 | try: 136 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, 137 | strides=[1, d_h, d_w, 1]) 138 | 139 | # Support for verisons of TensorFlow before 0.7.0 140 | except AttributeError: 141 | deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, 142 | strides=[1, d_h, d_w, 1]) 143 | 144 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 145 | deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) 146 | 147 | if with_w: 148 | return deconv, w, biases 149 | else: 150 | return deconv 151 | 152 | def lrelu(x, leak=0.2, name="lrelu"): 153 | with tf.variable_scope(name): 154 | f1 = 0.5 * (1 + leak) 155 | f2 = 0.5 * (1 - leak) 156 | return f1 * x + f2 * abs(x) 157 | 158 | def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): 159 | shape = input_.get_shape().as_list() 160 | 161 | with tf.variable_scope(scope or "Linear"): 162 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, 163 | tf.random_normal_initializer(stddev=stddev)) 164 | bias = tf.get_variable("bias", [output_size], 165 | initializer=tf.constant_initializer(bias_start)) 166 | if with_w: 167 | return tf.matmul(input_, matrix) + bias, matrix, bias 168 | else: 169 | return tf.matmul(input_, matrix) + bias -------------------------------------------------------------------------------- /code/lib/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some codes from https://github.com/Newmu/dcgan_code 3 | """ 4 | import math 5 | import json 6 | import random 7 | import pprint 8 | import scipy.misc 9 | import numpy as np 10 | from cStringIO import StringIO 11 | from time import gmtime, strftime 12 | 13 | pp = pprint.PrettyPrinter() 14 | 15 | get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1]) 16 | 17 | 18 | 19 | def save_sentence_pairs(X, Y, dic, fname, CODE_EOF=1): 20 | data_num = X.shape[0] 21 | max_time = X.shape[1] 22 | 23 | file_str = StringIO() 24 | for i in xrange(data_num): 25 | write_sentence(file_str, dic, X[i, :], CODE_EOF) 26 | write_sentence(file_str, dic, Y[i, :], CODE_EOF) 27 | file_str.write('\n') 28 | 29 | with open(fname, 'w') as fp: 30 | fp.write(file_str.getvalue()) 31 | 32 | def save_sentences(X, dic, fname, CODE_EOF=1): 33 | data_num = X.shape[0] 34 | max_time = X.shape[1] 35 | 36 | file_str = StringIO() 37 | for i in xrange(data_num): 38 | write_sentence(file_str, dic, X[i, :], CODE_EOF) 39 | file_str.write('\n') 40 | 41 | with open(fname, 'w') as fp: 42 | fp.write(file_str.getvalue()) 43 | 44 | def write_sentence(fp, dic, sentence, CODE_EOF=1): 45 | max_time = sentence.shape[0] 46 | 47 | for i in xrange(max_time): 48 | v = sentence[i] 49 | if v == 0: 50 | break 51 | if v == CODE_EOF: 52 | fp.write('EOF') 53 | break 54 | 55 | try: 56 | w = dic.keys()[dic.values().index(v)] 57 | except ValueError: 58 | w = 'NULL' 59 | fp.write(w + ' ') 60 | fp.write('\n') 61 | 62 | def get_image(image_path, image_size, is_crop=True): 63 | return transform(imread(image_path), image_size, is_crop) 64 | 65 | def save_images(images, size, image_path): 66 | return imsave(inverse_transform(images), size, image_path) 67 | 68 | def imread(path): 69 | return scipy.misc.imread(path).astype(np.float) 70 | 71 | def merge_images(images, size): 72 | return inverse_transform(images) 73 | 74 | def merge(images, size): 75 | h, w = images.shape[1], images.shape[2] 76 | img = np.zeros((h * size[0], w * size[1], 3)) 77 | 78 | for idx, image in enumerate(images): 79 | i = idx % size[1] 80 | j = idx / size[1] 81 | img[j*h:j*h+h, i*w:i*w+w, :] = image 82 | 83 | return img 84 | 85 | def imsave(images, size, path): 86 | return scipy.misc.imsave(path, merge(images, size)) 87 | 88 | def center_crop(x, crop_h, crop_w=None, resize_w=64): 89 | if crop_w is None: 90 | crop_w = crop_h 91 | h, w = x.shape[:2] 92 | j = int(round((h - crop_h)/2.)) 93 | i = int(round((w - crop_w)/2.)) 94 | return scipy.misc.imresize(x[j:j+crop_h, i:i+crop_w], 95 | [resize_w, resize_w]) 96 | 97 | def transform(image, npx=64, is_crop=True): 98 | # npx : # of pixels width/height of image 99 | if is_crop: 100 | cropped_image = center_crop(image, npx) 101 | else: 102 | cropped_image = image 103 | return np.array(cropped_image)/127.5 - 1. 104 | 105 | def inverse_transform(images): 106 | return (images+1.)/2. 107 | 108 | 109 | def to_json(output_path, *layers): 110 | with open(output_path, "w") as layer_f: 111 | lines = "" 112 | for w, b, bn in layers: 113 | layer_idx = w.name.split('/')[0].split('h')[1] 114 | 115 | B = b.eval() 116 | 117 | if "lin/" in w.name: 118 | W = w.eval() 119 | depth = W.shape[1] 120 | else: 121 | W = np.rollaxis(w.eval(), 2, 0) 122 | depth = W.shape[0] 123 | 124 | biases = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(B)]} 125 | if bn != None: 126 | gamma = bn.gamma.eval() 127 | beta = bn.beta.eval() 128 | 129 | gamma = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(gamma)]} 130 | beta = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(beta)]} 131 | else: 132 | gamma = {"sy": 1, "sx": 1, "depth": 0, "w": []} 133 | beta = {"sy": 1, "sx": 1, "depth": 0, "w": []} 134 | 135 | if "lin/" in w.name: 136 | fs = [] 137 | for w in W.T: 138 | fs.append({"sy": 1, "sx": 1, "depth": W.shape[0], "w": ['%.2f' % elem for elem in list(w)]}) 139 | 140 | lines += """ 141 | var layer_%s = { 142 | "layer_type": "fc", 143 | "sy": 1, "sx": 1, 144 | "out_sx": 1, "out_sy": 1, 145 | "stride": 1, "pad": 0, 146 | "out_depth": %s, "in_depth": %s, 147 | "biases": %s, 148 | "gamma": %s, 149 | "beta": %s, 150 | "filters": %s 151 | };""" % (layer_idx.split('_')[0], W.shape[1], W.shape[0], biases, gamma, beta, fs) 152 | else: 153 | fs = [] 154 | for w_ in W: 155 | fs.append({"sy": 5, "sx": 5, "depth": W.shape[3], "w": ['%.2f' % elem for elem in list(w_.flatten())]}) 156 | 157 | lines += """ 158 | var layer_%s = { 159 | "layer_type": "deconv", 160 | "sy": 5, "sx": 5, 161 | "out_sx": %s, "out_sy": %s, 162 | "stride": 2, "pad": 1, 163 | "out_depth": %s, "in_depth": %s, 164 | "biases": %s, 165 | "gamma": %s, 166 | "beta": %s, 167 | "filters": %s 168 | };""" % (layer_idx, 2**(int(layer_idx)+2), 2**(int(layer_idx)+2), 169 | W.shape[0], W.shape[3], biases, gamma, beta, fs) 170 | layer_f.write(" ".join(lines.replace("'","").split())) 171 | 172 | def make_gif(images, fname, duration=2, true_image=False): 173 | import moviepy.editor as mpy 174 | 175 | def make_frame(t): 176 | try: 177 | x = images[int(len(images)/duration*t)] 178 | except: 179 | x = images[-1] 180 | 181 | if true_image: 182 | return x.astype(np.uint8) 183 | else: 184 | return ((x+1)/2*255).astype(np.uint8) 185 | 186 | clip = mpy.VideoClip(make_frame, duration=duration) 187 | clip.write_gif(fname, fps = len(images) / duration) 188 | 189 | def visualize(sess, dcgan, config, option): 190 | if option == 0: 191 | z_sample = np.random.uniform(-1, 1, size=(config.batch_size, dcgan.z_dim)) 192 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 193 | save_images(samples[0:121], [11, 11], './samples/test_%s.png' % strftime("%Y-%m-%d %H:%M:%S", gmtime())) 194 | elif option == 1: 195 | values = np.arange(0, 1, 1./config.batch_size) 196 | for idx in xrange(100): 197 | print(" [*] %d" % idx) 198 | z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 199 | for kdx, z in enumerate(z_sample): 200 | z[idx] = values[kdx] 201 | 202 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 203 | save_images(samples, [8, 8], './samples/test_arange_%s.png' % (idx)) 204 | 205 | elif option == 2: 206 | values = np.arange(0, 1, 1./config.batch_size) 207 | for idx in [random.randint(0, 99) for _ in xrange(100)]: 208 | print(" [*] %d" % idx) 209 | z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim)) 210 | z_sample = np.tile(z, (config.batch_size, 1)) 211 | #z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 212 | for kdx, z in enumerate(z_sample): 213 | z[idx] = values[kdx] 214 | 215 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 216 | make_gif(samples, './samples/test_gif_%s.gif' % (idx)) 217 | 218 | elif option == 3: 219 | values = np.arange(0, 1, 1./config.batch_size) 220 | for idx in xrange(100): 221 | print(" [*] %d" % idx) 222 | z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 223 | for kdx, z in enumerate(z_sample): 224 | z[idx] = values[kdx] 225 | 226 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 227 | make_gif(samples, './samples/test_gif_%s.gif' % (idx)) 228 | 229 | elif option == 4: 230 | image_set = [] 231 | values = np.arange(0, 1, 1./config.batch_size) 232 | 233 | for idx in xrange(100): 234 | print(" [*] %d" % idx) 235 | z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 236 | for kdx, z in enumerate(z_sample): 237 | z[idx] = values[kdx] 238 | 239 | image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})) 240 | make_gif(image_set[-1], './samples/test_gif_%s.gif' % (idx)) 241 | 242 | new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) for idx in range(64) + range(63, -1, -1)] 243 | make_gif(new_image_set, './samples/test_gif_merged.gif', duration=8) -------------------------------------------------------------------------------- /code/lib/vae_cgan_double_alpha.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from glob import glob 4 | import tensorflow as tf 5 | 6 | from ops import * 7 | from utils import * 8 | 9 | class GAN(object): 10 | def __init__(self, sess, config): 11 | 12 | self.sess = sess 13 | self.batch_size = config.batch_size 14 | self.train_size = config.train_size 15 | self.image_size = 64 16 | self.image_shape = [64, 64, 3] 17 | 18 | self.checked_img = self.create_checked_img(self.image_size) 19 | 20 | self.z_dim = 128 21 | 22 | self.checkpoint_dir = config.checkpoint_dir 23 | self.sample_dir = config.sample_dir 24 | self.build_model() 25 | 26 | def build_model(self): 27 | 28 | self.images = tf.placeholder(tf.float32, [self.batch_size] + self.image_shape, name='real_images') 29 | self.sample_images= tf.placeholder(tf.float32, [self.batch_size] + self.image_shape, name='sample_images') 30 | self.z1 = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='z1') 31 | self.z2 = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='z2') 32 | 33 | # counter 34 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 35 | 36 | # 37 | self.mean1, self.stddev1, self.mean2, self.stddev2 = self.encoder(self.images) 38 | self.vae_output, self.O1, self.O2, self.A1, self.A2 = self.generator(self.mean1, self.stddev1, self.mean2, self.stddev2) 39 | self.D = self.discriminator(self.images) 40 | 41 | self.G, _, _, _, _ = self.generator(self.z1, tf.zeros([self.batch_size, self.z_dim]), self.z2, tf.zeros([self.batch_size, self.z_dim]), reuse=True) 42 | self.D_ = self.discriminator(self.G, reuse=True) 43 | 44 | self.feature_out = self.discriminator(self.vae_output, reuse=True, feature=True) 45 | self.feature_img = self.discriminator(self.images, reuse=True, feature=True) 46 | 47 | 48 | # alpha loss 49 | self.alpha_loss1_1 = tf.reduce_sum(tf.abs(tf.reduce_sum(self.A1, [1, 2, 3]) - self.image_size*self.image_size*0.3)) 50 | self.alpha_loss1_2 = tf.reduce_sum(tf.reduce_sum(-tf.square(self.A1 - 0.5) + 0.25, [1, 2, 3])) 51 | self.alpha_loss2_1 = tf.reduce_sum(tf.abs(tf.reduce_sum(self.A2, [1, 2, 3]) - self.image_size*self.image_size*0.3)) 52 | self.alpha_loss2_2 = tf.reduce_sum(tf.reduce_sum(-tf.square(self.A2 - 0.5) + 0.25, [1, 2, 3])) 53 | self.alpha_loss = self.alpha_loss1_1 + self.alpha_loss1_2 + self.alpha_loss2_1 + self.alpha_loss2_2 54 | 55 | # prior loss 56 | self.prior_loss1 = tf.reduce_sum(0.5 * (tf.square(self.mean1) + tf.square(self.stddev1) - 2.0 * tf.log(self.stddev1 + 1e-8) - 1.0)) 57 | self.prior_loss2 = tf.reduce_sum(0.5 * (tf.square(self.mean2) + tf.square(self.stddev2) - 2.0 * tf.log(self.stddev2 + 1e-8) - 1.0)) 58 | self.prior_loss = self.prior_loss1 + self.prior_loss2 59 | 60 | # recon loss 61 | self.recon_loss_raw = tf.reduce_sum(tf.square(self.vae_output - self.images)) 62 | self.recon_loss_adv = tf.reduce_sum(tf.square(self.feature_out - self.feature_img)) 63 | 64 | # e loss 65 | self.e_loss = self.prior_loss + self.recon_loss_adv + 0.2*self.recon_loss_raw + 0.2*self.alpha_loss 66 | 67 | # g loss 68 | self.g_gan_loss = binary_cross_entropy_with_logits(tf.ones_like(self.D_), self.D_) 69 | self.g_loss = self.g_gan_loss + 0.0005*self.recon_loss_adv + 0.0001*self.recon_loss_raw + 0.0001*self.alpha_loss 70 | 71 | # d loss 72 | self.d_loss_real = binary_cross_entropy_with_logits(tf.ones_like(self.D), self.D) 73 | self.d_loss_fake = binary_cross_entropy_with_logits(tf.zeros_like(self.D_), self.D_) 74 | self.d_loss = self.d_loss_real + self.d_loss_fake 75 | 76 | 77 | t_vars = tf.trainable_variables() 78 | 79 | self.e_vars = [var for var in t_vars if 'encoder' in var.name] 80 | self.d_vars = [var for var in t_vars if 'discriminator' in var.name] 81 | self.g_vars = [var for var in t_vars if 'generator' in var.name] 82 | 83 | self.saver = tf.train.Saver() 84 | 85 | def train(self, data): 86 | """Train DCGAN""" 87 | 88 | data_size = len(data) 89 | 90 | e_optim = tf.train.AdamOptimizer(0.0002, beta1=0.5).minimize(self.e_loss, var_list=self.e_vars) 91 | d_optim = tf.train.AdamOptimizer(0.0002, beta1=0.5).minimize(self.d_loss, var_list=self.d_vars) 92 | g_optim = tf.train.AdamOptimizer(0.0004, beta1=0.5).minimize(self.g_loss, var_list=self.g_vars) 93 | tf.initialize_all_variables().run() 94 | 95 | self.saver = tf.train.Saver() 96 | 97 | sample_images = self.transform(data[0:self.batch_size]).astype(np.float32) 98 | save_images(sample_images, [8, 16], '%s/sample_images.png' % (self.sample_dir)) 99 | 100 | 101 | start_time = time.time() 102 | 103 | if self.load(self.checkpoint_dir): 104 | print(" [*] Load SUCCESS") 105 | counter = self.global_step.eval() 106 | 107 | 108 | for inf_loop in xrange(1000000): 109 | 110 | errD = 0 111 | errG = 0 112 | 113 | # random mini-batch 114 | i = np.random.randint(0, self.train_size - self.batch_size) 115 | batch_images = self.transform(data[i:i + self.batch_size]).astype(np.float32) 116 | 117 | # Update E network 118 | self.sess.run([e_optim], feed_dict={ self.images: batch_images }) 119 | 120 | # Update D network 121 | batch_z1 = self.random_z() 122 | batch_z2 = self.random_z() 123 | self.sess.run(d_optim, feed_dict={ self.images: batch_images, self.z1: batch_z1, self.z2: batch_z2 }) 124 | 125 | # Update G network 126 | self.sess.run(g_optim, feed_dict={ self.images: batch_images, self.z1: batch_z1, self.z2: batch_z2 }) 127 | 128 | counter += 1 129 | 130 | if np.mod(counter, 10) == 1: 131 | # random mini-batch 132 | i = np.random.randint(0, self.train_size - self.batch_size) 133 | batch_images = self.transform(data[i:i + self.batch_size]).astype(np.float32) 134 | batch_z1 = self.random_z() 135 | batch_z2 = self.random_z() 136 | 137 | recon_loss_adv = self.recon_loss_adv.eval({self.images: batch_images}) 138 | recon_loss_raw = self.recon_loss_raw.eval({self.images: batch_images}) 139 | alpha_loss = self.alpha_loss.eval({self.images: batch_images}) 140 | g_gan_loss = self.g_gan_loss.eval({self.z1: batch_z1, self.z2: batch_z2}) 141 | d_loss = self.d_loss.eval({self.z1: batch_z1, self.z2: batch_z2, self.images: batch_images}) 142 | 143 | print("[%5d] time: %4.4f, adv: %.1f, raw: %.1f, alpha: %.1f, d: %.8f, g_gan: %.8f" \ 144 | % (counter, time.time() - start_time, recon_loss_adv, recon_loss_raw, alpha_loss, d_loss, g_gan_loss)) 145 | 146 | if np.mod(counter, 100) == 1: 147 | #for gi in np.arange(1): 148 | samp_o1, samp_o2, samp = self.sess.run( 149 | [self.O1, self.O2, self.vae_output], 150 | feed_dict={self.images: sample_images}) 151 | 152 | samp_fix = self.sess.run( 153 | self.G, 154 | feed_dict={self.z1: self.random_fix_z(), self.z2: self.random_z()}) 155 | 156 | samples = np.concatenate((samp_fix, samp_o1, samp_o2, samp), axis=0) 157 | 158 | save_images(samples, [32, 16], '%s/train_%05d.png' % (self.sample_dir, counter)) 159 | 160 | 161 | if np.mod(counter, 500) == 1: 162 | self.save(self.checkpoint_dir, counter) 163 | 164 | def encoder(self, image, reuse=False): 165 | 166 | def enc(image, name='enc', reuse=False): 167 | with tf.variable_scope(name, reuse=reuse): 168 | h0 = lrelu(conv2d(image, 64, name='e_h0_conv')) 169 | 170 | d_bn1 = batch_norm(self.batch_size, name='e_bn1') 171 | h1 = lrelu(d_bn1(conv2d(h0, 64*2, name='e_h1_conv'))) 172 | 173 | d_bn2 = batch_norm(self.batch_size, name='e_bn2') 174 | h2 = lrelu(d_bn2(conv2d(h1, 64*4, name='e_h2_conv'))) 175 | 176 | d_bn3 = batch_norm(self.batch_size, name='e_bn3') 177 | h3 = lrelu(d_bn3(conv2d(h2, 64*8, name='e_h3_conv'))) 178 | 179 | h3_flat = tf.reshape(h3, [self.batch_size, -1]) 180 | u = linear(h3_flat, self.z_dim, 'e_u_lin') 181 | s = tf.sqrt(tf.exp(linear(h3_flat, self.z_dim, 'e_s_lin'))) 182 | 183 | return u, s 184 | 185 | with tf.variable_scope('encoder', reuse=reuse): 186 | u1, s1 = enc(image, 'enc1', reuse) 187 | u2, s2 = enc(image, 'enc2', reuse) 188 | 189 | return u1, s1, u2, s2 190 | 191 | def discriminator(self, image, feature=False, reuse=False): 192 | 193 | with tf.variable_scope('discriminator', reuse=reuse): 194 | h0 = lrelu(conv2d(image, 64, name='d_h0_conv')) 195 | 196 | d_bn1 = batch_norm(self.batch_size, name='d_bn1') 197 | h1 = lrelu(d_bn1(conv2d(h0, 64*2, name='d_h1_conv'))) 198 | 199 | d_bn2 = batch_norm(self.batch_size, name='d_bn2') 200 | h2 = lrelu(d_bn2(conv2d(h1, 64*4, name='d_h2_conv'))) 201 | 202 | d_bn3 = batch_norm(self.batch_size, name='d_bn3') 203 | h3 = lrelu(d_bn3(conv2d(h2, 64*8, name='d_h3_conv'))) 204 | 205 | h3_flat = tf.reshape(h3, [self.batch_size, -1]) 206 | h4 = linear(h3_flat, 1, 'd_h3_lin') 207 | 208 | if feature: 209 | return h3_flat 210 | 211 | return tf.nn.sigmoid(h4) 212 | 213 | def generator(self, u1, s1, u2, s2, name='generator', reuse=False, feature=False): 214 | 215 | def gen(h, gen_name='gen', reuse=False): 216 | with tf.variable_scope(gen_name, reuse=reuse): 217 | h0 = linear(h, 512*4*4, 'g_h0_lin') 218 | h0 = tf.reshape(h0, [-1, 4, 4, 512]) 219 | h0 = tf.nn.relu(h0) 220 | 221 | h1 = deconv2d(h0, [self.batch_size, 8, 8, 256], name='g_h1') 222 | g_bn1 = batch_norm(self.batch_size, name='g_bn1') 223 | h1 = tf.nn.relu(g_bn1(h1)) 224 | 225 | h2 = deconv2d(h1,[self.batch_size, 16, 16, 128], name='g_h2') 226 | g_bn2 = batch_norm(self.batch_size, name='g_bn2') 227 | h2 = tf.nn.relu(g_bn2(h2)) 228 | 229 | h3 = deconv2d(h2, [self.batch_size, 32, 32, 64], name='g_h3') 230 | g_bn3 = batch_norm(self.batch_size, name='g_bn3') 231 | h3 = tf.nn.relu(g_bn3(h3)) 232 | 233 | h4 = deconv2d(h3, [self.batch_size, 64, 64, 3], name='g_h4') 234 | alpha = deconv2d(h3, [self.batch_size, 64, 64, 1], name='g_a') 235 | 236 | return tf.sigmoid(h4), tf.sigmoid(alpha) 237 | 238 | e = tf.random_normal([self.batch_size, self.z_dim]) 239 | samples1 = u1 + e * s1 240 | samples2 = u2 + e * s2 241 | 242 | with tf.variable_scope(name, reuse=reuse): 243 | lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(self.z_dim, forget_bias=0.0) 244 | state = lstm_cell.zero_state(self.batch_size, tf.float32) 245 | 246 | with tf.variable_scope('g_rnn'): 247 | (cell_output1, state) = lstm_cell(samples1, state) 248 | tf.get_variable_scope().reuse_variables() 249 | (cell_output2, state) = lstm_cell(samples2, state) 250 | 251 | rgb1, alpha1 = gen(cell_output1, 'gen1') 252 | rgb2, alpha2 = gen(cell_output2, 'gen2') 253 | 254 | a_norm = alpha1*(1 - alpha2) + alpha2 255 | a1 = alpha1/a_norm 256 | a2 = alpha2/a_norm 257 | 258 | o1 = self.checked_img*(1 - a1) + rgb1*a1 259 | o2 = self.checked_img*(1 - a2) + rgb2*a2 260 | o = rgb1*alpha1*(1 - alpha2) + rgb2*alpha2 261 | 262 | return o*2 - 1, o1*2 - 1, o2*2 - 1, alpha1, alpha2 263 | 264 | def random_z(self): 265 | return np.random.normal(size=(self.batch_size, self.z_dim)) 266 | 267 | def random_fix_z(self): 268 | r = np.zeros([self.batch_size, self.z_dim]) 269 | r[0:,:] = np.random.normal(size=(1, self.z_dim)) 270 | return r 271 | 272 | def save(self, checkpoint_dir, step): 273 | self.sess.run(self.global_step.assign(step)) 274 | 275 | model_name = "GAN.model" 276 | 277 | if not os.path.exists(checkpoint_dir): 278 | os.makedirs(checkpoint_dir) 279 | 280 | self.saver.save(self.sess, 281 | os.path.join(checkpoint_dir, model_name), 282 | global_step=step) 283 | 284 | def load(self, checkpoint_dir): 285 | print(" [*] Reading checkpoints...") 286 | 287 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 288 | if ckpt and ckpt.model_checkpoint_path: 289 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 290 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 291 | return True 292 | else: 293 | return False 294 | 295 | def transform(self, X): 296 | return X*2 - 1 297 | 298 | def inverse_transform(self, X): 299 | X = (X+1.)/2. 300 | return X 301 | 302 | def create_checked_img(self, size): 303 | arr = np.arange(size) 304 | chk1 = arr[np.mod(arr, size/8) < size/16] 305 | chk2 = arr[np.mod(arr, size/8) >= size/16] 306 | a = np.meshgrid(chk1, chk1) 307 | b = np.meshgrid(chk2, chk2) 308 | 309 | img = np.ones([size, size, 3], dtype=np.float32) 310 | 311 | img[a[0], a[1]] = 0 312 | img[b[0], b[1]] = 0 313 | 314 | return img -------------------------------------------------------------------------------- /code/main_cgan_triple_alpha_celeba.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | import h5py 5 | 6 | from lib.cgan_triple_alpha import GAN 7 | 8 | flags = tf.app.flags 9 | flags.DEFINE_integer("batch_size", 128, "The size of batch images") 10 | flags.DEFINE_integer("train_size", 160000, "The size of train images") 11 | flags.DEFINE_string("dataset", "celeba", "The name of dataset") 12 | flags.DEFINE_string("data_dir", "../data/celeba.hdf5", "Directory of dataset") 13 | flags.DEFINE_string("checkpoint_dir", "../checkpoint/cgan_triple_alpha_celeba", "Directory name to save the checkpoints [checkpoint]") 14 | flags.DEFINE_string("sample_dir", "../samples/cgan_triple_alpha_celeba", "Directory name to save the image samples [samples]") 15 | FLAGS = flags.FLAGS 16 | 17 | def main(_): 18 | 19 | if not os.path.exists(FLAGS.checkpoint_dir): 20 | os.makedirs(FLAGS.checkpoint_dir) 21 | if not os.path.exists(FLAGS.sample_dir): 22 | os.makedirs(FLAGS.sample_dir) 23 | 24 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8) 25 | 26 | with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess: 27 | gan = GAN(sess, FLAGS) 28 | 29 | HDF5_FILE = FLAGS.data_dir 30 | f = h5py.File(HDF5_FILE, 'r') 31 | tr_data = f['images'] # data_num * height * width * channel (float32) 32 | 33 | gan.train(tr_data) 34 | f.close() 35 | 36 | 37 | if __name__ == '__main__': 38 | tf.app.run() 39 | -------------------------------------------------------------------------------- /data/download.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modification of https://github.com/stanfordnlp/treelstm/blob/master/scripts/download.py 3 | Downloads the following: 4 | - Celeb-A dataset 5 | - LSUN dataset 6 | - MNIST dataset 7 | """ 8 | 9 | from __future__ import print_function 10 | import os 11 | import sys 12 | import gzip 13 | import json 14 | import shutil 15 | import zipfile 16 | import argparse 17 | import subprocess 18 | from six.moves import urllib 19 | 20 | parser = argparse.ArgumentParser(description='Download dataset for DCGAN.') 21 | parser.add_argument('datasets', metavar='N', type=str, nargs='+', choices=['celebA', 'lsun', 'mnist'], 22 | help='name of dataset to download [celebA, lsun, mnist]') 23 | 24 | def download(url, dirpath): 25 | filename = url.split('/')[-1] 26 | filepath = os.path.join(dirpath, filename) 27 | u = urllib.request.urlopen(url) 28 | f = open(filepath, 'wb') 29 | filesize = int(u.headers["Content-Length"]) 30 | print("Downloading: %s Bytes: %s" % (filename, filesize)) 31 | 32 | downloaded = 0 33 | block_sz = 8192 34 | status_width = 70 35 | while True: 36 | buf = u.read(block_sz) 37 | if not buf: 38 | print('') 39 | break 40 | else: 41 | print('', end='\r') 42 | downloaded += len(buf) 43 | f.write(buf) 44 | status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") % 45 | ('=' * int(float(downloaded) / filesize * status_width) + '>', downloaded * 100. / filesize)) 46 | print(status, end='') 47 | sys.stdout.flush() 48 | f.close() 49 | return filepath 50 | 51 | def unzip(filepath): 52 | print("Extracting: " + filepath) 53 | dirpath = os.path.dirname(filepath) 54 | with zipfile.ZipFile(filepath) as zf: 55 | zf.extractall(dirpath) 56 | os.remove(filepath) 57 | 58 | def download_celeb_a(dirpath): 59 | data_dir = 'celebA' 60 | if os.path.exists(os.path.join(dirpath, data_dir)): 61 | print('Found Celeb-A - skip') 62 | return 63 | url = 'https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip?dl=1&pv=1' 64 | filepath = download(url, dirpath) 65 | zip_dir = '' 66 | with zipfile.ZipFile(filepath) as zf: 67 | zip_dir = zf.namelist()[0] 68 | zf.extractall(dirpath) 69 | os.remove(filepath) 70 | os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir)) 71 | 72 | def _list_categories(tag): 73 | url = 'http://lsun.cs.princeton.edu/htbin/list.cgi?tag=' + tag 74 | f = urllib.request.urlopen(url) 75 | return json.loads(f.read()) 76 | 77 | def _download_lsun(out_dir, category, set_name, tag): 78 | url = 'http://lsun.cs.princeton.edu/htbin/download.cgi?tag={tag}' \ 79 | '&category={category}&set={set_name}'.format(**locals()) 80 | print(url) 81 | if set_name == 'test': 82 | out_name = 'test_lmdb.zip' 83 | else: 84 | out_name = '{category}_{set_name}_lmdb.zip'.format(**locals()) 85 | out_path = os.path.join(out_dir, out_name) 86 | cmd = ['curl', url, '-o', out_path] 87 | print('Downloading', category, set_name, 'set') 88 | subprocess.call(cmd) 89 | 90 | def download_lsun(dirpath): 91 | data_dir = os.path.join(dirpath, 'lsun') 92 | if os.path.exists(data_dir): 93 | print('Found LSUN - skip') 94 | return 95 | else: 96 | os.mkdir(data_dir) 97 | 98 | tag = 'latest' 99 | #categories = _list_categories(tag) 100 | categories = ['bedroom'] 101 | 102 | for category in categories: 103 | _download_lsun(data_dir, category, 'train', tag) 104 | _download_lsun(data_dir, category, 'val', tag) 105 | _download_lsun(data_dir, '', 'test', tag) 106 | 107 | def download_mnist(dirpath): 108 | data_dir = os.path.join(dirpath, 'mnist') 109 | if os.path.exists(data_dir): 110 | print('Found MNIST - skip') 111 | return 112 | else: 113 | os.mkdir(data_dir) 114 | url_base = 'http://yann.lecun.com/exdb/mnist/' 115 | file_names = ['train-images-idx3-ubyte.gz','train-labels-idx1-ubyte.gz','t10k-images-idx3-ubyte.gz','t10k-labels-idx1-ubyte.gz'] 116 | for file_name in file_names: 117 | url = (url_base+file_name).format(**locals()) 118 | print(url) 119 | out_path = os.path.join(data_dir,file_name) 120 | cmd = ['curl', url, '-o', out_path] 121 | print('Downloading ', file_name) 122 | subprocess.call(cmd) 123 | cmd = ['gzip', '-d', out_path] 124 | print('Decompressing ', file_name) 125 | subprocess.call(cmd) 126 | 127 | def prepare_data_dir(path = './data'): 128 | if not os.path.exists(path): 129 | os.mkdir(path) 130 | 131 | if __name__ == '__main__': 132 | args = parser.parse_args() 133 | prepare_data_dir() 134 | 135 | if 'celebA' in args.datasets: 136 | download_celeb_a('./data') 137 | if 'lsun' in args.datasets: 138 | download_lsun('./data') 139 | if 'mnist' in args.datasets: 140 | download_mnist('./data') -------------------------------------------------------------------------------- /data/preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import glob 4 | import h5py 5 | from PIL import Image 6 | 7 | 8 | IMG_SIZE = 64 9 | 10 | 11 | def load_Img(fname): 12 | img = Image.open(fname) 13 | img = img.resize((IMG_SIZE, IMG_SIZE), Image.ANTIALIAS) 14 | img = np.asarray(img)/255. 15 | 16 | if len(img.shape) < 3: 17 | rgb_img = np.zeros([IMG_SIZE, IMG_SIZE, 3], dtype=np.float32) 18 | rgb_img[:, :, 0] = rgb_img[:, :, 1] = rgb_img[:, :, 2] = img 19 | return rgb_img 20 | 21 | return img 22 | 23 | 24 | file_list = glob.glob('./data/celebA/*.jpg') 25 | IMG_NUM = len(file_list) 26 | 27 | rand_idx = np.arange(IMG_NUM) 28 | np.random.shuffle(rand_idx) 29 | 30 | HDF5_FILE_WRITE = 'celeba.hdf5' 31 | fw = h5py.File(HDF5_FILE_WRITE, 'w') 32 | images = fw.create_dataset('images', (IMG_NUM, IMG_SIZE, IMG_SIZE, 3), dtype='float32') 33 | 34 | for i in xrange(IMG_NUM): 35 | idx = rand_idx[i] 36 | images[i] = load_Img(file_list[idx]) 37 | 38 | if i % 1000 == 0: 39 | print '%.1f %% preprocessed.' % (100.*i/IMG_NUM) 40 | 41 | fw.close() 42 | 43 | -------------------------------------------------------------------------------- /samples/readme.txt: -------------------------------------------------------------------------------- 1 | This directory is where the samples are saved. 2 | --------------------------------------------------------------------------------