├── ali ├── __init__.py ├── algorithms.py ├── utils.py ├── graphing.py ├── streams.py ├── datasets.py ├── bricks.py ├── mixture_viz.py └── conditional_bricks.py ├── paper ├── .gitignore ├── celeba_samples.png ├── mixture_plot.png ├── svhn_samples.png ├── cifar10_samples.png ├── celeba_interpolations.png ├── svhn_reconstructions.png ├── tiny_imagenet_samples.png ├── celeba_reconstructions.png ├── cifar10_reconstructions.png ├── celeba_conditional_sequence.png ├── tiny_imagenet_reconstructions.png ├── iclr2017_conference.sty ├── bibliography.bib └── fancyhdr.sty ├── theanorc ├── LICENSE ├── setup.py ├── scripts ├── sample ├── reconstruct ├── preprocess_representations ├── interpolate ├── generate_spiral_plots └── generate_mixture_plots ├── .gitignore ├── README.md └── experiments ├── semi_supervised_svhn.py ├── gan_mixture.py ├── ali_svhn.py ├── ali_celeba.py ├── ali_cifar10.py ├── ali_tiny_imagenet.py ├── ali_mixture.py └── ali_celeba_conditional.py /ali/__init__.py: -------------------------------------------------------------------------------- 1 | from . import algorithms, bricks, datasets, streams, utils 2 | -------------------------------------------------------------------------------- /paper/.gitignore: -------------------------------------------------------------------------------- 1 | *.aux 2 | *.bbl 3 | *.blg 4 | *.fdb_latexmk 5 | *.fls 6 | *.out 7 | *.pdf 8 | -------------------------------------------------------------------------------- /paper/celeba_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IshmaelBelghazi/ALI/HEAD/paper/celeba_samples.png -------------------------------------------------------------------------------- /paper/mixture_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IshmaelBelghazi/ALI/HEAD/paper/mixture_plot.png -------------------------------------------------------------------------------- /paper/svhn_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IshmaelBelghazi/ALI/HEAD/paper/svhn_samples.png -------------------------------------------------------------------------------- /paper/cifar10_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IshmaelBelghazi/ALI/HEAD/paper/cifar10_samples.png -------------------------------------------------------------------------------- /paper/celeba_interpolations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IshmaelBelghazi/ALI/HEAD/paper/celeba_interpolations.png -------------------------------------------------------------------------------- /paper/svhn_reconstructions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IshmaelBelghazi/ALI/HEAD/paper/svhn_reconstructions.png -------------------------------------------------------------------------------- /paper/tiny_imagenet_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IshmaelBelghazi/ALI/HEAD/paper/tiny_imagenet_samples.png -------------------------------------------------------------------------------- /paper/celeba_reconstructions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IshmaelBelghazi/ALI/HEAD/paper/celeba_reconstructions.png -------------------------------------------------------------------------------- /paper/cifar10_reconstructions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IshmaelBelghazi/ALI/HEAD/paper/cifar10_reconstructions.png -------------------------------------------------------------------------------- /paper/celeba_conditional_sequence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IshmaelBelghazi/ALI/HEAD/paper/celeba_conditional_sequence.png -------------------------------------------------------------------------------- /paper/tiny_imagenet_reconstructions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IshmaelBelghazi/ALI/HEAD/paper/tiny_imagenet_reconstructions.png -------------------------------------------------------------------------------- /theanorc: -------------------------------------------------------------------------------- 1 | [global] 2 | device = gpu 3 | floatX = float32 4 | mode = FAST_RUN 5 | 6 | [lib] 7 | cnmem = 0.8 8 | 9 | [dnn.conv] 10 | algo_fwd = time_once 11 | algo_bwd_data = time_once 12 | algo_bwd_filter = time_once 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name='ali', 5 | version='0.1.0', 6 | description='Code for the "Adversarially Learned Inference" paper', 7 | long_description='Code for the "Adversarially Learned Inference" paper', 8 | url='https://github.com/IshmaelBelghazi/ALI', 9 | author='Vincent Dumoulin, Ishmael Belghazi', 10 | license='MIT', 11 | # See https://pypi.python.org/pypi?%3Aaction=list_classifiers 12 | classifiers=[ 13 | 'Development Status :: 3 - Alpha', 14 | 'Intended Audience :: Science/Research', 15 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 16 | 'Topic :: Scientific/Engineering :: Image Recognition', 17 | 'License :: OSI Approved :: MIT License', 18 | 'Programming Language :: Python :: 2', 19 | 'Programming Language :: Python :: 2.7', 20 | 'Programming Language :: Python :: 3', 21 | 'Programming Language :: Python :: 3.4', 22 | ], 23 | keywords='theano blocks machine learning neural networks deep learning', 24 | packages=find_packages(exclude=['scripts', 'experiments']), 25 | install_requires=['numpy', 'theano', 'blocks', 'fuel'], 26 | zip_safe=False) 27 | -------------------------------------------------------------------------------- /ali/algorithms.py: -------------------------------------------------------------------------------- 1 | """ALI-related training algorithms.""" 2 | from collections import OrderedDict 3 | 4 | import theano 5 | from blocks.algorithms import GradientDescent, CompositeRule, Restrict 6 | 7 | 8 | def ali_algorithm(discriminator_loss, discriminator_parameters, 9 | discriminator_step_rule, generator_loss, 10 | generator_parameters, generator_step_rule): 11 | """Instantiates a training algorithm for ALI. 12 | 13 | Parameters 14 | ---------- 15 | discriminator_loss : tensor variable 16 | Discriminator loss. 17 | discriminator_parameters : list 18 | Discriminator parameters. 19 | discriminator_step_rule : :class:`blocks.algorithms.StepRule` 20 | Discriminator step rule. 21 | generator_loss : tensor variable 22 | Generator loss. 23 | generator_parameters : list 24 | Generator parameters. 25 | generator_step_rule : :class:`blocks.algorithms.StepRule` 26 | Generator step rule. 27 | """ 28 | gradients = OrderedDict() 29 | gradients.update( 30 | zip(discriminator_parameters, 31 | theano.grad(discriminator_loss, discriminator_parameters))) 32 | gradients.update( 33 | zip(generator_parameters, 34 | theano.grad(generator_loss, generator_parameters))) 35 | step_rule = CompositeRule([Restrict(discriminator_step_rule, 36 | discriminator_parameters), 37 | Restrict(generator_step_rule, 38 | generator_parameters)]) 39 | return GradientDescent( 40 | cost=generator_loss + discriminator_loss, 41 | gradients=gradients, 42 | parameters=discriminator_parameters + generator_parameters, 43 | step_rule=step_rule) 44 | -------------------------------------------------------------------------------- /ali/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions.""" 2 | import random 3 | import string 4 | 5 | import numpy 6 | import theano 7 | from blocks.bricks.bn import SpatialBatchNormalization 8 | from blocks.bricks.conv import Convolutional, ConvolutionalTranspose 9 | 10 | 11 | def name_generator(): 12 | """Returns a random 8-character name.""" 13 | return ''.join(random.choice(string.ascii_uppercase + string.digits) 14 | for _ in range(8)) 15 | 16 | 17 | def get_log_odds(raw_marginals): 18 | """Computes marginal log-odds.""" 19 | marginals = numpy.clip(raw_marginals.mean(axis=0), 1e-7, 1 - 1e-7) 20 | return numpy.log(marginals / (1 - marginals)).astype(theano.config.floatX) 21 | 22 | 23 | def conv_brick(filter_size, step, num_filters, border_mode='valid'): 24 | """Instantiates a ConvolutionalBrick.""" 25 | return Convolutional(filter_size=(filter_size, filter_size), 26 | step=(step, step), 27 | border_mode=border_mode, 28 | num_filters=num_filters, 29 | name=name_generator()) 30 | 31 | 32 | def conv_transpose_brick(filter_size, step, num_filters, border_mode='valid'): 33 | """Instantiates a ConvolutionalTranspose brick.""" 34 | return ConvolutionalTranspose(filter_size=(filter_size, filter_size), 35 | step=(step, step), 36 | border_mode=border_mode, 37 | num_filters=num_filters, 38 | name=name_generator()) 39 | 40 | 41 | def bn_brick(): 42 | """Instantiates a SpatialBatchNormalization brick.""" 43 | return SpatialBatchNormalization(name=name_generator()) 44 | 45 | 46 | def as_array(obj, dtype=theano.config.floatX): 47 | """Converts to ndarray of specified dtype""" 48 | return numpy.asarray(obj, dtype=dtype) 49 | -------------------------------------------------------------------------------- /scripts/sample: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | 4 | import theano 5 | from blocks.serialization import load 6 | from matplotlib import cm, pyplot 7 | from mpl_toolkits.axes_grid1 import ImageGrid 8 | 9 | 10 | def main(main_loop, nrows, ncols, save_path=None): 11 | ali, = main_loop.model.top_bricks 12 | input_shape = ali.encoder.get_dim('output') 13 | z = ali.theano_rng.normal(size=(nrows * ncols,) + input_shape) 14 | x = ali.sample(z) 15 | samples = theano.function([], x)() 16 | 17 | figure = pyplot.figure() 18 | grid = ImageGrid(figure, 111, (nrows, ncols), axes_pad=0.1) 19 | 20 | for sample, axis in zip(samples, grid): 21 | axis.imshow(sample.transpose(1, 2, 0).squeeze(), 22 | cmap=cm.Greys_r, interpolation='nearest') 23 | axis.set_yticklabels(['' for _ in range(sample.shape[1])]) 24 | axis.set_xticklabels(['' for _ in range(sample.shape[2])]) 25 | axis.axis('off') 26 | 27 | if save_path is None: 28 | pyplot.show() 29 | else: 30 | pyplot.savefig(save_path, transparent=True, bbox_inches='tight') 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser(description="Plot samples.") 35 | parser.add_argument("main_loop_path", type=str, 36 | help="path to the pickled main loop.") 37 | parser.add_argument("--nrows", type=int, default=10, 38 | help="number of rows of samples to display.") 39 | parser.add_argument("--ncols", type=int, default=10, 40 | help="number of columns of samples to display.") 41 | parser.add_argument("--save-path", type=str, default=None, 42 | help="where to save the generated samples.") 43 | args = parser.parse_args() 44 | 45 | with open(args.main_loop_path, 'rb') as src: 46 | main(load(src), args.nrows, args.ncols, args.save_path) 47 | -------------------------------------------------------------------------------- /scripts/reconstruct: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | 4 | import numpy 5 | import theano 6 | from blocks.serialization import load 7 | from matplotlib import cm, pyplot 8 | from mpl_toolkits.axes_grid1 import ImageGrid 9 | from theano import tensor 10 | 11 | from ali import streams 12 | 13 | 14 | def main(main_loop, data_stream, nrows, ncols, save_path): 15 | ali, = main_loop.model.top_bricks 16 | x = tensor.tensor4('features') 17 | examples, = next(data_stream.get_epoch_iterator()) 18 | reconstructions = theano.function([x], ali.reconstruct(x))(examples) 19 | 20 | figure = pyplot.figure() 21 | grid = ImageGrid(figure, 111, (nrows, 2 * ncols), axes_pad=0.1) 22 | images = numpy.empty( 23 | (2 * nrows * ncols,) + examples.shape[1:], dtype=examples.dtype) 24 | images[::2] = examples 25 | images[1::2] = reconstructions 26 | 27 | for image, axis in zip(images, grid): 28 | axis.imshow(image.transpose(1, 2, 0).squeeze(), 29 | cmap=cm.Greys_r, interpolation='nearest') 30 | axis.set_yticklabels(['' for _ in range(image.shape[1])]) 31 | axis.set_xticklabels(['' for _ in range(image.shape[2])]) 32 | axis.axis('off') 33 | 34 | if save_path is None: 35 | pyplot.show() 36 | else: 37 | pyplot.savefig(save_path, transparent=True, bbox_inches='tight') 38 | 39 | 40 | if __name__ == "__main__": 41 | stream_functions = { 42 | 'cifar10': streams.create_cifar10_data_streams, 43 | 'svhn': streams.create_svhn_data_streams, 44 | 'celeba': streams.create_celeba_data_streams, 45 | 'tiny_imagenet': streams.create_tiny_imagenet_data_streams} 46 | parser = argparse.ArgumentParser(description="Plot reconstructions.") 47 | parser.add_argument("which_dataset", type=str, 48 | choices=tuple(stream_functions.keys()), 49 | help="which dataset to compute reconstructions on.") 50 | parser.add_argument("main_loop_path", type=str, 51 | help="path to the pickled main loop.") 52 | parser.add_argument("--nrows", type=int, default=10, 53 | help="number of rows of samples to display.") 54 | parser.add_argument("--ncols", type=int, default=10, 55 | help="number of columns of samples to display.") 56 | parser.add_argument("--save-path", type=str, default=None, 57 | help="where to save the reconstructions.") 58 | args = parser.parse_args() 59 | 60 | with open(args.main_loop_path, 'rb') as src: 61 | main_loop = load(src) 62 | num_examples = args.nrows * args.ncols 63 | rng = numpy.random.RandomState() 64 | _1, _2, data_stream = stream_functions[args.which_dataset](num_examples, 65 | num_examples, 66 | rng=rng) 67 | main(main_loop, data_stream, args.nrows, args.ncols, args.save_path) 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### https://raw.github.com/github/gitignore/dad599102e7097c796b794689606874983a77dc1/Python.gitignore 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *,cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask instance folder 59 | instance/ 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # IPython Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # dotenv 80 | .env 81 | 82 | # virtualenv 83 | venv/ 84 | ENV/ 85 | 86 | # Spyder project settings 87 | .spyderproject 88 | 89 | 90 | ### https://raw.github.com/github/gitignore/dad599102e7097c796b794689606874983a77dc1/Global/Emacs.gitignore 91 | 92 | # -*- mode: gitignore; -*- 93 | *~ 94 | \#*\# 95 | /.emacs.desktop 96 | /.emacs.desktop.lock 97 | *.elc 98 | auto-save-list 99 | tramp 100 | .\#* 101 | 102 | # Org-mode 103 | .org-id-locations 104 | *_archive 105 | 106 | # flymake-mode 107 | *_flymake.* 108 | 109 | # eshell files 110 | /eshell/history 111 | /eshell/lastdir 112 | 113 | # elpa packages 114 | /elpa/ 115 | 116 | # reftex files 117 | *.rel 118 | 119 | # AUCTeX auto folder 120 | /auto/ 121 | 122 | # cask packages 123 | .cask/ 124 | 125 | # Flycheck 126 | flycheck_*.el 127 | 128 | # server auth directory 129 | /server/ 130 | 131 | # projectiles files 132 | .projectile 133 | 134 | ### https://raw.github.com/github/gitignore/dad599102e7097c796b794689606874983a77dc1/Global/Vim.gitignore 135 | 136 | # swap 137 | [._]*.s[a-w][a-z] 138 | [._]s[a-w][a-z] 139 | # session 140 | Session.vim 141 | # temporary 142 | .netrwhist 143 | *~ 144 | # auto-generated tag files 145 | tags 146 | 147 | 148 | ### https://raw.github.com/github/gitignore/dad599102e7097c796b794689606874983a77dc1/Global/SublimeText.gitignore 149 | 150 | # cache files for sublime text 151 | *.tmlanguage.cache 152 | *.tmPreferences.cache 153 | *.stTheme.cache 154 | 155 | # workspace files are user-specific 156 | *.sublime-workspace 157 | 158 | # project files should be checked into the repository, unless a significant 159 | # proportion of contributors will probably not be using SublimeText 160 | # *.sublime-project 161 | 162 | # sftp configuration file 163 | sftp-config.json 164 | 165 | 166 | -------------------------------------------------------------------------------- /scripts/preprocess_representations: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | 4 | import h5py 5 | import numpy 6 | import theano 7 | from blocks.graph import ComputationGraph 8 | from blocks.filter import VariableFilter 9 | from blocks.roles import OUTPUT 10 | from blocks.select import Selector 11 | from blocks.serialization import load 12 | from fuel.converters.base import fill_hdf5_file 13 | from fuel.datasets import SVHN 14 | from fuel.streams import DataStream 15 | from fuel.schemes import SequentialScheme 16 | from theano import tensor 17 | 18 | 19 | def preprocess_svhn(main_loop, save_path): 20 | h5file = h5py.File(save_path, mode='w') 21 | 22 | ali, = Selector(main_loop.model.top_bricks).select('/ali').bricks 23 | x = tensor.tensor4('features') 24 | y = tensor.imatrix('targets') 25 | params = ali.encoder.apply(x) 26 | mu = params[:, :ali.encoder._nlat] 27 | acts = [] 28 | acts += [mu] 29 | acts += VariableFilter( 30 | bricks=[ali.encoder.layers[-9], ali.encoder.layers[-6], 31 | ali.encoder.layers[-3]], 32 | roles=[OUTPUT])(ComputationGraph([mu]).variables) 33 | output = tensor.concatenate([act.flatten(ndim=2) for act in acts], axis=1) 34 | preprocess = theano.function([x, y], [output.flatten(ndim=2), y]) 35 | 36 | train_set = SVHN(2, which_sets=('train',), sources=('features', 'targets')) 37 | train_stream = DataStream.default_stream( 38 | train_set, 39 | iteration_scheme=SequentialScheme(train_set.num_examples, 100)) 40 | train_features, train_targets = map( 41 | numpy.vstack, 42 | list(zip(*[preprocess(*batch) for batch in 43 | train_stream.get_epoch_iterator()]))) 44 | 45 | test_set = SVHN(2, which_sets=('test',), sources=('features', 'targets')) 46 | test_stream = DataStream.default_stream( 47 | test_set, 48 | iteration_scheme=SequentialScheme(test_set.num_examples, 100)) 49 | test_features, test_targets = map( 50 | numpy.vstack, 51 | list(zip(*[preprocess(*batch) for batch in 52 | test_stream.get_epoch_iterator()]))) 53 | 54 | data = (('train', 'features', train_features), 55 | ('test', 'features', test_features), 56 | ('train', 'targets', train_targets), 57 | ('test', 'targets', test_targets)) 58 | fill_hdf5_file(h5file, data) 59 | for i, label in enumerate(('batch', 'feature')): 60 | h5file['features'].dims[i].label = label 61 | for i, label in enumerate(('batch', 'index')): 62 | h5file['targets'].dims[i].label = label 63 | 64 | h5file.flush() 65 | h5file.close() 66 | 67 | 68 | if __name__ == "__main__": 69 | parser = argparse.ArgumentParser( 70 | description="Preprocess ALI latent representations on SVHN for " 71 | "semi-supervised learning") 72 | parser.add_argument("main_loop_path", type=str, 73 | help="path to the pickled main loop") 74 | parser.add_argument("save_path", type=str, 75 | help="where to save the preprocessed dataset") 76 | args = parser.parse_args() 77 | 78 | with open(args.main_loop_path, 'rb') as src: 79 | preprocess_svhn(load(src), args.save_path) 80 | -------------------------------------------------------------------------------- /scripts/interpolate: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | 4 | import numpy 5 | import theano 6 | from blocks.serialization import load 7 | from matplotlib import cm, pyplot 8 | from mpl_toolkits.axes_grid1 import ImageGrid 9 | from theano import tensor 10 | 11 | from ali import streams 12 | 13 | 14 | def main(main_loop, data_stream, num_pairs, num_steps, save_path): 15 | ali, = main_loop.model.top_bricks 16 | x = tensor.tensor4('features') 17 | z = tensor.tensor4('z') 18 | encode = theano.function([x], ali.encoder.apply(x)) 19 | decode = theano.function([z], ali.decoder.apply(z)) 20 | 21 | it = data_stream.get_epoch_iterator() 22 | (from_x,), (to_x,) = next(it), next(it) 23 | from_z, to_z = encode(from_x), encode(to_x) 24 | 25 | from_to_tensor = to_z - from_z 26 | between_x_list = [from_x] 27 | for alpha in numpy.linspace(0, 1, num_steps + 1): 28 | between_z = from_z + alpha * from_to_tensor 29 | between_x_list.append(decode(between_z)) 30 | between_x_list.append(to_x) 31 | 32 | figure = pyplot.figure() 33 | grid = ImageGrid(figure, 111, (num_pairs, num_steps + 3), axes_pad=0.1) 34 | images = numpy.empty( 35 | (num_pairs * (num_steps + 3),) + between_x_list[0].shape[1:], 36 | dtype=between_x_list[0].dtype) 37 | for i, between_x in enumerate(between_x_list): 38 | images[i::num_steps + 3] = between_x 39 | 40 | for image, axis in zip(images, grid): 41 | axis.imshow(image.transpose(1, 2, 0).squeeze(), 42 | cmap=cm.Greys_r, interpolation='nearest') 43 | axis.set_yticklabels(['' for _ in range(image.shape[1])]) 44 | axis.set_xticklabels(['' for _ in range(image.shape[2])]) 45 | axis.axis('off') 46 | 47 | if save_path is None: 48 | pyplot.show() 49 | else: 50 | pyplot.savefig(save_path, transparent=True, bbox_inches='tight') 51 | 52 | 53 | if __name__ == "__main__": 54 | stream_functions = { 55 | 'cifar10': streams.create_cifar10_data_streams, 56 | 'svhn': streams.create_svhn_data_streams, 57 | 'celeba': streams.create_celeba_data_streams, 58 | 'tiny_imagenet': streams.create_tiny_imagenet_data_streams} 59 | parser = argparse.ArgumentParser(description="Plot interpolations.") 60 | parser.add_argument("which_dataset", type=str, 61 | choices=tuple(stream_functions.keys()), 62 | help="which dataset to compute interpolations on.") 63 | parser.add_argument("main_loop_path", type=str, 64 | help="path to the pickled main loop.") 65 | parser.add_argument("--num-pairs", type=int, default=10, 66 | help="number of pairs of samples to interpolate.") 67 | parser.add_argument("--num-steps", type=int, default=10, 68 | help="number of interpolation steps.") 69 | parser.add_argument("--save-path", type=str, default=None, 70 | help="where to save the interpolations.") 71 | args = parser.parse_args() 72 | 73 | with open(args.main_loop_path, 'rb') as src: 74 | main_loop = load(src) 75 | num_pairs = args.num_pairs 76 | rng = numpy.random.RandomState() 77 | _1, _2, data_stream = stream_functions[args.which_dataset](num_pairs, 78 | num_pairs, 79 | rng=rng) 80 | main(main_loop, data_stream, num_pairs, args.num_steps, args.save_path) 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarially Learned Inference 2 | 3 | Code for the Adversarially Learned Inference paper. 4 | 5 | ## Compiling the paper locally 6 | 7 | From the repo's root directory, 8 | 9 | ``` bash 10 | $ cd papers 11 | $ latexmk --pdf adverarially_learned_inference 12 | ``` 13 | 14 | ## Requirements 15 | 16 | * [Blocks](https://blocks.readthedocs.org/en/latest/), development version 17 | * [Fuel](https://fuel.readthedocs.org/en/latest/), development version 18 | 19 | ## Setup 20 | 21 | Clone the repository, then install with 22 | 23 | ``` bash 24 | $ pip install -e ALI 25 | ``` 26 | 27 | ## Downloading and converting the datasets 28 | 29 | Set up your `~/.fuelrc` file: 30 | 31 | ``` bash 32 | $ echo "data_path: \"\"" > ~/.fuelrc 33 | ``` 34 | 35 | Go to ``: 36 | 37 | ``` bash 38 | $ cd 39 | ``` 40 | 41 | Download the CIFAR-10 dataset: 42 | 43 | ``` bash 44 | $ fuel-download cifar10 45 | $ fuel-convert cifar10 46 | $ fuel-download cifar10 --clear 47 | ``` 48 | 49 | Download the SVHN format 2 dataset: 50 | 51 | ``` bash 52 | $ fuel-download svhn 2 53 | $ fuel-convert svhn 2 54 | $ fuel-download svhn 2 --clear 55 | ``` 56 | 57 | Download the CelebA dataset: 58 | 59 | ``` bash 60 | $ fuel-download celeba 64 61 | $ fuel-convert celeba 64 62 | $ fuel-download celeba 64 --clear 63 | ``` 64 | 65 | ## Training the models 66 | 67 | Make sure you're in the repo's root directory. 68 | 69 | ### CIFAR-10 70 | 71 | ``` bash 72 | $ THEANORC=theanorc python experiments/ali_cifar10.py 73 | ``` 74 | 75 | ### SVHN 76 | 77 | ``` bash 78 | $ THEANORC=theanorc python experiments/ali_svhn.py 79 | ``` 80 | 81 | ### CelebA 82 | 83 | ``` bash 84 | $ THEANORC=theanorc python experiments/ali_celeba.py 85 | ``` 86 | 87 | ### Toy task 88 | 89 | ``` bash 90 | $ THEANORC=theanorc python experiments/ali_mixture.py 91 | ``` 92 | 93 | ``` bash 94 | $ THEANORC=theanorc python experiments/gan_mixture.py 95 | ``` 96 | 97 | ## Evaluating the models 98 | 99 | ### Samples 100 | 101 | ``` bash 102 | $ THEANORC=theanorc scripts/sample [main_loop.tar] 103 | ``` 104 | 105 | e.g. 106 | 107 | ``` bash 108 | $ THEANORC=theanorc scripts/sample ali_cifar10.tar 109 | ``` 110 | 111 | ### Interpolations 112 | 113 | ``` bash 114 | $ THEANORC=theanorc scripts/interpolate [which_dataset] [main_loop.tar] 115 | ``` 116 | 117 | e.g. 118 | 119 | ``` bash 120 | $ THEANORC=theanorc scripts/interpolate celeba ali_celeba.tar 121 | ``` 122 | 123 | ### Reconstructions 124 | 125 | ``` bash 126 | $ THEANORC=theanorc scripts/reconstruct [which_dataset] [main_loop.tar] 127 | ``` 128 | 129 | e.g. 130 | 131 | ``` bash 132 | $ THEANORC=theanorc scripts/reconstruct cifar10 ali_cifar10.tar 133 | ``` 134 | 135 | ### Semi-supervised learning on SVHN 136 | 137 | First, preprocess the SVHN dataset with the learned ALI features: 138 | 139 | ``` bash 140 | $ THEANORC=theanorc scripts/preprocess_representations [main_loop.tar] [save_path.hdf5] 141 | ``` 142 | 143 | e.g. 144 | 145 | ``` bash 146 | $ THEANORC=theanorc scripts/preprocess_representations ali_svhn.tar ali_svhn_preprocessed.hdf5 147 | ``` 148 | 149 | Then, launch the semi-supervised script: 150 | 151 | ``` bash 152 | $ python experiments/semi_supervised_svhn.py ali_svhn.tar [save_path.hdf5] 153 | ``` 154 | 155 | e.g. 156 | 157 | ``` bash 158 | $ python experiments/semi_supervised_svhn.py ali_svhn_preprocessed.hdf5 159 | 160 | [...] 161 | Validation error rate = ... +- ... 162 | Test error rate = ... +- ... 163 | ``` 164 | 165 | ### Toy task 166 | 167 | ``` bash 168 | $ THEANORC=theanorc scripts/generate_mixture_plots [ali_main_loop.tar] [gan_main_loop.tar] 169 | ``` 170 | 171 | e.g. 172 | 173 | ``` bash 174 | $ THEANORC=theanorc scripts/generate_mixture_plots ali_mixture.tar gan_mixture.tar 175 | ``` 176 | -------------------------------------------------------------------------------- /scripts/generate_spiral_plots: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | 4 | import theano 5 | from blocks.bricks.interfaces import Random 6 | from blocks.serialization import load 7 | from matplotlib import pyplot, rc 8 | from theano import tensor 9 | 10 | from ali.streams import create_spiral_data_streams 11 | 12 | rc('font', **{'family': 'serif', 'serif': 'Computer Modern Roman'}) 13 | rc('text', usetex=True) 14 | 15 | 16 | def main(ali_main_loop, gan_main_loop, save_path=None): 17 | ali, = ali_main_loop.model.top_bricks 18 | gan, = gan_main_loop.model.top_bricks 19 | random_brick = Random() 20 | 21 | _1, _2, stream = create_spiral_data_streams(1000, 1000) 22 | 23 | x = tensor.as_tensor_variable(next(stream.get_epoch_iterator())[0]) 24 | z = random_brick.theano_rng.normal( 25 | size=(x.shape[0], ali.decoder.input_dim), dtype=x.dtype) 26 | params = ali.encoder.apply(x) 27 | latent_dim = ali.decoder.input_dim 28 | mu, log_sigma = params[:, :latent_dim], params[:, latent_dim:] 29 | epsilon = random_brick.theano_rng.normal(size=mu.shape, dtype=mu.dtype) 30 | z_hat = mu + tensor.exp(log_sigma) * epsilon 31 | x_tilde = ali.decoder.apply(z) 32 | x_hat = ali.decoder.apply(z_hat) 33 | gan_x_tilde = gan.decoder.apply(z) 34 | 35 | samples = theano.function([], [x, x_tilde, gan_x_tilde, x_hat, z, z_hat])() 36 | x, x_tilde, gan_x_tilde, x_hat, z, z_hat = samples 37 | 38 | figure, axes = pyplot.subplots(nrows=2, ncols=3) 39 | for ax in axes.ravel(): 40 | ax.set_aspect('equal') 41 | ax.set_xticks([]) 42 | ax.set_yticks([]) 43 | for ax in axes[0]: 44 | ax.set_xlim([-2, 2]) 45 | ax.set_ylim([-2, 2]) 46 | ax.set_xlabel('$x_1$') 47 | ax.set_ylabel('$x_2$') 48 | for ax in axes[1]: 49 | ax.set_xlim([-4, 4]) 50 | ax.set_ylim([-4, 4]) 51 | ax.set_xlabel('$z_1$') 52 | ax.set_ylabel('$z_2$') 53 | axes[0, 0].set_title('ALI reconstructions') 54 | axes[0, 0].scatter(x[:, 0], x[:, 1], marker='o', c='black', alpha=0.3) 55 | axes[0, 0].scatter(x_hat[:, 0], x_hat[:, 1], marker='o', c='blue', 56 | alpha=0.3) 57 | axes[0, 1].set_title('ALI samples') 58 | axes[0, 1].scatter(x[:, 0], x[:, 1], marker='o', c='black', alpha=0.3) 59 | axes[0, 1].scatter(x_tilde[:, 0], x_tilde[:, 1], marker='o', c='blue', 60 | alpha=0.3) 61 | axes[0, 2].set_title('GAN samples') 62 | axes[0, 2].scatter(x[:, 0], x[:, 1], marker='o', c='black', alpha=0.3) 63 | axes[0, 2].scatter(gan_x_tilde[:, 0], gan_x_tilde[:, 1], marker='o', 64 | c='blue', alpha=0.3) 65 | axes[1, 0].set_title('ALI encoding') 66 | axes[1, 0].scatter(z_hat[:, 0], z_hat[:, 1], marker='o', c='blue', 67 | alpha=0.3) 68 | axes[1, 1].set_title('Prior') 69 | axes[1, 1].scatter(z[:, 0], z[:, 1], marker='o', c='blue', alpha=0.3) 70 | axes[1, 2].set_title('Prior') 71 | axes[1, 2].scatter(z[:, 0], z[:, 1], marker='o', c='blue', alpha=0.3) 72 | 73 | pyplot.tight_layout() 74 | if save_path is None: 75 | pyplot.show() 76 | else: 77 | pyplot.savefig(save_path, transparent=True, bbox_inches='tight') 78 | 79 | 80 | if __name__ == "__main__": 81 | parser = argparse.ArgumentParser(description="Plot Spiral samples.") 82 | parser.add_argument("ali_main_loop_path", type=str, 83 | help="path to the pickled ALI main loop.") 84 | parser.add_argument("gan_main_loop_path", type=str, 85 | help="path to the pickled GAN main loop.") 86 | parser.add_argument("--save-path", type=str, default=None, 87 | help="where to save the generated samples.") 88 | args = parser.parse_args() 89 | 90 | with open(args.ali_main_loop_path, 'rb') as ali_src: 91 | with open(args.gan_main_loop_path, 'rb') as gan_src: 92 | main(load(ali_src), load(gan_src), args.save_path) 93 | -------------------------------------------------------------------------------- /scripts/generate_mixture_plots: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import itertools 4 | 5 | import numpy 6 | import theano 7 | from blocks.serialization import load 8 | from matplotlib import pyplot, rc 9 | from theano import tensor 10 | 11 | from ali.datasets import GaussianMixture 12 | 13 | rc('font', **{'family': 'serif', 'serif': 'Computer Modern Roman'}) 14 | rc('text', usetex=True) 15 | 16 | MEANS = [numpy.array([i, j]) for i, j in itertools.product(range(-4, 5, 2), 17 | range(-4, 5, 2))] 18 | VARIANCES = [0.05 ** 2 * numpy.eye(len(mean)) for mean in MEANS] 19 | PRIORS = None 20 | 21 | 22 | def main(ali_main_loop, gan_main_loop, save_path=None): 23 | ali, = ali_main_loop.model.top_bricks 24 | gan, = gan_main_loop.model.top_bricks 25 | 26 | dataset = GaussianMixture(num_examples=2500, 27 | means=MEANS, variances=VARIANCES, priors=None, 28 | rng=None, sources=('features', 'label')) 29 | features, targets = dataset.indexables 30 | 31 | x = tensor.as_tensor_variable(features) 32 | z = ali.theano_rng.normal( 33 | size=(x.shape[0], ali.encoder.mapping.output_dim), dtype=x.dtype) 34 | z_hat = ali.encoder.apply(x) 35 | x_tilde = ali.decoder.apply(z) 36 | x_hat = ali.decoder.apply(z_hat) 37 | gan_x_tilde = gan.decoder.apply(z) 38 | 39 | samples = theano.function([], [x, x_tilde, x_hat, gan_x_tilde, z, z_hat])() 40 | x, x_tilde, x_hat, gan_x_tilde, z, z_hat = samples 41 | 42 | figure, axes = pyplot.subplots(nrows=2, ncols=4, figsize=(8, 4.5)) 43 | for ax in axes.ravel(): 44 | ax.set_aspect('equal') 45 | for ax in axes[0]: 46 | ax.set_xlim([-6, 6]) 47 | ax.set_ylim([-6, 6]) 48 | ax.set_xticks([-6, -4, -2, 0, 2, 4, 6]) 49 | ax.set_yticks([-6, -4, -2, 0, 2, 4, 6]) 50 | ax.set_xlabel('$x_1$') 51 | ax.set_ylabel('$x_2$') 52 | for ax in axes[1]: 53 | ax.set_xlim([-4, 4]) 54 | ax.set_ylim([-4, 4]) 55 | ax.set_xticks([-4, -2, 0, 2, 4]) 56 | ax.set_yticks([-4, -2, 0, 2, 4]) 57 | ax.set_xlabel('$z_1$') 58 | ax.set_ylabel('$z_2$') 59 | # ALI - q(x, z) 60 | axes[0, 0].set_title('$\mathbf{x} \sim q(\mathbf{x})$') 61 | axes[0, 0].scatter(x[:, 0], x[:, 1], marker='.', c=targets.ravel(), 62 | alpha=0.3) 63 | axes[1, 1].set_title( 64 | '$\hat{\mathbf{z}} \sim q(\mathbf{z} \mid \mathbf{x})$') 65 | axes[1, 1].scatter(z_hat[:, 0], z_hat[:, 1], marker='.', c=targets.ravel(), 66 | alpha=0.3) 67 | # ALI - p(x, z) 68 | axes[0, 2].set_title( 69 | '$\\tilde{\mathbf{x}} \sim p(\mathbf{x} \mid \mathbf{z})$') 70 | axes[0, 2].scatter(x_tilde[:, 0], x_tilde[:, 1], marker='.', c='black', 71 | alpha=0.3) 72 | axes[1, 2].set_title('$\mathbf{z} \sim p(\mathbf{z})$') 73 | axes[1, 2].scatter(z[:, 0], z[:, 1], marker='.', c='black', alpha=0.3) 74 | # ALI - q(z) p(x | z) (reconstruction) 75 | axes[0, 1].set_title( 76 | '$\hat{\mathbf{x}} \sim p(\mathbf{x} \mid \hat{\mathbf{z}}$)') 77 | axes[0, 1].scatter(x_hat[:, 0], x_hat[:, 1], marker='.', 78 | c=targets.ravel(), alpha=0.3) 79 | # GAN - p(x) 80 | axes[0, 3].set_title('GAN $\mathbf{x} = G(\mathbf{z})$') 81 | axes[0, 3].scatter(gan_x_tilde[:, 0], gan_x_tilde[:, 1], marker='.', 82 | c='black', alpha=0.3) 83 | axes[1, 0].axis('off') 84 | axes[1, 3].axis('off') 85 | 86 | pyplot.tight_layout() 87 | if save_path is None: 88 | pyplot.show() 89 | else: 90 | pyplot.savefig(save_path, transparent=True, bbox_inches='tight') 91 | 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser(description="Plot samples.") 95 | parser.add_argument("ali_main_loop_path", type=str) 96 | parser.add_argument("gan_main_loop_path", type=str) 97 | parser.add_argument("--save-path", type=str, default=None) 98 | args = parser.parse_args() 99 | with open(args.ali_main_loop_path, 'rb') as ali_src: 100 | with open(args.gan_main_loop_path, 'rb') as gan_src: 101 | main(load(ali_src), load(gan_src), args.save_path) 102 | -------------------------------------------------------------------------------- /experiments/semi_supervised_svhn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy 4 | from fuel.datasets import H5PYDataset 5 | from fuel.schemes import ShuffledScheme, ShuffledExampleScheme 6 | from fuel.streams import DataStream 7 | from sklearn.svm import LinearSVC 8 | 9 | 10 | def main(dataset_path, use_c, log_min, log_max, num_steps): 11 | train_set = H5PYDataset( 12 | dataset_path, which_sets=('train',), sources=('features', 'targets'), 13 | subset=slice(0, 63257), load_in_memory=True) 14 | train_stream = DataStream.default_stream( 15 | train_set, 16 | iteration_scheme=ShuffledExampleScheme(train_set.num_examples)) 17 | 18 | def get_class_balanced_batch(iterator): 19 | train_features = [[] for _ in range(10)] 20 | train_targets = [[] for _ in range(10)] 21 | batch_size = 0 22 | while batch_size < 1000: 23 | f, t = next(iterator) 24 | t = t[0] 25 | if len(train_features[t]) < 100: 26 | train_features[t].append(f) 27 | train_targets[t].append(t) 28 | batch_size += 1 29 | train_features = numpy.vstack(sum(train_features, [])) 30 | train_targets = numpy.vstack(sum(train_targets, [])) 31 | return train_features, train_targets 32 | 33 | train_features, train_targets = get_class_balanced_batch( 34 | train_stream.get_epoch_iterator()) 35 | 36 | valid_set = H5PYDataset( 37 | dataset_path, which_sets=('train',), sources=('features', 'targets'), 38 | subset=slice(63257, 73257), load_in_memory=True) 39 | valid_features, valid_targets = valid_set.data_sources 40 | 41 | test_set = H5PYDataset( 42 | dataset_path, which_sets=('test',), sources=('features', 'targets'), 43 | load_in_memory=True) 44 | test_features, test_targets = test_set.data_sources 45 | 46 | if use_c is None: 47 | best_error_rate = 1.0 48 | best_C = None 49 | for log_C in numpy.linspace(log_min, log_max, num_steps): 50 | C = numpy.exp(log_C) 51 | svm = LinearSVC(C=C) 52 | svm.fit(train_features, train_targets.ravel()) 53 | error_rate = 1 - numpy.mean( 54 | [svm.score(valid_features[1000 * i: 1000 * (i + 1)], 55 | valid_targets[1000 * i: 1000 * (i + 1)].ravel()) 56 | for i in range(10)]) 57 | if error_rate < best_error_rate: 58 | best_error_rate = error_rate 59 | best_C = C 60 | print('C = {}, validation error rate = {} '.format(C, error_rate) + 61 | '(best is {}, {})'.format(best_C, best_error_rate)) 62 | else: 63 | best_C = use_c 64 | 65 | error_rates = [] 66 | for _ in range(10): 67 | train_features, train_targets = get_class_balanced_batch( 68 | train_stream.get_epoch_iterator()) 69 | svm = LinearSVC(C=best_C) 70 | svm.fit(train_features, train_targets.ravel()) 71 | error_rates.append(1 - numpy.mean( 72 | [svm.score(valid_features[1000 * i: 1000 * (i + 1)], 73 | valid_targets[1000 * i: 1000 * (i + 1)].ravel()) 74 | for i in range(10)])) 75 | 76 | print('Validation error rate = {} +- {} '.format(numpy.mean(error_rates), 77 | numpy.std(error_rates))) 78 | 79 | error_rates = [] 80 | for _ in range(100): 81 | train_features, train_targets = get_class_balanced_batch( 82 | train_stream.get_epoch_iterator()) 83 | svm = LinearSVC(C=best_C) 84 | svm.fit(train_features, train_targets.ravel()) 85 | s = 1000 * numpy.sum( 86 | [svm.score(test_features[1000 * i: 1000 * (i + 1)], 87 | test_targets[1000 * i: 1000 * (i + 1)].ravel()) 88 | for i in range(26)]) 89 | s += 32 * svm.score(test_features[-32:], test_targets[-32:].ravel()) 90 | s = s / 26032.0 91 | error_rates.append(1 - s) 92 | 93 | print('Test error rate = {} +- {} '.format(numpy.mean(error_rates), 94 | numpy.std(error_rates))) 95 | 96 | 97 | if __name__ == "__main__": 98 | parser = argparse.ArgumentParser(description="ALI-based semi-supervised " 99 | "training on SVHN") 100 | parser.add_argument("dataset_path", type=str, 101 | help="path to the saved main loop") 102 | parser.add_argument("--use-c", type=float, default=None, 103 | help="evaluate using a specific C value") 104 | parser.add_argument("--log-min", type=float, default=-20, 105 | help="minimum C value in log-space") 106 | parser.add_argument("--log-max", type=float, default=20, 107 | help="maximum C value in log-space") 108 | parser.add_argument("--num-steps", type=int, default=50, 109 | help="number of values to try") 110 | args = parser.parse_args() 111 | main(args.dataset_path, args.use_c, args.log_min, args.log_max, 112 | args.num_steps) 113 | -------------------------------------------------------------------------------- /ali/graphing.py: -------------------------------------------------------------------------------- 1 | """ ALI related graphs """ 2 | from functools import partial 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import scipy 6 | 7 | from ali.datasets import GaussianMixtureDistribution 8 | from ali.utils import as_array 9 | 10 | 11 | def make_2D_latent_view(valid_data, 12 | samples_data, 13 | gradients_funs=None, 14 | densities_funs=None, 15 | epoch=None, 16 | save_path=None): 17 | """ 18 | 2D views of the latent and visible spaces 19 | Parameters 20 | ---------- 21 | valid_data: dictionary of numpy arrays 22 | Holds five keys: originals, labels, mu, sigma, encoding, reconstructions 23 | samples_data: dictionary of numpy arrays 24 | Holds two keys prior and samples 25 | gradients_funs: dict of functions 26 | Holds two keys: latent, for the gradients on the latent space w.r.p to Z and 27 | visible, for the gradients ob the visible space 28 | densities_fun: dictionary of functions 29 | Holds two keys: latent, for the probability density of the latent space, and 30 | visible, for the probability density on the latent space 31 | """ 32 | 33 | # Creating figure 34 | fig = plt.figure() 35 | # Getting Cmap 36 | cmap = plt.cm.get_cmap('Spectral', 5) 37 | # Adding visible subplot 38 | recons_visible_ax = fig.add_subplot(221, aspect='equal') 39 | # Train data 40 | recons_visible_ax.scatter(valid_data['originals'][:, 0], 41 | valid_data['originals'][:, 1], 42 | c=valid_data['labels'], 43 | marker='s', label='originals', 44 | alpha=0.3, cmap=cmap) 45 | 46 | recons_visible_ax.scatter(valid_data['reconstructions'][:, 0], 47 | valid_data['reconstructions'][:, 1], 48 | c=valid_data['labels'], 49 | marker='x', label='reconstructions', 50 | alpha=0.3, 51 | cmap=cmap) 52 | 53 | recons_visible_ax.set_title('Visible space. Epoch {}'.format(str(epoch))) 54 | samples_visible_ax = fig.add_subplot(222, aspect='equal', 55 | sharex=recons_visible_ax, 56 | sharey=recons_visible_ax) 57 | 58 | samples_visible_ax.scatter(valid_data['originals'][:, 0], 59 | valid_data['originals'][:, 1], 60 | c=valid_data['labels'], 61 | marker='s', label='originals', 62 | alpha=0.3, 63 | cmap=cmap) 64 | 65 | samples_visible_ax.scatter(samples_data['samples'][:, 0], 66 | samples_data['samples'][:, 1], 67 | marker='o', alpha=0.3, label='samples') 68 | samples_visible_ax.set_title('Visible space. Epoch {}'.format(str(epoch))) 69 | 70 | # plt.legend(loc="upper left", bbox_to_anchor=[0, 1], 71 | # shadow=True, title="Legend", fancybox=True) 72 | # visible_ax.get_legend() 73 | 74 | # Adding latent subplot 75 | recons_latent_ax = fig.add_subplot(223, aspect='equal') 76 | recons_latent_ax.scatter(valid_data['encodings'][:, 0], 77 | valid_data['encodings'][:, 1], 78 | c=valid_data['labels'], 79 | marker='x', label='encodings', 80 | alpha=0.3, cmap=cmap) 81 | recons_latent_ax.set_title('Latent space. Epoch {}'.format(str(epoch))) 82 | 83 | samples_latent_ax = fig.add_subplot(224, aspect='equal', 84 | sharex=recons_latent_ax, 85 | sharey=recons_latent_ax) 86 | samples_latent_ax.scatter(samples_data['noise'][:, 0], 87 | samples_data['noise'][:, 1], 88 | marker='o', label='noise', 89 | alpha=0.3) 90 | samples_latent_ax.set_title('Latent space. Epoch {}'.format(str(epoch))) 91 | 92 | # plt.legend(loc="upper left", bbox_to_anchor=[0, 1], 93 | # shadow=True, title="Legend", fancybox=True) 94 | # latent_ax.get_legend() 95 | plt.tight_layout() 96 | 97 | if save_path is None: 98 | plt.show() 99 | else: 100 | plt.savefig(save_path, transparent=True, bbox_inches='tight') 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | if __name__ == '__main__': 119 | means = map(lambda x: as_array(x), [[0, 0], 120 | [1, 1], 121 | [-1, -1], 122 | [1, -1], 123 | [-1, 1]]) 124 | std = 0.01 125 | variances = [np.eye(2) * std for _ in means] 126 | priors = [1.0/len(means) for _ in means] 127 | 128 | gaussian_mixture = GaussianMixtureDistribution(means=means, 129 | variances=variances, 130 | priors=priors) 131 | originals, labels = gaussian_mixture.sample(1000) 132 | reconstructions = originals * np.random.normal(size=originals.shape, 133 | scale=0.05) 134 | encodings = np.random.normal(size=(1000, 2)) 135 | train_data = {'originals': originals, 'labels': labels, 136 | 'encodings': encodings, 137 | 'reconstructions': reconstructions} 138 | valid_data = train_data 139 | 140 | noise = np.random.normal(size=(1000, 2)) 141 | samples = np.random.normal(size=(1000, 2), scale=0.3) 142 | samples_data = {'noise': noise, 143 | 'samples': samples} 144 | 145 | #make_2D_latent_view(train_data, valid_data, samples_data) 146 | make_assignement_plots(valid_data) 147 | 148 | 149 | -------------------------------------------------------------------------------- /ali/streams.py: -------------------------------------------------------------------------------- 1 | """Functions for creating data streams.""" 2 | from fuel.datasets import CIFAR10, SVHN, CelebA 3 | from fuel.datasets.toy import Spiral 4 | from fuel.schemes import ShuffledScheme 5 | from fuel.streams import DataStream 6 | 7 | from .datasets import TinyILSVRC2012, GaussianMixture 8 | 9 | 10 | def create_svhn_data_streams(batch_size, monitoring_batch_size, rng=None): 11 | train_set = SVHN(2, ('extra',), sources=('features',)) 12 | valid_set = SVHN(2, ('train',), sources=('features',)) 13 | main_loop_stream = DataStream.default_stream( 14 | train_set, 15 | iteration_scheme=ShuffledScheme( 16 | train_set.num_examples, batch_size, rng=rng)) 17 | train_monitor_stream = DataStream.default_stream( 18 | train_set, 19 | iteration_scheme=ShuffledScheme( 20 | 5000, monitoring_batch_size, rng=rng)) 21 | valid_monitor_stream = DataStream.default_stream( 22 | valid_set, 23 | iteration_scheme=ShuffledScheme( 24 | 5000, monitoring_batch_size, rng=rng)) 25 | return main_loop_stream, train_monitor_stream, valid_monitor_stream 26 | 27 | 28 | def create_cifar10_data_streams(batch_size, monitoring_batch_size, rng=None): 29 | train_set = CIFAR10( 30 | ('train',), sources=('features',), subset=slice(0, 45000)) 31 | valid_set = CIFAR10( 32 | ('train',), sources=('features',), subset=slice(45000, 50000)) 33 | main_loop_stream = DataStream.default_stream( 34 | train_set, 35 | iteration_scheme=ShuffledScheme( 36 | train_set.num_examples, batch_size, rng=rng)) 37 | train_monitor_stream = DataStream.default_stream( 38 | train_set, 39 | iteration_scheme=ShuffledScheme( 40 | 5000, monitoring_batch_size, rng=rng)) 41 | valid_monitor_stream = DataStream.default_stream( 42 | valid_set, 43 | iteration_scheme=ShuffledScheme( 44 | 5000, monitoring_batch_size, rng=rng)) 45 | return main_loop_stream, train_monitor_stream, valid_monitor_stream 46 | 47 | 48 | def create_celeba_data_streams(batch_size, monitoring_batch_size, 49 | sources=('features', ), rng=None): 50 | train_set = CelebA('64', ('train',), sources=sources) 51 | valid_set = CelebA('64', ('valid',), sources=sources) 52 | main_loop_stream = DataStream.default_stream( 53 | train_set, 54 | iteration_scheme=ShuffledScheme( 55 | train_set.num_examples, batch_size, rng=rng)) 56 | train_monitor_stream = DataStream.default_stream( 57 | train_set, 58 | iteration_scheme=ShuffledScheme( 59 | 5000, monitoring_batch_size, rng=rng)) 60 | valid_monitor_stream = DataStream.default_stream( 61 | valid_set, 62 | iteration_scheme=ShuffledScheme( 63 | 5000, monitoring_batch_size, rng=rng)) 64 | return main_loop_stream, train_monitor_stream, valid_monitor_stream 65 | 66 | 67 | def create_tiny_imagenet_data_streams(batch_size, monitoring_batch_size, 68 | rng=None): 69 | train_set = TinyILSVRC2012(('train',), sources=('features',)) 70 | valid_set = TinyILSVRC2012(('valid',), sources=('features',)) 71 | main_loop_stream = DataStream.default_stream( 72 | train_set, 73 | iteration_scheme=ShuffledScheme( 74 | train_set.num_examples, batch_size, rng=rng)) 75 | train_monitor_stream = DataStream.default_stream( 76 | train_set, 77 | iteration_scheme=ShuffledScheme( 78 | 4096, monitoring_batch_size, rng=rng)) 79 | valid_monitor_stream = DataStream.default_stream( 80 | valid_set, 81 | iteration_scheme=ShuffledScheme( 82 | 4096, monitoring_batch_size, rng=rng)) 83 | return main_loop_stream, train_monitor_stream, valid_monitor_stream 84 | 85 | 86 | def create_spiral_data_streams(batch_size, monitoring_batch_size, rng=None, 87 | num_examples=100000, classes=1, cycles=2, 88 | noise=0.1): 89 | train_set = Spiral(num_examples=num_examples, classes=classes, 90 | cycles=cycles, noise=noise, sources=('features',)) 91 | 92 | valid_set = Spiral(num_examples=num_examples, classes=classes, 93 | cycles=cycles, noise=noise, sources=('features',)) 94 | 95 | main_loop_stream = DataStream.default_stream( 96 | train_set, 97 | iteration_scheme=ShuffledScheme( 98 | train_set.num_examples, batch_size=batch_size, rng=rng)) 99 | 100 | train_monitor_stream = DataStream.default_stream( 101 | train_set, iteration_scheme=ShuffledScheme(5000, batch_size, rng=rng)) 102 | 103 | valid_monitor_stream = DataStream.default_stream( 104 | valid_set, iteration_scheme=ShuffledScheme(5000, batch_size, rng=rng)) 105 | 106 | return main_loop_stream, train_monitor_stream, valid_monitor_stream 107 | 108 | 109 | def create_gaussian_mixture_data_streams(batch_size, monitoring_batch_size, 110 | means=None, variances=None, priors=None, 111 | rng=None, num_examples=100000, 112 | sources=('features', )): 113 | train_set = GaussianMixture(num_examples=num_examples, means=means, 114 | variances=variances, priors=priors, 115 | rng=rng, sources=sources) 116 | 117 | valid_set = GaussianMixture(num_examples=num_examples, 118 | means=means, variances=variances, 119 | priors=priors, rng=rng, sources=sources) 120 | 121 | main_loop_stream = DataStream( 122 | train_set, 123 | iteration_scheme=ShuffledScheme( 124 | train_set.num_examples, batch_size=batch_size, rng=rng)) 125 | 126 | train_monitor_stream = DataStream( 127 | train_set, iteration_scheme=ShuffledScheme(5000, batch_size, rng=rng)) 128 | 129 | valid_monitor_stream = DataStream( 130 | valid_set, iteration_scheme=ShuffledScheme(5000, batch_size, rng=rng)) 131 | 132 | return main_loop_stream, train_monitor_stream, valid_monitor_stream 133 | -------------------------------------------------------------------------------- /ali/datasets.py: -------------------------------------------------------------------------------- 1 | """Additional dataset classes.""" 2 | from __future__ import (division, print_function, ) 3 | from collections import OrderedDict 4 | from scipy.stats import multivariate_normal 5 | 6 | import numpy as np 7 | import numpy.random as npr 8 | 9 | from fuel import config 10 | from fuel.datasets import H5PYDataset, IndexableDataset 11 | from fuel.transformers.defaults import uint8_pixels_to_floatX 12 | from fuel.utils import find_in_data_path 13 | 14 | from ali.utils import as_array 15 | 16 | 17 | class TinyILSVRC2012(H5PYDataset): 18 | """The Tiny ILSVRC2012 Dataset. 19 | 20 | Parameters 21 | ---------- 22 | which_sets : tuple of str 23 | Which split to load. Valid values are 'train' (1,281,167 examples) 24 | 'valid' (50,000 examples), and 'test' (100,000 examples). 25 | 26 | """ 27 | filename = 'ilsvrc2012_tiny.hdf5' 28 | default_transformers = uint8_pixels_to_floatX(('features',)) 29 | 30 | def __init__(self, which_sets, **kwargs): 31 | kwargs.setdefault('load_in_memory', False) 32 | super(TinyILSVRC2012, self).__init__( 33 | file_or_path=find_in_data_path(self.filename), 34 | which_sets=which_sets, **kwargs) 35 | 36 | 37 | class GaussianMixture(IndexableDataset): 38 | """ Toy dataset containing points sampled from a gaussian mixture distribution. 39 | 40 | The dataset contains 3 sources: 41 | * features 42 | * label 43 | * densities 44 | 45 | """ 46 | def __init__(self, num_examples, means=None, variances=None, priors=None, 47 | **kwargs): 48 | rng = kwargs.pop('rng', None) 49 | if rng is None: 50 | seed = kwargs.pop('seed', config.default_seed) 51 | rng = np.random.RandomState(seed) 52 | 53 | gaussian_mixture = GaussianMixtureDistribution(means=means, 54 | variances=variances, 55 | priors=priors, 56 | rng=rng) 57 | self.means = gaussian_mixture.means 58 | self.variances = gaussian_mixture.variances 59 | self.priors = gaussian_mixture.priors 60 | 61 | features, labels = gaussian_mixture.sample(nsamples=num_examples) 62 | densities = gaussian_mixture.pdf(x=features) 63 | 64 | data = OrderedDict([ 65 | ('features', features), 66 | ('label', labels), 67 | ('density', densities) 68 | ]) 69 | 70 | super(GaussianMixture, self).__init__(data, **kwargs) 71 | 72 | 73 | class GaussianMixtureDistribution(object): 74 | """ Gaussian Mixture Distribution 75 | 76 | Parameters 77 | ---------- 78 | means : tuple of ndarray. 79 | Specifies the means for the gaussian components. 80 | variances : tuple of ndarray. 81 | Specifies the variances for the gaussian components. 82 | priors : tuple of ndarray 83 | Specifies the prior distribution of the components. 84 | 85 | """ 86 | 87 | def __init__(self, means=None, variances=None, priors=None, rng=None, seed=None): 88 | 89 | if means is None: 90 | means = map(lambda x: 10.0 * as_array(x), [[0, 0], 91 | [1, 1], 92 | [-1, -1], 93 | [1, -1], 94 | [-1, 1]]) 95 | # Number of components 96 | self.ncomponents = len(means) 97 | self.dim = means[0].shape[0] 98 | self.means = means 99 | # If prior is not specified let prior be flat. 100 | if priors is None: 101 | priors = [1.0/self.ncomponents for _ in range(self.ncomponents)] 102 | self.priors = priors 103 | # If variances are not specified let variances be identity 104 | if variances is None: 105 | variances = [np.eye(self.dim) for _ in range(self.ncomponents)] 106 | self.variances = variances 107 | 108 | assert len(means) == len(variances), "Mean variances mismatch" 109 | assert len(variances) == len(priors), "prior mismatch" 110 | 111 | if rng is None: 112 | rng = npr.RandomState(seed=seed) 113 | self.rng = rng 114 | 115 | def _sample_prior(self, nsamples): 116 | return self.rng.choice(a=self.ncomponents, 117 | size=(nsamples, ), 118 | replace=True, 119 | p=self.priors) 120 | 121 | def sample(self, nsamples): 122 | # Sampling priors 123 | samples = [] 124 | fathers = self._sample_prior(nsamples=nsamples).tolist() 125 | for father in fathers: 126 | samples.append(self._sample_gaussian(self.means[father], 127 | self.variances[father])) 128 | return as_array(samples), as_array(fathers) 129 | 130 | def _sample_gaussian(self, mean, variance): 131 | # sampling unit gaussians 132 | epsilons = self.rng.normal(size=(self.dim, )) 133 | 134 | return mean + np.linalg.cholesky(variance).dot(epsilons) 135 | 136 | def _gaussian_pdf(self, x, mean, variance): 137 | return multivariate_normal.pdf(x, mean=mean, cov=variance) 138 | 139 | def pdf(self, x): 140 | "Evaluates the the probability density function at the given point x" 141 | pdfs = map(lambda m, v, p: p * self._gaussian_pdf(x, m, v), 142 | self.means, self.variances, self.priors) 143 | return reduce(lambda x, y: x + y, pdfs, 0.0) 144 | 145 | 146 | if __name__ == '__main__': 147 | means = map(lambda x: as_array(x), [[0, 0], 148 | [1, 1], 149 | [-1, -1], 150 | [1, -1], 151 | [-1, 1]]) 152 | std = 0.01 153 | variances = [np.eye(2) * std for _ in means] 154 | priors = [1.0/len(means) for _ in means] 155 | 156 | gaussian_mixture = GaussianMixtureDistribution(means=means, 157 | variances=variances, 158 | priors=priors) 159 | gmdset = GaussianMixture(1000, means, variances, priors, sources=('features', )) 160 | 161 | -------------------------------------------------------------------------------- /experiments/gan_mixture.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import numpy 4 | from theano import tensor 5 | from blocks.algorithms import Adam 6 | from blocks.bricks import MLP, Rectifier, Identity, LinearMaxout, Linear 7 | from blocks.bricks.bn import BatchNormalization 8 | from blocks.bricks.sequences import Sequence 9 | from blocks.extensions import FinishAfter, Timing, Printing, ProgressBar 10 | from blocks.extensions.monitoring import DataStreamMonitoring 11 | from blocks.extensions.saveload import Checkpoint 12 | from blocks.graph import ComputationGraph, apply_dropout 13 | from blocks.graph.bn import (batch_normalization, 14 | get_batch_normalization_updates) 15 | from blocks.filter import VariableFilter 16 | from blocks.initialization import IsotropicGaussian, Constant 17 | from blocks.model import Model 18 | from blocks.main_loop import MainLoop 19 | from blocks.roles import INPUT 20 | 21 | from ali.algorithms import ali_algorithm 22 | from ali.streams import create_gaussian_mixture_data_streams 23 | from ali.bricks import GAN 24 | from ali.utils import as_array 25 | 26 | import logging 27 | import argparse 28 | 29 | INPUT_DIM = 2 30 | NLAT = 2 31 | GEN_HIDDEN = 400 32 | DISC_HIDDEN = 200 33 | GEN_ACTIVATION = Rectifier 34 | MAXOUT_PIECES = 5 35 | GAUSSIAN_INIT = IsotropicGaussian(std=0.02) 36 | ZERO_INIT = Constant(0.0) 37 | 38 | NUM_EPOCHS = 200 39 | LEARNING_RATE = 1e-4 40 | BETA1 = 0.8 41 | BATCH_SIZE = 100 42 | MONITORING_BATCH_SIZE = 500 43 | MEANS = [numpy.array([i, j]) for i, j in itertools.product(range(-4, 5, 2), 44 | range(-4, 5, 2))] 45 | VARIANCES = [0.05 ** 2 * numpy.eye(len(mean)) for mean in MEANS] 46 | PRIORS = None 47 | 48 | 49 | def create_model_brick(): 50 | decoder = MLP( 51 | dims=[NLAT, GEN_HIDDEN, GEN_HIDDEN, GEN_HIDDEN, GEN_HIDDEN, INPUT_DIM], 52 | activations=[Sequence([BatchNormalization(GEN_HIDDEN).apply, 53 | GEN_ACTIVATION().apply], 54 | name='decoder_h1'), 55 | Sequence([BatchNormalization(GEN_HIDDEN).apply, 56 | GEN_ACTIVATION().apply], 57 | name='decoder_h2'), 58 | Sequence([BatchNormalization(GEN_HIDDEN).apply, 59 | GEN_ACTIVATION().apply], 60 | name='decoder_h3'), 61 | Sequence([BatchNormalization(GEN_HIDDEN).apply, 62 | GEN_ACTIVATION().apply], 63 | name='decoder_h4'), 64 | Identity(name='decoder_out')], 65 | use_bias=False, 66 | name='decoder') 67 | 68 | discriminator = Sequence( 69 | application_methods=[ 70 | LinearMaxout( 71 | input_dim=INPUT_DIM, 72 | output_dim=DISC_HIDDEN, 73 | num_pieces=MAXOUT_PIECES, 74 | weights_init=GAUSSIAN_INIT, 75 | biases_init=ZERO_INIT, 76 | name='discriminator_h1').apply, 77 | LinearMaxout( 78 | input_dim=DISC_HIDDEN, 79 | output_dim=DISC_HIDDEN, 80 | num_pieces=MAXOUT_PIECES, 81 | weights_init=GAUSSIAN_INIT, 82 | biases_init=ZERO_INIT, 83 | name='discriminator_h2').apply, 84 | LinearMaxout( 85 | input_dim=DISC_HIDDEN, 86 | output_dim=DISC_HIDDEN, 87 | num_pieces=MAXOUT_PIECES, 88 | weights_init=GAUSSIAN_INIT, 89 | biases_init=ZERO_INIT, 90 | name='discriminator_h3').apply, 91 | Linear( 92 | input_dim=DISC_HIDDEN, 93 | output_dim=1, 94 | weights_init=GAUSSIAN_INIT, 95 | biases_init=ZERO_INIT, 96 | name='discriminator_out').apply], 97 | name='discriminator') 98 | 99 | gan = GAN(decoder=decoder, discriminator=discriminator, 100 | weights_init=GAUSSIAN_INIT, biases_init=ZERO_INIT, name='gan') 101 | gan.push_allocation_config() 102 | decoder.linear_transformations[-1].use_bias = True 103 | gan.initialize() 104 | 105 | return gan 106 | 107 | 108 | def create_models(): 109 | gan = create_model_brick() 110 | x = tensor.matrix('features') 111 | z = gan.theano_rng.normal(size=(x.shape[0], NLAT)) 112 | 113 | def _create_model(with_dropout): 114 | cg = ComputationGraph(gan.compute_losses(x, z)) 115 | if with_dropout: 116 | inputs = VariableFilter( 117 | bricks=gan.discriminator.children[1:], 118 | roles=[INPUT])(cg.variables) 119 | cg = apply_dropout(cg, inputs, 0.5) 120 | inputs = VariableFilter( 121 | bricks=[gan.discriminator], 122 | roles=[INPUT])(cg.variables) 123 | cg = apply_dropout(cg, inputs, 0.2) 124 | return Model(cg.outputs) 125 | 126 | model = _create_model(with_dropout=False) 127 | with batch_normalization(gan): 128 | bn_model = _create_model(with_dropout=False) 129 | 130 | pop_updates = list( 131 | set(get_batch_normalization_updates(bn_model, allow_duplicates=True))) 132 | bn_updates = [(p, m * 0.05 + p * 0.95) for p, m in pop_updates] 133 | 134 | return model, bn_model, bn_updates 135 | 136 | 137 | def create_main_loop(save_path): 138 | model, bn_model, bn_updates = create_models() 139 | gan, = bn_model.top_bricks 140 | discriminator_loss, generator_loss = bn_model.outputs 141 | step_rule = Adam(learning_rate=LEARNING_RATE, beta1=BETA1) 142 | algorithm = ali_algorithm(discriminator_loss, gan.discriminator_parameters, 143 | step_rule, generator_loss, 144 | gan.generator_parameters, step_rule) 145 | algorithm.add_updates(bn_updates) 146 | streams = create_gaussian_mixture_data_streams( 147 | batch_size=BATCH_SIZE, monitoring_batch_size=MONITORING_BATCH_SIZE, 148 | means=MEANS, variances=VARIANCES, priors=PRIORS) 149 | main_loop_stream, train_monitor_stream, valid_monitor_stream = streams 150 | bn_monitored_variables = ( 151 | [v for v in bn_model.auxiliary_variables if 'norm' not in v.name] + 152 | bn_model.outputs) 153 | monitored_variables = ( 154 | [v for v in model.auxiliary_variables if 'norm' not in v.name] + 155 | model.outputs) 156 | extensions = [ 157 | Timing(), 158 | FinishAfter(after_n_epochs=NUM_EPOCHS), 159 | DataStreamMonitoring( 160 | bn_monitored_variables, train_monitor_stream, prefix="train", 161 | updates=bn_updates), 162 | DataStreamMonitoring( 163 | monitored_variables, valid_monitor_stream, prefix="valid"), 164 | Checkpoint(save_path, after_epoch=True, after_training=True, 165 | use_cpickle=True), 166 | ProgressBar(), 167 | Printing(), 168 | ] 169 | main_loop = MainLoop(model=bn_model, data_stream=main_loop_stream, 170 | algorithm=algorithm, extensions=extensions) 171 | return main_loop 172 | 173 | if __name__ == '__main__': 174 | logging.basicConfig(level=logging.INFO) 175 | parser = argparse.ArgumentParser(description='Train GAN on MOG') 176 | parser.add_argument("--save-path", type=str, 177 | default='gan_mixture_prime.tar', 178 | help="main loop save path") 179 | args = parser.parse_args() 180 | main_loop = create_main_loop(args.save_path) 181 | main_loop.run() 182 | -------------------------------------------------------------------------------- /experiments/ali_svhn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | from blocks.algorithms import Adam 5 | from blocks.bricks import LeakyRectifier, Logistic 6 | from blocks.bricks.conv import ConvolutionalSequence 7 | from blocks.extensions import FinishAfter, Timing, Printing, ProgressBar 8 | from blocks.extensions.monitoring import DataStreamMonitoring 9 | from blocks.extensions.saveload import Checkpoint 10 | from blocks.filter import VariableFilter 11 | from blocks.graph import ComputationGraph, apply_dropout 12 | from blocks.graph.bn import (batch_normalization, 13 | get_batch_normalization_updates) 14 | from blocks.initialization import IsotropicGaussian, Constant 15 | from blocks.main_loop import MainLoop 16 | from blocks.model import Model 17 | from blocks.roles import INPUT 18 | from theano import tensor 19 | 20 | from ali.algorithms import ali_algorithm 21 | from ali.bricks import (ALI, GaussianConditional, DeterministicConditional, 22 | XZJointDiscriminator) 23 | from ali.streams import create_svhn_data_streams 24 | from ali.utils import get_log_odds, conv_brick, conv_transpose_brick, bn_brick 25 | 26 | BATCH_SIZE = 100 27 | MONITORING_BATCH_SIZE = 500 28 | NUM_EPOCHS = 100 29 | IMAGE_SIZE = (32, 32) 30 | NUM_CHANNELS = 3 31 | NLAT = 256 32 | GAUSSIAN_INIT = IsotropicGaussian(std=0.01) 33 | ZERO_INIT = Constant(0) 34 | LEARNING_RATE = 1e-4 35 | BETA1 = 0.5 36 | 37 | 38 | def create_model_brick(): 39 | layers = [ 40 | conv_brick(5, 1, 32), bn_brick(), LeakyRectifier(), 41 | conv_brick(4, 2, 64), bn_brick(), LeakyRectifier(), 42 | conv_brick(4, 1, 128), bn_brick(), LeakyRectifier(), 43 | conv_brick(4, 2, 256), bn_brick(), LeakyRectifier(), 44 | conv_brick(4, 1, 512), bn_brick(), LeakyRectifier(), 45 | conv_brick(1, 1, 512), bn_brick(), LeakyRectifier(), 46 | conv_brick(1, 1, 2 * NLAT)] 47 | encoder_mapping = ConvolutionalSequence( 48 | layers=layers, num_channels=NUM_CHANNELS, image_size=IMAGE_SIZE, 49 | use_bias=False, name='encoder_mapping') 50 | encoder = GaussianConditional(encoder_mapping, name='encoder') 51 | 52 | layers = [ 53 | conv_transpose_brick(4, 1, 256), bn_brick(), LeakyRectifier(), 54 | conv_transpose_brick(4, 2, 128), bn_brick(), LeakyRectifier(), 55 | conv_transpose_brick(4, 1, 64), bn_brick(), LeakyRectifier(), 56 | conv_transpose_brick(4, 2, 32), bn_brick(), LeakyRectifier(), 57 | conv_transpose_brick(5, 1, 32), bn_brick(), LeakyRectifier(), 58 | conv_transpose_brick(1, 1, 32), bn_brick(), LeakyRectifier(), 59 | conv_brick(1, 1, NUM_CHANNELS), Logistic()] 60 | decoder_mapping = ConvolutionalSequence( 61 | layers=layers, num_channels=NLAT, image_size=(1, 1), use_bias=False, 62 | name='decoder_mapping') 63 | decoder = DeterministicConditional(decoder_mapping, name='decoder') 64 | 65 | layers = [ 66 | conv_brick(5, 1, 32), LeakyRectifier(), 67 | conv_brick(4, 2, 64), bn_brick(), LeakyRectifier(), 68 | conv_brick(4, 1, 128), bn_brick(), LeakyRectifier(), 69 | conv_brick(4, 2, 256), bn_brick(), LeakyRectifier(), 70 | conv_brick(4, 1, 512), bn_brick(), LeakyRectifier()] 71 | x_discriminator = ConvolutionalSequence( 72 | layers=layers, num_channels=NUM_CHANNELS, image_size=IMAGE_SIZE, 73 | use_bias=False, name='x_discriminator') 74 | x_discriminator.push_allocation_config() 75 | 76 | layers = [ 77 | conv_brick(1, 1, 512), LeakyRectifier(), 78 | conv_brick(1, 1, 512), LeakyRectifier()] 79 | z_discriminator = ConvolutionalSequence( 80 | layers=layers, num_channels=NLAT, image_size=(1, 1), use_bias=False, 81 | name='z_discriminator') 82 | z_discriminator.push_allocation_config() 83 | 84 | layers = [ 85 | conv_brick(1, 1, 1024), LeakyRectifier(), 86 | conv_brick(1, 1, 1024), LeakyRectifier(), 87 | conv_brick(1, 1, 1)] 88 | joint_discriminator = ConvolutionalSequence( 89 | layers=layers, 90 | num_channels=(x_discriminator.get_dim('output')[0] + 91 | z_discriminator.get_dim('output')[0]), 92 | image_size=(1, 1), 93 | name='joint_discriminator') 94 | 95 | discriminator = XZJointDiscriminator( 96 | x_discriminator, z_discriminator, joint_discriminator, 97 | name='discriminator') 98 | 99 | ali = ALI(encoder, decoder, discriminator, 100 | weights_init=GAUSSIAN_INIT, biases_init=ZERO_INIT, 101 | name='ali') 102 | ali.push_allocation_config() 103 | encoder_mapping.layers[-1].use_bias = True 104 | encoder_mapping.layers[-1].tied_biases = False 105 | decoder_mapping.layers[-2].use_bias = True 106 | decoder_mapping.layers[-2].tied_biases = False 107 | x_discriminator.layers[0].use_bias = True 108 | x_discriminator.layers[0].tied_biases = True 109 | ali.initialize() 110 | raw_marginals, = next( 111 | create_svhn_data_streams(500, 500)[0].get_epoch_iterator()) 112 | b_value = get_log_odds(raw_marginals) 113 | decoder_mapping.layers[-2].b.set_value(b_value) 114 | 115 | return ali 116 | 117 | 118 | def create_models(): 119 | ali = create_model_brick() 120 | x = tensor.tensor4('features') 121 | z = ali.theano_rng.normal(size=(x.shape[0], NLAT, 1, 1)) 122 | 123 | def _create_model(with_dropout): 124 | cg = ComputationGraph(ali.compute_losses(x, z)) 125 | if with_dropout: 126 | inputs = VariableFilter( 127 | bricks=([ali.discriminator.x_discriminator.layers[0]] + 128 | ali.discriminator.x_discriminator.layers[2::3] + 129 | ali.discriminator.z_discriminator.layers[::2] + 130 | ali.discriminator.joint_discriminator.layers[::2]), 131 | roles=[INPUT])(cg.variables) 132 | cg = apply_dropout(cg, inputs, 0.2) 133 | return Model(cg.outputs) 134 | 135 | model = _create_model(with_dropout=False) 136 | with batch_normalization(ali): 137 | bn_model = _create_model(with_dropout=True) 138 | 139 | pop_updates = list( 140 | set(get_batch_normalization_updates(bn_model, allow_duplicates=True))) 141 | bn_updates = [(p, m * 0.05 + p * 0.95) for p, m in pop_updates] 142 | 143 | return model, bn_model, bn_updates 144 | 145 | 146 | def create_main_loop(save_path): 147 | model, bn_model, bn_updates = create_models() 148 | ali, = bn_model.top_bricks 149 | discriminator_loss, generator_loss = bn_model.outputs 150 | 151 | step_rule = Adam(learning_rate=LEARNING_RATE, beta1=BETA1) 152 | algorithm = ali_algorithm(discriminator_loss, ali.discriminator_parameters, 153 | step_rule, generator_loss, 154 | ali.generator_parameters, step_rule) 155 | algorithm.add_updates(bn_updates) 156 | streams = create_svhn_data_streams(BATCH_SIZE, MONITORING_BATCH_SIZE) 157 | main_loop_stream, train_monitor_stream, valid_monitor_stream = streams 158 | bn_monitored_variables = ( 159 | [v for v in bn_model.auxiliary_variables if 'norm' not in v.name] + 160 | bn_model.outputs) 161 | monitored_variables = ( 162 | [v for v in model.auxiliary_variables if 'norm' not in v.name] + 163 | model.outputs) 164 | extensions = [ 165 | Timing(), 166 | FinishAfter(after_n_epochs=NUM_EPOCHS), 167 | DataStreamMonitoring( 168 | bn_monitored_variables, train_monitor_stream, prefix="train", 169 | updates=bn_updates), 170 | DataStreamMonitoring( 171 | monitored_variables, valid_monitor_stream, prefix="valid"), 172 | Checkpoint(save_path, after_epoch=True, after_training=True, 173 | use_cpickle=True), 174 | ProgressBar(), 175 | Printing(), 176 | ] 177 | main_loop = MainLoop(model=bn_model, data_stream=main_loop_stream, 178 | algorithm=algorithm, extensions=extensions) 179 | return main_loop 180 | 181 | 182 | if __name__ == "__main__": 183 | logging.basicConfig(level=logging.INFO) 184 | parser = argparse.ArgumentParser(description="Train ALI on SVHN") 185 | parser.add_argument("--save-path", type=str, default='ali_svhn.tar', 186 | help="main loop save path") 187 | args = parser.parse_args() 188 | create_main_loop(args.save_path).run() 189 | -------------------------------------------------------------------------------- /experiments/ali_celeba.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | from blocks.algorithms import Adam 5 | from blocks.bricks import LeakyRectifier, Logistic 6 | from blocks.bricks.conv import ConvolutionalSequence 7 | from blocks.extensions import FinishAfter, Timing, Printing, ProgressBar 8 | from blocks.extensions.monitoring import DataStreamMonitoring 9 | from blocks.extensions.saveload import Checkpoint 10 | from blocks.filter import VariableFilter 11 | from blocks.graph import ComputationGraph, apply_dropout 12 | from blocks.graph.bn import (batch_normalization, 13 | get_batch_normalization_updates) 14 | from blocks.initialization import IsotropicGaussian, Constant 15 | from blocks.main_loop import MainLoop 16 | from blocks.model import Model 17 | from blocks.roles import INPUT 18 | from theano import tensor 19 | 20 | from ali.algorithms import ali_algorithm 21 | from ali.bricks import (ALI, GaussianConditional, DeterministicConditional, 22 | XZJointDiscriminator) 23 | from ali.streams import create_celeba_data_streams 24 | from ali.utils import get_log_odds, conv_brick, conv_transpose_brick, bn_brick 25 | 26 | BATCH_SIZE = 100 27 | MONITORING_BATCH_SIZE = 500 28 | NUM_EPOCHS = 123 29 | IMAGE_SIZE = (64, 64) 30 | NUM_CHANNELS = 3 31 | NLAT = 256 32 | GAUSSIAN_INIT = IsotropicGaussian(std=0.01) 33 | ZERO_INIT = Constant(0) 34 | LEARNING_RATE = 1e-4 35 | BETA1 = 0.5 36 | LEAK = 0.02 37 | 38 | 39 | def create_model_brick(): 40 | layers = [ 41 | conv_brick(2, 1, 64), bn_brick(), LeakyRectifier(leak=LEAK), 42 | conv_brick(7, 2, 128), bn_brick(), LeakyRectifier(leak=LEAK), 43 | conv_brick(5, 2, 256), bn_brick(), LeakyRectifier(leak=LEAK), 44 | conv_brick(7, 2, 256), bn_brick(), LeakyRectifier(leak=LEAK), 45 | conv_brick(4, 1, 512), bn_brick(), LeakyRectifier(leak=LEAK), 46 | conv_brick(1, 1, 2 * NLAT)] 47 | encoder_mapping = ConvolutionalSequence( 48 | layers=layers, num_channels=NUM_CHANNELS, image_size=IMAGE_SIZE, 49 | use_bias=False, name='encoder_mapping') 50 | encoder = GaussianConditional(encoder_mapping, name='encoder') 51 | 52 | layers = [ 53 | conv_transpose_brick(4, 1, 512), bn_brick(), LeakyRectifier(leak=LEAK), 54 | conv_transpose_brick(7, 2, 256), bn_brick(), LeakyRectifier(leak=LEAK), 55 | conv_transpose_brick(5, 2, 256), bn_brick(), LeakyRectifier(leak=LEAK), 56 | conv_transpose_brick(7, 2, 128), bn_brick(), LeakyRectifier(leak=LEAK), 57 | conv_transpose_brick(2, 1, 64), bn_brick(), LeakyRectifier(leak=LEAK), 58 | conv_brick(1, 1, NUM_CHANNELS), Logistic()] 59 | decoder_mapping = ConvolutionalSequence( 60 | layers=layers, num_channels=NLAT, image_size=(1, 1), use_bias=False, 61 | name='decoder_mapping') 62 | decoder = DeterministicConditional(decoder_mapping, name='decoder') 63 | 64 | layers = [ 65 | conv_brick(2, 1, 64), LeakyRectifier(leak=LEAK), 66 | conv_brick(7, 2, 128), bn_brick(), LeakyRectifier(leak=LEAK), 67 | conv_brick(5, 2, 256), bn_brick(), LeakyRectifier(leak=LEAK), 68 | conv_brick(7, 2, 256), bn_brick(), LeakyRectifier(leak=LEAK), 69 | conv_brick(4, 1, 512), bn_brick(), LeakyRectifier(leak=LEAK)] 70 | x_discriminator = ConvolutionalSequence( 71 | layers=layers, num_channels=NUM_CHANNELS, image_size=IMAGE_SIZE, 72 | use_bias=False, name='x_discriminator') 73 | x_discriminator.push_allocation_config() 74 | 75 | layers = [ 76 | conv_brick(1, 1, 1024), LeakyRectifier(leak=LEAK), 77 | conv_brick(1, 1, 1024), LeakyRectifier(leak=LEAK)] 78 | z_discriminator = ConvolutionalSequence( 79 | layers=layers, num_channels=NLAT, image_size=(1, 1), use_bias=False, 80 | name='z_discriminator') 81 | z_discriminator.push_allocation_config() 82 | 83 | layers = [ 84 | conv_brick(1, 1, 2048), LeakyRectifier(leak=LEAK), 85 | conv_brick(1, 1, 2048), LeakyRectifier(leak=LEAK), 86 | conv_brick(1, 1, 1)] 87 | joint_discriminator = ConvolutionalSequence( 88 | layers=layers, 89 | num_channels=(x_discriminator.get_dim('output')[0] + 90 | z_discriminator.get_dim('output')[0]), 91 | image_size=(1, 1), 92 | name='joint_discriminator') 93 | 94 | discriminator = XZJointDiscriminator( 95 | x_discriminator, z_discriminator, joint_discriminator, 96 | name='discriminator') 97 | 98 | ali = ALI(encoder, decoder, discriminator, 99 | weights_init=GAUSSIAN_INIT, biases_init=ZERO_INIT, 100 | name='ali') 101 | ali.push_allocation_config() 102 | encoder_mapping.layers[-1].use_bias = True 103 | encoder_mapping.layers[-1].tied_biases = False 104 | decoder_mapping.layers[-2].use_bias = True 105 | decoder_mapping.layers[-2].tied_biases = False 106 | x_discriminator.layers[0].use_bias = True 107 | x_discriminator.layers[0].tied_biases = True 108 | ali.initialize() 109 | raw_marginals, = next( 110 | create_celeba_data_streams(500, 500)[0].get_epoch_iterator()) 111 | b_value = get_log_odds(raw_marginals) 112 | decoder_mapping.layers[-2].b.set_value(b_value) 113 | 114 | return ali 115 | 116 | 117 | def create_models(): 118 | ali = create_model_brick() 119 | x = tensor.tensor4('features') 120 | z = ali.theano_rng.normal(size=(x.shape[0], NLAT, 1, 1)) 121 | 122 | def _create_model(with_dropout): 123 | cg = ComputationGraph(ali.compute_losses(x, z)) 124 | if with_dropout: 125 | inputs = VariableFilter( 126 | bricks=([ali.discriminator.x_discriminator.layers[0]] + 127 | ali.discriminator.x_discriminator.layers[2::3] + 128 | ali.discriminator.z_discriminator.layers[::2] + 129 | ali.discriminator.joint_discriminator.layers[::2]), 130 | roles=[INPUT])(cg.variables) 131 | cg = apply_dropout(cg, inputs, 0.2) 132 | return Model(cg.outputs) 133 | 134 | model = _create_model(with_dropout=False) 135 | with batch_normalization(ali): 136 | bn_model = _create_model(with_dropout=True) 137 | 138 | pop_updates = list( 139 | set(get_batch_normalization_updates(bn_model, allow_duplicates=True))) 140 | bn_updates = [(p, m * 0.05 + p * 0.95) for p, m in pop_updates] 141 | 142 | return model, bn_model, bn_updates 143 | 144 | 145 | def create_main_loop(save_path): 146 | model, bn_model, bn_updates = create_models() 147 | ali, = bn_model.top_bricks 148 | discriminator_loss, generator_loss = bn_model.outputs 149 | 150 | step_rule = Adam(learning_rate=LEARNING_RATE, beta1=BETA1) 151 | algorithm = ali_algorithm(discriminator_loss, ali.discriminator_parameters, 152 | step_rule, generator_loss, 153 | ali.generator_parameters, step_rule) 154 | algorithm.add_updates(bn_updates) 155 | streams = create_celeba_data_streams(BATCH_SIZE, MONITORING_BATCH_SIZE) 156 | main_loop_stream, train_monitor_stream, valid_monitor_stream = streams 157 | bn_monitored_variables = ( 158 | [v for v in bn_model.auxiliary_variables if 'norm' not in v.name] + 159 | bn_model.outputs) 160 | monitored_variables = ( 161 | [v for v in model.auxiliary_variables if 'norm' not in v.name] + 162 | model.outputs) 163 | extensions = [ 164 | Timing(), 165 | FinishAfter(after_n_epochs=NUM_EPOCHS), 166 | DataStreamMonitoring( 167 | bn_monitored_variables, train_monitor_stream, prefix="train", 168 | updates=bn_updates), 169 | DataStreamMonitoring( 170 | monitored_variables, valid_monitor_stream, prefix="valid"), 171 | Checkpoint(save_path, after_epoch=True, after_training=True, 172 | use_cpickle=True), 173 | ProgressBar(), 174 | Printing(), 175 | ] 176 | main_loop = MainLoop(model=bn_model, data_stream=main_loop_stream, 177 | algorithm=algorithm, extensions=extensions) 178 | return main_loop 179 | 180 | 181 | if __name__ == "__main__": 182 | logging.basicConfig(level=logging.INFO) 183 | parser = argparse.ArgumentParser(description="Train ALI on CelebA") 184 | parser.add_argument("--save-path", type=str, default='ali_celeba.tar', 185 | help="main loop save path") 186 | args = parser.parse_args() 187 | create_main_loop(args.save_path).run() 188 | -------------------------------------------------------------------------------- /experiments/ali_cifar10.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | from blocks.algorithms import Adam 5 | from blocks.bricks import LeakyRectifier, Logistic 6 | from blocks.bricks.conv import ConvolutionalSequence 7 | from blocks.extensions import FinishAfter, Timing, Printing, ProgressBar 8 | from blocks.extensions.monitoring import DataStreamMonitoring 9 | from blocks.extensions.saveload import Checkpoint 10 | from blocks.filter import VariableFilter 11 | from blocks.graph import ComputationGraph, apply_dropout 12 | from blocks.graph.bn import (batch_normalization, 13 | get_batch_normalization_updates) 14 | from blocks.initialization import IsotropicGaussian, Constant 15 | from blocks.main_loop import MainLoop 16 | from blocks.model import Model 17 | from blocks.roles import INPUT 18 | from theano import tensor 19 | 20 | from ali.algorithms import ali_algorithm 21 | from ali.bricks import (ALI, GaussianConditional, DeterministicConditional, 22 | XZJointDiscriminator, ConvMaxout) 23 | from ali.streams import create_cifar10_data_streams 24 | from ali.utils import get_log_odds, conv_brick, conv_transpose_brick, bn_brick 25 | 26 | BATCH_SIZE = 100 27 | MONITORING_BATCH_SIZE = 500 28 | NUM_EPOCHS = 6475 29 | IMAGE_SIZE = (32, 32) 30 | NUM_CHANNELS = 3 31 | NLAT = 64 32 | GAUSSIAN_INIT = IsotropicGaussian(std=0.01) 33 | ZERO_INIT = Constant(0) 34 | LEAK = 0.1 35 | NUM_PIECES = 2 36 | LEARNING_RATE = 1e-4 37 | BETA1 = 0.5 38 | 39 | 40 | def create_model_brick(): 41 | layers = [ 42 | conv_brick(5, 1, 32), bn_brick(), LeakyRectifier(leak=LEAK), 43 | conv_brick(4, 2, 64), bn_brick(), LeakyRectifier(leak=LEAK), 44 | conv_brick(4, 1, 128), bn_brick(), LeakyRectifier(leak=LEAK), 45 | conv_brick(4, 2, 256), bn_brick(), LeakyRectifier(leak=LEAK), 46 | conv_brick(4, 1, 512), bn_brick(), LeakyRectifier(leak=LEAK), 47 | conv_brick(1, 1, 512), bn_brick(), LeakyRectifier(leak=LEAK), 48 | conv_brick(1, 1, 2 * NLAT)] 49 | encoder_mapping = ConvolutionalSequence( 50 | layers=layers, num_channels=NUM_CHANNELS, image_size=IMAGE_SIZE, 51 | use_bias=False, name='encoder_mapping') 52 | encoder = GaussianConditional(encoder_mapping, name='encoder') 53 | 54 | layers = [ 55 | conv_transpose_brick(4, 1, 256), bn_brick(), LeakyRectifier(leak=LEAK), 56 | conv_transpose_brick(4, 2, 128), bn_brick(), LeakyRectifier(leak=LEAK), 57 | conv_transpose_brick(4, 1, 64), bn_brick(), LeakyRectifier(leak=LEAK), 58 | conv_transpose_brick(4, 2, 32), bn_brick(), LeakyRectifier(leak=LEAK), 59 | conv_transpose_brick(5, 1, 32), bn_brick(), LeakyRectifier(leak=LEAK), 60 | conv_transpose_brick(1, 1, 32), bn_brick(), LeakyRectifier(leak=LEAK), 61 | conv_brick(1, 1, NUM_CHANNELS), Logistic()] 62 | decoder_mapping = ConvolutionalSequence( 63 | layers=layers, num_channels=NLAT, image_size=(1, 1), use_bias=False, 64 | name='decoder_mapping') 65 | decoder = DeterministicConditional(decoder_mapping, name='decoder') 66 | 67 | layers = [ 68 | conv_brick(5, 1, 32), ConvMaxout(num_pieces=NUM_PIECES), 69 | conv_brick(4, 2, 64), ConvMaxout(num_pieces=NUM_PIECES), 70 | conv_brick(4, 1, 128), ConvMaxout(num_pieces=NUM_PIECES), 71 | conv_brick(4, 2, 256), ConvMaxout(num_pieces=NUM_PIECES), 72 | conv_brick(4, 1, 512), ConvMaxout(num_pieces=NUM_PIECES)] 73 | x_discriminator = ConvolutionalSequence( 74 | layers=layers, num_channels=NUM_CHANNELS, image_size=IMAGE_SIZE, 75 | name='x_discriminator') 76 | x_discriminator.push_allocation_config() 77 | 78 | layers = [ 79 | conv_brick(1, 1, 512), ConvMaxout(num_pieces=NUM_PIECES), 80 | conv_brick(1, 1, 512), ConvMaxout(num_pieces=NUM_PIECES)] 81 | z_discriminator = ConvolutionalSequence( 82 | layers=layers, num_channels=NLAT, image_size=(1, 1), use_bias=False, 83 | name='z_discriminator') 84 | z_discriminator.push_allocation_config() 85 | 86 | layers = [ 87 | conv_brick(1, 1, 1024), ConvMaxout(num_pieces=NUM_PIECES), 88 | conv_brick(1, 1, 1024), ConvMaxout(num_pieces=NUM_PIECES), 89 | conv_brick(1, 1, 1)] 90 | joint_discriminator = ConvolutionalSequence( 91 | layers=layers, 92 | num_channels=(x_discriminator.get_dim('output')[0] + 93 | z_discriminator.get_dim('output')[0]), 94 | image_size=(1, 1), 95 | name='joint_discriminator') 96 | 97 | discriminator = XZJointDiscriminator( 98 | x_discriminator, z_discriminator, joint_discriminator, 99 | name='discriminator') 100 | 101 | ali = ALI(encoder, decoder, discriminator, 102 | weights_init=GAUSSIAN_INIT, biases_init=ZERO_INIT, 103 | name='ali') 104 | ali.push_allocation_config() 105 | encoder_mapping.layers[-1].use_bias = True 106 | encoder_mapping.layers[-1].tied_biases = False 107 | decoder_mapping.layers[-2].use_bias = True 108 | decoder_mapping.layers[-2].tied_biases = False 109 | ali.initialize() 110 | raw_marginals, = next( 111 | create_cifar10_data_streams(500, 500)[0].get_epoch_iterator()) 112 | b_value = get_log_odds(raw_marginals) 113 | decoder_mapping.layers[-2].b.set_value(b_value) 114 | 115 | return ali 116 | 117 | 118 | def create_models(): 119 | ali = create_model_brick() 120 | x = tensor.tensor4('features') 121 | z = ali.theano_rng.normal(size=(x.shape[0], NLAT, 1, 1)) 122 | 123 | def _create_model(with_dropout): 124 | cg = ComputationGraph(ali.compute_losses(x, z)) 125 | if with_dropout: 126 | inputs = VariableFilter( 127 | bricks=([ali.discriminator.x_discriminator.layers[0], 128 | ali.discriminator.z_discriminator.layers[0]]), 129 | roles=[INPUT])(cg.variables) 130 | cg = apply_dropout(cg, inputs, 0.2) 131 | inputs = VariableFilter( 132 | bricks=(ali.discriminator.x_discriminator.layers[2::3] + 133 | ali.discriminator.z_discriminator.layers[2::2] + 134 | ali.discriminator.joint_discriminator.layers[::2]), 135 | roles=[INPUT])(cg.variables) 136 | cg = apply_dropout(cg, inputs, 0.5) 137 | return Model(cg.outputs) 138 | 139 | model = _create_model(with_dropout=False) 140 | with batch_normalization(ali): 141 | bn_model = _create_model(with_dropout=True) 142 | 143 | pop_updates = list( 144 | set(get_batch_normalization_updates(bn_model, allow_duplicates=True))) 145 | bn_updates = [(p, m * 0.05 + p * 0.95) for p, m in pop_updates] 146 | 147 | return model, bn_model, bn_updates 148 | 149 | 150 | def create_main_loop(save_path): 151 | 152 | model, bn_model, bn_updates = create_models() 153 | ali, = bn_model.top_bricks 154 | discriminator_loss, generator_loss = bn_model.outputs 155 | 156 | step_rule = Adam(learning_rate=LEARNING_RATE, beta1=BETA1) 157 | algorithm = ali_algorithm(discriminator_loss, ali.discriminator_parameters, 158 | step_rule, generator_loss, 159 | ali.generator_parameters, step_rule) 160 | algorithm.add_updates(bn_updates) 161 | streams = create_cifar10_data_streams(BATCH_SIZE, MONITORING_BATCH_SIZE) 162 | main_loop_stream, train_monitor_stream, valid_monitor_stream = streams 163 | bn_monitored_variables = ( 164 | [v for v in bn_model.auxiliary_variables if 'norm' not in v.name] + 165 | bn_model.outputs) 166 | monitored_variables = ( 167 | [v for v in model.auxiliary_variables if 'norm' not in v.name] + 168 | model.outputs) 169 | extensions = [ 170 | Timing(), 171 | FinishAfter(after_n_epochs=NUM_EPOCHS), 172 | DataStreamMonitoring( 173 | bn_monitored_variables, train_monitor_stream, prefix="train", 174 | updates=bn_updates), 175 | DataStreamMonitoring( 176 | monitored_variables, valid_monitor_stream, prefix="valid"), 177 | Checkpoint(save_path, after_epoch=True, after_training=True, 178 | use_cpickle=True), 179 | ProgressBar(), 180 | Printing(), 181 | ] 182 | main_loop = MainLoop(model=bn_model, data_stream=main_loop_stream, 183 | algorithm=algorithm, extensions=extensions) 184 | return main_loop 185 | 186 | 187 | if __name__ == "__main__": 188 | logging.basicConfig(level=logging.INFO) 189 | parser = argparse.ArgumentParser(description="Train ALI on CIFAR10") 190 | parser.add_argument("--save-path", type=str, default='ali_cifar10.tar', 191 | help="main loop save path") 192 | args = parser.parse_args() 193 | create_main_loop(args.save_path).run() 194 | -------------------------------------------------------------------------------- /experiments/ali_tiny_imagenet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | from blocks.algorithms import Adam 5 | from blocks.bricks import LeakyRectifier, Logistic 6 | from blocks.bricks.conv import ConvolutionalSequence 7 | from blocks.extensions import FinishAfter, Timing, Printing, ProgressBar 8 | from blocks.extensions.monitoring import DataStreamMonitoring 9 | from blocks.extensions.saveload import Checkpoint 10 | from blocks.filter import VariableFilter 11 | from blocks.graph import ComputationGraph, apply_dropout 12 | from blocks.graph.bn import (batch_normalization, 13 | get_batch_normalization_updates) 14 | from blocks.initialization import IsotropicGaussian, Constant 15 | from blocks.main_loop import MainLoop 16 | from blocks.model import Model 17 | from blocks.roles import INPUT 18 | from theano import tensor 19 | 20 | from ali.algorithms import ali_algorithm 21 | from ali.bricks import (ALI, GaussianConditional, DeterministicConditional, 22 | XZJointDiscriminator) 23 | from ali.streams import create_tiny_imagenet_data_streams 24 | from ali.utils import get_log_odds, conv_brick, conv_transpose_brick, bn_brick 25 | 26 | BATCH_SIZE = 128 27 | MONITORING_BATCH_SIZE = 128 28 | NUM_EPOCHS = 1000 29 | IMAGE_SIZE = (64, 64) 30 | NUM_CHANNELS = 3 31 | NLAT = 256 32 | GAUSSIAN_INIT = IsotropicGaussian(std=0.01) 33 | ZERO_INIT = Constant(0) 34 | LEARNING_RATE = 1e-4 35 | BETA1 = 0.5 36 | 37 | 38 | def create_model_brick(): 39 | layers = [ 40 | conv_brick(4, 2, 64), bn_brick(), LeakyRectifier(), 41 | conv_brick(4, 1, 64), bn_brick(), LeakyRectifier(), 42 | conv_brick(4, 2, 128), bn_brick(), LeakyRectifier(), 43 | conv_brick(4, 1, 128), bn_brick(), LeakyRectifier(), 44 | conv_brick(4, 2, 256), bn_brick(), LeakyRectifier(), 45 | conv_brick(4, 1, 256), bn_brick(), LeakyRectifier(), 46 | conv_brick(1, 1, 2048), bn_brick(), LeakyRectifier(), 47 | conv_brick(1, 1, 2048), bn_brick(), LeakyRectifier(), 48 | conv_brick(1, 1, 2 * NLAT)] 49 | encoder_mapping = ConvolutionalSequence( 50 | layers=layers, num_channels=NUM_CHANNELS, image_size=IMAGE_SIZE, 51 | use_bias=False, name='encoder_mapping') 52 | encoder = GaussianConditional(encoder_mapping, name='encoder') 53 | 54 | layers = [ 55 | conv_brick(1, 1, 2048), bn_brick(), LeakyRectifier(), 56 | conv_brick(1, 1, 256), bn_brick(), LeakyRectifier(), 57 | conv_transpose_brick(4, 1, 256), bn_brick(), LeakyRectifier(), 58 | conv_transpose_brick(4, 2, 128), bn_brick(), LeakyRectifier(), 59 | conv_transpose_brick(4, 1, 128), bn_brick(), LeakyRectifier(), 60 | conv_transpose_brick(4, 2, 64), bn_brick(), LeakyRectifier(), 61 | conv_transpose_brick(4, 1, 64), bn_brick(), LeakyRectifier(), 62 | conv_transpose_brick(4, 2, 64), bn_brick(), LeakyRectifier(), 63 | conv_brick(1, 1, NUM_CHANNELS), Logistic()] 64 | decoder_mapping = ConvolutionalSequence( 65 | layers=layers, num_channels=NLAT, image_size=(1, 1), use_bias=False, 66 | name='decoder_mapping') 67 | decoder = DeterministicConditional(decoder_mapping, name='decoder') 68 | 69 | layers = [ 70 | conv_brick(4, 2, 64), LeakyRectifier(), 71 | conv_brick(4, 1, 64), bn_brick(), LeakyRectifier(), 72 | conv_brick(4, 2, 128), bn_brick(), LeakyRectifier(), 73 | conv_brick(4, 1, 128), bn_brick(), LeakyRectifier(), 74 | conv_brick(4, 2, 256), bn_brick(), LeakyRectifier(), 75 | conv_brick(4, 1, 256), bn_brick(), LeakyRectifier()] 76 | x_discriminator = ConvolutionalSequence( 77 | layers=layers, num_channels=NUM_CHANNELS, image_size=IMAGE_SIZE, 78 | use_bias=False, name='x_discriminator') 79 | x_discriminator.push_allocation_config() 80 | 81 | layers = [ 82 | conv_brick(1, 1, 2048), LeakyRectifier(), 83 | conv_brick(1, 1, 2048), LeakyRectifier()] 84 | z_discriminator = ConvolutionalSequence( 85 | layers=layers, num_channels=NLAT, image_size=(1, 1), 86 | name='z_discriminator') 87 | z_discriminator.push_allocation_config() 88 | 89 | layers = [ 90 | conv_brick(1, 1, 4096), LeakyRectifier(), 91 | conv_brick(1, 1, 4096), LeakyRectifier(), 92 | conv_brick(1, 1, 1)] 93 | joint_discriminator = ConvolutionalSequence( 94 | layers=layers, 95 | num_channels=(x_discriminator.get_dim('output')[0] + 96 | z_discriminator.get_dim('output')[0]), 97 | image_size=(1, 1), 98 | name='joint_discriminator') 99 | 100 | discriminator = XZJointDiscriminator( 101 | x_discriminator, z_discriminator, joint_discriminator, 102 | name='discriminator') 103 | 104 | ali = ALI(encoder, decoder, discriminator, 105 | weights_init=GAUSSIAN_INIT, biases_init=ZERO_INIT, 106 | name='ali') 107 | ali.push_allocation_config() 108 | encoder_mapping.layers[-1].use_bias = True 109 | encoder_mapping.layers[-1].tied_biases = False 110 | decoder_mapping.layers[-2].use_bias = True 111 | decoder_mapping.layers[-2].tied_biases = False 112 | x_discriminator.layers[0].use_bias = True 113 | x_discriminator.layers[0].tied_biases = True 114 | ali.initialize() 115 | raw_marginals, = next( 116 | create_tiny_imagenet_data_streams(500, 500)[0].get_epoch_iterator()) 117 | b_value = get_log_odds(raw_marginals) 118 | decoder_mapping.layers[-2].b.set_value(b_value) 119 | 120 | return ali 121 | 122 | 123 | def create_models(): 124 | ali = create_model_brick() 125 | x = tensor.tensor4('features') 126 | z = ali.theano_rng.normal(size=(x.shape[0], NLAT, 1, 1)) 127 | 128 | def _create_model(with_dropout): 129 | cg = ComputationGraph(ali.compute_losses(x, z)) 130 | if with_dropout: 131 | inputs = VariableFilter( 132 | bricks=([ali.discriminator.x_discriminator.layers[0]] + 133 | ali.discriminator.x_discriminator.layers[2::3] + 134 | ali.discriminator.z_discriminator.layers[::2] + 135 | ali.discriminator.joint_discriminator.layers[::2]), 136 | roles=[INPUT])(cg.variables) 137 | cg = apply_dropout(cg, inputs, 0.2) 138 | return Model(cg.outputs) 139 | 140 | model = _create_model(with_dropout=False) 141 | with batch_normalization(ali): 142 | bn_model = _create_model(with_dropout=True) 143 | 144 | pop_updates = list( 145 | set(get_batch_normalization_updates(bn_model, allow_duplicates=True))) 146 | bn_updates = [(p, m * 0.05 + p * 0.95) for p, m in pop_updates] 147 | 148 | return model, bn_model, bn_updates 149 | 150 | 151 | def create_main_loop(save_path): 152 | model, bn_model, bn_updates = create_models() 153 | ali, = bn_model.top_bricks 154 | discriminator_loss, generator_loss = bn_model.outputs 155 | 156 | step_rule = Adam(learning_rate=LEARNING_RATE, beta1=BETA1) 157 | algorithm = ali_algorithm(discriminator_loss, ali.discriminator_parameters, 158 | step_rule, generator_loss, 159 | ali.generator_parameters, step_rule) 160 | algorithm.add_updates(bn_updates) 161 | streams = create_tiny_imagenet_data_streams(BATCH_SIZE, 162 | MONITORING_BATCH_SIZE) 163 | main_loop_stream, train_monitor_stream, valid_monitor_stream = streams 164 | bn_monitored_variables = ( 165 | [v for v in bn_model.auxiliary_variables if 'norm' not in v.name] + 166 | bn_model.outputs) 167 | monitored_variables = ( 168 | [v for v in model.auxiliary_variables if 'norm' not in v.name] + 169 | model.outputs) 170 | extensions = [ 171 | Timing(), 172 | FinishAfter(after_n_epochs=NUM_EPOCHS), 173 | DataStreamMonitoring( 174 | bn_monitored_variables, train_monitor_stream, prefix="train", 175 | updates=bn_updates), 176 | DataStreamMonitoring( 177 | monitored_variables, valid_monitor_stream, prefix="valid"), 178 | Checkpoint(save_path, after_epoch=True, after_training=True, 179 | use_cpickle=True), 180 | ProgressBar(), 181 | Printing(), 182 | ] 183 | main_loop = MainLoop(model=bn_model, data_stream=main_loop_stream, 184 | algorithm=algorithm, extensions=extensions) 185 | return main_loop 186 | 187 | 188 | if __name__ == "__main__": 189 | logging.basicConfig(level=logging.INFO) 190 | parser = argparse.ArgumentParser(description="Train ALI on Tiny ImageNet") 191 | parser.add_argument("--save-path", type=str, 192 | default='ali_tiny_imagenet.tar', 193 | help="main loop save path") 194 | args = parser.parse_args() 195 | create_main_loop(args.save_path).run() 196 | -------------------------------------------------------------------------------- /experiments/ali_mixture.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import numpy 4 | from theano import tensor 5 | from blocks.algorithms import Adam 6 | from blocks.bricks import MLP, Rectifier, Identity, LinearMaxout, Linear 7 | from blocks.bricks.bn import BatchNormalization 8 | from blocks.bricks.sequences import Sequence 9 | from blocks.extensions import FinishAfter, Timing, Printing, ProgressBar 10 | from blocks.extensions.monitoring import DataStreamMonitoring 11 | from blocks.extensions.saveload import Checkpoint 12 | from blocks.graph import ComputationGraph, apply_dropout 13 | from blocks.graph.bn import (batch_normalization, 14 | get_batch_normalization_updates) 15 | from blocks.filter import VariableFilter 16 | from blocks.initialization import IsotropicGaussian, Constant 17 | from blocks.model import Model 18 | from blocks.main_loop import MainLoop 19 | from blocks.roles import INPUT 20 | 21 | from ali.algorithms import ali_algorithm 22 | from ali.streams import create_gaussian_mixture_data_streams 23 | from ali.bricks import (ALI, COVConditional, DeterministicConditional, 24 | XZJointDiscriminator) 25 | from ali.utils import as_array 26 | 27 | import logging 28 | import argparse 29 | 30 | INPUT_DIM = 2 31 | NLAT = 2 32 | GEN_HIDDEN = 400 33 | DISC_HIDDEN = 200 34 | GEN_ACTIVATION = Rectifier 35 | MAXOUT_PIECES = 5 36 | GAUSSIAN_INIT = IsotropicGaussian(std=0.02) 37 | ZERO_INIT = Constant(0.0) 38 | 39 | NUM_EPOCHS = 200 40 | LEARNING_RATE = 1e-4 41 | BETA1 = 0.8 42 | BATCH_SIZE = 100 43 | MONITORING_BATCH_SIZE = 500 44 | MEANS = [numpy.array([i, j]) for i, j in itertools.product(range(-4, 5, 2), 45 | range(-4, 5, 2))] 46 | VARIANCES = [0.05 ** 2 * numpy.eye(len(mean)) for mean in MEANS] 47 | PRIORS = None 48 | 49 | 50 | def create_model_brick(): 51 | encoder_mapping = MLP( 52 | dims=[2 * INPUT_DIM, GEN_HIDDEN, GEN_HIDDEN, NLAT], 53 | activations=[Sequence([BatchNormalization(GEN_HIDDEN).apply, 54 | GEN_ACTIVATION().apply], 55 | name='encoder_h1'), 56 | Sequence([BatchNormalization(GEN_HIDDEN).apply, 57 | GEN_ACTIVATION().apply], 58 | name='encoder_h2'), 59 | Identity(name='encoder_out')], 60 | use_bias=False, 61 | name='encoder_mapping') 62 | encoder = COVConditional(encoder_mapping, (INPUT_DIM,), name='encoder') 63 | 64 | decoder_mapping = MLP( 65 | dims=[NLAT, GEN_HIDDEN, GEN_HIDDEN, GEN_HIDDEN, GEN_HIDDEN, INPUT_DIM], 66 | activations=[Sequence([BatchNormalization(GEN_HIDDEN).apply, 67 | GEN_ACTIVATION().apply], 68 | name='decoder_h1'), 69 | Sequence([BatchNormalization(GEN_HIDDEN).apply, 70 | GEN_ACTIVATION().apply], 71 | name='decoder_h2'), 72 | Sequence([BatchNormalization(GEN_HIDDEN).apply, 73 | GEN_ACTIVATION().apply], 74 | name='decoder_h3'), 75 | Sequence([BatchNormalization(GEN_HIDDEN).apply, 76 | GEN_ACTIVATION().apply], 77 | name='decoder_h4'), 78 | Identity(name='decoder_out')], 79 | use_bias=False, 80 | name='decoder_mapping') 81 | decoder = DeterministicConditional(decoder_mapping, name='decoder') 82 | 83 | x_discriminator = Identity(name='x_discriminator') 84 | z_discriminator = Identity(name='z_discriminator') 85 | joint_discriminator = Sequence( 86 | application_methods=[ 87 | LinearMaxout( 88 | input_dim=INPUT_DIM + NLAT, 89 | output_dim=DISC_HIDDEN, 90 | num_pieces=MAXOUT_PIECES, 91 | weights_init=GAUSSIAN_INIT, 92 | biases_init=ZERO_INIT, 93 | name='discriminator_h1').apply, 94 | LinearMaxout( 95 | input_dim=DISC_HIDDEN, 96 | output_dim=DISC_HIDDEN, 97 | num_pieces=MAXOUT_PIECES, 98 | weights_init=GAUSSIAN_INIT, 99 | biases_init=ZERO_INIT, 100 | name='discriminator_h2').apply, 101 | LinearMaxout( 102 | input_dim=DISC_HIDDEN, 103 | output_dim=DISC_HIDDEN, 104 | num_pieces=MAXOUT_PIECES, 105 | weights_init=GAUSSIAN_INIT, 106 | biases_init=ZERO_INIT, 107 | name='discriminator_h3').apply, 108 | Linear( 109 | input_dim=DISC_HIDDEN, 110 | output_dim=1, 111 | weights_init=GAUSSIAN_INIT, 112 | biases_init=ZERO_INIT, 113 | name='discriminator_out').apply], 114 | name='joint_discriminator') 115 | discriminator = XZJointDiscriminator( 116 | x_discriminator, z_discriminator, joint_discriminator, 117 | name='discriminator') 118 | 119 | ali = ALI(encoder=encoder, decoder=decoder, discriminator=discriminator, 120 | weights_init=GAUSSIAN_INIT, biases_init=ZERO_INIT, name='ali') 121 | ali.push_allocation_config() 122 | encoder_mapping.linear_transformations[-1].use_bias = True 123 | decoder_mapping.linear_transformations[-1].use_bias = True 124 | ali.initialize() 125 | 126 | return ali 127 | 128 | 129 | def create_models(): 130 | ali = create_model_brick() 131 | x = tensor.matrix('features') 132 | z = ali.theano_rng.normal(size=(x.shape[0], NLAT)) 133 | 134 | def _create_model(with_dropout): 135 | cg = ComputationGraph(ali.compute_losses(x, z)) 136 | if with_dropout: 137 | inputs = VariableFilter( 138 | bricks=ali.discriminator.joint_discriminator.children[1:], 139 | roles=[INPUT])(cg.variables) 140 | cg = apply_dropout(cg, inputs, 0.5) 141 | inputs = VariableFilter( 142 | bricks=[ali.discriminator.joint_discriminator], 143 | roles=[INPUT])(cg.variables) 144 | cg = apply_dropout(cg, inputs, 0.2) 145 | return Model(cg.outputs) 146 | 147 | model = _create_model(with_dropout=False) 148 | with batch_normalization(ali): 149 | bn_model = _create_model(with_dropout=False) 150 | 151 | pop_updates = list( 152 | set(get_batch_normalization_updates(bn_model, allow_duplicates=True))) 153 | bn_updates = [(p, m * 0.05 + p * 0.95) for p, m in pop_updates] 154 | 155 | return model, bn_model, bn_updates 156 | 157 | 158 | def create_main_loop(save_path): 159 | model, bn_model, bn_updates = create_models() 160 | ali, = bn_model.top_bricks 161 | discriminator_loss, generator_loss = bn_model.outputs 162 | step_rule = Adam(learning_rate=LEARNING_RATE, beta1=BETA1) 163 | algorithm = ali_algorithm(discriminator_loss, ali.discriminator_parameters, 164 | step_rule, generator_loss, 165 | ali.generator_parameters, step_rule) 166 | algorithm.add_updates(bn_updates) 167 | streams = create_gaussian_mixture_data_streams( 168 | batch_size=BATCH_SIZE, monitoring_batch_size=MONITORING_BATCH_SIZE, 169 | means=MEANS, variances=VARIANCES, priors=PRIORS) 170 | main_loop_stream, train_monitor_stream, valid_monitor_stream = streams 171 | bn_monitored_variables = ( 172 | [v for v in bn_model.auxiliary_variables if 'norm' not in v.name] + 173 | bn_model.outputs) 174 | monitored_variables = ( 175 | [v for v in model.auxiliary_variables if 'norm' not in v.name] + 176 | model.outputs) 177 | extensions = [ 178 | Timing(), 179 | FinishAfter(after_n_epochs=NUM_EPOCHS), 180 | DataStreamMonitoring( 181 | bn_monitored_variables, train_monitor_stream, prefix="train", 182 | updates=bn_updates), 183 | DataStreamMonitoring( 184 | monitored_variables, valid_monitor_stream, prefix="valid"), 185 | Checkpoint(save_path, after_epoch=True, after_training=True, 186 | use_cpickle=True), 187 | ProgressBar(), 188 | Printing(), 189 | ] 190 | main_loop = MainLoop(model=bn_model, data_stream=main_loop_stream, 191 | algorithm=algorithm, extensions=extensions) 192 | return main_loop 193 | 194 | if __name__ == '__main__': 195 | logging.basicConfig(level=logging.INFO) 196 | parser = argparse.ArgumentParser(description='Train ALI on MOG') 197 | parser.add_argument("--save-path", type=str, 198 | default='ali_mixture_prime.tar', 199 | help="main loop save path") 200 | args = parser.parse_args() 201 | main_loop = create_main_loop(args.save_path) 202 | main_loop.run() 203 | -------------------------------------------------------------------------------- /experiments/ali_celeba_conditional.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | from blocks.algorithms import Adam 5 | from blocks.bricks import LeakyRectifier, Logistic 6 | from blocks.bricks.conv import ConvolutionalSequence 7 | from blocks.extensions import FinishAfter, Timing, Printing, ProgressBar 8 | from blocks.extensions.monitoring import DataStreamMonitoring 9 | from blocks.extensions.saveload import Checkpoint 10 | from blocks.filter import VariableFilter 11 | from blocks.graph import ComputationGraph, apply_dropout 12 | from blocks.graph.bn import (batch_normalization, 13 | get_batch_normalization_updates) 14 | from blocks.initialization import IsotropicGaussian, Constant 15 | from blocks.main_loop import MainLoop 16 | from blocks.model import Model 17 | from blocks.roles import INPUT 18 | from theano import tensor 19 | 20 | from ali.algorithms import ali_algorithm 21 | from ali.conditional_bricks import (EncoderMapping, Decoder, 22 | GaussianConditional, XZYJointDiscriminator, 23 | ConditionalALI, ) 24 | from ali.streams import create_celeba_data_streams 25 | from ali.utils import get_log_odds, conv_brick, conv_transpose_brick, bn_brick 26 | 27 | BATCH_SIZE = 128 28 | MONITORING_BATCH_SIZE = 128 29 | NUM_EPOCHS = 123 30 | IMAGE_SIZE = (64, 64) 31 | NUM_CHANNELS = 3 32 | NLAT = 256 33 | NCLASSES = 40 34 | NEMB = 256 35 | 36 | GAUSSIAN_INIT = IsotropicGaussian(std=0.01) 37 | ZERO_INIT = Constant(0) 38 | LEARNING_RATE = 1e-4 39 | BETA1 = 0.5 40 | LEAK = 0.02 41 | 42 | 43 | def create_model_brick(): 44 | # Encoder 45 | enc_layers = [ 46 | conv_brick(2, 1, 64), bn_brick(), LeakyRectifier(leak=LEAK), 47 | conv_brick(7, 2, 128), bn_brick(), LeakyRectifier(leak=LEAK), 48 | conv_brick(5, 2, 256), bn_brick(), LeakyRectifier(leak=LEAK), 49 | conv_brick(7, 2, 256), bn_brick(), LeakyRectifier(leak=LEAK), 50 | conv_brick(4, 1, 512), bn_brick(), LeakyRectifier(leak=LEAK), 51 | conv_brick(1, 1, 2 * NLAT)] 52 | 53 | encoder_mapping = EncoderMapping(layers=enc_layers, 54 | num_channels=NUM_CHANNELS, 55 | n_emb=NEMB, 56 | image_size=IMAGE_SIZE, weights_init=GAUSSIAN_INIT, 57 | biases_init=ZERO_INIT, 58 | use_bias=False) 59 | 60 | encoder = GaussianConditional(encoder_mapping, name='encoder') 61 | # Decoder 62 | dec_layers = [ 63 | conv_transpose_brick(4, 1, 512), bn_brick(), LeakyRectifier(leak=LEAK), 64 | conv_transpose_brick(7, 2, 256), bn_brick(), LeakyRectifier(leak=LEAK), 65 | conv_transpose_brick(5, 2, 256), bn_brick(), LeakyRectifier(leak=LEAK), 66 | conv_transpose_brick(7, 2, 128), bn_brick(), LeakyRectifier(leak=LEAK), 67 | conv_transpose_brick(2, 1, 64), bn_brick(), LeakyRectifier(leak=LEAK), 68 | conv_brick(1, 1, NUM_CHANNELS), Logistic()] 69 | 70 | decoder = Decoder( 71 | layers=dec_layers, num_channels=NLAT + NEMB, image_size=(1, 1), use_bias=False, 72 | name='decoder_mapping') 73 | # Discriminator 74 | layers = [ 75 | conv_brick(2, 1, 64), LeakyRectifier(leak=LEAK), 76 | conv_brick(7, 2, 128), bn_brick(), LeakyRectifier(leak=LEAK), 77 | conv_brick(5, 2, 256), bn_brick(), LeakyRectifier(leak=LEAK), 78 | conv_brick(7, 2, 256), bn_brick(), LeakyRectifier(leak=LEAK), 79 | conv_brick(4, 1, 512), bn_brick(), LeakyRectifier(leak=LEAK)] 80 | x_discriminator = ConvolutionalSequence( 81 | layers=layers, num_channels=NUM_CHANNELS, image_size=IMAGE_SIZE, 82 | use_bias=False, name='x_discriminator') 83 | x_discriminator.push_allocation_config() 84 | 85 | layers = [ 86 | conv_brick(1, 1, 1024), LeakyRectifier(leak=LEAK), 87 | conv_brick(1, 1, 1024), LeakyRectifier(leak=LEAK)] 88 | z_discriminator = ConvolutionalSequence( 89 | layers=layers, num_channels=NLAT, image_size=(1, 1), use_bias=False, 90 | name='z_discriminator') 91 | z_discriminator.push_allocation_config() 92 | 93 | layers = [ 94 | conv_brick(1, 1, 2048), LeakyRectifier(leak=LEAK), 95 | conv_brick(1, 1, 2048), LeakyRectifier(leak=LEAK), 96 | conv_brick(1, 1, 1)] 97 | joint_discriminator = ConvolutionalSequence( 98 | layers=layers, 99 | num_channels=(x_discriminator.get_dim('output')[0] + 100 | z_discriminator.get_dim('output')[0] + 101 | NEMB), 102 | image_size=(1, 1), 103 | name='joint_discriminator') 104 | 105 | discriminator = XZYJointDiscriminator( 106 | x_discriminator, z_discriminator, joint_discriminator, 107 | name='discriminator') 108 | 109 | ali = ConditionalALI(encoder, decoder, discriminator, 110 | n_cond=NCLASSES, n_emb=NEMB, 111 | weights_init=GAUSSIAN_INIT, biases_init=ZERO_INIT, 112 | name='ali') 113 | ali.push_allocation_config() 114 | encoder_mapping.layers[-1].use_bias = True 115 | encoder_mapping.layers[-1].tied_biases = False 116 | decoder.layers[-2].use_bias = True 117 | decoder.layers[-2].tied_biases = False 118 | x_discriminator.layers[0].use_bias = True 119 | x_discriminator.layers[0].tied_biases = True 120 | ali.initialize() 121 | raw_marginals, = next( 122 | create_celeba_data_streams(500, 500)[0].get_epoch_iterator()) 123 | b_value = get_log_odds(raw_marginals) 124 | decoder.layers[-2].b.set_value(b_value) 125 | 126 | return ali 127 | 128 | 129 | def create_models(): 130 | ali = create_model_brick() 131 | x = tensor.tensor4('features') 132 | y = tensor.matrix('targets') 133 | z = ali.theano_rng.normal(size=(x.shape[0], NLAT, 1, 1)) 134 | 135 | def _create_model(with_dropout): 136 | cg = ComputationGraph(ali.compute_losses(x, z, y)) 137 | if with_dropout: 138 | inputs = VariableFilter( 139 | bricks=([ali.discriminator.x_discriminator.layers[0]] + 140 | ali.discriminator.x_discriminator.layers[2::3] + 141 | ali.discriminator.z_discriminator.layers[::2] + 142 | ali.discriminator.joint_discriminator.layers[::2]), 143 | roles=[INPUT])(cg.variables) 144 | cg = apply_dropout(cg, inputs, 0.2) 145 | return Model(cg.outputs) 146 | 147 | model = _create_model(with_dropout=False) 148 | with batch_normalization(ali): 149 | bn_model = _create_model(with_dropout=True) 150 | 151 | pop_updates = list( 152 | set(get_batch_normalization_updates(bn_model, allow_duplicates=True))) 153 | bn_updates = [(p, m * 0.05 + p * 0.95) for p, m in pop_updates] 154 | 155 | return model, bn_model, bn_updates 156 | 157 | 158 | def create_main_loop(save_path): 159 | model, bn_model, bn_updates = create_models() 160 | ali, = bn_model.top_bricks 161 | discriminator_loss, generator_loss = bn_model.outputs 162 | 163 | step_rule = Adam(learning_rate=LEARNING_RATE, beta1=BETA1) 164 | algorithm = ali_algorithm(discriminator_loss, ali.discriminator_parameters, 165 | step_rule, generator_loss, 166 | ali.generator_parameters, step_rule) 167 | algorithm.add_updates(bn_updates) 168 | streams = create_celeba_data_streams(BATCH_SIZE, MONITORING_BATCH_SIZE, 169 | sources=('features', 'targets')) 170 | main_loop_stream, train_monitor_stream, valid_monitor_stream = streams 171 | bn_monitored_variables = ( 172 | [v for v in bn_model.auxiliary_variables if 'norm' not in v.name] + 173 | bn_model.outputs) 174 | monitored_variables = ( 175 | [v for v in model.auxiliary_variables if 'norm' not in v.name] + 176 | model.outputs) 177 | extensions = [ 178 | Timing(), 179 | FinishAfter(after_n_epochs=NUM_EPOCHS), 180 | DataStreamMonitoring( 181 | bn_monitored_variables, train_monitor_stream, prefix="train", 182 | updates=bn_updates), 183 | DataStreamMonitoring( 184 | monitored_variables, valid_monitor_stream, prefix="valid"), 185 | Checkpoint(save_path, after_epoch=True, after_training=True, 186 | use_cpickle=True), 187 | ProgressBar(), 188 | Printing(), 189 | ] 190 | main_loop = MainLoop(model=bn_model, data_stream=main_loop_stream, 191 | algorithm=algorithm, extensions=extensions) 192 | return main_loop 193 | 194 | 195 | if __name__ == "__main__": 196 | logging.basicConfig(level=logging.INFO) 197 | parser = argparse.ArgumentParser(description="Train ALI on CelebA") 198 | parser.add_argument("--save-path", type=str, default='ali_conditional_celeba.tar', 199 | help="main loop save path") 200 | args = parser.parse_args() 201 | create_main_loop(args.save_path).run() 202 | -------------------------------------------------------------------------------- /paper/iclr2017_conference.sty: -------------------------------------------------------------------------------- 1 | %%%% ICLR Macros (LaTex) 2 | %%%% Adapted by Hugo Larochelle from the NIPS stylefile Macros 3 | %%%% Style File 4 | %%%% Dec 12, 1990 Rev Aug 14, 1991; Sept, 1995; April, 1997; April, 1999; October 2014 5 | 6 | % This file can be used with Latex2e whether running in main mode, or 7 | % 2.09 compatibility mode. 8 | % 9 | % If using main mode, you need to include the commands 10 | % \documentclass{article} 11 | % \usepackage{iclr14submit_e,times} 12 | % 13 | 14 | % Change the overall width of the page. If these parameters are 15 | % changed, they will require corresponding changes in the 16 | % maketitle section. 17 | % 18 | \usepackage{eso-pic} % used by \AddToShipoutPicture 19 | \RequirePackage{fancyhdr} 20 | \RequirePackage{natbib} 21 | 22 | % modification to natbib citations 23 | \setcitestyle{authoryear,round,citesep={;},aysep={,},yysep={;}} 24 | 25 | \renewcommand{\topfraction}{0.95} % let figure take up nearly whole page 26 | \renewcommand{\textfraction}{0.05} % let figure take up nearly whole page 27 | 28 | % Define iclrfinal, set to true if iclrfinalcopy is defined 29 | \newif\ificlrfinal 30 | \iclrfinalfalse 31 | \def\iclrfinalcopy{\iclrfinaltrue} 32 | \font\iclrtenhv = phvb at 8pt 33 | 34 | % Specify the dimensions of each page 35 | 36 | \setlength{\paperheight}{11in} 37 | \setlength{\paperwidth}{8.5in} 38 | 39 | 40 | \oddsidemargin .5in % Note \oddsidemargin = \evensidemargin 41 | \evensidemargin .5in 42 | \marginparwidth 0.07 true in 43 | %\marginparwidth 0.75 true in 44 | %\topmargin 0 true pt % Nominal distance from top of page to top of 45 | %\topmargin 0.125in 46 | \topmargin -0.625in 47 | \addtolength{\headsep}{0.25in} 48 | \textheight 9.0 true in % Height of text (including footnotes & figures) 49 | \textwidth 5.5 true in % Width of text line. 50 | \widowpenalty=10000 51 | \clubpenalty=10000 52 | 53 | % \thispagestyle{empty} \pagestyle{empty} 54 | \flushbottom \sloppy 55 | 56 | % We're never going to need a table of contents, so just flush it to 57 | % save space --- suggested by drstrip@sandia-2 58 | \def\addcontentsline#1#2#3{} 59 | 60 | % Title stuff, taken from deproc. 61 | \def\maketitle{\par 62 | \begingroup 63 | \def\thefootnote{\fnsymbol{footnote}} 64 | \def\@makefnmark{\hbox to 0pt{$^{\@thefnmark}$\hss}} % for perfect author 65 | % name centering 66 | % The footnote-mark was overlapping the footnote-text, 67 | % added the following to fix this problem (MK) 68 | \long\def\@makefntext##1{\parindent 1em\noindent 69 | \hbox to1.8em{\hss $\m@th ^{\@thefnmark}$}##1} 70 | \@maketitle \@thanks 71 | \endgroup 72 | \setcounter{footnote}{0} 73 | \let\maketitle\relax \let\@maketitle\relax 74 | \gdef\@thanks{}\gdef\@author{}\gdef\@title{}\let\thanks\relax} 75 | 76 | % The toptitlebar has been raised to top-justify the first page 77 | 78 | \usepackage{fancyhdr} 79 | \pagestyle{fancy} 80 | \fancyhead{} 81 | 82 | % Title (includes both anonimized and non-anonimized versions) 83 | \def\@maketitle{\vbox{\hsize\textwidth 84 | %\linewidth\hsize \vskip 0.1in \toptitlebar \centering 85 | {\LARGE\sc \@title\par} 86 | %\bottomtitlebar % \vskip 0.1in % minus 87 | \ificlrfinal 88 | \lhead{Published as a conference paper at ICLR 2017} 89 | \def\And{\end{tabular}\hfil\linebreak[0]\hfil 90 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 91 | \def\AND{\end{tabular}\hfil\linebreak[4]\hfil 92 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 93 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\@author\end{tabular}% 94 | \else 95 | \lhead{Under review as a conference paper at ICLR 2017} 96 | \def\And{\end{tabular}\hfil\linebreak[0]\hfil 97 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 98 | \def\AND{\end{tabular}\hfil\linebreak[4]\hfil 99 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 100 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\@author\end{tabular}% 101 | % \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces 102 | %Anonymous Author(s) \\ 103 | %Affiliation \\ 104 | %Address \\ 105 | %\texttt{email} \\ 106 | %\end{tabular}% 107 | \fi 108 | \vskip 0.3in minus 0.1in}} 109 | 110 | \renewenvironment{abstract}{\vskip.075in\centerline{\large\sc 111 | Abstract}\vspace{0.5ex}\begin{quote}}{\par\end{quote}\vskip 1ex} 112 | 113 | % sections with less space 114 | \def\section{\@startsection {section}{1}{\z@}{-2.0ex plus 115 | -0.5ex minus -.2ex}{1.5ex plus 0.3ex 116 | minus0.2ex}{\large\sc\raggedright}} 117 | 118 | \def\subsection{\@startsection{subsection}{2}{\z@}{-1.8ex plus 119 | -0.5ex minus -.2ex}{0.8ex plus .2ex}{\normalsize\sc\raggedright}} 120 | \def\subsubsection{\@startsection{subsubsection}{3}{\z@}{-1.5ex 121 | plus -0.5ex minus -.2ex}{0.5ex plus 122 | .2ex}{\normalsize\sc\raggedright}} 123 | \def\paragraph{\@startsection{paragraph}{4}{\z@}{1.5ex plus 124 | 0.5ex minus .2ex}{-1em}{\normalsize\bf}} 125 | \def\subparagraph{\@startsection{subparagraph}{5}{\z@}{1.5ex plus 126 | 0.5ex minus .2ex}{-1em}{\normalsize\sc}} 127 | \def\subsubsubsection{\vskip 128 | 5pt{\noindent\normalsize\rm\raggedright}} 129 | 130 | 131 | % Footnotes 132 | \footnotesep 6.65pt % 133 | \skip\footins 9pt plus 4pt minus 2pt 134 | \def\footnoterule{\kern-3pt \hrule width 12pc \kern 2.6pt } 135 | \setcounter{footnote}{0} 136 | 137 | % Lists and paragraphs 138 | \parindent 0pt 139 | \topsep 4pt plus 1pt minus 2pt 140 | \partopsep 1pt plus 0.5pt minus 0.5pt 141 | \itemsep 2pt plus 1pt minus 0.5pt 142 | \parsep 2pt plus 1pt minus 0.5pt 143 | \parskip .5pc 144 | 145 | 146 | %\leftmargin2em 147 | \leftmargin3pc 148 | \leftmargini\leftmargin \leftmarginii 2em 149 | \leftmarginiii 1.5em \leftmarginiv 1.0em \leftmarginv .5em 150 | 151 | %\labelsep \labelsep 5pt 152 | 153 | \def\@listi{\leftmargin\leftmargini} 154 | \def\@listii{\leftmargin\leftmarginii 155 | \labelwidth\leftmarginii\advance\labelwidth-\labelsep 156 | \topsep 2pt plus 1pt minus 0.5pt 157 | \parsep 1pt plus 0.5pt minus 0.5pt 158 | \itemsep \parsep} 159 | \def\@listiii{\leftmargin\leftmarginiii 160 | \labelwidth\leftmarginiii\advance\labelwidth-\labelsep 161 | \topsep 1pt plus 0.5pt minus 0.5pt 162 | \parsep \z@ \partopsep 0.5pt plus 0pt minus 0.5pt 163 | \itemsep \topsep} 164 | \def\@listiv{\leftmargin\leftmarginiv 165 | \labelwidth\leftmarginiv\advance\labelwidth-\labelsep} 166 | \def\@listv{\leftmargin\leftmarginv 167 | \labelwidth\leftmarginv\advance\labelwidth-\labelsep} 168 | \def\@listvi{\leftmargin\leftmarginvi 169 | \labelwidth\leftmarginvi\advance\labelwidth-\labelsep} 170 | 171 | \abovedisplayskip 7pt plus2pt minus5pt% 172 | \belowdisplayskip \abovedisplayskip 173 | \abovedisplayshortskip 0pt plus3pt% 174 | \belowdisplayshortskip 4pt plus3pt minus3pt% 175 | 176 | % Less leading in most fonts (due to the narrow columns) 177 | % The choices were between 1-pt and 1.5-pt leading 178 | %\def\@normalsize{\@setsize\normalsize{11pt}\xpt\@xpt} % got rid of @ (MK) 179 | \def\normalsize{\@setsize\normalsize{11pt}\xpt\@xpt} 180 | \def\small{\@setsize\small{10pt}\ixpt\@ixpt} 181 | \def\footnotesize{\@setsize\footnotesize{10pt}\ixpt\@ixpt} 182 | \def\scriptsize{\@setsize\scriptsize{8pt}\viipt\@viipt} 183 | \def\tiny{\@setsize\tiny{7pt}\vipt\@vipt} 184 | \def\large{\@setsize\large{14pt}\xiipt\@xiipt} 185 | \def\Large{\@setsize\Large{16pt}\xivpt\@xivpt} 186 | \def\LARGE{\@setsize\LARGE{20pt}\xviipt\@xviipt} 187 | \def\huge{\@setsize\huge{23pt}\xxpt\@xxpt} 188 | \def\Huge{\@setsize\Huge{28pt}\xxvpt\@xxvpt} 189 | 190 | \def\toptitlebar{\hrule height4pt\vskip .25in\vskip-\parskip} 191 | 192 | \def\bottomtitlebar{\vskip .29in\vskip-\parskip\hrule height1pt\vskip 193 | .09in} % 194 | %Reduced second vskip to compensate for adding the strut in \@author 195 | 196 | 197 | %% % Vertical Ruler 198 | %% % This code is, largely, from the CVPR 2010 conference style file 199 | %% % ----- define vruler 200 | %% \makeatletter 201 | %% \newbox\iclrrulerbox 202 | %% \newcount\iclrrulercount 203 | %% \newdimen\iclrruleroffset 204 | %% \newdimen\cv@lineheight 205 | %% \newdimen\cv@boxheight 206 | %% \newbox\cv@tmpbox 207 | %% \newcount\cv@refno 208 | %% \newcount\cv@tot 209 | %% % NUMBER with left flushed zeros \fillzeros[] 210 | %% \newcount\cv@tmpc@ \newcount\cv@tmpc 211 | %% \def\fillzeros[#1]#2{\cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi 212 | %% \cv@tmpc=1 % 213 | %% \loop\ifnum\cv@tmpc@<10 \else \divide\cv@tmpc@ by 10 \advance\cv@tmpc by 1 \fi 214 | %% \ifnum\cv@tmpc@=10\relax\cv@tmpc@=11\relax\fi \ifnum\cv@tmpc@>10 \repeat 215 | %% \ifnum#2<0\advance\cv@tmpc1\relax-\fi 216 | %% \loop\ifnum\cv@tmpc<#1\relax0\advance\cv@tmpc1\relax\fi \ifnum\cv@tmpc<#1 \repeat 217 | %% \cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi \relax\the\cv@tmpc@}% 218 | %% % \makevruler[][][][][] 219 | %% \def\makevruler[#1][#2][#3][#4][#5]{\begingroup\offinterlineskip 220 | %% \textheight=#5\vbadness=10000\vfuzz=120ex\overfullrule=0pt% 221 | %% \global\setbox\iclrrulerbox=\vbox to \textheight{% 222 | %% {\parskip=0pt\hfuzz=150em\cv@boxheight=\textheight 223 | %% \cv@lineheight=#1\global\iclrrulercount=#2% 224 | %% \cv@tot\cv@boxheight\divide\cv@tot\cv@lineheight\advance\cv@tot2% 225 | %% \cv@refno1\vskip-\cv@lineheight\vskip1ex% 226 | %% \loop\setbox\cv@tmpbox=\hbox to0cm{{\iclrtenhv\hfil\fillzeros[#4]\iclrrulercount}}% 227 | %% \ht\cv@tmpbox\cv@lineheight\dp\cv@tmpbox0pt\box\cv@tmpbox\break 228 | %% \advance\cv@refno1\global\advance\iclrrulercount#3\relax 229 | %% \ifnum\cv@refno<\cv@tot\repeat}}\endgroup}% 230 | %% \makeatother 231 | %% % ----- end of vruler 232 | 233 | %% % \makevruler[][][][][] 234 | %% \def\iclrruler#1{\makevruler[12pt][#1][1][3][0.993\textheight]\usebox{\iclrrulerbox}} 235 | %% \AddToShipoutPicture{% 236 | %% \ificlrfinal\else 237 | %% \iclrruleroffset=\textheight 238 | %% \advance\iclrruleroffset by -3.7pt 239 | %% \color[rgb]{.7,.7,.7} 240 | %% \AtTextUpperLeft{% 241 | %% \put(\LenToUnit{-35pt},\LenToUnit{-\iclrruleroffset}){%left ruler 242 | %% \iclrruler{\iclrrulercount}} 243 | %% } 244 | %% \fi 245 | %% } 246 | %%% To add a vertical bar on the side 247 | %\AddToShipoutPicture{ 248 | %\AtTextLowerLeft{ 249 | %\hspace*{-1.8cm} 250 | %\colorbox[rgb]{0.7,0.7,0.7}{\small \parbox[b][\textheight]{0.1cm}{}}} 251 | %} 252 | 253 | -------------------------------------------------------------------------------- /paper/bibliography.bib: -------------------------------------------------------------------------------- 1 | @article{lin1991divergence, 2 | title={Divergence measures based on the Shannon entropy}, 3 | author={Lin, Jianhua}, 4 | journal={Information Theory, IEEE Transactions on}, 5 | volume={37}, 6 | number={1}, 7 | pages={145--151}, 8 | year={1991}, 9 | publisher={IEEE} 10 | } 11 | 12 | @book{vapnik1998, 13 | title={Statistical Learning Theory}, 14 | author={Vapnik, Vladimir N.}, 15 | publisher={Wiley-Interscience}, 16 | year={1998}, 17 | } 18 | 19 | @misc{lecun1998mnist, 20 | title={The MNIST database of handwritten digits}, 21 | author={LeCun, Yann and Cortes, Corinna and Burges, Christopher JC}, 22 | year={1998} 23 | } 24 | 25 | @misc{krizhevsky2009learning, 26 | title={Learning multiple layers of features from tiny images}, 27 | author={Krizhevsky, Alex and Hinton, Geoffrey}, 28 | year={2009}, 29 | publisher={Citeseer} 30 | } 31 | 32 | @inproceedings{bergstra2010theano, 33 | title={Theano: a CPU and GPU math expression compiler}, 34 | author={Bergstra, James and Breuleux, Olivier and Bastien, Fr{\'e}d{\'e}ric 35 | and Lamblin, Pascal and Pascanu, Razvan and Desjardins, Guillaume and 36 | Turian, Joseph and Warde-Farley, David and Bengio, Yoshua}, 37 | booktitle={Proceedings of the Python for scientific computing conference 38 | (SciPy)}, 39 | volume={4}, 40 | pages={3}, 41 | year={2010}, 42 | organization={Austin, TX} 43 | } 44 | 45 | @inproceedings{netzer2011reading, 46 | title={Reading digits in natural images with unsupervised feature learning}, 47 | author={Netzer, Yuval and Wang, Tao and Coates, Adam and Bissacco, Alessandro 48 | and Wu, Bo and Ng, Andrew Y}, 49 | booktitle={NIPS workshop on deep learning and unsupervised feature learning}, 50 | volume={2011}, 51 | pages={4}, 52 | year={2011}, 53 | organization={Granada, Spain} 54 | } 55 | 56 | @article{bastien2012theano, 57 | title={Theano: new features and speed improvements}, 58 | author={Bastien, Fr{\'e}d{\'e}ric and Lamblin, Pascal and Pascanu, Razvan and 59 | Bergstra, James and Goodfellow, Ian and Bergeron, Arnaud and Bouchard, 60 | Nicolas and Warde-Farley, David and Bengio, Yoshua}, 61 | journal={arXiv preprint arXiv:1211.5590}, 62 | year={2012} 63 | } 64 | 65 | @article{kingma2013fast, 66 | title={Fast gradient-based inference with continuous latent variable models in 67 | auxiliary form}, 68 | author={Kingma, Diederik P}, 69 | journal={arXiv preprint arXiv:1306.0733}, 70 | year={2013} 71 | } 72 | 73 | @article{bengio2013estimating, 74 | title={Estimating or propagating gradients through stochastic neurons for 75 | conditional computation}, 76 | author={Bengio, Yoshua and L{\'e}onard, Nicholas and Courville, Aaron}, 77 | journal={arXiv preprint arXiv:1308.3432}, 78 | year={2013} 79 | } 80 | 81 | @article{bengio2013deep, 82 | title={Deep generative stochastic networks trainable by backprop}, 83 | author={Bengio, Yoshua and Thibodeau-Laufer, Eric and Alain, Guillaume and 84 | Yosinski, Jason}, 85 | journal={arXiv preprint arXiv:1306.1091}, 86 | year={2013} 87 | } 88 | 89 | @article{kingma2013auto, 90 | title={Auto-encoding variational bayes}, 91 | author={Kingma, Diederik P and Welling, Max}, 92 | journal={arXiv preprint arXiv:1312.6114}, 93 | year={2013} 94 | } 95 | 96 | @article{goodfellow2013maxout, 97 | title={Maxout networks}, 98 | author={Goodfellow, Ian J and Warde-Farley, David and Mirza, Mehdi and 99 | Courville, Aaron and Bengio, Yoshua}, 100 | journal={arXiv preprint arXiv:1302.4389}, 101 | year={2013} 102 | } 103 | 104 | @article{rezende2014stochastic, 105 | title={Stochastic backpropagation and approximate inference in deep generative 106 | models}, 107 | author={Rezende, Danilo Jimenez and Mohamed, Shakir and Wierstra, Daan}, 108 | journal={arXiv preprint arXiv:1401.4082}, 109 | year={2014} 110 | } 111 | 112 | @inproceedings{goodfellow2014generative, 113 | title={Generative adversarial nets}, 114 | author={Goodfellow, Ian and Pouget-Abadie, Jean and Mirza, Mehdi and Xu, Bing 115 | and Warde-Farley, David and Ozair, Sherjil and Courville, Aaron and 116 | Bengio, Yoshua}, 117 | booktitle={Advances in Neural Information Processing Systems}, 118 | pages={2672--2680}, 119 | year={2014} 120 | } 121 | 122 | @article{salimans2014markov, 123 | title={Markov chain Monte Carlo and variational inference: Bridging the gap}, 124 | author={Salimans, Tim and Kingma, Diederik P and Welling, Max}, 125 | journal={arXiv preprint arXiv:1410.6460}, 126 | year={2014} 127 | } 128 | 129 | @article{Salimans2016gan, 130 | author = {Tim Salimans and 131 | Ian J. Goodfellow and 132 | Wojciech Zaremba and 133 | Vicki Cheung and 134 | Alec Radford and 135 | Xi Chen}, 136 | title = {Improved Techniques for Training GANs}, 137 | journal = {arXiv preprint arXiv:1606.03498}, 138 | year = {2016} 139 | } 140 | 141 | @inproceedings{kingma2014semi, 142 | title={Semi-supervised learning with deep generative models}, 143 | author={Kingma, Diederik P and Mohamed, Shakir and Rezende, Danilo Jimenez and 144 | Welling, Max}, 145 | booktitle={Advances in Neural Information Processing Systems}, 146 | pages={3581--3589}, 147 | year={2014} 148 | } 149 | 150 | @article{kingma2014adam, 151 | title={Adam: A method for stochastic optimization}, 152 | author={Kingma, Diederik and Ba, Jimmy}, 153 | journal={arXiv preprint arXiv:1412.6980}, 154 | year={2014} 155 | } 156 | 157 | @inproceedings{liu2015deep, 158 | title={Deep learning face attributes in the wild}, 159 | author={Liu, Ziwei and Luo, Ping and Wang, Xiaogang and Tang, Xiaoou}, 160 | booktitle={Proceedings of the IEEE International Conference on Computer Vision}, 161 | pages={3730--3738}, 162 | year={2015} 163 | } 164 | 165 | @inproceedings{denton2015deep, 166 | title={Deep Generative Image Models using a Laplacian Pyramid of Adversarial 167 | Networks}, 168 | author={Denton, Emily L and Chintala, Soumith and Fergus, Rob and others}, 169 | booktitle={Advances in Neural Information Processing Systems}, 170 | pages={1486--1494}, 171 | year={2015} 172 | } 173 | 174 | @article{gregor2015draw, 175 | title={DRAW: A recurrent neural network for image generation}, 176 | author={Gregor, Karol and Danihelka, Ivo and Graves, Alex and Wierstra, Daan}, 177 | journal={arXiv preprint arXiv:1502.04623}, 178 | year={2015} 179 | } 180 | 181 | @article{radford2015unsupervised, 182 | title={Unsupervised Representation Learning with Deep Convolutional Generative 183 | Adversarial Networks}, 184 | author={Radford, Alec and Metz, Luke and Chintala, Soumith}, 185 | journal={arXiv preprint arXiv:1511.06434}, 186 | year={2015} 187 | } 188 | 189 | @article{makhzani2015adversarial, 190 | title={Adversarial Autoencoders}, 191 | author={Makhzani, Alireza and Shlens, Jonathon and Jaitly, Navdeep and 192 | Goodfellow, Ian}, 193 | journal={arXiv preprint arXiv:1511.05644}, 194 | year={2015} 195 | } 196 | 197 | @article{larsen2015autoencoding, 198 | title={Autoencoding beyond pixels using a learned similarity metric}, 199 | author={Larsen, Anders Boesen Lindbo and S{\o}nderby, S{\o}ren Kaae and 200 | Winther, Ole}, 201 | journal={arXiv preprint arXiv:1512.09300}, 202 | year={2015} 203 | } 204 | 205 | @article{russakovsky2015imagenet, 206 | title={Imagenet large scale visual recognition challenge}, 207 | author={Russakovsky, Olga and Deng, Jia and Su, Hao and Krause, Jonathan and 208 | Satheesh, Sanjeev and Ma, Sean and Huang, Zhiheng and Karpathy, Andrej 209 | and Khosla, Aditya and Bernstein, Michael and others}, 210 | journal={International Journal of Computer Vision}, 211 | volume={115}, 212 | number={3}, 213 | pages={211--252}, 214 | year={2015}, 215 | publisher={Springer} 216 | } 217 | 218 | @article{van2015blocks, 219 | title={Blocks and fuel: Frameworks for deep learning}, 220 | author={van Merri{\"e}nboer, Bart and Bahdanau, Dzmitry and Dumoulin, Vincent 221 | and Serdyuk, Dmitriy and Warde-Farley, David and Chorowski, Jan and 222 | Bengio, Yoshua}, 223 | journal={arXiv preprint arXiv:1506.00619}, 224 | year={2015} 225 | } 226 | 227 | @article{ioffe2015batch, 228 | title={Batch normalization: Accelerating deep network training by reducing 229 | internal covariate shift}, 230 | author={Ioffe, Sergey and Szegedy, Christian}, 231 | journal={arXiv preprint arXiv:1502.03167}, 232 | year={2015} 233 | } 234 | 235 | @article{zhao2015stacked, 236 | title={Stacked what-where auto-encoders}, 237 | author={Zhao, Junbo and Mathieu, Michael and Goroshin, Ross and Lecun, Yann}, 238 | journal={arXiv preprint arXiv:1506.02351}, 239 | year={2015} 240 | } 241 | 242 | @article{lamb2016discriminative, 243 | title={Discriminative Regularization for Generative Models}, 244 | author={Lamb, Alex and Dumoulin, Vincent and Courville, Aaron}, 245 | journal={arXiv preprint arXiv:1602.03220}, 246 | year={2016} 247 | } 248 | 249 | @article{dumoulin2016guide, 250 | title={A guide to convolution arithmetic for deep learning}, 251 | author={Dumoulin, Vincent and Visin, Francesco}, 252 | journal={arXiv preprint arXiv:1603.07285}, 253 | year={2016} 254 | } 255 | 256 | @article{dosovitskiy2016generating, 257 | title={Generating images with perceptual similarity metrics based on deep 258 | networks}, 259 | author={Dosovitskiy, Alexey and Brox, Thomas}, 260 | journal={arXiv preprint arXiv:1602.02644}, 261 | year={2016} 262 | } 263 | 264 | @article{gregor2016towards, 265 | title={Towards Conceptual Compression}, 266 | author={Gregor, Karol and Besse, Frederic and Jimenez Rezende, Danilo and 267 | Danihelka, Ivo and Wierstra, Daan}, 268 | journal={arXiv preprint arXiv:1604.08772}, 269 | year={2016} 270 | } 271 | 272 | @article{theano2016theano, 273 | title={Theano: A {Python} framework for fast computation of mathematical 274 | expressions}, 275 | author={{Theano Development Team}}, 276 | journal={arXiv preprint arXiv:1605.02688}, 277 | year={2016} 278 | } 279 | 280 | @inproceedings{rezende2015, 281 | Publisher={JMLR Workshop and Conference Proceedings}, 282 | title={Variational Inference with Normalizing Flows}, 283 | Author={Rezende, Danilo and Mohamed, Shakir}, 284 | year={2015}, 285 | Booktitle={Proceedings of the 32nd International Conference on Machine 286 | Learning (ICML-15)}, 287 | Editor={David Blei and Francis Bach}, 288 | pages={1530-1538}, 289 | url={http://jmlr.org/proceedings/papers/v37/rezende15.pdf} 290 | } 291 | 292 | @article{Burda2015, 293 | author={Burda, Yuri and Gross, Roger B. and Salakhutdinov, Ruslan}, 294 | title={Importance Weighted Autoencoders}, 295 | journal={Proceeding of the International Conference on Learning 296 | Representations}, 297 | volume={arXiv:1509.00519}, 298 | year= {2016}, 299 | url={http://arxiv.org/abs/1509.00519}, 300 | timestamp={Thu, 01 Oct 2015 14:28:48 +0200}, 301 | biburl={http://dblp.uni-trier.de/rec/bib/journals/corr/BurdaGS15}, 302 | bibsource={dblp computer science bibliography, http://dblp.org} 303 | } 304 | 305 | @article{Theis2015, 306 | author={Theis, Lucas and van den Oord, Aron and Bethge, Matthias}, 307 | title={A note on the evaluation of generative models}, 308 | journal={arXiv preprint arXiv:1511.01844}, 309 | year={2015} 310 | } 311 | 312 | @article{maaloe2016auxiliary, 313 | title={Auxiliary Deep Generative Models}, 314 | author={Maal{\o}e, Lars and S{\o}nderby, Casper Kaae and S{\o}nderby, 315 | S{\o}ren Kaae and Winther, Ole}, 316 | journal={arXiv preprint arXiv:1602.05473}, 317 | year={2016} 318 | } 319 | 320 | @article{im2016generating, 321 | title={Generating images with recurrent adversarial networks}, 322 | author={Im, Daniel Jiwoong and Kim, Chris Dongjoo and Jiang, Hui and 323 | Memisevic, Roland}, 324 | journal={arXiv preprint arXiv:1602.05110}, 325 | year={2016} 326 | } 327 | 328 | @article{donahue2016adversarial, 329 | title={Adversarial Feature Learning}, 330 | author={Donahue, Jeff and Kr{\"a}henb{\"u}hl, Philipp and Darrell, Trevor}, 331 | journal={arXiv preprint arXiv:1605.09782}, 332 | year={2016} 333 | } 334 | 335 | @article{van2016pixelcnn, 336 | title={Conditional Image Generation with PixelCNN Decoders}, 337 | author={van den Oord, Aaron and Kalchbrenner, Nal and 338 | Vinyals, Oriol and Espeholt, Lasse and Graves, Alex and 339 | Kavukcuoglu, Koray}, 340 | journal={arXiv preprint arXiv:1606.05328}, 341 | year={2016} 342 | } 343 | 344 | @article{van2016pixel, 345 | title={Pixel Recurrent Neural Networks}, 346 | author={van den Oord, Aaron and Kalchbrenner, Nal and Kavukcuoglu, Koray}, 347 | journal={arXiv preprint arXiv:1601.06759}, 348 | year={2016} 349 | } 350 | 351 | @article{van2016wavenet, 352 | author = {van den Oord, Aaron and Dieleman, Sander and Zen, Heiga and Simonyan, Karen and 353 | Vinyals, Oriol and Graves, Alex and Kalchbrenner, Nal and W. Senior, Andrew and 354 | Kavukcuoglu, Koray}, 355 | title = {WaveNet: {A} Generative Model for Raw Audio}, 356 | journal = {arXiv preprint arXiv:1609.03499}, 357 | year = {2016}, 358 | } 359 | 360 | @article{brock2016neural, 361 | title={Neural Photo Editing with Introspective Adversarial Networks}, 362 | author={Brock, Andrew and Lim, Theodore and Ritchie, JM and Weston, Nick}, 363 | journal={arXiv preprint arXiv:1609.07093}, 364 | year={2016} 365 | } 366 | 367 | @article{ladder2015, 368 | title={Semi-Supervised Learning with Ladder Network}, 369 | author={Rasmus, Antti and Valpola, Harri and Honkala, Mikko and Berglund, Mathias and 370 | Raiko, Tapani}, 371 | journal={In Advances in Neural Information Processing Systems, 2015}, 372 | year={2015} 373 | } 374 | 375 | @article{catgan2015, 376 | author={Springenberg, Jost Tobias}, 377 | title={Unsupervised and Semi-supervised Learning with Categorical Generative Adversarial Networks}, 378 | journal={arXiv preprint arXiv:1511.06390}, 379 | year={2015} 380 | } 381 | 382 | @article{shi2016deconvolution, 383 | title={Is the deconvolution layer the same as a convolutional layer?}, 384 | author={Shi, Wenzhe and Caballero, Jose and Theis, Lucas and Huszar, Ferenc and Aitken, Andrew and Ledig, Christian and Wang, Zehan}, 385 | journal={arXiv preprint arXiv:1609.07009}, 386 | year={2016} 387 | } 388 | 389 | @misc{odena2016deconvolution, 390 | author = {Odena, Augustus and Dumoulin, Vincent and Olah, Chris}, 391 | title = {Deconvolution and Checkerboard Artifacts}, 392 | year = {2016}, 393 | howpublished = {http://distill.pub/2016/deconv-checkerboard/} 394 | } 395 | 396 | @inproceedings{chen2016infogan, 397 | title={Infogan: Interpretable representation learning by information 398 | maximizing generative adversarial nets}, 399 | author={Chen, Xi and Duan, Yan and Houthooft, Rein and Schulman, John and 400 | Sutskever, Ilya and Abbeel, Pieter}, 401 | booktitle={Advances in Neural Information Processing Systems}, 402 | pages={2172--2180}, 403 | year={2016} 404 | } 405 | 406 | @article{kingma2016improving, 407 | title={Improving variational inference with inverse autoregressive flow}, 408 | author={Kingma, Diederik P and Salimans, Tim and Welling, Max}, 409 | journal={arXiv preprint arXiv:1606.04934}, 410 | year={2016} 411 | } 412 | -------------------------------------------------------------------------------- /ali/bricks.py: -------------------------------------------------------------------------------- 1 | """ALI-related bricks.""" 2 | from theano import tensor 3 | from blocks.bricks.base import Brick, application, lazy 4 | from blocks.bricks.conv import ConvolutionalSequence 5 | from blocks.bricks.interfaces import Initializable, Random 6 | from blocks.select import Selector 7 | 8 | 9 | class ALI(Initializable, Random): 10 | """Adversarial learned inference brick. 11 | 12 | Parameters 13 | ---------- 14 | encoder : :class:`blocks.bricks.Brick` 15 | Encoder network. 16 | decoder : :class:`blocks.bricks.Brick` 17 | Decoder network. 18 | discriminator : :class:`blocks.bricks.Brick` 19 | Discriminator network taking :math:`x` and :math:`z` as input. 20 | 21 | """ 22 | def __init__(self, encoder, decoder, discriminator, **kwargs): 23 | self.encoder = encoder 24 | self.decoder = decoder 25 | self.discriminator = discriminator 26 | 27 | super(ALI, self).__init__(**kwargs) 28 | self.children.extend([self.encoder, self.decoder, self.discriminator]) 29 | 30 | @property 31 | def discriminator_parameters(self): 32 | return list( 33 | Selector([self.discriminator]).get_parameters().values()) 34 | 35 | @property 36 | def generator_parameters(self): 37 | return list( 38 | Selector([self.encoder, self.decoder]).get_parameters().values()) 39 | 40 | @application(inputs=['x', 'z_hat', 'x_tilde', 'z'], 41 | outputs=['data_preds', 'sample_preds']) 42 | def get_predictions(self, x, z_hat, x_tilde, z, application_call): 43 | # NOTE: the unbroadcasts act as a workaround for a weird broadcasting 44 | # bug when applying dropout 45 | input_x = tensor.unbroadcast( 46 | tensor.concatenate([x, x_tilde], axis=0), *range(x.ndim)) 47 | input_z = tensor.unbroadcast( 48 | tensor.concatenate([z_hat, z], axis=0), *range(x.ndim)) 49 | data_sample_preds = self.discriminator.apply(input_x, input_z) 50 | data_preds = data_sample_preds[:x.shape[0]] 51 | sample_preds = data_sample_preds[x.shape[0]:] 52 | 53 | application_call.add_auxiliary_variable( 54 | tensor.nnet.sigmoid(data_preds).mean(), name='data_accuracy') 55 | application_call.add_auxiliary_variable( 56 | (1 - tensor.nnet.sigmoid(sample_preds)).mean(), 57 | name='sample_accuracy') 58 | 59 | return data_preds, sample_preds 60 | 61 | @application(inputs=['x', 'z'], 62 | outputs=['discriminator_loss', 'generator_loss']) 63 | def compute_losses(self, x, z, application_call): 64 | z_hat = self.encoder.apply(x) 65 | x_tilde = self.decoder.apply(z) 66 | 67 | data_preds, sample_preds = self.get_predictions(x, z_hat, x_tilde, z) 68 | 69 | discriminator_loss = (tensor.nnet.softplus(-data_preds) + 70 | tensor.nnet.softplus(sample_preds)).mean() 71 | generator_loss = (tensor.nnet.softplus(data_preds) + 72 | tensor.nnet.softplus(-sample_preds)).mean() 73 | 74 | return discriminator_loss, generator_loss 75 | 76 | @application(inputs=['z'], outputs=['samples']) 77 | def sample(self, z): 78 | return self.decoder.apply(z) 79 | 80 | @application(inputs=['x'], outputs=['reconstructions']) 81 | def reconstruct(self, x): 82 | return self.decoder.apply(self.encoder.apply(x)) 83 | 84 | 85 | class COVConditional(Initializable, Random): 86 | """Change-of-variables conditional. 87 | 88 | Parameters 89 | ---------- 90 | mapping : :class:`blocks.bricks.Brick` 91 | Network mapping the concatenation of the input and a source of 92 | noise to the output. 93 | noise_shape : tuple of int 94 | Shape of the input noise. 95 | 96 | """ 97 | def __init__(self, mapping, noise_shape, **kwargs): 98 | self.mapping = mapping 99 | self.noise_shape = noise_shape 100 | 101 | super(COVConditional, self).__init__(**kwargs) 102 | self.children.extend([self.mapping]) 103 | 104 | def get_dim(self, name): 105 | if isinstance(self.mapping, ConvolutionalSequence): 106 | dim = self.mapping.get_dim(name) 107 | if name == 'input_': 108 | return (dim[0] - self.noise_shape[0],) + dim[1:] 109 | else: 110 | return dim 111 | else: 112 | if name == 'output': 113 | return self.mapping.output_dim 114 | elif name == 'input_': 115 | return self.mapping.input_dim - self.noise_shape[0] 116 | else: 117 | return self.mapping.get_dim(name) 118 | 119 | @application(inputs=['input_'], outputs=['output']) 120 | def apply(self, input_, application_call): 121 | epsilon = self.theano_rng.normal( 122 | size=(input_.shape[0],) + self.noise_shape) 123 | output = self.mapping.apply( 124 | tensor.concatenate([input_, epsilon], axis=1)) 125 | 126 | application_call.add_auxiliary_variable(output.mean(), name='avg') 127 | application_call.add_auxiliary_variable(output.std(), name='std') 128 | application_call.add_auxiliary_variable(output.min(), name='min') 129 | application_call.add_auxiliary_variable(output.max(), name='max') 130 | 131 | return output 132 | 133 | 134 | class GaussianConditional(Initializable, Random): 135 | """Gaussian conditional. 136 | 137 | Parameters 138 | ---------- 139 | mapping : :class:`blocks.bricks.Brick` 140 | Network predicting distribution parameters. It is expected to 141 | output a concatenation of :math:`\mu` and :math:`\log\sigma`. 142 | 143 | """ 144 | def __init__(self, mapping, **kwargs): 145 | self.mapping = mapping 146 | 147 | super(GaussianConditional, self).__init__(**kwargs) 148 | self.children.extend([self.mapping]) 149 | 150 | @property 151 | def _nlat(self): 152 | if isinstance(self.mapping, ConvolutionalSequence): 153 | return self.get_dim('output')[0] 154 | else: 155 | return self.get_dim('output') 156 | 157 | def get_dim(self, name): 158 | if isinstance(self.mapping, ConvolutionalSequence): 159 | dim = self.mapping.get_dim(name) 160 | if name == 'output': 161 | return (dim[0] // 2,) + dim[1:] 162 | else: 163 | return dim 164 | else: 165 | if name == 'output': 166 | return self.mapping.output_dim // 2 167 | elif name == 'input_': 168 | return self.mapping.input_dim 169 | else: 170 | return self.mapping.get_dim(name) 171 | 172 | @application(inputs=['input_'], outputs=['output']) 173 | def apply(self, input_, application_call): 174 | params = self.mapping.apply(input_) 175 | mu, log_sigma = params[:, :self._nlat], params[:, self._nlat:] 176 | sigma = tensor.exp(log_sigma) 177 | epsilon = self.theano_rng.normal(size=mu.shape) 178 | output = mu + sigma * epsilon 179 | 180 | application_call.add_auxiliary_variable(mu.mean(), name='mu_avg') 181 | application_call.add_auxiliary_variable(mu.std(), name='mu_std') 182 | application_call.add_auxiliary_variable(mu.min(), name='mu_min') 183 | application_call.add_auxiliary_variable(mu.max(), name='mu_max') 184 | 185 | application_call.add_auxiliary_variable(sigma.mean(), name='sigma_avg') 186 | application_call.add_auxiliary_variable(sigma.std(), name='sigma_std') 187 | application_call.add_auxiliary_variable(sigma.min(), name='sigma_min') 188 | application_call.add_auxiliary_variable(sigma.max(), name='sigma_max') 189 | 190 | return output 191 | 192 | 193 | class DeterministicConditional(Initializable, Random): 194 | """Deterministic conditional. 195 | 196 | Parameters 197 | ---------- 198 | mapping : :class:`blocks.bricks.Brick` 199 | Network producing the output of the conditional. 200 | 201 | """ 202 | def __init__(self, mapping, **kwargs): 203 | self.mapping = mapping 204 | 205 | super(DeterministicConditional, self).__init__(**kwargs) 206 | self.children.extend([self.mapping]) 207 | 208 | def get_dim(self, name): 209 | return self.mapping.get_dim(name) 210 | 211 | @application(inputs=['input_'], outputs=['output']) 212 | def apply(self, input_, application_call): 213 | output = self.mapping.apply(input_) 214 | 215 | application_call.add_auxiliary_variable(output.mean(), name='avg') 216 | application_call.add_auxiliary_variable(output.std(), name='std') 217 | application_call.add_auxiliary_variable(output.min(), name='min') 218 | application_call.add_auxiliary_variable(output.max(), name='max') 219 | 220 | return output 221 | 222 | 223 | class XZJointDiscriminator(Initializable): 224 | """Three-way discriminator. 225 | 226 | Parameters 227 | ---------- 228 | x_discriminator : :class:`blocks.bricks.Brick` 229 | Part of the discriminator taking :math:`x` as input. Its 230 | output will be concatenated with ``z_discriminator``'s output 231 | and fed to ``joint_discriminator``. 232 | z_discriminator : :class:`blocks.bricks.Brick` 233 | Part of the discriminator taking :math:`z` as input. Its 234 | output will be concatenated with ``x_discriminator``'s output 235 | and fed to ``joint_discriminator``. 236 | joint_discriminator : :class:`blocks.bricks.Brick` 237 | Part of the discriminator taking the concatenation of 238 | ``x_discriminator``'s and output ``z_discriminator``'s output 239 | as input and computing :math:`D(x, z)`. 240 | 241 | """ 242 | def __init__(self, x_discriminator, z_discriminator, joint_discriminator, 243 | **kwargs): 244 | self.x_discriminator = x_discriminator 245 | self.z_discriminator = z_discriminator 246 | self.joint_discriminator = joint_discriminator 247 | 248 | super(XZJointDiscriminator, self).__init__(**kwargs) 249 | self.children.extend([self.x_discriminator, self.z_discriminator, 250 | self.joint_discriminator]) 251 | 252 | @application(inputs=['x', 'z'], outputs=['output']) 253 | def apply(self, x, z): 254 | # NOTE: the unbroadcasts act as a workaround for a weird broadcasting 255 | # bug when applying dropout 256 | input_ = tensor.unbroadcast( 257 | tensor.concatenate( 258 | [self.x_discriminator.apply(x), self.z_discriminator.apply(z)], 259 | axis=1), 260 | *range(x.ndim)) 261 | return self.joint_discriminator.apply(input_) 262 | 263 | 264 | class GAN(Initializable, Random): 265 | """Generative adversarial networks. 266 | 267 | Parameters 268 | ---------- 269 | decoder : :class:`blocks.bricks.Brick` 270 | Decoder network. 271 | discriminator : :class:`blocks.bricks.Brick` 272 | Discriminator network. 273 | 274 | """ 275 | def __init__(self, decoder, discriminator, **kwargs): 276 | self.decoder = decoder 277 | self.discriminator = discriminator 278 | 279 | super(GAN, self).__init__(**kwargs) 280 | self.children.extend([self.decoder, self.discriminator]) 281 | 282 | @property 283 | def discriminator_parameters(self): 284 | return list( 285 | Selector([self.discriminator]).get_parameters().values()) 286 | 287 | @property 288 | def generator_parameters(self): 289 | return list( 290 | Selector([self.decoder]).get_parameters().values()) 291 | 292 | @application(inputs=['z'], outputs=['x_tilde']) 293 | def sample_x_tilde(self, z, application_call): 294 | x_tilde = self.decoder.apply(z) 295 | 296 | application_call.add_auxiliary_variable(x_tilde.mean(), name='avg') 297 | application_call.add_auxiliary_variable(x_tilde.std(), name='std') 298 | 299 | return x_tilde 300 | 301 | @application(inputs=['x', 'x_tilde'], 302 | outputs=['data_preds', 'sample_preds']) 303 | def get_predictions(self, x, x_tilde, application_call): 304 | # NOTE: the unbroadcasts act as a workaround for a weird broadcasting 305 | # bug when applying dropout 306 | data_sample_preds = self.discriminator.apply( 307 | tensor.unbroadcast(tensor.concatenate([x, x_tilde], axis=0), 308 | *range(x.ndim))) 309 | data_preds = data_sample_preds[:x.shape[0]] 310 | sample_preds = data_sample_preds[x.shape[0]:] 311 | 312 | application_call.add_auxiliary_variable( 313 | tensor.nnet.sigmoid(data_preds).mean(), name='data_accuracy') 314 | application_call.add_auxiliary_variable( 315 | (1 - tensor.nnet.sigmoid(sample_preds)).mean(), 316 | name='sample_accuracy') 317 | 318 | return data_preds, sample_preds 319 | 320 | @application(inputs=['x'], 321 | outputs=['discriminator_loss', 'generator_loss']) 322 | def compute_losses(self, x, z, application_call): 323 | x_tilde = self.sample_x_tilde(z) 324 | data_preds, sample_preds = self.get_predictions(x, x_tilde) 325 | 326 | discriminator_loss = (tensor.nnet.softplus(-data_preds) + 327 | tensor.nnet.softplus(sample_preds)).mean() 328 | generator_loss = (tensor.nnet.softplus(data_preds) + 329 | tensor.nnet.softplus(-sample_preds)).mean() 330 | 331 | return discriminator_loss, generator_loss 332 | 333 | @application(inputs=['z'], outputs=['samples']) 334 | def sample(self, z): 335 | return self.sample_x_tilde(z) 336 | 337 | 338 | class ConvMaxout(Brick): 339 | """Convolutional version of the Maxout activation. 340 | 341 | Parameters 342 | ---------- 343 | num_pieces : int 344 | Number of linear pieces. 345 | num_channels : int 346 | Number of input channels. 347 | image_size : (int, int), optional 348 | Input shape. Defaults to ``(None, None)``. 349 | 350 | """ 351 | @lazy(allocation=['num_pieces', 'num_channels']) 352 | def __init__(self, num_pieces, num_channels, image_size=(None, None), 353 | **kwargs): 354 | super(ConvMaxout, self).__init__(**kwargs) 355 | self.num_pieces = num_pieces 356 | self.num_channels = num_channels 357 | 358 | def get_dim(self, name): 359 | if name == 'input_': 360 | return (self.num_channels,) + self.image_size 361 | if name == 'output': 362 | return (self.num_filters,) + self.image_size 363 | return super(ConvMaxout, self).get_dim(name) 364 | 365 | @property 366 | def num_filters(self): 367 | return self.num_channels // self.num_pieces 368 | 369 | @property 370 | def num_output_channels(self): 371 | return self.num_filters 372 | 373 | @application(inputs=['input_'], outputs=['output']) 374 | def apply(self, input_): 375 | input_ = input_.dimshuffle(0, 2, 3, 1) 376 | new_shape = ([input_.shape[i] for i in range(input_.ndim - 1)] + 377 | [self.num_filters, self.num_pieces]) 378 | output = tensor.max(input_.reshape(new_shape, ndim=input_.ndim + 1), 379 | axis=input_.ndim) 380 | return output.dimshuffle(0, 3, 1, 2) 381 | -------------------------------------------------------------------------------- /ali/mixture_viz.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualization of 2D mixture learned with ALI 3 | """ 4 | from collections import OrderedDict 5 | from functools import partial 6 | 7 | import matplotlib 8 | print("Using Backend: ", matplotlib.get_backend()) 9 | 10 | import numpy as np 11 | import numpy.random as npr 12 | import matplotlib.pyplot as plt 13 | import theano 14 | from theano import tensor 15 | 16 | from blocks.serialization import load 17 | from blocks.bricks.interfaces import Random 18 | 19 | import matplotlib.pyplot as plt 20 | from ali.streams import create_gaussian_mixture_data_streams 21 | from ali.utils import (as_array, ) 22 | 23 | LABELS_CMAP = 'Spectral' 24 | PROB_CMAP = 'jet' 25 | GRADS_GRID_NPTS = 20 # NUmber of points in the gradients grid 26 | NGRADS = 1 # Number of gradients skipped in the quiver plot 27 | SCATTER_ALPHA = 0.5 # Scatter plots transparency 28 | MARKERSIZE = 8 29 | 30 | ############# 31 | ## Helpers ## 32 | ############# 33 | def get_key_from_val(dictionary, target_val): 34 | for key, val in dictionary.items(): 35 | if val == target_val: 36 | return key 37 | return None 38 | 39 | 40 | def get_data(main_loop, n_points=1000): 41 | means = main_loop.data_stream.dataset.means 42 | variances = main_loop.data_stream.dataset.variances 43 | priors = main_loop.data_stream.dataset.priors 44 | _, _, stream = create_gaussian_mixture_data_streams(n_points, n_points, 45 | sources=('features', 46 | 'label'), 47 | means=means, 48 | variances=variances, 49 | priors=priors) 50 | originals, labels = next(stream.get_epoch_iterator()) 51 | return {'originals': originals, 52 | 'labels': labels} 53 | 54 | 55 | def get_compiled_functions(main_loop): 56 | ali, = main_loop.model.top_bricks 57 | x = tensor.matrix('x') 58 | z = tensor.matrix('z') 59 | 60 | # Accuracies 61 | accuracies = ali.get_accuracies(x, z) 62 | accuracies_fun = theano.function([x, z], accuracies) 63 | # Encoding decoding 64 | encoding = ali.sample_z_hat(x) 65 | encoding_fun = theano.function([x], encoding) 66 | decoding = ali.sample_x_tilde(z) 67 | decoding_fun = theano.function([z], decoding) 68 | 69 | # losses and latent gradients 70 | disc_loss, gen_loss = ali.compute_losses(x, z) 71 | disc_loss_fun = theano.function([x, z], disc_loss) 72 | gen_loss_fun = theano.function([x, z], gen_loss) 73 | 74 | disc_loss_z_grads = tensor.grad(disc_loss, z) 75 | gen_loss_z_grads = tensor.grad(gen_loss, z) 76 | 77 | disc_loss_x_grads = tensor.grad(disc_loss, x) 78 | gen_loss_x_grads = tensor.grad(gen_loss, x) 79 | 80 | disc_grad_z_fun = theano.function([x, z], disc_loss_z_grads) 81 | gen_grad_z_fun = theano.function([x, z], gen_loss_z_grads) 82 | 83 | disc_grad_x_fun = theano.function([x, z], disc_loss_x_grads) 84 | gen_grad_x_fun = theano.function([x, z], gen_loss_x_grads) 85 | 86 | gradients_funs = {'discriminator': {'X_grads': disc_grad_x_fun, 87 | 'Z_grads': disc_grad_z_fun}, 88 | 'generator': {'X_grads': gen_grad_x_fun, 89 | 'Z_grads': gen_grad_z_fun}} 90 | 91 | return (gradients_funs, 92 | accuracies_fun, 93 | {'encode': encoding_fun, 94 | 'decode': decoding_fun}) 95 | 96 | 97 | def mouseevent_to_nparray(event): 98 | return as_array((event.xdata, event.ydata)) 99 | 100 | ################ 101 | ## Visualizer ## 102 | ################ 103 | class MixtureVisualizer(object): 104 | def __init__(self, main_loop, ngrid_pts): 105 | self.main_loop = main_loop 106 | self.ngrid_pts = ngrid_pts 107 | self.fig, axes = plt.subplots(nrows=2, ncols=3) 108 | self.axes = OrderedDict(zip(['X', 'Z', 'X_of_Z', 109 | 'Info','Z_grads', 'X_grads'], axes.ravel())) 110 | self.scatter_plots = OrderedDict(zip(self.axes.keys(), 111 | [None] * 6)) 112 | self.grads_plots = OrderedDict(zip(['X_grads', 'Z_grads'], 113 | [None] * 2)) 114 | self.prob_plots = OrderedDict(zip(['Z', 'X_of_Z'], 115 | [None] * 2)) 116 | # getting compiled functions 117 | comp_funs = get_compiled_functions(main_loop) 118 | self._get_grads = comp_funs[0] 119 | self._get_accuracies = comp_funs[1] 120 | self._get_mappings = comp_funs[2] 121 | 122 | # getting validation data 123 | self.data = get_data(self.main_loop) 124 | self.n_classes = len(self.main_loop.data_stream.dataset.priors) 125 | 126 | self.add_titles() 127 | self.add_scatters() 128 | selected_x = self.features[0] 129 | selected_z = self.codes[0] 130 | # self.update_gradients_field('Z_grads', selected_x) 131 | # self.update_gradients_field('X_grads', selected_z) 132 | 133 | # Adding initial probability Maps 134 | self.selected_id = {'X': {'base': None, 135 | 'target_prob': None, 136 | 'target_grad': None}, 137 | 'Z': {'base': None, 138 | 'target_prob': None, 139 | 'target_grad': None}} 140 | 141 | self.update_probability_map('Z', self.features[0]) 142 | self.update_probability_map('X_of_Z', self.codes[0]) 143 | self.finetune_axes() 144 | self.register_callbacks() 145 | 146 | @property 147 | def labels(self): 148 | return self.data['labels'] 149 | 150 | @property 151 | def features(self): 152 | return self.data['originals'] 153 | 154 | @property 155 | def codes(self): 156 | return self._get_mappings['encode'](self.features) 157 | 158 | @property 159 | def reconstructions(self): 160 | return self._get_mappings['decode'](self.codes) 161 | 162 | @property 163 | def current_epoch(self): 164 | return self.main_loop.status['epochs_done'] 165 | 166 | def add_scatter(self, name, datum, label): 167 | ax = self.axes[name].scatter(*(self._split_arr(datum)), 168 | c=self.labels, 169 | s=50, 170 | marker='o', 171 | label=label, 172 | alpha=SCATTER_ALPHA, 173 | cmap=plt.cm.get_cmap(LABELS_CMAP, 174 | self.n_classes)) 175 | self.scatter_plots[name] = ax 176 | 177 | def add_scatters(self): 178 | names = ['X', 'Z', 'X_of_Z', 'X_grads', 'Z_grads'] 179 | data = [self.features, self.codes, self.reconstructions, 180 | self.reconstructions, self.codes] 181 | labels = ['originals', 'encodings', 'reconstructions', 182 | 'reconstructions', 'encodings'] 183 | assert len(names) == len(data) == len(labels) 184 | 185 | for name, datum, label in zip(names, data, labels): 186 | self.add_scatter(name=name, datum=datum, label=label) 187 | 188 | def add_probability_map(self, name, accuracies): 189 | im = self.axes[name].imshow(accuracies, 190 | cmap=plt.cm.get_cmap(PROB_CMAP), 191 | extent=self._get_extent(name), 192 | vmin=0.0, vmax=1.0) 193 | self.prob_plots[name] = im 194 | 195 | def update_probability_map(self, name, selected): 196 | accuracies = self.get_accuracies(name, selected) 197 | # Annoying Qt4 bug forces redrawing the entire image, 198 | self.add_probability_map(name, accuracies) 199 | 200 | def update_gradients_field(self, name, selected): 201 | # Getting gradients and grid 202 | grad_x, grad_y, x, y = self.get_gradients(name, selected) 203 | # plot every n grads 204 | if self.grads_plots[name] is None: 205 | quiv = self.axes[name].quiver(x[::NGRADS], y[::NGRADS], 206 | grad_x[::NGRADS], grad_y[::NGRADS]) 207 | self.grads_plots[name] = quiv 208 | else: 209 | # assert self.grads_plots[name] is not None 210 | self.grads_plots[name].set_UVC(grad_x[::NGRADS], grad_y[::NGRADS]) 211 | 212 | def add_titles(self): 213 | self.fig.suptitle('ALI - Gaussian Mixture - Epoch: {}'.format( 214 | self.current_epoch) 215 | ) 216 | 217 | self.axes['X'].set_title('Validation') 218 | self.axes['Z'].set_title('Validation Encodings & Data Accuracies') 219 | self.axes['X_of_Z'].set_title('Reconstructions & Sample Accuracies') 220 | self.axes['X_grads'].set_title('Discriminator score w.r.p to x') 221 | self.axes['Z_grads'].set_title('Discriminator score w.r.p to z') 222 | 223 | def finetune_axes(self): 224 | # Forcing subplots to have 'box' aspect 225 | for ax in self.axes.values(): 226 | ax.set_aspect('equal', adjustable='box') 227 | ax.set_autoscale_on(False) 228 | 229 | # Setting ylim and xlim 230 | X_axes = ['X_grads', 'X_of_Z'] 231 | for ax_name in X_axes: 232 | self.axes[ax_name].set_xlim(self.axes['X'].get_xlim()) 233 | self.axes[ax_name].set_ylim(self.axes['X'].get_ylim()) 234 | self.axes['Z_grads'].set_xlim(self.axes['Z'].get_xlim()) 235 | self.axes['Z_grads'].set_ylim(self.axes['Z'].get_ylim()) 236 | 237 | # Adding colorbar 238 | self.fig.subplots_adjust(right=0.8) 239 | cbar_ax = self.fig.add_axes([0.85, 0.15, 0.05, 0.7]) 240 | self.fig.colorbar(self.prob_plots['Z'], cax=cbar_ax) 241 | 242 | def get_grid(self, ax, num=None): 243 | if num is None: 244 | num = self.ngrid_pts 245 | x = np.linspace(*ax.get_xlim(), num=num) 246 | y = np.linspace(*ax.get_ylim(), num=num) 247 | xx, yy = np.meshgrid(x, y) 248 | return xx, yy 249 | 250 | def get_accuracies(self, name, selected): 251 | assert name in ['X_of_Z', 'Z'] 252 | xx, yy = self.get_grid(self.axes[name]) 253 | grid = np.vstack([xx.flatten(order='F'), yy.flatten(order='F')]).T 254 | selected_grid = np.tile(selected, (grid.shape[0], 1)) 255 | if name == 'X_of_Z': 256 | input_grids = [grid, selected_grid] 257 | elif name == 'Z': 258 | input_grids = [selected_grid, grid] 259 | accuracies = self._get_accuracies(*input_grids).reshape(xx.shape, 260 | order='F') 261 | return accuracies 262 | 263 | def get_Z_gradients(self, selected_x): 264 | xx, yy = self.get_grid(self.axes['Z_grads'], GRADS_GRID_NPTS) 265 | assert xx.shape == yy.shape 266 | grads_shape = xx.shape 267 | grid = np.vstack([xx.flatten(order='F'), yy.flatten(order='F')]).T 268 | x0 = np.tile(selected_x, (grid.shape[0], 1)) 269 | grads = self._get_grads['discriminator']['Z_grads'](x0, grid) 270 | return [grad.reshape(grads_shape, order='F') 271 | for grad in self._split_arr(grads)] + [xx, yy] 272 | 273 | def get_X_gradients(self, selected_z): 274 | xx, yy = self.get_grid(self.axes['X_grads'], GRADS_GRID_NPTS) 275 | assert xx.shape == yy.shape 276 | grads_shape = xx.shape 277 | grid = np.vstack([xx.flatten(order='F'), yy.flatten(order='F')]).T 278 | z0 = np.tile(selected_z, (grid.shape[0], 1)) 279 | grads = self._get_grads['discriminator']['X_grads'](grid, z0) 280 | return [grad.reshape(grads_shape, order='F') 281 | for grad in self._split_arr(grads)] + [xx, yy] 282 | 283 | def get_gradients(self, name, selected): 284 | assert name in ['X_grads', 'Z_grads'] 285 | if name == 'X_grads': 286 | return self.get_X_gradients(selected) 287 | elif name == 'Z_grads': 288 | return self.get_Z_gradients(selected) 289 | 290 | def remove_previously_selected(self, name): 291 | for renderer in self.selected_id[name].values(): 292 | if renderer is not None: 293 | renderer.remove() 294 | 295 | def mark_selected(self, name, prob_target_name, grad_target_name, selected_val): 296 | mapping_name = 'encode' if name == 'X' else 'decode' 297 | mapped_val = self._get_mappings[mapping_name]( 298 | selected_val.reshape(1, selected_val.shape[0])).flatten() 299 | marker_style = '^r' if name == 'X' else '^b' 300 | 301 | # Adding selected val 302 | self.selected_id[name]['base'], = self.axes[name].plot( 303 | selected_val[0], selected_val[1], 304 | marker_style, markersize=MARKERSIZE) 305 | 306 | # Adding mapped val 307 | self.selected_id[name]['prob_target'], = self.axes[prob_target_name].plot( 308 | mapped_val[0], mapped_val[1], 309 | marker_style, markersize=MARKERSIZE 310 | ) 311 | self.selected_id[name]['grad_target'], = self.axes[grad_target_name].plot( 312 | mapped_val[0], mapped_val[1], 313 | marker_style, markersize=MARKERSIZE 314 | ) 315 | 316 | def click_event(self, event): 317 | # Getting current ax 318 | inax = event.inaxes 319 | # get current ax identity 320 | ax_name = get_key_from_val(self.axes, inax) 321 | 322 | isvalid_axis = ax_name in ['X', 'Z'] 323 | isvalid_pt = event.xdata is not None and event.ydata is not None 324 | if isvalid_axis and isvalid_pt: 325 | selected_val = mouseevent_to_nparray(event) 326 | prob_target_name = 'Z' if ax_name == 'X' else 'X_of_Z' 327 | self.update_probability_map(prob_target_name, 328 | selected_val) 329 | 330 | grad_target_name = 'Z_grads' if ax_name == 'X' else 'X_grads' 331 | self.update_gradients_field(grad_target_name, selected_val) 332 | 333 | self.remove_previously_selected(ax_name) 334 | self.mark_selected(ax_name, prob_target_name, grad_target_name, 335 | selected_val) 336 | # Updating figure 337 | plt.pause(0.0001) 338 | # self.fig.canvas.draw() 339 | 340 | def register_callbacks(self): 341 | self.fig.canvas.mpl_connect('button_press_event', self.click_event) 342 | 343 | def _split_arr(self, arr): 344 | return np.split(arr, 2, axis=1) 345 | 346 | def _get_extent(self, axis_name): 347 | "Returns (xmin, xmax, ymin, ymax)" 348 | self.axes['Z_grads'].set_ylim(self.axes['Z'].get_xlim()) 349 | return self.axes[axis_name].get_xlim() \ 350 | + self.axes[axis_name].get_ylim() 351 | 352 | def show(self): 353 | self.fig.show() 354 | 355 | if __name__ == '__main__': 356 | main_loop_path = "../experiments/ali_gm.tar" 357 | with open(main_loop_path, 'rb') as ali_src: 358 | main_loop = load(ali_src) 359 | 360 | # Initializing visualizer 361 | plt.ion() 362 | ngrid_pts = 200 363 | mixture_viz = MixtureVisualizer(main_loop, ngrid_pts=ngrid_pts) 364 | #mixture_viz.show() 365 | -------------------------------------------------------------------------------- /ali/conditional_bricks.py: -------------------------------------------------------------------------------- 1 | """ Conditional ALI related Bricks""" 2 | 3 | from theano import tensor 4 | from theano import (function, ) 5 | 6 | from blocks.bricks.base import Brick, application, lazy 7 | from blocks.bricks import LeakyRectifier, Logistic 8 | from blocks.bricks import (Linear, Sequence, ) 9 | from blocks.bricks.conv import (Convolutional, ConvolutionalSequence, ConvolutionalSequence, ) 10 | from blocks.bricks.interfaces import Initializable, Random 11 | 12 | from blocks.initialization import IsotropicGaussian, Constant 13 | 14 | from blocks.select import Selector 15 | 16 | from ali.bricks import ConvMaxout 17 | from ali.utils import get_log_odds, conv_brick, conv_transpose_brick, bn_brick 18 | 19 | 20 | class Embedder(Initializable): 21 | """ 22 | Linear Embedding Brick 23 | Parameters 24 | ---------- 25 | dim_in: :class:`int` 26 | Dimensionality of the input 27 | dim_out: :class:`int` 28 | Dimensionality of the output 29 | output_type: :class:`str` 30 | fc for fully connected. conv for convolutional 31 | """ 32 | 33 | def __init__(self, dim_in, dim_out, output_type='fc', **kwargs): 34 | 35 | self.dim_in = dim_in 36 | self.dim_out = dim_out 37 | self.output_type = output_type 38 | self.linear = Linear(dim_in, dim_out, name='embed_layer') 39 | children = [self.linear] 40 | kwargs.setdefault('children', []).extend(children) 41 | super(Embedder, self).__init__(**kwargs) 42 | 43 | @application(inputs=['y'], outputs=['outputs']) 44 | def apply(self, y): 45 | embedding = self.linear.apply(y) 46 | if self.output_type == 'fc': 47 | return embedding 48 | if self.output_type == 'conv': 49 | return embedding.reshape((-1, embedding.shape[-1], 1, 1)) 50 | 51 | def get_dim(self, name): 52 | if self.output_type == 'fc': 53 | return self.linear.get_dim(name) 54 | if self.output_type == 'conv': 55 | return (self.linear.get_dim(name), 1, 1) 56 | 57 | 58 | class EncoderMapping(Initializable): 59 | """ 60 | Parameters 61 | ---------- 62 | layers: :class:`list` 63 | list of bricks 64 | num_channels: :class: `int` 65 | Number of input channels 66 | image_size: :class:`tuple` 67 | Image size 68 | n_emb: :class:`int` 69 | Dimensionality of the embedding 70 | use_bias: :class:`bool` 71 | self explanatory 72 | """ 73 | def __init__(self, layers, num_channels, image_size, n_emb, use_bias=False, **kwargs): 74 | self.layers = layers 75 | self.num_channels = num_channels 76 | self.image_size = image_size 77 | 78 | self.pre_encoder = ConvolutionalSequence(layers=layers[:-1], 79 | num_channels=num_channels, 80 | image_size=image_size, 81 | use_bias=use_bias, 82 | name='encoder_conv_mapping') 83 | self.pre_encoder.allocate() 84 | n_channels = n_emb + self.pre_encoder.get_dim('output')[0] 85 | self.post_encoder = ConvolutionalSequence(layers=[layers[-1]], 86 | num_channels=n_channels, 87 | image_size=(1, 1), 88 | use_bias=use_bias) 89 | children = [self.pre_encoder, self.post_encoder] 90 | kwargs.setdefault('children', []).extend(children) 91 | super(EncoderMapping, self).__init__(**kwargs) 92 | 93 | @application(inputs=['x', 'y'], outputs=['output']) 94 | def apply(self, x, y): 95 | "Returns mu and logsigma" 96 | # Getting emebdding 97 | pre_z = self.pre_encoder.apply(x) 98 | # Concatenating 99 | pre_z_embed_y = tensor.concatenate([pre_z, y], axis=1) 100 | # propagating through last layer 101 | return self.post_encoder.apply(pre_z_embed_y) 102 | 103 | 104 | class Decoder(Initializable): 105 | def __init__(self, layers, num_channels, image_size, use_bias=False, **kwargs): 106 | self.layers = layers 107 | self.num_channels = num_channels 108 | self.image_size = image_size 109 | 110 | self.mapping = ConvolutionalSequence(layers=layers, 111 | num_channels=num_channels, 112 | image_size=image_size, 113 | use_bias=use_bias, 114 | name='decoder_mapping') 115 | children = [self.mapping] 116 | kwargs.setdefault('children', []).extend(children) 117 | super(Decoder, self).__init__(**kwargs) 118 | 119 | @application(inputs=['z', 'y'], outputs=['outputs']) 120 | def apply(self, z, y, application_call): 121 | # Concatenating conditional data with inputs 122 | z_y = tensor.concatenate([z, y], axis=1) 123 | return self.mapping.apply(z_y) 124 | 125 | 126 | class GaussianConditional(Initializable, Random): 127 | def __init__(self, mapping, **kwargs): 128 | self.mapping = mapping 129 | super(GaussianConditional, self).__init__(**kwargs) 130 | self.children.extend([mapping]) 131 | @property 132 | def _nlat(self): 133 | # if isinstance(self.mapping, ConvolutionalSequence): 134 | # return self.get_dim('output')[0] 135 | # else: 136 | # return self.get_dim('output') 137 | return self.mapping.children[-1].get_dim('output')[0] // 2 138 | 139 | def get_dim(self, name): 140 | if isinstance(self.mapping, ConvolutionalSequence): 141 | dim = self.mapping.get_dim(name) 142 | if name == 'output': 143 | return (dim[0] // 2) + dim[1:] 144 | else: 145 | return dim 146 | else: 147 | if name == 'output': 148 | return self.mapping.output_dim // 2 149 | elif name == 'input_': 150 | return self.mapping.input_dim 151 | else: 152 | return self.mapping.get_dim(name) 153 | @application(inputs=['x', 'y'], outputs=['output']) 154 | def apply(self, x, y, application_call): 155 | params = self.mapping.apply(x, y) 156 | mu, log_sigma = params[:, :self._nlat], params[:, self._nlat:] 157 | sigma = tensor.exp(log_sigma) 158 | epsilon = self.theano_rng.normal(size=mu.shape) 159 | return mu + sigma * epsilon 160 | 161 | 162 | class XZYJointDiscriminator(Initializable): 163 | """Three-way discriminator. 164 | 165 | Parameters 166 | ---------- 167 | x_discriminator : :class:`blocks.bricks.Brick` 168 | Part of the discriminator taking :math:`x` as input. Its 169 | output will be concatenated with ``z_discriminator``'s output 170 | and fed to ``joint_discriminator``. 171 | z_discriminator : :class:`blocks.bricks.Brick` 172 | Part of the discriminator taking :math:`z` as input. Its 173 | output will be concatenated with ``x_discriminator``'s output 174 | and fed to ``joint_discriminator``. 175 | joint_discriminator : :class:`blocks.bricks.Brick` 176 | Part of the discriminator taking the concatenation of 177 | ``x_discriminator``'s and output ``z_discriminator``'s output 178 | as input and computing :math:`D(x, z)`. 179 | 180 | """ 181 | def __init__(self, x_discriminator, z_discriminator, joint_discriminator, 182 | **kwargs): 183 | self.x_discriminator = x_discriminator 184 | self.z_discriminator = z_discriminator 185 | self.joint_discriminator = joint_discriminator 186 | 187 | super(XZYJointDiscriminator, self).__init__(**kwargs) 188 | self.children.extend([self.x_discriminator, self.z_discriminator, 189 | self.joint_discriminator]) 190 | 191 | @application(inputs=['x', 'z', 'y'], outputs=['output']) 192 | def apply(self, x, z, y): 193 | # NOTE: the unbroadcasts act as a workaround for a weird broadcasting 194 | # bug when applying dropout 195 | input_ = tensor.unbroadcast( 196 | tensor.concatenate( 197 | [self.x_discriminator.apply(x), self.z_discriminator.apply(z), y], 198 | axis=1), 199 | *range(x.ndim)) 200 | return self.joint_discriminator.apply(input_) 201 | 202 | 203 | class ConditionalALI(Initializable, Random): 204 | """Adversarial learned inference brick. 205 | 206 | Parameters 207 | ---------- 208 | encoder : :class:`blocks.bricks.Brick` 209 | Encoder network. 210 | decoder : :class:`blocks.bricks.Brick` 211 | Decoder network. 212 | discriminator : :class:`blocks.bricks.Brick` 213 | Discriminator network taking :math:`x` and :math:`z` as input. 214 | n_cond: `int` 215 | Dimensionality of conditional data 216 | n_emb: `int` 217 | Dimensionality of embedding 218 | 219 | """ 220 | def __init__(self, encoder, decoder, discriminator, n_cond, n_emb, **kwargs): 221 | self.encoder = encoder 222 | self.decoder = decoder 223 | self.discriminator = discriminator 224 | self.n_cond = n_cond # Features in conditional data 225 | self.n_emb = n_emb # Features in embeddings 226 | self.embedder = Embedder(n_cond, n_emb, output_type='conv') 227 | 228 | super(ConditionalALI, self).__init__(**kwargs) 229 | self.children.extend([self.encoder, self.decoder, self.discriminator, 230 | self.embedder]) 231 | 232 | @property 233 | def discriminator_parameters(self): 234 | return list( 235 | Selector([self.discriminator]).get_parameters().values()) 236 | 237 | @property 238 | def generator_parameters(self): 239 | return list( 240 | Selector([self.encoder, self.decoder]).get_parameters().values()) 241 | @property 242 | def embedding_parameters(self): 243 | return list( 244 | Selector([self.embedder]).get_parameters().values()) 245 | 246 | @application(inputs=['x', 'z_hat', 'x_tilde', 'z', 'y'], 247 | outputs=['data_preds', 'sample_preds']) 248 | def get_predictions(self, x, z_hat, x_tilde, z, y, application_call): 249 | # NOTE: the unbroadcasts act as a workaround for a weird broadcasting 250 | # bug when applying dropout 251 | input_x = tensor.unbroadcast( 252 | tensor.concatenate([x, x_tilde], axis=0), *range(x.ndim)) 253 | input_z = tensor.unbroadcast( 254 | tensor.concatenate([z_hat, z], axis=0), *range(x.ndim)) 255 | input_y = tensor.unbroadcast(tensor.concatenate([y, y], axis=0), *range(x.ndim)) 256 | data_sample_preds = self.discriminator.apply(input_x, input_z, input_y) 257 | data_preds = data_sample_preds[:x.shape[0]] 258 | sample_preds = data_sample_preds[x.shape[0]:] 259 | 260 | application_call.add_auxiliary_variable( 261 | tensor.nnet.sigmoid(data_preds).mean(), name='data_accuracy') 262 | application_call.add_auxiliary_variable( 263 | (1 - tensor.nnet.sigmoid(sample_preds)).mean(), 264 | name='sample_accuracy') 265 | 266 | return data_preds, sample_preds 267 | 268 | @application(inputs=['x', 'z', 'y'], 269 | outputs=['discriminator_loss', 'generator_loss']) 270 | def compute_losses(self, x, z, y, application_call): 271 | embeddings = self.embedder.apply(y) 272 | z_hat = self.encoder.apply(x, embeddings) 273 | x_tilde = self.decoder.apply(z, embeddings) 274 | 275 | data_preds, sample_preds = self.get_predictions(x, z_hat, x_tilde, z, 276 | embeddings) 277 | 278 | # To be modularized 279 | discriminator_loss = (tensor.nnet.softplus(-data_preds) + 280 | tensor.nnet.softplus(sample_preds)).mean() 281 | generator_loss = (tensor.nnet.softplus(data_preds) + 282 | tensor.nnet.softplus(-sample_preds)).mean() 283 | 284 | return discriminator_loss, generator_loss 285 | 286 | @application(inputs=['z', 'y'], outputs=['samples']) 287 | def sample(self, z, y): 288 | return self.decoder.apply(z, self.embedder.apply(y)) 289 | 290 | @application(inputs=['x', 'y'], outputs=['reconstructions']) 291 | def reconstruct(self, x, y): 292 | embeddings = self.embedder.apply(y) 293 | return self.decoder.apply(self.encoder.apply(x, embeddings), 294 | embeddings) 295 | 296 | 297 | if __name__ == '__main__': 298 | import numpy as np 299 | import numpy.random as npr 300 | 301 | WEIGHTS_INIT = IsotropicGaussian(0.01) 302 | BIASES_INIT = Constant(0.) 303 | LEAK = 0.1 304 | NLAT = 64 305 | 306 | IMAGE_SIZE = (32, 32) 307 | NUM_CHANNELS = 3 308 | NUM_PIECES = 2 309 | 310 | NCLASSES = 10 311 | NEMB = 100 312 | # Testing embedder 313 | embedder = Embedder(NCLASSES, NEMB, output_type='conv', 314 | weights_init=WEIGHTS_INIT, biases_init=BIASES_INIT) 315 | embedder.initialize() 316 | 317 | x = tensor.tensor4('x') 318 | y = tensor.matrix('y') 319 | 320 | embedder_test = function([y], embedder.apply(y)) 321 | 322 | test_labels = np.zeros(shape=(5, 10)) 323 | idx = npr.randint(0, 9, size=5) 324 | for n, id in enumerate(idx): 325 | test_labels[n, id] = 1 326 | embeddings = embedder_test(test_labels) 327 | print(embeddings) 328 | print(embeddings.shape) 329 | 330 | # Generate synthetic 4D tensor 331 | features = npr.random(size=(5, 3, 32, 32)) 332 | 333 | # Testing Encoder 334 | layers = [ 335 | # 32 X 32 X 3 336 | conv_brick(5, 1, 32), bn_brick(), LeakyRectifier(leak=LEAK), 337 | # 28 X 28 X 32 338 | conv_brick(4, 2, 64), bn_brick(), LeakyRectifier(leak=LEAK), 339 | # 13 X 13 X 64 340 | conv_brick(4, 1, 128), bn_brick(), LeakyRectifier(leak=LEAK), 341 | # 10 X 10 X 128 342 | conv_brick(4, 2, 256), bn_brick(), LeakyRectifier(leak=LEAK), 343 | # 4 X 4 X 256 344 | conv_brick(4, 1, 512), bn_brick(), LeakyRectifier(leak=LEAK), 345 | # 1 X 1 X 512 346 | conv_brick(1, 1, 512), bn_brick(), LeakyRectifier(leak=LEAK), 347 | # 1 X 1 X 512 348 | conv_brick(1, 1, 2 * NLAT) 349 | # 1 X 1 X 2 * NLAT 350 | ] 351 | 352 | encoder_mapping = EncoderMapping(layers=layers, 353 | num_channels=NUM_CHANNELS, 354 | image_size=IMAGE_SIZE, weights_init=WEIGHTS_INIT, 355 | biases_init=BIASES_INIT) 356 | encoder_mapping.initialize() 357 | 358 | embeddings = embedder.apply(y) 359 | encoder_mapping_fun = function([x, y], encoder_mapping.apply(x, embeddings)) 360 | out = encoder_mapping_fun(features, test_labels) 361 | print(out.shape) 362 | 363 | ## Testing Gaussian encoder blocks 364 | embeddings = embedder.apply(y) 365 | encoder = GaussianConditional(mapping=encoder_mapping) 366 | encoder.initialize() 367 | encoder_fun = function([x, y], encoder.apply(x, embeddings)) 368 | z_hat = encoder_fun(features, test_labels) 369 | # print(out) 370 | print(z_hat) 371 | 372 | # Decoder 373 | z = tensor.tensor4('z') 374 | layers = [ 375 | conv_transpose_brick(4, 1, 256), bn_brick(), LeakyRectifier(leak=LEAK), 376 | conv_transpose_brick(4, 2, 128), bn_brick(), LeakyRectifier(leak=LEAK), 377 | conv_transpose_brick(4, 1, 64), bn_brick(), LeakyRectifier(leak=LEAK), 378 | conv_transpose_brick(4, 2, 32), bn_brick(), LeakyRectifier(leak=LEAK), 379 | conv_transpose_brick(5, 1, 32), bn_brick(), LeakyRectifier(leak=LEAK), 380 | conv_transpose_brick(1, 1, 32), bn_brick(), LeakyRectifier(leak=LEAK), 381 | conv_brick(1, 1, NUM_CHANNELS), Logistic()] 382 | 383 | decoder = Decoder(layers=layers, num_channels=(NLAT + NEMB), image_size=(1, 1), 384 | weights_init=WEIGHTS_INIT, biases_init=BIASES_INIT) 385 | decoder.initialize() 386 | decoder_fun = function([z, y], decoder.apply(z, embeddings)) 387 | out = decoder_fun(z_hat, test_labels) 388 | 389 | # Discriminator 390 | 391 | layers = [ 392 | conv_brick(5, 1, 32), ConvMaxout(num_pieces=NUM_PIECES), 393 | conv_brick(4, 2, 64), ConvMaxout(num_pieces=NUM_PIECES), 394 | conv_brick(4, 1, 128), ConvMaxout(num_pieces=NUM_PIECES), 395 | conv_brick(4, 2, 256), ConvMaxout(num_pieces=NUM_PIECES), 396 | conv_brick(4, 1, 512), ConvMaxout(num_pieces=NUM_PIECES)] 397 | x_discriminator = ConvolutionalSequence( 398 | layers=layers, num_channels=NUM_CHANNELS, image_size=IMAGE_SIZE, 399 | name='x_discriminator') 400 | x_discriminator.push_allocation_config() 401 | 402 | layers = [ 403 | conv_brick(1, 1, 512), ConvMaxout(num_pieces=NUM_PIECES), 404 | conv_brick(1, 1, 512), ConvMaxout(num_pieces=NUM_PIECES)] 405 | z_discriminator = ConvolutionalSequence( 406 | layers=layers, num_channels=NLAT, image_size=(1, 1), use_bias=False, 407 | name='z_discriminator') 408 | z_discriminator.push_allocation_config() 409 | 410 | layers = [ 411 | conv_brick(1, 1, 1024), ConvMaxout(num_pieces=NUM_PIECES), 412 | conv_brick(1, 1, 1024), ConvMaxout(num_pieces=NUM_PIECES), 413 | conv_brick(1, 1, 1)] 414 | joint_discriminator = ConvolutionalSequence( 415 | layers=layers, 416 | num_channels=(x_discriminator.get_dim('output')[0] + 417 | z_discriminator.get_dim('output')[0] + 418 | NEMB), 419 | image_size=(1, 1), 420 | name='joint_discriminator') 421 | 422 | discriminator = XZYJointDiscriminator( 423 | x_discriminator, z_discriminator, joint_discriminator, 424 | name='discriminator') 425 | 426 | discriminator = XZYJointDiscriminator(x_discriminator, z_discriminator, joint_discriminator, 427 | name='discriminator', weights_init=WEIGHTS_INIT, 428 | biases_init=BIASES_INIT) 429 | discriminator.initialize() 430 | discriminator_fun = function([x, z, y], discriminator.apply(x, z, embeddings)) 431 | out = discriminator_fun(features, z_hat, test_labels) 432 | print(out.shape) 433 | 434 | 435 | # Initializing ALI 436 | ali = ConditionalALI(encoder=encoder, decoder=decoder, discriminator=discriminator, 437 | n_cond=NCLASSES, 438 | n_emb=NEMB, 439 | weights_init=WEIGHTS_INIT, 440 | biases_init=BIASES_INIT) 441 | ali.initialize() 442 | # Computing Loss 443 | loss = ali.compute_losses(x, z, y) 444 | loss_fun = function([x, z, y], loss) 445 | out = loss_fun(features, z_hat, test_labels) 446 | 447 | 448 | -------------------------------------------------------------------------------- /paper/fancyhdr.sty: -------------------------------------------------------------------------------- 1 | % fancyhdr.sty version 3.2 2 | % Fancy headers and footers for LaTeX. 3 | % Piet van Oostrum, 4 | % Dept of Computer and Information Sciences, University of Utrecht, 5 | % Padualaan 14, P.O. Box 80.089, 3508 TB Utrecht, The Netherlands 6 | % Telephone: +31 30 2532180. Email: piet@cs.uu.nl 7 | % ======================================================================== 8 | % LICENCE: 9 | % This file may be distributed under the terms of the LaTeX Project Public 10 | % License, as described in lppl.txt in the base LaTeX distribution. 11 | % Either version 1 or, at your option, any later version. 12 | % ======================================================================== 13 | % MODIFICATION HISTORY: 14 | % Sep 16, 1994 15 | % version 1.4: Correction for use with \reversemargin 16 | % Sep 29, 1994: 17 | % version 1.5: Added the \iftopfloat, \ifbotfloat and \iffloatpage commands 18 | % Oct 4, 1994: 19 | % version 1.6: Reset single spacing in headers/footers for use with 20 | % setspace.sty or doublespace.sty 21 | % Oct 4, 1994: 22 | % version 1.7: changed \let\@mkboth\markboth to 23 | % \def\@mkboth{\protect\markboth} to make it more robust 24 | % Dec 5, 1994: 25 | % version 1.8: corrections for amsbook/amsart: define \@chapapp and (more 26 | % importantly) use the \chapter/sectionmark definitions from ps@headings if 27 | % they exist (which should be true for all standard classes). 28 | % May 31, 1995: 29 | % version 1.9: The proposed \renewcommand{\headrulewidth}{\iffloatpage... 30 | % construction in the doc did not work properly with the fancyplain style. 31 | % June 1, 1995: 32 | % version 1.91: The definition of \@mkboth wasn't restored on subsequent 33 | % \pagestyle{fancy}'s. 34 | % June 1, 1995: 35 | % version 1.92: The sequence \pagestyle{fancyplain} \pagestyle{plain} 36 | % \pagestyle{fancy} would erroneously select the plain version. 37 | % June 1, 1995: 38 | % version 1.93: \fancypagestyle command added. 39 | % Dec 11, 1995: 40 | % version 1.94: suggested by Conrad Hughes 41 | % CJCH, Dec 11, 1995: added \footruleskip to allow control over footrule 42 | % position (old hardcoded value of .3\normalbaselineskip is far too high 43 | % when used with very small footer fonts). 44 | % Jan 31, 1996: 45 | % version 1.95: call \@normalsize in the reset code if that is defined, 46 | % otherwise \normalsize. 47 | % this is to solve a problem with ucthesis.cls, as this doesn't 48 | % define \@currsize. Unfortunately for latex209 calling \normalsize doesn't 49 | % work as this is optimized to do very little, so there \@normalsize should 50 | % be called. Hopefully this code works for all versions of LaTeX known to 51 | % mankind. 52 | % April 25, 1996: 53 | % version 1.96: initialize \headwidth to a magic (negative) value to catch 54 | % most common cases that people change it before calling \pagestyle{fancy}. 55 | % Note it can't be initialized when reading in this file, because 56 | % \textwidth could be changed afterwards. This is quite probable. 57 | % We also switch to \MakeUppercase rather than \uppercase and introduce a 58 | % \nouppercase command for use in headers. and footers. 59 | % May 3, 1996: 60 | % version 1.97: Two changes: 61 | % 1. Undo the change in version 1.8 (using the pagestyle{headings} defaults 62 | % for the chapter and section marks. The current version of amsbook and 63 | % amsart classes don't seem to need them anymore. Moreover the standard 64 | % latex classes don't use \markboth if twoside isn't selected, and this is 65 | % confusing as \leftmark doesn't work as expected. 66 | % 2. include a call to \ps@empty in ps@@fancy. This is to solve a problem 67 | % in the amsbook and amsart classes, that make global changes to \topskip, 68 | % which are reset in \ps@empty. Hopefully this doesn't break other things. 69 | % May 7, 1996: 70 | % version 1.98: 71 | % Added % after the line \def\nouppercase 72 | % May 7, 1996: 73 | % version 1.99: This is the alpha version of fancyhdr 2.0 74 | % Introduced the new commands \fancyhead, \fancyfoot, and \fancyhf. 75 | % Changed \headrulewidth, \footrulewidth, \footruleskip to 76 | % macros rather than length parameters, In this way they can be 77 | % conditionalized and they don't consume length registers. There is no need 78 | % to have them as length registers unless you want to do calculations with 79 | % them, which is unlikely. Note that this may make some uses of them 80 | % incompatible (i.e. if you have a file that uses \setlength or \xxxx=) 81 | % May 10, 1996: 82 | % version 1.99a: 83 | % Added a few more % signs 84 | % May 10, 1996: 85 | % version 1.99b: 86 | % Changed the syntax of \f@nfor to be resistent to catcode changes of := 87 | % Removed the [1] from the defs of \lhead etc. because the parameter is 88 | % consumed by the \@[xy]lhead etc. macros. 89 | % June 24, 1997: 90 | % version 1.99c: 91 | % corrected \nouppercase to also include the protected form of \MakeUppercase 92 | % \global added to manipulation of \headwidth. 93 | % \iffootnote command added. 94 | % Some comments added about \@fancyhead and \@fancyfoot. 95 | % Aug 24, 1998 96 | % version 1.99d 97 | % Changed the default \ps@empty to \ps@@empty in order to allow 98 | % \fancypagestyle{empty} redefinition. 99 | % Oct 11, 2000 100 | % version 2.0 101 | % Added LPPL license clause. 102 | % 103 | % A check for \headheight is added. An errormessage is given (once) if the 104 | % header is too large. Empty headers don't generate the error even if 105 | % \headheight is very small or even 0pt. 106 | % Warning added for the use of 'E' option when twoside option is not used. 107 | % In this case the 'E' fields will never be used. 108 | % 109 | % Mar 10, 2002 110 | % version 2.1beta 111 | % New command: \fancyhfoffset[place]{length} 112 | % defines offsets to be applied to the header/footer to let it stick into 113 | % the margins (if length > 0). 114 | % place is like in fancyhead, except that only E,O,L,R can be used. 115 | % This replaces the old calculation based on \headwidth and the marginpar 116 | % area. 117 | % \headwidth will be dynamically calculated in the headers/footers when 118 | % this is used. 119 | % 120 | % Mar 26, 2002 121 | % version 2.1beta2 122 | % \fancyhfoffset now also takes h,f as possible letters in the argument to 123 | % allow the header and footer widths to be different. 124 | % New commands \fancyheadoffset and \fancyfootoffset added comparable to 125 | % \fancyhead and \fancyfoot. 126 | % Errormessages and warnings have been made more informative. 127 | % 128 | % Dec 9, 2002 129 | % version 2.1 130 | % The defaults for \footrulewidth, \plainheadrulewidth and 131 | % \plainfootrulewidth are changed from \z@skip to 0pt. In this way when 132 | % someone inadvertantly uses \setlength to change any of these, the value 133 | % of \z@skip will not be changed, rather an errormessage will be given. 134 | 135 | % March 3, 2004 136 | % Release of version 3.0 137 | 138 | % Oct 7, 2004 139 | % version 3.1 140 | % Added '\endlinechar=13' to \fancy@reset to prevent problems with 141 | % includegraphics in header when verbatiminput is active. 142 | 143 | % March 22, 2005 144 | % version 3.2 145 | % reset \everypar (the real one) in \fancy@reset because spanish.ldf does 146 | % strange things with \everypar between << and >>. 147 | 148 | \def\ifancy@mpty#1{\def\temp@a{#1}\ifx\temp@a\@empty} 149 | 150 | \def\fancy@def#1#2{\ifancy@mpty{#2}\fancy@gbl\def#1{\leavevmode}\else 151 | \fancy@gbl\def#1{#2\strut}\fi} 152 | 153 | \let\fancy@gbl\global 154 | 155 | \def\@fancyerrmsg#1{% 156 | \ifx\PackageError\undefined 157 | \errmessage{#1}\else 158 | \PackageError{Fancyhdr}{#1}{}\fi} 159 | \def\@fancywarning#1{% 160 | \ifx\PackageWarning\undefined 161 | \errmessage{#1}\else 162 | \PackageWarning{Fancyhdr}{#1}{}\fi} 163 | 164 | % Usage: \@forc \var{charstring}{command to be executed for each char} 165 | % This is similar to LaTeX's \@tfor, but expands the charstring. 166 | 167 | \def\@forc#1#2#3{\expandafter\f@rc\expandafter#1\expandafter{#2}{#3}} 168 | \def\f@rc#1#2#3{\def\temp@ty{#2}\ifx\@empty\temp@ty\else 169 | \f@@rc#1#2\f@@rc{#3}\fi} 170 | \def\f@@rc#1#2#3\f@@rc#4{\def#1{#2}#4\f@rc#1{#3}{#4}} 171 | 172 | % Usage: \f@nfor\name:=list\do{body} 173 | % Like LaTeX's \@for but an empty list is treated as a list with an empty 174 | % element 175 | 176 | \newcommand{\f@nfor}[3]{\edef\@fortmp{#2}% 177 | \expandafter\@forloop#2,\@nil,\@nil\@@#1{#3}} 178 | 179 | % Usage: \def@ult \cs{defaults}{argument} 180 | % sets \cs to the characters from defaults appearing in argument 181 | % or defaults if it would be empty. All characters are lowercased. 182 | 183 | \newcommand\def@ult[3]{% 184 | \edef\temp@a{\lowercase{\edef\noexpand\temp@a{#3}}}\temp@a 185 | \def#1{}% 186 | \@forc\tmpf@ra{#2}% 187 | {\expandafter\if@in\tmpf@ra\temp@a{\edef#1{#1\tmpf@ra}}{}}% 188 | \ifx\@empty#1\def#1{#2}\fi} 189 | % 190 | % \if@in 191 | % 192 | \newcommand{\if@in}[4]{% 193 | \edef\temp@a{#2}\def\temp@b##1#1##2\temp@b{\def\temp@b{##1}}% 194 | \expandafter\temp@b#2#1\temp@b\ifx\temp@a\temp@b #4\else #3\fi} 195 | 196 | \newcommand{\fancyhead}{\@ifnextchar[{\f@ncyhf\fancyhead h}% 197 | {\f@ncyhf\fancyhead h[]}} 198 | \newcommand{\fancyfoot}{\@ifnextchar[{\f@ncyhf\fancyfoot f}% 199 | {\f@ncyhf\fancyfoot f[]}} 200 | \newcommand{\fancyhf}{\@ifnextchar[{\f@ncyhf\fancyhf{}}% 201 | {\f@ncyhf\fancyhf{}[]}} 202 | 203 | % New commands for offsets added 204 | 205 | \newcommand{\fancyheadoffset}{\@ifnextchar[{\f@ncyhfoffs\fancyheadoffset h}% 206 | {\f@ncyhfoffs\fancyheadoffset h[]}} 207 | \newcommand{\fancyfootoffset}{\@ifnextchar[{\f@ncyhfoffs\fancyfootoffset f}% 208 | {\f@ncyhfoffs\fancyfootoffset f[]}} 209 | \newcommand{\fancyhfoffset}{\@ifnextchar[{\f@ncyhfoffs\fancyhfoffset{}}% 210 | {\f@ncyhfoffs\fancyhfoffset{}[]}} 211 | 212 | % The header and footer fields are stored in command sequences with 213 | % names of the form: \f@ncy with for [eo], from [lcr] 214 | % and from [hf]. 215 | 216 | \def\f@ncyhf#1#2[#3]#4{% 217 | \def\temp@c{}% 218 | \@forc\tmpf@ra{#3}% 219 | {\expandafter\if@in\tmpf@ra{eolcrhf,EOLCRHF}% 220 | {}{\edef\temp@c{\temp@c\tmpf@ra}}}% 221 | \ifx\@empty\temp@c\else 222 | \@fancyerrmsg{Illegal char `\temp@c' in \string#1 argument: 223 | [#3]}% 224 | \fi 225 | \f@nfor\temp@c{#3}% 226 | {\def@ult\f@@@eo{eo}\temp@c 227 | \if@twoside\else 228 | \if\f@@@eo e\@fancywarning 229 | {\string#1's `E' option without twoside option is useless}\fi\fi 230 | \def@ult\f@@@lcr{lcr}\temp@c 231 | \def@ult\f@@@hf{hf}{#2\temp@c}% 232 | \@forc\f@@eo\f@@@eo 233 | {\@forc\f@@lcr\f@@@lcr 234 | {\@forc\f@@hf\f@@@hf 235 | {\expandafter\fancy@def\csname 236 | f@ncy\f@@eo\f@@lcr\f@@hf\endcsname 237 | {#4}}}}}} 238 | 239 | \def\f@ncyhfoffs#1#2[#3]#4{% 240 | \def\temp@c{}% 241 | \@forc\tmpf@ra{#3}% 242 | {\expandafter\if@in\tmpf@ra{eolrhf,EOLRHF}% 243 | {}{\edef\temp@c{\temp@c\tmpf@ra}}}% 244 | \ifx\@empty\temp@c\else 245 | \@fancyerrmsg{Illegal char `\temp@c' in \string#1 argument: 246 | [#3]}% 247 | \fi 248 | \f@nfor\temp@c{#3}% 249 | {\def@ult\f@@@eo{eo}\temp@c 250 | \if@twoside\else 251 | \if\f@@@eo e\@fancywarning 252 | {\string#1's `E' option without twoside option is useless}\fi\fi 253 | \def@ult\f@@@lcr{lr}\temp@c 254 | \def@ult\f@@@hf{hf}{#2\temp@c}% 255 | \@forc\f@@eo\f@@@eo 256 | {\@forc\f@@lcr\f@@@lcr 257 | {\@forc\f@@hf\f@@@hf 258 | {\expandafter\setlength\csname 259 | f@ncyO@\f@@eo\f@@lcr\f@@hf\endcsname 260 | {#4}}}}}% 261 | \fancy@setoffs} 262 | 263 | % Fancyheadings version 1 commands. These are more or less deprecated, 264 | % but they continue to work. 265 | 266 | \newcommand{\lhead}{\@ifnextchar[{\@xlhead}{\@ylhead}} 267 | \def\@xlhead[#1]#2{\fancy@def\f@ncyelh{#1}\fancy@def\f@ncyolh{#2}} 268 | \def\@ylhead#1{\fancy@def\f@ncyelh{#1}\fancy@def\f@ncyolh{#1}} 269 | 270 | \newcommand{\chead}{\@ifnextchar[{\@xchead}{\@ychead}} 271 | \def\@xchead[#1]#2{\fancy@def\f@ncyech{#1}\fancy@def\f@ncyoch{#2}} 272 | \def\@ychead#1{\fancy@def\f@ncyech{#1}\fancy@def\f@ncyoch{#1}} 273 | 274 | \newcommand{\rhead}{\@ifnextchar[{\@xrhead}{\@yrhead}} 275 | \def\@xrhead[#1]#2{\fancy@def\f@ncyerh{#1}\fancy@def\f@ncyorh{#2}} 276 | \def\@yrhead#1{\fancy@def\f@ncyerh{#1}\fancy@def\f@ncyorh{#1}} 277 | 278 | \newcommand{\lfoot}{\@ifnextchar[{\@xlfoot}{\@ylfoot}} 279 | \def\@xlfoot[#1]#2{\fancy@def\f@ncyelf{#1}\fancy@def\f@ncyolf{#2}} 280 | \def\@ylfoot#1{\fancy@def\f@ncyelf{#1}\fancy@def\f@ncyolf{#1}} 281 | 282 | \newcommand{\cfoot}{\@ifnextchar[{\@xcfoot}{\@ycfoot}} 283 | \def\@xcfoot[#1]#2{\fancy@def\f@ncyecf{#1}\fancy@def\f@ncyocf{#2}} 284 | \def\@ycfoot#1{\fancy@def\f@ncyecf{#1}\fancy@def\f@ncyocf{#1}} 285 | 286 | \newcommand{\rfoot}{\@ifnextchar[{\@xrfoot}{\@yrfoot}} 287 | \def\@xrfoot[#1]#2{\fancy@def\f@ncyerf{#1}\fancy@def\f@ncyorf{#2}} 288 | \def\@yrfoot#1{\fancy@def\f@ncyerf{#1}\fancy@def\f@ncyorf{#1}} 289 | 290 | \newlength{\fancy@headwidth} 291 | \let\headwidth\fancy@headwidth 292 | \newlength{\f@ncyO@elh} 293 | \newlength{\f@ncyO@erh} 294 | \newlength{\f@ncyO@olh} 295 | \newlength{\f@ncyO@orh} 296 | \newlength{\f@ncyO@elf} 297 | \newlength{\f@ncyO@erf} 298 | \newlength{\f@ncyO@olf} 299 | \newlength{\f@ncyO@orf} 300 | \newcommand{\headrulewidth}{0.4pt} 301 | \newcommand{\footrulewidth}{0pt} 302 | \newcommand{\footruleskip}{.3\normalbaselineskip} 303 | 304 | % Fancyplain stuff shouldn't be used anymore (rather 305 | % \fancypagestyle{plain} should be used), but it must be present for 306 | % compatibility reasons. 307 | 308 | \newcommand{\plainheadrulewidth}{0pt} 309 | \newcommand{\plainfootrulewidth}{0pt} 310 | \newif\if@fancyplain \@fancyplainfalse 311 | \def\fancyplain#1#2{\if@fancyplain#1\else#2\fi} 312 | 313 | \headwidth=-123456789sp %magic constant 314 | 315 | % Command to reset various things in the headers: 316 | % a.o. single spacing (taken from setspace.sty) 317 | % and the catcode of ^^M (so that epsf files in the header work if a 318 | % verbatim crosses a page boundary) 319 | % It also defines a \nouppercase command that disables \uppercase and 320 | % \Makeuppercase. It can only be used in the headers and footers. 321 | \let\fnch@everypar\everypar% save real \everypar because of spanish.ldf 322 | \def\fancy@reset{\fnch@everypar{}\restorecr\endlinechar=13 323 | \def\baselinestretch{1}% 324 | \def\nouppercase##1{{\let\uppercase\relax\let\MakeUppercase\relax 325 | \expandafter\let\csname MakeUppercase \endcsname\relax##1}}% 326 | \ifx\undefined\@newbaseline% NFSS not present; 2.09 or 2e 327 | \ifx\@normalsize\undefined \normalsize % for ucthesis.cls 328 | \else \@normalsize \fi 329 | \else% NFSS (2.09) present 330 | \@newbaseline% 331 | \fi} 332 | 333 | % Initialization of the head and foot text. 334 | 335 | % The default values still contain \fancyplain for compatibility. 336 | \fancyhf{} % clear all 337 | % lefthead empty on ``plain'' pages, \rightmark on even, \leftmark on odd pages 338 | % evenhead empty on ``plain'' pages, \leftmark on even, \rightmark on odd pages 339 | \if@twoside 340 | \fancyhead[el,or]{\fancyplain{}{\sl\rightmark}} 341 | \fancyhead[er,ol]{\fancyplain{}{\sl\leftmark}} 342 | \else 343 | \fancyhead[l]{\fancyplain{}{\sl\rightmark}} 344 | \fancyhead[r]{\fancyplain{}{\sl\leftmark}} 345 | \fi 346 | \fancyfoot[c]{\rm\thepage} % page number 347 | 348 | % Use box 0 as a temp box and dimen 0 as temp dimen. 349 | % This can be done, because this code will always 350 | % be used inside another box, and therefore the changes are local. 351 | 352 | \def\@fancyvbox#1#2{\setbox0\vbox{#2}\ifdim\ht0>#1\@fancywarning 353 | {\string#1 is too small (\the#1): ^^J Make it at least \the\ht0.^^J 354 | We now make it that large for the rest of the document.^^J 355 | This may cause the page layout to be inconsistent, however\@gobble}% 356 | \dimen0=#1\global\setlength{#1}{\ht0}\ht0=\dimen0\fi 357 | \box0} 358 | 359 | % Put together a header or footer given the left, center and 360 | % right text, fillers at left and right and a rule. 361 | % The \lap commands put the text into an hbox of zero size, 362 | % so overlapping text does not generate an errormessage. 363 | % These macros have 5 parameters: 364 | % 1. LEFTSIDE BEARING % This determines at which side the header will stick 365 | % out. When \fancyhfoffset is used this calculates \headwidth, otherwise 366 | % it is \hss or \relax (after expansion). 367 | % 2. \f@ncyolh, \f@ncyelh, \f@ncyolf or \f@ncyelf. This is the left component. 368 | % 3. \f@ncyoch, \f@ncyech, \f@ncyocf or \f@ncyecf. This is the middle comp. 369 | % 4. \f@ncyorh, \f@ncyerh, \f@ncyorf or \f@ncyerf. This is the right component. 370 | % 5. RIGHTSIDE BEARING. This is always \relax or \hss (after expansion). 371 | 372 | \def\@fancyhead#1#2#3#4#5{#1\hbox to\headwidth{\fancy@reset 373 | \@fancyvbox\headheight{\hbox 374 | {\rlap{\parbox[b]{\headwidth}{\raggedright#2}}\hfill 375 | \parbox[b]{\headwidth}{\centering#3}\hfill 376 | \llap{\parbox[b]{\headwidth}{\raggedleft#4}}}\headrule}}#5} 377 | 378 | \def\@fancyfoot#1#2#3#4#5{#1\hbox to\headwidth{\fancy@reset 379 | \@fancyvbox\footskip{\footrule 380 | \hbox{\rlap{\parbox[t]{\headwidth}{\raggedright#2}}\hfill 381 | \parbox[t]{\headwidth}{\centering#3}\hfill 382 | \llap{\parbox[t]{\headwidth}{\raggedleft#4}}}}}#5} 383 | 384 | \def\headrule{{\if@fancyplain\let\headrulewidth\plainheadrulewidth\fi 385 | \hrule\@height\headrulewidth\@width\headwidth \vskip-\headrulewidth}} 386 | 387 | \def\footrule{{\if@fancyplain\let\footrulewidth\plainfootrulewidth\fi 388 | \vskip-\footruleskip\vskip-\footrulewidth 389 | \hrule\@width\headwidth\@height\footrulewidth\vskip\footruleskip}} 390 | 391 | \def\ps@fancy{% 392 | \@ifundefined{@chapapp}{\let\@chapapp\chaptername}{}%for amsbook 393 | % 394 | % Define \MakeUppercase for old LaTeXen. 395 | % Note: we used \def rather than \let, so that \let\uppercase\relax (from 396 | % the version 1 documentation) will still work. 397 | % 398 | \@ifundefined{MakeUppercase}{\def\MakeUppercase{\uppercase}}{}% 399 | \@ifundefined{chapter}{\def\sectionmark##1{\markboth 400 | {\MakeUppercase{\ifnum \c@secnumdepth>\z@ 401 | \thesection\hskip 1em\relax \fi ##1}}{}}% 402 | \def\subsectionmark##1{\markright {\ifnum \c@secnumdepth >\@ne 403 | \thesubsection\hskip 1em\relax \fi ##1}}}% 404 | {\def\chaptermark##1{\markboth {\MakeUppercase{\ifnum \c@secnumdepth>\m@ne 405 | \@chapapp\ \thechapter. \ \fi ##1}}{}}% 406 | \def\sectionmark##1{\markright{\MakeUppercase{\ifnum \c@secnumdepth >\z@ 407 | \thesection. \ \fi ##1}}}}% 408 | %\csname ps@headings\endcsname % use \ps@headings defaults if they exist 409 | \ps@@fancy 410 | \gdef\ps@fancy{\@fancyplainfalse\ps@@fancy}% 411 | % Initialize \headwidth if the user didn't 412 | % 413 | \ifdim\headwidth<0sp 414 | % 415 | % This catches the case that \headwidth hasn't been initialized and the 416 | % case that the user added something to \headwidth in the expectation that 417 | % it was initialized to \textwidth. We compensate this now. This loses if 418 | % the user intended to multiply it by a factor. But that case is more 419 | % likely done by saying something like \headwidth=1.2\textwidth. 420 | % The doc says you have to change \headwidth after the first call to 421 | % \pagestyle{fancy}. This code is just to catch the most common cases were 422 | % that requirement is violated. 423 | % 424 | \global\advance\headwidth123456789sp\global\advance\headwidth\textwidth 425 | \fi} 426 | \def\ps@fancyplain{\ps@fancy \let\ps@plain\ps@plain@fancy} 427 | \def\ps@plain@fancy{\@fancyplaintrue\ps@@fancy} 428 | \let\ps@@empty\ps@empty 429 | \def\ps@@fancy{% 430 | \ps@@empty % This is for amsbook/amsart, which do strange things with \topskip 431 | \def\@mkboth{\protect\markboth}% 432 | \def\@oddhead{\@fancyhead\fancy@Oolh\f@ncyolh\f@ncyoch\f@ncyorh\fancy@Oorh}% 433 | \def\@oddfoot{\@fancyfoot\fancy@Oolf\f@ncyolf\f@ncyocf\f@ncyorf\fancy@Oorf}% 434 | \def\@evenhead{\@fancyhead\fancy@Oelh\f@ncyelh\f@ncyech\f@ncyerh\fancy@Oerh}% 435 | \def\@evenfoot{\@fancyfoot\fancy@Oelf\f@ncyelf\f@ncyecf\f@ncyerf\fancy@Oerf}% 436 | } 437 | % Default definitions for compatibility mode: 438 | % These cause the header/footer to take the defined \headwidth as width 439 | % And to shift in the direction of the marginpar area 440 | 441 | \def\fancy@Oolh{\if@reversemargin\hss\else\relax\fi} 442 | \def\fancy@Oorh{\if@reversemargin\relax\else\hss\fi} 443 | \let\fancy@Oelh\fancy@Oorh 444 | \let\fancy@Oerh\fancy@Oolh 445 | 446 | \let\fancy@Oolf\fancy@Oolh 447 | \let\fancy@Oorf\fancy@Oorh 448 | \let\fancy@Oelf\fancy@Oelh 449 | \let\fancy@Oerf\fancy@Oerh 450 | 451 | % New definitions for the use of \fancyhfoffset 452 | % These calculate the \headwidth from \textwidth and the specified offsets. 453 | 454 | \def\fancy@offsolh{\headwidth=\textwidth\advance\headwidth\f@ncyO@olh 455 | \advance\headwidth\f@ncyO@orh\hskip-\f@ncyO@olh} 456 | \def\fancy@offselh{\headwidth=\textwidth\advance\headwidth\f@ncyO@elh 457 | \advance\headwidth\f@ncyO@erh\hskip-\f@ncyO@elh} 458 | 459 | \def\fancy@offsolf{\headwidth=\textwidth\advance\headwidth\f@ncyO@olf 460 | \advance\headwidth\f@ncyO@orf\hskip-\f@ncyO@olf} 461 | \def\fancy@offself{\headwidth=\textwidth\advance\headwidth\f@ncyO@elf 462 | \advance\headwidth\f@ncyO@erf\hskip-\f@ncyO@elf} 463 | 464 | \def\fancy@setoffs{% 465 | % Just in case \let\headwidth\textwidth was used 466 | \fancy@gbl\let\headwidth\fancy@headwidth 467 | \fancy@gbl\let\fancy@Oolh\fancy@offsolh 468 | \fancy@gbl\let\fancy@Oelh\fancy@offselh 469 | \fancy@gbl\let\fancy@Oorh\hss 470 | \fancy@gbl\let\fancy@Oerh\hss 471 | \fancy@gbl\let\fancy@Oolf\fancy@offsolf 472 | \fancy@gbl\let\fancy@Oelf\fancy@offself 473 | \fancy@gbl\let\fancy@Oorf\hss 474 | \fancy@gbl\let\fancy@Oerf\hss} 475 | 476 | \newif\iffootnote 477 | \let\latex@makecol\@makecol 478 | \def\@makecol{\ifvoid\footins\footnotetrue\else\footnotefalse\fi 479 | \let\topfloat\@toplist\let\botfloat\@botlist\latex@makecol} 480 | \def\iftopfloat#1#2{\ifx\topfloat\empty #2\else #1\fi} 481 | \def\ifbotfloat#1#2{\ifx\botfloat\empty #2\else #1\fi} 482 | \def\iffloatpage#1#2{\if@fcolmade #1\else #2\fi} 483 | 484 | \newcommand{\fancypagestyle}[2]{% 485 | \@namedef{ps@#1}{\let\fancy@gbl\relax#2\relax\ps@fancy}} 486 | --------------------------------------------------------------------------------