├── .gitignore ├── README.md ├── aae_decoder.h5 ├── aae_encoder.h5 ├── images ├── aae_grid.png ├── aae_latent.png ├── regular_grid.png └── regular_latent.png ├── keras-aae.py ├── regular_decoder.h5 └── regular_encoder.h5 /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # keras-aae 2 | 3 | Reproduces Adversarial Autoencoder architecture from [Makhzani, Alireza, et al. "Adversarial autoencoders." arXiv preprint arXiv:1511.05644 (2015)](https://arxiv.org/abs/1511.05644) with Keras. 4 | 5 | ## Summary 6 | 7 | The Adversarial Autoencoder behaves similarly to [Variational Autoencoders](https://arxiv.org/abs/1312.6114), forcing the latent space of an autoencoder to follow a predefined prior. In the case of the Adversarial Autoencoder, this latent space can be defined arbitrarily and easily sampled and fed into the Discriminator in the network. 8 | 9 |
10 | Latent space from Adversarial Autoencoder 11 | Latent space from regular Autoencoder 12 |
13 | 14 | *The left image shows the latent space of an unseen MNIST test set after training with an Adversarial Autoencoder for 50 epochs, which follows a 2D Gaussian prior. Contrast this with the latent space of the regular Autoencoder trained under the same conditions, with a far more irregular latent distribution.* 15 | 16 | ## Instructions 17 | 18 | To train a model just run 19 | 20 | ``` 21 | $ python keras-aae.py --train 22 | ``` 23 | 24 | For more parameters, run with `--help` flag. 25 | 26 | For comparison with a regular autoencoder, run 27 | 28 | ``` 29 | $ python regular-ae.py --train --noadversarial 30 | ``` -------------------------------------------------------------------------------- /aae_decoder.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/greentfrapp/keras-aae/758d0e72b6d49a8d34579f153ad9aee24ebac282/aae_decoder.h5 -------------------------------------------------------------------------------- /aae_encoder.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/greentfrapp/keras-aae/758d0e72b6d49a8d34579f153ad9aee24ebac282/aae_encoder.h5 -------------------------------------------------------------------------------- /images/aae_grid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/greentfrapp/keras-aae/758d0e72b6d49a8d34579f153ad9aee24ebac282/images/aae_grid.png -------------------------------------------------------------------------------- /images/aae_latent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/greentfrapp/keras-aae/758d0e72b6d49a8d34579f153ad9aee24ebac282/images/aae_latent.png -------------------------------------------------------------------------------- /images/regular_grid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/greentfrapp/keras-aae/758d0e72b6d49a8d34579f153ad9aee24ebac282/images/regular_grid.png -------------------------------------------------------------------------------- /images/regular_latent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/greentfrapp/keras-aae/758d0e72b6d49a8d34579f153ad9aee24ebac282/images/regular_latent.png -------------------------------------------------------------------------------- /keras-aae.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | try: 3 | raw_input 4 | except: 5 | raw_input = input 6 | 7 | import numpy as np 8 | from keras.models import Sequential, Model, load_model 9 | from keras.layers import Input, Dense 10 | from keras.utils import plot_model 11 | from keras.datasets import mnist 12 | from keras.optimizers import Adam 13 | import argparse 14 | import matplotlib.pyplot as plt 15 | from matplotlib import gridspec, colors 16 | from datetime import datetime 17 | from sklearn.manifold import TSNE 18 | from absl import flags 19 | from absl import app 20 | 21 | 22 | FLAGS = flags.FLAGS 23 | 24 | # General 25 | flags.DEFINE_bool("adversarial", True, "Use Adversarial Autoencoder or regular Autoencoder") 26 | flags.DEFINE_bool("train", False, "Train") 27 | flags.DEFINE_bool("reconstruct", False, "Reconstruct image") 28 | flags.DEFINE_bool("generate", False, "Generate image from latent") 29 | flags.DEFINE_bool("generate_grid", False, "Generate grid of images from latent space (only for 2D latent)") 30 | flags.DEFINE_bool("plot", False, "Plot latent space") 31 | flags.DEFINE_integer("latent_dim", 2, "Latent dimension") 32 | 33 | # Train 34 | flags.DEFINE_integer("epochs", 50, "Number of training epochs") 35 | flags.DEFINE_integer("train_samples", 10000, "Number of training samples from MNIST") 36 | flags.DEFINE_integer("batchsize", 100, "Training batchsize") 37 | 38 | # Test 39 | flags.DEFINE_integer("test_samples", 10000, "Number of test samples from MNIST") 40 | flags.DEFINE_list("latent_vec", None, "Latent vector (use with --generate flag)") 41 | 42 | 43 | def create_model(input_dim, latent_dim, verbose=False, save_graph=False): 44 | 45 | autoencoder_input = Input(shape=(input_dim,)) 46 | generator_input = Input(shape=(input_dim,)) 47 | 48 | encoder = Sequential() 49 | encoder.add(Dense(1000, input_shape=(input_dim,), activation='relu')) 50 | encoder.add(Dense(1000, activation='relu')) 51 | encoder.add(Dense(latent_dim, activation=None)) 52 | 53 | decoder = Sequential() 54 | decoder.add(Dense(1000, input_shape=(latent_dim,), activation='relu')) 55 | decoder.add(Dense(1000, activation='relu')) 56 | decoder.add(Dense(input_dim, activation='sigmoid')) 57 | 58 | if FLAGS.adversarial: 59 | discriminator = Sequential() 60 | discriminator.add(Dense(1000, input_shape=(latent_dim,), activation='relu')) 61 | discriminator.add(Dense(1000, activation='relu')) 62 | discriminator.add(Dense(1, activation='sigmoid')) 63 | 64 | autoencoder = Model(autoencoder_input, decoder(encoder(autoencoder_input))) 65 | autoencoder.compile(optimizer=Adam(lr=1e-4), loss="mean_squared_error") 66 | 67 | if FLAGS.adversarial: 68 | discriminator.compile(optimizer=Adam(lr=1e-4), loss="binary_crossentropy") 69 | discriminator.trainable = False 70 | generator = Model(generator_input, discriminator(encoder(generator_input))) 71 | generator.compile(optimizer=Adam(lr=1e-4), loss="binary_crossentropy") 72 | 73 | if verbose: 74 | print("Autoencoder Architecture") 75 | print(autoencoder.summary()) 76 | if FLAGS.adversarial: 77 | print("Discriminator Architecture") 78 | print(discriminator.summary()) 79 | print("Generator Architecture") 80 | print(generator.summary()) 81 | 82 | if save_graph: 83 | plot_model(autoencoder, to_file="autoencoder_graph.png") 84 | if FLAGS.adversarial: 85 | plot_model(discriminator, to_file="discriminator_graph.png") 86 | plot_model(generator, to_file="generator_graph.png") 87 | 88 | if FLAGS.adversarial: 89 | return autoencoder, discriminator, generator, encoder, decoder 90 | else: 91 | return autoencoder, None, None, encoder, decoder 92 | 93 | def train(n_samples, batch_size, n_epochs): 94 | autoencoder, discriminator, generator, encoder, decoder = create_model(input_dim=784, latent_dim=FLAGS.latent_dim) 95 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 96 | # Get n_samples/10 samples from each class 97 | x_classes = {} 98 | y_classes = {} 99 | for i in np.arange(10): 100 | x_classes[i] = x_train[np.where(y_train == i), :, :][0][:int(n_samples / 10), :, :] 101 | y_classes[i] = np.ones(int(n_samples / 10)) * i 102 | x = np.concatenate((list(x_classes.values()))) 103 | y = np.concatenate((list(y_classes.values()))) 104 | x = x.reshape(-1, 784) 105 | normalize = colors.Normalize(0., 255.) 106 | x = normalize(x) 107 | 108 | rand_x = np.random.RandomState(42) 109 | rand_y = np.random.RandomState(42) 110 | 111 | past = datetime.now() 112 | for epoch in np.arange(1, n_epochs + 1): 113 | autoencoder_losses = [] 114 | if FLAGS.adversarial: 115 | discriminator_losses = [] 116 | generator_losses = [] 117 | rand_x.shuffle(x) 118 | rand_y.shuffle(y) 119 | for batch in np.arange(len(x) / batch_size): 120 | start = int(batch * batch_size) 121 | end = int(start + batch_size) 122 | samples = x[start:end] 123 | autoencoder_history = autoencoder.fit(x=samples, y=samples, epochs=1, batch_size=batch_size, validation_split=0.0, verbose=0) 124 | if FLAGS.adversarial: 125 | fake_latent = encoder.predict(samples) 126 | discriminator_input = np.concatenate((fake_latent, np.random.randn(batch_size, FLAGS.latent_dim) * 5.)) 127 | discriminator_labels = np.concatenate((np.zeros((batch_size, 1)), np.ones((batch_size, 1)))) 128 | discriminator_history = discriminator.fit(x=discriminator_input, y=discriminator_labels, epochs=1, batch_size=batch_size, validation_split=0.0, verbose=0) 129 | generator_history = generator.fit(x=samples, y=np.ones((batch_size, 1)), epochs=1, batch_size=batch_size, validation_split=0.0, verbose=0) 130 | 131 | autoencoder_losses.append(autoencoder_history.history["loss"]) 132 | if FLAGS.adversarial: 133 | discriminator_losses.append(discriminator_history.history["loss"]) 134 | generator_losses.append(generator_history.history["loss"]) 135 | now = datetime.now() 136 | print("\nEpoch {}/{} - {:.1f}s".format(epoch, n_epochs, (now - past).total_seconds())) 137 | print("Autoencoder Loss: {}".format(np.mean(autoencoder_losses))) 138 | if FLAGS.adversarial: 139 | print("Discriminator Loss: {}".format(np.mean(discriminator_losses))) 140 | print("Generator Loss: {}".format(np.mean(generator_losses))) 141 | past = now 142 | 143 | if epoch % 50 == 0: 144 | print("\nSaving models...") 145 | # autoencoder.save('{}_autoencoder.h5'.format(desc)) 146 | encoder.save('{}_encoder.h5'.format(desc)) 147 | decoder.save('{}_decoder.h5'.format(desc)) 148 | # if FLAGS.adversarial: 149 | # discriminator.save('{}_discriminator.h5'.format(desc)) 150 | # generator.save('{}_generator.h5'.format(desc)) 151 | 152 | # autoencoder.save('{}_autoencoder.h5'.format(desc)) 153 | encoder.save('{}_encoder.h5'.format(desc)) 154 | decoder.save('{}_decoder.h5'.format(desc)) 155 | # if FLAGS.adversarial: 156 | # discriminator.save('{}_discriminator.h5'.format(desc)) 157 | # generator.save('{}_generator.h5'.format(desc)) 158 | 159 | def reconstruct(n_samples): 160 | encoder = load_model('{}_encoder.h5'.format(desc)) 161 | decoder = load_model('{}_decoder.h5'.format(desc)) 162 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 163 | choice = np.random.choice(np.arange(n_samples)) 164 | original = x_test[choice].reshape(1, 784) 165 | normalize = colors.Normalize(0., 255.) 166 | original = normalize(original) 167 | latent = encoder.predict(original) 168 | reconstruction = decoder.predict(latent) 169 | draw([{"title": "Original", "image": original}, {"title": "Reconstruction", "image": reconstruction}]) 170 | 171 | def generate(latent=None): 172 | decoder = load_model('{}_decoder.h5'.format(desc)) 173 | if latent is None: 174 | latent = np.random.randn(1, FLAGS.latent_dim) 175 | else: 176 | latent = np.array(latent) 177 | sample = decoder.predict(latent.reshape(1, FLAGS.latent_dim)) 178 | draw([{"title": "Sample", "image": sample}]) 179 | 180 | def draw(samples): 181 | fig = plt.figure(figsize=(5 * len(samples), 5)) 182 | gs = gridspec.GridSpec(1, len(samples)) 183 | for i, sample in enumerate(samples): 184 | ax = plt.Subplot(fig, gs[i]) 185 | ax.imshow((sample["image"] * 255.).reshape(28, 28), cmap='gray') 186 | ax.set_xticks([]) 187 | ax.set_yticks([]) 188 | ax.set_aspect('equal') 189 | ax.set_title(sample["title"]) 190 | fig.add_subplot(ax) 191 | plt.show(block=False) 192 | raw_input("Press Enter to Exit") 193 | 194 | def generate_grid(latent=None): 195 | decoder = load_model('{}_decoder.h5'.format(desc)) 196 | samples = [] 197 | for i in np.arange(400): 198 | latent = np.array([(i % 20) * 1.5 - 15., 15. - (i / 20) * 1.5]) 199 | samples.append({ 200 | "image": decoder.predict(latent.reshape(1, FLAGS.latent_dim)) 201 | }) 202 | draw_grid(samples) 203 | 204 | def draw_grid(samples): 205 | fig = plt.figure(figsize=(15, 15)) 206 | gs = gridspec.GridSpec(20, 20, wspace=-.5, hspace=0) 207 | for i, sample in enumerate(samples): 208 | ax = plt.Subplot(fig, gs[i]) 209 | ax.imshow((sample["image"] * 255.).reshape(28, 28), cmap='gray') 210 | ax.set_xticks([]) 211 | ax.set_yticks([]) 212 | ax.set_aspect('equal') 213 | # ax.set_title(sample["title"]) 214 | fig.add_subplot(ax) 215 | plt.show(block=False) 216 | raw_input("Press Enter to Exit") 217 | # fig.savefig("images/{}_grid.png".format(desc), bbox_inches="tight", dpi=300) 218 | 219 | def plot(n_samples): 220 | encoder = load_model('{}_encoder.h5'.format(desc)) 221 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 222 | x = x_test[:n_samples].reshape(n_samples, 784) 223 | y = y_test[:n_samples] 224 | normalize = colors.Normalize(0., 255.) 225 | x = normalize(x) 226 | latent = encoder.predict(x) 227 | if FLAGS.latent_dim > 2: 228 | tsne = TSNE() 229 | print("\nFitting t-SNE, this will take awhile...") 230 | latent = tsne.fit_transform(latent) 231 | fig, ax = plt.subplots() 232 | for label in np.arange(10): 233 | ax.scatter(latent[(y_test == label), 0], latent[(y_test == label), 1], label=label, s=3) 234 | ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 235 | ax.set_aspect('equal') 236 | ax.set_title("Latent Space") 237 | plt.show(block=False) 238 | raw_input("Press Enter to Exit") 239 | # fig.savefig("images/{}_latent.png".format(desc), bbox_inches="tight", dpi=300) 240 | 241 | def main(argv): 242 | global desc 243 | if FLAGS.adversarial: 244 | desc = "aae" 245 | else: 246 | desc = "regular" 247 | if FLAGS.train: 248 | train(n_samples=FLAGS.train_samples, batch_size=FLAGS.batchsize, n_epochs=FLAGS.epochs) 249 | elif FLAGS.reconstruct: 250 | reconstruct(n_samples=FLAGS.test_samples) 251 | elif FLAGS.generate: 252 | if FLAGS.latent_vec: 253 | assert len(FLAGS.latent_vec) == FLAGS.latent_dim, "Latent vector provided is of dim {}; required dim is {}".format(len(FLAGS.latent_vec), FLAGS.latent_dim) 254 | generate(FLAGS.latent_vec) 255 | else: 256 | generate() 257 | elif FLAGS.generate_grid: 258 | generate_grid() 259 | elif FLAGS.plot: 260 | plot(FLAGS.test_samples) 261 | 262 | 263 | if __name__ == "__main__": 264 | app.run(main) 265 | -------------------------------------------------------------------------------- /regular_decoder.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/greentfrapp/keras-aae/758d0e72b6d49a8d34579f153ad9aee24ebac282/regular_decoder.h5 -------------------------------------------------------------------------------- /regular_encoder.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/greentfrapp/keras-aae/758d0e72b6d49a8d34579f153ad9aee24ebac282/regular_encoder.h5 --------------------------------------------------------------------------------