├── images ├── img.jpg └── reconstructed.png ├── environment.yml ├── bvae ├── model_utils.py ├── ae.py ├── sample_layer.py └── models.py ├── LICENSE ├── .gitignore └── README.md /images/img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alecGraves/BVAE-tf/HEAD/images/img.jpg -------------------------------------------------------------------------------- /images/reconstructed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alecGraves/BVAE-tf/HEAD/images/reconstructed.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: bvae-tf 2 | channels: 3 | - defaults 4 | dependencies: 5 | - certifi=2018.1.18=py35_0 6 | - pip=9.0.1=py35_5 7 | - python=3.5.4=h1357f44_23 8 | - setuptools=39.0.1=py35_0 9 | - vc=14=h0510ff6_3 10 | - vs2015_runtime=14.0.25123=3 11 | - wheel=0.30.0=py35h38a90bc_1 12 | - wincertstore=0.2=py35hfebbdb8_0 13 | - pip: 14 | - bleach==1.5.0 15 | - enum34==1.1.6 16 | - html5lib==0.9999999 17 | - markdown==2.6.11 18 | - numpy==1.14.2 19 | - pillow==5.1.0 20 | - protobuf==3.5.2.post1 21 | - six==1.11.0 22 | - tensorflow-gpu==1.4.0 23 | - tensorflow-tensorboard==0.4.0 24 | - werkzeug==0.14.1 25 | -------------------------------------------------------------------------------- /bvae/model_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | model_utils.py 3 | contains custom blocks, etc. for building mdoels. 4 | 5 | created by shadySource 6 | 7 | THE UNLICENSE 8 | ''' 9 | import tensorflow as tf 10 | from tensorflow.python.keras.layers import (InputLayer, Conv2D, Conv2DTranspose, 11 | BatchNormalization, LeakyReLU, MaxPool2D, UpSampling2D, 12 | Reshape, GlobalAveragePooling2D, Layer) 13 | 14 | class ConvBnLRelu(object): 15 | def __init__(self, filters, kernelSize, strides=1): 16 | self.filters = filters 17 | self.kernelSize = kernelSize 18 | self.strides = strides 19 | # return conv + bn + leaky_relu model 20 | def __call__(self, net, training=None): 21 | net = Conv2D(self.filters, self.kernelSize, strides=self.strides, padding='same')(net) 22 | net = BatchNormalization()(net, training=training) 23 | net = LeakyReLU()(net) 24 | return net 25 | 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BVAE-tf 2 | Disentangled Variational Auto-Encoder in TensorFlow (Beta-VAE) 3 | #### :star: :boom: THE UNLICENSE :boom: :star: 4 | ## Example Reconstructed Image 5 | ![alt text](https://github.com/shadySource/BVAE-tf/raw/master/images/reconstructed.png) 6 | 7 | ## What has been done 8 | * darknet 19 (fully convolutional & fast) encoder and decoder 9 | * Custom keras sampling layer for sampling the distribution of variational autoencoders 10 | * Custom loss in sampling layer for latent space regularization 11 | * Options are no reg, vae reg (kl divergence), or bvae reg (beta*kl-divergence) 12 | * You can also set a target capacity for dimension usage of the latent space 13 | * Simple interface for setting up your own VAE or B-VAE 14 | * See [the test function in ae.py](https://github.com/shadySource/BVAE-tf/blob/master/bvae/ae.py#L22) for usage information 15 | 16 | ## Enviroment Setup 17 | I am using conda to ensure the enviroment is easy to install 18 | 19 | 0. Install [Anaconda](https://www.anaconda.com/download/) or 20 | [Miniconda](https://conda.io/miniconda.html) (the python 3 version) for your platform 21 | 1. Recreate the conda environment from the yml: 22 | ``` conda env create -f environment.yml ``` 23 | 2. Active the enviroment 24 | 1. Windows: go to cmd and ```activate bvae-tf``` 25 | 2. Linux: ```source activate bvae-tf``` 26 | 3. If you want to use CPU only, run ```pip uninstall tensorflow-gpu``` 27 | followed by ```pip install tensorflow==1.4.0``` after you activate the environment. 28 | 29 | If you do not want to / cannot use conda, I am using tensorflow 1.4.0; see the environment.yml for more package info. 30 | 31 | ## Demo 32 | For a simple overfitting demonstration, run ```ae.py``` in your terminal. This will cause the autoencoder to run on the included demo image. 33 | 34 | Note: The demo takes a few minutes on my 1060 6GB, so it will take a while on a CPU... 35 | -------------------------------------------------------------------------------- /bvae/ae.py: -------------------------------------------------------------------------------- 1 | ''' 2 | vae.py 3 | contains the setup for autoencoders. 4 | 5 | created by shadySource 6 | 7 | THE UNLICENSE 8 | ''' 9 | import tensorflow as tf 10 | from tensorflow.python.keras.models import Model 11 | from tensorflow.python.keras import backend as K 12 | 13 | class AutoEncoder(object): 14 | def __init__(self, encoderArchitecture, 15 | decoderArchitecture): 16 | 17 | self.encoder = encoderArchitecture.model 18 | self.decoder = decoderArchitecture.model 19 | 20 | self.ae = Model(self.encoder.inputs, self.decoder(self.encoder.outputs)) 21 | 22 | def test(): 23 | import os 24 | import numpy as np 25 | from PIL import Image 26 | from tensorflow.python.keras.preprocessing.image import load_img 27 | 28 | from models import Darknet19Encoder, Darknet19Decoder 29 | 30 | inputShape = (256, 256, 3) 31 | batchSize = 8 32 | latentSize = 100 33 | 34 | img = load_img(os.path.join(os.path.dirname(__file__), '..','images', 'img.jpg'), target_size=inputShape[:-1]) 35 | img.show() 36 | 37 | img = np.array(img, dtype=np.float32) * (2/255) - 1 38 | # print(np.min(img)) 39 | # print(np.max(img)) 40 | # print(np.mean(img)) 41 | 42 | img = np.array([img]*batchSize) # make fake batches to improve GPU utilization 43 | 44 | # This is how you build the autoencoder 45 | encoder = Darknet19Encoder(inputShape, latentSize=latentSize, latentConstraints='bvae', beta=69) 46 | decoder = Darknet19Decoder(inputShape, latentSize=latentSize) 47 | bvae = AutoEncoder(encoder, decoder) 48 | bvae.ae.compile(optimizer='adam', loss='mean_absolute_error') 49 | while True: 50 | bvae.ae.fit(img, img, 51 | epochs=100, 52 | batch_size=batchSize) 53 | 54 | # example retrieving the latent vector 55 | latentVec = bvae.encoder.predict(img)[0] 56 | print(latentVec) 57 | 58 | pred = bvae.ae.predict(img) # get the reconstructed image 59 | pred = np.uint8((pred + 1)* 255/2) # convert to regular image values 60 | 61 | pred = Image.fromarray(pred[0]) 62 | pred.show() # display popup 63 | 64 | if __name__ == "__main__": 65 | test() 66 | -------------------------------------------------------------------------------- /bvae/sample_layer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | sample_layer.py 3 | contains keras SampleLayer for bvae 4 | 5 | created by shadySource 6 | 7 | THE UNLICENSE 8 | ''' 9 | 10 | from tensorflow.python.keras.layers import Layer 11 | from tensorflow.python.keras import backend as K 12 | 13 | 14 | class SampleLayer(Layer): 15 | ''' 16 | Keras Layer to grab a random sample from a distribution (by multiplication) 17 | Computes "(normal)*logvar + mean" for the vae sampling operation 18 | (written for tf backend) 19 | 20 | Additionally, 21 | Applies regularization to the latent space representation. 22 | Can perform standard regularization or B-VAE regularization. 23 | 24 | call: 25 | pass in mean then logvar layers to sample from the distribution 26 | ex. 27 | sample = SampleLayer('bvae', 16)([mean, logvar]) 28 | ''' 29 | def __init__(self, latent_regularizer='bvae', beta=100., **kwargs): 30 | ''' 31 | args: 32 | ------ 33 | latent_regularizer : str 34 | Either 'bvae', 'vae', or 'no' 35 | Determines whether regularization is applied 36 | to the latent space representation. 37 | beta : float 38 | beta > 1, used for 'bvae' latent_regularizer, 39 | (Unused if 'bvae' not selected) 40 | ------ 41 | ex. 42 | sample = SampleLayer('bvae', 16)([mean, logvar]) 43 | ''' 44 | if latent_regularizer.lower() in ['bvae', 'vae']: 45 | self.reg = latent_regularizer 46 | else: 47 | self.reg = None 48 | 49 | if self.reg == 'bvae': 50 | self.beta = beta 51 | elif self.reg == 'vae': 52 | self.beta = 1. 53 | 54 | super(SampleLayer, self).__init__(**kwargs) 55 | 56 | def build(self, input_shape): 57 | # save the shape for distribution sampling 58 | super(SampleLayer, self).build(input_shape) # needed for layers 59 | 60 | def call(self, x, training=None): 61 | if len(x) != 2: 62 | raise Exception('input layers must be a list: mean and logvar') 63 | if len(x[0].shape) != 2 or len(x[1].shape) != 2: 64 | raise Exception('input shape is not a vector [batchSize, latentSize]') 65 | 66 | mean = x[0] 67 | logvar = x[1] 68 | 69 | # trick to allow setting batch at train/eval time 70 | if mean.shape[0].value == None or logvar.shape[0].value == None: 71 | return mean + 0*logvar # Keras needs the *0 so the gradinent is not None 72 | 73 | if self.reg is not None: 74 | # kl divergence: 75 | latent_loss = -0.5 * (1 + logvar 76 | - K.square(mean) 77 | - K.exp(logvar)) 78 | latent_loss = K.sum(latent_loss, axis=-1) # sum over latent dimension 79 | latent_loss = K.mean(latent_loss, axis=0) # avg over batch 80 | 81 | # use beta to force less usage of vector space: 82 | latent_loss = self.beta * latent_loss 83 | self.add_loss(latent_loss, x) 84 | 85 | def reparameterization_trick(): 86 | epsilon = K.random_normal(shape=logvar.shape, 87 | mean=0., stddev=1.) 88 | stddev = K.exp(logvar*0.5) 89 | return mean + stddev * epsilon 90 | 91 | return K.in_train_phase(reparameterization_trick, mean + 0*logvar, training=training) # TODO figure out why this is not working in the specified tf version??? 92 | 93 | def compute_output_shape(self, input_shape): 94 | return input_shape[0] 95 | -------------------------------------------------------------------------------- /bvae/models.py: -------------------------------------------------------------------------------- 1 | ''' 2 | models.py 3 | contains models for use with the BVAE experiments. 4 | 5 | created by shadySource 6 | 7 | THE UNLICENSE 8 | ''' 9 | import tensorflow as tf 10 | from tensorflow.python.keras import Input 11 | from tensorflow.python.keras.layers import (InputLayer, Conv2D, Conv2DTranspose, 12 | BatchNormalization, LeakyReLU, MaxPool2D, UpSampling2D, 13 | Reshape, GlobalAveragePooling2D) 14 | from tensorflow.python.keras.models import Model 15 | 16 | from model_utils import ConvBnLRelu 17 | from sample_layer import SampleLayer 18 | 19 | class Architecture(object): 20 | ''' 21 | generic architecture template 22 | ''' 23 | def __init__(self, inputShape=None, batchSize=None, latentSize=None): 24 | ''' 25 | params: 26 | --------- 27 | inputShape : tuple 28 | the shape of the input, expecting 3-dim images (h, w, 3) 29 | batchSize : int 30 | the number of samples in a batch 31 | latentSize : int 32 | the number of dimensions in the two output distribution vectors - 33 | mean and std-deviation 34 | latentSize : Bool or None 35 | True forces resampling, False forces no resampling, None chooses based on K.learning_phase() 36 | ''' 37 | self.inputShape = inputShape 38 | self.batchSize = batchSize 39 | self.latentSize = latentSize 40 | 41 | self.model = self.Build() 42 | 43 | def Build(self): 44 | raise NotImplementedError('architecture must implement Build function') 45 | 46 | 47 | class Darknet19Encoder(Architecture): 48 | ''' 49 | This encoder predicts distributions then randomly samples them. 50 | Regularization may be applied to the latent space output 51 | 52 | a simple, fully convolutional architecture inspried by 53 | pjreddie's darknet architecture 54 | https://github.com/pjreddie/darknet/blob/master/cfg/darknet19.cfg 55 | ''' 56 | def __init__(self, inputShape=(256, 256, 3), batchSize=None, 57 | latentSize=1000, latentConstraints='bvae', beta=100., training=None): 58 | ''' 59 | params 60 | ------- 61 | latentConstraints : str 62 | Either 'bvae', 'vae', or 'no' 63 | Determines whether regularization is applied 64 | to the latent space representation. 65 | beta : float 66 | beta > 1, used for 'bvae' latent_regularizer 67 | (Unused if 'bvae' not selected, default 100) 68 | ''' 69 | self.latentConstraints = latentConstraints 70 | self.beta = beta 71 | self.training=training 72 | super().__init__(inputShape, batchSize, latentSize) 73 | 74 | def Build(self): 75 | # create the input layer for feeding the netowrk 76 | inLayer = Input(self.inputShape, self.batchSize) 77 | net = ConvBnLRelu(32, kernelSize=3)(inLayer, training=self.training) # 1 78 | net = MaxPool2D((2, 2), strides=(2, 2))(net) 79 | 80 | net = ConvBnLRelu(64, kernelSize=3)(net, training=self.training) # 2 81 | net = MaxPool2D((2, 2), strides=(2, 2))(net) 82 | 83 | net = ConvBnLRelu(128, kernelSize=3)(net, training=self.training) # 3 84 | net = ConvBnLRelu(64, kernelSize=1)(net, training=self.training) # 4 85 | net = ConvBnLRelu(128, kernelSize=3)(net, training=self.training) # 5 86 | net = MaxPool2D((2, 2), strides=(2, 2))(net) 87 | 88 | net = ConvBnLRelu(256, kernelSize=3)(net, training=self.training) # 6 89 | net = ConvBnLRelu(128, kernelSize=1)(net, training=self.training) # 7 90 | net = ConvBnLRelu(256, kernelSize=3)(net, training=self.training) # 8 91 | net = MaxPool2D((2, 2), strides=(2, 2))(net) 92 | 93 | net = ConvBnLRelu(512, kernelSize=3)(net, training=self.training) # 9 94 | net = ConvBnLRelu(256, kernelSize=1)(net, training=self.training) # 10 95 | net = ConvBnLRelu(512, kernelSize=3)(net, training=self.training) # 11 96 | net = ConvBnLRelu(256, kernelSize=1)(net, training=self.training) # 12 97 | net = ConvBnLRelu(512, kernelSize=3)(net, training=self.training) # 13 98 | net = MaxPool2D((2, 2), strides=(2, 2))(net) 99 | 100 | net = ConvBnLRelu(1024, kernelSize=3)(net, training=self.training) # 14 101 | net = ConvBnLRelu(512, kernelSize=1)(net, training=self.training) # 15 102 | net = ConvBnLRelu(1024, kernelSize=3)(net, training=self.training) # 16 103 | net = ConvBnLRelu(512, kernelSize=1)(net, training=self.training) # 17 104 | net = ConvBnLRelu(1024, kernelSize=3)(net, training=self.training) # 18 105 | 106 | # variational encoder output (distributions) 107 | mean = Conv2D(filters=self.latentSize, kernel_size=(1, 1), 108 | padding='same')(net) 109 | mean = GlobalAveragePooling2D()(mean) 110 | logvar = Conv2D(filters=self.latentSize, kernel_size=(1, 1), 111 | padding='same')(net) 112 | logvar = GlobalAveragePooling2D()(logvar) 113 | 114 | sample = SampleLayer(self.latentConstraints, self.beta)([mean, logvar], training=self.training) 115 | 116 | return Model(inputs=inLayer, outputs=sample) 117 | 118 | class Darknet19Decoder(Architecture): 119 | def __init__(self, inputShape=(256, 256, 3), batchSize=None, latentSize=1000, training=None): 120 | self.training=training 121 | super().__init__(inputShape, batchSize, latentSize) 122 | 123 | def Build(self): 124 | # input layer is from GlobalAveragePooling: 125 | inLayer = Input([self.latentSize], self.batchSize) 126 | # reexpand the input from flat: 127 | net = Reshape((1, 1, self.latentSize))(inLayer) 128 | # darknet downscales input by a factor of 32, so we upsample to the second to last output shape: 129 | net = UpSampling2D((self.inputShape[0]//32, self.inputShape[1]//32))(net) 130 | 131 | # TODO try inverting num filter arangement (e.g. 512, 1204, 512, 1024, 512) 132 | # and also try (1, 3, 1, 3, 1) for the filter shape 133 | net = ConvBnLRelu(1024, kernelSize=3)(net, training=self.training) 134 | net = ConvBnLRelu(512, kernelSize=1)(net, training=self.training) 135 | net = ConvBnLRelu(1024, kernelSize=3)(net, training=self.training) 136 | net = ConvBnLRelu(512, kernelSize=1)(net, training=self.training) 137 | net = ConvBnLRelu(1024, kernelSize=3)(net, training=self.training) 138 | 139 | net = UpSampling2D((2, 2))(net) 140 | net = ConvBnLRelu(512, kernelSize=3)(net, training=self.training) 141 | net = ConvBnLRelu(256, kernelSize=1)(net, training=self.training) 142 | net = ConvBnLRelu(512, kernelSize=3)(net, training=self.training) 143 | net = ConvBnLRelu(256, kernelSize=1)(net, training=self.training) 144 | net = ConvBnLRelu(512, kernelSize=3)(net, training=self.training) 145 | 146 | net = UpSampling2D((2, 2))(net) 147 | net = ConvBnLRelu(256, kernelSize=3)(net, training=self.training) 148 | net = ConvBnLRelu(128, kernelSize=1)(net, training=self.training) 149 | net = ConvBnLRelu(256, kernelSize=3)(net, training=self.training) 150 | 151 | net = UpSampling2D((2, 2))(net) 152 | net = ConvBnLRelu(128, kernelSize=3)(net, training=self.training) 153 | net = ConvBnLRelu(64, kernelSize=1)(net, training=self.training) 154 | net = ConvBnLRelu(128, kernelSize=3)(net, training=self.training) 155 | 156 | net = UpSampling2D((2, 2))(net) 157 | net = ConvBnLRelu(64, kernelSize=3)(net, training=self.training) 158 | 159 | net = UpSampling2D((2, 2))(net) 160 | net = ConvBnLRelu(32, kernelSize=3)(net, training=self.training) 161 | net = ConvBnLRelu(64, kernelSize=1)(net, training=self.training) 162 | 163 | # net = ConvBnLRelu(3, kernelSize=1)(net, training=self.training) 164 | net = Conv2D(filters=self.inputShape[-1], kernel_size=(1, 1), 165 | padding='same', activation="tanh")(net) 166 | 167 | return Model(inLayer, net) 168 | 169 | class Darknet53Encoder(Architecture): 170 | ''' 171 | a larger, fully convolutional architecture inspried by 172 | pjreddie's darknet architecture 173 | https://github.com/pjreddie/darknet/blob/master/cfg/darknet19.cfg 174 | https://github.com/pjreddie/darknet/blob/master/cfg/yolov3.cfg 175 | https://pjreddie.com/media/files/papers/YOLOv3.pdf 176 | ''' 177 | def __init__(self, inputShape=(None, None, None, None), name='darkent53_encoder'): 178 | ''' 179 | input shape for the network, a name for the scope, and a data format. 180 | ''' 181 | super().__init__(inputShape, name) 182 | self.Build() 183 | 184 | def Build(self): 185 | ''' 186 | builds darknet53 encoder network 187 | ''' 188 | raise NotImplementedError('this architecture is not complete') 189 | 190 | def ConvBlock(self): 191 | ''' 192 | adds a darknet conv block to the net 193 | ''' 194 | raise NotImplementedError('this architecture is not complete') 195 | 196 | 197 | def test(): 198 | d19e = Darknet19Encoder() 199 | d19e.model.summary() 200 | d19d = Darknet19Decoder() 201 | d19d.model.summary() 202 | 203 | if __name__ == '__main__': 204 | test() 205 | --------------------------------------------------------------------------------