├── README.md ├── download.py ├── img ├── real.png └── recon.png ├── main.py ├── msssim.py ├── ops.py ├── utils.py └── vaegan.py /README.md: -------------------------------------------------------------------------------- 1 | # VAE/GAN 2 | the tensorflow code of [Autoencoding beyond pixels using a learned similarity metric](https://arxiv.org/abs/1512.09300v2) 3 | 4 | The paper should be the first one to combine the Variational Autoencoder(VAE) and Generative Adversarial Networks(GAN), by 5 | using the discrimiator of GAN as the perceptual loss instead of the pixel-wise loss in the original VAE. VAE/GAN also can be used for image 6 | reconstruction and visual attribution manipulation. 7 | 8 | ## About training instability 9 | 10 | I also found the training is very instability. So, I update the code to stablize the adversarial progress of VAE/GAN. The details is in the below. 11 | 12 | - Add a trick, named label smoothing [Improved Techniques for Training GANs](https://arxiv.org/abs/1606.03498) 13 | 14 | ## Pretrained models. 15 | 16 | The checkpoints files can be downloads from [Google Drive](https://drive.google.com/open?id=1E5FWN6Xqg65bmXT5mtY8nmuLREz4gLoZ). Please unzip the files inside the project directory. Later, I will update the new models after more training iterations. 17 | 18 | ## Prerequisites 19 | 20 | - tensorflow >=1.4 21 | 22 | ## dataset requirement 23 | 24 | You can download the [Align and Cropped CelebA dataset](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) 25 | and unzip CelebA into a directory. Noted that this directory don't contain the sub-directory. 26 | 27 | ## Usage 28 | 29 | Train: 30 | 31 | $ python main.py --op 0 --path your data path 32 | 33 | Test: 34 | 35 | $ python main.py --op 1 --path your data path 36 | 37 | ## Experiments visual result 38 | 39 | Input: 40 | 41 | ![](img/real.png) 42 | 43 | Reconstruction 44 | 45 | ![](img/recon.png) 46 | 47 | 48 | ## Issue 49 | If you find the bug and problem, Thanks for your issue to propose it. 50 | 51 | 52 | ## Reference code 53 | 54 | [DCGAN](https://github.com/carpedm20/DCGAN-tensorflow) 55 | 56 | [autoencoding_beyond_pixels](https://github.com/andersbll/autoencoding_beyond_pixels) 57 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import gzip 5 | import json 6 | import shutil 7 | import zipfile 8 | import argparse 9 | import subprocess 10 | from six.moves import urllib 11 | 12 | parser = argparse.ArgumentParser(description='Download dataset for DCGAN.') 13 | parser.add_argument('datasets', metavar='N', type=str, nargs='+', choices=['celebA', 'lsun', 'mnist'], 14 | help='name of dataset to download [celebA, lsun, mnist]' , default='mnist') 15 | 16 | def download(url, dirpath): 17 | 18 | filename = url.split('/')[-1] 19 | filepath = os.path.join(dirpath, filename) 20 | u = urllib.request.urlopen(url) 21 | f = open(filepath, 'wb') 22 | filesize = int(u.headers["Content-Length"]) 23 | print("Downloading: %s Bytes: %s" % (filename, filesize)) 24 | 25 | downloaded = 0 26 | block_sz = 8192 27 | status_width = 70 28 | while True: 29 | buf = u.read(block_sz) 30 | if not buf: 31 | print('') 32 | break 33 | else: 34 | print('', end='\r') 35 | downloaded += len(buf) 36 | f.write(buf) 37 | status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") % 38 | ('=' * int(float(downloaded) / filesize * status_width) + '>', downloaded * 100. / filesize)) 39 | print(status, end='') 40 | sys.stdout.flush() 41 | f.close() 42 | return filepath 43 | 44 | def unzip(filepath): 45 | print("Extracting: " + filepath) 46 | dirpath = os.path.dirname(filepath) 47 | with zipfile.ZipFile(filepath) as zf: 48 | zf.extractall(dirpath) 49 | os.remove(filepath) 50 | 51 | def download_celeb_a(dirpath): 52 | data_dir = 'celebA' 53 | if os.path.exists(os.path.join(dirpath, data_dir)): 54 | print('Found Celeb-A - skip') 55 | return 56 | url = 'https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip?dl=1&pv=1' 57 | filepath = download(url, dirpath) 58 | zip_dir = '' 59 | with zipfile.ZipFile(filepath) as zf: 60 | zip_dir = zf.namelist()[0] 61 | zf.extractall(dirpath) 62 | os.remove(filepath) 63 | os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir)) 64 | 65 | def _list_categories(tag): 66 | url = 'http://lsun.cs.princeton.edu/htbin/list.cgi?tag=' + tag 67 | f = urllib.request.urlopen(url) 68 | return json.loads(f.read()) 69 | 70 | def _download_lsun(out_dir, category, set_name, tag): 71 | url = 'http://lsun.cs.princeton.edu/htbin/download.cgi?tag={tag}' \ 72 | '&category={category}&set={set_name}'.format(**locals()) 73 | print(url) 74 | if set_name == 'test': 75 | out_name = 'test_lmdb.zip' 76 | else: 77 | out_name = '{category}_{set_name}_lmdb.zip'.format(**locals()) 78 | out_path = os.path.join(out_dir, out_name) 79 | cmd = ['curl', url, '-o', out_path] 80 | print('Downloading', category, set_name, 'set') 81 | subprocess.call(cmd) 82 | 83 | def download_lsun(dirpath): 84 | data_dir = os.path.join(dirpath, 'lsun') 85 | if os.path.exists(data_dir): 86 | print('Found LSUN - skip') 87 | return 88 | else: 89 | os.mkdir(data_dir) 90 | 91 | tag = 'latest' 92 | #categories = _list_categories(tag) 93 | categories = ['bedroom'] 94 | 95 | for category in categories: 96 | _download_lsun(data_dir, category, 'train', tag) 97 | _download_lsun(data_dir, category, 'val', tag) 98 | _download_lsun(data_dir, '', 'test', tag) 99 | 100 | def download_mnist(dirpath): 101 | data_dir = os.path.join(dirpath, 'mnist') 102 | if os.path.exists(data_dir): 103 | print('Found MNIST - skip') 104 | return 105 | else: 106 | os.mkdir(data_dir) 107 | url_base = 'http://yann.lecun.com/exdb/mnist/' 108 | file_names = ['train-images-idx3-ubyte.gz','train-labels-idx1-ubyte.gz','t10k-images-idx3-ubyte.gz','t10k-labels-idx1-ubyte.gz'] 109 | for file_name in file_names: 110 | url = (url_base+file_name).format(**locals()) 111 | print(url) 112 | out_path = os.path.join(data_dir,file_name) 113 | cmd = ['curl', url, '-o', out_path] 114 | print('Downloading ', file_name) 115 | subprocess.call(cmd) 116 | cmd = ['gzip', '-d', out_path] 117 | print('Decompressing ', file_name) 118 | subprocess.call(cmd) 119 | 120 | def prepare_data_dir(path = './data'): 121 | if not os.path.exists(path): 122 | os.mkdir(path) 123 | 124 | if __name__ == '__main__': 125 | args = parser.parse_args() 126 | prepare_data_dir() 127 | 128 | if 'celebA' in args.datasets: 129 | download_celeb_a('./data') 130 | if 'lsun' in args.datasets: 131 | download_lsun('./data') 132 | if 'mnist' in args.datasets: 133 | download_mnist('./data') -------------------------------------------------------------------------------- /img/real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangqianhui/vae-gan-tensorflow/636f805090cffa5afcc4ab561524fb8435d1a850/img/real.png -------------------------------------------------------------------------------- /img/recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangqianhui/vae-gan-tensorflow/636f805090cffa5afcc4ab561524fb8435d1a850/img/recon.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from utils import mkdir_p 4 | from vaegan import vaegan 5 | from utils import CelebA 6 | 7 | flags = tf.app.flags 8 | 9 | flags.DEFINE_integer("batch_size" , 64, "batch size") 10 | flags.DEFINE_integer("max_iters" , 600000, "the maxmization epoch") 11 | flags.DEFINE_integer("latent_dim" , 128, "the dim of latent code") 12 | flags.DEFINE_float("learn_rate_init" , 0.0003, "the init of learn rate") 13 | #Please set this num of repeat by the size of your datasets. 14 | flags.DEFINE_integer("repeat", 10000, "the numbers of repeat for your datasets") 15 | flags.DEFINE_string("path", '/home/?/data/', "for example, '/home/jack/data/' is the directory of your celebA data") 16 | flags.DEFINE_integer("op", 0, "Training or Test") 17 | 18 | FLAGS = flags.FLAGS 19 | if __name__ == "__main__": 20 | 21 | root_log_dir = "./vaeganlogs/logs/celeba_test" 22 | vaegan_checkpoint_dir = "./model_vaegan/model.ckpt" 23 | sample_path = "./vaeganSample/sample" 24 | 25 | mkdir_p(root_log_dir) 26 | mkdir_p('./model_vaegan/') 27 | mkdir_p(sample_path) 28 | 29 | model_path = vaegan_checkpoint_dir 30 | 31 | batch_size = FLAGS.batch_size 32 | max_iters = FLAGS.max_iters 33 | latent_dim = FLAGS.latent_dim 34 | data_repeat = FLAGS.repeat 35 | 36 | learn_rate_init = FLAGS.learn_rate_init 37 | cb_ob = CelebA(FLAGS.path) 38 | 39 | vaeGan = vaegan(batch_size= batch_size, max_iters= max_iters, repeat = data_repeat, 40 | model_path= model_path, data_ob= cb_ob, latent_dim= latent_dim, 41 | sample_path= sample_path , log_dir= root_log_dir , learnrate_init= learn_rate_init) 42 | 43 | if FLAGS.op == 0: 44 | vaeGan.build_model_vaegan() 45 | vaeGan.train() 46 | 47 | else: 48 | vaeGan.build_model_vaegan() 49 | vaeGan.test() 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /msssim.py: -------------------------------------------------------------------------------- 1 | """Python implementation of MS-SSIM. 2 | Usage: 3 | python msssim.py --original_image=original.png --compared_image=distorted.png 4 | """ 5 | import numpy as np 6 | from scipy import signal 7 | from scipy.ndimage.filters import convolve 8 | import tensorflow as tf 9 | import os 10 | import string 11 | 12 | ori_path= "./vaeganCeleba2/sample/test00_0000_r.png" 13 | com_path= "./vaeganCeleba2/sample/test00_0000.png" 14 | 15 | tf.flags.DEFINE_string('original_image',ori_path , 'Path to PNG image.') 16 | tf.flags.DEFINE_string('compared_image',com_path , 'Path to PNG image.') 17 | FLAGS = tf.flags.FLAGS 18 | 19 | 20 | def _FSpecialGauss(size, sigma): 21 | """Function to mimic the 'fspecial' gaussian MATLAB function.""" 22 | radius = size // 2 23 | offset = 0.0 24 | start, stop = -radius, radius + 1 25 | if size % 2 == 0: 26 | offset = 0.5 27 | stop -= 1 28 | x, y = np.mgrid[offset + start:stop, offset + start:stop] 29 | assert len(x) == size 30 | g = np.exp(-((x**2 + y**2)/(2.0 * sigma**2))) 31 | return g / g.sum() 32 | 33 | 34 | def _SSIMForMultiScale(img1, img2, max_val=255, filter_size=11, 35 | filter_sigma=1.5, k1=0.01, k2=0.03): 36 | """Return the Structural Similarity Map between `img1` and `img2`. 37 | This function attempts to match the functionality of ssim_index_new.m by 38 | Zhou Wang: http://www.cns.nyu.edu/~lcv/ssim/msssim.zip 39 | Arguments: 40 | img1: Numpy array holding the first RGB image batch. 41 | img2: Numpy array holding the second RGB image batch. 42 | max_val: the dynamic range of the images (i.e., the difference between the 43 | maximum the and minimum allowed values). 44 | filter_size: Size of blur kernel to use (will be reduced for small images). 45 | filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced 46 | for small images). 47 | k1: Constant used to maintain stability in the SSIM calculation (0.01 in 48 | the original paper). 49 | k2: Constant used to maintain stability in the SSIM calculation (0.03 in 50 | the original paper). 51 | Returns: 52 | Pair containing the mean SSIM and contrast sensitivity between `img1` and 53 | `img2`. 54 | Raises: 55 | RuntimeError: If input images don't have the same shape or don't have four 56 | dimensions: [batch_size, height, width, depth]. 57 | """ 58 | if img1.shape != img2.shape: 59 | raise RuntimeError('Input images must have the same shape (%s vs. %s).', 60 | img1.shape, img2.shape) 61 | if img1.ndim != 4: 62 | raise RuntimeError('Input images must have four dimensions, not %d', 63 | img1.ndim) 64 | 65 | img1 = img1.astype(np.float64) 66 | img2 = img2.astype(np.float64) 67 | _, height, width, _ = img1.shape 68 | 69 | # Filter size can't be larger than height or width of images. 70 | size = min(filter_size, height, width) 71 | 72 | # Scale down sigma if a smaller filter size is used. 73 | sigma = size * filter_sigma / filter_size if filter_size else 0 74 | 75 | if filter_size: 76 | window = np.reshape(_FSpecialGauss(size, sigma), (1, size, size, 1)) 77 | mu1 = signal.fftconvolve(img1, window, mode='valid') 78 | mu2 = signal.fftconvolve(img2, window, mode='valid') 79 | sigma11 = signal.fftconvolve(img1 * img1, window, mode='valid') 80 | sigma22 = signal.fftconvolve(img2 * img2, window, mode='valid') 81 | sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid') 82 | else: 83 | # Empty blur kernel so no need to convolve. 84 | mu1, mu2 = img1, img2 85 | sigma11 = img1 * img1 86 | sigma22 = img2 * img2 87 | sigma12 = img1 * img2 88 | 89 | mu11 = mu1 * mu1 90 | mu22 = mu2 * mu2 91 | mu12 = mu1 * mu2 92 | sigma11 -= mu11 93 | sigma22 -= mu22 94 | sigma12 -= mu12 95 | 96 | # Calculate intermediate values used by both ssim and cs_map. 97 | c1 = (k1 * max_val) ** 2 98 | c2 = (k2 * max_val) ** 2 99 | v1 = 2.0 * sigma12 + c2 100 | v2 = sigma11 + sigma22 + c2 101 | ssim = np.mean((((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2))) 102 | cs = np.mean(v1 / v2) 103 | return ssim, cs 104 | 105 | 106 | def MultiScaleSSIM(img1, img2, max_val=255, filter_size=11, filter_sigma=1.5, 107 | k1=0.01, k2=0.03, weights=None): 108 | """Return the MS-SSIM score between `img1` and `img2`. 109 | This function implements Multi-Scale Structural Similarity (MS-SSIM) Image 110 | Quality Assessment according to Zhou Wang's paper, "Multi-scale structural 111 | similarity for image quality assessment" (2003). 112 | Link: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf 113 | Author's MATLAB implementation: 114 | http://www.cns.nyu.edu/~lcv/ssim/msssim.zip 115 | Arguments: 116 | img1: Numpy array holding the first RGB image batch. 117 | img2: Numpy array holding the second RGB image batch. 118 | max_val: the dynamic range of the images (i.e., the difference between the 119 | maximum the and minimum allowed values). 120 | filter_size: Size of blur kernel to use (will be reduced for small images). 121 | filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced 122 | for small images). 123 | k1: Constant used to maintain stability in the SSIM calculation (0.01 in 124 | the original paper). 125 | k2: Constant used to maintain stability in the SSIM calculation (0.03 in 126 | the original paper). 127 | weights: List of weights for each level; if none, use five levels and the 128 | weights from the original paper. 129 | Returns: 130 | MS-SSIM score between `img1` and `img2`. 131 | Raises: 132 | RuntimeError: If input images don't have the same shape or don't have four 133 | dimensions: [batch_size, height, width, depth]. 134 | """ 135 | if img1.shape != img2.shape: 136 | raise RuntimeError('Input images must have the same shape (%s vs. %s).', 137 | img1.shape, img2.shape) 138 | if img1.ndim != 4: 139 | raise RuntimeError('Input images must have four dimensions, not %d', 140 | img1.ndim) 141 | 142 | # Note: default weights don't sum to 1.0 but do match the paper / matlab code. 143 | weights = np.array(weights if weights else 144 | [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) 145 | levels = weights.size 146 | downsample_filter = np.ones((1, 2, 2, 1)) / 4.0 147 | im1, im2 = [x.astype(np.float64) for x in [img1, img2]] 148 | mssim = np.array([]) 149 | mcs = np.array([]) 150 | for _ in range(levels): 151 | ssim, cs = _SSIMForMultiScale( 152 | im1, im2, max_val=max_val, filter_size=filter_size, 153 | filter_sigma=filter_sigma, k1=k1, k2=k2) 154 | mssim = np.append(mssim, ssim) 155 | mcs = np.append(mcs, cs) 156 | filtered = [convolve(im, downsample_filter, mode='reflect') 157 | for im in [im1, im2]] 158 | im1, im2 = [x[:, ::2, ::2, :] for x in filtered] 159 | return (np.prod(mcs[0:levels-1] ** weights[0:levels-1]) * 160 | (mssim[levels-1] ** weights[levels-1])) 161 | 162 | 163 | def compare(x, y): 164 | 165 | stat_x = os.stat(x) 166 | stat_y = os.stat(y) 167 | 168 | if stat_x.st_ctime < stat_y.st_ctime: 169 | 170 | return -1 171 | 172 | elif stat_x.st_ctime > stat_y.st_ctime: 173 | 174 | return 1 175 | 176 | else: 177 | 178 | return 0 179 | 180 | def read_image_list(category): 181 | 182 | file_ori = [] 183 | file_new = [] 184 | print("list file") 185 | 186 | list = os.listdir(category) 187 | 188 | for file in list: 189 | 190 | if string.find(file, 'r') != -1: 191 | file_ori.append(category + "/" + file) 192 | else: 193 | file_new.append(category + "/" + file) 194 | 195 | return file_ori, file_new 196 | 197 | 198 | def main(_): 199 | 200 | score = 0.0 201 | file_path = "./vaeganCeleba2/sample_ssim" 202 | ori_list , gen_list = read_image_list(file_path) 203 | ori_list.sort(compare) 204 | gen_list.sort(compare) 205 | print gen_list 206 | 207 | 208 | print ori_list 209 | 210 | 211 | for i in range(len(ori_list)): 212 | 213 | with tf.gfile.FastGFile(ori_list[i]) as image_file: 214 | img1_str = image_file.read() 215 | with tf.gfile.FastGFile(gen_list[i]) as image_file: 216 | img2_str = image_file.read() 217 | 218 | input_img = tf.placeholder(tf.string) 219 | decoded_image = tf.expand_dims(tf.image.decode_png(input_img, channels=3), 0) 220 | 221 | with tf.Session() as sess: 222 | 223 | img1 = sess.run(decoded_image, feed_dict={input_img: img1_str}) 224 | img2 = sess.run(decoded_image, feed_dict={input_img: img2_str}) 225 | 226 | print MultiScaleSSIM(img1, img2, max_val=255) 227 | 228 | 229 | score = score + MultiScaleSSIM(img1, img2, max_val=255) 230 | 231 | print score/len(ori_list) 232 | 233 | 234 | if __name__ == '__main__': 235 | tf.app.run() -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.layers.python.layers import batch_norm 3 | 4 | #the implements of leakyRelu 5 | def lrelu(x , alpha = 0.2 , name="LeakyReLU"): 6 | return tf.maximum(x , alpha*x) 7 | 8 | def conv2d(input_, output_dim, 9 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 10 | name="conv2d"): 11 | 12 | with tf.variable_scope(name): 13 | 14 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], 15 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 16 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') 17 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 18 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 19 | 20 | return conv 21 | 22 | def de_conv(input_, output_shape, 23 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 24 | name="deconv2d", with_w=False): 25 | 26 | with tf.variable_scope(name): 27 | # filter : [height, width, output_channels, in_channels] 28 | w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]], 29 | initializer=tf.random_normal_initializer(stddev=stddev)) 30 | 31 | try: 32 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, 33 | strides=[1, d_h, d_w, 1]) 34 | 35 | # Support for verisons of TensorFlow before 0.7.0 36 | except AttributeError: 37 | 38 | deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, 39 | strides=[1, d_h, d_w, 1]) 40 | 41 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 42 | deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) 43 | 44 | if with_w: 45 | 46 | return deconv, w, biases 47 | 48 | else: 49 | 50 | return deconv 51 | 52 | def fully_connect(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): 53 | shape = input_.get_shape().as_list() 54 | with tf.variable_scope(scope or "Linear"): 55 | 56 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, 57 | tf.random_normal_initializer(stddev=stddev)) 58 | bias = tf.get_variable("bias", [output_size], 59 | initializer=tf.constant_initializer(bias_start)) 60 | 61 | if with_w: 62 | return tf.matmul(input_, matrix) + bias, matrix, bias 63 | else: 64 | 65 | return tf.matmul(input_, matrix) + bias 66 | 67 | def conv_cond_concat(x, y): 68 | """Concatenate conditioning vector on feature map axis.""" 69 | x_shapes = x.get_shape() 70 | y_shapes = y.get_shape() 71 | 72 | return tf.concat(3 , [x , y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2] , y_shapes[3]])]) 73 | 74 | def batch_normal(input , scope="scope" , reuse=False): 75 | return batch_norm(input , epsilon=1e-5, decay=0.9 , scale=True, scope=scope , reuse=reuse , updates_collections=None) 76 | 77 | def instance_norm(x): 78 | 79 | epsilon = 1e-9 80 | mean, var = tf.nn.moments(x, [1, 2], keep_dims=True) 81 | return tf.div(tf.subtract(x, mean), tf.sqrt(tf.add(var, epsilon))) 82 | 83 | def residual(x, output_dims, kernel, strides, name_1, name_2): 84 | 85 | with tf.variable_scope('residual') as scope: 86 | 87 | conv1 = conv2d(x, output_dims, k_h=kernel, k_w=kernel, d_h=strides, d_w=strides, name=name_1) 88 | conv2 = conv2d(tf.nn.relu(conv1), output_dims, k_h=kernel, k_w=kernel, d_h=strides, d_w=strides, name=name_2) 89 | resi = x + conv2 90 | 91 | return resi 92 | 93 | def deresidual(x, output_shape, kernel, strides, name_1, name_2): 94 | 95 | with tf.variable_scope('residual_un') as scope: 96 | 97 | deconv1 = de_conv(x, output_shape=output_shape, k_h=kernel, k_w=kernel, d_h=strides, d_w=strides, name=name_1) 98 | deconv2 = de_conv(tf.nn.relu(deconv1), output_shape=output_shape, k_h=kernel, k_w=kernel, d_h=strides, d_w=strides, name=name_2) 99 | resi = x + deconv2 100 | 101 | return resi 102 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import numpy as np 4 | import scipy 5 | import scipy.misc 6 | 7 | 8 | def mkdir_p(path): 9 | try: 10 | os.makedirs(path) 11 | except OSError as exc: # Python >2.5 12 | if exc.errno == errno.EEXIST and os.path.isdir(path): 13 | pass 14 | else: 15 | raise 16 | 17 | def get_image(image_path, image_size, is_crop=True, resize_w=64, is_grayscale=False): 18 | return transform(imread(image_path, is_grayscale), image_size, is_crop, resize_w) 19 | 20 | 21 | def transform(image, npx=64, is_crop=False, resize_w=64): 22 | # npx : # of pixels width/height of image 23 | if is_crop: 24 | cropped_image = center_crop(image, npx, resize_w=resize_w) 25 | else: 26 | cropped_image = image 27 | cropped_image = scipy.misc.imresize(cropped_image, 28 | [resize_w, resize_w]) 29 | return np.array(cropped_image) / 127.5 - 1 30 | 31 | def center_crop(x, crop_h , crop_w=None, resize_w=64): 32 | 33 | if crop_w is None: 34 | crop_w = crop_h 35 | h, w = x.shape[:2] 36 | j = int(round((h - crop_h)/2.)) 37 | i = int(round((w - crop_w)/2.)) 38 | return scipy.misc.imresize(x[j:j+crop_h, i:i+crop_w], 39 | [resize_w, resize_w]) 40 | 41 | 42 | def save_images(images, size, image_path): 43 | return imsave(inverse_transform(images), size, image_path) 44 | 45 | def imread(path, is_grayscale=False): 46 | if (is_grayscale): 47 | return scipy.misc.imread(path, flatten=True).astype(np.float) 48 | else: 49 | return scipy.misc.imread(path).astype(np.float) 50 | 51 | 52 | def imsave(images, size, path): 53 | return scipy.misc.imsave(path, merge(images, size)) 54 | 55 | def merge(images, size): 56 | h, w = images.shape[1], images.shape[2] 57 | img = np.zeros((h * size[0], w * size[1], 3)) 58 | for idx, image in enumerate(images): 59 | i = idx % size[1] 60 | j = idx // size[1] 61 | img[j * h:j * h + h, i * w: i * w + w, :] = image 62 | 63 | return img 64 | 65 | def inverse_transform(image): 66 | return ((image + 1) * 127.5).astype(np.uint8) 67 | 68 | class CelebA(object): 69 | def __init__(self, images_path): 70 | 71 | self.dataname = "CelebA" 72 | self.dims = 64 * 64 73 | self.shape = [64, 64, 3] 74 | self.image_size = 64 75 | self.channel = 3 76 | self.images_path = images_path 77 | self.train_data_list, self.train_lab_list = self.load_celebA() 78 | 79 | def load_celebA(self): 80 | 81 | # get the list of image path 82 | return read_image_list_file(self.images_path, is_test=False) 83 | 84 | def load_test_celebA(self): 85 | 86 | # get the list of image path 87 | return read_image_list_file(self.images_path, is_test=True) 88 | 89 | def read_image_list_file(category, is_test): 90 | end_num = 0 91 | if is_test == False: 92 | 93 | start_num = 1202 94 | path = category + "celebA/" 95 | 96 | else: 97 | 98 | start_num = 4 99 | path = category + "celeba_test/" 100 | end_num = 1202 101 | 102 | list_image = [] 103 | list_label = [] 104 | 105 | lines = open(category + "list_attr_celeba.txt") 106 | li_num = 0 107 | for line in lines: 108 | 109 | if li_num < start_num: 110 | li_num += 1 111 | continue 112 | 113 | if li_num >= end_num and is_test == True: 114 | break 115 | 116 | flag = line.split('1 ', 41)[20] # get the label for gender 117 | file_name = line.split(' ', 1)[0] 118 | 119 | # print flag 120 | if flag == ' ': 121 | 122 | list_label.append(1) 123 | 124 | else: 125 | 126 | list_label.append(0) 127 | 128 | list_image.append(path + file_name) 129 | 130 | li_num += 1 131 | 132 | lines.close() 133 | 134 | return list_image, list_label 135 | 136 | 137 | 138 | -------------------------------------------------------------------------------- /vaegan.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from ops import batch_normal, de_conv, conv2d, fully_connect, lrelu 3 | from utils import save_images, get_image 4 | from utils import CelebA 5 | import numpy as np 6 | import cv2 7 | from tensorflow.python.framework.ops import convert_to_tensor 8 | import os 9 | TINY = 1e-8 10 | d_scale_factor = 0.25 11 | g_scale_factor = 1 - 0.75/2 12 | 13 | 14 | class vaegan(object): 15 | 16 | #build model 17 | def __init__(self, batch_size, max_iters, repeat, model_path, data_ob, latent_dim, sample_path, log_dir, learnrate_init): 18 | 19 | self.batch_size = batch_size 20 | self.max_iters = max_iters 21 | self.repeat_num = repeat 22 | self.saved_model_path = model_path 23 | self.data_ob = data_ob 24 | self.latent_dim = latent_dim 25 | self.sample_path = sample_path 26 | self.log_dir = log_dir 27 | self.learn_rate_init = learnrate_init 28 | self.log_vars = [] 29 | 30 | self.channel = 3 31 | self.output_size = data_ob.image_size 32 | self.images = tf.placeholder(tf.float32, [self.batch_size, self.output_size, self.output_size, self.channel]) 33 | self.ep = tf.random_normal(shape=[self.batch_size, self.latent_dim]) 34 | self.zp = tf.random_normal(shape=[self.batch_size, self.latent_dim]) 35 | 36 | self.dataset = tf.data.Dataset.from_tensor_slices( 37 | convert_to_tensor(self.data_ob.train_data_list, dtype=tf.string)) 38 | self.dataset = self.dataset.map(lambda filename : tuple(tf.py_func(self._read_by_function, 39 | [filename], [tf.double])), num_parallel_calls=16) 40 | self.dataset = self.dataset.repeat(self.repeat_num) 41 | self.dataset = self.dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size)) 42 | 43 | self.iterator = tf.data.Iterator.from_structure(self.dataset.output_types, self.dataset.output_shapes) 44 | self.next_x = tf.squeeze(self.iterator.get_next()) 45 | self.training_init_op = self.iterator.make_initializer(self.dataset) 46 | 47 | def build_model_vaegan(self): 48 | 49 | self.z_mean, self.z_sigm = self.Encode(self.images) 50 | self.z_x = tf.add(self.z_mean, tf.sqrt(tf.exp(self.z_sigm))*self.ep) 51 | self.x_tilde = self.generate(self.z_x, reuse=False) 52 | self.l_x_tilde, self.De_pro_tilde = self.discriminate(self.x_tilde) 53 | 54 | self.x_p = self.generate(self.zp, reuse=True) 55 | 56 | self.l_x, self.D_pro_logits = self.discriminate(self.images, True) 57 | _, self.G_pro_logits = self.discriminate(self.x_p, True) 58 | 59 | #KL loss 60 | self.kl_loss = self.KL_loss(self.z_mean, self.z_sigm) 61 | 62 | # D loss 63 | self.D_fake_loss = tf.reduce_mean( 64 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(self.G_pro_logits), logits=self.G_pro_logits)) 65 | self.D_real_loss = tf.reduce_mean( 66 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(self.D_pro_logits) - d_scale_factor, logits=self.D_pro_logits)) 67 | self.D_tilde_loss = tf.reduce_mean( 68 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(self.De_pro_tilde), logits=self.De_pro_tilde)) 69 | 70 | # G loss 71 | self.G_fake_loss = tf.reduce_mean( 72 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(self.G_pro_logits) - g_scale_factor, logits=self.G_pro_logits)) 73 | self.G_tilde_loss = tf.reduce_mean( 74 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(self.De_pro_tilde) - g_scale_factor, logits=self.De_pro_tilde)) 75 | 76 | self.D_loss = self.D_fake_loss + self.D_real_loss + self.D_tilde_loss 77 | 78 | # preceptual loss(feature loss) 79 | self.LL_loss = tf.reduce_mean(tf.reduce_sum(self.NLLNormal(self.l_x_tilde, self.l_x), [1,2,3])) 80 | 81 | #For encode 82 | self.encode_loss = self.kl_loss/(self.latent_dim*self.batch_size) - self.LL_loss / (4 * 4 * 256) 83 | 84 | #for Gen 85 | self.G_loss = self.G_fake_loss + self.G_tilde_loss - 1e-6*self.LL_loss 86 | 87 | self.log_vars.append(("encode_loss", self.encode_loss)) 88 | self.log_vars.append(("generator_loss", self.G_loss)) 89 | self.log_vars.append(("discriminator_loss", self.D_loss)) 90 | self.log_vars.append(("LL_loss", self.LL_loss)) 91 | 92 | t_vars = tf.trainable_variables() 93 | 94 | self.d_vars = [var for var in t_vars if 'dis' in var.name] 95 | self.g_vars = [var for var in t_vars if 'gen' in var.name] 96 | self.e_vars = [var for var in t_vars if 'e_' in var.name] 97 | 98 | self.saver = tf.train.Saver() 99 | for k, v in self.log_vars: 100 | tf.summary.scalar(k, v) 101 | 102 | #do train 103 | def train(self): 104 | 105 | global_step = tf.Variable(0, trainable=False) 106 | add_global = global_step.assign_add(1) 107 | new_learning_rate = tf.train.exponential_decay(self.learn_rate_init, global_step=global_step, decay_steps=10000, 108 | decay_rate=0.98) 109 | #for D 110 | trainer_D = tf.train.RMSPropOptimizer(learning_rate=new_learning_rate) 111 | gradients_D = trainer_D.compute_gradients(self.D_loss, var_list=self.d_vars) 112 | opti_D = trainer_D.apply_gradients(gradients_D) 113 | 114 | #for G 115 | trainer_G = tf.train.RMSPropOptimizer(learning_rate=new_learning_rate) 116 | gradients_G = trainer_G.compute_gradients(self.G_loss, var_list=self.g_vars) 117 | opti_G = trainer_G.apply_gradients(gradients_G) 118 | 119 | #for E 120 | trainer_E = tf.train.RMSPropOptimizer(learning_rate=new_learning_rate) 121 | gradients_E = trainer_E.compute_gradients(self.encode_loss, var_list=self.e_vars) 122 | opti_E = trainer_E.apply_gradients(gradients_E) 123 | 124 | init = tf.global_variables_initializer() 125 | config = tf.ConfigProto() 126 | config.gpu_options.allow_growth = True 127 | 128 | with tf.Session(config=config) as sess: 129 | 130 | sess.run(init) 131 | 132 | # Initialzie the iterator 133 | sess.run(self.training_init_op) 134 | summary_op = tf.summary.merge_all() 135 | summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph) 136 | 137 | #self.saver.restore(sess, self.saved_model_path) 138 | step = 0 139 | 140 | while step <= self.max_iters: 141 | 142 | next_x_images = sess.run(self.next_x) 143 | 144 | fd ={self.images: next_x_images} 145 | sess.run(opti_E, feed_dict=fd) 146 | # optimizaiton G 147 | sess.run(opti_G, feed_dict=fd) 148 | # optimization D 149 | sess.run(opti_D, feed_dict=fd) 150 | 151 | summary_str = sess.run(summary_op, feed_dict=fd) 152 | 153 | summary_writer.add_summary(summary_str, step) 154 | new_learn_rate = sess.run(new_learning_rate) 155 | 156 | if new_learn_rate > 0.00005: 157 | sess.run(add_global) 158 | 159 | if step%200 == 0: 160 | 161 | D_loss, fake_loss, encode_loss, LL_loss, kl_loss, new_learn_rate \ 162 | = sess.run([self.D_loss, self.G_loss, self.encode_loss, self.LL_loss, self.kl_loss/(self.latent_dim*self.batch_size), new_learning_rate], feed_dict=fd) 163 | print("Step %d: D: loss = %.7f G: loss=%.7f E: loss=%.7f LL loss=%.7f KL=%.7f, LR=%.7f" % (step, D_loss, fake_loss, encode_loss, LL_loss, kl_loss, new_learn_rate)) 164 | 165 | if np.mod(step , 200) == 1: 166 | 167 | save_images(next_x_images[0:self.batch_size], [self.batch_size/8, 8], 168 | '{}/train_{:02d}_real.png'.format(self.sample_path, step)) 169 | sample_images = sess.run(self.x_tilde, feed_dict=fd) 170 | save_images(sample_images[0:self.batch_size] , [self.batch_size/8, 8], '{}/train_{:02d}_recon.png'.format(self.sample_path, step)) 171 | 172 | if np.mod(step , 2000) == 1 and step != 0: 173 | 174 | self.saver.save(sess , self.saved_model_path) 175 | 176 | step += 1 177 | 178 | save_path = self.saver.save(sess , self.saved_model_path) 179 | print "Model saved in file: %s" % save_path 180 | 181 | def test(self): 182 | 183 | init = tf.global_variables_initializer() 184 | config = tf.ConfigProto() 185 | config.gpu_options.allow_growth = True 186 | 187 | with tf.Session(config=config) as sess: 188 | 189 | # Initialzie the iterator 190 | sess.run(self.training_init_op) 191 | 192 | sess.run(init) 193 | self.saver.restore(sess, self.saved_model_path) 194 | 195 | next_x_images = sess.run(self.next_x) 196 | 197 | real_images, sample_images = sess.run([self.images, self.x_tilde], feed_dict={self.images: next_x_images}) 198 | save_images(sample_images[0:self.batch_size], [self.batch_size/8, 8], '{}/train_{:02d}_{:04d}_con.png'.format(self.sample_path, 0, 0)) 199 | save_images(real_images[0:self.batch_size], [self.batch_size/8, 8], '{}/train_{:02d}_{:04d}_r.png'.format(self.sample_path, 0, 0)) 200 | 201 | ri = cv2.imread('{}/train_{:02d}_{:04d}_r.png'.format(self.sample_path, 0, 0), 1) 202 | fi = cv2.imread('{}/train_{:02d}_{:04d}_con.png'.format(self.sample_path, 0, 0), 1) 203 | 204 | cv2.imshow('real_image', ri) 205 | cv2.imshow('reconstruction', fi) 206 | 207 | cv2.waitKey(-1) 208 | 209 | def discriminate(self, x_var, reuse=False): 210 | 211 | with tf.variable_scope("discriminator") as scope: 212 | 213 | if reuse: 214 | scope.reuse_variables() 215 | 216 | conv1 = tf.nn.relu(conv2d(x_var, output_dim=32, name='dis_conv1')) 217 | conv2= tf.nn.relu(batch_normal(conv2d(conv1, output_dim=128, name='dis_conv2'), scope='dis_bn1', reuse=reuse)) 218 | conv3= tf.nn.relu(batch_normal(conv2d(conv2, output_dim=256, name='dis_conv3'), scope='dis_bn2', reuse=reuse)) 219 | conv4 = conv2d(conv3, output_dim=256, name='dis_conv4') 220 | middle_conv = conv4 221 | conv4= tf.nn.relu(batch_normal(conv4, scope='dis_bn3', reuse=reuse)) 222 | conv4= tf.reshape(conv4, [self.batch_size, -1]) 223 | 224 | fl = tf.nn.relu(batch_normal(fully_connect(conv4, output_size=256, scope='dis_fully1'), scope='dis_bn4', reuse=reuse)) 225 | output = fully_connect(fl , output_size=1, scope='dis_fully2') 226 | 227 | return middle_conv, output 228 | 229 | def generate(self, z_var, reuse=False): 230 | 231 | with tf.variable_scope('generator') as scope: 232 | 233 | if reuse == True: 234 | scope.reuse_variables() 235 | 236 | d1 = tf.nn.relu(batch_normal(fully_connect(z_var , output_size=8*8*256, scope='gen_fully1'), scope='gen_bn1', reuse=reuse)) 237 | d2 = tf.reshape(d1, [self.batch_size, 8, 8, 256]) 238 | d2 = tf.nn.relu(batch_normal(de_conv(d2 , output_shape=[self.batch_size, 16, 16, 256], name='gen_deconv2'), scope='gen_bn2', reuse=reuse)) 239 | d3 = tf.nn.relu(batch_normal(de_conv(d2, output_shape=[self.batch_size, 32, 32, 128], name='gen_deconv3'), scope='gen_bn3', reuse=reuse)) 240 | d4 = tf.nn.relu(batch_normal(de_conv(d3, output_shape=[self.batch_size, 64, 64, 32], name='gen_deconv4'), scope='gen_bn4', reuse=reuse)) 241 | d5 = de_conv(d4, output_shape=[self.batch_size, 64, 64, 3], name='gen_deconv5', d_h=1, d_w=1) 242 | 243 | return tf.nn.tanh(d5) 244 | 245 | def Encode(self, x): 246 | 247 | with tf.variable_scope('encode') as scope: 248 | 249 | conv1 = tf.nn.relu(batch_normal(conv2d(x, output_dim=64, name='e_c1'), scope='e_bn1')) 250 | conv2 = tf.nn.relu(batch_normal(conv2d(conv1, output_dim=128, name='e_c2'), scope='e_bn2')) 251 | conv3 = tf.nn.relu(batch_normal(conv2d(conv2 , output_dim=256, name='e_c3'), scope='e_bn3')) 252 | conv3 = tf.reshape(conv3, [self.batch_size, 256 * 8 * 8]) 253 | fc1 = tf.nn.relu(batch_normal(fully_connect(conv3, output_size=1024, scope='e_f1'), scope='e_bn4')) 254 | z_mean = fully_connect(fc1 , output_size=128, scope='e_f2') 255 | z_sigma = fully_connect(fc1, output_size=128, scope='e_f3') 256 | 257 | return z_mean, z_sigma 258 | 259 | def KL_loss(self, mu, log_var): 260 | return -0.5 * tf.reduce_sum(1 + log_var - tf.pow(mu, 2) - tf.exp(log_var)) 261 | 262 | def sample_z(self, mu, log_var): 263 | eps = tf.random_normal(shape=tf.shape(mu)) 264 | return mu + tf.exp(log_var / 2) * eps 265 | 266 | def NLLNormal(self, pred, target): 267 | 268 | c = -0.5 * tf.log(2 * np.pi) 269 | multiplier = 1.0 / (2.0 * 1) 270 | tmp = tf.square(pred - target) 271 | tmp *= -multiplier 272 | tmp += c 273 | 274 | return tmp 275 | 276 | def _parse_function(self, images_filenames): 277 | 278 | image_string = tf.read_file(images_filenames) 279 | image_decoded = tf.image.decode_and_crop_jpeg(image_string, crop_window=[218 / 2 - 54, 178 / 2 - 54 , 108, 108], channels=3) 280 | image_resized = tf.image.resize_images(image_decoded, [self.output_size, self.output_size]) 281 | image_resized = image_resized / 127.5 - 1 282 | 283 | return image_resized 284 | 285 | def _read_by_function(self, filename): 286 | 287 | array = get_image(filename, 108, is_crop=True, resize_w=self.output_size, 288 | is_grayscale=False) 289 | real_images = np.array(array) 290 | return real_images 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | --------------------------------------------------------------------------------