├── .gitignore ├── .travis.yml ├── LICENSE.txt ├── README.rst ├── doc └── images │ ├── aae-epoch-099.png │ ├── bigan-epoch-000.png │ ├── bigan-epoch-099.png │ ├── gan-cifar10-epoch-000.png │ ├── gan-cifar10-epoch-099.png │ ├── gan-epoch-000.png │ └── gan-epoch-099.png ├── examples ├── cifar10_utils.py ├── example_aae.py ├── example_aae_cifar10.py ├── example_bigan.py ├── example_bigan_unrolled.py ├── example_gan.py ├── example_gan_cifar10.py ├── example_gan_convolutional.py ├── example_gan_unrolled.py ├── example_gan_unrolled_hinge.py ├── example_rock_paper_scissors.py ├── image_utils.py └── mnist_utils.py ├── keras_adversarial ├── __init__.py ├── adversarial_model.py ├── adversarial_optimizers.py ├── adversarial_utils.py ├── backend │ ├── __init__.py │ ├── tensorflow_backend.py │ ├── tensorflow_monkeypatch.py │ └── theano_backend.py ├── image_grid.py ├── image_grid_callback.py ├── legacy.py └── unrolled_optimizer.py ├── pytest.ini ├── requirements.txt ├── setup.cfg ├── setup.py └── tests └── integration └── gan_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | /output 91 | /examples/output 92 | /.idea 93 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: required 2 | dist: trusty 3 | language: python 4 | matrix: 5 | include: 6 | - python: 2.7 7 | env: KERAS_BACKEND=theano TEST_MODE=PEP8 8 | - python: 2.7 9 | env: KERAS_BACKEND=theano LEGACY_KERAS=1 10 | - python: 3.5 11 | env: KERAS_BACKEND=theano LEGACY_KERAS=1 12 | - python: 2.7 13 | env: KERAS_BACKEND=theano 14 | - python: 3.5 15 | env: KERAS_BACKEND=theano 16 | - python: 2.7 17 | env: KERAS_BACKEND=tensorflow LEGACY_KERAS=1 18 | - python: 3.5 19 | env: KERAS_BACKEND=tensorflow LEGACY_KERAS=1 20 | - python: 2.7 21 | env: KERAS_BACKEND=tensorflow 22 | - python: 3.5 23 | env: KERAS_BACKEND=tensorflow 24 | install: 25 | # code below is taken from http://conda.pydata.org/docs/travis.html 26 | # We do this conditionally because it saves us some downloading if the 27 | # version is the same. 28 | - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then 29 | wget https://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh; 30 | else 31 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; 32 | fi 33 | - bash miniconda.sh -b -p $HOME/miniconda 34 | - export PATH="$HOME/miniconda/bin:$PATH" 35 | - hash -r 36 | - conda config --set always_yes yes --set changeps1 no 37 | - conda update -q conda 38 | # Useful for debugging any issues with conda 39 | - conda info -a 40 | 41 | - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION numpy scipy matplotlib pandas pytest h5py 42 | - source activate test-environment 43 | - pip install git+git://github.com/Theano/Theano.git 44 | - pip install pytest-pep8 45 | # install PIL for preprocessing tests 46 | - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then 47 | conda install pil; 48 | elif [[ "$TRAVIS_PYTHON_VERSION" == "3.5" ]]; then 49 | conda install Pillow; 50 | fi 51 | 52 | - pip install -e .[tests] 53 | 54 | # install TensorFlow 55 | - pip install tensorflow 56 | - pip install tqdm 57 | - if [[ "$LEGACY_KERAS" == "1" ]]; then 58 | pip install keras==1.2.2; 59 | else 60 | pip install keras; 61 | fi 62 | 63 | # command to run tests 64 | script: 65 | # run keras backend init to initialize backend config 66 | - python -c "import keras.backend" 67 | # create dataset directory to avoid concurrent directory creation at runtime 68 | - mkdir ~/.keras/datasets 69 | # set up keras backend 70 | - sed -i -e 's/"backend":[[:space:]]*"[^"]*/"backend":\ "'$KERAS_BACKEND'/g' ~/.keras/keras.json; 71 | - echo -e "Running tests with the following config:\n$(cat ~/.keras/keras.json)" 72 | - if [[ "$TEST_MODE" == "PEP8" ]]; then 73 | PYTHONPATH=$PWD:$PYTHONPATH py.test --pep8 -m pep8; 74 | else 75 | PYTHONPATH=$PWD:$PYTHONPATH py.test tests/; 76 | fi 77 | after_success: 78 | - coveralls 79 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2017 Benjamin Striner 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit 6 | persons to whom the Software is furnished to do so, subject to the following conditions: 7 | 8 | The above copyright notice and this permission notice shall be included in all copies or substantial portions 9 | of the Software. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE 12 | WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 13 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 14 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | Keras Adversarial Models 2 | ======================== 3 | 4 | **Combine multiple models into a single Keras model. GANs made easy!** 5 | 6 | ``AdversarialModel`` simulates multi-player games. A single call to 7 | ``model.fit`` takes targets for each player and updates all of the 8 | players. Use ``AdversarialOptimizer`` for complete control of whether 9 | updates are simultaneous, alternating, or something else entirely. No 10 | more fooling with ``Trainable`` either! 11 | 12 | Installation 13 | ------------ 14 | 15 | .. code:: shell 16 | 17 | git clone https://github.com/bstriner/keras_adversarial.git 18 | cd keras_adversarial 19 | python setup.py install 20 | 21 | Usage 22 | ----- 23 | 24 | Please check the examples folder for exemplary usage. 25 | 26 | Instantiating an adversarial model 27 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 28 | 29 | - Build separate models for each component / player such as generator 30 | and discriminator. 31 | - Build a combined model. For a GAN, this might have an input for 32 | images and an input for noise and an output for D(fake) and an output 33 | for D(real) 34 | - Pass the combined model and the separate models to the 35 | ``AdversarialModel`` constructor 36 | 37 | .. code:: python 38 | 39 | adversarial_model = AdversarialModel(base_model=gan, 40 | player_params=[generator.trainable_weights, discriminator.trainable_weights], 41 | player_names=["generator", "discriminator"]) 42 | 43 | The resulting model will have the same inputs as ``gan`` but separate 44 | targets and metrics for each player. This is accomplished by copying the 45 | model for each player. If each player has a different model, use 46 | ``player_models`` (see below regarding dropout). 47 | 48 | .. code:: python 49 | 50 | adversarial_model = AdversarialModel(player_models=[gan_g, gan_d], 51 | player_params=[generator.trainable_weights, discriminator.trainable_weights], 52 | player_names=["generator", "discriminator"]) 53 | 54 | Compiling an adversarial model 55 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 56 | 57 | Use ``adversarial_compile`` to compile the model. The parameters are an 58 | ``AdversarialOptimizer`` and a list of ``Optimizer`` objects for each 59 | player. The loss is passed to ``model.compile`` for each model, so may 60 | be a dictionary or other object. Use the same order for 61 | ``player_optimizers`` as you did for ``player_params`` and 62 | ``player_names``. 63 | 64 | .. code:: python 65 | 66 | model.adversarial_compile(adversarial_optimizer=adversarial_optimizer, 67 | player_optimizers=[Adam(1e-4, decay=1e-4), Adam(1e-3, decay=1e-4)], 68 | loss='binary_crossentropy') 69 | 70 | Training a simple adversarial model 71 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 72 | 73 | Adversarial models can be trained using ``fit`` and callbacks just like 74 | any other Keras model. Just make sure to provide the correct targets in 75 | the correct order. 76 | 77 | For example, given simple GAN named ``gan``: 78 | 79 | - Inputs: ``[x]`` 80 | - Targets: ``[y_fake, y_real]`` 81 | - Metrics: ``[loss, loss_y_fake, loss_y_real]`` 82 | 83 | ``AdversarialModel(base_model=gan, player_names=['g', 'd']...)`` will have: 84 | 85 | - Inputs: ``[x]`` 86 | - Targets: ``[g_y_fake, g_y_real, d_y_fake, d_y_real]`` 87 | - Metrics: ``[loss, g_loss, g_loss_y_fake, g_loss_y_real, d_loss, d_loss_y_fake, d_loss_y_real]`` 88 | 89 | Adversarial Optimizers 90 | ---------------------- 91 | 92 | There are many possible strategies for optimizing multiplayer games. 93 | ``AdversarialOptimizer`` is a base class that abstracts those strategies 94 | and is responsible for creating the training function. 95 | 96 | - ``AdversarialOptimizerSimultaneous`` updates each player simultaneously on each batch. 97 | - ``AdversarialOptimizerAlternating`` updates each player in a round-robin. 98 | Take each batch and run that batch through each of the models. All models are trained on each batch. 99 | - ``AdversarialOptimizerScheduled`` passes each batch to a different player according to a schedule. 100 | ``[1,1,0]`` would mean train player 1 on batches 0,1,3,4,6,7,etc. and train player 0 on batches 2,5,8,etc. 101 | - ``UnrolledAdversarialOptimizer`` unrolls updates to stabilize training 102 | (only tested in Theano; slow to build graph but runs reasonably fast) 103 | 104 | Examples 105 | -------- 106 | 107 | MNIST Generative Adversarial Network (GAN) 108 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 109 | 110 | `example\_gan.py `__ 111 | shows how to create a GAN in Keras for the MNIST dataset. 112 | 113 | .. figure:: https://github.com/bstriner/keras_adversarial/raw/master/doc/images/gan-epoch-099.png 114 | :alt: Example GAN 115 | 116 | Example GAN 117 | 118 | CIFAR10 Generative Adversarial Network (GAN) 119 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 120 | 121 | `example\_gan\_cifar10.py `__ 122 | shows how to create a GAN in Keras for the CIFAR10 dataset. 123 | 124 | .. figure:: https://github.com/bstriner/keras_adversarial/raw/master/doc/images/gan-cifar10-epoch-099.png 125 | :alt: Example GAN 126 | 127 | Example GAN 128 | 129 | MNIST Bi-Directional Generative Adversarial Network (BiGAN) 130 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 131 | 132 | `example\_bigan.py `__ 133 | shows how to create a BiGAN in Keras. 134 | 135 | .. figure:: https://github.com/bstriner/keras_adversarial/raw/master/doc/images/bigan-epoch-099.png 136 | :alt: Example BiGAN 137 | 138 | Example BiGAN 139 | 140 | MNIST Adversarial Autoencoder (AAE) 141 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 142 | 143 | An AAE is like a cross between a GAN and a Variational Autoencoder 144 | (VAE). 145 | `example\_aae.py `__ 146 | shows how to create an AAE in Keras. 147 | 148 | .. figure:: https://github.com/bstriner/keras_adversarial/raw/master/doc/images/aae-epoch-099.png 149 | :alt: Example AAE 150 | 151 | Example AAE 152 | 153 | Unrolled Generative Adversarial Network 154 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 155 | 156 | `example\_gan\_unrolled.py `__ 157 | shows how to use the unrolled optimizer. 158 | 159 | WARNING: Unrolling the discriminator 8 times takes about 6 hours to 160 | build the function on my computer, but only a few minutes for epoch of 161 | training. Be prepared to let it run a long time or turn the depth down 162 | to around 4. 163 | 164 | Notes 165 | ----- 166 | 167 | Dropout 168 | ~~~~~~~ 169 | 170 | When training adversarial models using dropout, you may want to create 171 | separate models for each player. 172 | 173 | If you want to train a discriminator with dropout, but train the 174 | generator against the discriminator without dropout, create two models. 175 | \* GAN to train generator: ``D(G(z, dropout=0.5), dropout=0)`` \* GAN to 176 | train discriminator: ``D(G(z, dropout=0), dropout=0.5)`` 177 | 178 | If you create separate models, use ``player_models`` parameter of 179 | ``AdversarialModel`` constructor. 180 | 181 | If you aren't using dropout, one model is sufficient, and use 182 | ``base_model`` parameter of ``AdversarialModel`` constructor, which will 183 | duplicate the ``base_model`` for each player. 184 | 185 | Theano and Tensorflow 186 | ~~~~~~~~~~~~~~~~~~~~~ 187 | 188 | I do most of my development in theano but try to test tensorflow when I 189 | have extra time. The goal is to support both. Please let me know any 190 | issues you have with either backend. 191 | 192 | Questions? 193 | ~~~~~~~~~~ 194 | 195 | Feel free to start an issue or a PR here or in Keras if you are having 196 | any issues or think of something that might be useful. 197 | -------------------------------------------------------------------------------- /doc/images/aae-epoch-099.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bstriner/keras-adversarial/6651cfad771f72521c78a5cc3a23a2313efeaa88/doc/images/aae-epoch-099.png -------------------------------------------------------------------------------- /doc/images/bigan-epoch-000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bstriner/keras-adversarial/6651cfad771f72521c78a5cc3a23a2313efeaa88/doc/images/bigan-epoch-000.png -------------------------------------------------------------------------------- /doc/images/bigan-epoch-099.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bstriner/keras-adversarial/6651cfad771f72521c78a5cc3a23a2313efeaa88/doc/images/bigan-epoch-099.png -------------------------------------------------------------------------------- /doc/images/gan-cifar10-epoch-000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bstriner/keras-adversarial/6651cfad771f72521c78a5cc3a23a2313efeaa88/doc/images/gan-cifar10-epoch-000.png -------------------------------------------------------------------------------- /doc/images/gan-cifar10-epoch-099.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bstriner/keras-adversarial/6651cfad771f72521c78a5cc3a23a2313efeaa88/doc/images/gan-cifar10-epoch-099.png -------------------------------------------------------------------------------- /doc/images/gan-epoch-000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bstriner/keras-adversarial/6651cfad771f72521c78a5cc3a23a2313efeaa88/doc/images/gan-epoch-000.png -------------------------------------------------------------------------------- /doc/images/gan-epoch-099.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bstriner/keras-adversarial/6651cfad771f72521c78a5cc3a23a2313efeaa88/doc/images/gan-epoch-099.png -------------------------------------------------------------------------------- /examples/cifar10_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from keras.datasets import cifar10 3 | 4 | 5 | def cifar10_process(x): 6 | x = x.astype(np.float32) / 255.0 7 | return x 8 | 9 | 10 | def cifar10_data(): 11 | (xtrain, ytrain), (xtest, ytest) = cifar10.load_data() 12 | return cifar10_process(xtrain), cifar10_process(xtest) 13 | -------------------------------------------------------------------------------- /examples/example_aae.py: -------------------------------------------------------------------------------- 1 | # import os 2 | # os.environ["THEANO_FLAGS"] = "mode=FAST_COMPILE,device=cpu,floatX=float32" 3 | 4 | import matplotlib as mpl 5 | 6 | # This line allows mpl to run with no DISPLAY defined 7 | mpl.use('Agg') 8 | 9 | from keras.layers import Dense, Reshape, Flatten, Input, merge 10 | from keras.models import Sequential, Model 11 | from keras.optimizers import Adam 12 | from keras_adversarial.legacy import l1l2 13 | import keras.backend as K 14 | import pandas as pd 15 | import numpy as np 16 | from keras_adversarial.image_grid_callback import ImageGridCallback 17 | 18 | from keras_adversarial import AdversarialModel, fix_names, n_choice 19 | from keras_adversarial import AdversarialOptimizerSimultaneous, normal_latent_sampling 20 | from mnist_utils import mnist_data 21 | from keras.layers import LeakyReLU, Activation 22 | import os 23 | 24 | 25 | def model_generator(latent_dim, input_shape, hidden_dim=512, reg=lambda: l1l2(1e-7, 0)): 26 | return Sequential([ 27 | Dense(hidden_dim, name="generator_h1", input_dim=latent_dim, W_regularizer=reg()), 28 | LeakyReLU(0.2), 29 | Dense(hidden_dim, name="generator_h2", W_regularizer=reg()), 30 | LeakyReLU(0.2), 31 | Dense(np.prod(input_shape), name="generator_x_flat", W_regularizer=reg()), 32 | Activation('sigmoid'), 33 | Reshape(input_shape, name="generator_x")], 34 | name="generator") 35 | 36 | 37 | def model_encoder(latent_dim, input_shape, hidden_dim=512, reg=lambda: l1l2(1e-7, 0)): 38 | x = Input(input_shape, name="x") 39 | h = Flatten()(x) 40 | h = Dense(hidden_dim, name="encoder_h1", W_regularizer=reg())(h) 41 | h = LeakyReLU(0.2)(h) 42 | h = Dense(hidden_dim, name="encoder_h2", W_regularizer=reg())(h) 43 | h = LeakyReLU(0.2)(h) 44 | mu = Dense(latent_dim, name="encoder_mu", W_regularizer=reg())(h) 45 | log_sigma_sq = Dense(latent_dim, name="encoder_log_sigma_sq", W_regularizer=reg())(h) 46 | z = merge([mu, log_sigma_sq], mode=lambda p: p[0] + K.random_normal(K.shape(p[0])) * K.exp(p[1] / 2), 47 | output_shape=lambda p: p[0]) 48 | return Model(x, z, name="encoder") 49 | 50 | 51 | def model_discriminator(latent_dim, output_dim=1, hidden_dim=512, 52 | reg=lambda: l1l2(1e-7, 1e-7)): 53 | z = Input((latent_dim,)) 54 | h = z 55 | h = Dense(hidden_dim, name="discriminator_h1", W_regularizer=reg())(h) 56 | h = LeakyReLU(0.2)(h) 57 | h = Dense(hidden_dim, name="discriminator_h2", W_regularizer=reg())(h) 58 | h = LeakyReLU(0.2)(h) 59 | y = Dense(output_dim, name="discriminator_y", activation="sigmoid", W_regularizer=reg())(h) 60 | return Model(z, y) 61 | 62 | 63 | def example_aae(path, adversarial_optimizer): 64 | # z \in R^100 65 | latent_dim = 100 66 | # x \in R^{28x28} 67 | input_shape = (28, 28) 68 | 69 | # generator (z -> x) 70 | generator = model_generator(latent_dim, input_shape) 71 | # encoder (x ->z) 72 | encoder = model_encoder(latent_dim, input_shape) 73 | # autoencoder (x -> x') 74 | autoencoder = Model(encoder.inputs, generator(encoder(encoder.inputs))) 75 | # discriminator (z -> y) 76 | discriminator = model_discriminator(latent_dim) 77 | 78 | # assemple AAE 79 | x = encoder.inputs[0] 80 | z = encoder(x) 81 | xpred = generator(z) 82 | zreal = normal_latent_sampling((latent_dim,))(x) 83 | yreal = discriminator(zreal) 84 | yfake = discriminator(z) 85 | aae = Model(x, fix_names([xpred, yfake, yreal], ["xpred", "yfake", "yreal"])) 86 | 87 | # print summary of models 88 | generator.summary() 89 | encoder.summary() 90 | discriminator.summary() 91 | autoencoder.summary() 92 | 93 | # build adversarial model 94 | generative_params = generator.trainable_weights + encoder.trainable_weights 95 | model = AdversarialModel(base_model=aae, 96 | player_params=[generative_params, discriminator.trainable_weights], 97 | player_names=["generator", "discriminator"]) 98 | model.adversarial_compile(adversarial_optimizer=adversarial_optimizer, 99 | player_optimizers=[Adam(1e-4, decay=1e-4), Adam(1e-3, decay=1e-4)], 100 | loss={"yfake": "binary_crossentropy", "yreal": "binary_crossentropy", 101 | "xpred": "mean_squared_error"}, 102 | player_compile_kwargs=[{"loss_weights": {"yfake": 1e-2, "yreal": 1e-2, "xpred": 1}}] * 2) 103 | 104 | # load mnist data 105 | xtrain, xtest = mnist_data() 106 | 107 | # callback for image grid of generated samples 108 | def generator_sampler(): 109 | zsamples = np.random.normal(size=(10 * 10, latent_dim)) 110 | return generator.predict(zsamples).reshape((10, 10, 28, 28)) 111 | 112 | generator_cb = ImageGridCallback(os.path.join(path, "generated-epoch-{:03d}.png"), generator_sampler) 113 | 114 | # callback for image grid of autoencoded samples 115 | def autoencoder_sampler(): 116 | xsamples = n_choice(xtest, 10) 117 | xrep = np.repeat(xsamples, 9, axis=0) 118 | xgen = autoencoder.predict(xrep).reshape((10, 9, 28, 28)) 119 | xsamples = xsamples.reshape((10, 1, 28, 28)) 120 | samples = np.concatenate((xsamples, xgen), axis=1) 121 | return samples 122 | 123 | autoencoder_cb = ImageGridCallback(os.path.join(path, "autoencoded-epoch-{:03d}.png"), autoencoder_sampler) 124 | 125 | # train network 126 | # generator, discriminator; pred, yfake, yreal 127 | n = xtrain.shape[0] 128 | y = [xtrain, np.ones((n, 1)), np.zeros((n, 1)), xtrain, np.zeros((n, 1)), np.ones((n, 1))] 129 | ntest = xtest.shape[0] 130 | ytest = [xtest, np.ones((ntest, 1)), np.zeros((ntest, 1)), xtest, np.zeros((ntest, 1)), np.ones((ntest, 1))] 131 | history = model.fit(x=xtrain, y=y, validation_data=(xtest, ytest), callbacks=[generator_cb, autoencoder_cb], 132 | nb_epoch=100, batch_size=32) 133 | 134 | # save history 135 | df = pd.DataFrame(history.history) 136 | df.to_csv(os.path.join(path, "history.csv")) 137 | 138 | # save model 139 | encoder.save(os.path.join(path, "encoder.h5")) 140 | generator.save(os.path.join(path, "generator.h5")) 141 | discriminator.save(os.path.join(path, "discriminator.h5")) 142 | 143 | 144 | def main(): 145 | example_aae("output/aae", AdversarialOptimizerSimultaneous()) 146 | 147 | 148 | if __name__ == "__main__": 149 | main() 150 | -------------------------------------------------------------------------------- /examples/example_aae_cifar10.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | 3 | # This line allows mpl to run with no DISPLAY defined 4 | mpl.use('Agg') 5 | 6 | from keras.layers import Reshape, Flatten, Lambda 7 | from keras.layers import Input 8 | from keras.layers.convolutional import UpSampling2D, MaxPooling2D 9 | from keras.models import Sequential, Model 10 | from keras.optimizers import Adam 11 | import keras.backend as K 12 | import pandas as pd 13 | import numpy as np 14 | from keras_adversarial.image_grid_callback import ImageGridCallback 15 | from keras_adversarial.legacy import l1l2, Dense, fit, Convolution2D 16 | from keras_adversarial import AdversarialModel, fix_names, n_choice 17 | from keras_adversarial import AdversarialOptimizerSimultaneous, normal_latent_sampling 18 | from cifar10_utils import cifar10_data 19 | from keras.layers import LeakyReLU, Activation 20 | from image_utils import dim_ordering_unfix, dim_ordering_shape 21 | import os 22 | 23 | 24 | def model_generator(latent_dim, units=512, dropout=0.5, reg=lambda: l1l2(l1=1e-7, l2=1e-7)): 25 | model = Sequential(name="decoder") 26 | h = 5 27 | model.add(Dense(units * 4 * 4, input_dim=latent_dim, W_regularizer=reg())) 28 | model.add(Reshape(dim_ordering_shape((units, 4, 4)))) 29 | # model.add(SpatialDropout2D(dropout)) 30 | model.add(LeakyReLU(0.2)) 31 | model.add(Convolution2D(units / 2, h, h, border_mode='same', W_regularizer=reg())) 32 | # model.add(SpatialDropout2D(dropout)) 33 | model.add(LeakyReLU(0.2)) 34 | model.add(UpSampling2D(size=(2, 2))) 35 | model.add(Convolution2D(units / 2, h, h, border_mode='same', W_regularizer=reg())) 36 | # model.add(SpatialDropout2D(dropout)) 37 | model.add(LeakyReLU(0.2)) 38 | model.add(UpSampling2D(size=(2, 2))) 39 | model.add(Convolution2D(units / 4, h, h, border_mode='same', W_regularizer=reg())) 40 | # model.add(SpatialDropout2D(dropout)) 41 | model.add(LeakyReLU(0.2)) 42 | model.add(UpSampling2D(size=(2, 2))) 43 | model.add(Convolution2D(3, h, h, border_mode='same', W_regularizer=reg())) 44 | model.add(Activation('sigmoid')) 45 | return model 46 | 47 | 48 | def model_encoder(latent_dim, input_shape, units=512, reg=lambda: l1l2(l1=1e-7, l2=1e-7), dropout=0.5): 49 | k = 5 50 | x = Input(input_shape) 51 | h = Convolution2D(units / 4, k, k, border_mode='same', W_regularizer=reg())(x) 52 | # h = SpatialDropout2D(dropout)(h) 53 | h = MaxPooling2D(pool_size=(2, 2))(h) 54 | h = LeakyReLU(0.2)(h) 55 | h = Convolution2D(units / 2, k, k, border_mode='same', W_regularizer=reg())(h) 56 | # h = SpatialDropout2D(dropout)(h) 57 | h = MaxPooling2D(pool_size=(2, 2))(h) 58 | h = LeakyReLU(0.2)(h) 59 | h = Convolution2D(units / 2, k, k, border_mode='same', W_regularizer=reg())(h) 60 | # h = SpatialDropout2D(dropout)(h) 61 | h = MaxPooling2D(pool_size=(2, 2))(h) 62 | h = LeakyReLU(0.2)(h) 63 | h = Convolution2D(units, k, k, border_mode='same', W_regularizer=reg())(h) 64 | # h = SpatialDropout2D(dropout)(h) 65 | h = LeakyReLU(0.2)(h) 66 | h = Flatten()(h) 67 | mu = Dense(latent_dim, name="encoder_mu", W_regularizer=reg())(h) 68 | log_sigma_sq = Dense(latent_dim, name="encoder_log_sigma_sq", W_regularizer=reg())(h) 69 | z = Lambda(lambda (_mu, _lss): _mu + K.random_normal(K.shape(_mu)) * K.exp(_lss / 2), 70 | output_shape=lambda (_mu, _lss): _mu)([mu, log_sigma_sq]) 71 | return Model(x, z, name="encoder") 72 | 73 | 74 | def model_discriminator(latent_dim, output_dim=1, units=256, reg=lambda: l1l2(1e-7, 1e-7)): 75 | z = Input((latent_dim,)) 76 | h = z 77 | mode = 1 78 | h = Dense(units, name="discriminator_h1", W_regularizer=reg())(h) 79 | # h = BatchNormalization(mode=mode)(h) 80 | h = LeakyReLU(0.2)(h) 81 | h = Dense(units / 2, name="discriminator_h2", W_regularizer=reg())(h) 82 | # h = BatchNormalization(mode=mode)(h) 83 | h = LeakyReLU(0.2)(h) 84 | h = Dense(units / 2, name="discriminator_h3", W_regularizer=reg())(h) 85 | # h = BatchNormalization(mode=mode)(h) 86 | h = LeakyReLU(0.2)(h) 87 | y = Dense(output_dim, name="discriminator_y", activation="sigmoid", W_regularizer=reg())(h) 88 | return Model(z, y) 89 | 90 | 91 | def example_aae(path, adversarial_optimizer): 92 | # z \in R^100 93 | latent_dim = 256 94 | units = 512 95 | # x \in R^{28x28} 96 | input_shape = dim_ordering_shape((3, 32, 32)) 97 | 98 | # generator (z -> x) 99 | generator = model_generator(latent_dim, units=units) 100 | # encoder (x ->z) 101 | encoder = model_encoder(latent_dim, input_shape, units=units) 102 | # autoencoder (x -> x') 103 | autoencoder = Model(encoder.inputs, generator(encoder(encoder.inputs))) 104 | # discriminator (z -> y) 105 | discriminator = model_discriminator(latent_dim, units=units) 106 | 107 | # build AAE 108 | x = encoder.inputs[0] 109 | z = encoder(x) 110 | xpred = generator(z) 111 | zreal = normal_latent_sampling((latent_dim,))(x) 112 | yreal = discriminator(zreal) 113 | yfake = discriminator(z) 114 | aae = Model(x, fix_names([xpred, yfake, yreal], ["xpred", "yfake", "yreal"])) 115 | 116 | # print summary of models 117 | generator.summary() 118 | encoder.summary() 119 | discriminator.summary() 120 | autoencoder.summary() 121 | 122 | # build adversarial model 123 | generative_params = generator.trainable_weights + encoder.trainable_weights 124 | model = AdversarialModel(base_model=aae, 125 | player_params=[generative_params, discriminator.trainable_weights], 126 | player_names=["generator", "discriminator"]) 127 | model.adversarial_compile(adversarial_optimizer=adversarial_optimizer, 128 | player_optimizers=[Adam(3e-4, decay=1e-4), Adam(1e-3, decay=1e-4)], 129 | loss={"yfake": "binary_crossentropy", "yreal": "binary_crossentropy", 130 | "xpred": "mean_squared_error"}, 131 | player_compile_kwargs=[{"loss_weights": {"yfake": 1e-1, "yreal": 1e-1, 132 | "xpred": 1e2}}] * 2) 133 | 134 | # load mnist data 135 | xtrain, xtest = cifar10_data() 136 | 137 | # callback for image grid of generated samples 138 | def generator_sampler(): 139 | zsamples = np.random.normal(size=(10 * 10, latent_dim)) 140 | return dim_ordering_unfix(generator.predict(zsamples)).transpose((0, 2, 3, 1)).reshape((10, 10, 32, 32, 3)) 141 | 142 | generator_cb = ImageGridCallback(os.path.join(path, "generated-epoch-{:03d}.png"), generator_sampler) 143 | 144 | # callback for image grid of autoencoded samples 145 | def autoencoder_sampler(): 146 | xsamples = n_choice(xtest, 10) 147 | xrep = np.repeat(xsamples, 9, axis=0) 148 | xgen = dim_ordering_unfix(autoencoder.predict(xrep)).reshape((10, 9, 3, 32, 32)) 149 | xsamples = dim_ordering_unfix(xsamples).reshape((10, 1, 3, 32, 32)) 150 | samples = np.concatenate((xsamples, xgen), axis=1) 151 | samples = samples.transpose((0, 1, 3, 4, 2)) 152 | return samples 153 | 154 | autoencoder_cb = ImageGridCallback(os.path.join(path, "autoencoded-epoch-{:03d}.png"), autoencoder_sampler, 155 | cmap=None) 156 | 157 | # train network 158 | # generator, discriminator; pred, yfake, yreal 159 | n = xtrain.shape[0] 160 | y = [xtrain, np.ones((n, 1)), np.zeros((n, 1)), xtrain, np.zeros((n, 1)), np.ones((n, 1))] 161 | ntest = xtest.shape[0] 162 | ytest = [xtest, np.ones((ntest, 1)), np.zeros((ntest, 1)), xtest, np.zeros((ntest, 1)), np.ones((ntest, 1))] 163 | history = fit(model, x=xtrain, y=y, validation_data=(xtest, ytest), 164 | callbacks=[generator_cb, autoencoder_cb], 165 | nb_epoch=100, batch_size=32) 166 | 167 | # save history 168 | df = pd.DataFrame(history.history) 169 | df.to_csv(os.path.join(path, "history.csv")) 170 | 171 | # save model 172 | encoder.save(os.path.join(path, "encoder.h5")) 173 | generator.save(os.path.join(path, "generator.h5")) 174 | discriminator.save(os.path.join(path, "discriminator.h5")) 175 | 176 | 177 | def main(): 178 | example_aae("output/aae-cifar10", AdversarialOptimizerSimultaneous()) 179 | 180 | 181 | if __name__ == "__main__": 182 | main() 183 | -------------------------------------------------------------------------------- /examples/example_bigan.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | 3 | # This line allows mpl to run with no DISPLAY defined 4 | mpl.use('Agg') 5 | 6 | from keras.layers import Dense, Flatten, Input, merge, Dropout 7 | from keras.models import Model 8 | from keras.optimizers import Adam 9 | from keras_adversarial.legacy import l1l2 10 | import keras.backend as K 11 | import pandas as pd 12 | import numpy as np 13 | from keras_adversarial.image_grid_callback import ImageGridCallback 14 | from keras_adversarial import AdversarialModel, gan_targets, fix_names, n_choice, simple_bigan 15 | from keras_adversarial import AdversarialOptimizerSimultaneous, normal_latent_sampling 16 | from mnist_utils import mnist_data 17 | from example_gan import model_generator 18 | from keras.layers import BatchNormalization, LeakyReLU 19 | import os 20 | 21 | 22 | def model_encoder(latent_dim, input_shape, hidden_dim=1024, reg=lambda: l1l2(1e-5, 0), batch_norm_mode=0): 23 | x = Input(input_shape, name="x") 24 | h = Flatten()(x) 25 | h = Dense(hidden_dim, name="encoder_h1", W_regularizer=reg())(h) 26 | h = BatchNormalization(mode=batch_norm_mode)(h) 27 | h = LeakyReLU(0.2)(h) 28 | h = Dense(hidden_dim / 2, name="encoder_h2", W_regularizer=reg())(h) 29 | h = BatchNormalization(mode=batch_norm_mode)(h) 30 | h = LeakyReLU(0.2)(h) 31 | h = Dense(hidden_dim / 4, name="encoder_h3", W_regularizer=reg())(h) 32 | h = BatchNormalization(mode=batch_norm_mode)(h) 33 | h = LeakyReLU(0.2)(h) 34 | mu = Dense(latent_dim, name="encoder_mu", W_regularizer=reg())(h) 35 | log_sigma_sq = Dense(latent_dim, name="encoder_log_sigma_sq", W_regularizer=reg())(h) 36 | z = merge([mu, log_sigma_sq], mode=lambda p: p[0] + K.random_normal(K.shape(p[0])) * K.exp(p[1] / 2), 37 | output_shape=lambda x: x[0]) 38 | return Model(x, z, name="encoder") 39 | 40 | 41 | def model_discriminator(latent_dim, input_shape, output_dim=1, hidden_dim=2048, 42 | reg=lambda: l1l2(1e-7, 1e-7), batch_norm_mode=1, dropout=0.5): 43 | z = Input((latent_dim,)) 44 | x = Input(input_shape, name="x") 45 | h = merge([z, Flatten()(x)], mode='concat') 46 | 47 | h1 = Dense(hidden_dim, name="discriminator_h1", W_regularizer=reg()) 48 | b1 = BatchNormalization(mode=batch_norm_mode) 49 | h2 = Dense(hidden_dim, name="discriminator_h2", W_regularizer=reg()) 50 | b2 = BatchNormalization(mode=batch_norm_mode) 51 | h3 = Dense(hidden_dim, name="discriminator_h3", W_regularizer=reg()) 52 | b3 = BatchNormalization(mode=batch_norm_mode) 53 | y = Dense(output_dim, name="discriminator_y", activation="sigmoid", W_regularizer=reg()) 54 | 55 | # training model uses dropout 56 | _h = h 57 | _h = Dropout(dropout)(LeakyReLU(0.2)((b1(h1(_h))))) 58 | _h = Dropout(dropout)(LeakyReLU(0.2)((b2(h2(_h))))) 59 | _h = Dropout(dropout)(LeakyReLU(0.2)((b3(h3(_h))))) 60 | ytrain = y(_h) 61 | mtrain = Model([z, x], ytrain, name="discriminator_train") 62 | 63 | # testing model does not use dropout 64 | _h = h 65 | _h = LeakyReLU(0.2)((b1(h1(_h)))) 66 | _h = LeakyReLU(0.2)((b2(h2(_h)))) 67 | _h = LeakyReLU(0.2)((b3(h3(_h)))) 68 | ytest = y(_h) 69 | mtest = Model([z, x], ytest, name="discriminator_test") 70 | 71 | return mtrain, mtest 72 | 73 | 74 | def example_bigan(path, adversarial_optimizer): 75 | # z \in R^100 76 | latent_dim = 25 77 | # x \in R^{28x28} 78 | input_shape = (28, 28) 79 | 80 | # generator (z -> x) 81 | generator = model_generator(latent_dim, input_shape) 82 | # encoder (x ->z) 83 | encoder = model_encoder(latent_dim, input_shape) 84 | # autoencoder (x -> x') 85 | autoencoder = Model(encoder.inputs, generator(encoder(encoder.inputs))) 86 | # discriminator (x -> y) 87 | discriminator_train, discriminator_test = model_discriminator(latent_dim, input_shape) 88 | # bigan (z, x - > yfake, yreal) 89 | bigan_generator = simple_bigan(generator, encoder, discriminator_test) 90 | bigan_discriminator = simple_bigan(generator, encoder, discriminator_train) 91 | # z generated on GPU based on batch dimension of x 92 | x = bigan_generator.inputs[1] 93 | z = normal_latent_sampling((latent_dim,))(x) 94 | # eliminate z from inputs 95 | bigan_generator = Model([x], fix_names(bigan_generator([z, x]), bigan_generator.output_names)) 96 | bigan_discriminator = Model([x], fix_names(bigan_discriminator([z, x]), bigan_discriminator.output_names)) 97 | 98 | generative_params = generator.trainable_weights + encoder.trainable_weights 99 | 100 | # print summary of models 101 | generator.summary() 102 | encoder.summary() 103 | discriminator_train.summary() 104 | bigan_discriminator.summary() 105 | autoencoder.summary() 106 | 107 | # build adversarial model 108 | model = AdversarialModel(player_models=[bigan_generator, bigan_discriminator], 109 | player_params=[generative_params, discriminator_train.trainable_weights], 110 | player_names=["generator", "discriminator"]) 111 | model.adversarial_compile(adversarial_optimizer=adversarial_optimizer, 112 | player_optimizers=[Adam(1e-4, decay=1e-4), Adam(1e-3, decay=1e-4)], 113 | loss='binary_crossentropy') 114 | 115 | # load mnist data 116 | xtrain, xtest = mnist_data() 117 | 118 | # callback for image grid of generated samples 119 | def generator_sampler(): 120 | zsamples = np.random.normal(size=(10 * 10, latent_dim)) 121 | return generator.predict(zsamples).reshape((10, 10, 28, 28)) 122 | 123 | generator_cb = ImageGridCallback(os.path.join(path, "generated-epoch-{:03d}.png"), generator_sampler) 124 | 125 | # callback for image grid of autoencoded samples 126 | def autoencoder_sampler(): 127 | xsamples = n_choice(xtest, 10) 128 | xrep = np.repeat(xsamples, 9, axis=0) 129 | xgen = autoencoder.predict(xrep).reshape((10, 9, 28, 28)) 130 | xsamples = xsamples.reshape((10, 1, 28, 28)) 131 | x = np.concatenate((xsamples, xgen), axis=1) 132 | return x 133 | 134 | autoencoder_cb = ImageGridCallback(os.path.join(path, "autoencoded-epoch-{:03d}.png"), autoencoder_sampler) 135 | 136 | # train network 137 | y = gan_targets(xtrain.shape[0]) 138 | ytest = gan_targets(xtest.shape[0]) 139 | history = model.fit(x=xtrain, y=y, validation_data=(xtest, ytest), callbacks=[generator_cb, autoencoder_cb], 140 | nb_epoch=100, batch_size=32) 141 | 142 | # save history 143 | df = pd.DataFrame(history.history) 144 | df.to_csv(os.path.join(path, "history.csv")) 145 | 146 | # save model 147 | encoder.save(os.path.join(path, "encoder.h5")) 148 | generator.save(os.path.join(path, "generator.h5")) 149 | discriminator_train.save(os.path.join(path, "discriminator.h5")) 150 | 151 | 152 | def main(): 153 | example_bigan("output/bigan", AdversarialOptimizerSimultaneous()) 154 | 155 | 156 | if __name__ == "__main__": 157 | main() 158 | -------------------------------------------------------------------------------- /examples/example_bigan_unrolled.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | 3 | # This line allows mpl to run with no DISPLAY defined 4 | mpl.use('Agg') 5 | 6 | from keras.layers import Dense, Flatten, Input, merge, Dropout 7 | from keras.models import Model 8 | from keras.optimizers import Adam 9 | from keras.regularizers import l1, l1l2 10 | import keras.backend as K 11 | import pandas as pd 12 | import numpy as np 13 | from keras_adversarial.image_grid_callback import ImageGridCallback 14 | 15 | from keras_adversarial import AdversarialModel, gan_targets, n_choice, simple_bigan 16 | from keras_adversarial import AdversarialOptimizerSimultaneous, normal_latent_sampling 17 | from mnist_utils import mnist_data 18 | from example_gan import model_generator 19 | from keras.layers import BatchNormalization, LeakyReLU 20 | 21 | 22 | def model_encoder(latent_dim, input_shape, hidden_dim=1024, reg=lambda: l1(1e-5), batch_norm_mode=2): 23 | x = Input(input_shape, name="x") 24 | h = Flatten()(x) 25 | h = Dense(hidden_dim, name="encoder_h1", W_regularizer=reg())(h) 26 | h = BatchNormalization(mode=batch_norm_mode)(h) 27 | h = LeakyReLU(0.2)(h) 28 | h = Dense(hidden_dim / 2, name="encoder_h2", W_regularizer=reg())(h) 29 | h = BatchNormalization(mode=batch_norm_mode)(h) 30 | h = LeakyReLU(0.2)(h) 31 | h = Dense(hidden_dim / 4, name="encoder_h3", W_regularizer=reg())(h) 32 | h = BatchNormalization(mode=batch_norm_mode)(h) 33 | h = LeakyReLU(0.2)(h) 34 | mu = Dense(latent_dim, name="encoder_mu", W_regularizer=reg())(h) 35 | log_sigma_sq = Dense(latent_dim, name="encoder_log_sigma_sq", W_regularizer=reg())(h) 36 | z = merge([mu, log_sigma_sq], mode=lambda p: p[0] + K.random_normal(p[0].shape) * K.exp(p[1] / 2), 37 | output_shape=lambda x: x[0]) 38 | return Model(x, z, name="encoder") 39 | 40 | 41 | def model_discriminator(latent_dim, input_shape, output_dim=1, hidden_dim=1024, 42 | reg=lambda: l1l2(1e-4, 1e-4), batch_norm_mode=1): 43 | z = Input((latent_dim,)) 44 | x = Input(input_shape, name="x") 45 | h = merge([z, Flatten()(x)], mode='concat') 46 | h = Dense(hidden_dim, name="discriminator_h1", W_regularizer=reg())(h) 47 | h = BatchNormalization(mode=batch_norm_mode)(h) 48 | h = LeakyReLU(0.2)(h) 49 | h = Dropout(0.5)(h) 50 | h = Dense(hidden_dim / 2, name="discriminator_h2", W_regularizer=reg())(h) 51 | h = BatchNormalization(mode=batch_norm_mode)(h) 52 | h = LeakyReLU(0.2)(h) 53 | h = Dropout(0.5)(h) 54 | h = Dense(hidden_dim / 4, name="discriminator_h3", W_regularizer=reg())(h) 55 | h = BatchNormalization(mode=batch_norm_mode)(h) 56 | h = LeakyReLU(0.2)(h) 57 | h = Dropout(0.5)(h) 58 | y = Dense(output_dim, name="discriminator_y", activation="sigmoid", W_regularizer=reg())(h) 59 | return Model([z, x], y, name="discriminator") 60 | 61 | 62 | def main(): 63 | # z \in R^100 64 | latent_dim = 100 65 | # x \in R^{28x28} 66 | input_shape = (28, 28) 67 | 68 | # generator (z -> x) 69 | generator = model_generator(latent_dim, input_shape) 70 | # encoder (x ->z) 71 | encoder = model_encoder(latent_dim, input_shape) 72 | # autoencoder (x -> x') 73 | autoencoder = Model(encoder.inputs, generator(encoder(encoder.inputs))) 74 | # discriminator (x -> y) 75 | discriminator = model_discriminator(latent_dim, input_shape) 76 | # bigan (x - > yfake, yreal), z generated on GPU 77 | bigan = simple_bigan(generator, encoder, discriminator, normal_latent_sampling((latent_dim,))) 78 | 79 | generative_params = generator.trainable_weights + encoder.trainable_weights 80 | 81 | # print summary of models 82 | generator.summary() 83 | encoder.summary() 84 | discriminator.summary() 85 | bigan.summary() 86 | autoencoder.summary() 87 | 88 | # build adversarial model 89 | model = AdversarialModel(base_model=bigan, 90 | player_params=[generative_params, discriminator.trainable_weights], 91 | player_names=["generator", "discriminator"]) 92 | model.adversarial_compile(adversarial_optimizer=AdversarialOptimizerSimultaneous(), 93 | player_optimizers=[Adam(1e-4, decay=1e-4), Adam(1e-3, decay=1e-4)], 94 | loss='binary_crossentropy') 95 | 96 | # train model 97 | xtrain, xtest = mnist_data() 98 | 99 | def generator_sampler(): 100 | zsamples = np.random.normal(size=(10 * 10, latent_dim)) 101 | return generator.predict(zsamples).reshape((10, 10, 28, 28)) 102 | 103 | generator_cb = ImageGridCallback("output/bigan/generated-epoch-{:03d}.png", generator_sampler) 104 | 105 | def autoencoder_sampler(): 106 | xsamples = n_choice(xtest, 10) 107 | xrep = np.repeat(xsamples, 9, axis=0) 108 | xgen = autoencoder.predict(xrep).reshape((10, 9, 28, 28)) 109 | xsamples = xsamples.reshape((10, 1, 28, 28)) 110 | x = np.concatenate((xsamples, xgen), axis=1) 111 | return x 112 | 113 | autoencoder_cb = ImageGridCallback("output/bigan/autoencoded-epoch-{:03d}.png", autoencoder_sampler) 114 | 115 | y = gan_targets(xtrain.shape[0]) 116 | ytest = gan_targets(xtest.shape[0]) 117 | history = model.fit(x=xtrain, y=y, validation_data=(xtest, ytest), callbacks=[generator_cb, autoencoder_cb], 118 | nb_epoch=100, batch_size=32) 119 | df = pd.DataFrame(history.history) 120 | df.to_csv("output/bigan/history.csv") 121 | 122 | encoder.save("output/bigan/encoder.h5") 123 | generator.save("output/bigan/generator.h5") 124 | discriminator.save("output/bigan/discriminator.h5") 125 | 126 | 127 | if __name__ == "__main__": 128 | main() 129 | -------------------------------------------------------------------------------- /examples/example_gan.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | 3 | # This line allows mpl to run with no DISPLAY defined 4 | mpl.use('Agg') 5 | 6 | import pandas as pd 7 | import numpy as np 8 | import os 9 | from keras.layers import Reshape, Flatten, LeakyReLU, Activation 10 | from keras.models import Sequential 11 | from keras.optimizers import Adam 12 | from keras.callbacks import TensorBoard 13 | from keras_adversarial.image_grid_callback import ImageGridCallback 14 | from keras_adversarial import AdversarialModel, simple_gan, gan_targets 15 | from keras_adversarial import normal_latent_sampling, AdversarialOptimizerSimultaneous 16 | from keras_adversarial.legacy import l1l2, Dense, fit 17 | import keras.backend as K 18 | from mnist_utils import mnist_data 19 | 20 | 21 | def model_generator(latent_dim, input_shape, hidden_dim=1024, reg=lambda: l1l2(1e-5, 1e-5)): 22 | return Sequential([ 23 | Dense(int(hidden_dim / 4), name="generator_h1", input_dim=latent_dim, W_regularizer=reg()), 24 | LeakyReLU(0.2), 25 | Dense(int(hidden_dim / 2), name="generator_h2", W_regularizer=reg()), 26 | LeakyReLU(0.2), 27 | Dense(hidden_dim, name="generator_h3", W_regularizer=reg()), 28 | LeakyReLU(0.2), 29 | Dense(np.prod(input_shape), name="generator_x_flat", W_regularizer=reg()), 30 | Activation('sigmoid'), 31 | Reshape(input_shape, name="generator_x")], 32 | name="generator") 33 | 34 | 35 | def model_discriminator(input_shape, hidden_dim=1024, reg=lambda: l1l2(1e-5, 1e-5), output_activation="sigmoid"): 36 | return Sequential([ 37 | Flatten(name="discriminator_flatten", input_shape=input_shape), 38 | Dense(hidden_dim, name="discriminator_h1", W_regularizer=reg()), 39 | LeakyReLU(0.2), 40 | Dense(int(hidden_dim / 2), name="discriminator_h2", W_regularizer=reg()), 41 | LeakyReLU(0.2), 42 | Dense(int(hidden_dim / 4), name="discriminator_h3", W_regularizer=reg()), 43 | LeakyReLU(0.2), 44 | Dense(1, name="discriminator_y", W_regularizer=reg()), 45 | Activation(output_activation)], 46 | name="discriminator") 47 | 48 | 49 | def example_gan(adversarial_optimizer, path, opt_g, opt_d, nb_epoch, generator, discriminator, latent_dim, 50 | targets=gan_targets, loss='binary_crossentropy'): 51 | csvpath = os.path.join(path, "history.csv") 52 | if os.path.exists(csvpath): 53 | print("Already exists: {}".format(csvpath)) 54 | return 55 | 56 | print("Training: {}".format(csvpath)) 57 | # gan (x - > yfake, yreal), z generated on GPU 58 | gan = simple_gan(generator, discriminator, normal_latent_sampling((latent_dim,))) 59 | 60 | # print summary of models 61 | generator.summary() 62 | discriminator.summary() 63 | gan.summary() 64 | 65 | # build adversarial model 66 | model = AdversarialModel(base_model=gan, 67 | player_params=[generator.trainable_weights, discriminator.trainable_weights], 68 | player_names=["generator", "discriminator"]) 69 | model.adversarial_compile(adversarial_optimizer=adversarial_optimizer, 70 | player_optimizers=[opt_g, opt_d], 71 | loss=loss) 72 | 73 | # create callback to generate images 74 | zsamples = np.random.normal(size=(10 * 10, latent_dim)) 75 | 76 | def generator_sampler(): 77 | return generator.predict(zsamples).reshape((10, 10, 28, 28)) 78 | 79 | generator_cb = ImageGridCallback(os.path.join(path, "epoch-{:03d}.png"), generator_sampler) 80 | 81 | # train model 82 | xtrain, xtest = mnist_data() 83 | y = targets(xtrain.shape[0]) 84 | ytest = targets(xtest.shape[0]) 85 | callbacks = [generator_cb] 86 | if K.backend() == "tensorflow": 87 | callbacks.append( 88 | TensorBoard(log_dir=os.path.join(path, 'logs'), histogram_freq=0, write_graph=True, write_images=True)) 89 | history = fit(model, x=xtrain, y=y, validation_data=(xtest, ytest), callbacks=callbacks, nb_epoch=nb_epoch, 90 | batch_size=32) 91 | 92 | # save history to CSV 93 | df = pd.DataFrame(history.history) 94 | df.to_csv(csvpath) 95 | 96 | # save models 97 | generator.save(os.path.join(path, "generator.h5")) 98 | discriminator.save(os.path.join(path, "discriminator.h5")) 99 | 100 | 101 | def main(): 102 | # z \in R^100 103 | latent_dim = 100 104 | # x \in R^{28x28} 105 | input_shape = (28, 28) 106 | # generator (z -> x) 107 | generator = model_generator(latent_dim, input_shape) 108 | # discriminator (x -> y) 109 | discriminator = model_discriminator(input_shape) 110 | example_gan(AdversarialOptimizerSimultaneous(), "output/gan", 111 | opt_g=Adam(1e-4, decay=1e-4), 112 | opt_d=Adam(1e-3, decay=1e-4), 113 | nb_epoch=100, generator=generator, discriminator=discriminator, 114 | latent_dim=latent_dim) 115 | 116 | 117 | if __name__ == "__main__": 118 | main() 119 | -------------------------------------------------------------------------------- /examples/example_gan_cifar10.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | 3 | # This line allows mpl to run with no DISPLAY defined 4 | mpl.use('Agg') 5 | 6 | import pandas as pd 7 | import numpy as np 8 | import os 9 | from keras.layers import Reshape, Flatten, LeakyReLU, Activation 10 | from keras.layers.convolutional import UpSampling2D, MaxPooling2D 11 | from keras.models import Sequential 12 | from keras.optimizers import Adam 13 | from keras.callbacks import TensorBoard 14 | from keras_adversarial.image_grid_callback import ImageGridCallback 15 | 16 | from keras_adversarial import AdversarialModel, simple_gan, gan_targets 17 | from keras_adversarial import AdversarialOptimizerSimultaneous, normal_latent_sampling 18 | from keras_adversarial.legacy import Dense, BatchNormalization, fit, l1l2, Convolution2D, AveragePooling2D 19 | import keras.backend as K 20 | from cifar10_utils import cifar10_data 21 | from image_utils import dim_ordering_fix, dim_ordering_unfix, dim_ordering_shape 22 | 23 | 24 | def model_generator(): 25 | model = Sequential() 26 | nch = 256 27 | reg = lambda: l1l2(l1=1e-7, l2=1e-7) 28 | h = 5 29 | model.add(Dense(nch * 4 * 4, input_dim=100, W_regularizer=reg())) 30 | model.add(BatchNormalization(mode=0)) 31 | model.add(Reshape(dim_ordering_shape((nch, 4, 4)))) 32 | model.add(Convolution2D(int(nch / 2), h, h, border_mode='same', W_regularizer=reg())) 33 | model.add(BatchNormalization(mode=0, axis=1)) 34 | model.add(LeakyReLU(0.2)) 35 | model.add(UpSampling2D(size=(2, 2))) 36 | model.add(Convolution2D(int(nch / 2), h, h, border_mode='same', W_regularizer=reg())) 37 | model.add(BatchNormalization(mode=0, axis=1)) 38 | model.add(LeakyReLU(0.2)) 39 | model.add(UpSampling2D(size=(2, 2))) 40 | model.add(Convolution2D(int(nch / 4), h, h, border_mode='same', W_regularizer=reg())) 41 | model.add(BatchNormalization(mode=0, axis=1)) 42 | model.add(LeakyReLU(0.2)) 43 | model.add(UpSampling2D(size=(2, 2))) 44 | model.add(Convolution2D(3, h, h, border_mode='same', W_regularizer=reg())) 45 | model.add(Activation('sigmoid')) 46 | return model 47 | 48 | 49 | def model_discriminator(): 50 | nch = 256 51 | h = 5 52 | reg = lambda: l1l2(l1=1e-7, l2=1e-7) 53 | 54 | c1 = Convolution2D(int(nch / 4), h, h, border_mode='same', W_regularizer=reg(), 55 | input_shape=dim_ordering_shape((3, 32, 32))) 56 | c2 = Convolution2D(int(nch / 2), h, h, border_mode='same', W_regularizer=reg()) 57 | c3 = Convolution2D(nch, h, h, border_mode='same', W_regularizer=reg()) 58 | c4 = Convolution2D(1, h, h, border_mode='same', W_regularizer=reg()) 59 | 60 | model = Sequential() 61 | model.add(c1) 62 | model.add(MaxPooling2D(pool_size=(2, 2))) 63 | model.add(LeakyReLU(0.2)) 64 | model.add(c2) 65 | model.add(MaxPooling2D(pool_size=(2, 2))) 66 | model.add(LeakyReLU(0.2)) 67 | model.add(c3) 68 | model.add(MaxPooling2D(pool_size=(2, 2))) 69 | model.add(LeakyReLU(0.2)) 70 | model.add(c4) 71 | model.add(AveragePooling2D(pool_size=(4, 4), border_mode='valid')) 72 | model.add(Flatten()) 73 | model.add(Activation('sigmoid')) 74 | return model 75 | 76 | 77 | def example_gan(adversarial_optimizer, path, opt_g, opt_d, nb_epoch, generator, discriminator, latent_dim, 78 | targets=gan_targets, loss='binary_crossentropy'): 79 | csvpath = os.path.join(path, "history.csv") 80 | if os.path.exists(csvpath): 81 | print("Already exists: {}".format(csvpath)) 82 | return 83 | 84 | print("Training: {}".format(csvpath)) 85 | # gan (x - > yfake, yreal), z is gaussian generated on GPU 86 | # can also experiment with uniform_latent_sampling 87 | generator.summary() 88 | discriminator.summary() 89 | gan = simple_gan(generator=generator, 90 | discriminator=discriminator, 91 | latent_sampling=normal_latent_sampling((latent_dim,))) 92 | 93 | # build adversarial model 94 | model = AdversarialModel(base_model=gan, 95 | player_params=[generator.trainable_weights, discriminator.trainable_weights], 96 | player_names=["generator", "discriminator"]) 97 | model.adversarial_compile(adversarial_optimizer=adversarial_optimizer, 98 | player_optimizers=[opt_g, opt_d], 99 | loss=loss) 100 | 101 | # create callback to generate images 102 | zsamples = np.random.normal(size=(10 * 10, latent_dim)) 103 | 104 | def generator_sampler(): 105 | xpred = dim_ordering_unfix(generator.predict(zsamples)).transpose((0, 2, 3, 1)) 106 | return xpred.reshape((10, 10) + xpred.shape[1:]) 107 | 108 | generator_cb = ImageGridCallback(os.path.join(path, "epoch-{:03d}.png"), generator_sampler, cmap=None) 109 | 110 | # train model 111 | xtrain, xtest = cifar10_data() 112 | y = targets(xtrain.shape[0]) 113 | ytest = targets(xtest.shape[0]) 114 | callbacks = [generator_cb] 115 | if K.backend() == "tensorflow": 116 | callbacks.append( 117 | TensorBoard(log_dir=os.path.join(path, 'logs'), histogram_freq=0, write_graph=True, write_images=True)) 118 | history = fit(model, x=xtrain, y=y, validation_data=(xtest, ytest), 119 | callbacks=callbacks, nb_epoch=nb_epoch, 120 | batch_size=32) 121 | 122 | # save history to CSV 123 | df = pd.DataFrame(history.history) 124 | df.to_csv(csvpath) 125 | 126 | # save models 127 | generator.save(os.path.join(path, "generator.h5")) 128 | discriminator.save(os.path.join(path, "discriminator.h5")) 129 | 130 | 131 | def main(): 132 | # z \in R^100 133 | latent_dim = 100 134 | # x \in R^{28x28} 135 | # generator (z -> x) 136 | generator = model_generator() 137 | # discriminator (x -> y) 138 | discriminator = model_discriminator() 139 | example_gan(AdversarialOptimizerSimultaneous(), "output/gan-cifar10", 140 | opt_g=Adam(1e-4, decay=1e-5), 141 | opt_d=Adam(1e-3, decay=1e-5), 142 | nb_epoch=100, generator=generator, discriminator=discriminator, 143 | latent_dim=latent_dim) 144 | 145 | 146 | if __name__ == "__main__": 147 | main() 148 | -------------------------------------------------------------------------------- /examples/example_gan_convolutional.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | 3 | # This line allows mpl to run with no DISPLAY defined 4 | mpl.use('Agg') 5 | 6 | from keras.layers import Flatten, Dropout, LeakyReLU, Input, Activation 7 | from keras.models import Model 8 | from keras.layers.convolutional import UpSampling2D 9 | from keras.optimizers import Adam 10 | from keras.datasets import mnist 11 | import pandas as pd 12 | import numpy as np 13 | import keras.backend as K 14 | from keras_adversarial.legacy import Dense, BatchNormalization, Convolution2D 15 | from keras_adversarial.image_grid_callback import ImageGridCallback 16 | from keras_adversarial import AdversarialModel, simple_gan, gan_targets 17 | from keras_adversarial import AdversarialOptimizerSimultaneous, normal_latent_sampling 18 | from image_utils import dim_ordering_fix, dim_ordering_input, dim_ordering_reshape, dim_ordering_unfix 19 | 20 | 21 | def leaky_relu(x): 22 | return K.relu(x, 0.2) 23 | 24 | 25 | def model_generator(): 26 | nch = 256 27 | g_input = Input(shape=[100]) 28 | H = Dense(nch * 14 * 14)(g_input) 29 | H = BatchNormalization(mode=2)(H) 30 | H = Activation('relu')(H) 31 | H = dim_ordering_reshape(nch, 14)(H) 32 | H = UpSampling2D(size=(2, 2))(H) 33 | H = Convolution2D(int(nch / 2), 3, 3, border_mode='same')(H) 34 | H = BatchNormalization(mode=2, axis=1)(H) 35 | H = Activation('relu')(H) 36 | H = Convolution2D(int(nch / 4), 3, 3, border_mode='same')(H) 37 | H = BatchNormalization(mode=2, axis=1)(H) 38 | H = Activation('relu')(H) 39 | H = Convolution2D(1, 1, 1, border_mode='same')(H) 40 | g_V = Activation('sigmoid')(H) 41 | return Model(g_input, g_V) 42 | 43 | 44 | def model_discriminator(input_shape=(1, 28, 28), dropout_rate=0.5): 45 | d_input = dim_ordering_input(input_shape, name="input_x") 46 | nch = 512 47 | # nch = 128 48 | H = Convolution2D(int(nch / 2), 5, 5, subsample=(2, 2), border_mode='same', activation='relu')(d_input) 49 | H = LeakyReLU(0.2)(H) 50 | H = Dropout(dropout_rate)(H) 51 | H = Convolution2D(nch, 5, 5, subsample=(2, 2), border_mode='same', activation='relu')(H) 52 | H = LeakyReLU(0.2)(H) 53 | H = Dropout(dropout_rate)(H) 54 | H = Flatten()(H) 55 | H = Dense(int(nch / 2))(H) 56 | H = LeakyReLU(0.2)(H) 57 | H = Dropout(dropout_rate)(H) 58 | d_V = Dense(1, activation='sigmoid')(H) 59 | return Model(d_input, d_V) 60 | 61 | 62 | def mnist_process(x): 63 | x = x.astype(np.float32) / 255.0 64 | return x 65 | 66 | 67 | def mnist_data(): 68 | (xtrain, ytrain), (xtest, ytest) = mnist.load_data() 69 | return mnist_process(xtrain), mnist_process(xtest) 70 | 71 | 72 | def generator_sampler(latent_dim, generator): 73 | def fun(): 74 | zsamples = np.random.normal(size=(10 * 10, latent_dim)) 75 | gen = dim_ordering_unfix(generator.predict(zsamples)) 76 | return gen.reshape((10, 10, 28, 28)) 77 | 78 | return fun 79 | 80 | 81 | if __name__ == "__main__": 82 | # z \in R^100 83 | latent_dim = 100 84 | # x \in R^{28x28} 85 | input_shape = (1, 28, 28) 86 | 87 | # generator (z -> x) 88 | generator = model_generator() 89 | # discriminator (x -> y) 90 | discriminator = model_discriminator(input_shape=input_shape) 91 | # gan (x - > yfake, yreal), z generated on GPU 92 | gan = simple_gan(generator, discriminator, normal_latent_sampling((latent_dim,))) 93 | 94 | # print summary of models 95 | generator.summary() 96 | discriminator.summary() 97 | gan.summary() 98 | 99 | # build adversarial model 100 | model = AdversarialModel(base_model=gan, 101 | player_params=[generator.trainable_weights, discriminator.trainable_weights], 102 | player_names=["generator", "discriminator"]) 103 | model.adversarial_compile(adversarial_optimizer=AdversarialOptimizerSimultaneous(), 104 | player_optimizers=[Adam(1e-4, decay=1e-4), Adam(1e-3, decay=1e-4)], 105 | loss='binary_crossentropy') 106 | 107 | # train model 108 | generator_cb = ImageGridCallback("output/gan_convolutional/epoch-{:03d}.png", 109 | generator_sampler(latent_dim, generator)) 110 | 111 | xtrain, xtest = mnist_data() 112 | xtrain = dim_ordering_fix(xtrain.reshape((-1, 1, 28, 28))) 113 | xtest = dim_ordering_fix(xtest.reshape((-1, 1, 28, 28))) 114 | y = gan_targets(xtrain.shape[0]) 115 | ytest = gan_targets(xtest.shape[0]) 116 | history = model.fit(x=xtrain, y=y, validation_data=(xtest, ytest), callbacks=[generator_cb], nb_epoch=100, 117 | batch_size=32) 118 | df = pd.DataFrame(history.history) 119 | df.to_csv("output/gan_convolutional/history.csv") 120 | 121 | generator.save("output/gan_convolutional/generator.h5") 122 | discriminator.save("output/gan_convolutional/discriminator.h5") 123 | -------------------------------------------------------------------------------- /examples/example_gan_unrolled.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | 3 | # This line allows mpl to run with no DISPLAY defined 4 | mpl.use('Agg') 5 | from example_gan import example_gan 6 | from keras_adversarial.unrolled_optimizer import UnrolledAdversarialOptimizer 7 | from keras.optimizers import Adam 8 | from example_gan import model_generator, model_discriminator 9 | import os 10 | 11 | 12 | def example_gan_unrolled(path, depth_g, depth_d): 13 | # z \in R^100 14 | latent_dim = 100 15 | # x \in R^{28x28} 16 | input_shape = (28, 28) 17 | # generator (z -> x) 18 | generator = model_generator(latent_dim, input_shape, hidden_dim=512, batch_norm_mode=1) 19 | # discriminator (x -> y) 20 | discriminator = model_discriminator(input_shape, hidden_dim=512, dropout=0, batch_norm_mode=1) 21 | example_gan(UnrolledAdversarialOptimizer(depth_g=depth_g, depth_d=depth_d), path, 22 | opt_g=Adam(1e-4, decay=1e-4), 23 | opt_d=Adam(1e-3, decay=1e-4), 24 | nb_epoch=50, generator=generator, discriminator=discriminator, 25 | latent_dim=latent_dim) 26 | 27 | 28 | def example(name, depth_g, depth_d): 29 | path = "output/unrolled_gan" 30 | example_gan_unrolled(os.path.join(path, name), depth_g, depth_d) 31 | 32 | 33 | if __name__ == "__main__": 34 | example("k_0_0", 0, 0) 35 | example("k_8_8", 8, 8) 36 | example("k_16_16", 16, 16) 37 | example("k_8_0", 8, 0) 38 | example("k_0_8", 8, 0) 39 | example("k_16_8", 16, 8) 40 | example("k_32_32", 32, 32) 41 | -------------------------------------------------------------------------------- /examples/example_gan_unrolled_hinge.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | 3 | # This line allows mpl to run with no DISPLAY defined 4 | mpl.use('Agg') 5 | from example_gan import example_gan 6 | from keras_adversarial.unrolled_optimizer import UnrolledAdversarialOptimizer 7 | from keras.optimizers import Adam 8 | from example_gan import model_generator, model_discriminator 9 | from keras_adversarial import gan_targets_hinge 10 | import os 11 | 12 | 13 | def example_gan_unrolled_hinge(path, depth_g, depth_d, clipvalue=2.0): 14 | # z \in R^100 15 | latent_dim = 100 16 | # x \in R^{28x28} 17 | input_shape = (28, 28) 18 | # generator (z -> x) 19 | generator = model_generator(latent_dim, input_shape, hidden_dim=512, batch_norm_mode=-1) 20 | # discriminator (x -> y) 21 | discriminator = model_discriminator(input_shape, output_activation='linear', hidden_dim=512, batch_norm_mode=-1, 22 | dropout=0) 23 | example_gan(UnrolledAdversarialOptimizer(depth_g=depth_g, depth_d=depth_d), path, 24 | opt_g=Adam(1e-4, decay=1e-4, clipvalue=clipvalue), 25 | opt_d=Adam(1e-3, decay=1e-4, clipvalue=clipvalue), 26 | nb_epoch=50, generator=generator, discriminator=discriminator, 27 | latent_dim=latent_dim, loss="squared_hinge", targets=gan_targets_hinge) 28 | 29 | 30 | def example(name, depth_g, depth_d, clipvalue): 31 | path = "output/unrolled_gan_hinge" 32 | example_gan_unrolled_hinge(os.path.join(path, name), depth_g, depth_d, clipvalue) 33 | 34 | 35 | if __name__ == "__main__": 36 | example("k_0_0", 0, 0) 37 | example("k_8_8_clip_2", 8, 8, 2) 38 | example("k_8_8_clip_0.5", 8, 8, 0.5) 39 | example("k_8_8_clip_0", 8, 8, 0) 40 | example("k_16_16", 16, 16) 41 | example("k_16_16_clip_0", 16, 16, 0) 42 | example("k_16_16_clip_0.5", 16, 16, 0.5) 43 | example("k_16_16_clip_10", 16, 16, 10) 44 | example("k_32_32", 32, 32) 45 | example("k_1_1", 1, 1) 46 | example("k_2_0", 2, 0) 47 | example("k_4_0", 4, 0) 48 | example("k_8_0", 8, 0) 49 | -------------------------------------------------------------------------------- /examples/example_rock_paper_scissors.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["THEANO_FLAGS"] = "mode=FAST_COMPILE,device=cpu,floatX=float32" 4 | 5 | """Example of a two player game, rock paper scissors. 6 | 7 | This game does not converge under simple alternating or simultaneous descent, 8 | but converges using UnrolledAdversarialOptimizer. 9 | 10 | """ 11 | from keras_adversarial.adversarial_optimizers import AdversarialOptimizerSimultaneous, AdversarialOptimizerAlternating 12 | from keras_adversarial.unrolled_optimizer import UnrolledAdversarialOptimizer 13 | from keras_adversarial.adversarial_model import AdversarialModel 14 | from keras.layers import Dense, merge, Input 15 | from keras.models import Model 16 | from keras.optimizers import SGD 17 | from keras.callbacks import LambdaCallback 18 | from keras.regularizers import l2 19 | import numpy as np 20 | import matplotlib.pyplot as plt 21 | import os 22 | 23 | 24 | def rps_chart(path, a, b): 25 | """Bar chart of two players in rock, paper, scissors""" 26 | fig, ax = plt.subplots() 27 | n = 3 28 | width = 0.35 29 | pad = 1.0 - 2 * width 30 | ind = np.arange(n) 31 | ba = plt.bar(pad / 2 + ind, a, width=width, color='r') 32 | bb = plt.bar(pad / 2 + ind + width, b, width=width, color='g') 33 | ax.set_ylabel('Frequency') 34 | ax.set_xticks(pad / 2 + ind + width) 35 | ax.set_xticklabels(("Rock", "Paper", "Scissors")) 36 | fig.legend((ba, bb), ("Player A", "Player B")) 37 | ax.set_ylim([0, 1]) 38 | if not os.path.exists(os.path.dirname(path)): 39 | os.makedirs(os.path.dirname(path)) 40 | fig.savefig(path) 41 | plt.close(fig) 42 | 43 | 44 | def experiment(opt, path): 45 | """Train two players to play rock, paper, scissors using a given optimizer""" 46 | x = Input((1,), name="x") 47 | player_a = Dense(3, activation='softmax', name="player_a", bias=False, W_regularizer=l2(1e-2)) 48 | player_b = Dense(3, activation='softmax', name="player_b", bias=False, W_regularizer=l2(1e-2)) 49 | 50 | action_a = player_a(x) 51 | action_b = player_b(x) 52 | 53 | def rps(z): 54 | u = z[0] 55 | v = z[1] 56 | return u[:, 0] * v[:, 2] + u[:, 1] * v[:, 0] + u[:, 2] * v[:, 1] 57 | 58 | model_a = Model(x, merge([action_a, action_b], mode=rps, output_shape=lambda z: (z[0][0], 1))) 59 | model_b = Model(x, merge([action_b, action_a], mode=rps, output_shape=lambda z: (z[0][0], 1))) 60 | 61 | adversarial_model = AdversarialModel(player_models=[model_a, model_b], 62 | player_params=[[player_a.W], [player_b.W]], 63 | player_names=["a", "b"]) 64 | adversarial_model.adversarial_compile(opt, 65 | player_optimizers=[SGD(1), SGD(1)], 66 | loss="mean_absolute_error") 67 | param_model = Model(x, [action_a, action_b]) 68 | 69 | def print_params(epoch, logs): 70 | params = param_model.predict(np.ones((1, 1))) 71 | a = params[0].ravel() 72 | b = params[1].ravel() 73 | print("Epoch: {}, A: {}, B: {}".format(epoch, a, b)) 74 | imgpath = os.path.join(path, "epoch-{:03d}.png".format(epoch)) 75 | rps_chart(imgpath, a, b) 76 | 77 | cb = LambdaCallback(on_epoch_begin=print_params) 78 | batch_count = 5 79 | adversarial_model.fit(np.ones((batch_count, 1)), 80 | [np.ones((batch_count, 1)), np.ones((batch_count, 1))], 81 | nb_epoch=120, callbacks=[cb], verbose=0, batch_size=1) 82 | 83 | 84 | if __name__ == "__main__": 85 | experiment(AdversarialOptimizerSimultaneous(), "output/rock_paper_scissors/simultaneous") 86 | experiment(AdversarialOptimizerAlternating(), "output/rock_paper_scissors/alternating") 87 | experiment(UnrolledAdversarialOptimizer(depth_d=30, depth_g=30), "output/rock_paper_scissors/unrolled") 88 | experiment(UnrolledAdversarialOptimizer(depth_d=0, depth_g=30), "output/rock_paper_scissors/unrolled_player_a") 89 | -------------------------------------------------------------------------------- /examples/image_utils.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | import numpy as np 3 | from keras.layers import Input, Reshape 4 | 5 | 6 | def dim_ordering_fix(x): 7 | if K.image_dim_ordering() == 'th': 8 | return x 9 | else: 10 | return np.transpose(x, (0, 2, 3, 1)) 11 | 12 | 13 | def dim_ordering_unfix(x): 14 | if K.image_dim_ordering() == 'th': 15 | return x 16 | else: 17 | return np.transpose(x, (0, 3, 1, 2)) 18 | 19 | 20 | def dim_ordering_shape(input_shape): 21 | if K.image_dim_ordering() == 'th': 22 | return input_shape 23 | else: 24 | return (input_shape[1], input_shape[2], input_shape[0]) 25 | 26 | 27 | def dim_ordering_input(input_shape, name): 28 | if K.image_dim_ordering() == 'th': 29 | return Input(input_shape, name=name) 30 | else: 31 | return Input((input_shape[1], input_shape[2], input_shape[0]), name=name) 32 | 33 | 34 | def dim_ordering_reshape(k, w, **kwargs): 35 | if K.image_dim_ordering() == 'th': 36 | return Reshape((k, w, w), **kwargs) 37 | else: 38 | return Reshape((w, w, k), **kwargs) 39 | 40 | 41 | def channel_axis(): 42 | if K.image_dim_ordering() == 'th': 43 | return 1 44 | else: 45 | return 3 46 | -------------------------------------------------------------------------------- /examples/mnist_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from keras.datasets import mnist 3 | 4 | 5 | def mnist_process(x): 6 | x = x.astype(np.float32) / 255.0 7 | return x 8 | 9 | 10 | def mnist_data(): 11 | (xtrain, ytrain), (xtest, ytest) = mnist.load_data() 12 | return mnist_process(xtrain), mnist_process(xtest) 13 | -------------------------------------------------------------------------------- /keras_adversarial/__init__.py: -------------------------------------------------------------------------------- 1 | from .adversarial_model import AdversarialModel 2 | from .adversarial_optimizers import AdversarialOptimizerAlternating 3 | from .adversarial_optimizers import AdversarialOptimizerSimultaneous, AdversarialOptimizer 4 | from .adversarial_optimizers import AdversarialOptimizerScheduled 5 | from .adversarial_utils import gan_targets, build_gan, normal_latent_sampling, eliminate_z, fix_names, simple_gan 6 | from .adversarial_utils import n_choice, simple_bigan, gan_targets_hinge 7 | -------------------------------------------------------------------------------- /keras_adversarial/adversarial_model.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import numpy as np 4 | from keras import backend as K 5 | from keras import optimizers 6 | from keras.models import Model 7 | 8 | from .adversarial_utils import fix_names, merge_updates 9 | from .legacy import keras_2 10 | 11 | 12 | class AdversarialModel(Model): 13 | """ 14 | Adversarial training for multi-player games. 15 | Given a base model with n targets and k players, create a model with n*k targets. 16 | Each player optimizes loss on that player's targets. 17 | """ 18 | 19 | def __init__(self, player_params, base_model=None, player_models=None, player_names=None): 20 | """ 21 | Initialize adversarial model. Specify base_model or player_models, not both. 22 | :param player_params: list of player parameters for each player (shared variables) 23 | :param base_model: base model will be duplicated for each player to create player models 24 | :param player_models: model for each player 25 | :param player_names: names of each player (optional) 26 | """ 27 | 28 | assert (len(player_params) > 0) 29 | self.player_params = player_params 30 | self.player_count = len(self.player_params) 31 | if player_names is None: 32 | player_names = ["player_{}".format(i) for i in range(self.player_count)] 33 | assert (len(player_names) == self.player_count) 34 | self.player_names = player_names 35 | 36 | self.generator_optimizer = None 37 | self.discriminator_optimizer = None 38 | self.loss = None 39 | self.total_loss = None 40 | self.optimizer = None 41 | self._function_kwargs = None 42 | if base_model is None and player_models is None: 43 | raise ValueError("Please specify either base_model or player_models") 44 | if base_model is not None and player_models is not None: 45 | raise ValueError("Specify base_model or player_models, not both") 46 | if base_model is not None: 47 | self.layers = [] 48 | for i in range(self.player_count): 49 | # duplicate base model 50 | model = Model(base_model.inputs, 51 | fix_names(base_model(base_model.inputs), base_model.output_names)) 52 | # add model to list 53 | self.layers.append(model) 54 | if player_models is not None: 55 | assert (len(player_models) == self.player_count) 56 | self.layers = player_models 57 | 58 | def adversarial_compile(self, adversarial_optimizer, player_optimizers, loss, player_compile_kwargs=None, 59 | **kwargs): 60 | """ 61 | Configures the learning process. 62 | :param adversarial_optimizer: instance of AdversarialOptimizer 63 | :param player_optimizers: list of optimizers for each player 64 | :param loss: loss function or function name 65 | :param player_compile_kwargs: list of additional arguments to model compilation for each player 66 | :param kwargs: additional arguments to function compilation 67 | :return: 68 | """ 69 | self._function_kwargs = kwargs 70 | self.adversarial_optimizer = adversarial_optimizer 71 | assert (len(player_optimizers) == self.player_count) 72 | 73 | self.optimizers = [optimizers.get(optimizer) for optimizer in player_optimizers] 74 | self.loss = loss 75 | self.optimizer = None 76 | 77 | if player_compile_kwargs is None: 78 | player_compile_kwargs = [{} for _ in self.layers] 79 | 80 | # Build player models 81 | for opt, model, compile_kwargs in zip(self.optimizers, self.layers, player_compile_kwargs): 82 | model.compile(opt, loss=self.loss, **compile_kwargs) 83 | 84 | self.train_function = None 85 | self.test_function = None 86 | 87 | # Inputs are same for each model 88 | def filter_inputs(inputs): 89 | return inputs 90 | 91 | self.internal_input_shapes = filter_inputs(self.layers[0].internal_input_shapes) 92 | self.input_names = filter_inputs(self.layers[0].input_names) 93 | self.inputs = filter_inputs(self.layers[0].inputs) 94 | 95 | # Outputs are concatenated player models 96 | models = self.layers 97 | 98 | def collect(f): 99 | return list(itertools.chain.from_iterable(f(m) for m in models)) 100 | 101 | self.internal_output_shapes = collect(lambda m: m.internal_output_shapes) 102 | self.loss_functions = collect(lambda m: m.loss_functions) 103 | 104 | self.targets = collect(lambda m: m.targets) 105 | self.outputs = collect(lambda m: m.outputs) 106 | self.sample_weights = collect(lambda m: m.sample_weights) 107 | self.sample_weight_modes = collect(lambda m: m.sample_weight_modes) 108 | # for each target, output name is {player}_{target} 109 | self.output_names = [] 110 | for i in range(self.player_count): 111 | for name in models[i].output_names: 112 | self.output_names.append("{}_{}".format(self.player_names[i], name)) 113 | # for each metric, metric name is {player}_{metric} 114 | self.metrics_names = ["loss"] 115 | for i in range(self.player_count): 116 | for name in models[i].metrics_names: 117 | self.metrics_names.append("{}_{}".format(self.player_names[i], name)) 118 | 119 | # total loss is sum of losses 120 | self.total_loss = np.float32(0) 121 | for model in models: 122 | self.total_loss += model.total_loss 123 | 124 | # Keras-2 125 | self._feed_loss_fns = self.loss_functions 126 | self._feed_inputs = self.inputs 127 | self._feed_input_names = self.input_names 128 | self._feed_input_shapes = self.internal_input_shapes 129 | self._feed_outputs = self.outputs 130 | self._feed_output_names = self.output_names 131 | self._feed_output_shapes = self.internal_output_shapes 132 | self._feed_sample_weights = self.sample_weights 133 | self._feed_sample_weight_modes = self.sample_weight_modes 134 | 135 | @property 136 | def constraints(self): 137 | if keras_2: 138 | return [] 139 | else: 140 | return list(itertools.chain.from_iterable(model.constraints for model in self.layers)) 141 | 142 | @property 143 | def updates(self): 144 | return merge_updates(list(itertools.chain.from_iterable(model.updates for model in self.layers))) 145 | 146 | @property 147 | def regularizers(self): 148 | return list(itertools.chain.from_iterable(model.regularizers for model in self.layers)) 149 | 150 | def _make_train_function(self): 151 | if not hasattr(self, 'train_function'): 152 | raise Exception('You must compile your model before using it.') 153 | if self.train_function is None: 154 | inputs = self.inputs + self.targets + self.sample_weights 155 | if self.uses_learning_phase and not isinstance(K.learning_phase(), int): 156 | inputs += [K.learning_phase()] 157 | outputs = [self.total_loss] 158 | outputs += list(itertools.chain.from_iterable( 159 | [model.total_loss] + model.metrics_tensors 160 | for model in self.layers)) 161 | 162 | # returns loss and metrics. Updates weights at each call. 163 | constraints = [{} for model in self.layers] if keras_2 else [model.constraints for model in self.layers] 164 | self.train_function = self.adversarial_optimizer.make_train_function(inputs, outputs, 165 | [model.total_loss for model in 166 | self.layers], 167 | self.player_params, 168 | self.optimizers, 169 | constraints, 170 | self.updates, 171 | self._function_kwargs) 172 | 173 | def _make_test_function(self): 174 | if not hasattr(self, 'test_function'): 175 | raise Exception('You must compile your model before using it.') 176 | if self.test_function is None: 177 | inputs = self.inputs + self.targets + self.sample_weights 178 | if self.uses_learning_phase and not isinstance(K.learning_phase(), int): 179 | inputs += [K.learning_phase()] 180 | outputs = [self.total_loss] 181 | outputs += list(itertools.chain.from_iterable( 182 | [model.total_loss] + model.metrics_tensors 183 | for model in self.layers)) 184 | self.test_function = K.function(inputs, 185 | outputs, 186 | updates=self.state_updates, 187 | **self._function_kwargs) 188 | -------------------------------------------------------------------------------- /keras_adversarial/adversarial_optimizers.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import keras.backend as K 4 | 5 | from .legacy import get_updates 6 | 7 | 8 | class AdversarialOptimizer(object): 9 | __metaclass__ = ABCMeta 10 | 11 | @abstractmethod 12 | def make_train_function(self, inputs, outputs, losses, params, optimizers, constraints, model_updates, 13 | function_kwargs): 14 | """ 15 | Construct function that updates weights and returns losses. 16 | :param inputs: function inputs 17 | :param outputs: function outputs 18 | :param losses: player losses 19 | :param params: player parameters 20 | :param optimizers: player optimizers 21 | :param constraints: player constraints 22 | :param function_kwargs: function kwargs 23 | :return: 24 | """ 25 | pass 26 | 27 | 28 | class AdversarialOptimizerSimultaneous(object): 29 | """ 30 | Perform simultaneous updates for each player in the game. 31 | """ 32 | 33 | def make_train_function(self, inputs, outputs, losses, params, optimizers, constraints, model_updates, 34 | function_kwargs): 35 | return K.function(inputs, 36 | outputs, 37 | updates=self.call(losses, params, optimizers, constraints) + model_updates, 38 | **function_kwargs) 39 | 40 | def call(self, losses, params, optimizers, constraints): 41 | updates = [] 42 | for loss, param, optimizer, constraint in zip(losses, params, optimizers, constraints): 43 | updates += optimizer.get_updates(param, constraint, loss) 44 | return updates 45 | 46 | 47 | class AdversarialOptimizerAlternating(object): 48 | """ 49 | Perform round-robin updates for each player in the game. Each player takes a turn. 50 | Take each batch and run that batch through each of the models. All models are trained on each batch. 51 | """ 52 | 53 | def __init__(self, reverse=False): 54 | """ 55 | Initialize optimizer. 56 | :param reverse: players take turns in reverse order 57 | """ 58 | self.reverse = reverse 59 | 60 | def make_train_function(self, inputs, outputs, losses, params, optimizers, constraints, model_updates, 61 | function_kwargs): 62 | funcs = [] 63 | for loss, param, optimizer, constraint in zip(losses, params, optimizers, constraints): 64 | updates = optimizer.get_updates(param, constraint, loss) 65 | funcs.append(K.function(inputs, [], updates=updates, **function_kwargs)) 66 | output_func = K.function(inputs, outputs, updates=model_updates, **function_kwargs) 67 | if self.reverse: 68 | funcs = funcs.reverse() 69 | 70 | def train(_inputs): 71 | # update each player 72 | for func in funcs: 73 | func(_inputs) 74 | # return output 75 | return output_func(_inputs) 76 | 77 | return train 78 | 79 | 80 | class AdversarialOptimizerScheduled(object): 81 | """ 82 | Perform updates according to a schedule. 83 | For example, [0,0,1] will train player 0 on batches 0,1,3,4,6,7... and player 1 on batches 2,5,8... 84 | """ 85 | 86 | def __init__(self, schedule): 87 | """ 88 | Initialize optimizer. 89 | :param schedule: Schedule of updates 90 | """ 91 | assert len(schedule) > 0 92 | self.schedule = schedule 93 | self.iter = 0 94 | 95 | def make_train_function(self, inputs, outputs, losses, params, optimizers, constraints, model_updates, 96 | function_kwargs): 97 | funcs = [] 98 | for loss, param, optimizer, constraint in zip(losses, params, optimizers, constraints): 99 | updates = get_updates(optimizer=optimizer, params=param, constraints=constraint, loss=loss) 100 | funcs.append(K.function(inputs, outputs, updates=updates + model_updates, **function_kwargs)) 101 | 102 | def train(_inputs): 103 | self.iter += 1 104 | if self.iter == len(self.schedule): 105 | self.iter = 0 106 | func = funcs[self.schedule[self.iter]] 107 | return func(_inputs) 108 | 109 | return train 110 | -------------------------------------------------------------------------------- /keras_adversarial/adversarial_utils.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | import numpy as np 3 | from keras.layers import Activation, Lambda 4 | from keras.models import Model 5 | from six import iteritems 6 | 7 | from .backend import unpack_assignment, variable_key 8 | 9 | 10 | def build_gan(generator, discriminator, name="gan"): 11 | """ 12 | Build GAN from generator and discriminator 13 | Model is (z, x) -> (yfake, yreal) 14 | :param generator: Model (z -> x) 15 | :param discriminator: Model (x -> y) 16 | :return: GAN model 17 | """ 18 | yfake = Activation("linear", name="yfake")(discriminator(generator(generator.inputs))) 19 | yreal = Activation("linear", name="yreal")(discriminator(discriminator.inputs)) 20 | model = Model(generator.inputs + discriminator.inputs, [yfake, yreal], name=name) 21 | return model 22 | 23 | 24 | def eliminate_z(gan, latent_sampling): 25 | """ 26 | Eliminate z from GAN using latent_sampling 27 | :param gan: model with 2 inputs: z, x 28 | :param latent_sampling: layer that samples z with same batch size as x 29 | :return: Model x -> gan(latent_sampling(x), x) 30 | """ 31 | x = gan.inputs[1] 32 | z = latent_sampling(x) 33 | model = Model(x, fix_names(gan([z, x]), gan.output_names), name=gan.name) 34 | return model 35 | 36 | 37 | def simple_gan(generator, discriminator, latent_sampling): 38 | # build basic gan 39 | gan = build_gan(generator, discriminator) 40 | # generate z on gpu, eliminate one input 41 | if latent_sampling is None: 42 | return gan 43 | else: 44 | return eliminate_z(gan, latent_sampling) 45 | 46 | 47 | def simple_bigan(generator, encoder, discriminator, latent_sampling=None): 48 | """ 49 | Construct BiGRAN x -> yfake, yreal 50 | :param generator: model z->x 51 | :param encoder: model x->z 52 | :param discriminator: model z,x->y (z must be first) 53 | :param latent_sampling: layer for sampling from latent space 54 | :return: 55 | """ 56 | if latent_sampling is None: 57 | zfake = generator.inputs[0] 58 | else: 59 | zfake = latent_sampling(discriminator.inputs[1]) 60 | xreal = discriminator.inputs[1] 61 | xfake = generator(zfake) 62 | zreal = encoder(xreal) 63 | yfake = discriminator([zfake, xfake]) 64 | yreal = discriminator([zreal, xreal]) 65 | if latent_sampling is None: 66 | inputs = [zfake, xreal] 67 | else: 68 | inputs = [xreal] 69 | return Model(inputs, fix_names([yfake, yreal], ["yfake", "yreal"]), name="bigan") 70 | 71 | 72 | def fix_names(outputs, names): 73 | if not isinstance(outputs, list): 74 | outputs = [outputs] 75 | if not isinstance(names, list): 76 | names = [names] 77 | return [Activation('linear', name=name)(output) for output, name in zip(outputs, names)] 78 | 79 | 80 | def gan_targets(n): 81 | """ 82 | Standard training targets 83 | [generator_fake, generator_real, discriminator_fake, discriminator_real] = [1, 0, 0, 1] 84 | :param n: number of samples 85 | :return: array of targets 86 | """ 87 | generator_fake = np.ones((n, 1)) 88 | generator_real = np.zeros((n, 1)) 89 | discriminator_fake = np.zeros((n, 1)) 90 | discriminator_real = np.ones((n, 1)) 91 | return [generator_fake, generator_real, discriminator_fake, discriminator_real] 92 | 93 | 94 | def gan_targets_hinge(n): 95 | """ 96 | Standard training targets for hinge loss 97 | [generator_fake, generator_real, discriminator_fake, discriminator_real] = [1, -1, -1, 1] 98 | :param n: number of samples 99 | :return: array of targets 100 | """ 101 | generator_fake = np.ones((n, 1)) 102 | generator_real = np.ones((n, 1)) * -1 103 | discriminator_fake = np.ones((n, 1)) * -1 104 | discriminator_real = np.ones((n, 1)) 105 | return [generator_fake, generator_real, discriminator_fake, discriminator_real] 106 | 107 | 108 | def normal_latent_sampling(latent_shape): 109 | """ 110 | Sample from normal distribution 111 | :param latent_shape: batch shape 112 | :return: normal samples, shape=(n,)+latent_shape 113 | """ 114 | return Lambda(lambda x: K.random_normal((K.shape(x)[0],) + latent_shape), 115 | output_shape=lambda x: ((x[0],) + latent_shape)) 116 | 117 | 118 | def uniform_latent_sampling(latent_shape, low=0.0, high=1.0): 119 | """ 120 | Sample from uniform distribution 121 | :param latent_shape: batch shape 122 | :return: normal samples, shape=(n,)+latent_shape 123 | """ 124 | return Lambda(lambda x: K.random_uniform((K.shape(x)[0],) + latent_shape, low, high), 125 | output_shape=lambda x: ((x[0],) + latent_shape)) 126 | 127 | 128 | def n_choice(x, n): 129 | return x[np.random.choice(x.shape[0], size=n, replace=False)] 130 | 131 | 132 | def merge_updates(updates): 133 | """Average repeated updates of the same variable""" 134 | merged_updates = {} 135 | for update in updates: 136 | variable, value = unpack_assignment(update) 137 | key = variable_key(variable) 138 | if key not in merged_updates: 139 | merged_updates[key] = [variable, []] 140 | merged_updates[key][1].append(value) 141 | ret = [] 142 | for k, v in iteritems(merged_updates): 143 | variable = v[0] 144 | values = v[1] 145 | n = len(values) 146 | if n == 1: 147 | ret.append(K.update(variable, value[0])) 148 | else: 149 | ret.append(K.update(variable, sum(values) / n)) 150 | return ret 151 | -------------------------------------------------------------------------------- /keras_adversarial/backend/__init__.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | 3 | if K.backend() == "tensorflow": 4 | from .tensorflow_backend import unpack_assignment, clone_replace, map_params, variable_key 5 | else: 6 | from .theano_backend import unpack_assignment, clone_replace, map_params, variable_key 7 | 8 | 9 | def unpack_assignments(assignments): 10 | return [unpack_assignment(a) for a in assignments] 11 | -------------------------------------------------------------------------------- /keras_adversarial/backend/tensorflow_backend.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from six import iterkeys 3 | from tensorflow.contrib.graph_editor import select 4 | from tensorflow.contrib.graph_editor import util 5 | from tensorflow.python.framework import ops as tf_ops 6 | 7 | 8 | def unpack_assignment(a): 9 | if isinstance(a, (list, tuple)): 10 | assert (len(a) == 2) 11 | return a 12 | elif isinstance(a, tf.Tensor): 13 | assert (a.op.type in ['Assign', 'AssignAdd', 'AssignSub']) 14 | if a.op.type == 'Assign': 15 | return a.op.inputs[0], a.op.inputs[1] 16 | if a.op.type == 'AssignAdd': 17 | return a.op.inputs[0], a.op.inputs[0] + a.op.inputs[1] 18 | elif a.op.type == 'AssignSub': 19 | return a.op.inputs[0], a.op.inputs[0] - a.op.inputs[1] 20 | else: 21 | raise ValueError("Unsupported operation: {}".format(a.op.type)) 22 | else: 23 | raise ValueError("Unsupported assignment object type: {}".format(type(a))) 24 | 25 | 26 | def map_params(params): 27 | return [x.op.outputs[0] for x in params] 28 | 29 | 30 | def clone_replace(f, replace): 31 | flatten_target_ts = util.flatten_tree(f) 32 | graph = util.get_unique_graph(flatten_target_ts, check_types=(tf_ops.Tensor)) 33 | control_ios = util.ControlOutputs(graph) 34 | ops = select.get_walks_intersection_ops(list(iterkeys(replace)), 35 | flatten_target_ts, 36 | control_ios=control_ios) 37 | if not ops: 38 | # this happens with disconnected inputs 39 | return f 40 | else: 41 | return tf.contrib.graph_editor.graph_replace(f, replace) 42 | 43 | 44 | def variable_key(a): 45 | if hasattr(a, "op"): 46 | return a.op 47 | else: 48 | return a 49 | -------------------------------------------------------------------------------- /keras_adversarial/backend/tensorflow_monkeypatch.py: -------------------------------------------------------------------------------- 1 | import keras.backend 2 | 3 | """ 4 | Import this file to monkeypatch tensorflow to lazily convert tuples to tf_assign inside K.function. 5 | Makes unpacking and inspecting updates in tensorflow much cleaner. 6 | """ 7 | 8 | 9 | def update(x, new_x): 10 | return (x, new_x) 11 | 12 | 13 | def update_add(x, increment): 14 | return (x, x + increment) 15 | 16 | 17 | def update_sub(x, decrement): 18 | return (x, x - decrement) 19 | 20 | 21 | def moving_average_update(variable, value, momentum): 22 | return (variable, variable * momentum + value * (1. - momentum)) 23 | 24 | 25 | keras.backend.update = update 26 | keras.backend.update_add = update_add 27 | keras.backend.update_sub = update_sub 28 | keras.backend.moving_average_update = moving_average_update 29 | -------------------------------------------------------------------------------- /keras_adversarial/backend/theano_backend.py: -------------------------------------------------------------------------------- 1 | from theano import clone 2 | 3 | 4 | def unpack_assignment(a): 5 | return a 6 | 7 | 8 | def map_params(params): 9 | return params 10 | 11 | 12 | def clone_replace(f, replace): 13 | return clone(f, replace=replace) 14 | 15 | 16 | def variable_key(a): 17 | return a 18 | -------------------------------------------------------------------------------- /keras_adversarial/image_grid.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from matplotlib import pyplot as plt, gridspec 4 | 5 | 6 | def write_image_grid(filepath, imgs, figsize=None, cmap='gray'): 7 | directory = os.path.dirname(os.path.abspath(filepath)) 8 | if not os.path.exists(directory): 9 | os.makedirs(directory) 10 | fig = create_image_grid(imgs, figsize, cmap=cmap) 11 | fig.savefig(filepath) 12 | plt.close(fig) 13 | 14 | 15 | def create_image_grid(imgs, figsize=None, cmap='gray'): 16 | n = imgs.shape[0] 17 | m = imgs.shape[1] 18 | if figsize is None: 19 | figsize = (n, m) 20 | fig = plt.figure(figsize=figsize) 21 | gs1 = gridspec.GridSpec(n, m) 22 | gs1.update(wspace=0.025, hspace=0.025) # set the spacing between axes. 23 | for i in range(n): 24 | for j in range(m): 25 | ax = plt.subplot(gs1[i, j]) 26 | img = imgs[i, j, :] 27 | ax.imshow(img, cmap=cmap) 28 | ax.axis('off') 29 | return fig 30 | -------------------------------------------------------------------------------- /keras_adversarial/image_grid_callback.py: -------------------------------------------------------------------------------- 1 | from keras.callbacks import Callback 2 | 3 | from .image_grid import write_image_grid 4 | 5 | 6 | class ImageGridCallback(Callback): 7 | def __init__(self, image_path, generator, cmap='gray'): 8 | self.image_path = image_path 9 | self.generator = generator 10 | self.cmap = cmap 11 | 12 | def on_epoch_end(self, epoch, logs={}): 13 | xsamples = self.generator() 14 | image_path = self.image_path.format(epoch) 15 | write_image_grid(image_path, xsamples, cmap=self.cmap) 16 | -------------------------------------------------------------------------------- /keras_adversarial/legacy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions to avoid warnings while testing both Keras 1 and 2. 3 | """ 4 | import keras 5 | 6 | keras_2 = int(keras.__version__.split(".")[0]) > 1 # Keras > 1 7 | 8 | 9 | def fit_generator(model, generator, epochs, steps_per_epoch): 10 | if keras_2: 11 | model.fit_generator(generator, epochs=epochs, steps_per_epoch=steps_per_epoch) 12 | else: 13 | model.fit_generator(generator, nb_epoch=epochs, samples_per_epoch=steps_per_epoch) 14 | 15 | 16 | def fit(model, x, y, nb_epoch=10, *args, **kwargs): 17 | if keras_2: 18 | return model.fit(x, y, *args, epochs=nb_epoch, **kwargs) 19 | else: 20 | return model.fit(x, y, *args, nb_epoch=nb_epoch, **kwargs) 21 | 22 | 23 | def l1l2(l1=0, l2=0): 24 | if keras_2: 25 | return keras.regularizers.L1L2(l1, l2) 26 | else: 27 | return keras.regularizers.l1l2(l1, l2) 28 | 29 | 30 | def Dense(units, W_regularizer=None, W_initializer='glorot_uniform', **kwargs): 31 | if keras_2: 32 | return keras.layers.Dense(units, kernel_regularizer=W_regularizer, kernel_initializer=W_initializer, **kwargs) 33 | else: 34 | return keras.layers.Dense(units, W_regularizer=W_regularizer, init=W_initializer, **kwargs) 35 | 36 | 37 | def BatchNormalization(mode=0, **kwargs): 38 | if keras_2: 39 | return keras.layers.BatchNormalization(**kwargs) 40 | else: 41 | return keras.layers.BatchNormalization(mode=mode, **kwargs) 42 | 43 | 44 | def Convolution2D(units, w, h, W_regularizer=None, W_initializer='glorot_uniform', border_mode='same', **kwargs): 45 | if keras_2: 46 | return keras.layers.Convolution2D(units, (w, h), padding=border_mode, kernel_regularizer=W_regularizer, 47 | kernel_initializer=W_initializer, 48 | **kwargs) 49 | else: 50 | return keras.layers.Convolution2D(units, w, h, border_mode=border_mode, W_regularizer=W_regularizer, 51 | init=W_initializer, 52 | **kwargs) 53 | 54 | 55 | def AveragePooling2D(pool_size, border_mode='valid', **kwargs): 56 | if keras_2: 57 | return keras.layers.AveragePooling2D(pool_size=pool_size, padding=border_mode, **kwargs) 58 | else: 59 | return keras.layers.AveragePooling2D(pool_size=pool_size, border_mode=border_mode, **kwargs) 60 | 61 | 62 | def get_updates(optimizer, params, constraints, loss): 63 | if keras_2: 64 | return optimizer.get_updates(params, constraints, loss) 65 | else: 66 | return optimizer.get_updates(params=params, loss=loss) 67 | -------------------------------------------------------------------------------- /keras_adversarial/unrolled_optimizer.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | 3 | from .adversarial_optimizers import AdversarialOptimizerSimultaneous 4 | from .backend import unpack_assignments, clone_replace 5 | 6 | 7 | def unroll(updates, uupdates, depth): 8 | replace = {k: v for k, v in unpack_assignments(uupdates)} 9 | updates_t = unpack_assignments(updates) 10 | for i in range(depth): 11 | updates_t = [(k, clone_replace(v, replace)) for k, v in updates_t] 12 | return [K.update(a, b) for a, b in updates_t] 13 | 14 | 15 | class UnrolledAdversarialOptimizer(AdversarialOptimizerSimultaneous): 16 | def __init__(self, depth_g, depth_d): 17 | """ 18 | :param depth_g: Depth to unroll discriminator when updating generator 19 | :param depth_d: Depth to unroll generator when updating discriminator 20 | """ 21 | self.depth_g = depth_g 22 | self.depth_d = depth_d 23 | 24 | def call(self, losses, params, optimizers, constraints): 25 | # Players should be [generator, discriminator] 26 | assert (len(optimizers) == 2) 27 | 28 | updates = [o.get_updates(p, c, l) for o, p, c, l in zip(optimizers, params, constraints, losses)] 29 | 30 | gupdates = unroll(updates[0], updates[1], self.depth_g) 31 | dupdates = unroll(updates[1], updates[0], self.depth_d) 32 | 33 | return gupdates + dupdates 34 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | # Configuration of py.test 2 | [pytest] 3 | addopts=-v 4 | --durations=10 5 | 6 | python_functions = *_test 7 | # Do not run tests in the build folder 8 | norecursedirs= build 9 | 10 | # PEP-8 The following are ignored: 11 | # E501 line too long (82 > 79 characters) 12 | # E402 module level import not at top of file - temporary measure to continue adding ros python packaged in sys.path 13 | # E731 do not assign a lambda expression, use a def 14 | 15 | pep8ignore=* E501 \ 16 | * E402 \ 17 | * E731 \ 18 | 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | keras>=1.1.2 2 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.rst 3 | [bdist_wheel] 4 | universal=1 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | long_description = open('README.rst').read() 4 | version = '0.0.3' 5 | 6 | setup(name='keras-adversarial', 7 | version=version, 8 | description='Adversarial models and optimizers for Keras', 9 | url='https://github.com/bstriner/keras-adversarial', 10 | download_url='https://github.com/bstriner/keras-adversarial/tarball/v{}'.format(version), 11 | author='Ben Striner', 12 | author_email='bstriner@gmail.com', 13 | packages=find_packages(), 14 | install_requires=['Keras'], 15 | keywords=['keras', 'gan', 'adversarial', 'multiplayer'], 16 | license='MIT', 17 | long_description=long_description, 18 | classifiers=[ 19 | # Indicate who your project is intended for 20 | 'Intended Audience :: Developers', 21 | # Pick your license as you wish (should match "license" above) 22 | 'License :: OSI Approved :: MIT License', 23 | 24 | # Specify the Python versions you support here. In particular, ensure 25 | # that you indicate whether you support Python 2, Python 3 or both. 26 | 'Programming Language :: Python :: 2', 27 | 'Programming Language :: Python :: 3' 28 | ]) 29 | -------------------------------------------------------------------------------- /tests/integration/gan_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from keras.layers import LeakyReLU, Activation 4 | from keras.models import Sequential 5 | from keras.optimizers import Adam 6 | from keras_adversarial import AdversarialModel, simple_gan, gan_targets 7 | from keras_adversarial import AdversarialOptimizerSimultaneous, normal_latent_sampling 8 | from keras_adversarial.legacy import fit, Dense 9 | 10 | 11 | def model_generator(latent_dim, input_dim, hidden_dim=256): 12 | return Sequential([ 13 | Dense(hidden_dim, name="generator_h1", input_dim=latent_dim), 14 | LeakyReLU(0.2), 15 | Dense(hidden_dim, name="generator_h2"), 16 | LeakyReLU(0.2), 17 | Dense(hidden_dim, name="generator_h3"), 18 | LeakyReLU(0.2), 19 | Dense(input_dim, name="generator_x_flat")], 20 | name="generator") 21 | 22 | 23 | def model_discriminator(input_dim, hidden_dim=256): 24 | return Sequential([ 25 | Dense(hidden_dim, name="discriminator_h1", input_dim=input_dim), 26 | LeakyReLU(0.2), 27 | Dense(hidden_dim, name="discriminator_h2"), 28 | LeakyReLU(0.2), 29 | Dense(hidden_dim, name="discriminator_h3"), 30 | LeakyReLU(0.2), 31 | Dense(1, name="discriminator_y"), 32 | Activation('sigmoid')], 33 | name="discriminator") 34 | 35 | 36 | def gan_model_test(): 37 | latent_dim = 10 38 | input_dim = 5 39 | generator = model_generator(input_dim=input_dim, latent_dim=latent_dim) 40 | discriminator = model_discriminator(input_dim=input_dim) 41 | gan = simple_gan(generator, discriminator, normal_latent_sampling((latent_dim,))) 42 | 43 | # build adversarial model 44 | model = AdversarialModel(base_model=gan, 45 | player_params=[generator.trainable_weights, discriminator.trainable_weights], 46 | player_names=["generator", "discriminator"]) 47 | adversarial_optimizer = AdversarialOptimizerSimultaneous() 48 | opt_g = Adam(1e-4) 49 | opt_d = Adam(1e-3) 50 | loss = 'binary_crossentropy' 51 | model.adversarial_compile(adversarial_optimizer=adversarial_optimizer, 52 | player_optimizers=[opt_g, opt_d], 53 | loss=loss) 54 | 55 | # train model 56 | batch_size = 32 57 | n = batch_size * 8 58 | x = np.random.random((n, input_dim)) 59 | y = gan_targets(n) 60 | fit(model, x, y, nb_epoch=3, batch_size=batch_size) 61 | 62 | 63 | if __name__ == "__main__": 64 | pytest.main([__file__]) 65 | --------------------------------------------------------------------------------