├── .pylintrc ├── .gitignore ├── README.md ├── requirements.txt └── gan.py /.pylintrc: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .png 2 | images 3 | gan_olivetti.py 4 | env 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Simple implementation of Adversarial Neural Network (GAN) using keras 3 | Simple keras implementation of a generative adversarial neural network, as described in https://arxiv.org/abs/1406.2661 4 | 5 | # How to run this repo 6 | If you have virtualen installed run the following commands from the repo folder 7 | 8 | ``` 9 | virtualenv env 10 | 11 | source env/bin/ativate 12 | 13 | pip install -r requirements.txt 14 | 15 | ``` 16 | then you can run the code by simply: 17 | ``` 18 | python gan.py 19 | 20 | ``` 21 | 22 | if you dont have virtual env installed you can install it like this: 23 | 24 | ``` 25 | pip install virtualenv 26 | 27 | ``` 28 | # Medium article 29 | see the companion article on Medium : https://medium.com/@mattiaspinelli/simple-generative-adversarial-network-gans-with-keras-1fe578e44a87 30 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.1.10 2 | backports.functools-lru-cache==1.4 3 | backports.shutil-get-terminal-size==1.0.0 4 | backports.weakref==1.0.post1 5 | bleach==1.5.0 6 | cycler==0.10.0 7 | decorator==4.2.1 8 | enum34==1.1.6 9 | funcsigs==1.0.2 10 | futures==3.2.0 11 | html5lib==0.9999999 12 | ipython==5.5.0 13 | ipython-genutils==0.2.0 14 | Keras==2.1.2 15 | Markdown==2.6.11 16 | matplotlib==2.1.2 17 | mock==2.0.0 18 | numpy==1.14.0 19 | pathlib2==2.3.0 20 | pbr==3.1.1 21 | pexpect==4.3.1 22 | pickleshare==0.7.4 23 | prompt-toolkit==1.0.15 24 | protobuf==3.5.1 25 | ptyprocess==0.5.2 26 | Pygments==2.2.0 27 | pyparsing==2.2.0 28 | python-dateutil==2.6.1 29 | pytz==2017.3 30 | PyYAML==3.12 31 | scandir==1.6 32 | scipy==1.0.0 33 | simplegeneric==0.8.1 34 | six==1.11.0 35 | subprocess32==3.2.7 36 | tensorflow==1.5.0 37 | tensorflow-tensorboard==1.5.0 38 | traitlets==4.3.2 39 | wcwidth==0.1.7 40 | Werkzeug==0.14.1 41 | -------------------------------------------------------------------------------- /gan.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ Simple implementation of Generative Adversarial Neural Network """ 3 | import os 4 | import numpy as np 5 | 6 | from IPython.core.debugger import Tracer 7 | 8 | from keras.datasets import mnist 9 | from keras.layers import Input, Dense, Reshape, Flatten, Dropout 10 | from keras.layers import BatchNormalization 11 | from keras.layers.advanced_activations import LeakyReLU 12 | from keras.models import Sequential 13 | from keras.optimizers import Adam 14 | 15 | import matplotlib.pyplot as plt 16 | plt.switch_backend('agg') # allows code to run without a system DISPLAY 17 | 18 | 19 | class GAN(object): 20 | """ Generative Adversarial Network class """ 21 | def __init__(self, width=28, height=28, channels=1): 22 | 23 | self.width = width 24 | self.height = height 25 | self.channels = channels 26 | 27 | self.shape = (self.width, self.height, self.channels) 28 | 29 | self.optimizer = Adam(lr=0.0002, beta_1=0.5, decay=8e-8) 30 | 31 | self.G = self.__generator() 32 | self.G.compile(loss='binary_crossentropy', optimizer=self.optimizer) 33 | 34 | self.D = self.__discriminator() 35 | self.D.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=['accuracy']) 36 | 37 | self.stacked_generator_discriminator = self.__stacked_generator_discriminator() 38 | 39 | self.stacked_generator_discriminator.compile(loss='binary_crossentropy', optimizer=self.optimizer) 40 | 41 | 42 | def __generator(self): 43 | """ Declare generator """ 44 | 45 | model = Sequential() 46 | model.add(Dense(256, input_shape=(100,))) 47 | model.add(LeakyReLU(alpha=0.2)) 48 | model.add(BatchNormalization(momentum=0.8)) 49 | model.add(Dense(512)) 50 | model.add(LeakyReLU(alpha=0.2)) 51 | model.add(BatchNormalization(momentum=0.8)) 52 | model.add(Dense(1024)) 53 | model.add(LeakyReLU(alpha=0.2)) 54 | model.add(BatchNormalization(momentum=0.8)) 55 | model.add(Dense(self.width * self.height * self.channels, activation='tanh')) 56 | model.add(Reshape((self.width, self.height, self.channels))) 57 | 58 | return model 59 | 60 | def __discriminator(self): 61 | """ Declare discriminator """ 62 | 63 | model = Sequential() 64 | model.add(Flatten(input_shape=self.shape)) 65 | model.add(Dense((self.width * self.height * self.channels), input_shape=self.shape)) 66 | model.add(LeakyReLU(alpha=0.2)) 67 | model.add(Dense(np.int64((self.width * self.height * self.channels)/2))) 68 | model.add(LeakyReLU(alpha=0.2)) 69 | model.add(Dense(1, activation='sigmoid')) 70 | model.summary() 71 | 72 | return model 73 | 74 | def __stacked_generator_discriminator(self): 75 | 76 | self.D.trainable = False 77 | 78 | model = Sequential() 79 | model.add(self.G) 80 | model.add(self.D) 81 | 82 | return model 83 | 84 | def train(self, X_train, epochs=20000, batch = 32, save_interval = 100): 85 | 86 | for cnt in range(epochs): 87 | 88 | ## train discriminator 89 | random_index = np.random.randint(0, len(X_train) - np.int64(batch/2)) 90 | legit_images = X_train[random_index : random_index + np.int64(batch/2)].reshape(np.int64(batch/2), self.width, self.height, self.channels) 91 | 92 | gen_noise = np.random.normal(0, 1, (np.int64(batch/2), 100)) 93 | syntetic_images = self.G.predict(gen_noise) 94 | 95 | x_combined_batch = np.concatenate((legit_images, syntetic_images)) 96 | y_combined_batch = np.concatenate((np.ones((np.int64(batch/2), 1)), np.zeros((np.int64(batch/2), 1)))) 97 | 98 | d_loss = self.D.train_on_batch(x_combined_batch, y_combined_batch) 99 | 100 | 101 | # train generator 102 | 103 | noise = np.random.normal(0, 1, (batch, 100)) 104 | y_mislabled = np.ones((batch, 1)) 105 | 106 | g_loss = self.stacked_generator_discriminator.train_on_batch(noise, y_mislabled) 107 | 108 | print ('epoch: %d, [Discriminator :: d_loss: %f], [ Generator :: loss: %f]' % (cnt, d_loss[0], g_loss)) 109 | 110 | if cnt % save_interval == 0: 111 | self.plot_images(save2file=True, step=cnt) 112 | 113 | 114 | def plot_images(self, save2file=False, samples=16, step=0): 115 | ''' Plot and generated images ''' 116 | if not os.path.exists("./images"): 117 | os.makedirs("./images") 118 | filename = "./images/mnist_%d.png" % step 119 | noise = np.random.normal(0, 1, (samples, 100)) 120 | 121 | images = self.G.predict(noise) 122 | 123 | plt.figure(figsize=(10, 10)) 124 | 125 | for i in range(images.shape[0]): 126 | plt.subplot(4, 4, i+1) 127 | image = images[i, :, :, :] 128 | image = np.reshape(image, [self.height, self.width]) 129 | plt.imshow(image, cmap='gray') 130 | plt.axis('off') 131 | plt.tight_layout() 132 | 133 | if save2file: 134 | plt.savefig(filename) 135 | plt.close('all') 136 | else: 137 | plt.show() 138 | 139 | 140 | if __name__ == '__main__': 141 | (X_train, _), (_, _) = mnist.load_data() 142 | 143 | # Rescale -1 to 1 144 | X_train = (X_train.astype(np.float32) - 127.5) / 127.5 145 | X_train = np.expand_dims(X_train, axis=3) 146 | 147 | 148 | gan = GAN() 149 | gan.train(X_train) 150 | --------------------------------------------------------------------------------