├── Figures ├── figure_1.PNG ├── figure_2.PNG ├── figure_4.PNG └── figure_5.PNG ├── LICENSE ├── README.md ├── VGG16.py ├── dfc_vae_model.py ├── outputs ├── test │ ├── gen-1.png │ ├── random-2.png │ └── real-1.png └── test_interpolated │ ├── interpolate0.gif │ └── interpolate1.gif ├── train_dfc_vae.py └── util.py /Figures/figure_1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbavon/Deep-Feature-Consistent-Variational-AutoEncoder-in-Tensorflow/8db93292f5b4e99fab88f8708f2469bafd253822/Figures/figure_1.PNG -------------------------------------------------------------------------------- /Figures/figure_2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbavon/Deep-Feature-Consistent-Variational-AutoEncoder-in-Tensorflow/8db93292f5b4e99fab88f8708f2469bafd253822/Figures/figure_2.PNG -------------------------------------------------------------------------------- /Figures/figure_4.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbavon/Deep-Feature-Consistent-Variational-AutoEncoder-in-Tensorflow/8db93292f5b4e99fab88f8708f2469bafd253822/Figures/figure_4.PNG -------------------------------------------------------------------------------- /Figures/figure_5.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbavon/Deep-Feature-Consistent-Variational-AutoEncoder-in-Tensorflow/8db93292f5b4e99fab88f8708f2469bafd253822/Figures/figure_5.PNG -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Feature Consistent Variational Autoencoder in Tensorflow 2 | 3 | This repository has an objective to implement Deep Feature Consisten Variational Autoencoder (DFC-VAE) according to [Deep Feature Consistent Variational Autoencoder](https://arxiv.org/abs/1610.00291). 4 | Tensorflow and Python3 are used for development, and pre-trained VGG16 is adapted from [VGG in TensorFlow](https://www.cs.toronto.edu/~frossard/post/vgg16/). The training data is [CelebA dataset](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). 5 | 6 | To understand this following note, I would recommend to know the concept of Variational Autoencoder and generative model. 7 | 8 | ## Results 9 | 10 | ![Generated Image](Figures/figure_1.PNG) 11 | ![Random Image](Figures/figure_2.PNG) 12 | ![interpolated Image](outputs/test_interpolated/interpolate1.gif) 13 | 14 | Figure 3: Interpolated image 15 | 16 | ## Problem Statement 17 | 18 | It is known that one major problem of plain Variational Autoencoder (Plain-VAE) is that images generated by the model are blurry. 19 | This is because the plain model's loss function is defined by pixel-wise comparison between input images and generated images. 20 | As a consequence, optimizing model to achieve a great performance is difficult because slightly shifting or distorting those images can result in a very high loss. In other words, even the images have just slight difference in human eyes, computer treats that a big difference! 21 | 22 | ![distorted image](Figures/figure_4.PNG) 23 | 24 | However, with DFC-VAE, the model leverages perceptual loss used in [Neural Style Transfer](https://github.com/sbavon/Neural-Style-Transfer-in-Tensorflow). 25 | With regard to [this paper](https://arxiv.org/abs/1508.06576), internal representations of convolutional neural networks could capture a content of the input image. This finding leads to the concept of perceptual loss, which compares the content - hidden representation - between images as oppose to calculate euclidean distant among pixels. 26 | 27 | ![model architecture](Figures/figure_5.PNG) 28 | 29 | ## Implementation 30 | 31 | The solution contains four files 32 | 33 | | File Name | Description | 34 | | ------------- | ------------- | 35 | | dfc_vae_model.py | builds the VAE model, including encoder,decoder, VGG, loss function, and optimizer | 36 | | train_dfc_vae.py | trains the DFC_VAE model, and tests interpolation | 37 | | vgg16.py | builds the pre-trained VGG16 model | 38 | | util.py | contains supporting functions, such as data-preprocessing | 39 | 40 | ### Step-by-Step execution 41 | 42 | #### Download and preprocess data 43 | 1. Download pre-trained VGG weights from [VGG in TensorFlow](https://www.cs.toronto.edu/~frossard/post/vgg16/) 44 | 2. Download CelebA dataset from [CelebA dataset](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) 45 | 3. Compress data in Zip 46 | 4. Process images (crop and resize) and convert them to TFRecord format (refer to *write_tfrecord()* function in *util.py*) 47 | 48 | #### Train the model 49 | 4. Run train_dfc_vae.py 50 | 51 | ## Dependencies 52 | - scipy.misc 53 | - zipfile (used for reading content inside Zip file) 54 | - imageio (used for generating *.gif* file) 55 | 56 | ## Tips 57 | - Beta value is extremely significant. You need to adjust the value to make sure the model produce a great result 58 | - Save file in *.png* format for a better quality image 59 | 60 | ## References 61 | - [Deep Feature Consistent Variational Autoencoder](https://arxiv.org/abs/1610.00291) 62 | - [A Neural Algorithm of Artistic Style](https://arxiv.org/abs/1508.06576) 63 | - [VGG in TensorFlow](https://www.cs.toronto.edu/~frossard/post/vgg16/) 64 | 65 | -------------------------------------------------------------------------------- /VGG16.py: -------------------------------------------------------------------------------- 1 | ######################################################################################## 2 | # Davi Frossard, 2016 # 3 | # VGG16 implementation in TensorFlow # 4 | # Details: # 5 | # http://www.cs.toronto.edu/~frossard/post/vgg16/ # 6 | # # 7 | # Model from https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-readme-md # 8 | # Weights from Caffe converted using https://github.com/ethereon/caffe-tensorflow # 9 | ######################################################################################## 10 | 11 | import tensorflow as tf 12 | import numpy as np 13 | from scipy.misc import imread, imresize 14 | 15 | 16 | class vgg16: 17 | def __init__(self, imgs, weights=None, sess=None): 18 | self.imgs = imgs 19 | self.convlayers() 20 | self.fc_layers() 21 | self.probs = tf.nn.softmax(self.fc3l) 22 | #if weights is not None and sess is not None: 23 | # self.load_weights(weights, sess) 24 | ### add this function to return internal representations 25 | 26 | 27 | def get_layers(self): 28 | return self.conv1_1, self.conv2_1, self.conv3_1 29 | 30 | def convlayers(self): 31 | self.parameters = [] 32 | 33 | # zero-mean input 34 | #with tf.name_scope('preprocess') as scope: 35 | # mean = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32, shape=[1, 1, 1, 3], name='img_mean') 36 | # images = self.imgs-mean 37 | images = self.imgs 38 | 39 | # conv1_1 40 | with tf.name_scope('conv1_1') as scope: 41 | kernel = tf.Variable(tf.truncated_normal([3, 3, 3, 64], dtype=tf.float32, 42 | stddev=1e-1), name='weights') 43 | conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME') 44 | biases = tf.Variable(tf.constant(0.0, shape=[64], dtype=tf.float32), 45 | trainable=True, name='biases') 46 | out = tf.nn.bias_add(conv, biases) 47 | self.conv1_1 = tf.nn.relu(out, name=scope) 48 | self.parameters += [kernel, biases] 49 | 50 | # conv1_2 51 | with tf.name_scope('conv1_2') as scope: 52 | kernel = tf.Variable(tf.truncated_normal([3, 3, 64, 64], dtype=tf.float32, 53 | stddev=1e-1), name='weights') 54 | conv = tf.nn.conv2d(self.conv1_1, kernel, [1, 1, 1, 1], padding='SAME') 55 | biases = tf.Variable(tf.constant(0.0, shape=[64], dtype=tf.float32), 56 | trainable=True, name='biases') 57 | out = tf.nn.bias_add(conv, biases) 58 | self.conv1_2 = tf.nn.relu(out, name=scope) 59 | self.parameters += [kernel, biases] 60 | 61 | # pool1 62 | self.pool1 = tf.nn.max_pool(self.conv1_2, 63 | ksize=[1, 2, 2, 1], 64 | strides=[1, 2, 2, 1], 65 | padding='SAME', 66 | name='pool1') 67 | 68 | # conv2_1 69 | with tf.name_scope('conv2_1') as scope: 70 | kernel = tf.Variable(tf.truncated_normal([3, 3, 64, 128], dtype=tf.float32, 71 | stddev=1e-1), name='weights') 72 | conv = tf.nn.conv2d(self.pool1, kernel, [1, 1, 1, 1], padding='SAME') 73 | biases = tf.Variable(tf.constant(0.0, shape=[128], dtype=tf.float32), 74 | trainable=True, name='biases') 75 | out = tf.nn.bias_add(conv, biases) 76 | self.conv2_1 = tf.nn.relu(out, name=scope) 77 | self.parameters += [kernel, biases] 78 | 79 | # conv2_2 80 | with tf.name_scope('conv2_2') as scope: 81 | kernel = tf.Variable(tf.truncated_normal([3, 3, 128, 128], dtype=tf.float32, 82 | stddev=1e-1), name='weights') 83 | conv = tf.nn.conv2d(self.conv2_1, kernel, [1, 1, 1, 1], padding='SAME') 84 | biases = tf.Variable(tf.constant(0.0, shape=[128], dtype=tf.float32), 85 | trainable=True, name='biases') 86 | out = tf.nn.bias_add(conv, biases) 87 | self.conv2_2 = tf.nn.relu(out, name=scope) 88 | self.parameters += [kernel, biases] 89 | 90 | # pool2 91 | self.pool2 = tf.nn.max_pool(self.conv2_2, 92 | ksize=[1, 2, 2, 1], 93 | strides=[1, 2, 2, 1], 94 | padding='SAME', 95 | name='pool2') 96 | 97 | # conv3_1 98 | with tf.name_scope('conv3_1') as scope: 99 | kernel = tf.Variable(tf.truncated_normal([3, 3, 128, 256], dtype=tf.float32, 100 | stddev=1e-1), name='weights') 101 | conv = tf.nn.conv2d(self.pool2, kernel, [1, 1, 1, 1], padding='SAME') 102 | biases = tf.Variable(tf.constant(0.0, shape=[256], dtype=tf.float32), 103 | trainable=True, name='biases') 104 | out = tf.nn.bias_add(conv, biases) 105 | self.conv3_1 = tf.nn.relu(out, name=scope) 106 | self.parameters += [kernel, biases] 107 | 108 | # conv3_2 109 | with tf.name_scope('conv3_2') as scope: 110 | kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 256], dtype=tf.float32, 111 | stddev=1e-1), name='weights') 112 | conv = tf.nn.conv2d(self.conv3_1, kernel, [1, 1, 1, 1], padding='SAME') 113 | biases = tf.Variable(tf.constant(0.0, shape=[256], dtype=tf.float32), 114 | trainable=True, name='biases') 115 | out = tf.nn.bias_add(conv, biases) 116 | self.conv3_2 = tf.nn.relu(out, name=scope) 117 | self.parameters += [kernel, biases] 118 | 119 | # conv3_3 120 | with tf.name_scope('conv3_3') as scope: 121 | kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 256], dtype=tf.float32, 122 | stddev=1e-1), name='weights') 123 | conv = tf.nn.conv2d(self.conv3_2, kernel, [1, 1, 1, 1], padding='SAME') 124 | biases = tf.Variable(tf.constant(0.0, shape=[256], dtype=tf.float32), 125 | trainable=True, name='biases') 126 | out = tf.nn.bias_add(conv, biases) 127 | self.conv3_3 = tf.nn.relu(out, name=scope) 128 | self.parameters += [kernel, biases] 129 | 130 | # pool3 131 | self.pool3 = tf.nn.max_pool(self.conv3_3, 132 | ksize=[1, 2, 2, 1], 133 | strides=[1, 2, 2, 1], 134 | padding='SAME', 135 | name='pool3') 136 | 137 | # conv4_1 138 | with tf.name_scope('conv4_1') as scope: 139 | kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 512], dtype=tf.float32, 140 | stddev=1e-1), name='weights') 141 | conv = tf.nn.conv2d(self.pool3, kernel, [1, 1, 1, 1], padding='SAME') 142 | biases = tf.Variable(tf.constant(0.0, shape=[512], dtype=tf.float32), 143 | trainable=True, name='biases') 144 | out = tf.nn.bias_add(conv, biases) 145 | self.conv4_1 = tf.nn.relu(out, name=scope) 146 | self.parameters += [kernel, biases] 147 | 148 | # conv4_2 149 | with tf.name_scope('conv4_2') as scope: 150 | kernel = tf.Variable(tf.truncated_normal([3, 3, 512, 512], dtype=tf.float32, 151 | stddev=1e-1), name='weights') 152 | conv = tf.nn.conv2d(self.conv4_1, kernel, [1, 1, 1, 1], padding='SAME') 153 | biases = tf.Variable(tf.constant(0.0, shape=[512], dtype=tf.float32), 154 | trainable=True, name='biases') 155 | out = tf.nn.bias_add(conv, biases) 156 | self.conv4_2 = tf.nn.relu(out, name=scope) 157 | self.parameters += [kernel, biases] 158 | 159 | # conv4_3 160 | with tf.name_scope('conv4_3') as scope: 161 | kernel = tf.Variable(tf.truncated_normal([3, 3, 512, 512], dtype=tf.float32, 162 | stddev=1e-1), name='weights') 163 | conv = tf.nn.conv2d(self.conv4_2, kernel, [1, 1, 1, 1], padding='SAME') 164 | biases = tf.Variable(tf.constant(0.0, shape=[512], dtype=tf.float32), 165 | trainable=True, name='biases') 166 | out = tf.nn.bias_add(conv, biases) 167 | self.conv4_3 = tf.nn.relu(out, name=scope) 168 | self.parameters += [kernel, biases] 169 | 170 | # pool4 171 | self.pool4 = tf.nn.max_pool(self.conv4_3, 172 | ksize=[1, 2, 2, 1], 173 | strides=[1, 2, 2, 1], 174 | padding='SAME', 175 | name='pool4') 176 | 177 | # conv5_1 178 | with tf.name_scope('conv5_1') as scope: 179 | kernel = tf.Variable(tf.truncated_normal([3, 3, 512, 512], dtype=tf.float32, 180 | stddev=1e-1), name='weights') 181 | conv = tf.nn.conv2d(self.pool4, kernel, [1, 1, 1, 1], padding='SAME') 182 | biases = tf.Variable(tf.constant(0.0, shape=[512], dtype=tf.float32), 183 | trainable=True, name='biases') 184 | out = tf.nn.bias_add(conv, biases) 185 | self.conv5_1 = tf.nn.relu(out, name=scope) 186 | self.parameters += [kernel, biases] 187 | 188 | # conv5_2 189 | with tf.name_scope('conv5_2') as scope: 190 | kernel = tf.Variable(tf.truncated_normal([3, 3, 512, 512], dtype=tf.float32, 191 | stddev=1e-1), name='weights') 192 | conv = tf.nn.conv2d(self.conv5_1, kernel, [1, 1, 1, 1], padding='SAME') 193 | biases = tf.Variable(tf.constant(0.0, shape=[512], dtype=tf.float32), 194 | trainable=True, name='biases') 195 | out = tf.nn.bias_add(conv, biases) 196 | self.conv5_2 = tf.nn.relu(out, name=scope) 197 | self.parameters += [kernel, biases] 198 | 199 | # conv5_3 200 | with tf.name_scope('conv5_3') as scope: 201 | kernel = tf.Variable(tf.truncated_normal([3, 3, 512, 512], dtype=tf.float32, 202 | stddev=1e-1), name='weights') 203 | conv = tf.nn.conv2d(self.conv5_2, kernel, [1, 1, 1, 1], padding='SAME') 204 | biases = tf.Variable(tf.constant(0.0, shape=[512], dtype=tf.float32), 205 | trainable=True, name='biases') 206 | out = tf.nn.bias_add(conv, biases) 207 | self.conv5_3 = tf.nn.relu(out, name=scope) 208 | self.parameters += [kernel, biases] 209 | 210 | # pool5 211 | self.pool5 = tf.nn.max_pool(self.conv5_3, 212 | ksize=[1, 2, 2, 1], 213 | strides=[1, 2, 2, 1], 214 | padding='SAME', 215 | name='pool4') 216 | 217 | def fc_layers(self): 218 | # fc1 219 | with tf.name_scope('fc1') as scope: 220 | shape = int(np.prod(self.pool5.get_shape()[1:])) 221 | fc1w = tf.Variable(tf.truncated_normal([shape, 4096], 222 | dtype=tf.float32, 223 | stddev=1e-1), name='weights') 224 | fc1b = tf.Variable(tf.constant(1.0, shape=[4096], dtype=tf.float32), 225 | trainable=True, name='biases') 226 | pool5_flat = tf.reshape(self.pool5, [-1, shape]) 227 | fc1l = tf.nn.bias_add(tf.matmul(pool5_flat, fc1w), fc1b) 228 | self.fc1 = tf.nn.relu(fc1l) 229 | self.parameters += [fc1w, fc1b] 230 | 231 | # fc2 232 | with tf.name_scope('fc2') as scope: 233 | fc2w = tf.Variable(tf.truncated_normal([4096, 4096], 234 | dtype=tf.float32, 235 | stddev=1e-1), name='weights') 236 | fc2b = tf.Variable(tf.constant(1.0, shape=[4096], dtype=tf.float32), 237 | trainable=True, name='biases') 238 | fc2l = tf.nn.bias_add(tf.matmul(self.fc1, fc2w), fc2b) 239 | self.fc2 = tf.nn.relu(fc2l) 240 | self.parameters += [fc2w, fc2b] 241 | 242 | # fc3 243 | with tf.name_scope('fc3') as scope: 244 | fc3w = tf.Variable(tf.truncated_normal([4096, 1000], 245 | dtype=tf.float32, 246 | stddev=1e-1), name='weights') 247 | fc3b = tf.Variable(tf.constant(1.0, shape=[1000], dtype=tf.float32), 248 | trainable=True, name='biases') 249 | self.fc3l = tf.nn.bias_add(tf.matmul(self.fc2, fc3w), fc3b) 250 | self.parameters += [fc3w, fc3b] 251 | 252 | def load_weights(self, weight_file, sess): 253 | weights = np.load(weight_file) 254 | keys = sorted(weights.keys()) 255 | for i, k in enumerate(keys): 256 | print i, k, np.shape(weights[k]) 257 | sess.run(self.parameters[i].assign(weights[k])) 258 | 259 | ''' 260 | if __name__ == '__main__': 261 | sess = tf.Session() 262 | imgs = tf.placeholder(tf.float32, [None, 224, 224, 3]) 263 | vgg = vgg16(imgs, 'vgg16_weights.npz', sess) 264 | 265 | img1 = imread('laska.png', mode='RGB') 266 | img1 = imresize(img1, (224, 224)) 267 | 268 | prob = sess.run(vgg.probs, feed_dict={vgg.imgs: [img1]})[0] 269 | preds = (np.argsort(prob)[::-1])[0:5] 270 | for p in preds: 271 | print class_names[p], prob[p] 272 | ''' -------------------------------------------------------------------------------- /dfc_vae_model.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import tensorflow as tf 4 | from vgg16 import vgg16 5 | 6 | class dfc_vae_model(object): 7 | 8 | def __init__(self, shape, inputs, alpha = 1, beta = 0.5, vgg_layers = [], learning_rate = 0.0005): 9 | self.shape = shape 10 | self.img_input = inputs 11 | self.alpha = alpha 12 | self.beta = beta 13 | self.gstep = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step') 14 | self.vgg_layers = vgg_layers 15 | self.learning_rate = learning_rate 16 | 17 | def _get_weights(self, name, shape): 18 | with tf.variable_scope("weights", reuse=tf.AUTO_REUSE) as scope: 19 | w = tf.get_variable(name=name + '_W', 20 | shape=shape, 21 | initializer=tf.truncated_normal_initializer(stddev=0.1)) 22 | return w 23 | 24 | def _get_biases(self, name, shape): 25 | with tf.variable_scope("biases", reuse=tf.AUTO_REUSE) as scope: 26 | b = tf.get_variable(name=name + '_b', 27 | shape=shape, 28 | initializer=tf.truncated_normal_initializer(stddev=0.1)) 29 | return b 30 | 31 | def _conv2d_bn_relu(self, inputs, name, kernel_size, in_channel, out_channel, stride, activation=True,bn=True): 32 | with tf.variable_scope(name) as scope: 33 | 34 | ### setup weights and biases 35 | filters = self._get_weights(name, shape=[kernel_size, kernel_size, in_channel, out_channel]) 36 | biases = self._get_biases(name, shape=[out_channel]) 37 | 38 | ### convolutional neural network 39 | conv2d = tf.nn.conv2d(input=inputs, 40 | filter=filters, 41 | strides=[1,stride,stride,1], 42 | padding='SAME', 43 | name=name + '_conv') 44 | conv2d = tf.nn.bias_add(conv2d, biases, name=name+'_add') 45 | 46 | ### in case of batch normalization 47 | if bn == True: 48 | conv2d = tf.contrib.layers.batch_norm(conv2d, 49 | center=True, scale=True, 50 | is_training=True, 51 | scope='bn') 52 | 53 | ### in case of leaky relu activation 54 | if activation == True: 55 | conv2d = tf.nn.leaky_relu(conv2d, alpha=0.1, name=name) 56 | 57 | return conv2d 58 | 59 | def encoder(self, reuse=False): 60 | 61 | with tf.variable_scope("encoder", reuse = reuse): 62 | ### Conv2d_bn_relu Layer 1 63 | conv1 = self._conv2d_bn_relu(self.img_input, 64 | name="conv1", 65 | kernel_size=4, 66 | in_channel=3, 67 | out_channel=32, 68 | stride=2) 69 | 70 | ### Conv2d_bn_relu Layer 2 71 | conv2 = self._conv2d_bn_relu(conv1, 72 | name="conv2", 73 | kernel_size=4, 74 | in_channel=32, 75 | out_channel=64, 76 | stride=2) 77 | 78 | ### Conv2d_bn_relu Layer 3 79 | conv3 = self._conv2d_bn_relu(conv2, 80 | name="conv3", 81 | kernel_size=4, 82 | in_channel=64, 83 | out_channel=128, 84 | stride=2) 85 | 86 | ### Conv2d_bn_relu Layer 4 87 | conv4 = self._conv2d_bn_relu(conv3, 88 | name="conv4", 89 | kernel_size=4, 90 | in_channel=128, 91 | out_channel=256, 92 | stride=2) 93 | 94 | ### flatten the output 95 | conv4_flat = tf.reshape(conv4, [-1, 256*4*4]) 96 | 97 | ### FC Layer for mean 98 | fcmean = tf.layers.dense(inputs=conv4_flat, 99 | units=100, 100 | activation=None, 101 | name="fcmean") 102 | 103 | ### FC Layer for standard deviation 104 | fcstd = tf.layers.dense(inputs=conv4_flat, 105 | units=100, 106 | activation=None, 107 | name="fcstd") 108 | 109 | ### fcmean and fcstd will be used for sample z value (latent variables) 110 | return fcmean, fcstd + 1e-6 111 | 112 | def decoder(self,inputs, reuse=False): 113 | 114 | with tf.variable_scope("decoder", reuse = reuse): 115 | ### FC Layer for z 116 | fc = tf.layers.dense(inputs=inputs, 117 | units = 4096, 118 | activation = None) 119 | fc = tf.reshape(fc, [-1, 4, 4, 256]) 120 | 121 | ### Layer 1 122 | deconv1 = tf.image.resize_nearest_neighbor(fc, size=(8,8)) 123 | deconv1 = self._conv2d_bn_relu(deconv1, 124 | name="deconv1", 125 | kernel_size=3, 126 | in_channel=256, 127 | out_channel=128, 128 | stride=1) 129 | 130 | ### Layer 2 131 | deconv2 = tf.image.resize_nearest_neighbor(deconv1, size=(16,16)) 132 | deconv2 = self._conv2d_bn_relu(deconv2, 133 | name="deconv2", 134 | kernel_size=3, 135 | in_channel=128, 136 | out_channel=64, 137 | stride=1) 138 | 139 | ### Layer 3 140 | deconv3 = tf.image.resize_nearest_neighbor(deconv2, size=(32,32)) 141 | deconv3 = self._conv2d_bn_relu(deconv3, 142 | name="deconv3", 143 | kernel_size=3, 144 | in_channel=64, 145 | out_channel=32, 146 | stride=1) 147 | 148 | ### Layer 4 149 | deconv4 = tf.image.resize_nearest_neighbor(deconv3, size=(64,64)) 150 | deconv4 = self._conv2d_bn_relu(deconv4, 151 | name="deconv4", 152 | kernel_size=3, 153 | in_channel=32, 154 | out_channel=3, 155 | stride=1, 156 | activation=False, 157 | bn=False) 158 | 159 | return deconv4 160 | 161 | def load_vgg(self): 162 | 163 | ### pass the input image to VGG model 164 | #self.resize_input_img = tf.image.resize_images(self.img_input, [224,224]) 165 | #self.vgg_input = VGG(self.resize_input_img) 166 | #self.l1_r, self.l2_r, self.l3_r = self.vgg_input.load(reuse=False) 167 | 168 | ### pass the generated image to VGG model 169 | #self.resize_gen_img = tf.image.resize_images(self.gen_img, [224,224]) 170 | #self.vgg_gen = VGG(self.resize_gen_img) 171 | #self.l1_g, self.l2_g, self.l3_g = self.vgg_gen.load(reuse=True) 172 | self.resize_input_img = tf.image.resize_images(self.img_input, [224,224]) 173 | self.vgg_real = vgg16(self.resize_input_img, 'vgg16_weights.npz') 174 | self.l1_r, self.l2_r, self.l3_r = self.vgg_real.get_layers() 175 | 176 | self.resize_gen_img = tf.image.resize_images(self.gen_img, [224,224]) 177 | self.vgg_gen = vgg16(self.resize_gen_img, 'vgg16_weights.npz') 178 | self.l1_g, self.l2_g, self.l3_g = self.vgg_gen.get_layers() 179 | 180 | def calculate_loss(self): 181 | 182 | ### calculate perception loss 183 | #l1_loss = (tf.reduce_sum(tf.square(self.l1_r-self.l1_g)))/tf.cast(tf.size(self.l1_r), tf.float32) 184 | #l2_loss = (tf.reduce_sum(tf.square(self.l2_r-self.l2_g)))/tf.cast(tf.size(self.l2_r), tf.float32) 185 | #l3_loss = (tf.reduce_sum(tf.square(self.l3_r-self.l3_g)))/tf.cast(tf.size(self.l3_r), tf.float32) 186 | l1_loss = tf.reduce_sum(tf.square(self.l1_r-self.l1_g), [1,2,3]) 187 | l2_loss = tf.reduce_sum(tf.square(self.l2_r-self.l2_g), [1,2,3]) 188 | l3_loss = tf.reduce_sum(tf.square(self.l3_r-self.l3_g), [1,2,3]) 189 | self.pct_loss = tf.reduce_mean(l1_loss + l2_loss + l3_loss) 190 | 191 | ### calculate KL loss 192 | self.kl_loss = tf.reduce_mean(-0.5*tf.reduce_sum( 193 | 1 + self.std - tf.square(self.mean) - tf.exp(self.std), 1)) 194 | 195 | ### calculate total loss 196 | self.loss = tf.add(self.beta*self.pct_loss,self.alpha*self.kl_loss) 197 | 198 | def optimize(self): 199 | 200 | ### create optimizer 201 | var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='encoder') + tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='decoder') 202 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss,global_step=self.gstep, var_list=var_list) 203 | 204 | def build_model(self,reuse=tf.AUTO_REUSE): 205 | 206 | ### get mean and std from encoder 207 | self.mean, self.std = self.encoder(reuse) 208 | 209 | ### sampling z and use reparameterization trick 210 | epsilon = tf.random_normal((tf.shape(self.mean)[0],100), mean = 0.0, stddev=1.0) 211 | self.z = self.mean + epsilon * tf.exp(.5*self.std) 212 | 213 | ### decode to get a generated image 214 | self.gen_img = self.decoder(self.z,reuse) 215 | 216 | ### load vgg 217 | self.load_vgg() 218 | 219 | ### calculate loss 220 | self.calculate_loss() 221 | 222 | ### setup optimizer 223 | self.optimize() 224 | 225 | ### generate random latent variable for random images 226 | self.random_latent = tf.random_normal((tf.shape(self.mean)[0], 100)) 227 | self.ran_img = self.decoder(self.random_latent,reuse) 228 | 229 | ### load VGG weight 230 | def load_vgg_weight(self, weight_file, sess): 231 | self.vgg_real.load_weights(weight_file,sess) 232 | self.vgg_gen.load_weights(weight_file,sess) 233 | 234 | 235 | -------------------------------------------------------------------------------- /outputs/test/gen-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbavon/Deep-Feature-Consistent-Variational-AutoEncoder-in-Tensorflow/8db93292f5b4e99fab88f8708f2469bafd253822/outputs/test/gen-1.png -------------------------------------------------------------------------------- /outputs/test/random-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbavon/Deep-Feature-Consistent-Variational-AutoEncoder-in-Tensorflow/8db93292f5b4e99fab88f8708f2469bafd253822/outputs/test/random-2.png -------------------------------------------------------------------------------- /outputs/test/real-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbavon/Deep-Feature-Consistent-Variational-AutoEncoder-in-Tensorflow/8db93292f5b4e99fab88f8708f2469bafd253822/outputs/test/real-1.png -------------------------------------------------------------------------------- /outputs/test_interpolated/interpolate0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbavon/Deep-Feature-Consistent-Variational-AutoEncoder-in-Tensorflow/8db93292f5b4e99fab88f8708f2469bafd253822/outputs/test_interpolated/interpolate0.gif -------------------------------------------------------------------------------- /outputs/test_interpolated/interpolate1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbavon/Deep-Feature-Consistent-Variational-AutoEncoder-in-Tensorflow/8db93292f5b4e99fab88f8708f2469bafd253822/outputs/test_interpolated/interpolate1.gif -------------------------------------------------------------------------------- /train_dfc_vae.py: -------------------------------------------------------------------------------- 1 | 2 | import dfc_vae_model as dfc 3 | import tensorflow as tf 4 | import numpy as np 5 | import util 6 | import matplotlib.pyplot as plt 7 | import os 8 | import sys 9 | import imageio 10 | 11 | ############# Hyper-Parameters ################ 12 | ### Adjust parameters in this part 13 | BATCH_SIZE = 32 14 | NUM_EPOCH = 10 15 | VGG_LAYERS = ['conv1_1','conv2_1','conv3_1'] 16 | ALPHA = 1 17 | BETA = 8e-6 18 | LEARNING_RATE = 0.0001 19 | IMG_HEIGHT = 64 20 | IMG_WIDTH = 64 21 | TRAINING_DATA = 'celeb_data_tfrecord' 22 | IMG_MEAN = np.array([134.10714722, 102.52040863, 87.15436554]) 23 | IMG_STDDEV = np.sqrt(np.array([3941.30175781, 2856.94287109, 2519.35791016])) 24 | ############################################### 25 | 26 | ### restore checkpoint from "Checkpoint" folder 27 | def _restore_checkpoint(saver, sess): 28 | 29 | ckpt_path = os.path.dirname(os.path.join(os.getcwd(),'checkpoint/')) 30 | ckpt = tf.train.get_checkpoint_state(ckpt_path) 31 | if ckpt and ckpt.model_checkpoint_path: 32 | saver.restore(sess, ckpt.model_checkpoint_path) 33 | print("get checkpoint") 34 | return ckpt_path 35 | 36 | ### create dataset and return iterator and dataset 37 | def _get_data(training_data_tfrecord, batch_size): 38 | 39 | dataset = util.read_tfrecord(training_data_tfrecord) 40 | dataset = dataset.batch(batch_size) 41 | iterator = dataset.make_initializable_iterator() 42 | return iterator, dataset 43 | 44 | def train_dfc_vae(): 45 | 46 | ### setup hyper-parameter 47 | batch_size = BATCH_SIZE 48 | epoch = NUM_EPOCH 49 | vgg_layers = VGG_LAYERS 50 | alpha = ALPHA 51 | beta = BETA 52 | learning_rate = LEARNING_RATE 53 | img_height = IMG_HEIGHT 54 | img_width = IMG_WIDTH 55 | training_data_tfrecord = TRAINING_DATA 56 | 57 | ### get training data 58 | iterator, _ = _get_data(training_data_tfrecord, batch_size) 59 | 60 | ### create iterator's initializer for training data 61 | iterator_init = iterator.initializer 62 | data = iterator.get_next() 63 | 64 | ### define input data 65 | img_input = (tf.reshape(data, shape=[-1, img_height, img_width, 3])-IMG_MEAN) / IMG_STDDEV 66 | 67 | ### build model graph 68 | model = dfc.dfc_vae_model([img_height,img_width], img_input, alpha, beta, vgg_layers, learning_rate) 69 | model.build_model(tf.AUTO_REUSE) 70 | 71 | ### create saver for restoring and saving variables 72 | saver = tf.train.Saver() 73 | 74 | with tf.Session() as sess: 75 | 76 | ### initialize global variable 77 | sess.run(tf.global_variables_initializer()) 78 | 79 | ### restore checkpoint 80 | ckpt_path = _restore_checkpoint(saver,sess) 81 | 82 | ### load pre-trained vgg weights 83 | model.load_vgg_weight('vgg16_weights.npz', sess) 84 | 85 | ### lists of losses, used for tracking 86 | kl_loss = [] 87 | pct_loss = [] 88 | total_loss = [] 89 | iteration = [] 90 | 91 | ### count how many training iteration (different from epoch) 92 | iteration_count = 0 93 | 94 | for i in range(epoch): 95 | 96 | ### initialize iterator 97 | sess.run(iterator_init) 98 | 99 | try: 100 | while True: 101 | sys.stdout.write('\r' + 'Iteration: ' + str(iteration_count)) 102 | sys.stdout.flush() 103 | 104 | ### every 100 iteration, print losses 105 | if iteration_count % 100 == 0: 106 | pct, kl, loss, tempmean, tempstd = sess.run([model.pct_loss, model.kl_loss, model.loss, model.mean, model.std]) 107 | pct = pct * beta 108 | print("\nperceptual loss: {}, kl loss: {}, total loss: {}".format(pct,kl,loss)) 109 | print(tempmean[0,0:5]) 110 | print(tempstd[0,0:5]) 111 | 112 | iteration.append(iteration_count) 113 | kl_loss.append(kl) 114 | pct_loss.append(pct) 115 | total_loss.append(loss) 116 | 117 | ### every 500 iteration, save images 118 | if iteration_count % 500 == 0: 119 | 120 | ### get images from the dfc_vae_model 121 | original_img, gen_img, ran_img = sess.run([model.img_input,model.gen_img,model.ran_img]) 122 | 123 | ### denomalize 124 | original_img = original_img*IMG_STDDEV + IMG_MEAN 125 | gen_img = gen_img*IMG_STDDEV + IMG_MEAN 126 | ran_img = ran_img*IMG_STDDEV + IMG_MEAN 127 | 128 | ### clip values to be in RGB range and transform to 0..1 float 129 | original_img = np.clip(original_img,0.,255.).astype('float32')/255. 130 | gen_img = np.clip(gen_img,0.,255.).astype('float32')/255. 131 | ran_img = np.clip(ran_img,0.,255.).astype('float32')/255. 132 | 133 | ### save images in 5x5 grid 134 | util.save_grid_img(original_img, os.path.join(os.getcwd(), 'outputs', str(i) + '_' + str(iteration_count) + '_orginal' + '.png'), img_height, img_width, 5, 5) 135 | util.save_grid_img(gen_img, os.path.join(os.getcwd(), 'outputs', str(i) + '_' + str(iteration_count) + '_generated' + '.png'), img_height, img_width, 5, 5) 136 | util.save_grid_img(ran_img, os.path.join(os.getcwd(), 'outputs', str(i) + '_' + str(iteration_count) + '_random' + '.png'), img_height, img_width, 5, 5) 137 | 138 | ### plot losses 139 | plt.figure() 140 | plt.plot(iteration, kl_loss) 141 | plt.plot(iteration, pct_loss) 142 | plt.plot(iteration, total_loss) 143 | plt.legend(['kl loss', 'perceptual loss', 'total loss'], bbox_to_anchor=(1.05, 1), loc=2) 144 | plt.title('Loss per iteration') 145 | plt.show() 146 | 147 | ### run optimizer 148 | sess.run(model.optimizer) 149 | iteration_count += 1 150 | 151 | except tf.errors.OutOfRangeError: 152 | pass 153 | 154 | ### save session for each epoch 155 | ### recommend to change to save encoder, decoder, VGG's variables separately. 156 | print("\nepoch: {}, loss: {}".format(i, loss)) 157 | saver.save(sess, os.path.join(ckpt_path,"Face_Vae"), global_step = iteration_count + model.gstep) 158 | print("checkpoint saved") 159 | 160 | 161 | def test_gen_img(model, sess, i): 162 | 163 | real_img, gen_img, ran_img = sess.run([model.img_input, model.gen_img, model.ran_img]) 164 | 165 | ran_img = ran_img*IMG_STDDEV + IMG_MEAN 166 | real_img = real_img*IMG_STDDEV + IMG_MEAN 167 | gen_img = gen_img*IMG_STDDEV + IMG_MEAN 168 | 169 | ran_img = ran_img/255. 170 | real_img = real_img/255. 171 | gen_img = gen_img/255. 172 | 173 | util.save_grid_img(ran_img, os.path.join(os.getcwd(), 'outputs', 'test' , 'random-' + str(i) + '.png'), 64,64,8,8) 174 | util.save_grid_img(real_img, os.path.join(os.getcwd(), 'outputs', 'test' , 'real-' + str(i) + '.png'), 64,64,8,8) 175 | util.save_grid_img(gen_img, os.path.join(os.getcwd(), 'outputs', 'test' , 'gen-' + str(i) + '.png'), 64,64,8,8) 176 | 177 | def test_interpolation(model, sess, i): 178 | 179 | z1 = sess.run(model.z) 180 | z2 = sess.run(model.z) 181 | 182 | print(z1.shape) 183 | print(z2.shape) 184 | 185 | print(z1[0,:5]) 186 | print(z2[0,:5]) 187 | 188 | z = tf.Variable(np.zeros(z1.shape).astype(np.float32)) 189 | gen_img = model.decoder(z, tf.AUTO_REUSE) 190 | 191 | interpolated_img_list = [] 192 | 193 | for j in range(31): 194 | interpolated_z = z1 * (30-j)/30. + z2 * j/30. 195 | sess.run(z.assign(interpolated_z)) 196 | interpolated_img = sess.run(gen_img) 197 | interpolated_img = interpolated_img*IMG_STDDEV + IMG_MEAN 198 | interpolated_img = interpolated_img/255. 199 | interpolated_img = util.build_grid_img(interpolated_img, interpolated_img.shape[1], interpolated_img.shape[2],8,8) 200 | interpolated_img_list.append(interpolated_img) 201 | 202 | 203 | for j in range(31): 204 | imageio.mimsave(os.path.join(os.getcwd(), 'outputs', 'test_interpolated' , 'interpolate' + str(i) + '.gif'), interpolated_img_list) 205 | 206 | return interpolated_img_list 207 | 208 | 209 | def test(): 210 | 211 | ### setup hyper-parameter 212 | num_test_set = 2 213 | batch_size = 64 214 | vgg_layers = VGG_LAYERS 215 | img_height = IMG_HEIGHT 216 | img_width = IMG_WIDTH 217 | training_data_tfrecord = TRAINING_DATA 218 | 219 | ### get training data 220 | iterator, _ = _get_data(training_data_tfrecord, batch_size) 221 | 222 | ### create iterator's initializer for training data 223 | iterator_init = iterator.initializer 224 | data = iterator.get_next() 225 | 226 | ### define input data 227 | img_input = (tf.reshape(data, shape=[-1, img_height, img_width, 3]) - IMG_MEAN)/IMG_STDDEV 228 | 229 | ### build model graph 230 | model = dfc.dfc_vae_model([img_height,img_width], img_input) 231 | model.build_model(tf.AUTO_REUSE) 232 | 233 | ### create saver for restoring and saving variables 234 | saver = tf.train.Saver() 235 | 236 | with tf.Session() as sess: 237 | 238 | ### initialize global variable 239 | sess.run(tf.global_variables_initializer()) 240 | 241 | ### restore checkpoint 242 | ckpt_path = _restore_checkpoint(saver,sess) 243 | 244 | sess.run(iterator_init) 245 | 246 | for i in range(num_test_set): 247 | 248 | test_gen_img(model, sess, i) 249 | x = test_interpolation(model, sess, i) 250 | 251 | 252 | if __name__ == '__main__': 253 | tf.reset_default_graph() 254 | train_dfc_vae() 255 | test() 256 | 257 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import cv2 3 | import numpy as np 4 | import os 5 | import matplotlib.pyplot as plt 6 | import scipy.misc 7 | import urllib 8 | from zipfile import ZipFile 9 | from PIL import Image 10 | 11 | ### get the image 12 | def crop_center_image(img): 13 | width_start = int(img.shape[1]/2 - 150/2) 14 | height_start = int(img.shape[0]/2 - 150/2) 15 | cropped_img = img[height_start: height_start+150, width_start: width_start+150, :] 16 | #print(cropped_img.shape) 17 | return cropped_img 18 | 19 | ### download according to address provided and perform cropping 20 | def load_and_crop_image(img, img_width, img_height): 21 | img = scipy.misc.imread(img_addr) 22 | img = crop_center_image(img) 23 | img = scipy.misc.imresize(img, [img_width,img_height]) 24 | return img 25 | 26 | def register_extension(id, extension): 27 | Image.EXTENSION[extension.lower()] = id.upper() 28 | 29 | def register_extensions(id, extensions): 30 | for extension in extensions: register_extension(id, extension) 31 | 32 | ### create grid_img 33 | ### the image inputs will be 4 dimensions, which 0 dimention is the number of example 34 | def build_grid_img(inputs, img_height, img_width, n_row, n_col): 35 | grid_img = np.zeros((img_height*n_row, img_width*n_col, 3)) 36 | print(inputs.shape) 37 | count = 0 38 | for i in range(n_col): 39 | for j in range(n_row): 40 | grid_img[i*img_height:(i+1)*img_height, j*img_width:(j+1)*img_width,:] = inputs[count] 41 | count += 1 42 | return grid_img 43 | 44 | ### save images as a grid 45 | def save_grid_img(inputs, path, img_height, img_width, n_row, n_col): 46 | 47 | Image.register_extension = register_extension 48 | Image.register_extensions = register_extensions 49 | grid_img = build_grid_img(inputs, img_height, img_width, n_row, n_col) 50 | scipy.misc.imsave(path, grid_img) 51 | 52 | def _int64_feature(value): 53 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 54 | 55 | def _bytes_feature(value): 56 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 57 | 58 | ### convert image into binary format 59 | def get_image_binary(img): 60 | shape = np.array(img.shape, np.int32) 61 | img = np.asarray(img,np.uint8) 62 | return img.tobytes(), shape.tobytes() 63 | 64 | ### write data into tf record file format (images are stored in zip file) 65 | def write_tfrecord(tfrecord_filename, zipFileName, img_height, img_width): 66 | 67 | ### images counter 68 | count = 0 69 | 70 | ### create a writer 71 | writer = tf.python_io.TFRecordWriter(tfrecord_filename) 72 | 73 | with ZipFile(zipFileName) as archive: 74 | 75 | for entry in archive.infolist(): 76 | 77 | # skip the folder content 78 | if entry.filename == 'content/': 79 | continue 80 | 81 | with archive.open(entry) as file: 82 | 83 | sys.stdout.write('\r'+str(count)) 84 | 85 | ### pre-process data 86 | img = np.asarray(Image.open(file)) 87 | img = crop_center_image(img) 88 | img = scipy.misc.imresize(img, [img_height,img_width]) 89 | img, shape = get_image_binary(img) 90 | 91 | ### create features 92 | feature = {'image': _bytes_feature(img), 93 | 'shape':_bytes_feature(shape)} 94 | features = tf.train.Features(feature=feature) 95 | 96 | ### create example 97 | example = tf.train.Example(features=features) 98 | 99 | ### write example 100 | writer.write(example.SerializeToString()) 101 | sys.stdout.flush() 102 | 103 | count += 1 104 | 105 | writer.close() 106 | 107 | ### parse serialized data back into the usable form 108 | def _parse(serialized_data): 109 | features = {'image': tf.FixedLenFeature([], tf.string), 110 | 'shape': tf.FixedLenFeature([], tf.string)} 111 | features = tf.parse_single_example(serialized_data, 112 | features) 113 | img = tf.cast(tf.decode_raw(features['image'],tf.uint8), tf.float32) 114 | shape = tf.decode_raw(features['shape'],tf.int32) 115 | img = tf.reshape(img, shape) 116 | 117 | return img 118 | 119 | ### read tf record 120 | def read_tfrecord(tfrecord_filename): 121 | 122 | ### create dataset 123 | dataset = tf.data.TFRecordDataset(tfrecord_filename) 124 | dataset = dataset.map(_parse) 125 | return dataset 126 | 127 | def download(url, file_path): 128 | if os.path.exists(file_path): 129 | print("the file is already existed") 130 | return 131 | else: 132 | print("downloading file...") 133 | urllib.request.urlretrieve(url, file_path) 134 | print("downloading done") --------------------------------------------------------------------------------