├── LICENSE ├── README.md ├── data.py ├── discoGAN.py ├── main.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Josh Miller 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiscoGAN for Tensorflow 2 | An implementation of [Learning to Discover Cross-Domain Relations with Generative Adversarial Networks](https://arxiv.org/abs/1703.05192) written in tensorflow. 3 | 4 | ## Requirements 5 | - Tensorflow 1.0.1 6 | - scipy 7 | 8 | ## Training 9 | `python main.py` 10 | 11 | ## Training details 12 | Currently the data utils file works on domains from the celeba dataset 13 | 14 | ## Remarks 15 | As it currently stands, I have refactored much of the model and extracted it to `discoGAN.py`. I will soon be making it take command line arguments, download datasets automatically, etc. As mentioned before, there are now some barebones utilities to work with the celeba dataset. 16 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import zipfile 3 | import os 4 | import random 5 | import scipy.misc 6 | try: 7 | from StringIO import StringIO 8 | except ImportError: 9 | from io import StringIO 10 | 11 | domainA = [] 12 | domainB = [] 13 | 14 | im_size = 64 15 | 16 | """ 17 | Preload celeba dataset 18 | """ 19 | def load_celeba(): 20 | print("loading celeba") 21 | f = open('images/list_attr_celeba.txt') 22 | t = f.readline() 23 | t = f.readline() 24 | t = f.readline() 25 | while t: 26 | strs = t.split() 27 | fname = strs[0] 28 | att = int(strs[3]) 29 | if att == -1: 30 | domainA.append(fname) 31 | else: 32 | domainB.append(fname) 33 | t = f.readline() 34 | 35 | 36 | def set_im_size(size): 37 | im_size = size 38 | 39 | """ 40 | Get batch from domain 41 | """ 42 | def get_batch(size,domain): 43 | if len(domainA) == 0 or len(domainB) == 0: 44 | load_celeba() 45 | if domain == 'a': 46 | samples = domainA 47 | else: 48 | samples = domainB 49 | images = [] 50 | indices = random.sample(range(0,len(samples)),size) 51 | zfile = zipfile.ZipFile('images/celeba.zip','r') 52 | for i in indices: 53 | data = zfile.read('img_align_celeba/'+samples[i]) 54 | tmp = open('temp.jpg','wb') 55 | tmp.write(data) 56 | tmp.close() 57 | #print(StringIO(data)) 58 | img = scipy.misc.imread('temp.jpg') 59 | img = preprocess(img) 60 | images.append(img) 61 | os.remove('temp.jpg') 62 | 63 | return np.asarray(images) 64 | 65 | 66 | 67 | """ 68 | Preprocess image 69 | """ 70 | def preprocess(img): 71 | img = scipy.misc.imresize(img,[im_size,im_size]) 72 | img = img.astype(np.float32)/127.5 - 1. 73 | return img 74 | 75 | """ 76 | Postprocess image 77 | """ 78 | def postprocess(img): 79 | return (img+1.)/2 80 | 81 | """ 82 | Save image 83 | """ 84 | def save(filename,img): 85 | scipy.misc.imsave(filename,postprocess(img)) 86 | 87 | 88 | -------------------------------------------------------------------------------- /discoGAN.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import utils 3 | import data 4 | import os 5 | 6 | """ 7 | This object represents a discogan 8 | machine learning model. It comes with 9 | functions to train and restore weights 10 | for a DiscoGAN 11 | """ 12 | class DiscoGAN(object): 13 | def __init__(self,batch_size=10,im_size=64,channels=3,dtype=tf.float32,analytics=True): 14 | self.analytics = analytics 15 | self.batch_size = batch_size 16 | 17 | self.x_a = tf.placeholder(dtype,[None,im_size,im_size,channels],name='xa') 18 | self.x_b = tf.placeholder(dtype,[None,im_size,im_size,channels],name='xb') 19 | 20 | #Generator Networks 21 | self.g_ab = utils.generator(self.x_a,name="gen_AB",im_size=im_size) 22 | self.g_ba = utils.generator(self.x_b,name="gen_BA",im_size=im_size) 23 | 24 | #Secondary generator networks, reusing params of previous two 25 | self.g_aba = utils.generator(self.g_ab,name="gen_BA",im_size=im_size,reuse=True) 26 | self.g_bab = utils.generator(self.g_ba,name="gen_AB",im_size=im_size,reuse=True) 27 | 28 | #Discriminator for input a 29 | self.disc_a_real = utils.discriminator(self.x_a,name="disc_a",im_size=im_size) 30 | self.disc_a_fake = utils.discriminator(self.g_ba,name="disc_a",im_size=im_size,reuse=True) 31 | 32 | #Discriminator for input b 33 | self.disc_b_real = utils.discriminator(self.x_b,name="disc_b") 34 | self.disc_b_fake = utils.discriminator(self.g_ab,name="disc_b",reuse=True) 35 | 36 | #Reconstruction loss for generators 37 | self.l_const_a = tf.reduce_mean(utils.huber_loss(self.g_aba,self.x_a)) 38 | self.l_const_b = tf.reduce_mean(utils.huber_loss(self.g_bab,self.x_b)) 39 | 40 | #Generation loss for generators 41 | self.l_gan_a = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_a_fake,labels=tf.ones_like(self.disc_a_fake))) 42 | self.l_gan_b = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_b_fake,labels=tf.ones_like(self.disc_b_fake))) 43 | 44 | #Real example loss for discriminators 45 | self.l_disc_a_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_a_real,labels=tf.ones_like(self.disc_a_real))) 46 | self.l_disc_b_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_b_real,labels=tf.ones_like(self.disc_b_real))) 47 | 48 | #Fake example loss for discriminators 49 | self.l_disc_a_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_a_fake,labels=tf.zeros_like(self.disc_a_fake))) 50 | self.l_disc_b_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_b_fake,labels=tf.zeros_like(self.disc_b_fake))) 51 | 52 | #Combined loss for individual discriminators 53 | self.l_disc_a = self.l_disc_a_real + self.l_disc_a_fake 54 | self.l_disc_b = self.l_disc_b_real + self.l_disc_b_fake 55 | 56 | #Total discriminator loss 57 | self.l_disc = self.l_disc_a + self.l_disc_b 58 | 59 | #Combined loss for individual generators 60 | self.l_ga = self.l_gan_a + self.l_const_b 61 | self.l_gb = self.l_gan_b + self.l_const_a 62 | 63 | #Total GAN loss 64 | self.l_g = self.l_ga + self.l_gb 65 | 66 | #Parameter Lists 67 | self.disc_params = [] 68 | self.gen_params = [] 69 | 70 | for v in tf.trainable_variables(): 71 | if 'disc' in v.name: 72 | self.disc_params.append(v) 73 | if 'gen' in v.name: 74 | self.gen_params.append(v) 75 | 76 | if self.analytics: 77 | self.init_analytics() 78 | 79 | self.gen_a_dir = 'generator a->b' 80 | self.gen_b_dir = 'generator b->a' 81 | self.rec_a_dir = 'reconstruct a' 82 | self.rec_b_dir = 'reconstruct b' 83 | self.model_directory = "models" 84 | 85 | if not os.path.exists(self.gen_a_dir): 86 | os.makedirs(self.gen_a_dir) 87 | if not os.path.exists(self.gen_b_dir): 88 | os.makedirs(self.gen_b_dir) 89 | if not os.path.exists(self.rec_b_dir): 90 | os.makedirs(self.rec_b_dir) 91 | if not os.path.exists(self.rec_a_dir): 92 | os.makedirs(self.rec_a_dir) 93 | 94 | self.sess = tf.Session() 95 | self.saver = tf.train.Saver() 96 | 97 | """ 98 | Enable logging of analytics 99 | for tensorboard 100 | """ 101 | def init_analytics(self): 102 | #Scalars for all losses 103 | tf.summary.scalar("loss_g", self.l_g) 104 | tf.summary.scalar("loss_ga", self.l_ga) 105 | tf.summary.scalar("loss_gb", self.l_gb) 106 | tf.summary.scalar("loss_d", self.l_disc) 107 | tf.summary.scalar("loss_d_a", self.l_disc_a) 108 | tf.summary.scalar("loss_d_b", self.l_disc_b) 109 | tf.summary.scalar("l_const_a",self.l_const_a) 110 | tf.summary.scalar("l_const_b",self.l_const_b) 111 | 112 | #Histograms for all vars 113 | for v in tf.trainable_variables(): 114 | tf.summary.histogram(v.name,v) 115 | 116 | self.merged_summary_op = tf.summary.merge_all() 117 | 118 | """ 119 | Train DiscoGAN 120 | """ 121 | def train(self,LR=2e-4,B1=0.5,B2=0.999,iterations=50000,sample_frequency=10, 122 | sample_overlap=500,save_frequency=1000,domain_a="a",domain_b="b"): 123 | self.trainer_D = tf.train.AdamOptimizer(LR,beta1=B1,beta2=B2).minimize(self.l_disc,var_list=self.disc_params) 124 | self.trainer_G = tf.train.AdamOptimizer(LR,beta1=B1,beta2=B2).minimize(self.l_g,var_list=self.gen_params) 125 | 126 | with self.sess as sess: 127 | sess.run(tf.global_variables_initializer()) 128 | if self.analytics: 129 | if not os.path.exists("logs"): 130 | os.makedirs("logs") 131 | self.summary_writer = tf.summary.FileWriter(os.getcwd()+'/logs',graph=sess.graph) 132 | for i in range(iterations): 133 | realA = data.get_batch(self.batch_size,domain_a) 134 | realB = data.get_batch(self.batch_size,domain_b) 135 | op_list = [self.trainer_D,self.l_disc,self.trainer_G,self.l_g,self.merged_summary_op] 136 | 137 | _,dLoss,_,gLoss,summary_str = sess.run(op_list,feed_dict={self.x_a:realA,self.x_b:realB}) 138 | 139 | realA = data.get_batch(self.batch_size,domain_a) 140 | realB = data.get_batch(self.batch_size,domain_b) 141 | 142 | _,gLoss = sess.run([self.trainer_G,self.l_g],feed_dict={self.x_a:realA,self.x_b:realB}) 143 | 144 | if i%10 == 0: 145 | self.summary_writer.add_summary(summary_str, i) 146 | 147 | print("Generator Loss: " + str(gLoss) + "\tDiscriminator Loss: " + str(dLoss)) 148 | 149 | if i % sample_frequency == 0: 150 | realA = data.get_batch(1,domain_a) 151 | realB = data.get_batch(1,domain_b) 152 | ops = [self.g_ba,self.g_ab,self.g_aba,self.g_bab] 153 | out_a,out_b,out_ab,out_ba = sess.run(ops,feed_dict={self.x_a:realA,self.x_b:realB}) 154 | data.save(self.gen_a_dir+"/img"+str(i%sample_overlap)+'.png',out_a[0]) 155 | data.save(self.gen_b_dir+"/img"+str(i%sample_overlap)+'.png',out_b[0]) 156 | data.save(self.rec_a_dir+"/img"+str(i%sample_overlap)+'.png',out_ba[0]) 157 | data.save(self.rec_b_dir+"/img"+str(i%sample_overlap)+'.png',out_ab[0]) 158 | if i % save_frequency == 0: 159 | if not os.path.exists(self.model_directory): 160 | os.makedirs(self.model_directory) 161 | self.saver.save(sess,self.model_directory+'/model-'+str(i)+'.ckpt') 162 | print("Saved Model") 163 | 164 | """ 165 | Restore previously saved weights from 166 | trained / in-progress model 167 | """ 168 | def restore(): 169 | try: 170 | self.saver.restore(self.sess, tf.train.latest_checkpoint(self.model_directory)) 171 | except: 172 | print("Previous weights not found") 173 | 174 | 175 | 176 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from discoGAN import DiscoGAN 3 | import utils 4 | 5 | print("Building Model") 6 | network = DiscoGAN() 7 | print("Beginning training") 8 | network.train() 9 | 10 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class batch_norm(object): 4 | def __init__(self, epsilon=1e-3, momentum = 0.9, name="batch_norm"): 5 | with tf.variable_scope(name): 6 | self.epsilon = epsilon 7 | self.momentum = momentum 8 | self.name = name 9 | 10 | def __call__(self, x, train=True): 11 | return tf.contrib.layers.batch_norm(x, 12 | decay=self.momentum, 13 | updates_collections=None, 14 | epsilon=self.epsilon, 15 | scale=True, 16 | is_training=train, 17 | scope=self.name) 18 | 19 | """ 20 | Helper for convolution function given an input and conv weights 21 | """ 22 | def conv2d(x,w,stride=2,padding="SAME"): 23 | return tf.nn.conv2d(x,w,strides=[1,stride,stride,1],padding=padding) 24 | 25 | """ 26 | Helper for convolution transpose function given an input and conv weights 27 | """ 28 | def conv2d_t(x,w,shape,stride=2,padding="SAME"): 29 | return tf.nn.conv2d_transpose(x,w,shape,strides=[1,stride,stride,1],padding=padding) 30 | 31 | 32 | """ 33 | Helper for weight variable given shape. Does fancy initialization 34 | """ 35 | def weight_var(shape,name='w',init=tf.truncated_normal_initializer(stddev=0.02)): 36 | return tf.get_variable(name,shape,initializer=init) 37 | 38 | """ 39 | Same as above, but for bias 40 | """ 41 | def bias_var(shape,name='b'): 42 | return tf.get_variable(name,shape,initializer=tf.constant_initializer(0.0)) 43 | 44 | """ 45 | Special activation function we'll be using: leaky relu 46 | """ 47 | def lrelu(x,alpha=0.2): 48 | return tf.maximum(alpha*x,x) 49 | 50 | """ 51 | Nice helper function for creating convolutional layers 52 | """ 53 | def conv_layer(x,w_shape,b_shape,activation=tf.nn.relu,batch_norm=None,stride=2,name="conv2d",reuse=False): 54 | with tf.variable_scope(name) as scope: 55 | if reuse: 56 | scope.reuse_variables() 57 | w = weight_var(w_shape,name="w"+name) 58 | b = bias_var(b_shape,name="b"+name) 59 | h = conv2d(x,w,stride=stride)+b 60 | if batch_norm is not None: 61 | h = batch_norm(h) 62 | if activation: 63 | h = activation(h) 64 | return h 65 | 66 | """ 67 | Nice helper function for creating convolutional transpose layers 68 | """ 69 | def conv_layer_t(x,w_shape,b_shape,shape,activation=tf.nn.relu,batch_norm=None,stride=2,name="deconv2d",reuse=False): 70 | with tf.variable_scope(name) as scope: 71 | if reuse: 72 | scope.reuse_variables() 73 | w = weight_var(w_shape,name="w"+name) 74 | b = bias_var(b_shape,name="b"+name) 75 | for i in range(1,len(shape)): 76 | if isinstance(shape[i],float): 77 | shape[i] = int(shape[i]) 78 | h = conv2d_t(x,w,shape)+b 79 | if batch_norm is not None: 80 | h = batch_norm(h) 81 | if activation: 82 | h = activation(h) 83 | return h 84 | 85 | """ 86 | Another helper for batch norm. This should help a bit 87 | """ 88 | def batch_norm_layer(x,is_training=True): 89 | return tf.contrib.layers.batch_norm(x,is_training=is_training,epsilon=1e-4,trainable=True) 90 | 91 | """ 92 | Helper function for creating fully connected layers. Droupout and batchnorm optional 93 | """ 94 | def fc_layer(x,w_shape,b_shape,activation=tf.nn.relu,batch_norm=True,dropout=True,name="linear",reuse=False): 95 | with tf.variable_scope(name) as scope: 96 | if reuse: 97 | scope.reuse_variables() 98 | w = weight_var(w_shape) 99 | b = bias_var(b_shape) 100 | h = tf.matmul(x,w)+b 101 | if activation: 102 | h = activation(h) 103 | if batch_norm is not None: 104 | h = batch_norm(h) 105 | if dropout: 106 | h = tf.nn.dropout(h,keep_prob) 107 | return h 108 | 109 | """ 110 | Huber loss function 111 | """ 112 | def huber_loss(logits,labels,max_gradient=1.0): 113 | err = tf.abs(labels-logits) 114 | mg = tf.constant(max_gradient) 115 | lin = mg*(err-0.5*mg) 116 | quad = 0.5*err*err 117 | return tf.where(err= 1 else 1 135 | im_div_4 = im_size/4 if int(im_size/4) >= 1 else 1 136 | im_div_8 = im_size/8 if int(im_size/8) >= 1 else 1 137 | 138 | conv_1 = conv_layer(x,[4,4,int(x.get_shape()[-1]),im_size],[im_size],activation=lrelu,batch_norm=None,name="g_conv_1",reuse=reuse) 139 | conv_2 = conv_layer(conv_1,[4,4,int(conv_1.get_shape()[-1]),im_size*2],[im_size*2],activation=lrelu,batch_norm=g_bn0,name="g_conv_2",reuse=reuse) 140 | conv_3 = conv_layer(conv_2,[4,4,int(conv_2.get_shape()[-1]),im_size*4],[im_size*4],activation=lrelu,batch_norm=g_bn1,name="g_conv_3",reuse=reuse) 141 | conv_4 = conv_layer(conv_3,[4,4,int(conv_3.get_shape()[-1]),im_size*8],[im_size*8],activation=lrelu,batch_norm=g_bn2,name="g_conv_4",reuse=reuse) 142 | conv_t_1 = conv_layer_t(conv_4,[4,4,im_size,int(conv_4.get_shape()[-1])],[im_size],[tf.shape(x)[0],im_div_8,im_div_8,im_size],activation=lrelu,batch_norm=g_bn3,name="g_deconv_1",reuse=reuse) 143 | conv_t_2 = conv_layer_t(conv_t_1,[4,4,im_div_2,int(conv_t_1.get_shape()[-1])],[im_div_2],[tf.shape(x)[0],im_div_4,im_div_4,im_div_2],activation=lrelu,batch_norm=g_bn4,name="g_deconv_2",reuse=reuse) 144 | conv_t_3 = conv_layer_t(conv_t_2,[4,4,im_div_4,int(conv_t_2.get_shape()[-1])],[im_div_4],[tf.shape(x)[0],im_div_2,im_div_2,im_div_4],activation=lrelu,batch_norm=g_bn5,name="g_deconv_3",reuse=reuse) 145 | conv_t_4 = conv_layer_t(conv_t_3,[4,4,channels,int(conv_t_3.get_shape()[-1])],[channels],[tf.shape(x)[0],im_size,im_size,3],activation=None,batch_norm=None,name="g_deconv_4",reuse=reuse) 146 | 147 | out = conv_t_4 148 | 149 | return out 150 | 151 | """ 152 | Helper function to build discriminator network 153 | """ 154 | def discriminator(x,name="discriminator",im_size=64,reuse=False): 155 | with tf.variable_scope(name) as scope: 156 | if reuse: 157 | scope.reuse_variables() 158 | 159 | d_bn0 = batch_norm(name='d_bn0') 160 | d_bn1 = batch_norm(name='d_bn1') 161 | d_bn2 = batch_norm(name='d_bn2') 162 | 163 | im_div_2 = im_size/2 if int(im_size/2) >= 1 else 1 164 | im_div_4 = im_size/4 if int(im_size/4) >= 1 else 1 165 | im_div_8 = im_size/8 if int(im_size/8) >= 1 else 1 166 | im_div_16= im_size/16 if int(im_size/16) >= 1 else 1 167 | 168 | conv_1 = conv_layer(x,[4,4,int(x.get_shape()[-1]),im_div_2],[im_div_2],activation=lrelu,batch_norm=None,name="d_conv_1",reuse=reuse) 169 | conv_2 = conv_layer(conv_1,[4,4,int(conv_1.get_shape()[-1]),im_div_4],[im_div_4],activation=lrelu,batch_norm=d_bn0,name="d_conv_2",reuse=reuse) 170 | conv_3 = conv_layer(conv_2,[4,4,int(conv_2.get_shape()[-1]),im_div_8],[im_div_8],activation=lrelu,batch_norm=d_bn1,name="d_conv_3",reuse=reuse) 171 | conv_4 = conv_layer(conv_3,[4,4,int(conv_3.get_shape()[-1]),im_div_16],[im_div_16],activation=lrelu,batch_norm=d_bn2,name="d_conv_4",reuse=reuse) 172 | out = conv_layer(conv_4,[4,4,int(conv_4.get_shape()[-1]),1],[1],activation=None,stride=4,batch_norm=None,name="d_conv_5",reuse=reuse) 173 | out = tf.squeeze(out) 174 | return out 175 | 176 | 177 | --------------------------------------------------------------------------------