├── .gitignore ├── README.md ├── images └── model.PNG ├── layers.py ├── main.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | output/ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fader-Networks-Tensorflow 2 | 3 | It is the tensorflow implementaion of Fader Networks described in the [paper](https://arxiv.org/pdf/1706.00409.pdf) by Lample et al. In this paper, they have tried to disentangle the information of a face and changing the image by varying different attributes of the image. 4 | 5 | > This property could allow for applications where users can modify an image using sliding knobs, like faders on a mixing console, to change the facial expression of a portrait, or to update the color of some objects 6 | 7 | For this they have looked at the CelebA dataset, which is easily avaiable at their site. 8 | 9 | In the model, they have designed an Encoder such that given an input and image X it outputs the embedding of an image, which is independent of the features of the image X. So suppose if there are two images X and X' (with atrribute list Y and Y') which differ in only one feature lets say, in one image, say in X, eyes of the subject are open, while in other eyes are closed. Given this Enc(X,Y) = Enc(X',Y'). Definitely, this is the ideal case, but this is the main idea behind the fader networks. 10 | 11 | Picture below, depicts the architecture in the best way. 12 | 13 |
14 |
15 |
18 | Image from original paper 19 |
20 | 21 | 22 | It looks like a GAN-like architecture where Encoder tries to output an embedding X_emb , such that discriminator cannot guess what actual attributes of the images are, and at the same time, we train the Discriminator such that it tries to guess the attributes of image X even from the embedding X_emb. So, it acts like a two player game where performance of each will complement the performance of other in training. In the end, we will get a good encoder that can be used to create embedding from image X which is not dependent on its features. 23 | After this, we will have the decoder, which we will simply use to get the image back, given embedding and the new attributes for the output image. For training, we will feed the same attribute Y and will try to get the original image back. 24 | 25 | In whole model is an auto encoder with a mix of adverserial network. 26 | 27 | Training: 28 | 29 | Before training one needs to download the celeba dataset from the link . One also have to download the list of attributes as well. 30 | 31 | To start training the network, run the following command: 32 | 33 | ``` python main.py ``` 34 | 35 | You can also specify the dataset directory using ```--dataset``` argument to the above command. The output will be created in output folder. 36 | -------------------------------------------------------------------------------- /images/model.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardikbansal/Fader-Networks-Tensorflow/22cc9bd2c76737a28c51d9965639fad176910066/images/model.PNG -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def lrelu(x, leak=0.2, name="lrelu", alt_relu_impl=False): 6 | 7 | with tf.variable_scope(name) as scope: 8 | if alt_relu_impl: 9 | f1 = 0.5 * (1 + leak) 10 | f2 = 0.5 * (1 - leak) 11 | return f1 * x + f2 * abs(x) 12 | else: 13 | return tf.maximum(x, leak*x) 14 | 15 | def instance_norm(x): 16 | 17 | with tf.variable_scope("instance_norm") as scope: 18 | epsilon = 1e-5 19 | mean, var = tf.nn.moments(x, [1, 2], keep_dims=True) 20 | scale = tf.get_variable('scale',[x.get_shape()[-1]], 21 | initializer=tf.truncated_normal_initializer(mean=1.0, stddev=0.02)) 22 | offset = tf.get_variable('offset',[x.get_shape()[-1]],initializer=tf.constant_initializer(0.0)) 23 | out = scale*tf.div(x-mean, tf.sqrt(var+epsilon)) + offset 24 | 25 | return out 26 | 27 | 28 | def linear1d(inputlin, inputdim, outputdim, name="linear1d", std=0.02, mn=0.0): 29 | 30 | with tf.variable_scope(name) as scope: 31 | 32 | weight = tf.get_variable("weight",[inputdim, outputdim]) 33 | bias = tf.get_variable("bias",[outputdim], dtype=np.float32, initializer=tf.constant_initializer(0.0)) 34 | 35 | return tf.matmul(inputlin, weight) + bias 36 | 37 | 38 | def general_conv2d(inputconv, output_dim=64, filter_height=4, filter_width=4, stride_height=2, stride_width=2, stddev=0.02, padding="SAME", name="conv2d", do_norm=True, norm_type='batch_norm', do_relu=True, relufactor=0): 39 | with tf.variable_scope(name) as scope: 40 | 41 | conv = tf.contrib.layers.conv2d(inputconv, output_dim, [filter_width, filter_height], [stride_width, stride_height], padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev),biases_initializer=tf.constant_initializer(0.0)) 42 | if do_norm: 43 | if norm_type == 'instance_norm': 44 | conv = instance_norm(conv) 45 | elif norm_type == 'batch_norm': 46 | conv = tf.contrib.layers.batch_norm(conv, decay=0.9, updates_collections=None, epsilon=1e-5, scale=True, scope="batch_norm") 47 | 48 | if do_relu: 49 | if(relufactor == 0): 50 | conv = tf.nn.relu(conv,"relu") 51 | else: 52 | conv = lrelu(conv, relufactor, "lrelu") 53 | 54 | return conv 55 | 56 | 57 | 58 | def general_deconv2d(inputconv, output_dim=64, filter_height=4, filter_width=4, stride_height=2, stride_width=2, stddev=0.02, padding="SAME", name="deconv2d", do_norm=True, norm_type='batch_norm', do_relu=False, relufactor=0): 59 | with tf.variable_scope(name) as scope: 60 | 61 | conv = tf.contrib.layers.conv2d_transpose(inputconv, output_dim, [filter_height, filter_width], [stride_height, stride_width], padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev),biases_initializer=tf.constant_initializer(0.0)) 62 | 63 | if do_norm: 64 | if norm_type == 'instance': 65 | conv = instance_norm(conv) 66 | elif norm_type == 'batch_norm': 67 | conv = tf.contrib.layers.batch_norm(conv, decay=0.9, updates_collections=None, epsilon=1e-5, scale=True, scope="batch_norm") 68 | 69 | if do_relu: 70 | if(relufactor == 0): 71 | conv = tf.nn.relu(conv,"relu") 72 | else: 73 | conv = lrelu(conv, relufactor, "lrelu") 74 | 75 | return conv 76 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import optparse 4 | import os 5 | import shutil 6 | import time 7 | import random 8 | import sys 9 | import pickle 10 | import glob 11 | 12 | from layers import * 13 | 14 | from tensorflow.examples.tutorials.mnist import input_data 15 | from scipy.misc import imsave 16 | from scipy.misc import imresize 17 | from PIL import Image 18 | from tqdm import tqdm 19 | 20 | 21 | class Fader(): 22 | 23 | def run_parser(self): 24 | 25 | self.parser = optparse.OptionParser() 26 | 27 | self.parser.add_option('--num_iter', type='int', default=1000, dest='num_iter') 28 | self.parser.add_option('--batch_size', type='int', default=32, dest='batch_size') 29 | self.parser.add_option('--img_width', type='int', default=256, dest='img_width') 30 | self.parser.add_option('--img_height', type='int', default=256, dest='img_height') 31 | self.parser.add_option('--img_depth', type='int', default=3, dest='img_depth') 32 | self.parser.add_option('--num_attr', type='int', default=40, dest='num_attr') 33 | self.parser.add_option('--max_epoch', type='int', default=64, dest='max_epoch') 34 | self.parser.add_option('--num_train_images', type='int', default=150000, dest='num_train_images') 35 | self.parser.add_option('--num_test_images', type='int', default=50000, dest='num_test_images') 36 | self.parser.add_option('--test', action="store_true", default=False, dest="test") 37 | self.parser.add_option('--steps', type='int', default=10, dest='steps') 38 | self.parser.add_option('--enc_size', type='int', default=256, dest='enc_size') 39 | self.parser.add_option('--dec_size', type='int', default=256, dest='dec_size') 40 | self.parser.add_option('--model', type='string', default="draw_attn", dest='model_type') 41 | self.parser.add_option('--dataset', type='string', default="celebA", dest='dataset') 42 | self.parser.add_option('--dataset_dir', type='string', default="../datasets/img_align_celeba", dest='dataset_dir') 43 | self.parser.add_option('--test_dataset_dir', type='string', default="../datasets/img_align_celeba", dest='test_dataset_dir') 44 | 45 | 46 | def initialize(self): 47 | 48 | self.run_parser() 49 | 50 | opt = self.parser.parse_args()[0] 51 | 52 | self.max_epoch = opt.max_epoch 53 | self.batch_size = opt.batch_size 54 | self.dataset = opt.dataset 55 | 56 | self.img_width = opt.img_width 57 | self.img_height = opt.img_height 58 | self.img_depth = opt.img_depth 59 | 60 | self.img_size = self.img_width*self.img_height*self.img_depth 61 | self.num_attr = opt.num_attr 62 | self.num_train_images = opt.num_train_images 63 | self.num_test_images = opt.num_test_images 64 | self.model = "Fader" 65 | self.to_test = opt.test 66 | self.load_checkpoint = False 67 | self.do_setup = True 68 | self.dataset_dir = opt.dataset_dir 69 | 70 | self.tensorboard_dir = "./output/" + self.model + "/" + self.dataset + "/tensorboard" 71 | self.check_dir = "./output/"+ self.model + "/" + self.dataset +"/checkpoints" 72 | self.images_dir = "./output/" + self.model + "/" + self.dataset + "/imgs" 73 | 74 | def normalize_input(self, imgs): 75 | return imgs/127.5-1.0 76 | 77 | 78 | def transform_attr(self, attr): 79 | 80 | temp_shape = len(attr) 81 | 82 | final_attr = np.zeros([temp_shape, 2*self.num_attr]) 83 | 84 | for i in range(0, temp_shape): 85 | for j in range(0, self.num_attr): 86 | final_attr[i][2*j+attr[i][j]] = 1 87 | 88 | return final_attr 89 | 90 | 91 | def load_dataset(self, mode='train'): 92 | 93 | if(mode == "train"): 94 | 95 | self.train_attr = [] 96 | 97 | imageFolderPath = self.dataset_dir 98 | self.imagePath = glob.glob(imageFolderPath+'/*.jpg') 99 | 100 | dictn = [] 101 | 102 | count = 0 103 | with open(self.dataset_dir+"/train_attr.txt") as f: 104 | for lines in f: 105 | temp = lines 106 | temp = temp.split() 107 | dictn.append(temp[1:]) 108 | 109 | for i in range(self.num_train_images): 110 | self.train_attr.append(((np.array(dictn[i]).astype(np.int32)+1)/2).astype(np.int32)) 111 | 112 | self.train_attr_1h = self.transform_attr(self.train_attr) 113 | # print(self.train_attr[0:10]) 114 | 115 | elif (mode == "test"): 116 | 117 | self.test_attr = [] 118 | 119 | imageFolderPath = self.test_dataset_dir 120 | self.imagePath = glob.glob(imageFolderPath+'/*.jpg') 121 | 122 | dictn = [] 123 | 124 | count = 0 125 | with open(self.dataset_dir+"/test_attr.txt") as f: 126 | for lines in f: 127 | temp = lines 128 | temp = temp.split() 129 | dictn.append(temp[1:]) 130 | 131 | for i in range(self.num_test_images): 132 | self.test_attr.append(np.array(dictn[i])) 133 | 134 | 135 | def load_batch(self, batch_num, batch_sz, mode="train"): 136 | 137 | if(mode == "train"): 138 | temp = [] 139 | for i in range(batch_sz): 140 | temp.append(self.normalize_input(imresize(np.array(Image.open(self.imagePath[i + batch_sz*(batch_num)]),'f')[:,39:216,:], size=[256,256,3], interp="bilinear"))) 141 | return temp 142 | 143 | elif (mode == "test"): 144 | temp = [] 145 | for i in range(batch_sz): 146 | temp.append(self.normalize_input(imresize(np.array(Image.open(self.imagePath[num_train_images + i + batch_sz*(batch_num)]),'f')[:,39:216,:], size=[256,256,3], interp="bilinear"))) 147 | return temp 148 | 149 | 150 | def encoder(self, input_enc, name="Encoder"): 151 | 152 | with tf.variable_scope(name) as scope: 153 | 154 | o_c1 = general_conv2d(input_enc, 16, name="C16", relufactor=0.2) 155 | o_c2 = general_conv2d(o_c1, 32, name="C32", relufactor=0.2) 156 | o_c3 = general_conv2d(o_c2, 64, name="C64", relufactor=0.2) 157 | o_c4 = general_conv2d(o_c3, 128, name="C128", relufactor=0.2) 158 | o_c5 = general_conv2d(o_c4, 256, name="C256", relufactor=0.2) 159 | o_c6 = general_conv2d(o_c5, 512, name="C512_1", relufactor=0.2) 160 | o_c7 = general_conv2d(o_c6, 512, name="C512_2", relufactor=0.2) 161 | 162 | 163 | return o_c7 164 | 165 | def decoder(self, input_dec, attr, name="Decoder"): 166 | 167 | with tf.variable_scope(name) as scope: 168 | 169 | attr_1 = tf.transpose(tf.stack([attr]*4),[1, 0, 2]) 170 | o_d0 = tf.concat([input_dec, tf.reshape(attr_1,[-1, 2, 2, 2*self.num_attr])], 3) 171 | o_d1 = general_deconv2d(o_d0, 512, name="D512_2") 172 | 173 | attr_2 = tf.concat([attr_1]*4,axis=1) 174 | o_d1 = tf.concat([o_d1, tf.reshape(attr_2,[-1, 4, 4, 2*self.num_attr])], 3) 175 | o_d2 = general_deconv2d(o_d1, 256, name="D512_1") 176 | 177 | attr_3 = tf.concat([attr_2]*4,axis=1) 178 | o_d2 = tf.concat([o_d2, tf.reshape(attr_3, [-1, 8, 8, 2*self.num_attr])], 3) 179 | o_d3 = general_deconv2d(o_d2, 128, name="D256") 180 | 181 | attr_4 = tf.concat([attr_3]*4,axis=1) 182 | o_d3 = tf.concat([o_d3, tf.reshape(attr_4, [-1, 16, 16, 2*self.num_attr])], 3) 183 | o_d4 = general_deconv2d(o_d3, 64, name="D128") 184 | 185 | attr_5 = tf.concat([attr_4]*4,axis=1) 186 | o_d4 = tf.concat([o_d4, tf.reshape(attr_5, [-1, 32, 32, 2*self.num_attr])], 3) 187 | o_d5 = general_deconv2d(o_d4, 32, name="D64") 188 | 189 | attr_6 = tf.concat([attr_5]*4,axis=1) 190 | o_d5 = tf.concat([o_d5, tf.reshape(attr_6, [-1, 64, 64, 2*self.num_attr])], 3) 191 | o_d6 = general_deconv2d(o_d5, 16, name="D32") 192 | 193 | attr_7 = tf.concat([attr_6]*4,axis=1) 194 | o_d6 = tf.concat([o_d6, tf.reshape(attr_7, [-1, 128, 128, 2*self.num_attr])], 3) 195 | o_d7 = general_deconv2d(o_d6, 3, name="D16") 196 | 197 | o_d7 = tf.nn.tanh(o_d7) 198 | 199 | return o_d7 200 | 201 | 202 | def discriminator(self, input_disc, name="Discriminator"): 203 | 204 | with tf.variable_scope(name) as scope: 205 | 206 | o_disc1 = general_conv2d(input_disc, 512, name="C512") 207 | o_disc1 = tf.layers.dropout(o_disc1, rate=0.3) 208 | size_disc = o_disc1.get_shape().as_list() 209 | o_flat = tf.reshape(o_disc1,[self.batch_size, 512]) 210 | o_disc2 = linear1d(o_flat, 512, 512, name="fc1") 211 | o_disc3 = linear1d(o_disc2, 512, self.num_attr, name="fc2") 212 | 213 | return tf.nn.sigmoid(o_disc3) 214 | 215 | 216 | def celeb_model_setup(self): 217 | 218 | self.input_imgs = tf.placeholder(tf.float32, [self.batch_size, self.img_height, self.img_width, self.img_depth]) 219 | self.input_attr = tf.placeholder(tf.float32, [self.batch_size, self.num_attr]) 220 | self.input_attr_1h = tf.placeholder(tf.float32, [self.batch_size, 2*self.num_attr]) 221 | self.lmda = tf.placeholder(tf.float32,[1]) 222 | 223 | with tf.variable_scope("Model") as scope: 224 | 225 | self.o_enc = self.encoder(self.input_imgs) 226 | self.o_dec = self.decoder(self.o_enc, self.input_attr_1h) 227 | self.o_disc = self.discriminator(self.o_enc) 228 | 229 | def model_setup(self): 230 | 231 | with tf.variable_scope("Model") as scope: 232 | self.celeb_model_setup() 233 | 234 | self.model_vars = tf.trainable_variables() 235 | for var in self.model_vars: print(var.name, var.get_shape()) 236 | 237 | self.do_setup = False 238 | 239 | def generation_loss(self, input_img, output_img, loss_type='mse'): 240 | 241 | if (loss_type == 'mse'): 242 | return tf.reduce_sum(tf.squared_difference(input_img, output_img), [1, 2, 3]) 243 | elif (loss_type == 'log_diff'): 244 | epsilon = 1e-8 245 | return -tf.reduce_sum(input_img*tf.log(output_img + epsilon) + (1 - input_img)*tf.log(epsilon + 1 - output_img),[1, 2, 3]) 246 | 247 | 248 | def discriminator_loss(self, out_attr, inp_attr): 249 | 250 | epsilon = 1e-8 251 | return tf.reduce_sum(tf.log(tf.abs(out_attr-inp_attr) + epsilon),1) 252 | 253 | def loss_setup(self): 254 | 255 | self.img_loss = tf.reduce_mean(self.generation_loss(self.input_imgs, self.o_dec)) 256 | self.enc_loss = tf.reduce_mean(self.discriminator_loss(self.o_disc, self.input_attr)) 257 | 258 | self.disc_loss = -tf.reduce_mean(self.discriminator_loss(self.o_disc, 1-self.input_attr)) 259 | self.enc_dec_loss = self.img_loss - self.lmda*self.enc_loss 260 | 261 | optimizer = tf.train.AdamOptimizer(0.002, beta1=0.5) 262 | 263 | enc_dec_vars = [var for var in self.model_vars if 'coder' in var.name] 264 | disc_vars = [var for var in self.model_vars if 'Discriminator' in var.name] 265 | 266 | self.enc_dec_loss_optimizer = optimizer.minimize(self.enc_dec_loss, var_list=enc_dec_vars) 267 | self.disc_loss_optimizer = optimizer.minimize(self.disc_loss, var_list=disc_vars) 268 | 269 | self.img_loss_summ = tf.summary.scalar("img_loss", self.img_loss) 270 | self.enc_loss_summ = tf.summary.scalar("enc_loss", self.enc_loss) 271 | self.disc_loss_summ = tf.summary.scalar("disc_loss", self.disc_loss) 272 | 273 | 274 | def train(self): 275 | 276 | self.model_setup() 277 | self.loss_setup() 278 | self.load_dataset() 279 | 280 | 281 | init = tf.global_variables_initializer() 282 | saver = tf.train.Saver() 283 | 284 | if not os.path.exists(self.images_dir+"/train/"): 285 | os.makedirs(self.images_dir+"/train/") 286 | if not os.path.exists(self.check_dir): 287 | os.makedirs(self.check_dir) 288 | 289 | with tf.Session() as sess: 290 | 291 | sess.run(init) 292 | writer = tf.summary.FileWriter(self.tensorboard_dir) 293 | writer.add_graph(sess.graph) 294 | 295 | if self.load_checkpoint: 296 | chkpt_fname = tf.train.latest_checkpoint(self.check_dir) 297 | saver.restore(sess,chkpt_fname) 298 | 299 | per_epoch_steps = int(self.num_train_images/self.batch_size) 300 | 301 | t = time.time() 302 | 303 | for epoch in range(0, self.max_epoch): 304 | for itr in range(0, per_epoch_steps): 305 | 306 | temp_lmd = 0.0001*(epoch*per_epoch_steps + itr)/(per_epoch_steps*self.max_epoch) 307 | 308 | imgs = self.load_batch(itr, self.batch_size) 309 | attrs = self.train_attr[itr*self.batch_size:(itr+1)*(self.batch_size)] 310 | attrs_1h = self.train_attr_1h[itr*self.batch_size:(itr+1)*(self.batch_size)] 311 | 312 | print("In the iteration", itr, "of epoch", epoch) 313 | print(time.time() - t) 314 | 315 | _, temp_tot_loss, temp_img_loss, temp_enc_loss, img_loss_str, enc_loss_str = sess.run( 316 | [self.enc_dec_loss_optimizer, self.enc_dec_loss, self.img_loss, self.enc_loss, self.img_loss_summ, self.enc_loss_summ], 317 | feed_dict={self.input_imgs:imgs, self.input_attr_1h:attrs_1h, self.input_attr:attrs, self.lmda:[temp_lmd]}) 318 | 319 | 320 | _, temp_disc_loss, disc_loss_str, temp_o_disc = sess.run( 321 | [self.disc_loss_optimizer, self.disc_loss, self.disc_loss_summ, self.o_disc], 322 | feed_dict={self.input_imgs:imgs, self.input_attr_1h:attrs_1h, self.input_attr:attrs, self.lmda:[temp_lmd]}) 323 | 324 | writer.add_summary(img_loss_str,epoch*per_epoch_steps + itr) 325 | writer.add_summary(enc_loss_str,epoch*per_epoch_steps + itr) 326 | writer.add_summary(disc_loss_str,epoch*per_epoch_steps + itr) 327 | 328 | print(temp_tot_loss, temp_img_loss, temp_enc_loss, temp_disc_loss) 329 | 330 | saver.save(sess,os.path.join(self.check_dir,"Fader"),global_step=epoch) 331 | 332 | def test(self): 333 | 334 | self.model_setup() 335 | self.load_dataset(mode="test") 336 | 337 | if not os.path.exists(self.images_dir+"/test/"): 338 | os.makedirs(self.images_dir+"/test/") 339 | if not os.path.exists(self.check_dir): 340 | os.makedirs(self.check_dir) 341 | 342 | 343 | 344 | with tf.Session() as sess: 345 | 346 | chkpt_fname = tf.train.latest_checkpoint(self.check_dir) 347 | saver.restore(sess, chkpt_fname) 348 | 349 | for itr in range(0, int(self.num_test_images/self.batch_size)): 350 | 351 | imgs = self.load_batch(itr, self.batch_size, mode="test") 352 | attrs = self.test_attr[itr*self.batch_size:(itr+1)*(self.batch_size)] 353 | 354 | temp_output = sess.run([self.o_dec], feed_dict={self.input_imgs:imgs, self.input_attr:attrs}) 355 | 356 | 357 | 358 | 359 | def main(): 360 | 361 | model = Fader() 362 | model.initialize() 363 | 364 | if(model.to_test): 365 | model.test() 366 | else: 367 | model.train() 368 | 369 | 370 | if __name__ == "__main__": 371 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def flat_batch(input_batch, batch_size, final_width, final_height): 4 | if(final_height*final_width != batch_size): 5 | print("Error flattening batch size. batch_size != final_width*final_height") 6 | else: 7 | h = 256 8 | w = 256 9 | output = np.zeros((final_width*h, final_height*w)) 10 | 11 | for idx, image in enumerate(input_batch): 12 | i = idx % final_width 13 | j = int(idx / final_height) 14 | output[j*h:j*h+h, i*w:i*w+w] = image.reshape((w,h)) 15 | 16 | return output 17 | --------------------------------------------------------------------------------