├── .gitignore
├── README.md
├── aae_decoder.h5
├── aae_encoder.h5
├── images
├── aae_grid.png
├── aae_latent.png
├── regular_grid.png
└── regular_latent.png
├── keras-aae.py
├── regular_decoder.h5
└── regular_encoder.h5
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # keras-aae
2 |
3 | Reproduces Adversarial Autoencoder architecture from [Makhzani, Alireza, et al. "Adversarial autoencoders." arXiv preprint arXiv:1511.05644 (2015)](https://arxiv.org/abs/1511.05644) with Keras.
4 |
5 | ## Summary
6 |
7 | The Adversarial Autoencoder behaves similarly to [Variational Autoencoders](https://arxiv.org/abs/1312.6114), forcing the latent space of an autoencoder to follow a predefined prior. In the case of the Adversarial Autoencoder, this latent space can be defined arbitrarily and easily sampled and fed into the Discriminator in the network.
8 |
9 |
10 |

11 |

12 |
13 |
14 | *The left image shows the latent space of an unseen MNIST test set after training with an Adversarial Autoencoder for 50 epochs, which follows a 2D Gaussian prior. Contrast this with the latent space of the regular Autoencoder trained under the same conditions, with a far more irregular latent distribution.*
15 |
16 | ## Instructions
17 |
18 | To train a model just run
19 |
20 | ```
21 | $ python keras-aae.py --train
22 | ```
23 |
24 | For more parameters, run with `--help` flag.
25 |
26 | For comparison with a regular autoencoder, run
27 |
28 | ```
29 | $ python regular-ae.py --train --noadversarial
30 | ```
--------------------------------------------------------------------------------
/aae_decoder.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/greentfrapp/keras-aae/758d0e72b6d49a8d34579f153ad9aee24ebac282/aae_decoder.h5
--------------------------------------------------------------------------------
/aae_encoder.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/greentfrapp/keras-aae/758d0e72b6d49a8d34579f153ad9aee24ebac282/aae_encoder.h5
--------------------------------------------------------------------------------
/images/aae_grid.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/greentfrapp/keras-aae/758d0e72b6d49a8d34579f153ad9aee24ebac282/images/aae_grid.png
--------------------------------------------------------------------------------
/images/aae_latent.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/greentfrapp/keras-aae/758d0e72b6d49a8d34579f153ad9aee24ebac282/images/aae_latent.png
--------------------------------------------------------------------------------
/images/regular_grid.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/greentfrapp/keras-aae/758d0e72b6d49a8d34579f153ad9aee24ebac282/images/regular_grid.png
--------------------------------------------------------------------------------
/images/regular_latent.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/greentfrapp/keras-aae/758d0e72b6d49a8d34579f153ad9aee24ebac282/images/regular_latent.png
--------------------------------------------------------------------------------
/keras-aae.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | try:
3 | raw_input
4 | except:
5 | raw_input = input
6 |
7 | import numpy as np
8 | from keras.models import Sequential, Model, load_model
9 | from keras.layers import Input, Dense
10 | from keras.utils import plot_model
11 | from keras.datasets import mnist
12 | from keras.optimizers import Adam
13 | import argparse
14 | import matplotlib.pyplot as plt
15 | from matplotlib import gridspec, colors
16 | from datetime import datetime
17 | from sklearn.manifold import TSNE
18 | from absl import flags
19 | from absl import app
20 |
21 |
22 | FLAGS = flags.FLAGS
23 |
24 | # General
25 | flags.DEFINE_bool("adversarial", True, "Use Adversarial Autoencoder or regular Autoencoder")
26 | flags.DEFINE_bool("train", False, "Train")
27 | flags.DEFINE_bool("reconstruct", False, "Reconstruct image")
28 | flags.DEFINE_bool("generate", False, "Generate image from latent")
29 | flags.DEFINE_bool("generate_grid", False, "Generate grid of images from latent space (only for 2D latent)")
30 | flags.DEFINE_bool("plot", False, "Plot latent space")
31 | flags.DEFINE_integer("latent_dim", 2, "Latent dimension")
32 |
33 | # Train
34 | flags.DEFINE_integer("epochs", 50, "Number of training epochs")
35 | flags.DEFINE_integer("train_samples", 10000, "Number of training samples from MNIST")
36 | flags.DEFINE_integer("batchsize", 100, "Training batchsize")
37 |
38 | # Test
39 | flags.DEFINE_integer("test_samples", 10000, "Number of test samples from MNIST")
40 | flags.DEFINE_list("latent_vec", None, "Latent vector (use with --generate flag)")
41 |
42 |
43 | def create_model(input_dim, latent_dim, verbose=False, save_graph=False):
44 |
45 | autoencoder_input = Input(shape=(input_dim,))
46 | generator_input = Input(shape=(input_dim,))
47 |
48 | encoder = Sequential()
49 | encoder.add(Dense(1000, input_shape=(input_dim,), activation='relu'))
50 | encoder.add(Dense(1000, activation='relu'))
51 | encoder.add(Dense(latent_dim, activation=None))
52 |
53 | decoder = Sequential()
54 | decoder.add(Dense(1000, input_shape=(latent_dim,), activation='relu'))
55 | decoder.add(Dense(1000, activation='relu'))
56 | decoder.add(Dense(input_dim, activation='sigmoid'))
57 |
58 | if FLAGS.adversarial:
59 | discriminator = Sequential()
60 | discriminator.add(Dense(1000, input_shape=(latent_dim,), activation='relu'))
61 | discriminator.add(Dense(1000, activation='relu'))
62 | discriminator.add(Dense(1, activation='sigmoid'))
63 |
64 | autoencoder = Model(autoencoder_input, decoder(encoder(autoencoder_input)))
65 | autoencoder.compile(optimizer=Adam(lr=1e-4), loss="mean_squared_error")
66 |
67 | if FLAGS.adversarial:
68 | discriminator.compile(optimizer=Adam(lr=1e-4), loss="binary_crossentropy")
69 | discriminator.trainable = False
70 | generator = Model(generator_input, discriminator(encoder(generator_input)))
71 | generator.compile(optimizer=Adam(lr=1e-4), loss="binary_crossentropy")
72 |
73 | if verbose:
74 | print("Autoencoder Architecture")
75 | print(autoencoder.summary())
76 | if FLAGS.adversarial:
77 | print("Discriminator Architecture")
78 | print(discriminator.summary())
79 | print("Generator Architecture")
80 | print(generator.summary())
81 |
82 | if save_graph:
83 | plot_model(autoencoder, to_file="autoencoder_graph.png")
84 | if FLAGS.adversarial:
85 | plot_model(discriminator, to_file="discriminator_graph.png")
86 | plot_model(generator, to_file="generator_graph.png")
87 |
88 | if FLAGS.adversarial:
89 | return autoencoder, discriminator, generator, encoder, decoder
90 | else:
91 | return autoencoder, None, None, encoder, decoder
92 |
93 | def train(n_samples, batch_size, n_epochs):
94 | autoencoder, discriminator, generator, encoder, decoder = create_model(input_dim=784, latent_dim=FLAGS.latent_dim)
95 | (x_train, y_train), (x_test, y_test) = mnist.load_data()
96 | # Get n_samples/10 samples from each class
97 | x_classes = {}
98 | y_classes = {}
99 | for i in np.arange(10):
100 | x_classes[i] = x_train[np.where(y_train == i), :, :][0][:int(n_samples / 10), :, :]
101 | y_classes[i] = np.ones(int(n_samples / 10)) * i
102 | x = np.concatenate((list(x_classes.values())))
103 | y = np.concatenate((list(y_classes.values())))
104 | x = x.reshape(-1, 784)
105 | normalize = colors.Normalize(0., 255.)
106 | x = normalize(x)
107 |
108 | rand_x = np.random.RandomState(42)
109 | rand_y = np.random.RandomState(42)
110 |
111 | past = datetime.now()
112 | for epoch in np.arange(1, n_epochs + 1):
113 | autoencoder_losses = []
114 | if FLAGS.adversarial:
115 | discriminator_losses = []
116 | generator_losses = []
117 | rand_x.shuffle(x)
118 | rand_y.shuffle(y)
119 | for batch in np.arange(len(x) / batch_size):
120 | start = int(batch * batch_size)
121 | end = int(start + batch_size)
122 | samples = x[start:end]
123 | autoencoder_history = autoencoder.fit(x=samples, y=samples, epochs=1, batch_size=batch_size, validation_split=0.0, verbose=0)
124 | if FLAGS.adversarial:
125 | fake_latent = encoder.predict(samples)
126 | discriminator_input = np.concatenate((fake_latent, np.random.randn(batch_size, FLAGS.latent_dim) * 5.))
127 | discriminator_labels = np.concatenate((np.zeros((batch_size, 1)), np.ones((batch_size, 1))))
128 | discriminator_history = discriminator.fit(x=discriminator_input, y=discriminator_labels, epochs=1, batch_size=batch_size, validation_split=0.0, verbose=0)
129 | generator_history = generator.fit(x=samples, y=np.ones((batch_size, 1)), epochs=1, batch_size=batch_size, validation_split=0.0, verbose=0)
130 |
131 | autoencoder_losses.append(autoencoder_history.history["loss"])
132 | if FLAGS.adversarial:
133 | discriminator_losses.append(discriminator_history.history["loss"])
134 | generator_losses.append(generator_history.history["loss"])
135 | now = datetime.now()
136 | print("\nEpoch {}/{} - {:.1f}s".format(epoch, n_epochs, (now - past).total_seconds()))
137 | print("Autoencoder Loss: {}".format(np.mean(autoencoder_losses)))
138 | if FLAGS.adversarial:
139 | print("Discriminator Loss: {}".format(np.mean(discriminator_losses)))
140 | print("Generator Loss: {}".format(np.mean(generator_losses)))
141 | past = now
142 |
143 | if epoch % 50 == 0:
144 | print("\nSaving models...")
145 | # autoencoder.save('{}_autoencoder.h5'.format(desc))
146 | encoder.save('{}_encoder.h5'.format(desc))
147 | decoder.save('{}_decoder.h5'.format(desc))
148 | # if FLAGS.adversarial:
149 | # discriminator.save('{}_discriminator.h5'.format(desc))
150 | # generator.save('{}_generator.h5'.format(desc))
151 |
152 | # autoencoder.save('{}_autoencoder.h5'.format(desc))
153 | encoder.save('{}_encoder.h5'.format(desc))
154 | decoder.save('{}_decoder.h5'.format(desc))
155 | # if FLAGS.adversarial:
156 | # discriminator.save('{}_discriminator.h5'.format(desc))
157 | # generator.save('{}_generator.h5'.format(desc))
158 |
159 | def reconstruct(n_samples):
160 | encoder = load_model('{}_encoder.h5'.format(desc))
161 | decoder = load_model('{}_decoder.h5'.format(desc))
162 | (x_train, y_train), (x_test, y_test) = mnist.load_data()
163 | choice = np.random.choice(np.arange(n_samples))
164 | original = x_test[choice].reshape(1, 784)
165 | normalize = colors.Normalize(0., 255.)
166 | original = normalize(original)
167 | latent = encoder.predict(original)
168 | reconstruction = decoder.predict(latent)
169 | draw([{"title": "Original", "image": original}, {"title": "Reconstruction", "image": reconstruction}])
170 |
171 | def generate(latent=None):
172 | decoder = load_model('{}_decoder.h5'.format(desc))
173 | if latent is None:
174 | latent = np.random.randn(1, FLAGS.latent_dim)
175 | else:
176 | latent = np.array(latent)
177 | sample = decoder.predict(latent.reshape(1, FLAGS.latent_dim))
178 | draw([{"title": "Sample", "image": sample}])
179 |
180 | def draw(samples):
181 | fig = plt.figure(figsize=(5 * len(samples), 5))
182 | gs = gridspec.GridSpec(1, len(samples))
183 | for i, sample in enumerate(samples):
184 | ax = plt.Subplot(fig, gs[i])
185 | ax.imshow((sample["image"] * 255.).reshape(28, 28), cmap='gray')
186 | ax.set_xticks([])
187 | ax.set_yticks([])
188 | ax.set_aspect('equal')
189 | ax.set_title(sample["title"])
190 | fig.add_subplot(ax)
191 | plt.show(block=False)
192 | raw_input("Press Enter to Exit")
193 |
194 | def generate_grid(latent=None):
195 | decoder = load_model('{}_decoder.h5'.format(desc))
196 | samples = []
197 | for i in np.arange(400):
198 | latent = np.array([(i % 20) * 1.5 - 15., 15. - (i / 20) * 1.5])
199 | samples.append({
200 | "image": decoder.predict(latent.reshape(1, FLAGS.latent_dim))
201 | })
202 | draw_grid(samples)
203 |
204 | def draw_grid(samples):
205 | fig = plt.figure(figsize=(15, 15))
206 | gs = gridspec.GridSpec(20, 20, wspace=-.5, hspace=0)
207 | for i, sample in enumerate(samples):
208 | ax = plt.Subplot(fig, gs[i])
209 | ax.imshow((sample["image"] * 255.).reshape(28, 28), cmap='gray')
210 | ax.set_xticks([])
211 | ax.set_yticks([])
212 | ax.set_aspect('equal')
213 | # ax.set_title(sample["title"])
214 | fig.add_subplot(ax)
215 | plt.show(block=False)
216 | raw_input("Press Enter to Exit")
217 | # fig.savefig("images/{}_grid.png".format(desc), bbox_inches="tight", dpi=300)
218 |
219 | def plot(n_samples):
220 | encoder = load_model('{}_encoder.h5'.format(desc))
221 | (x_train, y_train), (x_test, y_test) = mnist.load_data()
222 | x = x_test[:n_samples].reshape(n_samples, 784)
223 | y = y_test[:n_samples]
224 | normalize = colors.Normalize(0., 255.)
225 | x = normalize(x)
226 | latent = encoder.predict(x)
227 | if FLAGS.latent_dim > 2:
228 | tsne = TSNE()
229 | print("\nFitting t-SNE, this will take awhile...")
230 | latent = tsne.fit_transform(latent)
231 | fig, ax = plt.subplots()
232 | for label in np.arange(10):
233 | ax.scatter(latent[(y_test == label), 0], latent[(y_test == label), 1], label=label, s=3)
234 | ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
235 | ax.set_aspect('equal')
236 | ax.set_title("Latent Space")
237 | plt.show(block=False)
238 | raw_input("Press Enter to Exit")
239 | # fig.savefig("images/{}_latent.png".format(desc), bbox_inches="tight", dpi=300)
240 |
241 | def main(argv):
242 | global desc
243 | if FLAGS.adversarial:
244 | desc = "aae"
245 | else:
246 | desc = "regular"
247 | if FLAGS.train:
248 | train(n_samples=FLAGS.train_samples, batch_size=FLAGS.batchsize, n_epochs=FLAGS.epochs)
249 | elif FLAGS.reconstruct:
250 | reconstruct(n_samples=FLAGS.test_samples)
251 | elif FLAGS.generate:
252 | if FLAGS.latent_vec:
253 | assert len(FLAGS.latent_vec) == FLAGS.latent_dim, "Latent vector provided is of dim {}; required dim is {}".format(len(FLAGS.latent_vec), FLAGS.latent_dim)
254 | generate(FLAGS.latent_vec)
255 | else:
256 | generate()
257 | elif FLAGS.generate_grid:
258 | generate_grid()
259 | elif FLAGS.plot:
260 | plot(FLAGS.test_samples)
261 |
262 |
263 | if __name__ == "__main__":
264 | app.run(main)
265 |
--------------------------------------------------------------------------------
/regular_decoder.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/greentfrapp/keras-aae/758d0e72b6d49a8d34579f153ad9aee24ebac282/regular_decoder.h5
--------------------------------------------------------------------------------
/regular_encoder.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/greentfrapp/keras-aae/758d0e72b6d49a8d34579f153ad9aee24ebac282/regular_encoder.h5
--------------------------------------------------------------------------------