├── LICENSE ├── README.md ├── checkpoints └── checkpoint ├── config.py ├── datasets ├── download_celebA.py └── prepare_celeba.py ├── discriminator.py ├── generator.py ├── main.py ├── readme ├── conv_measure_vis.png ├── eq_autoencoder_loss.png ├── eq_conv_measure.png ├── eq_gamma.png ├── eq_global.png ├── eq_losses.png ├── eq_objective.png └── generated_from_Z.png └── utils ├── custom_ops.py └── misc.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Arthur Goldberg 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BEGAN: Boundary Equibilibrium Generative Adversarial Networks 2 | 3 | This is an implementation of the paper on Boundary Equilibrium Generative Adversarial Networks [(Berthelot, Schumm and Metz, 2017)](#references). 4 | 5 | ## Dependencies 6 | 7 | * Python 3+ 8 | * numpy 9 | * Tensorflow 10 | * tqdm 11 | * h5py 12 | * scipy (optional) 13 | 14 | ## What are Boundary Equilibrium Generative Adversarial Networks? 15 | 16 | Unlike standard generative adversarial networks [(Goodfellow et al. 2014)](#references), boundary equilibrium generative adversarial networks (BEGAN) use an auto-encoder as a disciminator. An auto-encoder loss is defined, and an approximation of the Wasserstein distance is then computed between the pixelwise auto-encoder loss distributions of real and generated samples. 17 | 18 |

19 | 20 |

21 | 22 | With the auto-encoder loss defined (above), the Wasserstein distance approximation simplifies to a loss function wherein the discriminating auto-encoder aims to perform *well on real samples* and *poorly on generated samples*, while the generator aims to produce adversarial samples which the discriminator can't help but perform well upon. 23 | 24 |

25 | 26 |

27 | 28 | Additionally, a hyper-parameter gamma is introduced which gives the user the power to control sample diversity by balancing the discriminator and generator. 29 | 30 |

31 | 32 |

33 | 34 | Gamma is put into effect through the use of a weighting parameter *k* which gets updated while training to adapt the loss function so that our output matches the desired diversity. The overall objective for the network is then: 35 | 36 |

37 | 38 |

39 | 40 | Unlike most generative adversarial network architectures, where we need to update *G* and *D* independently, the Boundary Equilibrium GAN has the nice property that we can define a global loss and train the network as a whole (though we still have to make sure to update parameters with respect to the relative loss functions) 41 | 42 |

43 | 44 |

45 | 46 | The final contribution of the paper is a derived convergence measure M which gives a good indicator as to how the network is doing. We use this parameter to track performance, as well as control learning rate. 47 | 48 |

49 | 50 |

51 | 52 | The overall result is a surprisingly effective model which produces samples well beyond the previous state of the art. 53 | 54 |

55 | 56 |

57 | 58 | *128x128 samples generated from random points in Z, from [(Berthelot, Schumm and Metz, 2017)](#references).* 59 | 60 | ## Usage 61 | 62 | ### Data Preprocessing 63 | 64 | You might want to use the 'CelebA' dataset [(Liu et al. 2015)](#references), this can be downloaded from [the project website](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). Make sure to download the 'Aligned and Cropped' Version. However you can modify these instructions to use an alternate dataset. 65 | 66 | (Note: if the CelebA Dropbox is down you can alternatively use their [Google Drive](https://drive.google.com/open?id=0B7EVK8r0v71pWEZsZE9oNnFzTm8)). 67 | 68 | This then needs to be prepared into hdf5 through the following method: 69 | 70 | ```python 71 | from glob import glob 72 | import os 73 | import numpy as np 74 | import h5py 75 | from tqdm import tqdm 76 | from scipy.misc import imread, imresize 77 | 78 | filenames = glob(os.path.join("img_align_celeba", "*.jpg")) 79 | filenames = np.sort(filenames) 80 | w, h = 64, 64 # Change this if you wish to use larger images 81 | data = np.zeros((len(filenames), w * h * 3), dtype = np.uint8) 82 | 83 | # This preprocessing is appriate for CelebA but should be adapted 84 | # (or removed entirely) for other datasets. 85 | 86 | def get_image(image_path, w=64, h=64): 87 | im = imread(image_path).astype(np.float) 88 | orig_h, orig_w = im.shape[:2] 89 | new_h = int(orig_h * w / orig_w) 90 | im = imresize(im, (new_h, w)) 91 | margin = int(round((new_h - h)/2)) 92 | return im[margin:margin+h] 93 | 94 | for n, fname in tqdm(enumerate(filenames)): 95 | image = get_image(fname, w, h) 96 | data[n] = image.flatten() 97 | 98 | with h5py.File(''.join(['datasets/celeba.h5']), 'w') as f: 99 | f.create_dataset("images", data=data) 100 | ``` 101 | 102 | 103 | ### Training 104 | 105 | After your dataset has been created through the method above, change the file config.py to point to your dataset, and to point to your desired checkpoint directory. 106 | 107 | E.g., if your dataset is stored at ```/home/user/data/dataset.hdf5```, then alter config.py to read: 108 | 109 | ```python 110 | dataset_path = '/home/user/data/dataset.hdf5' 111 | checkpoint_path = './checkpoints' 112 | ``` 113 | 114 | You can then begin training: 115 | 116 | ```bash 117 | python main.py --start-epoch=0, add-epochs=100 --save-every 5 118 | ``` 119 | 120 | If you have limited RAM you might need to limit the number of images loaded into memory at once, e.g. 121 | 122 | ```bash 123 | python main.py --start-epoch=0 add-epochs=100 --save-every 5 --max-images 20000 124 | ``` 125 | 126 | I have 12GB which works for around 60,000 images. 127 | 128 | You can specify GPU id with the ```--gpuid``` argument. If you want to run on CPU (not recommended!) use ```--gpuid -1``` 129 | 130 | Other parameters can be tuned if you wish (run ```python main.py --help``` for the full list). 131 | The default values are the same as in the paper (though the authors point out that their choices aren't necessarily optimal). 132 | 133 | The main difference between this implementation's defaults and the original paper is the use of batch normalisation, we found that not using batch normalisation made training much slower. 134 | 135 | ### Running 136 | 137 | After you've trained a model and you want to generate some samples simply run 138 | ```bash 139 | python main.py --start-epoch=N add-epochs=0 --train=False 140 | ``` 141 | where N is the checkpoint you want to run from. 142 | Samples will be saved to ./outputs/ by default (or add optional argument ```--outdir``` for alternative). 143 | 144 | ### Tracking Progress 145 | 146 | As discussed previously, the convergence measure gives a very nice way of tracking progress 147 | This is implemented into the code (via the dictionary ```loss_tracker``` with key ```convergence_measure```) 148 | 149 | Berthelot, Schumm and Metz show that it is a true-to-reality metric to use: 150 | 151 |

152 | 153 |

154 | 155 | *Convergence measure over training epochs, with generator outputs showed above [(Berthelot, Schumm and Metz, 2017)](#references).* 156 | 157 | 158 | ## Issues / Contributing / Todo 159 | 160 | Feel free to raise any issues in the project [issue tracker](http://github.com/artcg/BEGAN/issues), or make a [pull-request](http://github.com/artcg/BEGAN/pulls) if there is something you want to add. 161 | 162 | My next plan is to upload some pre-trained weights so beginners can run the model out-of-the-box. 163 | 164 | ## References 165 | 166 | * [Berthelot, Schumm and Metz. BEGAN: Boundary Equilibrium Generative Adversarial Networks. arXiv preprint, 2017](https://arxiv.org/abs/1703.10717) 167 | 168 | * [Goodfellow, Ian, et al. "Generative adversarial nets." Advances in neural information processing systems. 2014.](http://papers.nips.cc/paper/5423-generative-adversarial-nets) 169 | 170 | * [Liu, Ziwei, et al. "Deep Learning Face Attributes in the Wild." Proceedings of International Conference on Computer Vision. 2015.](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) 171 | -------------------------------------------------------------------------------- /checkpoints/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "BEGAN_64_64_0321.tfmod" 2 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | checkpoint_path = 'checkpoints' 2 | checkpoint_prefix = 'BEGAN_64_64' 3 | data_path = 'datasets/CelebA_64_64.h5' 4 | -------------------------------------------------------------------------------- /datasets/download_celebA.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Akash Rana (github.com/akash9182) 3 | 4 | Modification of 5 | - https://github.com/carpedm20/DCGAN-tensorflow/blob/master/download.py 6 | - http://stackoverflow.com/a/39225039 7 | - 8 | """ 9 | 10 | import os 11 | import zipfile 12 | import requests 13 | import subprocess 14 | from tqdm import tqdm 15 | from collections import OrderedDict 16 | 17 | base_path = './' 18 | 19 | def download_file_from_google_drive(id, destination): 20 | URL = "https://docs.google.com/uc?export=download" 21 | session = requests.Session() 22 | 23 | response = session.get(URL, params={ 'id': id }, stream=True) 24 | token = get_confirm_token(response) 25 | 26 | if token: 27 | params = { 'id' : id, 'confirm' : token } 28 | response = session.get(URL, params=params, stream=True) 29 | 30 | save_response_content(response, destination) 31 | 32 | def get_confirm_token(response): 33 | for key, value in response.cookies.items(): 34 | if key.startswith('download_warning'): 35 | return value 36 | return None 37 | 38 | def save_response_content(response, destination, chunk_size=32*1024): 39 | total_size = int(response.headers.get('content-length', 0)) 40 | with open(destination, "wb") as f: 41 | for chunk in tqdm(response.iter_content(chunk_size), total=total_size, 42 | unit='B', unit_scale=True, desc=destination): 43 | if chunk: # filter out keep-alive new chunks 44 | f.write(chunk) 45 | 46 | def unzip(filepath): 47 | print("Extracting: " + filepath) 48 | base_path = os.path.dirname(filepath) 49 | with zipfile.ZipFile(filepath) as zf: 50 | zf.extractall(base_path) 51 | os.remove(filepath) 52 | 53 | def download_celeb_a(base_path): 54 | data_path = os.path.join(base_path, 'CelebA') 55 | images_path = os.path.join(data_path, 'images') 56 | if os.path.exists(data_path): 57 | print('[!] Found Celeb-A - skip') 58 | return 59 | 60 | filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM" 61 | save_path = os.path.join(base_path, filename) 62 | 63 | if os.path.exists(save_path): 64 | print('[*] {} already exists'.format(save_path)) 65 | else: 66 | download_file_from_google_drive(drive_id, save_path) 67 | 68 | zip_dir = '' 69 | with zipfile.ZipFile(save_path) as zf: 70 | zip_dir = zf.namelist()[0] 71 | zf.extractall(base_path) 72 | if not os.path.exists(data_path): 73 | os.mkdir(data_path) 74 | os.rename(os.path.join(base_path, "img_align_celeba"), images_path) 75 | os.remove(save_path) 76 | 77 | 78 | 79 | if __name__ == '__main__': 80 | download_celeb_a(base_path) 81 | -------------------------------------------------------------------------------- /datasets/prepare_celeba.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import os 3 | import numpy as np 4 | import h5py 5 | from tqdm import tqdm 6 | from scipy.misc import imread, imresize 7 | 8 | filenames = glob(os.path.join("img_align_celeba", "*.jpg")) 9 | filenames = np.sort(filenames) 10 | w, h = 64, 64 # Change this if you wish to use larger images 11 | data = np.zeros((len(filenames), w * h * 3), dtype=np.uint8) 12 | 13 | # This preprocessing is appriate for CelebA but should be adapted 14 | # (or removed entirely) for other datasets. 15 | 16 | 17 | def get_image(image_path, w=64, h=64): 18 | im = imread(image_path).astype(np.float) 19 | orig_h, orig_w = im.shape[:2] 20 | new_h = int(orig_h * w / orig_w) 21 | im = imresize(im, (new_h, w)) 22 | margin = int(round((new_h - h)/2)) 23 | return im[margin:margin+h] 24 | 25 | for n, fname in tqdm(enumerate(filenames)): 26 | image = get_image(fname, w, h) 27 | data[n] = image.flatten() 28 | 29 | with h5py.File('datasets/celeba.h5', 'w') as f: 30 | f.create_dataset("images", data=data) 31 | -------------------------------------------------------------------------------- /discriminator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from utils.custom_ops import custom_fc, custom_conv2d 3 | from generator import decoder 4 | 5 | 6 | def began_discriminator(D_I, batch_size, num_filters, hidden_size, image_size, 7 | scope_name="discriminator", reuse_scope=False): 8 | ''' 9 | Unlike most generative adversarial networks, the boundary 10 | equilibrium uses an autoencoder as a discriminator. 11 | 12 | For simplicity, the decoder architecture is the same as the generator. 13 | 14 | Downsampling is 3x3 convolutions with a stride of 2. 15 | Upsampling is 3x3 convolutions, with nearest neighbour resizing 16 | to the desired resolution. 17 | 18 | Args: 19 | D_I: a batch of images [batch_size, 64 x 64 x 3] 20 | batch_size: Batch size of encodings 21 | num_filters: Number of filters in convolutional layers 22 | hidden_size: Dimensionality of encoding 23 | image_size: First dimension of generated image (must be 64 or 128) 24 | scope_name: Tensorflow scope name 25 | reuse_scope: Tensorflow scope handling 26 | Returns: 27 | Flattened tensor of re-created images, with dimensionality: 28 | [batch_size, image_size * image_size * 3] 29 | ''' 30 | 31 | 32 | with tf.variable_scope(scope_name) as scope: 33 | if reuse_scope: 34 | scope.reuse_variables() 35 | 36 | layer_1 = tf.reshape(D_I, [-1, image_size, image_size, 3]) # '-1' is batch size 37 | 38 | conv_0 = custom_conv2d(layer_1, 3, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec0') 39 | conv_0 = tf.nn.elu(conv_0) 40 | 41 | conv_1 =custom_conv2d(conv_0, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec1') 42 | conv_1 = tf.nn.elu(conv_1) 43 | 44 | 45 | conv_2 =custom_conv2d(conv_1, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec2') 46 | conv_2 = tf.nn.elu(conv_2) 47 | 48 | 49 | layer_2 = custom_conv2d(conv_2, 2 * num_filters, k_h=3, k_w=3, d_h=2, d_w=2, scope='el2') 50 | layer_2 = tf.nn.elu(layer_2) 51 | 52 | conv_3 = custom_conv2d(layer_2, 2 * num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec3') 53 | conv_3 = tf.nn.elu(conv_3) 54 | 55 | conv_4 =custom_conv2d(conv_3, 2 * num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec4') 56 | conv_4 = tf.nn.elu(conv_4) 57 | 58 | 59 | layer_3 = custom_conv2d(conv_2, 3 * num_filters, k_h=3, k_w=3, d_h=2, d_w=2, scope='el3') 60 | layer_3 = tf.nn.elu(layer_3) 61 | 62 | conv_5 = custom_conv2d(layer_3, 3 * num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec5') 63 | conv_5 = tf.nn.elu(conv_5) 64 | 65 | conv_6 =custom_conv2d(conv_5, 3 * num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec6') 66 | conv_6 = tf.nn.elu(conv_6) 67 | 68 | 69 | layer_4 = custom_conv2d(conv_6, 4 * num_filters, k_h=3, k_w=3, d_h=2, d_w=2, scope='el4') 70 | layer_4 = tf.nn.elu(layer_4) 71 | 72 | conv_7 = custom_conv2d(layer_4, 4 * num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec7') 73 | conv_7 = tf.nn.elu(conv_7) 74 | 75 | conv_8 =custom_conv2d(conv_7, 4 * num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec8') 76 | conv_8 = tf.nn.elu(conv_8) 77 | 78 | if image_size == 64: 79 | enc = custom_fc(conv_8, hidden_size, scope='enc') 80 | else: 81 | layer_5 = custom_conv2d(conv_8, 5 * num_filters, k_h=3, k_w=3, d_h=2, d_w=2, scope='el5') 82 | layer_5 = tf.nn.elu(layer_5) 83 | 84 | conv_9 = custom_conv2d(layer_5, 5 * num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec9') 85 | conv_9 = tf.nn.elu(conv_9) 86 | 87 | conv_10 =custom_conv2d(conv_9, 5 * num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec10') 88 | conv_10 = tf.nn.elu(conv_10) 89 | enc = custom_fc(conv_10, hidden_size, scope='enc') 90 | 91 | # add elu before decoding? 92 | return decoder(enc, batch_size=batch_size, num_filters=num_filters, 93 | hidden_size=hidden_size, image_size=image_size) 94 | -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from utils.custom_ops import custom_fc, custom_conv2d 3 | 4 | if False: # This to silence pyflake 5 | custom_ops 6 | 7 | 8 | def decoder(Z, batch_size, num_filters, hidden_size, image_size): 9 | ''' 10 | The Boundary Equilibrium GAN deliberately uses a simple generator 11 | architecture. 12 | 13 | Upsampling is 3x3 convolutions, with nearest neighbour resizing 14 | to the desired resolution. 15 | 16 | Args: 17 | Z: Latent space 18 | batch_size: Batch size of generations 19 | num_filters: Number of filters in convolutional layers 20 | hidden_size: Dimensionality of encoding 21 | image_size: First dimension of generated image (must be 64 or 128) 22 | scope_name: Tensorflow scope name 23 | reuse_scope: Tensorflow scope handling 24 | Returns: 25 | Flattened tensor of generated images, with dimensionality: 26 | [batch_size, image_size * image_size * 3] 27 | ''' 28 | layer_1 = custom_fc(Z, 8 * 8 * num_filters, scope='l1') 29 | 30 | layer_1 = tf.reshape(layer_1, [-1, 8, 8, num_filters]) # '-1' is batch size 31 | 32 | conv_1 = custom_conv2d(layer_1, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='c1') 33 | conv_1 = tf.nn.elu(conv_1) 34 | 35 | conv_2 = custom_conv2d(conv_1, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='c2') 36 | conv_2 = tf.nn.elu(conv_2) 37 | 38 | layer_2 = tf.image.resize_nearest_neighbor(conv_2, [16, 16]) 39 | 40 | conv_3 = custom_conv2d(layer_2, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='c3') 41 | conv_3 = tf.nn.elu(conv_3) 42 | 43 | conv_4 = custom_conv2d(conv_3, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='c4') 44 | conv_4 = tf.nn.elu(conv_4) 45 | 46 | layer_3 = tf.image.resize_nearest_neighbor(conv_4, [32, 32]) 47 | 48 | conv_5 = custom_conv2d(layer_3, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='c5') 49 | conv_5 = tf.nn.elu(conv_5) 50 | 51 | conv_6 = custom_conv2d(conv_5, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='c6') 52 | conv_6 = tf.nn.elu(conv_6) 53 | 54 | layer_4 = tf.image.resize_nearest_neighbor(conv_6, [64, 64]) 55 | 56 | conv_7 = custom_conv2d(layer_4, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='c7') 57 | conv_7 = tf.nn.elu(conv_7) 58 | 59 | conv_8 = custom_conv2d(conv_7, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='c8') 60 | conv_8 = tf.nn.elu(conv_8) 61 | 62 | if image_size == 64: 63 | im = conv_8 64 | else: 65 | layer_5 = tf.image.resize_nearest_neighbor(conv_8, [128, 128]) 66 | 67 | conv_9 = custom_conv2d(layer_5, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='c9') 68 | conv_9 = tf.nn.elu(conv_9) 69 | 70 | conv_10 = custom_conv2d(conv_9, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='c10') 71 | im = tf.nn.elu(conv_10) 72 | 73 | im = custom_conv2d(im, 3, k_h=3, k_w=3, d_h=1, d_w=1, scope='im') 74 | im = tf.sigmoid(im) 75 | im = tf.reshape(im, [-1, image_size * image_size * 3]) 76 | return im 77 | 78 | def began_generator(Z, batch_size, num_filters, hidden_size, image_size, 79 | scope_name="generator", reuse_scope=False): 80 | with tf.variable_scope(scope_name) as scope: 81 | if reuse_scope: 82 | scope.reuse_variables() 83 | return decoder(Z, batch_size, num_filters, hidden_size, image_size) 84 | 85 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from generator import began_generator as generator 3 | from discriminator import began_discriminator as discriminator 4 | from utils.misc import loadData, dataIterator 5 | import tqdm 6 | import numpy as np 7 | from utils.misc import plot_gens 8 | from config import checkpoint_path, checkpoint_prefix 9 | 10 | 11 | class BEGAN: 12 | loss_tracker = {'generator': [], 13 | 'discriminator': [], 14 | 'convergence_measure': []} 15 | 16 | def loss(D_real_in, D_real_out, D_gen_in, D_gen_out, k_t, gamma=0.75): 17 | ''' 18 | The Boundary Equibilibrium GAN uses an approximation of the 19 | Wasserstein Loss between the disitributions of pixel-wise 20 | autoencoder loss based on the discriminator performance on 21 | real vs. generated data. 22 | 23 | This simplifies to reducing the L1 norm of the autoencoder loss: 24 | making the discriminator objective to perform well on real images 25 | and poorly on generated images; with the generator objective 26 | to create samples which the discriminator will perform well upon. 27 | 28 | args: 29 | D_real_in: input to discriminator with real sample. 30 | D_real_out: output from discriminator with real sample. 31 | D_gen_in: input to discriminator with generated sample. 32 | D_gen_out: output from discriminator with generated sample. 33 | k_t: weighting parameter which constantly updates during training 34 | gamma: diversity ratio, used to control model equibilibrium. 35 | returns: 36 | D_loss: discriminator loss to minimise. 37 | G_loss: generator loss to minimise. 38 | k_tp: value of k_t for next train step. 39 | convergence_measure: measure of model convergence. 40 | ''' 41 | def pixel_autoencoder_loss(out, inp): 42 | ''' 43 | The autoencoder loss used is the L1 norm (note that this 44 | is based on the pixel-wise distribution of losses 45 | that the authors assert approximates the Normal distribution) 46 | 47 | args: 48 | out: discriminator output 49 | inp: discriminator input 50 | returns: 51 | L1 norm of pixel-wise loss 52 | ''' 53 | eta = 1 # paper uses L1 norm 54 | diff = tf.abs(out - inp) 55 | if eta == 1: 56 | return tf.reduce_mean(diff) 57 | else: 58 | return tf.reduce_mean(tf.pow(diff, eta)) 59 | 60 | mu_real = pixel_autoencoder_loss(D_real_out, D_real_in) 61 | mu_gen = pixel_autoencoder_loss(D_gen_out, D_gen_in) 62 | D_loss = mu_real - k_t * mu_gen 63 | G_loss = mu_gen 64 | lam = 0.001 # 'learning rate' for k. Berthelot et al. use 0.001 65 | k_tp = k_t + lam * (gamma * mu_real - mu_gen) 66 | convergence_measure = mu_real + np.abs(gamma * mu_real - mu_gen) 67 | return D_loss, G_loss, k_tp, convergence_measure 68 | 69 | def run(x, batch_size, num_filters, hidden_size, image_size): 70 | Z = tf.random_normal((batch_size, hidden_size), 0, 1) 71 | 72 | x_tilde = generator(Z, batch_size=batch_size, num_filters=num_filters, 73 | hidden_size=hidden_size, image_size=image_size) 74 | x_tilde_d = discriminator(x_tilde, batch_size=batch_size, num_filters=num_filters, 75 | hidden_size=hidden_size, image_size=image_size) 76 | 77 | x_d = discriminator(x, reuse_scope=True, batch_size=batch_size, num_filters=num_filters, 78 | hidden_size=hidden_size, image_size=image_size) 79 | 80 | return x_tilde, x_tilde_d, x_d 81 | 82 | scopes = ['generator', 'discriminator'] 83 | 84 | 85 | def began_train(num_images=50000, start_epoch=0, add_epochs=None, batch_size=16, 86 | hidden_size=64, image_size=64, gpu_id='/gpu:0', 87 | demo=False, get=False, start_learn_rate=1e-4, decay_every=100, 88 | save_every=1, batch_norm=True, gamma=0.75): 89 | 90 | num_epochs = start_epoch + add_epochs 91 | loss_tracker = BEGAN.loss_tracker 92 | 93 | graph = tf.Graph() 94 | with graph.as_default(): 95 | global_step = tf.get_variable('global_step', [], 96 | initializer=tf.constant_initializer(0), 97 | trainable=False) 98 | 99 | with tf.device(gpu_id): 100 | learning_rate = tf.placeholder(tf.float32, shape=[]) 101 | opt = tf.train.AdamOptimizer(learning_rate, epsilon=1.0) 102 | 103 | next_batch = tf.placeholder(tf.float32, 104 | [batch_size, image_size * image_size * 3]) 105 | 106 | x_tilde, x_tilde_d, x_d = BEGAN.run(next_batch, batch_size=batch_size, num_filters=128, 107 | hidden_size=hidden_size, image_size=image_size) 108 | 109 | k_t = tf.get_variable('kt', [], 110 | initializer=tf.constant_initializer(0), 111 | trainable=False) 112 | D_loss, G_loss, k_tp, convergence_measure = \ 113 | BEGAN.loss(next_batch, x_d, x_tilde, x_tilde_d, k_t=k_t) 114 | 115 | params = tf.trainable_variables() 116 | tr_vars = {} 117 | for s in BEGAN.scopes: 118 | tr_vars[s] = [i for i in params if s in i.name] 119 | 120 | G_grad = opt.compute_gradients(G_loss, 121 | var_list=tr_vars['generator']) 122 | 123 | D_grad = opt.compute_gradients(D_loss, 124 | var_list=tr_vars['discriminator']) 125 | 126 | G_train = opt.apply_gradients(G_grad, global_step=global_step) 127 | D_train = opt.apply_gradients(D_grad, global_step=global_step) 128 | 129 | init = tf.global_variables_initializer() 130 | saver = tf.train.Saver() 131 | sess = tf.Session(graph=graph, 132 | config=tf.ConfigProto(allow_soft_placement=True, 133 | log_device_placement=True)) 134 | sess.run(init) 135 | if start_epoch > 0: 136 | path = '{}/{}_{}.tfmod'.format(checkpoint_path, 137 | checkpoint_prefix, 138 | str(start_epoch-1).zfill(4)) 139 | tf.train.Saver.restore(saver, sess, path) 140 | 141 | k_t_ = sess.run(k_t) # We initialise with k_t = 0 as in the paper. 142 | num_batches_per_epoch = num_images // batch_size 143 | for epoch in range(start_epoch, num_epochs): 144 | images = loadData(size=num_images) 145 | print('Epoch {} / {}'.format(epoch + 1, num_epochs + 1)) 146 | for i in tqdm.tqdm(range(num_batches_per_epoch)): 147 | iter_ = dataIterator([images], batch_size) 148 | 149 | learning_rate_ = start_learn_rate * pow(0.5, epoch // decay_every) 150 | next_batch_ = next(iter_) 151 | 152 | _, _, D_loss_, G_loss_, k_t_, M_ = \ 153 | sess.run([G_train, D_train, D_loss, G_loss, k_tp, convergence_measure], 154 | {learning_rate: learning_rate_, 155 | next_batch: next_batch_, k_t: min(max(k_t_, 0), 1)}) 156 | 157 | loss_tracker['generator'].append(G_loss_) 158 | loss_tracker['discriminator'].append(D_loss_) 159 | loss_tracker['convergence_measure'].append(M_) 160 | 161 | if epoch % save_every == 0: 162 | path = '{}/{}_{}.tfmod'.format(checkpoint_path, 163 | checkpoint_prefix, 164 | str(epoch).zfill(4)) 165 | saver.save(sess, path) 166 | if demo: 167 | batch = dataIterator([images], batch_size).__next__() 168 | ims = sess.run(x_tilde) 169 | plot_gens((ims, batch), 170 | ('Generated 64x64 samples.', 'Random training images.'), 171 | loss_tracker) 172 | if get: 173 | return ims 174 | 175 | 176 | if __name__ == '__main__': 177 | import argparse 178 | parser = argparse.ArgumentParser(description='Run BEGAN.') 179 | 180 | parser.add_argument('--gpuid', type=int, default=0, 181 | help='GPU ID to use (-1 for CPU)') 182 | 183 | parser.add_argument('--save-every', type=int, default=5, 184 | help='Frequency to save checkpoint (in epochs)') 185 | 186 | parser.add_argument('--start-epoch', type=int, default=0, required=True, 187 | help='Start epoch (0 to begin training from scratch,' 188 | + 'N to restore from checkpoint N)') 189 | 190 | parser.add_argument('--add-epochs', type=int, default=100, required=True, 191 | help='Number of epochs to train' 192 | + '(-1 to train indefinitely)') 193 | 194 | parser.add_argument('--num-images', type=int, default=2000, 195 | help='Number of images to load into RAM at once') 196 | 197 | parser.add_argument('--gamma', type=float, default=0.75, 198 | help='Diversity ratio (read paper for more info)') 199 | 200 | parser.add_argument('--start-learn-rate', type=float, default=1e-5, 201 | help='Starting learn rate') 202 | 203 | parser.add_argument('--train', type=int, default=1, 204 | help='"1" to train; "0" to run and' 205 | + 'return output') 206 | 207 | parser.add_argument('--batch_size', type=int, default=16, 208 | help='Batch size for training (default 16' 209 | + 'as in paper)') 210 | 211 | parser.add_argument('--hidden_size', type=int, default=64, 212 | help='Dimensionality of the discriminator encoding.' 213 | + '(Paper doesnt specify value so we use guess)') 214 | 215 | parser.add_argument('--batch-norm', type=int, default=1, 216 | help='Set to "0" to disable batch normalisation') 217 | 218 | parser.add_argument('--decay-every', type=int, default=-1, 219 | help='Number of epochs before learning rate decay' 220 | + '(set to 0 to disable)') 221 | 222 | parser.add_argument('--outdir', type=str, default='output', 223 | help='Path to save output generations') 224 | 225 | parser.add_argument('--image-size', type=int, default=64, 226 | help='Image size (must be 64 or 128)') 227 | 228 | args = parser.parse_args() 229 | if args.gpuid == -1: 230 | args.gpuid = '/cpu:0' 231 | else: 232 | args.gpuid = '/gpu:{}'.format(args.gpuid) 233 | 234 | if args.decay_every == -1: 235 | args.decay_every = np.inf 236 | 237 | if args.train: 238 | demo = False 239 | get = False 240 | else: 241 | demo = True 242 | get = True 243 | args.add_epochs = 0 244 | 245 | im = began_train(start_epoch=args.start_epoch, add_epochs=args.add_epochs, 246 | batch_size=args.batch_size, hidden_size=args.hidden_size, 247 | gpu_id=args.gpuid, demo=demo, get=get, 248 | image_size=args.image_size, 249 | save_every=args.save_every, decay_every=args.decay_every, 250 | batch_norm=args.batch_norm, num_images=args.num_images, 251 | start_learn_rate=args.start_learn_rate) 252 | 253 | if not args.train: 254 | import matplotlib.pyplot as plt 255 | for n in range(8): 256 | im_to_save = im[n].reshape([args.image_size, args.image_size, 3]) 257 | plt.imsave(args.outdir+'/out_{}.jpg'.format(n), 258 | im_to_save) 259 | -------------------------------------------------------------------------------- /readme/conv_measure_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artcg/BEGAN/d6992dabca55f1d4c92161f9cd4dcc7722a8bbf5/readme/conv_measure_vis.png -------------------------------------------------------------------------------- /readme/eq_autoencoder_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artcg/BEGAN/d6992dabca55f1d4c92161f9cd4dcc7722a8bbf5/readme/eq_autoencoder_loss.png -------------------------------------------------------------------------------- /readme/eq_conv_measure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artcg/BEGAN/d6992dabca55f1d4c92161f9cd4dcc7722a8bbf5/readme/eq_conv_measure.png -------------------------------------------------------------------------------- /readme/eq_gamma.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artcg/BEGAN/d6992dabca55f1d4c92161f9cd4dcc7722a8bbf5/readme/eq_gamma.png -------------------------------------------------------------------------------- /readme/eq_global.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artcg/BEGAN/d6992dabca55f1d4c92161f9cd4dcc7722a8bbf5/readme/eq_global.png -------------------------------------------------------------------------------- /readme/eq_losses.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artcg/BEGAN/d6992dabca55f1d4c92161f9cd4dcc7722a8bbf5/readme/eq_losses.png -------------------------------------------------------------------------------- /readme/eq_objective.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artcg/BEGAN/d6992dabca55f1d4c92161f9cd4dcc7722a8bbf5/readme/eq_objective.png -------------------------------------------------------------------------------- /readme/generated_from_Z.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artcg/BEGAN/d6992dabca55f1d4c92161f9cd4dcc7722a8bbf5/readme/generated_from_Z.png -------------------------------------------------------------------------------- /utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | 6 | def leaky_rectify(x, leakiness=0.01): 7 | assert leakiness <= 1 8 | ret = tf.maximum(x, leakiness * x) 9 | return ret 10 | 11 | 12 | def custom_conv2d(input_layer, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, in_dim=None, 13 | padding='SAME', scope="conv2d"): 14 | with tf.variable_scope(scope): 15 | w = tf.get_variable('w', [k_h, k_w, in_dim or input_layer.shape[-1], output_dim], 16 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 17 | conv = tf.nn.conv2d(input_layer, w, 18 | strides=[1, d_h, d_w, 1], padding=padding) 19 | b = tf.get_variable("b", shape=output_dim, initializer=tf.constant_initializer(0.)) 20 | conv = tf.nn.bias_add(conv, b) 21 | return conv 22 | 23 | 24 | 25 | def custom_fc(input_layer, output_size, scope='Linear', 26 | in_dim=None, stddev=0.02, bias_start=0.0): 27 | shape = input_layer.shape 28 | if len(shape) > 2: 29 | input_layer = tf.reshape(input_layer, [-1, int(np.prod(shape[1:]))]) 30 | shape = input_layer.shape 31 | with tf.variable_scope(scope): 32 | matrix = tf.get_variable("weight", 33 | [in_dim or shape[1], output_size], 34 | dtype=tf.float32, 35 | initializer=tf.random_normal_initializer(stddev=stddev)) 36 | bias = tf.get_variable("bias", [output_size], initializer=tf.constant_initializer(bias_start)) 37 | return tf.nn.bias_add(tf.matmul(input_layer, matrix), bias) 38 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import scipy.misc 4 | from glob import glob 5 | from config import data_path 6 | 7 | 8 | def loadData(size): 9 | import h5py 10 | with h5py.File(data_path, 'r') as hf: 11 | faces = hf['images'] 12 | choice = np.random.choice(len(faces), size, replace=False) 13 | faces = faces[sorted(choice)] 14 | faces = np.array(faces, dtype=np.float16) 15 | return faces / 255 16 | 17 | 18 | def loadJPGs(path='/home/arthur/devel/input/', width=64, height=64): 19 | filenames = glob(path+"*.jpg") 20 | filenames = np.sort(filenames) 21 | 22 | def imread(path): 23 | return scipy.misc.imread(path) 24 | 25 | def scaleHeight(x, height=64): 26 | h, w = x.shape[:2] 27 | return scipy.misc.imresize(x, [height, int((float(w)/h)*height)]) 28 | 29 | def cropSides(x, width=64): 30 | w = x.shape[1] 31 | j = int(round((w - width)/2.)) 32 | return x[:, j:j+width, :] 33 | 34 | def get_image(image_path, width=64, height=64): 35 | return cropSides(scaleHeight(imread(image_path), height=height), 36 | width=width) 37 | 38 | images = np.zeros((len(filenames), width * height * 3), dtype=np.uint8) 39 | 40 | for n, i in enumerate(filenames): 41 | im = get_image(i) 42 | images[n] = im.flatten() 43 | images = np.array(images, dtype=np.float16) 44 | return images / 255 45 | 46 | 47 | def dataIterator(data, batch_size): 48 | ''' 49 | From great jupyter notebook by Tim Sainburg: 50 | http://github.com/timsainb/Tensorflow-MultiGPU-VAE-GAN 51 | ''' 52 | batch_idx = 0 53 | while True: 54 | length = len(data[0]) 55 | assert all(len(i) == length for i in data) 56 | idxs = np.arange(0, length) 57 | np.random.shuffle(idxs) 58 | for batch_idx in range(0, length, batch_size): 59 | cur_idxs = idxs[batch_idx:batch_idx + batch_size] 60 | images_batch = data[0][cur_idxs] 61 | # images_batch = images_batch.astype("float32") 62 | yield images_batch 63 | 64 | 65 | def create_image(im): 66 | ''' 67 | From great jupyter notebook by Tim Sainburg: 68 | http://github.com/timsainb/Tensorflow-MultiGPU-VAE-GAN 69 | ''' 70 | d1 = int(np.sqrt((np.product(im.shape) / 3))) 71 | im = np.array(im, dtype=np.float32) 72 | return np.reshape(im, (d1, d1, 3)) 73 | 74 | 75 | def plot_gens(images, rowlabels, losses): 76 | ''' 77 | From great jupyter notebook by Tim Sainburg: 78 | http://github.com/timsainb/Tensorflow-MultiGPU-VAE-GAN 79 | ''' 80 | examples = 8 81 | fig, ax = plt.subplots(nrows=len(images), ncols=examples, figsize=(18, 8)) 82 | for i in range(examples): 83 | for j in range(len(images)): 84 | ax[(j, i)].imshow(create_image(images[j][i]), cmap=plt.cm.gray, 85 | interpolation='nearest') 86 | ax[(j, i)].axis('off') 87 | title = '' 88 | for i in rowlabels: 89 | title += ' {}, '.format(i) 90 | fig.suptitle('Top to Bottom: {}'.format(title)) 91 | plt.show() 92 | #fig.savefig(''.join(['imgs/test_',str(epoch).zfill(4),'.png']),dpi=100) 93 | fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(20, 10), linewidth = 4) 94 | 95 | D_plt, = plt.semilogy((losses['discriminator']), linewidth=4, ls='-', 96 | color='b', alpha=.5, label='D') 97 | G_plt, = plt.semilogy((losses['generator']), linewidth=4, ls='-', 98 | color='k', alpha=.5, label='G') 99 | 100 | plt.gca() 101 | leg = plt.legend(handles=[D_plt, G_plt], 102 | fontsize=20) 103 | leg.get_frame().set_alpha(0.5) 104 | plt.show() 105 | --------------------------------------------------------------------------------