├── LICENSE
├── README.md
├── checkpoints
└── checkpoint
├── config.py
├── datasets
├── download_celebA.py
└── prepare_celeba.py
├── discriminator.py
├── generator.py
├── main.py
├── readme
├── conv_measure_vis.png
├── eq_autoencoder_loss.png
├── eq_conv_measure.png
├── eq_gamma.png
├── eq_global.png
├── eq_losses.png
├── eq_objective.png
└── generated_from_Z.png
└── utils
├── custom_ops.py
└── misc.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Arthur Goldberg
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BEGAN: Boundary Equibilibrium Generative Adversarial Networks
2 |
3 | This is an implementation of the paper on Boundary Equilibrium Generative Adversarial Networks [(Berthelot, Schumm and Metz, 2017)](#references).
4 |
5 | ## Dependencies
6 |
7 | * Python 3+
8 | * numpy
9 | * Tensorflow
10 | * tqdm
11 | * h5py
12 | * scipy (optional)
13 |
14 | ## What are Boundary Equilibrium Generative Adversarial Networks?
15 |
16 | Unlike standard generative adversarial networks [(Goodfellow et al. 2014)](#references), boundary equilibrium generative adversarial networks (BEGAN) use an auto-encoder as a disciminator. An auto-encoder loss is defined, and an approximation of the Wasserstein distance is then computed between the pixelwise auto-encoder loss distributions of real and generated samples.
17 |
18 |
19 |
20 |
21 |
22 | With the auto-encoder loss defined (above), the Wasserstein distance approximation simplifies to a loss function wherein the discriminating auto-encoder aims to perform *well on real samples* and *poorly on generated samples*, while the generator aims to produce adversarial samples which the discriminator can't help but perform well upon.
23 |
24 |
25 |
26 |
27 |
28 | Additionally, a hyper-parameter gamma is introduced which gives the user the power to control sample diversity by balancing the discriminator and generator.
29 |
30 |
31 |
32 |
33 |
34 | Gamma is put into effect through the use of a weighting parameter *k* which gets updated while training to adapt the loss function so that our output matches the desired diversity. The overall objective for the network is then:
35 |
36 |
37 |
38 |
39 |
40 | Unlike most generative adversarial network architectures, where we need to update *G* and *D* independently, the Boundary Equilibrium GAN has the nice property that we can define a global loss and train the network as a whole (though we still have to make sure to update parameters with respect to the relative loss functions)
41 |
42 |
43 |
44 |
45 |
46 | The final contribution of the paper is a derived convergence measure M which gives a good indicator as to how the network is doing. We use this parameter to track performance, as well as control learning rate.
47 |
48 |
49 |
50 |
51 |
52 | The overall result is a surprisingly effective model which produces samples well beyond the previous state of the art.
53 |
54 |
55 |
56 |
57 |
58 | *128x128 samples generated from random points in Z, from [(Berthelot, Schumm and Metz, 2017)](#references).*
59 |
60 | ## Usage
61 |
62 | ### Data Preprocessing
63 |
64 | You might want to use the 'CelebA' dataset [(Liu et al. 2015)](#references), this can be downloaded from [the project website](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). Make sure to download the 'Aligned and Cropped' Version. However you can modify these instructions to use an alternate dataset.
65 |
66 | (Note: if the CelebA Dropbox is down you can alternatively use their [Google Drive](https://drive.google.com/open?id=0B7EVK8r0v71pWEZsZE9oNnFzTm8)).
67 |
68 | This then needs to be prepared into hdf5 through the following method:
69 |
70 | ```python
71 | from glob import glob
72 | import os
73 | import numpy as np
74 | import h5py
75 | from tqdm import tqdm
76 | from scipy.misc import imread, imresize
77 |
78 | filenames = glob(os.path.join("img_align_celeba", "*.jpg"))
79 | filenames = np.sort(filenames)
80 | w, h = 64, 64 # Change this if you wish to use larger images
81 | data = np.zeros((len(filenames), w * h * 3), dtype = np.uint8)
82 |
83 | # This preprocessing is appriate for CelebA but should be adapted
84 | # (or removed entirely) for other datasets.
85 |
86 | def get_image(image_path, w=64, h=64):
87 | im = imread(image_path).astype(np.float)
88 | orig_h, orig_w = im.shape[:2]
89 | new_h = int(orig_h * w / orig_w)
90 | im = imresize(im, (new_h, w))
91 | margin = int(round((new_h - h)/2))
92 | return im[margin:margin+h]
93 |
94 | for n, fname in tqdm(enumerate(filenames)):
95 | image = get_image(fname, w, h)
96 | data[n] = image.flatten()
97 |
98 | with h5py.File(''.join(['datasets/celeba.h5']), 'w') as f:
99 | f.create_dataset("images", data=data)
100 | ```
101 |
102 |
103 | ### Training
104 |
105 | After your dataset has been created through the method above, change the file config.py to point to your dataset, and to point to your desired checkpoint directory.
106 |
107 | E.g., if your dataset is stored at ```/home/user/data/dataset.hdf5```, then alter config.py to read:
108 |
109 | ```python
110 | dataset_path = '/home/user/data/dataset.hdf5'
111 | checkpoint_path = './checkpoints'
112 | ```
113 |
114 | You can then begin training:
115 |
116 | ```bash
117 | python main.py --start-epoch=0, add-epochs=100 --save-every 5
118 | ```
119 |
120 | If you have limited RAM you might need to limit the number of images loaded into memory at once, e.g.
121 |
122 | ```bash
123 | python main.py --start-epoch=0 add-epochs=100 --save-every 5 --max-images 20000
124 | ```
125 |
126 | I have 12GB which works for around 60,000 images.
127 |
128 | You can specify GPU id with the ```--gpuid``` argument. If you want to run on CPU (not recommended!) use ```--gpuid -1```
129 |
130 | Other parameters can be tuned if you wish (run ```python main.py --help``` for the full list).
131 | The default values are the same as in the paper (though the authors point out that their choices aren't necessarily optimal).
132 |
133 | The main difference between this implementation's defaults and the original paper is the use of batch normalisation, we found that not using batch normalisation made training much slower.
134 |
135 | ### Running
136 |
137 | After you've trained a model and you want to generate some samples simply run
138 | ```bash
139 | python main.py --start-epoch=N add-epochs=0 --train=False
140 | ```
141 | where N is the checkpoint you want to run from.
142 | Samples will be saved to ./outputs/ by default (or add optional argument ```--outdir``` for alternative).
143 |
144 | ### Tracking Progress
145 |
146 | As discussed previously, the convergence measure gives a very nice way of tracking progress
147 | This is implemented into the code (via the dictionary ```loss_tracker``` with key ```convergence_measure```)
148 |
149 | Berthelot, Schumm and Metz show that it is a true-to-reality metric to use:
150 |
151 |
152 |
153 |
154 |
155 | *Convergence measure over training epochs, with generator outputs showed above [(Berthelot, Schumm and Metz, 2017)](#references).*
156 |
157 |
158 | ## Issues / Contributing / Todo
159 |
160 | Feel free to raise any issues in the project [issue tracker](http://github.com/artcg/BEGAN/issues), or make a [pull-request](http://github.com/artcg/BEGAN/pulls) if there is something you want to add.
161 |
162 | My next plan is to upload some pre-trained weights so beginners can run the model out-of-the-box.
163 |
164 | ## References
165 |
166 | * [Berthelot, Schumm and Metz. BEGAN: Boundary Equilibrium Generative Adversarial Networks. arXiv preprint, 2017](https://arxiv.org/abs/1703.10717)
167 |
168 | * [Goodfellow, Ian, et al. "Generative adversarial nets." Advances in neural information processing systems. 2014.](http://papers.nips.cc/paper/5423-generative-adversarial-nets)
169 |
170 | * [Liu, Ziwei, et al. "Deep Learning Face Attributes in the Wild." Proceedings of International Conference on Computer Vision. 2015.](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)
171 |
--------------------------------------------------------------------------------
/checkpoints/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "BEGAN_64_64_0321.tfmod"
2 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | checkpoint_path = 'checkpoints'
2 | checkpoint_prefix = 'BEGAN_64_64'
3 | data_path = 'datasets/CelebA_64_64.h5'
4 |
--------------------------------------------------------------------------------
/datasets/download_celebA.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Akash Rana (github.com/akash9182)
3 |
4 | Modification of
5 | - https://github.com/carpedm20/DCGAN-tensorflow/blob/master/download.py
6 | - http://stackoverflow.com/a/39225039
7 | -
8 | """
9 |
10 | import os
11 | import zipfile
12 | import requests
13 | import subprocess
14 | from tqdm import tqdm
15 | from collections import OrderedDict
16 |
17 | base_path = './'
18 |
19 | def download_file_from_google_drive(id, destination):
20 | URL = "https://docs.google.com/uc?export=download"
21 | session = requests.Session()
22 |
23 | response = session.get(URL, params={ 'id': id }, stream=True)
24 | token = get_confirm_token(response)
25 |
26 | if token:
27 | params = { 'id' : id, 'confirm' : token }
28 | response = session.get(URL, params=params, stream=True)
29 |
30 | save_response_content(response, destination)
31 |
32 | def get_confirm_token(response):
33 | for key, value in response.cookies.items():
34 | if key.startswith('download_warning'):
35 | return value
36 | return None
37 |
38 | def save_response_content(response, destination, chunk_size=32*1024):
39 | total_size = int(response.headers.get('content-length', 0))
40 | with open(destination, "wb") as f:
41 | for chunk in tqdm(response.iter_content(chunk_size), total=total_size,
42 | unit='B', unit_scale=True, desc=destination):
43 | if chunk: # filter out keep-alive new chunks
44 | f.write(chunk)
45 |
46 | def unzip(filepath):
47 | print("Extracting: " + filepath)
48 | base_path = os.path.dirname(filepath)
49 | with zipfile.ZipFile(filepath) as zf:
50 | zf.extractall(base_path)
51 | os.remove(filepath)
52 |
53 | def download_celeb_a(base_path):
54 | data_path = os.path.join(base_path, 'CelebA')
55 | images_path = os.path.join(data_path, 'images')
56 | if os.path.exists(data_path):
57 | print('[!] Found Celeb-A - skip')
58 | return
59 |
60 | filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM"
61 | save_path = os.path.join(base_path, filename)
62 |
63 | if os.path.exists(save_path):
64 | print('[*] {} already exists'.format(save_path))
65 | else:
66 | download_file_from_google_drive(drive_id, save_path)
67 |
68 | zip_dir = ''
69 | with zipfile.ZipFile(save_path) as zf:
70 | zip_dir = zf.namelist()[0]
71 | zf.extractall(base_path)
72 | if not os.path.exists(data_path):
73 | os.mkdir(data_path)
74 | os.rename(os.path.join(base_path, "img_align_celeba"), images_path)
75 | os.remove(save_path)
76 |
77 |
78 |
79 | if __name__ == '__main__':
80 | download_celeb_a(base_path)
81 |
--------------------------------------------------------------------------------
/datasets/prepare_celeba.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 | import os
3 | import numpy as np
4 | import h5py
5 | from tqdm import tqdm
6 | from scipy.misc import imread, imresize
7 |
8 | filenames = glob(os.path.join("img_align_celeba", "*.jpg"))
9 | filenames = np.sort(filenames)
10 | w, h = 64, 64 # Change this if you wish to use larger images
11 | data = np.zeros((len(filenames), w * h * 3), dtype=np.uint8)
12 |
13 | # This preprocessing is appriate for CelebA but should be adapted
14 | # (or removed entirely) for other datasets.
15 |
16 |
17 | def get_image(image_path, w=64, h=64):
18 | im = imread(image_path).astype(np.float)
19 | orig_h, orig_w = im.shape[:2]
20 | new_h = int(orig_h * w / orig_w)
21 | im = imresize(im, (new_h, w))
22 | margin = int(round((new_h - h)/2))
23 | return im[margin:margin+h]
24 |
25 | for n, fname in tqdm(enumerate(filenames)):
26 | image = get_image(fname, w, h)
27 | data[n] = image.flatten()
28 |
29 | with h5py.File('datasets/celeba.h5', 'w') as f:
30 | f.create_dataset("images", data=data)
31 |
--------------------------------------------------------------------------------
/discriminator.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from utils.custom_ops import custom_fc, custom_conv2d
3 | from generator import decoder
4 |
5 |
6 | def began_discriminator(D_I, batch_size, num_filters, hidden_size, image_size,
7 | scope_name="discriminator", reuse_scope=False):
8 | '''
9 | Unlike most generative adversarial networks, the boundary
10 | equilibrium uses an autoencoder as a discriminator.
11 |
12 | For simplicity, the decoder architecture is the same as the generator.
13 |
14 | Downsampling is 3x3 convolutions with a stride of 2.
15 | Upsampling is 3x3 convolutions, with nearest neighbour resizing
16 | to the desired resolution.
17 |
18 | Args:
19 | D_I: a batch of images [batch_size, 64 x 64 x 3]
20 | batch_size: Batch size of encodings
21 | num_filters: Number of filters in convolutional layers
22 | hidden_size: Dimensionality of encoding
23 | image_size: First dimension of generated image (must be 64 or 128)
24 | scope_name: Tensorflow scope name
25 | reuse_scope: Tensorflow scope handling
26 | Returns:
27 | Flattened tensor of re-created images, with dimensionality:
28 | [batch_size, image_size * image_size * 3]
29 | '''
30 |
31 |
32 | with tf.variable_scope(scope_name) as scope:
33 | if reuse_scope:
34 | scope.reuse_variables()
35 |
36 | layer_1 = tf.reshape(D_I, [-1, image_size, image_size, 3]) # '-1' is batch size
37 |
38 | conv_0 = custom_conv2d(layer_1, 3, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec0')
39 | conv_0 = tf.nn.elu(conv_0)
40 |
41 | conv_1 =custom_conv2d(conv_0, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec1')
42 | conv_1 = tf.nn.elu(conv_1)
43 |
44 |
45 | conv_2 =custom_conv2d(conv_1, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec2')
46 | conv_2 = tf.nn.elu(conv_2)
47 |
48 |
49 | layer_2 = custom_conv2d(conv_2, 2 * num_filters, k_h=3, k_w=3, d_h=2, d_w=2, scope='el2')
50 | layer_2 = tf.nn.elu(layer_2)
51 |
52 | conv_3 = custom_conv2d(layer_2, 2 * num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec3')
53 | conv_3 = tf.nn.elu(conv_3)
54 |
55 | conv_4 =custom_conv2d(conv_3, 2 * num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec4')
56 | conv_4 = tf.nn.elu(conv_4)
57 |
58 |
59 | layer_3 = custom_conv2d(conv_2, 3 * num_filters, k_h=3, k_w=3, d_h=2, d_w=2, scope='el3')
60 | layer_3 = tf.nn.elu(layer_3)
61 |
62 | conv_5 = custom_conv2d(layer_3, 3 * num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec5')
63 | conv_5 = tf.nn.elu(conv_5)
64 |
65 | conv_6 =custom_conv2d(conv_5, 3 * num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec6')
66 | conv_6 = tf.nn.elu(conv_6)
67 |
68 |
69 | layer_4 = custom_conv2d(conv_6, 4 * num_filters, k_h=3, k_w=3, d_h=2, d_w=2, scope='el4')
70 | layer_4 = tf.nn.elu(layer_4)
71 |
72 | conv_7 = custom_conv2d(layer_4, 4 * num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec7')
73 | conv_7 = tf.nn.elu(conv_7)
74 |
75 | conv_8 =custom_conv2d(conv_7, 4 * num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec8')
76 | conv_8 = tf.nn.elu(conv_8)
77 |
78 | if image_size == 64:
79 | enc = custom_fc(conv_8, hidden_size, scope='enc')
80 | else:
81 | layer_5 = custom_conv2d(conv_8, 5 * num_filters, k_h=3, k_w=3, d_h=2, d_w=2, scope='el5')
82 | layer_5 = tf.nn.elu(layer_5)
83 |
84 | conv_9 = custom_conv2d(layer_5, 5 * num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec9')
85 | conv_9 = tf.nn.elu(conv_9)
86 |
87 | conv_10 =custom_conv2d(conv_9, 5 * num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='ec10')
88 | conv_10 = tf.nn.elu(conv_10)
89 | enc = custom_fc(conv_10, hidden_size, scope='enc')
90 |
91 | # add elu before decoding?
92 | return decoder(enc, batch_size=batch_size, num_filters=num_filters,
93 | hidden_size=hidden_size, image_size=image_size)
94 |
--------------------------------------------------------------------------------
/generator.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from utils.custom_ops import custom_fc, custom_conv2d
3 |
4 | if False: # This to silence pyflake
5 | custom_ops
6 |
7 |
8 | def decoder(Z, batch_size, num_filters, hidden_size, image_size):
9 | '''
10 | The Boundary Equilibrium GAN deliberately uses a simple generator
11 | architecture.
12 |
13 | Upsampling is 3x3 convolutions, with nearest neighbour resizing
14 | to the desired resolution.
15 |
16 | Args:
17 | Z: Latent space
18 | batch_size: Batch size of generations
19 | num_filters: Number of filters in convolutional layers
20 | hidden_size: Dimensionality of encoding
21 | image_size: First dimension of generated image (must be 64 or 128)
22 | scope_name: Tensorflow scope name
23 | reuse_scope: Tensorflow scope handling
24 | Returns:
25 | Flattened tensor of generated images, with dimensionality:
26 | [batch_size, image_size * image_size * 3]
27 | '''
28 | layer_1 = custom_fc(Z, 8 * 8 * num_filters, scope='l1')
29 |
30 | layer_1 = tf.reshape(layer_1, [-1, 8, 8, num_filters]) # '-1' is batch size
31 |
32 | conv_1 = custom_conv2d(layer_1, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='c1')
33 | conv_1 = tf.nn.elu(conv_1)
34 |
35 | conv_2 = custom_conv2d(conv_1, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='c2')
36 | conv_2 = tf.nn.elu(conv_2)
37 |
38 | layer_2 = tf.image.resize_nearest_neighbor(conv_2, [16, 16])
39 |
40 | conv_3 = custom_conv2d(layer_2, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='c3')
41 | conv_3 = tf.nn.elu(conv_3)
42 |
43 | conv_4 = custom_conv2d(conv_3, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='c4')
44 | conv_4 = tf.nn.elu(conv_4)
45 |
46 | layer_3 = tf.image.resize_nearest_neighbor(conv_4, [32, 32])
47 |
48 | conv_5 = custom_conv2d(layer_3, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='c5')
49 | conv_5 = tf.nn.elu(conv_5)
50 |
51 | conv_6 = custom_conv2d(conv_5, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='c6')
52 | conv_6 = tf.nn.elu(conv_6)
53 |
54 | layer_4 = tf.image.resize_nearest_neighbor(conv_6, [64, 64])
55 |
56 | conv_7 = custom_conv2d(layer_4, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='c7')
57 | conv_7 = tf.nn.elu(conv_7)
58 |
59 | conv_8 = custom_conv2d(conv_7, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='c8')
60 | conv_8 = tf.nn.elu(conv_8)
61 |
62 | if image_size == 64:
63 | im = conv_8
64 | else:
65 | layer_5 = tf.image.resize_nearest_neighbor(conv_8, [128, 128])
66 |
67 | conv_9 = custom_conv2d(layer_5, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='c9')
68 | conv_9 = tf.nn.elu(conv_9)
69 |
70 | conv_10 = custom_conv2d(conv_9, num_filters, k_h=3, k_w=3, d_h=1, d_w=1, scope='c10')
71 | im = tf.nn.elu(conv_10)
72 |
73 | im = custom_conv2d(im, 3, k_h=3, k_w=3, d_h=1, d_w=1, scope='im')
74 | im = tf.sigmoid(im)
75 | im = tf.reshape(im, [-1, image_size * image_size * 3])
76 | return im
77 |
78 | def began_generator(Z, batch_size, num_filters, hidden_size, image_size,
79 | scope_name="generator", reuse_scope=False):
80 | with tf.variable_scope(scope_name) as scope:
81 | if reuse_scope:
82 | scope.reuse_variables()
83 | return decoder(Z, batch_size, num_filters, hidden_size, image_size)
84 |
85 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from generator import began_generator as generator
3 | from discriminator import began_discriminator as discriminator
4 | from utils.misc import loadData, dataIterator
5 | import tqdm
6 | import numpy as np
7 | from utils.misc import plot_gens
8 | from config import checkpoint_path, checkpoint_prefix
9 |
10 |
11 | class BEGAN:
12 | loss_tracker = {'generator': [],
13 | 'discriminator': [],
14 | 'convergence_measure': []}
15 |
16 | def loss(D_real_in, D_real_out, D_gen_in, D_gen_out, k_t, gamma=0.75):
17 | '''
18 | The Boundary Equibilibrium GAN uses an approximation of the
19 | Wasserstein Loss between the disitributions of pixel-wise
20 | autoencoder loss based on the discriminator performance on
21 | real vs. generated data.
22 |
23 | This simplifies to reducing the L1 norm of the autoencoder loss:
24 | making the discriminator objective to perform well on real images
25 | and poorly on generated images; with the generator objective
26 | to create samples which the discriminator will perform well upon.
27 |
28 | args:
29 | D_real_in: input to discriminator with real sample.
30 | D_real_out: output from discriminator with real sample.
31 | D_gen_in: input to discriminator with generated sample.
32 | D_gen_out: output from discriminator with generated sample.
33 | k_t: weighting parameter which constantly updates during training
34 | gamma: diversity ratio, used to control model equibilibrium.
35 | returns:
36 | D_loss: discriminator loss to minimise.
37 | G_loss: generator loss to minimise.
38 | k_tp: value of k_t for next train step.
39 | convergence_measure: measure of model convergence.
40 | '''
41 | def pixel_autoencoder_loss(out, inp):
42 | '''
43 | The autoencoder loss used is the L1 norm (note that this
44 | is based on the pixel-wise distribution of losses
45 | that the authors assert approximates the Normal distribution)
46 |
47 | args:
48 | out: discriminator output
49 | inp: discriminator input
50 | returns:
51 | L1 norm of pixel-wise loss
52 | '''
53 | eta = 1 # paper uses L1 norm
54 | diff = tf.abs(out - inp)
55 | if eta == 1:
56 | return tf.reduce_mean(diff)
57 | else:
58 | return tf.reduce_mean(tf.pow(diff, eta))
59 |
60 | mu_real = pixel_autoencoder_loss(D_real_out, D_real_in)
61 | mu_gen = pixel_autoencoder_loss(D_gen_out, D_gen_in)
62 | D_loss = mu_real - k_t * mu_gen
63 | G_loss = mu_gen
64 | lam = 0.001 # 'learning rate' for k. Berthelot et al. use 0.001
65 | k_tp = k_t + lam * (gamma * mu_real - mu_gen)
66 | convergence_measure = mu_real + np.abs(gamma * mu_real - mu_gen)
67 | return D_loss, G_loss, k_tp, convergence_measure
68 |
69 | def run(x, batch_size, num_filters, hidden_size, image_size):
70 | Z = tf.random_normal((batch_size, hidden_size), 0, 1)
71 |
72 | x_tilde = generator(Z, batch_size=batch_size, num_filters=num_filters,
73 | hidden_size=hidden_size, image_size=image_size)
74 | x_tilde_d = discriminator(x_tilde, batch_size=batch_size, num_filters=num_filters,
75 | hidden_size=hidden_size, image_size=image_size)
76 |
77 | x_d = discriminator(x, reuse_scope=True, batch_size=batch_size, num_filters=num_filters,
78 | hidden_size=hidden_size, image_size=image_size)
79 |
80 | return x_tilde, x_tilde_d, x_d
81 |
82 | scopes = ['generator', 'discriminator']
83 |
84 |
85 | def began_train(num_images=50000, start_epoch=0, add_epochs=None, batch_size=16,
86 | hidden_size=64, image_size=64, gpu_id='/gpu:0',
87 | demo=False, get=False, start_learn_rate=1e-4, decay_every=100,
88 | save_every=1, batch_norm=True, gamma=0.75):
89 |
90 | num_epochs = start_epoch + add_epochs
91 | loss_tracker = BEGAN.loss_tracker
92 |
93 | graph = tf.Graph()
94 | with graph.as_default():
95 | global_step = tf.get_variable('global_step', [],
96 | initializer=tf.constant_initializer(0),
97 | trainable=False)
98 |
99 | with tf.device(gpu_id):
100 | learning_rate = tf.placeholder(tf.float32, shape=[])
101 | opt = tf.train.AdamOptimizer(learning_rate, epsilon=1.0)
102 |
103 | next_batch = tf.placeholder(tf.float32,
104 | [batch_size, image_size * image_size * 3])
105 |
106 | x_tilde, x_tilde_d, x_d = BEGAN.run(next_batch, batch_size=batch_size, num_filters=128,
107 | hidden_size=hidden_size, image_size=image_size)
108 |
109 | k_t = tf.get_variable('kt', [],
110 | initializer=tf.constant_initializer(0),
111 | trainable=False)
112 | D_loss, G_loss, k_tp, convergence_measure = \
113 | BEGAN.loss(next_batch, x_d, x_tilde, x_tilde_d, k_t=k_t)
114 |
115 | params = tf.trainable_variables()
116 | tr_vars = {}
117 | for s in BEGAN.scopes:
118 | tr_vars[s] = [i for i in params if s in i.name]
119 |
120 | G_grad = opt.compute_gradients(G_loss,
121 | var_list=tr_vars['generator'])
122 |
123 | D_grad = opt.compute_gradients(D_loss,
124 | var_list=tr_vars['discriminator'])
125 |
126 | G_train = opt.apply_gradients(G_grad, global_step=global_step)
127 | D_train = opt.apply_gradients(D_grad, global_step=global_step)
128 |
129 | init = tf.global_variables_initializer()
130 | saver = tf.train.Saver()
131 | sess = tf.Session(graph=graph,
132 | config=tf.ConfigProto(allow_soft_placement=True,
133 | log_device_placement=True))
134 | sess.run(init)
135 | if start_epoch > 0:
136 | path = '{}/{}_{}.tfmod'.format(checkpoint_path,
137 | checkpoint_prefix,
138 | str(start_epoch-1).zfill(4))
139 | tf.train.Saver.restore(saver, sess, path)
140 |
141 | k_t_ = sess.run(k_t) # We initialise with k_t = 0 as in the paper.
142 | num_batches_per_epoch = num_images // batch_size
143 | for epoch in range(start_epoch, num_epochs):
144 | images = loadData(size=num_images)
145 | print('Epoch {} / {}'.format(epoch + 1, num_epochs + 1))
146 | for i in tqdm.tqdm(range(num_batches_per_epoch)):
147 | iter_ = dataIterator([images], batch_size)
148 |
149 | learning_rate_ = start_learn_rate * pow(0.5, epoch // decay_every)
150 | next_batch_ = next(iter_)
151 |
152 | _, _, D_loss_, G_loss_, k_t_, M_ = \
153 | sess.run([G_train, D_train, D_loss, G_loss, k_tp, convergence_measure],
154 | {learning_rate: learning_rate_,
155 | next_batch: next_batch_, k_t: min(max(k_t_, 0), 1)})
156 |
157 | loss_tracker['generator'].append(G_loss_)
158 | loss_tracker['discriminator'].append(D_loss_)
159 | loss_tracker['convergence_measure'].append(M_)
160 |
161 | if epoch % save_every == 0:
162 | path = '{}/{}_{}.tfmod'.format(checkpoint_path,
163 | checkpoint_prefix,
164 | str(epoch).zfill(4))
165 | saver.save(sess, path)
166 | if demo:
167 | batch = dataIterator([images], batch_size).__next__()
168 | ims = sess.run(x_tilde)
169 | plot_gens((ims, batch),
170 | ('Generated 64x64 samples.', 'Random training images.'),
171 | loss_tracker)
172 | if get:
173 | return ims
174 |
175 |
176 | if __name__ == '__main__':
177 | import argparse
178 | parser = argparse.ArgumentParser(description='Run BEGAN.')
179 |
180 | parser.add_argument('--gpuid', type=int, default=0,
181 | help='GPU ID to use (-1 for CPU)')
182 |
183 | parser.add_argument('--save-every', type=int, default=5,
184 | help='Frequency to save checkpoint (in epochs)')
185 |
186 | parser.add_argument('--start-epoch', type=int, default=0, required=True,
187 | help='Start epoch (0 to begin training from scratch,'
188 | + 'N to restore from checkpoint N)')
189 |
190 | parser.add_argument('--add-epochs', type=int, default=100, required=True,
191 | help='Number of epochs to train'
192 | + '(-1 to train indefinitely)')
193 |
194 | parser.add_argument('--num-images', type=int, default=2000,
195 | help='Number of images to load into RAM at once')
196 |
197 | parser.add_argument('--gamma', type=float, default=0.75,
198 | help='Diversity ratio (read paper for more info)')
199 |
200 | parser.add_argument('--start-learn-rate', type=float, default=1e-5,
201 | help='Starting learn rate')
202 |
203 | parser.add_argument('--train', type=int, default=1,
204 | help='"1" to train; "0" to run and'
205 | + 'return output')
206 |
207 | parser.add_argument('--batch_size', type=int, default=16,
208 | help='Batch size for training (default 16'
209 | + 'as in paper)')
210 |
211 | parser.add_argument('--hidden_size', type=int, default=64,
212 | help='Dimensionality of the discriminator encoding.'
213 | + '(Paper doesnt specify value so we use guess)')
214 |
215 | parser.add_argument('--batch-norm', type=int, default=1,
216 | help='Set to "0" to disable batch normalisation')
217 |
218 | parser.add_argument('--decay-every', type=int, default=-1,
219 | help='Number of epochs before learning rate decay'
220 | + '(set to 0 to disable)')
221 |
222 | parser.add_argument('--outdir', type=str, default='output',
223 | help='Path to save output generations')
224 |
225 | parser.add_argument('--image-size', type=int, default=64,
226 | help='Image size (must be 64 or 128)')
227 |
228 | args = parser.parse_args()
229 | if args.gpuid == -1:
230 | args.gpuid = '/cpu:0'
231 | else:
232 | args.gpuid = '/gpu:{}'.format(args.gpuid)
233 |
234 | if args.decay_every == -1:
235 | args.decay_every = np.inf
236 |
237 | if args.train:
238 | demo = False
239 | get = False
240 | else:
241 | demo = True
242 | get = True
243 | args.add_epochs = 0
244 |
245 | im = began_train(start_epoch=args.start_epoch, add_epochs=args.add_epochs,
246 | batch_size=args.batch_size, hidden_size=args.hidden_size,
247 | gpu_id=args.gpuid, demo=demo, get=get,
248 | image_size=args.image_size,
249 | save_every=args.save_every, decay_every=args.decay_every,
250 | batch_norm=args.batch_norm, num_images=args.num_images,
251 | start_learn_rate=args.start_learn_rate)
252 |
253 | if not args.train:
254 | import matplotlib.pyplot as plt
255 | for n in range(8):
256 | im_to_save = im[n].reshape([args.image_size, args.image_size, 3])
257 | plt.imsave(args.outdir+'/out_{}.jpg'.format(n),
258 | im_to_save)
259 |
--------------------------------------------------------------------------------
/readme/conv_measure_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/artcg/BEGAN/d6992dabca55f1d4c92161f9cd4dcc7722a8bbf5/readme/conv_measure_vis.png
--------------------------------------------------------------------------------
/readme/eq_autoencoder_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/artcg/BEGAN/d6992dabca55f1d4c92161f9cd4dcc7722a8bbf5/readme/eq_autoencoder_loss.png
--------------------------------------------------------------------------------
/readme/eq_conv_measure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/artcg/BEGAN/d6992dabca55f1d4c92161f9cd4dcc7722a8bbf5/readme/eq_conv_measure.png
--------------------------------------------------------------------------------
/readme/eq_gamma.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/artcg/BEGAN/d6992dabca55f1d4c92161f9cd4dcc7722a8bbf5/readme/eq_gamma.png
--------------------------------------------------------------------------------
/readme/eq_global.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/artcg/BEGAN/d6992dabca55f1d4c92161f9cd4dcc7722a8bbf5/readme/eq_global.png
--------------------------------------------------------------------------------
/readme/eq_losses.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/artcg/BEGAN/d6992dabca55f1d4c92161f9cd4dcc7722a8bbf5/readme/eq_losses.png
--------------------------------------------------------------------------------
/readme/eq_objective.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/artcg/BEGAN/d6992dabca55f1d4c92161f9cd4dcc7722a8bbf5/readme/eq_objective.png
--------------------------------------------------------------------------------
/readme/generated_from_Z.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/artcg/BEGAN/d6992dabca55f1d4c92161f9cd4dcc7722a8bbf5/readme/generated_from_Z.png
--------------------------------------------------------------------------------
/utils/custom_ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 |
5 |
6 | def leaky_rectify(x, leakiness=0.01):
7 | assert leakiness <= 1
8 | ret = tf.maximum(x, leakiness * x)
9 | return ret
10 |
11 |
12 | def custom_conv2d(input_layer, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, in_dim=None,
13 | padding='SAME', scope="conv2d"):
14 | with tf.variable_scope(scope):
15 | w = tf.get_variable('w', [k_h, k_w, in_dim or input_layer.shape[-1], output_dim],
16 | initializer=tf.truncated_normal_initializer(stddev=stddev))
17 | conv = tf.nn.conv2d(input_layer, w,
18 | strides=[1, d_h, d_w, 1], padding=padding)
19 | b = tf.get_variable("b", shape=output_dim, initializer=tf.constant_initializer(0.))
20 | conv = tf.nn.bias_add(conv, b)
21 | return conv
22 |
23 |
24 |
25 | def custom_fc(input_layer, output_size, scope='Linear',
26 | in_dim=None, stddev=0.02, bias_start=0.0):
27 | shape = input_layer.shape
28 | if len(shape) > 2:
29 | input_layer = tf.reshape(input_layer, [-1, int(np.prod(shape[1:]))])
30 | shape = input_layer.shape
31 | with tf.variable_scope(scope):
32 | matrix = tf.get_variable("weight",
33 | [in_dim or shape[1], output_size],
34 | dtype=tf.float32,
35 | initializer=tf.random_normal_initializer(stddev=stddev))
36 | bias = tf.get_variable("bias", [output_size], initializer=tf.constant_initializer(bias_start))
37 | return tf.nn.bias_add(tf.matmul(input_layer, matrix), bias)
38 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import scipy.misc
4 | from glob import glob
5 | from config import data_path
6 |
7 |
8 | def loadData(size):
9 | import h5py
10 | with h5py.File(data_path, 'r') as hf:
11 | faces = hf['images']
12 | choice = np.random.choice(len(faces), size, replace=False)
13 | faces = faces[sorted(choice)]
14 | faces = np.array(faces, dtype=np.float16)
15 | return faces / 255
16 |
17 |
18 | def loadJPGs(path='/home/arthur/devel/input/', width=64, height=64):
19 | filenames = glob(path+"*.jpg")
20 | filenames = np.sort(filenames)
21 |
22 | def imread(path):
23 | return scipy.misc.imread(path)
24 |
25 | def scaleHeight(x, height=64):
26 | h, w = x.shape[:2]
27 | return scipy.misc.imresize(x, [height, int((float(w)/h)*height)])
28 |
29 | def cropSides(x, width=64):
30 | w = x.shape[1]
31 | j = int(round((w - width)/2.))
32 | return x[:, j:j+width, :]
33 |
34 | def get_image(image_path, width=64, height=64):
35 | return cropSides(scaleHeight(imread(image_path), height=height),
36 | width=width)
37 |
38 | images = np.zeros((len(filenames), width * height * 3), dtype=np.uint8)
39 |
40 | for n, i in enumerate(filenames):
41 | im = get_image(i)
42 | images[n] = im.flatten()
43 | images = np.array(images, dtype=np.float16)
44 | return images / 255
45 |
46 |
47 | def dataIterator(data, batch_size):
48 | '''
49 | From great jupyter notebook by Tim Sainburg:
50 | http://github.com/timsainb/Tensorflow-MultiGPU-VAE-GAN
51 | '''
52 | batch_idx = 0
53 | while True:
54 | length = len(data[0])
55 | assert all(len(i) == length for i in data)
56 | idxs = np.arange(0, length)
57 | np.random.shuffle(idxs)
58 | for batch_idx in range(0, length, batch_size):
59 | cur_idxs = idxs[batch_idx:batch_idx + batch_size]
60 | images_batch = data[0][cur_idxs]
61 | # images_batch = images_batch.astype("float32")
62 | yield images_batch
63 |
64 |
65 | def create_image(im):
66 | '''
67 | From great jupyter notebook by Tim Sainburg:
68 | http://github.com/timsainb/Tensorflow-MultiGPU-VAE-GAN
69 | '''
70 | d1 = int(np.sqrt((np.product(im.shape) / 3)))
71 | im = np.array(im, dtype=np.float32)
72 | return np.reshape(im, (d1, d1, 3))
73 |
74 |
75 | def plot_gens(images, rowlabels, losses):
76 | '''
77 | From great jupyter notebook by Tim Sainburg:
78 | http://github.com/timsainb/Tensorflow-MultiGPU-VAE-GAN
79 | '''
80 | examples = 8
81 | fig, ax = plt.subplots(nrows=len(images), ncols=examples, figsize=(18, 8))
82 | for i in range(examples):
83 | for j in range(len(images)):
84 | ax[(j, i)].imshow(create_image(images[j][i]), cmap=plt.cm.gray,
85 | interpolation='nearest')
86 | ax[(j, i)].axis('off')
87 | title = ''
88 | for i in rowlabels:
89 | title += ' {}, '.format(i)
90 | fig.suptitle('Top to Bottom: {}'.format(title))
91 | plt.show()
92 | #fig.savefig(''.join(['imgs/test_',str(epoch).zfill(4),'.png']),dpi=100)
93 | fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(20, 10), linewidth = 4)
94 |
95 | D_plt, = plt.semilogy((losses['discriminator']), linewidth=4, ls='-',
96 | color='b', alpha=.5, label='D')
97 | G_plt, = plt.semilogy((losses['generator']), linewidth=4, ls='-',
98 | color='k', alpha=.5, label='G')
99 |
100 | plt.gca()
101 | leg = plt.legend(handles=[D_plt, G_plt],
102 | fontsize=20)
103 | leg.get_frame().set_alpha(0.5)
104 | plt.show()
105 |
--------------------------------------------------------------------------------